Image-Text-to-Text
Transformers
Safetensors
English
Chinese
ristretto
feature-extraction
conversational
custom_code
Instructions to use LiAutoAD/Ristretto-3B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use LiAutoAD/Ristretto-3B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-text-to-text", model="LiAutoAD/Ristretto-3B", trust_remote_code=True) messages = [ { "role": "user", "content": [ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"}, {"type": "text", "text": "What animal is on the candy?"} ] }, ] pipe(text=messages)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("LiAutoAD/Ristretto-3B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use LiAutoAD/Ristretto-3B with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "LiAutoAD/Ristretto-3B" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "LiAutoAD/Ristretto-3B", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker
docker model run hf.co/LiAutoAD/Ristretto-3B
- SGLang
How to use LiAutoAD/Ristretto-3B with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "LiAutoAD/Ristretto-3B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "LiAutoAD/Ristretto-3B", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "LiAutoAD/Ristretto-3B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "LiAutoAD/Ristretto-3B", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }' - Docker Model Runner
How to use LiAutoAD/Ristretto-3B with Docker Model Runner:
docker model run hf.co/LiAutoAD/Ristretto-3B
| # -------------------------------------------------------- | |
| # Ristretto | |
| # Copyright (c) 2025 LiAutoAD | |
| # Licensed under The MIT License | |
| # -------------------------------------------------------- | |
| import copy | |
| from typing import Any, List, Optional, Tuple, Union | |
| import torch.distributed as dist | |
| import torch.utils.checkpoint | |
| import transformers | |
| from torch import nn | |
| from torch.nn import CrossEntropyLoss | |
| from transformers import (GenerationConfig, LlamaConfig, | |
| LlamaForCausalLM, PretrainedConfig, | |
| Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, | |
| SiglipVisionModel) | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.trainer_pt_utils import LabelSmoother | |
| from transformers.utils import logging | |
| from .conversation import get_conv_template | |
| from .projector import TokenAdaptiveProjector | |
| IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
| logger = logging.get_logger(__name__) | |
| logger.setLevel(logging.INFO) | |
| def version_cmp(v1, v2, op='eq'): | |
| import operator | |
| from packaging import version | |
| op_func = getattr(operator, op) | |
| return op_func(version.parse(v1), version.parse(v2)) | |
| class RistrettoConfig(PretrainedConfig): | |
| model_type = 'ristretto' | |
| is_composition = True | |
| def __init__( | |
| self, | |
| vision_config=dict(model_type='siglip_vision_model'), | |
| llm_config=dict(architectures=['Qwen2ForCausalLM']), | |
| pad2square=False, | |
| select_layer=-1, | |
| force_image_size=None, | |
| num_image_token=256, | |
| template=None, | |
| dynamic_image_size=False, | |
| use_thumbnail=False, | |
| min_dynamic_patch=1, | |
| max_dynamic_patch=6, | |
| **kwargs): | |
| super().__init__(**kwargs) | |
| if vision_config["model_type"] == "siglip_vision_model": | |
| self.vision_config = SiglipVisionConfig(**vision_config) | |
| else: | |
| raise ValueError('Unsupported architecture: {}'.format(vision_config['model_type'])) | |
| if llm_config['architectures'][0] == 'LlamaForCausalLM': | |
| self.llm_config = LlamaConfig(**llm_config) | |
| elif llm_config['architectures'][0] == 'Qwen2ForCausalLM': | |
| self.llm_config = Qwen2Config(**llm_config) | |
| else: | |
| raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0])) | |
| self.pad2square = pad2square | |
| self.select_layer = select_layer | |
| self.force_image_size = force_image_size | |
| self.num_image_token = num_image_token | |
| self.template = template | |
| self.dynamic_image_size = dynamic_image_size | |
| self.use_thumbnail = use_thumbnail | |
| self.min_dynamic_patch = min_dynamic_patch | |
| self.max_dynamic_patch = max_dynamic_patch | |
| logger.info(f'vision_select_layer: {self.select_layer}') | |
| logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}') | |
| logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}') | |
| def to_dict(self): | |
| """ | |
| Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. | |
| Returns: | |
| `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, | |
| """ | |
| output = copy.deepcopy(self.__dict__) | |
| output['vision_config'] = self.vision_config.to_dict() | |
| output['llm_config'] = self.llm_config.to_dict() | |
| output['model_type'] = self.__class__.model_type | |
| output['pad2square'] = self.pad2square | |
| output['select_layer'] = self.select_layer | |
| output['force_image_size'] = self.force_image_size | |
| output['num_image_token'] = self.num_image_token | |
| output['template'] = self.template | |
| output['dynamic_image_size'] = self.dynamic_image_size | |
| output['use_thumbnail'] = self.use_thumbnail | |
| output['min_dynamic_patch'] = self.min_dynamic_patch | |
| output['max_dynamic_patch'] = self.max_dynamic_patch | |
| return output | |
| class RistrettoModel(PreTrainedModel): | |
| config_class = RistrettoConfig | |
| main_input_name = 'pixel_values' | |
| _no_split_modules = ['SiglipVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer'] | |
| _supports_flash_attn_2 = True | |
| _keys_to_ignore_on_save = [] | |
| def __init__(self, config: RistrettoConfig, vision_model=None, language_model=None): | |
| super().__init__(config) | |
| assert version_cmp(transformers.__version__, '4.37.0', 'ge') | |
| image_size = config.force_image_size or config.vision_config.image_size | |
| patch_size = config.vision_config.patch_size | |
| self.image_size = image_size | |
| self.patch_size = patch_size | |
| self.select_layer = config.select_layer | |
| self.template = config.template | |
| self.num_image_token = config.num_image_token | |
| self.llm_arch_name = config.llm_config.architectures[0] | |
| self.vision_model_type = config.vision_config.model_type | |
| if vision_model is not None: | |
| self.vision_model = vision_model | |
| else: | |
| if config.vision_config.model_type == 'siglip_vision_model': | |
| self.vision_model = SiglipVisionModel(config.vision_config) | |
| else: | |
| raise NotImplementedError(f'{config.vision_config.model_type} is not implemented.') | |
| if language_model is not None: | |
| self.language_model = language_model | |
| else: | |
| if config.llm_config.architectures[0] == 'LlamaForCausalLM': | |
| self.language_model = LlamaForCausalLM(config.llm_config) | |
| elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM': | |
| self.language_model = Qwen2ForCausalLM(config.llm_config) | |
| else: | |
| raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') | |
| vit_hidden_size = config.vision_config.hidden_size | |
| llm_hidden_size = config.llm_config.hidden_size | |
| self.projector = TokenAdaptiveProjector( | |
| vit_hidden_size=vit_hidden_size, | |
| llm_hidden_size=llm_hidden_size, | |
| num_image_token=self.num_image_token, | |
| ) | |
| self.img_context_token_id = None | |
| self.conv_template = get_conv_template(self.template) | |
| self.system_message = self.conv_template.system_message | |
| self.num_samples = 0 | |
| def forward( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| image_flags: Optional[torch.LongTensor] = None, | |
| num_image_tokens: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| num_image_token = None | |
| if num_image_tokens is not None: | |
| assert num_image_tokens.unique().shape[0] == 1, 'num_image_tokens must be the same for all samples in a batch' | |
| num_image_token = num_image_tokens[0].item() | |
| image_flags = image_flags.squeeze(-1) | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() | |
| vit_embeds = self.extract_feature(pixel_values, num_image_token) | |
| vit_embeds = vit_embeds[image_flags == 1] | |
| vit_batch_size = pixel_values.shape[0] | |
| B, N, C = input_embeds.shape | |
| input_embeds = input_embeds.reshape(B * N, C) | |
| if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: | |
| print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') | |
| input_ids = input_ids.reshape(B * N) | |
| selected = (input_ids == self.img_context_token_id) | |
| try: | |
| input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) | |
| ignore_flag = False | |
| except Exception as e: | |
| vit_embeds = vit_embeds.reshape(-1, C) | |
| print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' | |
| f'vit_embeds.shape={vit_embeds.shape}') | |
| n_token = selected.sum() | |
| input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] | |
| ignore_flag = True | |
| input_embeds = input_embeds.reshape(B, N, C) | |
| outputs = self.language_model( | |
| inputs_embeds=input_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| logits = outputs.logits | |
| loss = None | |
| if labels is not None: | |
| loss_fct = CrossEntropyLoss(reduction='none') | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # Calc loss weight | |
| loss_token_mask = shift_labels != loss_fct.ignore_index | |
| loss_token_num = loss_token_mask.sum(dim=1, keepdim=True).float() | |
| loss_token_weight = 1. / (loss_token_num.expand_as(shift_labels) ** 0.5 + 1e-6) | |
| # Flatten the tokens | |
| shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| loss_token_weight = loss_token_weight.view(-1) | |
| loss_token_mask = loss_token_mask.view(-1) | |
| # Enable model parallelism | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| all_token_weight = (loss_token_weight * loss_token_mask.float()).sum() | |
| dist.all_reduce(all_token_weight, op=dist.ReduceOp.SUM) | |
| loss = (loss * loss_token_weight * loss_token_mask.float()).sum() / (all_token_weight + 1e-6) | |
| # Hack for DDP training, since the loss is reduced in the forward function | |
| loss = loss * dist.get_world_size() | |
| if ignore_flag: | |
| loss = loss * 0.0 | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def extract_feature(self, pixel_values, num_image_token=None): | |
| if self.select_layer == -1: | |
| vit_embeds = self.vision_model( | |
| pixel_values=pixel_values, | |
| output_hidden_states=False, | |
| return_dict=True).last_hidden_state | |
| else: | |
| vit_embeds = self.vision_model( | |
| pixel_values=pixel_values, | |
| output_hidden_states=True, | |
| return_dict=True).hidden_states[self.select_layer] | |
| vit_embeds = self.projector(vit_embeds, num_image_token=num_image_token) | |
| return vit_embeds | |
| def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, | |
| history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', | |
| IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None): | |
| if history is not None or return_history: | |
| print('Now multi-turn chat is not supported in batch_chat.') | |
| raise NotImplementedError | |
| if image_counts is not None: | |
| num_patches_list = image_counts | |
| print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.') | |
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) | |
| self.img_context_token_id = img_context_token_id | |
| if verbose and pixel_values is not None: | |
| image_bs = pixel_values.shape[0] | |
| print(f'dynamic ViT batch size: {image_bs}') | |
| queries = [] | |
| for idx, _num_patches_list in enumerate(num_patches_list): | |
| question = questions[idx] | |
| if pixel_values is not None and '<image>' not in question: | |
| question = '<image>\n' + question | |
| template = get_conv_template(self.template) | |
| template.system_message = self.system_message | |
| template.append_message(template.roles[0], question) | |
| template.append_message(template.roles[1], None) | |
| query = template.get_prompt() | |
| for num_patches in _num_patches_list: | |
| image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN | |
| query = query.replace('<image>', image_tokens, 1) | |
| queries.append(query) | |
| tokenizer.padding_side = 'left' | |
| model_inputs = tokenizer(queries, return_tensors='pt', padding=True) | |
| input_ids = model_inputs['input_ids'].cuda() | |
| attention_mask = model_inputs['attention_mask'].cuda() | |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) | |
| generation_config['eos_token_id'] = eos_token_id | |
| generation_output = self.generate( | |
| pixel_values=pixel_values, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| **generation_config | |
| ) | |
| responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) | |
| responses = [response.split(template.sep)[0].strip() for response in responses] | |
| return responses | |
| def chat(self, tokenizer, pixel_values, question, generation_config, num_image_token=None, history=None, return_history=False, | |
| num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', | |
| verbose=False): | |
| if history is None and pixel_values is not None and '<image>' not in question: | |
| question = '<image>\n' + question | |
| if num_patches_list is None: | |
| num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] | |
| assert pixel_values is None or len(pixel_values) == sum(num_patches_list) | |
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) | |
| self.img_context_token_id = img_context_token_id | |
| template = get_conv_template(self.template) | |
| template.system_message = self.system_message | |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) | |
| history = [] if history is None else history | |
| for (old_question, old_answer) in history: | |
| template.append_message(template.roles[0], old_question) | |
| template.append_message(template.roles[1], old_answer) | |
| template.append_message(template.roles[0], question) | |
| template.append_message(template.roles[1], None) | |
| query = template.get_prompt() | |
| if verbose and pixel_values is not None: | |
| image_bs = pixel_values.shape[0] | |
| print(f'dynamic ViT batch size: {image_bs}') | |
| if num_image_token is None: | |
| num_image_token = self.num_image_token | |
| for num_patches in num_patches_list: | |
| image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * num_image_token * num_patches + IMG_END_TOKEN | |
| query = query.replace('<image>', image_tokens, 1) | |
| model_inputs = tokenizer(query, return_tensors='pt') | |
| input_ids = model_inputs['input_ids'].cuda() | |
| attention_mask = model_inputs['attention_mask'].cuda() | |
| generation_config['eos_token_id'] = tokenizer.eos_token_id | |
| generation_config['pad_token_id'] = tokenizer.pad_token_id | |
| generation_output = self.generate( | |
| pixel_values=pixel_values, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| num_image_token=num_image_token, | |
| **generation_config | |
| ) | |
| response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] | |
| response = response.split(template.sep)[0].strip() | |
| history.append((question, response)) | |
| if return_history: | |
| return response, history | |
| else: | |
| query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') | |
| query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') | |
| if verbose: | |
| print(query_to_print, response) | |
| return response | |
| def generate( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| input_ids: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| visual_features: Optional[torch.FloatTensor] = None, | |
| num_image_token: Optional[int] = None, | |
| generation_config: Optional[GenerationConfig] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **generate_kwargs, | |
| ) -> torch.LongTensor: | |
| assert self.img_context_token_id is not None | |
| if pixel_values is not None: | |
| if visual_features is not None: | |
| vit_embeds = visual_features | |
| else: | |
| vit_embeds = self.extract_feature(pixel_values, num_image_token) | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| B, N, C = input_embeds.shape | |
| input_embeds = input_embeds.reshape(B * N, C) | |
| input_ids = input_ids.reshape(B * N) | |
| selected = (input_ids == self.img_context_token_id) | |
| assert selected.sum() != 0 | |
| input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) | |
| input_embeds = input_embeds.reshape(B, N, C) | |
| else: | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| outputs = self.language_model.generate( | |
| inputs_embeds=input_embeds, | |
| attention_mask=attention_mask, | |
| generation_config=generation_config, | |
| output_hidden_states=output_hidden_states, | |
| use_cache=True, | |
| **generate_kwargs, | |
| ) | |
| return outputs | |