qwen-agentworld-35b-zerogpu / openai_shim.py
chabab's picture
Upload folder using huggingface_hub
a5f84e5 verified
Raw
History Blame Contribute Delete
4.74 kB
"""
Local OpenAI-compatible /v1 shim in front of the ZeroGPU Gradio Space.
Talks to the Space's dedicated api_name="/generate" endpoint over the raw
Gradio REST route (no gradio_client, so the Space's broken schema introspection
doesn't matter).
pip install fastapi "uvicorn[standard]" httpx
export HF_TOKEN=hf_xxx # your Pro token -> your ZeroGPU quota
python openai_shim.py # listens on :11346
Point any OpenAI client at http://localhost:11346/v1
"""
import json
import os
import time
import uuid
from typing import List, Optional, Union
import httpx
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
SPACE_SUB = os.environ.get("SPACE_SUB", "chabab-qwen-agentworld-35b-zerogpu")
BASE = f"https://{SPACE_SUB}.hf.space/gradio_api/call/generate"
MODEL_NAME = os.environ.get("MODEL_ID", "Qwen/Qwen-AgentWorld-35B-A3B")
PORT = int(os.environ.get("PORT", "11346"))
HEADERS = {"Authorization": f"Bearer {os.environ.get('HF_TOKEN', '')}",
"Content-Type": "application/json"}
class ChatMessage(BaseModel):
role: str
content: Union[str, List[dict]]
class ChatRequest(BaseModel):
model: str = MODEL_NAME
messages: List[ChatMessage]
temperature: float = 0.7
max_tokens: int = 512
stream: bool = False
def _flatten(messages: List[ChatMessage]):
"""Return (prompt_text, image_url). Folds history into the prompt; takes the
first image_url found on the last user message."""
image_url = ""
lines = []
for i, m in enumerate(messages):
if isinstance(m.content, str):
lines.append(f"{m.role}: {m.content}")
continue
texts = []
for part in m.content:
if part.get("type") == "text":
texts.append(part.get("text", ""))
elif part.get("type") == "image_url" and i == len(messages) - 1:
image_url = image_url or part["image_url"]["url"]
lines.append(f"{m.role}: {' '.join(texts)}")
return "\n".join(lines), image_url
def _call_space(text: str, image_url: str, max_tokens: int, temperature: float) -> str:
with httpx.Client(timeout=300) as c:
r = c.post(BASE, headers=HEADERS,
json={"data": [text, image_url, max_tokens, temperature]})
r.raise_for_status()
event_id = r.json()["event_id"]
# Stream the SSE result; the final "complete" event carries the output.
result = ""
with c.stream("GET", f"{BASE}/{event_id}", headers=HEADERS) as s:
event = None
for line in s.iter_lines():
if line.startswith("event:"):
event = line.split(":", 1)[1].strip()
elif line.startswith("data:") and event == "complete":
payload = json.loads(line[5:].strip())
result = payload[0] if isinstance(payload, list) else payload
elif event == "error":
raise HTTPException(502, "Space returned an error event")
return result or ""
app = FastAPI(title="OpenAI shim -> ZeroGPU Space")
@app.get("/v1/models")
def models():
return {"object": "list", "data": [{"id": MODEL_NAME, "object": "model"}]}
@app.post("/v1/chat/completions")
def chat_completions(req: ChatRequest):
text, image_url = _flatten(req.messages)
answer = _call_space(text, image_url, req.max_tokens, req.temperature)
cid = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
if req.stream:
def sse():
chunk = {"id": cid, "object": "chat.completion.chunk", "created": created,
"model": req.model,
"choices": [{"index": 0, "delta": {"role": "assistant",
"content": answer}, "finish_reason": None}]}
yield f"data: {json.dumps(chunk)}\n\n"
done = {"id": cid, "object": "chat.completion.chunk", "created": created,
"model": req.model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}
yield f"data: {json.dumps(done)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(sse(), media_type="text/event-stream")
return {"id": cid, "object": "chat.completion", "created": created,
"model": req.model,
"choices": [{"index": 0,
"message": {"role": "assistant", "content": answer},
"finish_reason": "stop"}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=PORT)