import os import json from fastapi import Request from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse from gradio import Server from gradio.oauth import attach_oauth, _get_valid_oauth_info_from_session from openai import AsyncOpenAI from pydantic import BaseModel from typing import List # Initialize gradio.Server (which is a subclass of FastAPI) app = Server() # Delegate Hugging Face OAuth to Gradio's battle-tested implementation. # attach_oauth(app) registers /login/huggingface, /login/callback, /logout, AND # the SessionMiddleware. After it runs, the user's token + profile are stored # in request.session['oauth_info'] under the keys authlib returns # (access_token, expires_at, userinfo, ...). attach_oauth(app) # Define request schemas class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] model: str = "zai-org/GLM-5.2:fireworks-ai" temperature: float = 0.7 max_tokens: int = 2048 @app.get("/me") async def me(request: Request): """Expose the current user's profile (or null) to the frontend.""" oauth_info = _get_valid_oauth_info_from_session(request.session) or {} user_info = oauth_info.get("userinfo") or {} if not oauth_info or not user_info: return JSONResponse({"user": None}) return JSONResponse({"user": user_info}) @app.post("/api/chat") async def chat_endpoint(request: Request, payload: ChatRequest): # Determine which token pays for this inference: # 1. The logged-in user's OAuth access token (bills the USER — the goal). # 2. Fall back to the host's HF_TOKEN (bills the Space owner) if configured. oauth_info = _get_valid_oauth_info_from_session(request.session) or {} user_token = oauth_info.get("access_token", "") or "" host_token = os.environ.get("HF_TOKEN", "").strip() api_key = (user_token or host_token).strip() # Track who is being billed so the UI can show it. billed_to = "user" if user_token else ("host" if host_token else "none") async def error_generator(msg): yield "data: " + json.dumps({"error": msg}) + "\n\n" yield "data: [DONE]\n\n" headers = { "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", } if not api_key: return StreamingResponse( error_generator( "Not signed in. Click the Hugging Face button in the sidebar to sign in " "and run inference billed to your own HF account." ), media_type="text/event-stream", headers=headers, ) # Initialize AsyncOpenAI client pointing at Hugging Face Router. # When authenticated with the user's token, the HF router bills that user. client = AsyncOpenAI( base_url="https://router.huggingface.co/v1", api_key=api_key, ) async def event_generator(): try: # Call completion API asynchronously with stream=True stream = await client.chat.completions.create( model="zai-org/GLM-5.2:fireworks-ai", messages=[{"role": m.role, "content": m.content} for m in payload.messages], temperature=payload.temperature, max_tokens=payload.max_tokens, stream=True, ) # Yield chunks as they arrive asynchronously async for chunk in stream: if chunk.choices and len(chunk.choices) > 0: delta_content = chunk.choices[0].delta.content if delta_content: yield f"data: {json.dumps({'content': delta_content, 'billedTo': billed_to})}\n\n" # Signal stream completion yield "data: [DONE]\n\n" except Exception as e: # Yield any error occurring during streaming yield f"data: {json.dumps({'error': str(e)})}\n\n" yield "data: [DONE]\n\n" # Configure headers to bypass buffer layers in reverse proxies (like Nginx/Cloudflare) return StreamingResponse(event_generator(), media_type="text/event-stream", headers=headers) @app.get("/", response_class=HTMLResponse) async def homepage(): html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") if not os.path.exists(html_path): return HTMLResponse("