SkyAsl commited on
Commit
5633111
·
verified ·
1 Parent(s): 4a765e2

Upload 2 files

Browse files
configuration_nanbeige_vlm.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class NanbeigeVLMConfig(PretrainedConfig):
5
+ model_type = "nanbeige_vlm"
6
+
7
+ def __init__(
8
+ self,
9
+ vision_model_id="google/siglip-so400m-patch14-384",
10
+ llm_model_id="Nanbeige/Nanbeige4.1-3B",
11
+ image_token_id=None,
12
+ **kwargs,
13
+ ):
14
+ super().__init__(**kwargs)
15
+ self.vision_model_id = vision_model_id
16
+ self.llm_model_id = llm_model_id
17
+ self.image_token_id = image_token_id
modeling_nanbeige_vlm.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ PreTrainedModel,
7
+ SiglipVisionModel,
8
+ SiglipImageProcessor,
9
+ )
10
+ from .configuration_nanbeige_vlm import NanbeigeVLMConfig
11
+
12
+
13
+ class NanbeigeVLM(PreTrainedModel):
14
+ config_class = NanbeigeVLMConfig
15
+
16
+ def __init__(self, config: NanbeigeVLMConfig):
17
+ super().__init__(config)
18
+
19
+ self.vision_tower = SiglipVisionModel.from_pretrained(
20
+ config.vision_model_id, torch_dtype=torch.bfloat16
21
+ )
22
+ self.vision_tower.requires_grad_(False)
23
+
24
+ vision_hidden_size = self.vision_tower.config.hidden_size
25
+
26
+ try:
27
+ self.language_model = AutoModelForCausalLM.from_pretrained(
28
+ config.llm_model_id,
29
+ trust_remote_code=True,
30
+ torch_dtype=torch.bfloat16,
31
+ attn_implementation="flash_attention_2",
32
+ )
33
+ except (ImportError, ValueError):
34
+ self.language_model = AutoModelForCausalLM.from_pretrained(
35
+ config.llm_model_id,
36
+ trust_remote_code=True,
37
+ torch_dtype=torch.bfloat16,
38
+ )
39
+
40
+ llm_hidden_size = self.language_model.config.hidden_size
41
+
42
+ self.mm_projector = nn.Sequential(
43
+ nn.Linear(vision_hidden_size, llm_hidden_size),
44
+ nn.GELU(),
45
+ nn.Linear(llm_hidden_size, llm_hidden_size),
46
+ ).to(torch.bfloat16)
47
+
48
+ self.image_token_id = config.image_token_id
49
+ self._tokenizer = None
50
+ self._processor = None
51
+
52
+ def set_tokenizer(self, tokenizer):
53
+ self._tokenizer = tokenizer
54
+ self._processor = SiglipImageProcessor.from_pretrained(self.config.vision_model_id)
55
+ if self.image_token_id is None:
56
+ self.image_token_id = tokenizer.convert_tokens_to_ids("<image>")
57
+
58
+ def _merge_image_embeddings(self, input_ids, pixel_values):
59
+ image_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state
60
+ image_embeds = self.mm_projector(image_features)
61
+ num_image_tokens = image_embeds.shape[1]
62
+
63
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
64
+
65
+ batch_size = input_ids.shape[0]
66
+ merged_embeds, merged_mask = [], []
67
+
68
+ for i in range(batch_size):
69
+ positions = (input_ids[i] == self.image_token_id).nonzero(as_tuple=True)[0]
70
+ if len(positions) == 0:
71
+ merged_embeds.append(inputs_embeds[i])
72
+ merged_mask.append(torch.ones(inputs_embeds.shape[1], device=input_ids.device))
73
+ continue
74
+
75
+ pos = positions[0].item()
76
+ img_mask = torch.ones(num_image_tokens, device=input_ids.device)
77
+ seq_mask = torch.ones(inputs_embeds.shape[1], device=input_ids.device)
78
+
79
+ merged_embeds.append(
80
+ torch.cat([inputs_embeds[i, :pos], image_embeds[i], inputs_embeds[i, pos + 1:]], dim=0)
81
+ )
82
+ merged_mask.append(
83
+ torch.cat([seq_mask[:pos], img_mask, seq_mask[pos + 1:]])
84
+ )
85
+
86
+ return torch.stack(merged_embeds, dim=0), torch.stack(merged_mask, dim=0)
87
+
88
+ def forward(self, input_ids, pixel_values, attention_mask=None, labels=None):
89
+ assert self.image_token_id is not None, "Call set_tokenizer() before forward()."
90
+
91
+ image_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state
92
+ image_embeds = self.mm_projector(image_features)
93
+ num_image_tokens = image_embeds.shape[1]
94
+
95
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
96
+
97
+ batch_size = input_ids.shape[0]
98
+ merged_embeds, merged_mask, merged_labels = [], [], []
99
+
100
+ for i in range(batch_size):
101
+ positions = (input_ids[i] == self.image_token_id).nonzero(as_tuple=True)[0]
102
+ if len(positions) == 0:
103
+ merged_embeds.append(inputs_embeds[i])
104
+ if attention_mask is not None:
105
+ merged_mask.append(attention_mask[i])
106
+ if labels is not None:
107
+ merged_labels.append(labels[i])
108
+ continue
109
+
110
+ pos = positions[0].item()
111
+ merged_embeds.append(
112
+ torch.cat([inputs_embeds[i, :pos], image_embeds[i], inputs_embeds[i, pos + 1:]], dim=0)
113
+ )
114
+
115
+ if attention_mask is not None:
116
+ img_mask = torch.ones(num_image_tokens, device=attention_mask.device, dtype=attention_mask.dtype)
117
+ merged_mask.append(
118
+ torch.cat([attention_mask[i, :pos], img_mask, attention_mask[i, pos + 1:]])
119
+ )
120
+
121
+ if labels is not None:
122
+ img_labels = torch.full((num_image_tokens,), -100, device=labels.device, dtype=labels.dtype)
123
+ merged_labels.append(
124
+ torch.cat([labels[i, :pos], img_labels, labels[i, pos + 1:]])
125
+ )
126
+
127
+ combined_embeds = torch.stack(merged_embeds, dim=0)
128
+ combined_mask = torch.stack(merged_mask, dim=0) if attention_mask is not None else None
129
+ combined_labels = torch.stack(merged_labels, dim=0) if labels is not None else None
130
+
131
+ return self.language_model(
132
+ inputs_embeds=combined_embeds,
133
+ attention_mask=combined_mask,
134
+ labels=combined_labels,
135
+ )
136
+
137
+ @torch.no_grad()
138
+ def describe(self, image, prompt="Describe the image.", max_new_tokens=512, do_sample=False, temperature=0.6, top_p=0.95):
139
+ assert self._tokenizer is not None, "Call set_tokenizer() before describe()."
140
+ assert self._processor is not None
141
+
142
+ device = next(self.parameters()).device
143
+
144
+ pixel_values = self._processor(images=image, return_tensors="pt").pixel_values.to(device, dtype=torch.bfloat16)
145
+
146
+ full_prompt = f"<image>\n{prompt}"
147
+ input_ids = self._tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
148
+
149
+ combined_embeds, attention_mask = self._merge_image_embeddings(input_ids, pixel_values)
150
+
151
+ generate_kwargs = dict(
152
+ inputs_embeds=combined_embeds,
153
+ attention_mask=attention_mask,
154
+ max_new_tokens=max_new_tokens,
155
+ do_sample=do_sample,
156
+ eos_token_id=self._tokenizer.eos_token_id,
157
+ pad_token_id=self._tokenizer.eos_token_id,
158
+ )
159
+ if do_sample:
160
+ generate_kwargs["temperature"] = temperature
161
+ generate_kwargs["top_p"] = top_p
162
+
163
+ output_ids = self.language_model.generate(**generate_kwargs)
164
+ return self._tokenizer.decode(output_ids[0], skip_special_tokens=True)