Andres77872's picture
Update app.py
0249c44 verified
Raw
History Blame Contribute Delete
3.15 kB
import gradio as gr
import torch
from PIL import Image
import requests
import spaces
from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
base_model_id = "Andres77872/SmolVLM-500M-anime-caption-v0.2"
processor = AutoProcessor.from_pretrained(base_model_id)
model = Idefics3ForConditionalGeneration.from_pretrained(
base_model_id,
torch_dtype=torch.bfloat16
).to("cuda:0")
class StopOnTokens(StoppingCriteria):
def __init__(self, tokenizer, stop_sequence):
super().__init__()
self.tokenizer = tokenizer
self.stop_sequence = stop_sequence
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
new_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
max_keep = len(self.stop_sequence) + 10
if len(new_text) > max_keep:
new_text = new_text[-max_keep:]
return self.stop_sequence in new_text
@spaces.GPU
def caption_anime_image_stream(image):
if image is None:
yield "Please upload an image."
return
question = "describe the image"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": question}
]
}
]
max_image_size = processor.image_processor.max_image_size["longest_edge"]
size = processor.image_processor.size.copy()
if "longest_edge" in size and size["longest_edge"] > max_image_size:
size["longest_edge"] = max_image_size
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=[prompt], images=[[image]], return_tensors='pt', padding=True, size=size)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
stop_sequence = "</RATING>"
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
custom_stopping_criteria = StoppingCriteriaList([
StopOnTokens(processor.tokenizer, stop_sequence)
])
with torch.no_grad():
generation_kwargs = dict(
**inputs,
streamer=streamer,
do_sample=False,
max_new_tokens=1024,
pad_token_id=processor.tokenizer.pad_token_id,
stopping_criteria=custom_stopping_criteria,
)
import threading
generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
generation_thread.start()
caption = ""
for new_text in streamer:
caption += new_text
yield caption.strip()
generation_thread.join()
demo = gr.Interface(
caption_anime_image_stream,
inputs=gr.Image(type="pil", label="Anime Image"),
outputs=gr.Textbox(lines=8, label="Caption"),
title="SmolVLM-500M-anime-caption-v0.2 Demo",
description="Upload an anime-style image to generate a caption.",
# Enable live streaming:
allow_flagging="auto",
examples=None,
)
demo.queue()
demo.launch()