"""Gemma3 Pi0.6 (270m VLM) modeling""" import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoModelForImageTextToText from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration from .configuration_gemma3_pi06 import Gemma3Pi06Config class Gemma3Pi06ForConditionalGeneration(Gemma3ForConditionalGeneration): """ Gemma3 Pi0.6 - VLM with 270m language model. Combines vision components from gemma-3-4b-pt with language model from gemma-3-270m. """ config_class = Gemma3Pi06Config def __init__(self, config: Gemma3Pi06Config): # Initialize with the config (creates architecture with 270m LLM size) super().__init__(config) # Reinitialize projector for correct dimensions # Vision hidden: 1152 -> LLM hidden: 640 (for 270m) vision_hidden = config.vision_config.hidden_size llm_hidden = config.text_config.hidden_size # Recreate projector with correct dimensions self.model.multi_modal_projector.mm_input_projection_weight = nn.Parameter( torch.randn(vision_hidden, llm_hidden) * 0.02 ) self.model.multi_modal_projector.mm_soft_emb_norm = nn.LayerNorm( vision_hidden, eps=config.text_config.rms_norm_eps ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Load model with weights from two sources: - Vision tower + processor from VLM base (gemma-3-4b-pt) - Language model from LLM base (gemma-3-270m) """ # If loading from a saved checkpoint (not initial creation) if kwargs.get('_from_checkpoint', False): return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) # Load config config = Gemma3Pi06Config.from_pretrained( pretrained_model_name_or_path, **kwargs.get('config_kwargs', {}) ) # Get base model IDs vlm_base = config.vlm_base_model llm_base = config.llm_base_model print(f"Loading Gemma3Pi06 model:") print(f" Vision components from: {vlm_base}") print(f" Language model from: {llm_base}") # Initialize model with config model = cls(config) # Load vision tower and projector from VLM print(f" [1/3] Loading vision tower from {vlm_base}...") vlm_model = AutoModelForImageTextToText.from_pretrained( vlm_base, trust_remote_code=True, torch_dtype=kwargs.get('torch_dtype', torch.bfloat16), low_cpu_mem_usage=True, ) # Copy vision tower weights model.model.vision_tower.load_state_dict(vlm_model.model.vision_tower.state_dict()) print(f" ✓ Vision tower loaded") # Note: projector will be randomly initialized (new dimensions) print(f" ⚠ Multi-modal projector randomly initialized (1152 -> 640)") # Load language model from LLM print(f" [2/3] Loading language model from {llm_base}...") llm_model = AutoModelForCausalLM.from_pretrained( llm_base, trust_remote_code=True, torch_dtype=kwargs.get('torch_dtype', torch.bfloat16), low_cpu_mem_usage=True, ) # Copy language model weights with vocab size handling llm_vocab_size = llm_model.model.embed_tokens.weight.shape[0] # 262144 vlm_vocab_size = config.text_config.vocab_size # 262208 (includes image tokens) # Load LLM state dict llm_state_dict = llm_model.model.state_dict() # Handle embed_tokens: extend with random init for image tokens if llm_vocab_size < vlm_vocab_size: print(f" ⚠ Extending embed_tokens: {llm_vocab_size} -> {vlm_vocab_size}") llm_embed = llm_state_dict['embed_tokens.weight'] # Create extended embedding with same dtype extended_embed = torch.randn( vlm_vocab_size, llm_embed.shape[1], dtype=llm_embed.dtype, device=llm_embed.device ) * 0.02 # Copy original embeddings extended_embed[:llm_vocab_size] = llm_embed llm_state_dict['embed_tokens.weight'] = extended_embed model.model.language_model.load_state_dict(llm_state_dict) print(f" ✓ Language model loaded (vocab extended for image tokens)") # Copy lm_head with vocab size handling print(f" [3/3] Loading lm_head...") llm_lm_head = llm_model.lm_head.weight if llm_vocab_size < vlm_vocab_size: print(f" ⚠ Extending lm_head: {llm_vocab_size} -> {vlm_vocab_size}") # Create extended lm_head extended_lm_head = torch.randn( vlm_vocab_size, llm_lm_head.shape[1], dtype=llm_lm_head.dtype, device=llm_lm_head.device ) * 0.02 # Copy original weights extended_lm_head[:llm_vocab_size] = llm_lm_head model.lm_head.weight.data = extended_lm_head else: model.lm_head.weight.data = llm_lm_head print(f" ✓ lm_head loaded (vocab extended for image tokens)") # Move to device if specified if 'device_map' in kwargs: device_map = kwargs['device_map'] if device_map != 'auto': model = model.to(device_map) print(f"✓ Gemma3Pi06 model loaded successfully") return model def save_pretrained(self, save_directory, **kwargs): """Save model with special marker to load correctly""" # Mark this as a checkpoint so from_pretrained doesn't try to reload from bases kwargs['_from_checkpoint'] = True return super().save_pretrained(save_directory, **kwargs)