import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( AutoModelForCausalLM, PreTrainedModel, SiglipVisionModel, SiglipImageProcessor, ) from .configuration_nanbeige_vlm import NanbeigeVLMConfig class PooledProjector(nn.Module): def __init__(self, vision_hidden_size, llm_hidden_size): super().__init__() self.grid_size = 27 self.proj = nn.Sequential( nn.Linear(vision_hidden_size, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size), ) def forward(self, image_features): B, N, C = image_features.shape x = image_features.permute(0, 2, 1).reshape(B, C, self.grid_size, self.grid_size) x = F.pad(x, (0, 1, 0, 1), mode="replicate") x = F.avg_pool2d(x, kernel_size=2, stride=2) x = x.flatten(2).permute(0, 2, 1) return self.proj(x) class NanbeigeVLM(PreTrainedModel): config_class = NanbeigeVLMConfig _tied_weights_keys = [] _no_split_modules = [] @classmethod def from_pretrained(cls, *args, **kwargs): kwargs.setdefault("ignore_mismatched_sizes", True) return super().from_pretrained(*args, **kwargs) @property def all_tied_weights_keys(self): return {} def __init__(self, config: NanbeigeVLMConfig): super().__init__(config) vision_hidden_size = 1152 llm_hidden_size = 2560 self.mm_projector = PooledProjector(vision_hidden_size, llm_hidden_size) self.image_token_id = config.image_token_id self._tokenizer = None self._processor = None self._submodels_loaded = False self.vision_tower = None self.language_model = None def set_tokenizer(self, tokenizer): self._tokenizer = tokenizer self._processor = SiglipImageProcessor.from_pretrained( self.config.vision_model_id ) if not self._submodels_loaded: device = next(self.mm_projector.parameters()).device dtype = next(self.mm_projector.parameters()).dtype self.vision_tower = SiglipVisionModel.from_pretrained( self.config.vision_model_id, torch_dtype=dtype ).to(device).eval() self.vision_tower.requires_grad_(False) self.language_model = AutoModelForCausalLM.from_pretrained( self.config.llm_model_id, trust_remote_code=True, torch_dtype=dtype, ).to(device).eval() self.language_model.resize_token_embeddings(len(tokenizer)) self._submodels_loaded = True @torch.no_grad() def describe(self, image, prompt="Describe the image.", max_new_tokens=256): assert self._processor is not None, "Call set_tokenizer() first." device = next(self.mm_projector.parameters()).device dtype = next(self.mm_projector.parameters()).dtype pixel_values = self._processor( images=image, return_tensors="pt" ).pixel_values.to(device, dtype=dtype) image_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state image_embeds = self.mm_projector(image_features) full_prompt = f"\n{prompt}" input_ids = self._tokenizer(full_prompt, return_tensors="pt").input_ids.to(device) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) positions = (input_ids[0] == self.image_token_id).nonzero(as_tuple=True)[0] if len(positions): p = positions[0].item() inputs_embeds = torch.cat([ inputs_embeds[0, :p], image_embeds[0], inputs_embeds[0, p + 1:] ], dim=0).unsqueeze(0) output_ids = self.language_model.generate( inputs_embeds = inputs_embeds, attention_mask = torch.ones(inputs_embeds.shape[:2], device=device), max_new_tokens = max_new_tokens, do_sample = False, repetition_penalty = 1.3, eos_token_id = self._tokenizer.eos_token_id, pad_token_id = self._tokenizer.eos_token_id, ) return self._tokenizer.decode(output_ids[0], skip_special_tokens=True)