import os import gradio as gr import torch import spaces from huggingface_hub import snapshot_download from diffsynth import ModelManager, save_video, WanVideoPipeline WAN_REPO_ID = "dsr2026/wan" WAN_LOCAL_DIR = "../wan_model" LORA_PATH = "./step=02400.lora_only.ckpt" # "" to disable OUT_DIR = "outputs" NEGATIVE_PROMPT = ( "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, " "incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, " "misshapen limbs, fused fingers, still picture, messy background, three legs, many people " "in the background, walking backwards" ) NUM_FRAMES = 81 NUM_INFERENCE_STEPS = 50 FPS = 16 QUALITY = 5 _PIPE = None _MODEL_FILES = None DSR_EXAMPLES = [ "In a quiet forest clearing, a squirrel is on the left of a lamp, then the squirrel scampers to the right of the lamp.", "At the edge of a sunny meadow, a dog is on the left of a bucket, then the dog runs to the right of the bucket.", "On a rocky hillside with moss, a fox is on the left of a chair, then the fox sprints to the right of the chair.", ] INTRODUCTION = ''' This demo is for SpatialAlign: Aligning Dynamic Spatial Relationships in Video Generation. Users can specify a Dynamic Spatial Relationship prompt to generate videos, following the template: 〈scene〉, the 〈animal〉 is 〈initial SSR〉 the 〈static object〉, then the 〈animal〉 〈verb〉 〈final SSR〉 the 〈static object〉. Here, the choice of SSR can be from ['the left of', 'the right of', 'the top of']. For the initial SSR, an 'on' should be put in the front. For the final SSR, an 'to' should be put in the front. Examples are provided for better reference. ''' def _ensure_wan_downloaded(): global _MODEL_FILES if _MODEL_FILES is not None: return _MODEL_FILES local_dir = snapshot_download( repo_id=WAN_REPO_ID, local_dir=WAN_LOCAL_DIR, local_dir_use_symlinks=False, ) model_files = [ os.path.join(local_dir, "diffusion_pytorch_model.safetensors"), os.path.join(local_dir, "models_t5_umt5-xxl-enc-bf16.pth"), os.path.join(local_dir, "Wan2.1_VAE.pth"), ] missing = [p for p in model_files if not os.path.exists(p)] if missing: raise FileNotFoundError("Missing model files:\n" + "\n".join(missing)) _MODEL_FILES = model_files return _MODEL_FILES @spaces.GPU(duration=240) def generate(prompt: str, seed: int): global _PIPE if not prompt or not prompt.strip(): return None if _PIPE is None: device = "cuda" if torch.cuda.is_available() else "cpu" model_files = _ensure_wan_downloaded() mm = ModelManager(device="cpu") mm.load_models(model_files, torch_dtype=torch.bfloat16) if LORA_PATH and os.path.exists(LORA_PATH): mm.load_lora(LORA_PATH, lora_alpha=1.0) pipe = WanVideoPipeline.from_model_manager(mm, torch_dtype=torch.bfloat16, device=device) pipe.enable_vram_management(num_persistent_param_in_dit=None) _PIPE = pipe os.makedirs(OUT_DIR, exist_ok=True) out_path = os.path.join(OUT_DIR, f"seed{int(seed):02d}.mp4") video = _PIPE( prompt=prompt, negative_prompt=NEGATIVE_PROMPT, num_inference_steps=NUM_INFERENCE_STEPS, seed=int(seed), tiled=False, num_frames=NUM_FRAMES, ) save_video(video, out_path, fps=FPS, quality=QUALITY) return out_path CSS = """ /* Make example buttons look like clickable prompt cards */ #examples-col button{ white-space: normal !important; text-align: left !important; line-height: 1.35 !important; padding: 12px 12px !important; border-radius: 10px !important; #main-row { align-items: flex-start !important; } } """ # Gradio 6.x: pass css=... to launch() to avoid the warning. with gr.Blocks(title="SpatialAlign Demo", css=CSS) as demo: gr.Markdown("## SpatialAlign Demo") # We'll create example buttons first (to keep layout order), # then bind their click handlers AFTER `prompt` is defined. example_buttons = [] with gr.Row(elem_id="main-row"): with gr.Column(scale=4): gr.Markdown("### Introduction") gr.Markdown(INTRODUCTION) # leave blank for now with gr.Column(scale=3, elem_id="examples-col"): gr.Markdown("### Propmt Examples (click to fill prompt)") for p in DSR_EXAMPLES: b = gr.Button(p) example_buttons.append((b, p)) with gr.Column(scale=3): gr.Markdown("### Generate") prompt = gr.Textbox(label="prompt", lines=6, placeholder="Describe a dynamic spatial relationship...") seed = gr.Number(label="seed", value=0, precision=0) btn = gr.Button("Generate") vid = gr.Video(label="output") btn.click(generate, inputs=[prompt, seed], outputs=vid) # Bind events after `prompt` exists (fixes NameError; keeps layout order). for b, p in example_buttons: b.click(fn=lambda _p=p: _p, inputs=None, outputs=prompt) demo.queue().launch( server_name="0.0.0.0", server_port=7860, ssr_mode=False, # css=CSS, )