| """Gradio Space for Flickr8k image captioning. |
| |
| On first launch, pulls the trained checkpoint + vocabulary from the |
| companion model repo on the Hub: |
| https://huggingface.co/OmarGamal48812/flickr-captioning |
| |
| After that, every request runs locally on the Space's CPU. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import io |
| import os |
| import pickle |
| from pathlib import Path |
|
|
| |
| |
| |
| |
| |
| import gradio_client.utils as _gcu |
|
|
| _orig_jspt = _gcu._json_schema_to_python_type |
| _orig_get_type = _gcu.get_type |
|
|
|
|
| def _safe_jspt(schema, defs=None): |
| if not isinstance(schema, dict): |
| return "Any" |
| return _orig_jspt(schema, defs) |
|
|
|
|
| def _safe_get_type(schema): |
| if not isinstance(schema, dict): |
| return "Any" |
| return _orig_get_type(schema) |
|
|
|
|
| _gcu._json_schema_to_python_type = _safe_jspt |
| _gcu.get_type = _safe_get_type |
| |
|
|
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import torch |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
|
|
| from src.inference import encode_image, generate_beam, load_attention_model |
| from src.utils import get_device |
| from src.visualize import attention_heatmap_for_image, plot_attention_heatmap |
| from src.vocabulary import Vocabulary |
|
|
|
|
| MODEL_REPO = os.environ.get("MODEL_REPO", "OmarGamal48812/flickr-captioning") |
| DEFAULT_BEAM_WIDTH = int(os.environ.get("BEAM_WIDTH", "5")) |
| DEFAULT_MAX_LEN = int(os.environ.get("MAX_LEN", "20")) |
|
|
|
|
| |
| |
| |
|
|
| print(f"[gradio] downloading checkpoint + vocab from {MODEL_REPO} ...") |
| CHECKPOINT_PATH = hf_hub_download(repo_id=MODEL_REPO, filename="attention_gru_glove.pth") |
| VOCAB_PATH = hf_hub_download(repo_id=MODEL_REPO, filename="vocab.pkl") |
|
|
| print("[gradio] loading model ...") |
| _device = get_device() |
| with open(VOCAB_PATH, "rb") as f: |
| _vocab: Vocabulary = pickle.load(f) |
| _encoder, _decoder, _cfg = load_attention_model(CHECKPOINT_PATH, len(_vocab), _device) |
| print(f"[gradio] model loaded on {_device}") |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _fig_to_pil(fig: plt.Figure) -> Image.Image: |
| buf = io.BytesIO() |
| fig.savefig(buf, format="png", dpi=110, bbox_inches="tight") |
| plt.close(fig) |
| buf.seek(0) |
| return Image.open(buf).convert("RGB") |
|
|
|
|
| def caption_image(image: Image.Image, beam_width: int = DEFAULT_BEAM_WIDTH): |
| if image is None: |
| return "Please upload an image.", "", None |
|
|
| tensor = encode_image(image, _device) |
| beams = generate_beam( |
| _encoder, _decoder, tensor, _vocab, |
| beam_width=beam_width, max_len=DEFAULT_MAX_LEN, |
| ) |
| best = beams[0] |
| best_md = ( |
| f"### {best.caption}\n\n" |
| f"*(beam search k={beam_width}, normalized log-prob = {best.score:.3f})*" |
| ) |
|
|
| alts_lines = ["**Alternative captions:**"] |
| for i, b in enumerate(beams[1:], start=2): |
| alts_lines.append(f"{i}. `{b.score:+.3f}` {b.caption}") |
| alts_md = "\n".join(alts_lines) if len(beams) > 1 else "" |
|
|
| caption_str, tokens, alphas, img_tensor = attention_heatmap_for_image( |
| encoder=_encoder, |
| decoder=_decoder, |
| image=image, |
| vocab=_vocab, |
| device=_device, |
| max_len=DEFAULT_MAX_LEN, |
| ) |
| fig = plot_attention_heatmap( |
| img_tensor, tokens, alphas, |
| title=f"Attention · {caption_str}", |
| ) |
| heatmap_img = _fig_to_pil(fig) |
|
|
| return best_md, alts_md, heatmap_img |
|
|
|
|
| |
| |
| |
|
|
|
|
| DESCRIPTION = f""" |
| ### What this does |
| Upload an image and the model will generate a caption **plus** a per-word |
| attention heatmap showing which part of the image influenced each word. |
| |
| ### Model |
| ResNet50 + Bahdanau attention + `{_cfg.rnn_type.upper()}Cell` decoder + GloVe 300d |
| (embed={_cfg.embed_size}, hidden={_cfg.hidden_size}, attention={_cfg.attention_dim}). |
| Trained on Flickr8k + Flickr30k (~39,874 images, 185K captions). Test-set BLEU-4 = **0.3093** with beam search k=5. |
| |
| Checkpoint: [{MODEL_REPO}](https://huggingface.co/{MODEL_REPO}) |
| """ |
|
|
| with gr.Blocks(title="Flickr Image Captioning") as demo: |
| gr.Markdown("# Flickr Image Captioning") |
| gr.Markdown(DESCRIPTION) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image(type="pil", label="Upload image") |
| beam_slider = gr.Slider( |
| minimum=1, maximum=10, value=DEFAULT_BEAM_WIDTH, step=1, |
| label="Beam width", |
| ) |
| run_btn = gr.Button("Caption this image", variant="primary") |
| with gr.Column(scale=1): |
| caption_md = gr.Markdown(label="Generated caption") |
| alts_md = gr.Markdown(label="Alternatives") |
|
|
| heatmap_image = gr.Image( |
| type="pil", label="Per-word attention heatmaps", interactive=False |
| ) |
|
|
| run_btn.click( |
| fn=caption_image, |
| inputs=[image_input, beam_slider], |
| outputs=[caption_md, alts_md, heatmap_image], |
| ) |
| image_input.change( |
| fn=lambda: ("", "", None), |
| outputs=[caption_md, alts_md, heatmap_image], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| ssr_mode=False, |
| show_api=False, |
| ) |
|
|