Upload modeling_vila.py
Browse files- modeling_vila.py +9 -1
modeling_vila.py
CHANGED
|
@@ -212,6 +212,7 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 212 |
self.vision_tower = self.vision_tower.cuda()
|
| 213 |
# set device_map auto can autoamtically shard llm to different devices
|
| 214 |
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
|
|
|
|
| 215 |
|
| 216 |
# NOTE(ligeng): hard code to set padding_side to left
|
| 217 |
self.tokenizer.padding_side = "left"
|
|
@@ -221,6 +222,12 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 221 |
self.post_config()
|
| 222 |
self.is_loaded = True
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
assert (
|
| 225 |
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
|
| 226 |
), "At least one of the components must be instantiated."
|
|
@@ -628,7 +635,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 628 |
self.encoders[name].end_tokens = None
|
| 629 |
|
| 630 |
# Extract text and media embeddings
|
| 631 |
-
text_embeds = self.
|
| 632 |
if media is not None:
|
| 633 |
media_embeds = self.__embed_media_tokens(media, media_config)
|
| 634 |
else:
|
|
@@ -712,6 +719,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 712 |
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
|
| 713 |
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
|
| 714 |
continue
|
|
|
|
| 715 |
embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
|
| 716 |
return embeds
|
| 717 |
|
|
|
|
| 212 |
self.vision_tower = self.vision_tower.cuda()
|
| 213 |
# set device_map auto can autoamtically shard llm to different devices
|
| 214 |
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
|
| 215 |
+
self.llm_model_embed_tokens = self.llm.model.embed_tokens
|
| 216 |
|
| 217 |
# NOTE(ligeng): hard code to set padding_side to left
|
| 218 |
self.tokenizer.padding_side = "left"
|
|
|
|
| 222 |
self.post_config()
|
| 223 |
self.is_loaded = True
|
| 224 |
|
| 225 |
+
self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False)
|
| 226 |
+
if self.llm_only_need_embed:
|
| 227 |
+
print("We only need the embed_tokens in llm.")
|
| 228 |
+
del self.llm
|
| 229 |
+
self.llm = None
|
| 230 |
+
|
| 231 |
assert (
|
| 232 |
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
|
| 233 |
), "At least one of the components must be instantiated."
|
|
|
|
| 635 |
self.encoders[name].end_tokens = None
|
| 636 |
|
| 637 |
# Extract text and media embeddings
|
| 638 |
+
text_embeds = self.llm_model_embed_tokens(input_ids)
|
| 639 |
if media is not None:
|
| 640 |
media_embeds = self.__embed_media_tokens(media, media_config)
|
| 641 |
else:
|
|
|
|
| 719 |
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
|
| 720 |
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
|
| 721 |
continue
|
| 722 |
+
media[name] = [a.to(torch.bfloat16) for a in media[name]]
|
| 723 |
embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
|
| 724 |
return embeds
|
| 725 |
|