ImNotTam's picture
Upload Spider MedGemma fine-tuned model with LoRA adapters
c8cf6e2 verified
|
Raw
History Blame Contribute Delete
3.65 kB
---
language:
- vi
- en
license: apache-2.0
tags:
- medical-vision
- medgemma
- vision-language-model
- lora
- unsloth
- spider-dataset
---
# medgemma-spider-finetuned
Fine-tuned MedGemma model for Spider dataset - Medical image analysis with multiple images per patient.
## 📋 Model Information
- **Base Model**: `google/medgemma-4b-it`
- **Fine-tuning Method**: LoRA (Low-Rank Adaptation)
- **Dataset**: Spider dataset (series format - 1 patient = multiple images)
- **Training Framework**: Unsloth (2x faster training)
## 📂 Folder Structure
```
output_medgemma_spider/
├── final_model/ # Full merged model (large)
├── lora_adapters/ # LoRA adapters only (recommended, lightweight)
├── checkpoint-*/ # Training checkpoints
├── trainer_state.json # Training state
└── eval_metrics.json # Evaluation metrics
```
## 🚀 Usage
### 1️⃣ Load LoRA Adapters (Recommended - Lightweight)
```python
from unsloth import FastVisionModel
model, processor = FastVisionModel.from_pretrained(
model_name="ImNotTam/medgemma-spider-finetuned",
subfolder="lora_adapters",
load_in_4bit=True,
)
# Enable inference mode
FastVisionModel.for_inference(model)
# Prepare input with multiple images
image_paths = ["path/to/image1.png", "path/to/image2.png", ...]
question = "What do you see in these images?"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img_path} for img_path in image_paths
] + [{"type": "text", "text": question}]
}
]
# Generate response
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0], skip_special_tokens=True)
print(response)
```
### 2️⃣ Load Final Model (Full Model)
```python
from transformers import AutoModelForVision2Seq, AutoProcessor
model = AutoModelForVision2Seq.from_pretrained(
"ImNotTam/medgemma-spider-finetuned",
subfolder="final_model",
device_map="auto",
torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(
"ImNotTam/medgemma-spider-finetuned",
subfolder="final_model"
)
# Use same inference code as above
```
### 3️⃣ Continue Training from LoRA Adapters
```python
from unsloth import FastVisionModel
from trl import SFTTrainer
# Load LoRA adapter
model, processor = FastVisionModel.from_pretrained(
model_name="ImNotTam/medgemma-spider-finetuned",
subfolder="lora_adapters",
load_in_4bit=True,
)
# Add new LoRA config để train tiếp
model = FastVisionModel.get_peft_model(
model,
r=24,
lora_alpha=48,
lora_dropout=0.1,
finetune_vision_layers=True,
finetune_language_layers=True,
)
# Train với data mới
trainer = SFTTrainer(
model=model,
tokenizer=processor,
train_dataset=your_new_dataset,
# ... training args
)
trainer.train()
```
## 📊 Training Details
- **LoRA Rank**: 24
- **LoRA Alpha**: 48
- **LoRA Dropout**: 0.1
- **Batch Size**: 2 (per device)
- **Gradient Accumulation**: 12 steps
- **Effective Batch Size**: 24
- **Learning Rate**: 2e-4
- **Max Sequence Length**: 1280
- **Max Images per Sample**: 18
- **Epochs**: 7
## 💡 Recommendations
- **For Inference**: Use `lora_adapters/` (lightweight, fast)
- **For Production**: Use `final_model/` (full model)
- **For Continued Training**: Load `lora_adapters/` + add new LoRA config
## 📦 Requirements
```bash
pip install unsloth transformers torch trl pillow
```
## 📄 License
Apache 2.0