| """Gemma3 Pi0.6 (270m VLM) modeling""" |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModelForCausalLM, AutoModelForImageTextToText |
| from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration |
|
|
| from .configuration_gemma3_pi06 import Gemma3Pi06Config |
|
|
|
|
| class Gemma3Pi06ForConditionalGeneration(Gemma3ForConditionalGeneration): |
| """ |
| Gemma3 Pi0.6 - VLM with 270m language model. |
| |
| Combines vision components from gemma-3-4b-pt with language model from gemma-3-270m. |
| """ |
|
|
| config_class = Gemma3Pi06Config |
|
|
| def __init__(self, config: Gemma3Pi06Config): |
| |
| super().__init__(config) |
|
|
| |
| |
| vision_hidden = config.vision_config.hidden_size |
| llm_hidden = config.text_config.hidden_size |
|
|
| |
| self.model.multi_modal_projector.mm_input_projection_weight = nn.Parameter( |
| torch.randn(vision_hidden, llm_hidden) * 0.02 |
| ) |
| self.model.multi_modal_projector.mm_soft_emb_norm = nn.LayerNorm( |
| vision_hidden, eps=config.text_config.rms_norm_eps |
| ) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| """ |
| Load model with weights from two sources: |
| - Vision tower + processor from VLM base (gemma-3-4b-pt) |
| - Language model from LLM base (gemma-3-270m) |
| """ |
| |
| if kwargs.get('_from_checkpoint', False): |
| return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|
| |
| config = Gemma3Pi06Config.from_pretrained( |
| pretrained_model_name_or_path, |
| **kwargs.get('config_kwargs', {}) |
| ) |
|
|
| |
| vlm_base = config.vlm_base_model |
| llm_base = config.llm_base_model |
|
|
| print(f"Loading Gemma3Pi06 model:") |
| print(f" Vision components from: {vlm_base}") |
| print(f" Language model from: {llm_base}") |
|
|
| |
| model = cls(config) |
|
|
| |
| print(f" [1/3] Loading vision tower from {vlm_base}...") |
| vlm_model = AutoModelForImageTextToText.from_pretrained( |
| vlm_base, |
| trust_remote_code=True, |
| torch_dtype=kwargs.get('torch_dtype', torch.bfloat16), |
| low_cpu_mem_usage=True, |
| ) |
|
|
| |
| model.model.vision_tower.load_state_dict(vlm_model.model.vision_tower.state_dict()) |
| print(f" ✓ Vision tower loaded") |
|
|
| |
| print(f" ⚠ Multi-modal projector randomly initialized (1152 -> 640)") |
|
|
| |
| print(f" [2/3] Loading language model from {llm_base}...") |
| llm_model = AutoModelForCausalLM.from_pretrained( |
| llm_base, |
| trust_remote_code=True, |
| torch_dtype=kwargs.get('torch_dtype', torch.bfloat16), |
| low_cpu_mem_usage=True, |
| ) |
|
|
| |
| llm_vocab_size = llm_model.model.embed_tokens.weight.shape[0] |
| vlm_vocab_size = config.text_config.vocab_size |
|
|
| |
| llm_state_dict = llm_model.model.state_dict() |
|
|
| |
| if llm_vocab_size < vlm_vocab_size: |
| print(f" ⚠ Extending embed_tokens: {llm_vocab_size} -> {vlm_vocab_size}") |
| llm_embed = llm_state_dict['embed_tokens.weight'] |
|
|
| |
| extended_embed = torch.randn( |
| vlm_vocab_size, |
| llm_embed.shape[1], |
| dtype=llm_embed.dtype, |
| device=llm_embed.device |
| ) * 0.02 |
|
|
| |
| extended_embed[:llm_vocab_size] = llm_embed |
| llm_state_dict['embed_tokens.weight'] = extended_embed |
|
|
| model.model.language_model.load_state_dict(llm_state_dict) |
| print(f" ✓ Language model loaded (vocab extended for image tokens)") |
|
|
| |
| print(f" [3/3] Loading lm_head...") |
| llm_lm_head = llm_model.lm_head.weight |
|
|
| if llm_vocab_size < vlm_vocab_size: |
| print(f" ⚠ Extending lm_head: {llm_vocab_size} -> {vlm_vocab_size}") |
| |
| extended_lm_head = torch.randn( |
| vlm_vocab_size, |
| llm_lm_head.shape[1], |
| dtype=llm_lm_head.dtype, |
| device=llm_lm_head.device |
| ) * 0.02 |
|
|
| |
| extended_lm_head[:llm_vocab_size] = llm_lm_head |
| model.lm_head.weight.data = extended_lm_head |
| else: |
| model.lm_head.weight.data = llm_lm_head |
|
|
| print(f" ✓ lm_head loaded (vocab extended for image tokens)") |
|
|
| |
| if 'device_map' in kwargs: |
| device_map = kwargs['device_map'] |
| if device_map != 'auto': |
| model = model.to(device_map) |
|
|
| print(f"✓ Gemma3Pi06 model loaded successfully") |
|
|
| return model |
|
|
| def save_pretrained(self, save_directory, **kwargs): |
| """Save model with special marker to load correctly""" |
| |
| kwargs['_from_checkpoint'] = True |
| return super().save_pretrained(save_directory, **kwargs) |
|
|