# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); """ MiniMax VL family HuggingFace-compatible Processor, ImageProcessor, VideoProcessor. """ import math from typing import List, Tuple import torch from torchvision.transforms import InterpolationMode from transformers import BatchFeature from transformers.image_processing_utils_fast import ( BaseImageProcessorFast, group_images_by_shape, reorder_images, ) from transformers.image_utils import PILImageResampling, SizeDict from transformers.processing_utils import ( ImagesKwargs, Unpack, ) from transformers.utils import TensorType MAX_RATIO = 200 def round_by_factor(number: int, factor: int) -> int: return round(number / factor) * factor def ceil_by_factor(number: int, factor: int) -> int: return math.ceil(number / factor) * factor def floor_by_factor(number: int, factor: int) -> int: return math.floor(number / factor) * factor def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 4 * 28 * 28, max_pixels: int = 451584, ) -> tuple[int, int]: if max(height, width) / min(height, width) > MAX_RATIO: raise ValueError( f"absolute aspect ratio must be smaller than {MAX_RATIO}, " f"got {max(height, width) / min(height, width)}" ) h_bar = max(factor, round_by_factor(height, factor)) w_bar = max(factor, round_by_factor(width, factor)) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = floor_by_factor(height / beta, factor) w_bar = floor_by_factor(width / beta, factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = ceil_by_factor(height * beta, factor) w_bar = ceil_by_factor(width * beta, factor) return h_bar, w_bar # ============================================================================== # MiniMax M3 VL Image Processor Fast (Fast Mode - Torch based) # ============================================================================== class MiniMaxM3VLImageProcessorKwargs(ImagesKwargs, total=False): patch_size: int temporal_patch_size: int merge_size: int max_pixels: int class MiniMaxM3VLImageProcessor(BaseImageProcessorFast): do_resize = True resample = PILImageResampling.BICUBIC size = {"height": 672, "width": 672} # required by base class validation, not used as resize bound default_to_square = False do_rescale = True rescale_factor = 1 / 255 do_normalize = True image_mean = [0.48145466, 0.4578275, 0.40821073] image_std = [0.26862954, 0.26130258, 0.27577711] do_convert_rgb = True patch_size = 14 temporal_patch_size = 2 merge_size = 2 max_pixels = 451584 # 672*672 valid_kwargs = MiniMaxM3VLImageProcessorKwargs model_input_names = ["pixel_values", "image_grid_thw"] def __init__(self, **kwargs: Unpack[MiniMaxM3VLImageProcessorKwargs]): super().__init__(**kwargs) def preprocess( self, images, **kwargs: Unpack[MiniMaxM3VLImageProcessorKwargs] ) -> BatchFeature: return super().preprocess(images, **kwargs) def _preprocess( self, images: List[torch.Tensor], do_resize: bool, size: SizeDict, resample: PILImageResampling | InterpolationMode | int | None, do_rescale: bool, rescale_factor: float, do_normalize: bool, image_mean: float | List[float] | None, image_std: float | List[float] | None, patch_size: int, temporal_patch_size: int, merge_size: int, max_pixels: int, disable_grouping: bool | None, return_tensors: str | TensorType | None, **kwargs, ) -> BatchFeature: grouped_images, grouped_images_index = group_images_by_shape( images, disable_grouping=disable_grouping ) resized_images_grouped = {} factor = patch_size * merge_size for shape, stacked_images in grouped_images.items(): height, width = stacked_images.shape[-2:] if do_resize: resized_height, resized_width = smart_resize( height, width, factor=factor, max_pixels=max_pixels, ) stacked_images = self.resize( stacked_images, size=SizeDict(height=resized_height, width=resized_width), resample=resample, ) resized_images_grouped[shape] = stacked_images resized_images = reorder_images(resized_images_grouped, grouped_images_index) grouped_images, grouped_images_index = group_images_by_shape( resized_images, disable_grouping=disable_grouping ) processed_images_grouped = {} processed_grids = {} for shape, stacked_images in grouped_images.items(): resized_height, resized_width = stacked_images.shape[-2:] patches = self.rescale_and_normalize( stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std, ) if patches.ndim == 4: patches = patches.unsqueeze(1) if patches.shape[1] % temporal_patch_size != 0: repeats = patches[:, -1:].repeat( 1, temporal_patch_size - (patches.shape[1] % temporal_patch_size), 1, 1, 1, ) patches = torch.cat([patches, repeats], dim=1) batch_size, grid_t, channel = patches.shape[:3] grid_t = grid_t // temporal_patch_size grid_h, grid_w = resized_height // patch_size, resized_width // patch_size patches = patches.view( batch_size, grid_t, temporal_patch_size, channel, grid_h // merge_size, merge_size, patch_size, grid_w // merge_size, merge_size, patch_size, ) patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) flatten_patches = patches.reshape( batch_size, grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size, ) processed_images_grouped[shape] = flatten_patches processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size processed_images = reorder_images( processed_images_grouped, grouped_images_index ) processed_grids = reorder_images(processed_grids, grouped_images_index) pixel_values = torch.cat(processed_images, dim=0) image_grid_thw = torch.tensor(processed_grids, dtype=torch.long) return BatchFeature( data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors, ) def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): images_kwargs = images_kwargs or {} patch_size = images_kwargs.get("patch_size", self.patch_size) merge_size = images_kwargs.get("merge_size", self.merge_size) max_pixels = images_kwargs.get("max_pixels", self.max_pixels) resized_height, resized_width = smart_resize( height, width, factor=patch_size * merge_size, max_pixels=max_pixels, ) grid_h, grid_w = resized_height // patch_size, resized_width // patch_size return grid_h * grid_w