""" app.py - Hugging Face Spaces entry point. Automatically downloads model checkpoints on first run and serves a Gradio interface for AI object removal. """ import os import sys # Ensure project root is on the path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) # -- Step 1: Download checkpoints if needed ---------------------------------- print("Checking model checkpoints...") from download_checkpoints import main as download_checkpoints download_checkpoints() # -- Step 2: Load the pipeline ----------------------------------------------- import gradio as gr from pipeline import ObjectRemovalPipeline from config import CLIP_SIMILARITY_THRESHOLD print("Initializing AI models (GroundingDINO, SAM, CLIP)...") pipeline = ObjectRemovalPipeline() print("All models loaded. Ready!") # -- Step 3: Define Gradio callback ------------------------------------------ def predict(scene_img_path, object_img_paths, threshold): if scene_img_path is None: return None, "Error: Please upload a scene image." if not object_img_paths: return None, "Error: Please upload at least one object reference image." # object_img_paths is a list of file paths from gr.File if isinstance(object_img_paths, str): object_img_paths = [object_img_paths] try: result_pil = pipeline.run( scene_path=scene_img_path, object_paths=object_img_paths, threshold=threshold, save_debug=False # avoid cluttering cloud storage ) return result_pil, "Done! Object(s) removed successfully." except Exception as e: import traceback error_trace = traceback.format_exc() return None, f"Error:\n{error_trace}" # -- Step 4: Build Gradio UI ------------------------------------------------- with gr.Blocks(title="AI Object Eraser") as demo: gr.Markdown("# AI Object Eraser") gr.Markdown( "Upload a **scene photo** and one or more **reference photos** of the objects " "you want to remove. The AI will detect and erase them for you." ) with gr.Row(): with gr.Column(): scene_input = gr.Image( label="Scene Image", type="filepath" ) object_input = gr.File( label="Object Reference(s) to Remove", file_count="multiple" ) threshold_slider = gr.Slider( minimum=0.1, maximum=1.0, value=CLIP_SIMILARITY_THRESHOLD, step=0.05, label="Detection Sensitivity (lower = more detections)" ) run_btn = gr.Button("Remove Objects", variant="primary") with gr.Column(): output_img = gr.Image(label="Result") status_box = gr.Textbox(label="Status") run_btn.click( fn=predict, inputs=[scene_input, object_input, threshold_slider], outputs=[output_img, status_box] ) # -- Step 5: Launch ---------------------------------------------------------- if __name__ == "__main__": # Port 7860 is mandatory for Hugging Face Spaces demo.launch(server_name="0.0.0.0", server_port=7860)