File size: 4,463 Bytes
020f30f
a5f84e5
020f30f
a5f84e5
 
 
 
 
 
 
 
020f30f
 
 
a5f84e5
 
020f30f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5f84e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
020f30f
 
a5f84e5
 
020f30f
 
 
 
 
 
 
 
 
 
a5f84e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
020f30f
 
a5f84e5
 
 
 
 
 
 
 
 
020f30f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Gradio ZeroGPU Space for Qwen/Qwen-AgentWorld-35B-A3B (Qwen3-VL, image+text).

Two interfaces:
  * a ChatInterface for humans
  * a clean, simple-typed API endpoint  api_name="generate"  for programmatic
    use via the raw REST route  POST /gradio_api/call/generate
    (ChatInterface's own endpoints are not reliably API-callable, so we expose
    our own.)

Runs on Hugging Face ZeroGPU; the 35B MoE is loaded 4-bit so it fits a slot.
"""

import os
import urllib.request
from io import BytesIO

import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import (
    AutoModelForImageTextToText,
    AutoProcessor,
    BitsAndBytesConfig,
)

MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen-AgentWorld-35B-A3B")

quant = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    quantization_config=quant,
    device_map="cuda",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)
model.eval()
tokenizer = processor.tokenizer


def _load_image(src: str) -> Image.Image:
    if src.startswith(("http://", "https://")):
        with urllib.request.urlopen(src) as r:
            return Image.open(BytesIO(r.read())).convert("RGB")
    return Image.open(src).convert("RGB")


def _infer(messages, images, max_new_tokens, temperature) -> str:
    """Runs entirely inside a @spaces.GPU process (real GPU attached)."""
    prompt = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = processor(
        text=[prompt], images=images or None, return_tensors="pt"
    ).to(model.device)
    out = model.generate(
        **inputs,
        max_new_tokens=int(max_new_tokens),
        do_sample=temperature > 0,
        temperature=max(float(temperature), 0.01),
        top_p=0.8,
        pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
    )
    gen = out[0][inputs["input_ids"].shape[1]:]
    return tokenizer.decode(gen, skip_special_tokens=True)


# --- Human chat handler ----------------------------------------------------
@spaces.GPU(duration=120)
def chat(message, history, max_new_tokens=512, temperature=0.7):
    text = message.get("text", "") if isinstance(message, dict) else str(message)
    files = message.get("files", []) if isinstance(message, dict) else []

    messages = []
    for turn in history:
        content = turn.get("content")
        if isinstance(content, str) and content.strip():
            messages.append({"role": turn["role"], "content": content})

    images = [_load_image(f) for f in files]
    user = [{"type": "image"} for _ in images] + [{"type": "text", "text": text}]
    messages.append({"role": "user", "content": user})
    return _infer(messages, images, max_new_tokens, temperature)


# --- Programmatic API endpoint ---------------------------------------------
@spaces.GPU(duration=120)
def generate(text: str, image_url: str = "", max_new_tokens: float = 512,
             temperature: float = 0.7) -> str:
    """Single-turn generation. Pass image_url='' for text-only."""
    images = [_load_image(image_url)] if image_url else []
    user = [{"type": "image"} for _ in images] + [{"type": "text", "text": text}]
    messages = [{"role": "user", "content": user}]
    return _infer(messages, images, max_new_tokens, temperature)


with gr.Blocks() as demo:
    gr.Markdown("# Qwen-AgentWorld-35B-A3B (ZeroGPU)\nImage + text, 4-bit. "
                "Programmatic API: `api_name=\"/generate\"`.")
    gr.ChatInterface(
        fn=chat,
        type="messages",
        multimodal=True,
        additional_inputs=[
            gr.Slider(64, 2048, value=512, step=64, label="max_new_tokens"),
            gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="temperature"),
        ],
    )

    # API-only endpoint with simple types -> clean schema, raw /call works.
    a_text = gr.Textbox(visible=False)
    a_img = gr.Textbox(visible=False)
    a_max = gr.Number(value=512, visible=False)
    a_temp = gr.Number(value=0.7, visible=False)
    a_out = gr.Textbox(visible=False)
    a_btn = gr.Button(visible=False)
    a_btn.click(generate, [a_text, a_img, a_max, a_temp], a_out, api_name="generate")


if __name__ == "__main__":
    demo.queue().launch()