Instructions to use nvidia/omnivinci with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nvidia/omnivinci with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nvidia/omnivinci", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import json | |
| import logging | |
| import numpy as np | |
| import os | |
| import os.path | |
| import os.path as osp | |
| import shutil | |
| import warnings | |
| from abc import ABC | |
| from collections import OrderedDict, defaultdict, deque | |
| from copy import deepcopy | |
| from itertools import chain | |
| from threading import Thread | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| from einops import rearrange | |
| from PIL import Image | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModel, | |
| AutoProcessor, | |
| AutoTokenizer, | |
| GenerationConfig, | |
| LogitsProcessor, | |
| PretrainedConfig, | |
| PreTrainedModel, | |
| Qwen2Config, | |
| Qwen2ForCausalLM, | |
| Qwen2PreTrainedModel, | |
| TextIteratorStreamer, | |
| WhisperFeatureExtractor, | |
| ) | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from transformers.modeling_utils import ContextManagers, no_init_weights | |
| from .auto_processor import VILAProcessor | |
| from .base_projector import MultimodalProjector, MultimodalProjectorConfig | |
| from .sound_base_projector import SoundMultimodalProjector, SoundMultimodalProjectorConfig | |
| from .speech_base_projector import SpeechMultimodalProjector, SpeechMultimodalProjectorConfig | |
| from .builder import build_llm_and_tokenizer | |
| from .configuration_vila import VILAConfig | |
| from .constants import * | |
| from .conversation import SeparatorStyle, default_conversation | |
| from .distributed import all_gather as vila_all_gather | |
| from .media import extract_media | |
| from .media_encoder import BasicImageEncoder, BasicVideoEncoder, TSPVideoEncoder, BasicSoundEncoder, CacheFeatures | |
| from .mm_utils import process_image, process_images | |
| from .model_utils_packing import set_seqlens_in_batch | |
| from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2 | |
| from .tokenizer_utils import tokenize_conversation | |
| from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template | |
| from .constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS_VILA, NUM_EXTRA_TOKENS_XVILA | |
| from .qwen_audio_encoder import Qwen2AudioTower | |
| import whisper | |
| from .audio_encoder import AudioTower | |
| def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel: | |
| """Build multimodal projector from path or configuration.""" | |
| if model_type_or_path is None: | |
| return None | |
| if config.resume_path: | |
| assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!" | |
| return MultimodalProjector.from_pretrained(model_type_or_path, config) | |
| else: | |
| mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path) | |
| mm_projector = MultimodalProjector(mm_projector_cfg, config) | |
| return mm_projector | |
| def build_speech_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel: | |
| """Build speech multimodal projector from path or configuration.""" | |
| if model_type_or_path is None: | |
| return None | |
| if config.resume_path: | |
| assert os.path.exists(model_type_or_path), f"Resume speech mm projector path {model_type_or_path} does not exist!" | |
| _model = SpeechMultimodalProjector.from_pretrained( | |
| model_type_or_path, config, torch_dtype=eval(config.model_dtype) | |
| ) | |
| return _model | |
| else: | |
| speech_mm_projector_cfg = SpeechMultimodalProjectorConfig(model_type_or_path) | |
| speech_mm_projector = SpeechMultimodalProjector(speech_mm_projector_cfg, config).to(eval(config.model_dtype)) | |
| return speech_mm_projector | |
| def build_sound_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel: | |
| """Build sound multimodal projector from path or configuration.""" | |
| if model_type_or_path is None: | |
| return None | |
| if type(config.model_dtype) == str: | |
| model_dtype = eval(config.model_dtype) | |
| else: | |
| model_dtype = config.model_dtype | |
| if config.resume_path: | |
| assert os.path.exists(model_type_or_path), f"Resume sound mm projector path {model_type_or_path} does not exist!" | |
| _model = SoundMultimodalProjector.from_pretrained( | |
| model_type_or_path, config, torch_dtype=model_dtype | |
| ) | |
| return _model | |
| else: | |
| sound_mm_projector_cfg = SoundMultimodalProjectorConfig(model_type_or_path) | |
| sound_mm_projector = SoundMultimodalProjector(sound_mm_projector_cfg, config).to(model_dtype) | |
| return sound_mm_projector | |
| def check_dot_in_model_path(model_path: str): | |
| """Check if the model path contains a dot, which may affect model loading.""" | |
| if osp.isdir(model_path): | |
| if "." in osp.abspath(model_path): | |
| return True | |
| else: | |
| if "." in model_path: | |
| return True | |
| return False | |
| def get_vila_version(model_path: str) -> str: | |
| VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"] | |
| for version in VERSIONS: | |
| if version in model_path.lower(): | |
| return version | |
| return None | |
| def generate_jinja_template(conv_mode: str) -> str: | |
| if conv_mode == "vicuna_v1": | |
| return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %} | |
| {% set roles = ["user", "assistant"] %} | |
| {% set sep = " " %} | |
| {{ system_prompt }} | |
| {% for message in messages %} | |
| {% if message['role'] == roles[0] %} | |
| {{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }} | |
| {% else %} | |
| {{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }} | |
| {% endif %} | |
| {% endfor %} | |
| {% if messages[-1]['role'] == 'user' %} | |
| {{ "ASSISTANT:" }} | |
| {% endif %} | |
| """ | |
| elif conv_mode == "llama_3": | |
| return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %} | |
| {% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%} | |
| {% set sep = "<|eot_id|>" %} | |
| {{ system_prompt }} | |
| {% for message in messages %} | |
| {% if message['role'] == 'user' %} | |
| {{ roles[0] }}{{ message['content'] }}{{ sep }} | |
| {% else %} | |
| {{ roles[1] }}{{ message['content'] }}{{ sep }} | |
| {% endif %} | |
| {% endfor %} | |
| {% if messages[-1]['role'] == 'user' %} | |
| {{ roles[1] }} | |
| {% endif %} | |
| """ | |
| elif conv_mode == "hermes_2": | |
| return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %} | |
| {% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %} | |
| {% set sep = "<|im_end|>" %} | |
| {{ system_prompt }}{{ sep }} | |
| {% for message in messages %} | |
| {% if message['role'] == 'user' %} | |
| {{ roles[0] }}{{ message['content'] }}{{ sep }} | |
| {% else %} | |
| {{ roles[1] }}{{ message['content'] }}{{ sep }} | |
| {% endif %} | |
| {% endfor %}""" | |
| else: | |
| raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.") | |
| def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel: | |
| """Build vision tower from path or configuration.""" | |
| # Skip vision tower instantiation if path is None | |
| if model_name_or_path is None: | |
| return None | |
| vision_tower_arch = None | |
| if config.resume_path and "radio" not in model_name_or_path: | |
| assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!" | |
| vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) | |
| vision_tower_arch = vision_tower_cfg.architectures[0].lower() | |
| vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path | |
| use_s2 = getattr(config, "s2", False) | |
| use_dynamic_s2 = getattr(config, "dynamic_s2", False) | |
| if "siglip" in vision_tower_name: | |
| if use_dynamic_s2: | |
| vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config) | |
| elif use_s2: | |
| vision_tower = SiglipVisionTowerS2(model_name_or_path, config) | |
| else: | |
| vision_tower = SiglipVisionTower(model_name_or_path, config) | |
| else: | |
| raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}") | |
| config.mm_hidden_size = ( | |
| vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size | |
| ) | |
| return vision_tower | |
| def build_audio_tower(model_name_or_path: str, config: PretrainedConfig, encoder_type: str) -> PreTrainedModel: | |
| """Build audio tower for sound or speech processing.""" | |
| assert encoder_type in ["sound", "speech"] | |
| # Skip tower instantiation if path is None | |
| if model_name_or_path is None: | |
| return None | |
| model_type = "af3" | |
| if model_type == "af3": | |
| model = Qwen2AudioTower(model_name_or_path, config) | |
| output_dim = 1280 | |
| else: | |
| raise NotImplementedError(f"Not implemented for this encoder: {model_name_or_path}") | |
| if encoder_type == "sound": | |
| config.sound_hidden_size = output_dim | |
| elif encoder_type == "speech": | |
| config.speech_hidden_size = output_dim | |
| else: | |
| raise NotImplementedError(f"Not implemented for this encoder: {model_name_or_path}") | |
| return model | |
| class VILAPretrainedModel(PreTrainedModel): | |
| config_class = VILAConfig | |
| main_input_name = "input_embeds" | |
| supports_gradient_checkpointing = True | |
| _supports_flash_attn_2 = True | |
| _no_split_modules = ["Qwen2DecoderLayer", "SiglipEncoderLayer"] | |
| def __init__(self, config: VILAConfig, *args, **kwargs): | |
| super().__init__(config) | |
| self.config = config | |
| cfgs = get_model_config(config) | |
| if len(cfgs) == 7: | |
| ( | |
| llm_cfg, | |
| vision_tower_cfg, | |
| speech_tower_cfg, | |
| sound_tower_cfg, | |
| mm_projector_cfg, | |
| speech_mm_projector_cfg, | |
| sound_mm_projector_cfg, | |
| ) = cfgs | |
| else: | |
| raise ValueError( | |
| "`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config." | |
| ) | |
| # loading on auto by default | |
| device_map = kwargs.get("device_map", "auto") | |
| self.mm_projector = build_mm_projector(mm_projector_cfg, config) | |
| self.vision_tower = build_vision_tower(vision_tower_cfg, config) | |
| if speech_tower_cfg: | |
| self.speech_tower = build_audio_tower(speech_tower_cfg, config, encoder_type="speech") | |
| self.speech_mm_projector = build_speech_mm_projector(speech_mm_projector_cfg, config) | |
| if sound_tower_cfg: | |
| self.sound_tower = build_audio_tower(sound_tower_cfg, config, encoder_type="sound") | |
| self.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config) | |
| if device_map in ["auto", "cuda"]: | |
| self.mm_projector = self.mm_projector.cuda() | |
| self.vision_tower = self.vision_tower.cuda() | |
| self.speech_tower = self.speech_tower.cuda() if hasattr(self, "speech_tower") else None | |
| self.sound_tower = self.sound_tower.cuda() if hasattr(self, "sound_tower") else None | |
| self.speech_mm_projector = self.speech_mm_projector.cuda() if hasattr(self, "speech_mm_projector") else None | |
| self.sound_mm_projector = self.sound_mm_projector.cuda() if hasattr(self, "sound_mm_projector") else None | |
| # set device_map auto can autoamtically shard llm to different devices | |
| self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map) | |
| self.llm_model_embed_tokens = self.llm.model.embed_tokens | |
| self.tokenizer.padding_side = "left" | |
| self.vocab_size = len(self.tokenizer) | |
| self.update_vocab_size = lambda: setattr(self, "vocab_size", len(self.tokenizer)) | |
| self.encoders = {} | |
| for name in ["image", "video", "speech", "sound"]: | |
| encoder_config = getattr(self.config, f"{name}_encoder") | |
| if isinstance(encoder_config, str): | |
| encoder_config = json.loads(encoder_config) | |
| if encoder_config.get("embed_time", False) == "True": | |
| if "trope_dim" not in encoder_config and encoder_config.get("time_embed_type", "") in ["pixel", "lang"]: | |
| encoder_config["trope_dim"] = self.config.hidden_size // 2 | |
| print(f"Warning: trope_dim not found in config, defaulting to hidden_size // 2: {encoder_config['trope_dim']}") | |
| encoder_config.pop('_target_') | |
| if name == "video": | |
| self.encoders[name] = TSPVideoEncoder(parent=self, **encoder_config) | |
| elif name == "image": | |
| self.encoders[name] = BasicImageEncoder(self) | |
| else: | |
| self.encoders[name] = BasicSoundEncoder(parent=self, **encoder_config) | |
| self.post_config() | |
| self.is_loaded = True | |
| self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False) | |
| if self.llm_only_need_embed: | |
| print("We only need the embed_tokens in llm.") | |
| del self.llm | |
| self.llm = None | |
| torch.cuda.empty_cache() | |
| assert ( | |
| self.llm is not None | |
| or self.vision_tower is not None | |
| or self.speech_tower is not None | |
| or self.mm_projector is not None | |
| or self.speech_mm_projector is not None | |
| ), "At least one of the components must be instantiated." | |
| def copy_or_symlink_directory(cls, model_path, output_dir, copy=True): | |
| # Create output directory if it doesn't exist | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Create symlinks for all files in model_path to output_dir | |
| for item in os.listdir(model_path): | |
| src_path = os.path.join(model_path, item) | |
| dst_path = os.path.join(output_dir, item) | |
| # Remove existing file/directory at destination if it exists | |
| if os.path.exists(dst_path): | |
| if os.path.islink(dst_path): | |
| os.unlink(dst_path) | |
| elif os.path.isdir(dst_path): | |
| shutil.rmtree(dst_path) | |
| else: | |
| os.remove(dst_path) | |
| # Create symlink | |
| if copy: | |
| if os.path.isdir(src_path): | |
| shutil.copytree(src_path, dst_path) | |
| else: | |
| shutil.copy2(src_path, dst_path) | |
| print(f"Copied {src_path} to {dst_path}") | |
| else: | |
| os.symlink(src_path, dst_path) | |
| print(f"Created symlink from {src_path} to {dst_path}") | |
| def copy_remote_py_files(cls, output_dir, copy=True): | |
| # copy .py and README for next loading | |
| current_file_path = os.path.abspath(__file__) | |
| current_folder = os.path.dirname(current_file_path) | |
| for file_name in os.listdir(current_folder): | |
| if file_name == "INSTRUCTIONS.md": | |
| src_fname = os.path.join(current_folder, file_name) | |
| dst_fname = os.path.join(output_dir, "README.md") | |
| if os.path.exists(dst_fname): | |
| old_readme = open(dst_fname).read() | |
| else: | |
| old_readme = "" | |
| with open(src_fname) as src, open(dst_fname, "w") as dst: | |
| dst.write(src.read()) | |
| dst.write(old_readme) | |
| print("[HF] README", src_fname, "to", dst_fname) | |
| if file_name.endswith(".py") or file_name.endswith(".jinja"): | |
| full_file_name = os.path.join(current_folder, file_name) | |
| if os.path.isfile(full_file_name): | |
| if copy: | |
| shutil.copy(full_file_name, output_dir) | |
| print("[HF] copying", full_file_name, "to", output_dir) | |
| else: | |
| # symlink to ease development | |
| if os.path.exists(os.path.join(output_dir, file_name)): | |
| os.remove(os.path.join(output_dir, file_name)) | |
| os.symlink(full_file_name, os.path.join(output_dir, file_name)) | |
| print("[HF] linking", full_file_name, "to", output_dir) | |
| def save_pretrained(self, output_dir, state_dict=None, **kwargs): | |
| if state_dict is None: | |
| # other wise fetch from deepspeed | |
| # state_dict = accelerator.get_state_dict(is_deepspeed_enabled) | |
| state_dict = self.state_dict() | |
| if getattr(self, "tokenizer", None): | |
| self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) | |
| if self.get_llm(): | |
| print(f"saving llm to {osp.join(output_dir, 'llm')}") | |
| self.llm.config._name_or_path = osp.join(output_dir, "llm") | |
| llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k}) | |
| self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict) | |
| self.config.llm_cfg = self.llm.config | |
| if self.get_vision_tower(): | |
| print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}") | |
| self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower") | |
| vision_tower_state_dict = OrderedDict( | |
| {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k} | |
| ) | |
| self.vision_tower.vision_tower.save_pretrained( | |
| os.path.join(output_dir, "vision_tower"), | |
| state_dict=vision_tower_state_dict, | |
| ) | |
| self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower")) | |
| self.config.vision_tower_cfg = self.vision_tower.config | |
| if hasattr(self.config.vision_tower_cfg, "auto_map"): | |
| if "radio" not in self.get_vision_tower().__class__.__name__.lower(): | |
| delattr(self.config.vision_tower_cfg, "auto_map") | |
| if self.get_speech_tower(): | |
| print(f"saving speech_tower to {osp.join(output_dir, 'speech_tower')}") | |
| self.speech_tower.config._name_or_path = osp.join(output_dir, "speech_tower").replace( | |
| "tmp-checkpoint", "checkpoint" | |
| ) | |
| speech_tower_state_dict = OrderedDict( | |
| {k.split("speech_tower.audio_tower.")[-1]: v for k, v in state_dict.items() if "speech_tower" in k} | |
| ) | |
| self.speech_tower.audio_tower.save_pretrained( | |
| os.path.join(output_dir, "speech_tower"), | |
| state_dict=speech_tower_state_dict, | |
| ) | |
| self.config.speech_tower_cfg = self.speech_tower.config | |
| if self.get_sound_tower(): | |
| print(f"saving sound_tower to {osp.join(output_dir, 'sound_tower')}") | |
| self.sound_tower.config._name_or_path = osp.join(output_dir, "sound_tower").replace( | |
| "tmp-checkpoint", "checkpoint" | |
| ) | |
| sound_tower_state_dict = OrderedDict( | |
| {k.split("sound_tower.audio_tower.")[-1]: v for k, v in state_dict.items() if "sound_tower" in k} | |
| ) | |
| self.sound_tower.audio_tower.save_pretrained( | |
| os.path.join(output_dir, "sound_tower"), | |
| state_dict=sound_tower_state_dict, | |
| ) | |
| self.config.sound_tower_cfg = self.sound_tower.config | |
| if self.get_mm_projector(): | |
| print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}") | |
| self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector") | |
| mm_projector_state_dict = OrderedDict( | |
| {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k} | |
| ) | |
| self.mm_projector.save_pretrained( | |
| os.path.join(output_dir, "mm_projector"), | |
| state_dict=mm_projector_state_dict, | |
| ) | |
| self.config.mm_projector_cfg = self.mm_projector.config | |
| if self.get_speech_mm_projector(): | |
| print(f"saving speech_mm_projector to {osp.join(output_dir, 'speech_mm_projector')}") | |
| self.speech_mm_projector.config._name_or_path = osp.join(output_dir, "speech_mm_projector").replace( | |
| "tmp-checkpoint", "checkpoint" | |
| ) | |
| speech_mm_projector_state_dict = OrderedDict( | |
| {k.split("speech_mm_projector.")[-1]: v for k, v in state_dict.items() if "speech_mm_projector" in k} | |
| ) | |
| self.speech_mm_projector.save_pretrained( | |
| os.path.join(output_dir, "speech_mm_projector"), | |
| state_dict=speech_mm_projector_state_dict, | |
| ) | |
| self.config.speech_mm_projector_cfg = self.speech_mm_projector.config | |
| if self.get_sound_mm_projector(): | |
| print(f"saving sound_mm_projector to {osp.join(output_dir, 'sound_mm_projector')}") | |
| self.sound_mm_projector.config._name_or_path = osp.join(output_dir, "sound_mm_projector").replace( | |
| "tmp-checkpoint", "checkpoint" | |
| ) | |
| sound_mm_projector_state_dict = OrderedDict( | |
| {k.split("sound_mm_projector.")[-1]: v for k, v in state_dict.items() if "sound_mm_projector" in k} | |
| ) | |
| self.sound_mm_projector.save_pretrained( | |
| os.path.join(output_dir, "sound_mm_projector"), | |
| state_dict=sound_mm_projector_state_dict, | |
| ) | |
| self.config.sound_mm_projector_cfg = self.sound_mm_projector.config | |
| # update and save top-level config | |
| self.config._name_or_path = output_dir | |
| self.config.architectures = [self.__class__.__name__] | |
| self.config.save_pretrained(output_dir) | |
| # copy .py and README for next loading | |
| self.copy_remote_py_files(output_dir) | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: Optional[str] = None, | |
| *model_args, | |
| config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, | |
| cache_dir: Optional[Union[str, os.PathLike]] = None, | |
| ignore_mismatched_sizes: bool = False, | |
| force_download: bool = False, | |
| local_files_only: bool = False, | |
| token: Optional[Union[str, bool]] = None, | |
| revision: str = "main", | |
| use_safetensors: Optional[bool] = None, | |
| weights_only: bool = True, | |
| **kwargs, | |
| ): | |
| # print("DEBUG2", kwargs); input() | |
| config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) | |
| if kwargs.get("torch_dtype", None) is not None: | |
| config.torch_dtype = kwargs.get("torch_dtype", None) | |
| config.model_dtype = kwargs.get("torch_dtype", None) | |
| if type(kwargs.get("torch_dtype", None)) == str: | |
| kwargs["torch_dtype"] = eval(kwargs.get("torch_dtype", None)) | |
| else: | |
| kwargs["torch_dtype"] = kwargs.get("torch_dtype", None) | |
| return cls._from_config(config, **kwargs) | |
| def init_llm(self, llm_config, config, *args, **kwargs): | |
| """Initialize language model and tokenizer.""" | |
| self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs) | |
| self.pad_token_list = ( | |
| self.tokenizer.pad_token_id, | |
| self.tokenizer.eos_token_id, | |
| self.tokenizer.tokenize("<|endoftext|>")[0], # for Qwen | |
| ) | |
| self.vocab_size = len(self.tokenizer) | |
| self.update_vocab_size = lambda: setattr(self, "vocab_size", len(self.tokenizer)) | |
| # XGrammar tokenizer and grammar compiler | |
| # lazy init only when specified json output during inference | |
| self.grammar_compiler = None | |
| # self.llm.resize_token_embeddings(len(self.tokenizer)) | |
| return self.llm, self.tokenizer | |
| def post_config(self): | |
| self.training = self.llm.training | |
| if self.training: | |
| self.train() | |
| else: | |
| self.eval() | |
| # configuration | |
| if getattr(self.config, "llm_cfg", None) is None: | |
| self.config.llm_cfg = self.llm.config | |
| if getattr(self.config, "vision_tower_cfg", None) is None: | |
| self.config.vision_tower_cfg = self.vision_tower.config | |
| if getattr(self.config, "mm_projector_cfg", None) is None: | |
| self.config.mm_projector_cfg = self.mm_projector.config | |
| if getattr(self.config, "speech_tower_cfg", None) is None and hasattr(self, "speech_tower"): | |
| self.config.speech_tower_cfg = self.speech_tower.config | |
| if getattr(self.config, "sound_tower_cfg", None) is None and hasattr(self, "sound_tower"): | |
| self.config.sound_tower_cfg = self.sound_tower.config | |
| if getattr(self.config, "speech_mm_projector_cfg", None) is None and hasattr(self, "speech_mm_projector"): | |
| self.config.speech_mm_projector_cfg = self.speech_mm_projector.config | |
| if getattr(self.config, "sound_mm_projector_cfg", None) is None and hasattr(self, "sound_mm_projector"): | |
| self.config.sound_mm_projector_cfg = self.sound_mm_projector.config | |
| def get_llm(self): | |
| llm = getattr(self, "llm", None) | |
| if type(llm) is list: | |
| llm = llm[0] | |
| return llm | |
| def get_lm_head(self): | |
| lm_head = getattr(self.get_llm(), "lm_head", None) | |
| return lm_head | |
| def get_vision_tower(self): | |
| vision_tower = getattr(self, "vision_tower", None) | |
| if type(vision_tower) is list: | |
| vision_tower = vision_tower[0] | |
| return vision_tower | |
| def get_speech_tower(self): | |
| speech_tower = getattr(self, "speech_tower", None) | |
| if type(speech_tower) is list: | |
| speech_tower = speech_tower[0] | |
| return speech_tower | |
| def get_sound_tower(self): | |
| sound_tower = getattr(self, "sound_tower", None) | |
| if type(sound_tower) is list: | |
| sound_tower = sound_tower[0] | |
| return sound_tower | |
| def get_mm_projector(self): | |
| mm_projector = getattr(self, "mm_projector", None) | |
| if type(mm_projector) is list: | |
| mm_projector = mm_projector[0] | |
| return mm_projector | |
| def get_sound_mm_projector(self): | |
| sound_mm_projector = getattr(self, "sound_mm_projector", None) | |
| if type(sound_mm_projector) is list: | |
| sound_mm_projector = sound_mm_projector[0] | |
| return sound_mm_projector | |
| def get_speech_tower(self): | |
| speech_tower = getattr(self, "speech_tower", None) | |
| if type(speech_tower) is list: | |
| speech_tower = speech_tower[0] | |
| return speech_tower | |
| def get_speech_mm_projector(self): | |
| speech_mm_projector = getattr(self, "speech_mm_projector", None) | |
| if type(speech_mm_projector) is list: | |
| speech_mm_projector = speech_mm_projector[0] | |
| return speech_mm_projector | |
| def freezed_module_patch(self): | |
| """ | |
| Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. | |
| """ | |
| if self.training: | |
| if self.get_llm() and not getattr(self.config, "tune_language_model", False): | |
| pass | |
| if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False): | |
| self.get_vision_tower().eval() | |
| if self.get_speech_tower() and not getattr(self.config, "tune_speech_tower", False): | |
| self.get_speech_tower().eval() | |
| if self.get_sound_tower() and not getattr(self.config, "tune_sound_tower", False): | |
| self.get_sound_tower().eval() | |
| if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False): | |
| self.get_mm_projector().eval() | |
| if self.get_speech_mm_projector() and not getattr(self.config, "tune_speech_mm_projector", False): | |
| self.get_speech_mm_projector().eval() | |
| if self.get_sound_mm_projector() and not getattr(self.config, "tune_sound_mm_projector", False): | |
| self.get_sound_mm_projector().eval() | |
| class VILAForCausalLM(VILAPretrainedModel): | |
| def __init__(self, config: VILAConfig, *args, **kwargs): | |
| super().__init__(config, *args, **kwargs) | |
| def merge_features_for_dynamic_s2(self, image_features, block_sizes): | |
| scales = self.get_vision_tower().scales | |
| resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx | |
| image_features_each_image = [] | |
| new_block_sizes = [] | |
| block_cnt = 0 | |
| for block_size_each_image in block_sizes: | |
| if block_size_each_image is None: | |
| cur_features = image_features[block_cnt : block_cnt + 1] | |
| cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5)) | |
| cur_features = cur_features.repeat(1, len(scales), 1, 1) | |
| image_features_each_image.append(cur_features) | |
| new_block_sizes.append((1, 1)) | |
| block_cnt += 1 | |
| else: | |
| cur_features_each_scale = [] | |
| for scale in scales[:-1]: | |
| num_blocks_this_scale = (scale // scales[0]) ** 2 | |
| cur_features_each_scale.append( | |
| self.merge_chessboard( | |
| image_features[block_cnt : block_cnt + num_blocks_this_scale], | |
| num_split_h=scale // scales[0], | |
| num_split_w=scale // scales[0], | |
| ) | |
| ) # 1 * C * H * W | |
| block_cnt += num_blocks_this_scale | |
| num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1] | |
| cur_features_each_scale.append( | |
| self.merge_chessboard( | |
| image_features[block_cnt : block_cnt + num_blocks_last_scale], | |
| num_split_h=block_size_each_image[0], | |
| num_split_w=block_size_each_image[1], | |
| ) | |
| ) # 1 * C * H * W | |
| block_cnt += num_blocks_last_scale | |
| # resize and concat features from different scales | |
| output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:] | |
| cur_features = torch.cat( | |
| [ | |
| F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to( | |
| cur_features_each_scale[i].dtype | |
| ) | |
| for i in range(len(cur_features_each_scale)) | |
| ], | |
| dim=1, | |
| ) | |
| image_features_each_image.append(cur_features) | |
| if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1: | |
| new_block_sizes.append(block_size_each_image) | |
| else: | |
| new_block_sizes.append( | |
| ( | |
| scales[resize_output_to_scale_idx] // scales[0], | |
| scales[resize_output_to_scale_idx] // scales[0], | |
| ) | |
| ) | |
| assert block_cnt == len(image_features) | |
| return image_features_each_image, new_block_sizes | |
| def split_chessboard(x, num_split_h, num_split_w): | |
| """ | |
| x: b * c * h * w | |
| out: b * c * h * w | |
| Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension | |
| """ | |
| B, C, H, W = x.shape | |
| assert H % num_split_h == 0 and W % num_split_w == 0 | |
| h, w = H // num_split_h, W // num_split_w | |
| x_split = torch.cat( | |
| [x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w] for i in range(num_split_h) for j in range(num_split_w)], | |
| dim=0, | |
| ) | |
| return x_split | |
| def merge_chessboard(x, num_split_h, num_split_w): | |
| """ | |
| x: b * n * c or b * h * w * c | |
| out: b * c * h * w | |
| Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square. | |
| """ | |
| B = x.shape[0] | |
| if x.dim() == 3: | |
| N = x.shape[1] | |
| x = rearrange(x, "b (h w) c -> b c h w", h=int(N**0.5), w=int(N**0.5)) | |
| assert B % (num_split_h * num_split_w) == 0 | |
| b = B // (num_split_h * num_split_w) | |
| x_merge = torch.cat( | |
| [ | |
| torch.cat( | |
| [x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b] for j in range(num_split_w)], dim=-1 | |
| ) | |
| for i in range(num_split_h) | |
| ], | |
| dim=-2, | |
| ) | |
| return x_merge | |
| def encode_video(self, inp, block_sizes: Optional[Optional[Tuple[int, ...]]] = None, mm_info: Optional[dict] = None, num_frames: Optional[List[int]] = None): | |
| bs = len(inp) | |
| cache_feas = [] | |
| cache_feas_index = [] | |
| inp_block_sizes = block_sizes | |
| # handle cache features | |
| for _idx in range(len(inp)): | |
| if type(inp[_idx]) == CacheFeatures: | |
| cache_feas.append(inp[_idx]) | |
| cache_feas_index.append(_idx) | |
| raw_images = [_ for _ in inp if type(_) != CacheFeatures] | |
| raw_videos_num_frames = [_.shape[0] for _ in raw_images] | |
| if len(raw_images) > 0: | |
| images = torch.cat(raw_images, dim=0) | |
| else: | |
| images = [] | |
| if block_sizes is None: | |
| block_sizes = [None] * len(images) | |
| def _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames): | |
| # load cache features | |
| if len(cache_feas) > 0: | |
| if len(image_features) > 0: | |
| image_features = torch.split(image_features, raw_videos_num_frames) | |
| new_image_features = [] | |
| cache_feas_idx = 0 | |
| raw_fea_idx = 0 | |
| for _idx in range(bs): | |
| if _idx in cache_feas_index: | |
| new_image_features.append(cache_feas[cache_feas_idx].value['features'].to(self.device, self.dtype)) | |
| cache_feas_idx += 1 | |
| else: | |
| new_image_features.append(image_features[raw_fea_idx]) | |
| raw_fea_idx += 1 | |
| assert len(new_image_features) == bs | |
| image_features = new_image_features | |
| image_features = torch.cat(image_features, dim=0) | |
| return image_features | |
| if getattr(self.config, "dynamic_s2", False): | |
| if len(images) > 0: | |
| image_features = self.get_vision_tower()(images) | |
| image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes) | |
| image_features = [ | |
| self.split_chessboard(x, block_size[0], block_size[1]) | |
| for x, block_size in zip(image_features, new_block_sizes) | |
| ] # list of B * C * H * W tensors | |
| image_features = torch.cat( | |
| [rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0 | |
| ) # B * N * C | |
| else: | |
| image_features = [] | |
| # load cache features | |
| image_features = _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames) | |
| # if hasattr(self.config, "save_data") and self.config.save_data and num_frames is not None: # video | |
| # _save_video_features(image_features, mm_info, inp) | |
| if inp_block_sizes is None: | |
| new_block_sizes = [(1, 1)] * len(image_features) | |
| else: | |
| raise ValueError(f"inp_block_sizes is not None: {inp_block_sizes}") | |
| image_features = image_features.to(self.device, self.dtype) | |
| image_features = self.get_mm_projector()(image_features) | |
| image_features = list( | |
| image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0) | |
| ) | |
| image_features = [ | |
| self.merge_chessboard(x, block_size[0], block_size[1]) | |
| for x, block_size in zip(image_features, new_block_sizes) | |
| ] # list of 1 * C * H * W tensors | |
| image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors | |
| if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]): | |
| image_features = torch.stack(image_features, dim=0) | |
| else: | |
| if len(images) > 0: | |
| image_features = self.get_vision_tower()(images) | |
| else: | |
| image_features = [] | |
| # load cache features | |
| image_features = _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames) | |
| image_features = self.get_mm_projector()(image_features) | |
| return image_features | |
| def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None, mm_info: Optional[dict] = None, num_frames: Optional[List[int]] = None): | |
| if block_sizes is None: | |
| block_sizes = [None] * len(images) | |
| if getattr(self.config, "dynamic_s2", False): | |
| image_features = self.get_vision_tower()(images) | |
| image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes) | |
| image_features = [ | |
| self.split_chessboard(x, block_size[0], block_size[1]) | |
| for x, block_size in zip(image_features, new_block_sizes) | |
| ] # list of B * C * H * W tensors | |
| image_features = torch.cat( | |
| [rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0 | |
| ) # B * N * C | |
| image_features = self.get_mm_projector()(image_features) | |
| image_features = list( | |
| image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0) | |
| ) | |
| image_features = [ | |
| self.merge_chessboard(x, block_size[0], block_size[1]) | |
| for x, block_size in zip(image_features, new_block_sizes) | |
| ] # list of 1 * C * H * W tensors | |
| image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors | |
| if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]): | |
| image_features = torch.stack(image_features, dim=0) | |
| else: | |
| image_features = self.get_vision_tower()(images) | |
| image_features = self.get_mm_projector()(image_features) | |
| return image_features | |
| def encode_sound(self, sounds, mm_info: Optional[dict] = None): | |
| audio_features, audio_output_lengths = self.get_sound_tower()(sounds) | |
| use_fea_downsample = False | |
| if getattr(self.config, "sound_mm_projector", "") != "": | |
| if "mlp_downsample" in getattr(self.config, "sound_mm_projector", ""): | |
| use_fea_downsample = True | |
| else: | |
| sound_mm_projector_cfg = getattr(self.config, "sound_mm_projector_cfg", None) | |
| if sound_mm_projector_cfg is not None: | |
| if type(sound_mm_projector_cfg) == dict: | |
| if "mlp_downsample" in sound_mm_projector_cfg["sound_mm_projector_type"]: | |
| use_fea_downsample = True | |
| elif type(sound_mm_projector_cfg) == SoundMultimodalProjectorConfig: | |
| if "mlp_downsample" in sound_mm_projector_cfg.sound_mm_projector_type: | |
| use_fea_downsample = True | |
| if not use_fea_downsample: | |
| audio_features = self.get_sound_mm_projector()(audio_features) | |
| if audio_output_lengths is not None: | |
| # split the batch | |
| new_audio_features = [] | |
| start = 0 | |
| for length in audio_output_lengths: | |
| new_audio_features.append(audio_features[start : start + length]) | |
| start += length | |
| audio_features = new_audio_features | |
| if use_fea_downsample: | |
| audio_features = torch.stack(audio_features, dim=0) | |
| audio_features = self.get_sound_mm_projector()(audio_features) | |
| return audio_features | |
| def train(self, mode: bool = True): | |
| super().train(mode) | |
| return self | |
| def _embed( | |
| self, | |
| input_ids: torch.Tensor, | |
| media: Dict[str, List[torch.Tensor]], | |
| media_config: Dict[str, Dict[str, Any]], | |
| labels: Optional[torch.Tensor], | |
| attention_mask: Optional[torch.Tensor], | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| media = copy.deepcopy(media) | |
| media_config = copy.deepcopy(media_config) | |
| labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX) | |
| attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool) | |
| PROCESS_GROUP_MANAGER = None | |
| if PROCESS_GROUP_MANAGER is not None: | |
| for name in media: | |
| self.encoders[name].end_tokens = None | |
| # Extract text and media embeddings | |
| text_embeds = self.llm_model_embed_tokens(input_ids) | |
| mm_info = {} | |
| if "video_info" in media: | |
| video_info = media["video_info"] | |
| del media["video_info"] | |
| mm_info['video_info'] = video_info | |
| else: | |
| video_info = None | |
| if "audio_info" in media: | |
| audio_info = media["audio_info"] | |
| del media["audio_info"] | |
| mm_info['audio_info'] = audio_info | |
| else: | |
| audio_info = None | |
| if media is not None: | |
| media_embeds = self.__embed_media_tokens(media, media_config, mm_info) | |
| else: | |
| # no media was provided, so we just return an empty dict | |
| media_embeds = {} | |
| if PROCESS_GROUP_MANAGER is not None: | |
| media_embeds_video = [] | |
| for i, images in enumerate(media_embeds["video"]): | |
| num_video_frame = media["video"][i].shape[0] | |
| media_embeds_video += torch.unbind(images.reshape(num_video_frame, -1, images.shape[-1])) | |
| media_embeds["video"] = deque(media_embeds_video) | |
| # This is a workaround to make sure the dummy embeddings are consumed | |
| while media_embeds.get("dummy"): | |
| dummy_embed = media_embeds["dummy"].popleft() | |
| text_embeds += torch.sum(dummy_embed) * 0 | |
| # Based on segment_aud_indices_list and segment_vis_indices_list, get interleaved vis-aud embeddings for video | |
| video_sound_embeds_idx = 0 | |
| sep_embed = self.encoders["video"].embed_tokens("\n") | |
| text_embeds = text_embeds.to(self.dtype) | |
| sep_embed = sep_embed.to(text_embeds.dtype) | |
| if video_info is not None and self.config.load_audio_in_video and self.config.interleaved_vis_aud_in_video: | |
| assert self.encoders["video"].end_tokens is None, "end_tokens must be None for interleaved vis-aud in video" | |
| new_video_embeds = deque() | |
| video_embeds_idx = 0 | |
| for k in range(len(video_info)): | |
| if video_info[k] is None: | |
| continue | |
| for i in range(len(video_info[k])): | |
| has_audio = video_info[k][i]["has_audio"] | |
| if not has_audio: | |
| new_video_embeds.append(media_embeds["video"][video_embeds_idx]) | |
| video_embeds_idx += 1 | |
| continue | |
| # Check bounds for sound embeddings | |
| if video_sound_embeds_idx >= len(media_embeds["sound"]): | |
| raise ValueError(f"Sound embeddings index {video_sound_embeds_idx} out of bounds for video_info[{k}][{i}]") | |
| segment_aud_indices_list = video_info[k][i]["segment_aud_indices_list"] | |
| segment_vis_indices_list = video_info[k][i]["segment_vis_indices_list"] | |
| vis_fea_len_per_frame = media_embeds["video"][video_embeds_idx].shape[0] / video_info[k][i]["expected_frame_count"] | |
| aud_fea_len_per_stft_frame = media_embeds["sound"][video_sound_embeds_idx].shape[0] / audio_info[k][i]["new_audio_n_stft_frames"] | |
| vis_end = 0 | |
| aud_end = 0 | |
| _new_video_embed = [] | |
| for j in range(len(segment_vis_indices_list)): | |
| _vis_aud_fea = [] | |
| if len(segment_vis_indices_list[j]) > 0: | |
| _new_frames = [int(np.ceil((_frame+1) * vis_fea_len_per_frame)) for _frame in segment_vis_indices_list[j]] | |
| _vis_fea_end = _new_frames[-1] | |
| # Ensure we don't exceed the available features | |
| _vis_fea_end = min(_vis_fea_end, media_embeds["video"][video_embeds_idx].shape[0]) | |
| if j == len(segment_vis_indices_list) - 1 and i == len(video_info) - 1 and k == len(video_info[i]) - 1 and not _vis_fea_end == media_embeds["video"][video_embeds_idx].shape[0]: | |
| print(f"Warning: The number of last interleaved video features does not match the video feature length. Expected: {media_embeds['video'][video_embeds_idx].shape[0]}, Got: {_vis_fea_end}") | |
| _vis_fea_end = media_embeds["video"][video_embeds_idx].shape[0] | |
| _vis_fea = media_embeds["video"][video_embeds_idx][vis_end:_vis_fea_end] | |
| vis_end = _vis_fea_end | |
| _vis_aud_fea.append(_vis_fea) | |
| _vis_aud_fea.append(sep_embed) | |
| if len(segment_aud_indices_list[j]) > 0: | |
| _new_audio_indices = [int(np.ceil(_fea * aud_fea_len_per_stft_frame)) for _fea in segment_aud_indices_list[j]] | |
| _aud_fea_end = _new_audio_indices[-1] | |
| # Ensure we don't exceed the available features | |
| _aud_fea_end = min(_aud_fea_end, media_embeds["sound"][video_sound_embeds_idx].shape[0]) | |
| _aud_fea = media_embeds["sound"][video_sound_embeds_idx][aud_end:_aud_fea_end] | |
| _vis_aud_fea.append(_aud_fea) | |
| aud_end = _aud_fea_end | |
| _vis_aud_fea.append(sep_embed) | |
| _new_video_embed.append(torch.cat(_vis_aud_fea, dim=0)) | |
| video_sound_embeds_idx += 1 | |
| new_video_embeds.append(torch.cat(_new_video_embed, dim=0)) | |
| video_embeds_idx += 1 | |
| assert len(new_video_embeds) == len(media_embeds["video"]), "The number of new video embeddings does not match the number of original video embeddings." | |
| media_embeds["video"] = new_video_embeds | |
| # Remove padding | |
| batch_size = labels.shape[0] | |
| text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)] | |
| labels = [labels[k][attention_mask[k]] for k in range(batch_size)] | |
| # Build inverse mapping from token ID to media name | |
| media_tokens = {} | |
| for name, token_id in self.tokenizer.media_token_ids.items(): | |
| media_tokens[token_id] = name | |
| # Fuse text and media embeddings | |
| inputs_m, labels_m = [], [] | |
| sound_embeds_idx = 0 | |
| for k in range(batch_size): | |
| inputs_mk, labels_mk = [], [] | |
| pos = 0 | |
| while pos < len(labels[k]): | |
| if input_ids[k][pos].item() in media_tokens: | |
| name = media_tokens[input_ids[k][pos].item()] if PROCESS_GROUP_MANAGER is None else "video" | |
| if input_ids[k][pos].item() == self.tokenizer.media_token_ids["sound"]: | |
| if self.config.interleaved_vis_aud_in_video: | |
| if sound_embeds_idx < video_sound_embeds_idx: | |
| media_embeds[name].popleft() | |
| sound_embeds_idx += 1 | |
| pos += 1 | |
| continue | |
| sound_embeds_idx += 1 | |
| end = pos + 1 | |
| input = media_embeds[name].popleft() | |
| label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype) | |
| else: | |
| end = pos | |
| while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens: | |
| end += 1 | |
| input = text_embeds[k][pos:end] | |
| label = labels[k][pos:end] | |
| inputs_mk.append(input) | |
| labels_mk.append(label) | |
| pos = end | |
| inputs_m.append(torch.cat(inputs_mk, dim=0)) | |
| labels_m.append(torch.cat(labels_mk, dim=0)) | |
| inputs, labels = inputs_m, labels_m | |
| inputs[0] += sep_embed.mean() * 0 # dummy embedding | |
| # Check if all media embeddings are consumed | |
| for name in media_embeds: | |
| if media_embeds[name]: | |
| raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.") | |
| # Truncate sequences to `model_max_length` as media embeddings are inserted | |
| inputs, labels = self.__truncate_sequence(inputs, labels) | |
| # Pad sequences to the longest one in the batch | |
| return self.__batchify_sequence(inputs, labels) | |
| def __embed_media_tokens( | |
| self, | |
| media: Dict[str, List[torch.Tensor]], | |
| media_config: Dict[str, Dict[str, Any]], | |
| mm_info, | |
| ) -> Dict[str, List[torch.Tensor]]: | |
| embeds = defaultdict(deque) | |
| if self.config.unified_audio_encoder: | |
| assert len(media["speech"]) == 0 | |
| for name in media: | |
| _encoder = self.encoders[name] | |
| if name in ["speech", "sound"] and self.config.unified_audio_encoder: | |
| _encoder = self.encoders["sound"] | |
| if self.training: | |
| # Gather metainfo of media objects from all ranks | |
| if name in ["speech", "sound"]: | |
| info = [] | |
| if type(media.get(name, {})) is dict: | |
| for _dict in media.get(name, {}): | |
| info.append({k: {"shape": v.shape, "dtype": v.dtype} for k, v in _dict.items()}) | |
| elif type(media.get(name, {})) is list: | |
| info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])] | |
| else: | |
| raise ValueError(f"Unsupported media type: {type(media.get(name, {}))}") | |
| infos_list = vila_all_gather(info) | |
| infos = list(chain(*infos_list)) | |
| # The entire batch does not contain any media objects of this type. | |
| if not infos: | |
| continue | |
| # for audio encoding, we have to ensure the batch size is the same for all ranks. If not, we need to pad the batch with dummy tensors to the max batch size | |
| max_batch_size = max(len(_info) for _info in infos_list) | |
| missing_batch_size = max_batch_size - len(info) | |
| _media = media.get(name, []) | |
| _medias = list(chain(vila_all_gather(_media))) | |
| if missing_batch_size > 0: | |
| for i in range(missing_batch_size): | |
| # use one of the media tensors to create a dummy tensor | |
| if type(media.get(name, {})) is dict: | |
| _dummy = {k: v.clone().to(device=self.device) for k, v in _medias[0].items()} | |
| elif type(media.get(name, {})) is list: | |
| if type(_medias[0]) is torch.Tensor: | |
| _dummy = _medias[0].clone().to(device=self.device) | |
| elif type(_medias[0]) is np.ndarray: | |
| _dummy = _medias[0].copy() | |
| else: | |
| raise ValueError(f"Unsupported media type: {type(_medias[0])}") | |
| else: | |
| raise ValueError(f"Unsupported media type: {type(media.get(name, {}))}") | |
| _media.append(_dummy) | |
| mm_info["audio_info"].append(["dummy"]) | |
| # we need to also align the length of all audio samples in the batch size | |
| cur_batch_max_audio_samples = max(len(_audio) for _audio in _medias) | |
| cur_batch_max_audio_samples = int(np.ceil(cur_batch_max_audio_samples / (self.config.audio_sampling_rate * 30)) * (self.config.audio_sampling_rate * 30)) # should be multiple of 30 seconds | |
| cur_batch_max_audio_samples = min(cur_batch_max_audio_samples, self.config.audio_chunk_length * self.config.audio_sampling_rate) | |
| cur_batch_max_audio_duration = cur_batch_max_audio_samples // self.config.audio_sampling_rate | |
| whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained( | |
| self.config._name_or_path, chunk_length=cur_batch_max_audio_duration, sampling_rate=self.config.audio_sampling_rate, hop_length=self.config.audio_hop_length | |
| ) | |
| # use WhisperFeatureExtractor in transformers to load | |
| new_media = [] | |
| aud_idx = 0 | |
| for _batch_idx in range(len(mm_info["audio_info"])): | |
| _audio_info = mm_info["audio_info"][_batch_idx] | |
| if _audio_info is not None: | |
| for _mm_idx in range(len(_audio_info)): | |
| _audio = _media[aud_idx] | |
| if type(_audio) is torch.Tensor: | |
| device = _audio.device | |
| dtype = _audio.dtype | |
| _audio = _audio.cpu().float() | |
| else: | |
| # logger.warning(f"The audio type is not a tensor, which is unexpected. Using the device and dtype of the model. media: {media}, mm_info: {mm_info}") | |
| device = self.device | |
| dtype = self.dtype | |
| _audio = whisper.pad_or_trim(_audio, length=cur_batch_max_audio_samples) | |
| aud_idx += 1 | |
| stft_features = whisper_feature_extractor( | |
| _audio, | |
| sampling_rate=self.config.audio_sampling_rate, | |
| return_attention_mask=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ).to(device, dtype) | |
| new_media.append(stft_features) | |
| if _audio_info[_mm_idx] != "dummy": | |
| _audio_info[_mm_idx]["new_audio_chunk_length"] = cur_batch_max_audio_duration | |
| _audio_info[_mm_idx]["new_audio_n_samples"] = cur_batch_max_audio_samples | |
| _audio_info[_mm_idx]["audio_end_sample_sec"] = _audio_info[_mm_idx]["audio_start_sec"] + cur_batch_max_audio_duration | |
| _audio_info[_mm_idx]["new_audio_n_stft_frames"] = stft_features["input_features"].shape[-1] | |
| assert aud_idx == len(_media), "The number of audio info does not match the number of audio samples." | |
| _media = new_media | |
| _fea = _encoder(_media, media_config[name], mm_info) | |
| # [751, 1536] | |
| # consume dummy features later | |
| _dummy_fea = _fea[len(info) :] | |
| embeds["dummy"].extend(_dummy_fea) | |
| # remove the dummy features | |
| _real_fea = _fea[: len(info)] | |
| if len(info) > 0: | |
| embeds[name] = deque(_real_fea) | |
| else: | |
| # Gather metainfo of media objects from all ranks | |
| info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])] | |
| infos = list(chain(vila_all_gather(info))) | |
| # The entire batch does not contain any media objects of this type. | |
| if not infos: | |
| continue | |
| # Create a dummy tensor to ensure the encoder is called, otherwise the training will hang. | |
| if media.get(name) is None or len(media[name]) == 0: | |
| dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device) | |
| embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name])) | |
| continue | |
| embeds[name] = deque(self.encoders[name](media[name], media_config[name])) | |
| else: | |
| if name == "sound": | |
| all_audio_chunk_lengths = [] | |
| for _sample_idx in range(len(media[name])): | |
| for _mm_idx in range(len(mm_info["audio_info"][_sample_idx])): | |
| _new_audio_chunk_length = mm_info["audio_info"][_sample_idx][_mm_idx]["new_audio_chunk_length"] | |
| all_audio_chunk_lengths.append(_new_audio_chunk_length) | |
| cur_batch_max_audio_duration = max(all_audio_chunk_lengths) | |
| cur_batch_max_audio_samples = cur_batch_max_audio_duration * self.config.audio_sampling_rate | |
| # for qwen omni audio | |
| # cur_batch_max_audio_samples = 960000 | |
| whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained( | |
| self.config._name_or_path, chunk_length=cur_batch_max_audio_duration, sampling_rate=self.config.audio_sampling_rate, hop_length=self.config.audio_hop_length | |
| ) | |
| new_media = [] | |
| _idx = 0 | |
| assert len(all_audio_chunk_lengths) == len(media[name]), "The number of audio chunk lengths does not match the number of audio samples." | |
| _media = media.get(name, []) | |
| aud_idx = 0 | |
| for _batch_idx in range(len(mm_info["audio_info"])): | |
| _audio_info = mm_info["audio_info"][_batch_idx] | |
| if _audio_info is not None: | |
| for _mm_idx in range(len(_audio_info)): | |
| _audio = _media[aud_idx] | |
| if type(_audio) is torch.Tensor: | |
| device = _audio.device | |
| dtype = _audio.dtype | |
| _audio = _audio.cpu().float() | |
| else: | |
| device = self.device | |
| dtype = self.dtype | |
| _audio = whisper.pad_or_trim(_audio, length=cur_batch_max_audio_samples) | |
| aud_idx += 1 | |
| stft_features = whisper_feature_extractor( | |
| _audio, | |
| sampling_rate=self.config.audio_sampling_rate, | |
| return_attention_mask=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ).to(device, dtype) | |
| new_media.append(stft_features) | |
| if _audio_info[_mm_idx] != "dummy": | |
| _audio_info[_mm_idx]["new_audio_chunk_length"] = cur_batch_max_audio_duration | |
| _audio_info[_mm_idx]["new_audio_n_samples"] = cur_batch_max_audio_samples | |
| _audio_info[_mm_idx]["audio_end_sample_sec"] = _audio_info[_mm_idx]["audio_start_sec"] + cur_batch_max_audio_duration | |
| _audio_info[_mm_idx]["new_audio_n_stft_frames"] = stft_features["input_features"].shape[-1] | |
| media[name] = new_media | |
| if len(media[name]) > 0: | |
| embeds[name] = deque(_encoder(media[name], media_config[name], mm_info)) | |
| return embeds | |
| def __truncate_sequence( | |
| self, inputs: List[torch.Tensor], labels: List[torch.Tensor] | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs): | |
| warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).") | |
| inputs = [input[: self.tokenizer.model_max_length] for input in inputs] | |
| labels = [label[: self.tokenizer.model_max_length] for label in labels] | |
| return inputs, labels | |
| def __batchify_sequence( | |
| self, inputs: List[torch.Tensor], labels: List[torch.Tensor] | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| batch_size = len(inputs) | |
| device = inputs[0].device | |
| hidden_size = inputs[0].shape[1] | |
| max_length = max(inputs[k].shape[0] for k in range(batch_size)) | |
| attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device) | |
| inputs_p, labels_p = [], [] | |
| for k in range(batch_size): | |
| size_pk = max_length - inputs[k].shape[0] | |
| inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device) | |
| labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device) | |
| if self.tokenizer.padding_side == "right": | |
| attention_mask[k, inputs[k].shape[0] :] = False | |
| inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0) | |
| labels_pk = torch.cat([labels[k], labels_pk], dim=0) | |
| else: | |
| labels[k] = labels[k].to(device) | |
| attention_mask[k, : -inputs[k].shape[0]] = False | |
| inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0) | |
| labels_pk = torch.cat([labels_pk, labels[k]], dim=0) | |
| inputs_p.append(inputs_pk) | |
| labels_p.append(labels_pk) | |
| inputs = torch.stack(inputs_p, dim=0) | |
| labels = torch.stack(labels_p, dim=0) | |
| return inputs, labels, attention_mask | |
| def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels): | |
| # Handle sequence parallelism | |
| PROCESS_GROUP_MANAGER = None | |
| # We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks. | |
| if PROCESS_GROUP_MANAGER is not None: | |
| sp_degree = PROCESS_GROUP_MANAGER.sp_degree | |
| sp_rank = PROCESS_GROUP_MANAGER.sp_rank | |
| sp_group = PROCESS_GROUP_MANAGER.sp_pg | |
| ring_degree = PROCESS_GROUP_MANAGER.ring_degree | |
| ring_rank = PROCESS_GROUP_MANAGER.ring_rank | |
| ring_type = PROCESS_GROUP_MANAGER.ring_type | |
| ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree | |
| ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank | |
| bs, shard_seqlen = position_ids.shape | |
| sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)] | |
| dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group) | |
| sp_seq_len_cat = torch.cat(sp_seq_len, dim=0) | |
| if sp_rank == 0: | |
| original_start_id = 0 | |
| else: | |
| original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item() | |
| original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item() | |
| # Gather attention_mask, position_ids, labels and input_embeds | |
| all_inputs_embeds = torch.zeros( | |
| bs, | |
| torch.sum(sp_seq_len_cat), | |
| inputs_embeds.shape[-1], | |
| dtype=inputs_embeds.dtype, | |
| device=inputs_embeds.device, | |
| ).contiguous() | |
| all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds | |
| dist.barrier(group=sp_group) | |
| dist.all_reduce(all_inputs_embeds, group=sp_group) | |
| dist.barrier(group=sp_group) | |
| attention_mask_list = [ | |
| torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device) | |
| for i in range(sp_degree) | |
| ] | |
| position_ids_list = [ | |
| torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device) | |
| for i in range(sp_degree) | |
| ] | |
| labels_list = [ | |
| torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree) | |
| ] | |
| dist.all_gather(attention_mask_list, attention_mask, group=sp_group) | |
| dist.all_gather(position_ids_list, position_ids, group=sp_group) | |
| dist.all_gather(labels_list, labels, group=sp_group) | |
| effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)] | |
| effective_seqlen = torch.stack(effective_seqlen_list, dim=-1) | |
| effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0) | |
| global_attention_mask_list = [] | |
| global_position_ids_list = [] | |
| global_labels_list = [] | |
| global_inputs_embeds_list = [] | |
| for i in range(bs): | |
| global_attention_mask_batch_list = [] | |
| global_position_ids_batch_list = [] | |
| global_labels_batch_list = [] | |
| global_inputs_embeds_batch_list = [] | |
| for j in range(sp_degree): | |
| eff_len = effective_seqlen_batch_list[i][j] | |
| prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0 | |
| global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len]) | |
| global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len]) | |
| global_labels_batch_list.append(labels_list[j][i, :eff_len]) | |
| global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :]) | |
| global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0)) | |
| global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0)) | |
| global_labels_list.append(torch.cat(global_labels_batch_list, dim=0)) | |
| global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0)) | |
| global_attention_mask = torch.nn.utils.rnn.pad_sequence( | |
| global_attention_mask_list, batch_first=True, padding_value=False | |
| ) | |
| global_position_ids = torch.nn.utils.rnn.pad_sequence( | |
| global_position_ids_list, batch_first=True, padding_value=-1 | |
| ) | |
| global_labels = torch.nn.utils.rnn.pad_sequence( | |
| global_labels_list, batch_first=True, padding_value=IGNORE_INDEX | |
| ) | |
| global_inputs_embeds = torch.nn.utils.rnn.pad_sequence( | |
| global_inputs_embeds_list, batch_first=True, padding_value=0 | |
| ) | |
| # Re-shard the inputs | |
| if ring_degree > 1: | |
| total_effective_seqlen = torch.sum(effective_seqlen, dim=1) | |
| new_seqlen_per_rank = total_effective_seqlen // sp_degree | |
| assert torch.all( | |
| total_effective_seqlen % sp_degree == 0 | |
| ), "total_effective_seqlen must be divisible by sp_degree" | |
| max_new_seqlen = torch.max(new_seqlen_per_rank).item() | |
| new_attention_mask = torch.zeros( | |
| (bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device | |
| ) | |
| new_position_ids = torch.zeros( | |
| (bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device | |
| ) | |
| new_labels = torch.full( | |
| (bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device | |
| ) | |
| new_inputs_embeds = torch.zeros( | |
| (bs, max_new_seqlen, global_inputs_embeds.shape[-1]), | |
| dtype=global_inputs_embeds.dtype, | |
| device=global_inputs_embeds.device, | |
| ) | |
| if ring_type == "ring_varlen": | |
| for i in range(bs): | |
| start_idx = new_seqlen_per_rank[i] * sp_rank | |
| end_idx = start_idx + new_seqlen_per_rank[i] | |
| new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx] | |
| new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx] | |
| new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx] | |
| new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[ | |
| i, start_idx:end_idx, : | |
| ] | |
| elif ring_type == "zigzag_ring_varlen": | |
| chunk_size = total_effective_seqlen // (2 * sp_degree) | |
| for i in range(bs): | |
| # Zigzag pattern indices | |
| if sp_degree == ring_degree: | |
| forward_rank_idx = sp_rank | |
| backward_rank_idx = 2 * sp_degree - sp_rank - 1 | |
| else: | |
| ulysses_offset = ulysses_rank * ring_degree * 2 | |
| forward_rank_idx = ring_rank + ulysses_offset | |
| backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset | |
| # Calculate start and end indices for the forward and backward zigzag | |
| start_idx_fwd = forward_rank_idx * chunk_size[i] | |
| end_idx_fwd = start_idx_fwd + chunk_size[i] | |
| start_idx_bwd = backward_rank_idx * chunk_size[i] | |
| end_idx_bwd = start_idx_bwd + chunk_size[i] | |
| # Fill new tensors with zigzag data | |
| new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd] | |
| new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[ | |
| i, start_idx_bwd:end_idx_bwd | |
| ] | |
| new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd] | |
| new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[ | |
| i, start_idx_bwd:end_idx_bwd | |
| ] | |
| new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd] | |
| new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd] | |
| new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :] | |
| new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[ | |
| i, start_idx_bwd:end_idx_bwd, : | |
| ] | |
| else: | |
| raise ValueError(f"Invalid ring_type: {ring_type}") | |
| else: | |
| global_seq_len = global_attention_mask.shape[-1] | |
| seq_len_sharded = global_seq_len // sp_degree | |
| start_idx_reshard = seq_len_sharded * sp_rank | |
| end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len | |
| new_attention_mask = torch.narrow( | |
| global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard | |
| ) | |
| new_position_ids = torch.narrow( | |
| global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard | |
| ) | |
| new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard) | |
| new_inputs_embeds = torch.narrow( | |
| global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard | |
| ) | |
| return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels | |
| device = inputs_embeds.device | |
| batch_size = inputs_embeds.shape[0] | |
| seqlens = [attention_mask[k].sum().item() for k in range(batch_size)] | |
| # Pack all sequences together | |
| inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)] | |
| attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)] | |
| position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)] | |
| labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)] | |
| # Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called | |
| inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device)) | |
| attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device)) | |
| position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device)) | |
| labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device)) | |
| # Mask the first token of each sequence to avoid contamination | |
| for label in labels_p: | |
| label[0] = IGNORE_INDEX | |
| # Batch the data | |
| inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0) | |
| attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0) | |
| position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0) | |
| labels_p = torch.cat(labels_p, dim=0).unsqueeze(0) | |
| if hasattr( | |
| self, "pad_to_multiple_of" | |
| ): # related to quantization, please refer to ModelArguments for more information. | |
| assert len(labels_p.shape) == 2 | |
| batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1] | |
| hidden_size = inputs_embeds_p.shape[-1] | |
| if max_length % self.pad_to_multiple_of != 0: | |
| max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of | |
| difference = max_length - cur_length | |
| inputs_embeds_p = torch.cat( | |
| ( | |
| inputs_embeds_p, | |
| torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p), | |
| ), | |
| dim=1, | |
| ) | |
| labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1) | |
| attention_mask_p = torch.cat( | |
| ( | |
| attention_mask_p, | |
| torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p), | |
| ), | |
| dim=1, | |
| ) | |
| position_ids_p = torch.cat( | |
| (position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1 | |
| ) | |
| return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| media: Optional[Dict[str, List[torch.Tensor]]] = None, | |
| images: Optional[torch.FloatTensor] = None, | |
| media_config: Optional[List] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| packing: bool = True, | |
| force_packing: bool = False, | |
| seqlens_in_batch: Optional[torch.LongTensor] = None, | |
| dpo_forward: bool = False, | |
| **kwargs, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| self.freezed_module_patch() | |
| if images is not None: | |
| if media is not None: | |
| raise ValueError("Both 'media' and 'images' are provided. Please provide only one.") | |
| print("The 'images' argument is deprecated. Please use 'media' instead.") | |
| media = {"image": images} | |
| if media_config is None: | |
| media_config = defaultdict(dict) | |
| if inputs_embeds is None: | |
| inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask) | |
| if force_packing or (packing and self.training and not dpo_forward): | |
| if seqlens_in_batch is None: | |
| seqlens_in_batch = torch.sum(attention_mask, dim=1) | |
| set_seqlens_in_batch(seqlens_in_batch) | |
| (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data( | |
| inputs_embeds, attention_mask, position_ids, labels | |
| ) | |
| outputs = self.llm( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| labels=labels, | |
| **kwargs, | |
| ) | |
| if self.training and getattr(self.config, "time_token_ids", []): | |
| outputs.loss = soft_cross_entropy( | |
| outputs.logits, | |
| labels, | |
| soft_tokens=self.config.time_token_ids, | |
| std=self.config.soft_ce_std, | |
| ) | |
| if dpo_forward: | |
| return outputs.logits, labels | |
| return outputs | |
| def generate( | |
| self, | |
| input_ids: Optional[torch.FloatTensor] = None, | |
| media: Optional[Dict[str, List[torch.Tensor]]] = None, | |
| media_config: Dict[str, Dict[str, Any]] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| return_output_ids_only: bool = True, | |
| **generation_kwargs, | |
| ) -> torch.LongTensor: | |
| """ | |
| input_tokens: <image> describe the image | |
| media: [Tensor(1, 3, 384, 384), ] | |
| -----------> | |
| input_tokens: 36000 001 002 003 004 | |
| input_emds: <media emd> 001 002 003 004 | |
| """ | |
| inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask) | |
| output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs) | |
| if return_output_ids_only: | |
| return_value = output_ids | |
| else: | |
| # by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen | |
| generation_config = generation_kwargs.get("generation_config", None) | |
| if generation_config is not None: | |
| num_generations = generation_config.num_return_sequences | |
| repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0) | |
| return_value = torch.cat([repeat_input_ids, output_ids], dim=-1) | |
| else: | |
| return_value = torch.cat([input_ids, output_ids], dim=-1) | |
| return return_value | |
| def generate_content( | |
| self, | |
| prompt: Union[str, List], | |
| generation_config: Optional[GenerationConfig] = None, | |
| response_format=None, | |
| ) -> str: | |
| conversation = [{"from": "human", "value": prompt}] | |
| # Convert response format to logits processor | |
| xgr_logits_processor = None | |
| # Extract media from the conversation | |
| media = extract_media(conversation, self.config) | |
| # Process media | |
| media_config = defaultdict(dict) | |
| for name in media: | |
| if name == "image": | |
| if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]: | |
| self.config.image_processor = self.vision_tower.image_processor | |
| if self.config.image_aspect_ratio == "dynamic": | |
| images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half() | |
| conversation[0]["value"] = conversation[0]["value"].replace( | |
| DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0] | |
| ) | |
| else: | |
| if type(self.config.s2_scales) is str: | |
| self.config.s2_scales = list(map(int, self.config.s2_scales.split(","))) | |
| images, block_sizes = process_image( | |
| media["image"][0], self.config, None, enable_dynamic_s2=True | |
| ) | |
| images = images.half() | |
| media_config[name]["block_sizes"] = [block_sizes] | |
| else: | |
| images = process_images(media["image"], self.vision_tower.image_processor, self.config).half() | |
| media[name] = [image for image in images] | |
| elif name == "video": | |
| if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1: | |
| media[name] = [ | |
| process_images( | |
| images, | |
| self.vision_tower.image_processor, | |
| self.config, | |
| enable_dynamic_res=True, | |
| max_tiles=self.config.video_max_tiles, | |
| ).half() | |
| for images in media[name] | |
| ] | |
| elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1: | |
| self.config.image_processor = self.vision_tower.image_processor | |
| if type(self.config.s2_scales) is str: | |
| self.config.s2_scales = list(map(int, self.config.s2_scales.split(","))) | |
| media[name] = [ | |
| torch.cat( | |
| [ | |
| process_image( | |
| image, | |
| self.config, | |
| None, | |
| enable_dynamic_s2=True, | |
| max_tiles=self.config.video_max_tiles, | |
| )[0].half() | |
| for image in images | |
| ] | |
| ) | |
| for images in media[name] | |
| ] | |
| else: | |
| media[name] = [ | |
| process_images(images, self.vision_tower.image_processor, self.config) | |
| for images in media[name] | |
| ] | |
| elif name == "speech": | |
| speeches = media["speech"] | |
| media[name] = [speech for speech in speeches] | |
| elif name == "sound": | |
| # sounds = process_sounds(media["sound"]).half() | |
| sounds = media["sound"] | |
| # media[name] = [{k: v.half() for sound in sounds for k, v in sound.items()] | |
| for sound in sounds: | |
| if type(sound) is dict: | |
| for k, v in sound.items(): | |
| sound[k] = v.half() | |
| media[name] = [sound for sound in sounds] | |
| elif name == "video_info": | |
| media[name] = [media["video_info"]] | |
| elif name == "audio_info": | |
| media[name] = [media["audio_info"]] | |
| else: | |
| raise ValueError(f"Unsupported media type: {name}") | |
| # Tokenize the conversation | |
| input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda() | |
| # Set up the generation config | |
| generation_config = generation_config or self.default_generation_config | |
| # Generate the response | |
| try: | |
| output_ids = self.generate( | |
| input_ids=input_ids, | |
| media=media, | |
| media_config=media_config, | |
| generation_config=generation_config, | |
| logits_processor=xgr_logits_processor, # structured generation | |
| ) | |
| except ValueError: | |
| if not generation_config.do_sample: | |
| raise | |
| logging.warning("Generation failed with sampling, retrying with greedy decoding.") | |
| generation_config.do_sample = False | |
| output_ids = self.generate( | |
| input_ids=input_ids, | |
| media=media, | |
| media_config=media_config, | |
| generation_config=generation_config, | |
| logits_processor=xgr_logits_processor, | |
| ) | |
| # Decode the response | |
| response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() | |
| return response | |
| def default_generation_config(self) -> GenerationConfig: | |
| generation_config = copy.deepcopy(self.generation_config or GenerationConfig()) | |
| if self.tokenizer.eos_token_id is None: | |
| raise ValueError("Tokenizer must have an EOS token") | |
| if generation_config.max_length == GenerationConfig().max_length: | |
| generation_config.max_length = self.tokenizer.model_max_length | |
| if generation_config.pad_token_id is None: | |
| generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id | |
| if generation_config.bos_token_id is None: | |
| generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id | |
| if generation_config.eos_token_id is None: | |
| generation_config.eos_token_id = self.tokenizer.eos_token_id | |
| return generation_config | |