Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import requests | |
| import spaces | |
| from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList | |
| base_model_id = "Andres77872/SmolVLM-500M-anime-caption-v0.2" | |
| processor = AutoProcessor.from_pretrained(base_model_id) | |
| model = Idefics3ForConditionalGeneration.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch.bfloat16 | |
| ).to("cuda:0") | |
| class StopOnTokens(StoppingCriteria): | |
| def __init__(self, tokenizer, stop_sequence): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.stop_sequence = stop_sequence | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| new_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
| max_keep = len(self.stop_sequence) + 10 | |
| if len(new_text) > max_keep: | |
| new_text = new_text[-max_keep:] | |
| return self.stop_sequence in new_text | |
| def caption_anime_image_stream(image): | |
| if image is None: | |
| yield "Please upload an image." | |
| return | |
| question = "describe the image" | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": question} | |
| ] | |
| } | |
| ] | |
| max_image_size = processor.image_processor.max_image_size["longest_edge"] | |
| size = processor.image_processor.size.copy() | |
| if "longest_edge" in size and size["longest_edge"] > max_image_size: | |
| size["longest_edge"] = max_image_size | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor(text=[prompt], images=[[image]], return_tensors='pt', padding=True, size=size) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| stop_sequence = "</RATING>" | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| custom_stopping_criteria = StoppingCriteriaList([ | |
| StopOnTokens(processor.tokenizer, stop_sequence) | |
| ]) | |
| with torch.no_grad(): | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| do_sample=False, | |
| max_new_tokens=1024, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| stopping_criteria=custom_stopping_criteria, | |
| ) | |
| import threading | |
| generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
| generation_thread.start() | |
| caption = "" | |
| for new_text in streamer: | |
| caption += new_text | |
| yield caption.strip() | |
| generation_thread.join() | |
| demo = gr.Interface( | |
| caption_anime_image_stream, | |
| inputs=gr.Image(type="pil", label="Anime Image"), | |
| outputs=gr.Textbox(lines=8, label="Caption"), | |
| title="SmolVLM-500M-anime-caption-v0.2 Demo", | |
| description="Upload an anime-style image to generate a caption.", | |
| # Enable live streaming: | |
| allow_flagging="auto", | |
| examples=None, | |
| ) | |
| demo.queue() | |
| demo.launch() |