chabab's picture
Upload folder using huggingface_hub
a5f84e5 verified
Raw
History Blame Contribute Delete
4.46 kB
"""
Gradio ZeroGPU Space for Qwen/Qwen-AgentWorld-35B-A3B (Qwen3-VL, image+text).
Two interfaces:
* a ChatInterface for humans
* a clean, simple-typed API endpoint api_name="generate" for programmatic
use via the raw REST route POST /gradio_api/call/generate
(ChatInterface's own endpoints are not reliably API-callable, so we expose
our own.)
Runs on Hugging Face ZeroGPU; the 35B MoE is loaded 4-bit so it fits a slot.
"""
import os
import urllib.request
from io import BytesIO
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import (
AutoModelForImageTextToText,
AutoProcessor,
BitsAndBytesConfig,
)
MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen-AgentWorld-35B-A3B")
quant = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
quantization_config=quant,
device_map="cuda",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model.eval()
tokenizer = processor.tokenizer
def _load_image(src: str) -> Image.Image:
if src.startswith(("http://", "https://")):
with urllib.request.urlopen(src) as r:
return Image.open(BytesIO(r.read())).convert("RGB")
return Image.open(src).convert("RGB")
def _infer(messages, images, max_new_tokens, temperature) -> str:
"""Runs entirely inside a @spaces.GPU process (real GPU attached)."""
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[prompt], images=images or None, return_tensors="pt"
).to(model.device)
out = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=temperature > 0,
temperature=max(float(temperature), 0.01),
top_p=0.8,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
gen = out[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(gen, skip_special_tokens=True)
# --- Human chat handler ----------------------------------------------------
@spaces.GPU(duration=120)
def chat(message, history, max_new_tokens=512, temperature=0.7):
text = message.get("text", "") if isinstance(message, dict) else str(message)
files = message.get("files", []) if isinstance(message, dict) else []
messages = []
for turn in history:
content = turn.get("content")
if isinstance(content, str) and content.strip():
messages.append({"role": turn["role"], "content": content})
images = [_load_image(f) for f in files]
user = [{"type": "image"} for _ in images] + [{"type": "text", "text": text}]
messages.append({"role": "user", "content": user})
return _infer(messages, images, max_new_tokens, temperature)
# --- Programmatic API endpoint ---------------------------------------------
@spaces.GPU(duration=120)
def generate(text: str, image_url: str = "", max_new_tokens: float = 512,
temperature: float = 0.7) -> str:
"""Single-turn generation. Pass image_url='' for text-only."""
images = [_load_image(image_url)] if image_url else []
user = [{"type": "image"} for _ in images] + [{"type": "text", "text": text}]
messages = [{"role": "user", "content": user}]
return _infer(messages, images, max_new_tokens, temperature)
with gr.Blocks() as demo:
gr.Markdown("# Qwen-AgentWorld-35B-A3B (ZeroGPU)\nImage + text, 4-bit. "
"Programmatic API: `api_name=\"/generate\"`.")
gr.ChatInterface(
fn=chat,
type="messages",
multimodal=True,
additional_inputs=[
gr.Slider(64, 2048, value=512, step=64, label="max_new_tokens"),
gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="temperature"),
],
)
# API-only endpoint with simple types -> clean schema, raw /call works.
a_text = gr.Textbox(visible=False)
a_img = gr.Textbox(visible=False)
a_max = gr.Number(value=512, visible=False)
a_temp = gr.Number(value=0.7, visible=False)
a_out = gr.Textbox(visible=False)
a_btn = gr.Button(visible=False)
a_btn.click(generate, [a_text, a_img, a_max, a_temp], a_out, api_name="generate")
if __name__ == "__main__":
demo.queue().launch()