Spaces:
Running on Zero
Running on Zero
| """ | |
| Gradio ZeroGPU Space for Qwen/Qwen-AgentWorld-35B-A3B (Qwen3-VL, image+text). | |
| Two interfaces: | |
| * a ChatInterface for humans | |
| * a clean, simple-typed API endpoint api_name="generate" for programmatic | |
| use via the raw REST route POST /gradio_api/call/generate | |
| (ChatInterface's own endpoints are not reliably API-callable, so we expose | |
| our own.) | |
| Runs on Hugging Face ZeroGPU; the 35B MoE is loaded 4-bit so it fits a slot. | |
| """ | |
| import os | |
| import urllib.request | |
| from io import BytesIO | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from transformers import ( | |
| AutoModelForImageTextToText, | |
| AutoProcessor, | |
| BitsAndBytesConfig, | |
| ) | |
| MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen-AgentWorld-35B-A3B") | |
| quant = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_ID, | |
| quantization_config=quant, | |
| device_map="cuda", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| tokenizer = processor.tokenizer | |
| def _load_image(src: str) -> Image.Image: | |
| if src.startswith(("http://", "https://")): | |
| with urllib.request.urlopen(src) as r: | |
| return Image.open(BytesIO(r.read())).convert("RGB") | |
| return Image.open(src).convert("RGB") | |
| def _infer(messages, images, max_new_tokens, temperature) -> str: | |
| """Runs entirely inside a @spaces.GPU process (real GPU attached).""" | |
| prompt = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| text=[prompt], images=images or None, return_tensors="pt" | |
| ).to(model.device) | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=temperature > 0, | |
| temperature=max(float(temperature), 0.01), | |
| top_p=0.8, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| ) | |
| gen = out[0][inputs["input_ids"].shape[1]:] | |
| return tokenizer.decode(gen, skip_special_tokens=True) | |
| # --- Human chat handler ---------------------------------------------------- | |
| def chat(message, history, max_new_tokens=512, temperature=0.7): | |
| text = message.get("text", "") if isinstance(message, dict) else str(message) | |
| files = message.get("files", []) if isinstance(message, dict) else [] | |
| messages = [] | |
| for turn in history: | |
| content = turn.get("content") | |
| if isinstance(content, str) and content.strip(): | |
| messages.append({"role": turn["role"], "content": content}) | |
| images = [_load_image(f) for f in files] | |
| user = [{"type": "image"} for _ in images] + [{"type": "text", "text": text}] | |
| messages.append({"role": "user", "content": user}) | |
| return _infer(messages, images, max_new_tokens, temperature) | |
| # --- Programmatic API endpoint --------------------------------------------- | |
| def generate(text: str, image_url: str = "", max_new_tokens: float = 512, | |
| temperature: float = 0.7) -> str: | |
| """Single-turn generation. Pass image_url='' for text-only.""" | |
| images = [_load_image(image_url)] if image_url else [] | |
| user = [{"type": "image"} for _ in images] + [{"type": "text", "text": text}] | |
| messages = [{"role": "user", "content": user}] | |
| return _infer(messages, images, max_new_tokens, temperature) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Qwen-AgentWorld-35B-A3B (ZeroGPU)\nImage + text, 4-bit. " | |
| "Programmatic API: `api_name=\"/generate\"`.") | |
| gr.ChatInterface( | |
| fn=chat, | |
| type="messages", | |
| multimodal=True, | |
| additional_inputs=[ | |
| gr.Slider(64, 2048, value=512, step=64, label="max_new_tokens"), | |
| gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="temperature"), | |
| ], | |
| ) | |
| # API-only endpoint with simple types -> clean schema, raw /call works. | |
| a_text = gr.Textbox(visible=False) | |
| a_img = gr.Textbox(visible=False) | |
| a_max = gr.Number(value=512, visible=False) | |
| a_temp = gr.Number(value=0.7, visible=False) | |
| a_out = gr.Textbox(visible=False) | |
| a_btn = gr.Button(visible=False) | |
| a_btn.click(generate, [a_text, a_img, a_max, a_temp], a_out, api_name="generate") | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |