import os import gradio as gr from huggingface_hub import InferenceClient from huggingface_hub import get_token as hf_get_token from gradio.context import LocalContext import tempfile import uuid def _get_user_token() -> str | None: """ Get the logged-in user's HF OAuth token from the current request session. On Spaces: returns the real OAuth access_token after HF login. Locally: the OAuth flow mocks the token, so we fall back to hf_get_token(). """ try: request = LocalContext.request.get(None) if request is not None: session = getattr(request, "session", {}) oauth_info = session.get("oauth_info", {}) if oauth_info: token = oauth_info.get("access_token") # Skip the local dev mock token — fall through to hf_get_token() if token and token != "mock-oauth-token-for-local-dev": return token except Exception: pass # Fallback: use the locally saved HF token (hf auth login / HF_TOKEN env var) try: return hf_get_token() except Exception: return None def generate_prompt(concept: str) -> str: """ Expands a simple concept into a detailed image prompt using the NVIDIA Nemotron model. Uses the signed-in user's HF OAuth token for inference provider billing. """ if not concept: return "a ginger cat wearing a tiny wizard hat reading a spellbook" try: token = _get_user_token() client = InferenceClient( provider="together", api_key=token, ) system_instruction = ( "You are an expert prompt engineer for text-to-image models. " "Your task is to take a simple concept and expand it into a detailed, " "vivid, and high-quality image prompt for FLUX.1-dev. " "Describe the scene, lighting, materials, and aesthetic in detail. " "Provide ONLY the final prompt text. Do not include any introductory or concluding text, " "do not provide multiple options, and do not wrap the prompt in quotes." ) messages = [ {"role": "system", "content": system_instruction}, {"role": "user", "content": f"Concept: {concept}"} ] response = client.chat_completion( model="nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-NVFP4", messages=messages, temperature=0.7, max_tokens=256 ) result = response.choices[0].message.content clean_result = str(result).strip() if clean_result.startswith('"') and clean_result.endswith('"'): clean_result = clean_result[1:-1] elif clean_result.startswith("'") and clean_result.endswith("'"): clean_result = clean_result[1:-1] return clean_result except Exception as e: print(f"Error calling Nemotron model: {e}") return f"A detailed, high-quality, professional commercial product photograph of {concept}" def generate_z_image(prompt: str) -> dict: """ Generates an image from a prompt using the Tongyi-MAI/Z-Image-Turbo model. Uses the signed-in user's HF OAuth token for inference provider billing. Returns a dictionary structure compatible with Gradio's image viewer. """ if not prompt: prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook" try: token = _get_user_token() client = InferenceClient( provider="auto", api_key=token, ) image = client.text_to_image( prompt, model="Tongyi-MAI/Z-Image-Turbo", ) filepath = os.path.join(tempfile.gettempdir(), f"{uuid.uuid4()}.png") image.save(filepath) return { "path": filepath, "url": f"/gradio_api/file={filepath}", "is_file": True } except Exception as e: print(f"Error calling Z-Image-Turbo model: {e}") raise e demo = gr.Workflow(bind=[generate_prompt, generate_z_image]) if __name__ == "__main__": demo.launch()