File size: 3,649 Bytes
c8cf6e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
---
language:
- vi
- en
license: apache-2.0
tags:
- medical-vision
- medgemma
- vision-language-model
- lora
- unsloth
- spider-dataset
---

# medgemma-spider-finetuned

Fine-tuned MedGemma model for Spider dataset - Medical image analysis with multiple images per patient.

## 📋 Model Information

- **Base Model**: `google/medgemma-4b-it`
- **Fine-tuning Method**: LoRA (Low-Rank Adaptation)
- **Dataset**: Spider dataset (series format - 1 patient = multiple images)
- **Training Framework**: Unsloth (2x faster training)

## 📂 Folder Structure
```
output_medgemma_spider/
├── final_model/           # Full merged model (large)
├── lora_adapters/         # LoRA adapters only (recommended, lightweight)
├── checkpoint-*/          # Training checkpoints
├── trainer_state.json     # Training state
└── eval_metrics.json      # Evaluation metrics
```

## 🚀 Usage

### 1️⃣ Load LoRA Adapters (Recommended - Lightweight)

```python
from unsloth import FastVisionModel

model, processor = FastVisionModel.from_pretrained(
    model_name="ImNotTam/medgemma-spider-finetuned",
    subfolder="lora_adapters",
    load_in_4bit=True,
)

# Enable inference mode
FastVisionModel.for_inference(model)

# Prepare input with multiple images
image_paths = ["path/to/image1.png", "path/to/image2.png", ...]
question = "What do you see in these images?"

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": img_path} for img_path in image_paths
        ] + [{"type": "text", "text": question}]
    }
]

# Generate response
inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_tensors="pt",
).to("cuda")

outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0], skip_special_tokens=True)
print(response)
```

### 2️⃣ Load Final Model (Full Model)

```python
from transformers import AutoModelForVision2Seq, AutoProcessor

model = AutoModelForVision2Seq.from_pretrained(
    "ImNotTam/medgemma-spider-finetuned",
    subfolder="final_model",
    device_map="auto",
    torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(
    "ImNotTam/medgemma-spider-finetuned",
    subfolder="final_model"
)

# Use same inference code as above
```

### 3️⃣ Continue Training from LoRA Adapters

```python
from unsloth import FastVisionModel
from trl import SFTTrainer

# Load LoRA adapter
model, processor = FastVisionModel.from_pretrained(
    model_name="ImNotTam/medgemma-spider-finetuned",
    subfolder="lora_adapters",
    load_in_4bit=True,
)

# Add new LoRA config để train tiếp
model = FastVisionModel.get_peft_model(
    model,
    r=24,
    lora_alpha=48,
    lora_dropout=0.1,
    finetune_vision_layers=True,
    finetune_language_layers=True,
)

# Train với data mới
trainer = SFTTrainer(
    model=model,
    tokenizer=processor,
    train_dataset=your_new_dataset,
    # ... training args
)
trainer.train()
```

## 📊 Training Details

- **LoRA Rank**: 24
- **LoRA Alpha**: 48
- **LoRA Dropout**: 0.1
- **Batch Size**: 2 (per device)
- **Gradient Accumulation**: 12 steps
- **Effective Batch Size**: 24
- **Learning Rate**: 2e-4
- **Max Sequence Length**: 1280
- **Max Images per Sample**: 18
- **Epochs**: 7

## 💡 Recommendations

- **For Inference**: Use `lora_adapters/` (lightweight, fast)
- **For Production**: Use `final_model/` (full model)
- **For Continued Training**: Load `lora_adapters/` + add new LoRA config

## 📦 Requirements

```bash
pip install unsloth transformers torch trl pillow
```

## 📄 License

Apache 2.0