--- language: - en license: mit tags: - vision - image-to-text - image-captioning - gpt2 - vision-transformer - flickr8k - multimodal - cross-attention pipeline_tag: image-to-text datasets: - jxie/flickr8k --- # Vision-GPT: Multimodal Image Captioning A lightweight multimodal model combining GPT-2 and Vision Transformer for image captioning, built for **Smart India Hackathon 2025**. ## 🎯 Architecture **Sparse Cross-Attention Fusion** (inspired by Llama 3.2) - **GPT-2** (124M params) - Language model backbone ❄️ frozen - **ViT-B/16** (87M params) - Visual encoder ❄️ frozen - **Cross-Attention + Perceiver Resampler** (11M params) - Vision-language fusion 🔥 trainable - Cross-attention inserted at layers 3, 6, 9 - Perceiver Resampler for efficient visual token compression **Total**: 222M params | **Trainable**: 11M params (5%) ## 📊 Training Details - **Dataset**: Flickr8k - **Epochs**: 2 - **Final Loss**: 2.632 - **Strategy**: Freeze pretrained models, train only cross-attention layers - **Hardware**: Single GPU ## 📦 Model Versions ### FP32 (Full Precision) - **Size**: 0.89 GB - **Precision**: 32-bit floating point - **Use case**: Maximum accuracy, research - **Path**: `model_fp32/model_checkpoint.pth` ### FP16 (Half Precision) - **Size**: 0.52 GB - **Precision**: 16-bit floating point - **Use case**: Faster inference, 2x memory reduction - **Path**: `model_fp16/model_checkpoint.pth` - **Space saved**: 0.37 GB (41.9% reduction) ## 🚀 Quick Start ### Installation ```bash pip install torch torchvision transformers pillow huggingface-hub ``` ### Load Model (FP32) ```python import torch from huggingface_hub import hf_hub_download # Download checkpoint checkpoint_path = hf_hub_download( repo_id="gurumurthy3/vision-gpt-flickr8k_v2", filename="model_fp32/model_checkpoint.pth" ) # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location="cpu") model_state_dict = checkpoint['model_state_dict'] # Load your model architecture and weights # model.load_state_dict(model_state_dict) # model.eval() ``` ### Load Model (FP16 - Faster Inference) ```python import torch from huggingface_hub import hf_hub_download # Download FP16 checkpoint checkpoint_path = hf_hub_download( repo_id="gurumurthy3/vision-gpt-flickr8k_v2", filename="model_fp16/model_checkpoint.pth" ) # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location="cpu") # Load model # model.load_state_dict(checkpoint['model_state_dict']) # model.half() # Ensure model is in FP16 # model.eval() # For GPU inference with FP16 # model = model.to('cuda') # images = images.to('cuda').half() ``` ### Generate Caption ```python from PIL import Image from torchvision import transforms # Image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) # Load and preprocess image image = Image.open("your_image.jpg").convert('RGB') image_tensor = transform(image).unsqueeze(0) # Generate caption with torch.no_grad(): caption = model.generate(image_tensor, max_length=50) print(f"Caption: {caption}") ``` ## 🎨 Demo Try it live: [Multimodal GPT-2 Demo](https://huggingface.co/spaces/gurumurthy3/Multimodal-Gpt2-Demo) ## 🏗️ Model Architecture ``` Input Image (224×224) ↓ ViT-B/16 Encoder ❄️ (87M params) ↓ Perceiver Resampler 🔥 (compress to 64 tokens) ↓ Cross-Attention Layers 🔥 (at layers 3, 6, 9) ↓ GPT-2 ❄️ (124M params) ↓ Generated Caption ``` ## ⚠️ Limitations - Trained only on Flickr8k (limited domain coverage) - English captions only - Best for images similar to Flickr8k dataset (people, activities, scenes) - May generate generic captions for out-of-domain images ## 📝 Citation ```bibtex @misc{vision-gpt-flickr8k-2025, author = {gurumurthy3}, title = {Vision-GPT: Multimodal Image Captioning with Sparse Cross-Attention}, year = {2025}, publisher = {Hugging Face}, journal = {Hugging Face Model Hub}, howpublished = {\url{https://huggingface.co/gurumurthy3/vision-gpt-flickr8k_v2}} } ``` ## 🙏 Acknowledgments - **Llama 3.2 Vision** for sparse cross-attention inspiration - **OpenAI** for GPT-2 - **Google Research** for Vision Transformer (ViT) - **Flickr8k Dataset** by Hodosh et al. - **Smart India Hackathon 2025** ## 📄 License MIT License --- Built with ❤️ for Smart India Hackathon 2025