gemma-3-0.7b-vlm-custom / modeling_gemma3_pi06.py
sonsus's picture
Upload folder using huggingface_hub
86096d3 verified
raw
history blame
5.93 kB
"""Gemma3 Pi0.6 (270m VLM) modeling"""
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
from .configuration_gemma3_pi06 import Gemma3Pi06Config
class Gemma3Pi06ForConditionalGeneration(Gemma3ForConditionalGeneration):
"""
Gemma3 Pi0.6 - VLM with 270m language model.
Combines vision components from gemma-3-4b-pt with language model from gemma-3-270m.
"""
config_class = Gemma3Pi06Config
def __init__(self, config: Gemma3Pi06Config):
# Initialize with the config (creates architecture with 270m LLM size)
super().__init__(config)
# Reinitialize projector for correct dimensions
# Vision hidden: 1152 -> LLM hidden: 640 (for 270m)
vision_hidden = config.vision_config.hidden_size
llm_hidden = config.text_config.hidden_size
# Recreate projector with correct dimensions
self.model.multi_modal_projector.mm_input_projection_weight = nn.Parameter(
torch.randn(vision_hidden, llm_hidden) * 0.02
)
self.model.multi_modal_projector.mm_soft_emb_norm = nn.LayerNorm(
vision_hidden, eps=config.text_config.rms_norm_eps
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Load model with weights from two sources:
- Vision tower + processor from VLM base (gemma-3-4b-pt)
- Language model from LLM base (gemma-3-270m)
"""
# If loading from a saved checkpoint (not initial creation)
if kwargs.get('_from_checkpoint', False):
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
# Load config
config = Gemma3Pi06Config.from_pretrained(
pretrained_model_name_or_path,
**kwargs.get('config_kwargs', {})
)
# Get base model IDs
vlm_base = config.vlm_base_model
llm_base = config.llm_base_model
print(f"Loading Gemma3Pi06 model:")
print(f" Vision components from: {vlm_base}")
print(f" Language model from: {llm_base}")
# Initialize model with config
model = cls(config)
# Load vision tower and projector from VLM
print(f" [1/3] Loading vision tower from {vlm_base}...")
vlm_model = AutoModelForImageTextToText.from_pretrained(
vlm_base,
trust_remote_code=True,
torch_dtype=kwargs.get('torch_dtype', torch.bfloat16),
low_cpu_mem_usage=True,
)
# Copy vision tower weights
model.model.vision_tower.load_state_dict(vlm_model.model.vision_tower.state_dict())
print(f" ✓ Vision tower loaded")
# Note: projector will be randomly initialized (new dimensions)
print(f" ⚠ Multi-modal projector randomly initialized (1152 -> 640)")
# Load language model from LLM
print(f" [2/3] Loading language model from {llm_base}...")
llm_model = AutoModelForCausalLM.from_pretrained(
llm_base,
trust_remote_code=True,
torch_dtype=kwargs.get('torch_dtype', torch.bfloat16),
low_cpu_mem_usage=True,
)
# Copy language model weights with vocab size handling
llm_vocab_size = llm_model.model.embed_tokens.weight.shape[0] # 262144
vlm_vocab_size = config.text_config.vocab_size # 262208 (includes image tokens)
# Load LLM state dict
llm_state_dict = llm_model.model.state_dict()
# Handle embed_tokens: extend with random init for image tokens
if llm_vocab_size < vlm_vocab_size:
print(f" ⚠ Extending embed_tokens: {llm_vocab_size} -> {vlm_vocab_size}")
llm_embed = llm_state_dict['embed_tokens.weight']
# Create extended embedding with same dtype
extended_embed = torch.randn(
vlm_vocab_size,
llm_embed.shape[1],
dtype=llm_embed.dtype,
device=llm_embed.device
) * 0.02
# Copy original embeddings
extended_embed[:llm_vocab_size] = llm_embed
llm_state_dict['embed_tokens.weight'] = extended_embed
model.model.language_model.load_state_dict(llm_state_dict)
print(f" ✓ Language model loaded (vocab extended for image tokens)")
# Copy lm_head with vocab size handling
print(f" [3/3] Loading lm_head...")
llm_lm_head = llm_model.lm_head.weight
if llm_vocab_size < vlm_vocab_size:
print(f" ⚠ Extending lm_head: {llm_vocab_size} -> {vlm_vocab_size}")
# Create extended lm_head
extended_lm_head = torch.randn(
vlm_vocab_size,
llm_lm_head.shape[1],
dtype=llm_lm_head.dtype,
device=llm_lm_head.device
) * 0.02
# Copy original weights
extended_lm_head[:llm_vocab_size] = llm_lm_head
model.lm_head.weight.data = extended_lm_head
else:
model.lm_head.weight.data = llm_lm_head
print(f" ✓ lm_head loaded (vocab extended for image tokens)")
# Move to device if specified
if 'device_map' in kwargs:
device_map = kwargs['device_map']
if device_map != 'auto':
model = model.to(device_map)
print(f"✓ Gemma3Pi06 model loaded successfully")
return model
def save_pretrained(self, save_directory, **kwargs):
"""Save model with special marker to load correctly"""
# Mark this as a checkpoint so from_pretrained doesn't try to reload from bases
kwargs['_from_checkpoint'] = True
return super().save_pretrained(save_directory, **kwargs)