Spaces:
Sleeping
Sleeping
| import os | |
| import spaces # HF ZeroGPU — must be imported before torch. | |
| import logging | |
| import shutil | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import threading | |
| import time | |
| from contextlib import contextmanager | |
| from pathlib import Path | |
| from typing import Any | |
| import cv2 | |
| import hydra | |
| import numpy as np | |
| import torch | |
| from dotenv import load_dotenv | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from gradio import Server | |
| from gradio.data_classes import FileData | |
| from hydra.core.global_hydra import GlobalHydra | |
| from PIL import Image | |
| load_dotenv() | |
| # =========================================== | |
| # LOGGING | |
| # =========================================== | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # =========================================== | |
| # PATHS | |
| # =========================================== | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| project_root = current_dir | |
| if project_root not in sys.path: | |
| sys.path.insert(0, project_root) | |
| try: | |
| from eneas.segmentation import UniqueInstanceSegmenter, GenericCategorySegmenter | |
| from eneas.segmentation.model_manager import ModelManager | |
| except ImportError as e: | |
| logger.error(f"Error importing ENEAS: {e}") | |
| raise e | |
| # =========================================== | |
| # HYDRA CONTEXT MANAGER | |
| # =========================================== | |
| def hydra_environment(config_module: str): | |
| if GlobalHydra.instance().is_initialized(): | |
| GlobalHydra.instance().clear() | |
| try: | |
| logger.info(f"Initializing Hydra for module: {config_module}") | |
| hydra.initialize_config_module(config_module=config_module, version_base="1.2") | |
| yield | |
| except Exception as e: | |
| logger.error(f"Hydra initialization error: {e}") | |
| raise e | |
| finally: | |
| if GlobalHydra.instance().is_initialized(): | |
| GlobalHydra.instance().clear() | |
| # =========================================== | |
| # CONSTANTS & ENV | |
| # =========================================== | |
| SPACE_SAMPLING = "1 FPS (Space limit)" | |
| # Locally the user picks the sampling rate and there is no frame cap. | |
| # The HF Space sets SAMPLING_LOCKED = True (forces 1 FPS) and MAX_FRAMES = 150. | |
| SAMPLING_LOCKED = True | |
| MAX_FRAMES = 150 | |
| PREFETCH_MODELS = True | |
| OLLAMA_HOST = os.getenv("OLLAMA_HOST", "127.0.0.1:11434") | |
| OLLAMA_URL = f"http://{OLLAMA_HOST}" | |
| OLLAMA_BIN = os.getenv("OLLAMA_BIN", os.path.join(project_root, "bin", "ollama")) | |
| OLLAMA_LOG_PATH = os.path.join(tempfile.gettempdir(), "ollama_serve.log") | |
| HF_READY_FLAG = os.path.join(tempfile.gettempdir(), "eneas_hf_ready") | |
| OLLAMA_READY_FLAG = os.path.join(tempfile.gettempdir(), "eneas_ollama_ready") | |
| OUTPUT_BASE_DIR = str(Path(current_dir) / "gradio_outputs") | |
| os.makedirs(OUTPUT_BASE_DIR, exist_ok=True) | |
| VLM_MODELS = [ | |
| "qwen3-vl:4b-instruct-q8_0", | |
| "qwen3-vl:2b-instruct-q8_0" | |
| ] | |
| class CompletedStartupTask: | |
| def is_alive(self): | |
| return False | |
| def join(self): | |
| return None | |
| # =========================================== | |
| # OLLAMA FUNCTIONS (run inside the GPU worker) | |
| # =========================================== | |
| def get_ollama_env(): | |
| """Get environment variables for Ollama process with GPU support.""" | |
| env = os.environ.copy() | |
| env["OLLAMA_HOST"] = OLLAMA_HOST | |
| env["OLLAMA_ORIGINS"] = "*" | |
| env["HOME"] = os.getcwd() | |
| # Add local lib path for the extracted binary | |
| cwd = os.getcwd() | |
| lib_path = f"{cwd}/lib" | |
| if "LD_LIBRARY_PATH" in env: | |
| env["LD_LIBRARY_PATH"] += f":{lib_path}" | |
| else: | |
| env["LD_LIBRARY_PATH"] = lib_path | |
| return env | |
| def is_ollama_server_running() -> bool: | |
| """Check if Ollama server is responding.""" | |
| try: | |
| result = subprocess.run( | |
| ["curl", "-s", "-o", "/dev/null", "-w", "%{http_code}", OLLAMA_URL], | |
| capture_output=True, | |
| text=True, | |
| timeout=5 | |
| ) | |
| return result.stdout.strip() == "200" | |
| except Exception: | |
| return False | |
| def start_ollama_server_gpu(): | |
| """ | |
| Start Ollama server inside the GPU worker. | |
| This ensures Ollama detects and uses the GPU. | |
| Returns: | |
| bool: True if server started successfully | |
| """ | |
| if is_ollama_server_running(): | |
| logger.info("Ollama server is already running.") | |
| return True | |
| logger.info("Starting Ollama server inside GPU context...") | |
| try: | |
| env = get_ollama_env() | |
| # Start server as background process; capture its logs to a file so we | |
| # can inspect GPU detection / layer offload decisions on failure. | |
| ollama_log = open(OLLAMA_LOG_PATH, "w") | |
| process = subprocess.Popen( | |
| [OLLAMA_BIN, "serve"], | |
| env=env, | |
| stdout=ollama_log, | |
| stderr=subprocess.STDOUT | |
| ) | |
| # Wait for server to be ready (max 30 seconds) | |
| max_retries = 30 | |
| for i in range(max_retries): | |
| if is_ollama_server_running(): | |
| logger.info(f"Ollama server started successfully in {i+1} seconds.") | |
| return True | |
| time.sleep(1) | |
| logger.error("Ollama server failed to start within 30 seconds.") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Failed to start Ollama server: {e}") | |
| return False | |
| def load_model_into_vram(model_name: str) -> bool: | |
| """ | |
| Load model into VRAM for faster inference. | |
| Uses keep_alive=-1 to keep model loaded. | |
| Args: | |
| model_name: Name of the Ollama model to load | |
| Returns: | |
| bool: True if model loaded successfully | |
| """ | |
| logger.info(f"Loading model {model_name} into VRAM (num_ctx=8192)...") | |
| t0 = time.time() | |
| try: | |
| # Send a minimal request to trigger model loading | |
| result = subprocess.run( | |
| [ | |
| "curl", "-s", f"{OLLAMA_URL}/api/generate", | |
| "-d", f'{{"model": "{model_name}", "prompt": "hi", "stream": false, "options": {{"num_ctx": 8192}}}}' | |
| ], | |
| capture_output=True, | |
| text=True, | |
| timeout=240 # Model load into VRAM on ZeroGPU can be slow on first call | |
| ) | |
| logger.info(f"VLM /api/generate returned in {time.time()-t0:.1f}s.") | |
| if "error" in result.stdout.lower(): | |
| logger.error(f"Error loading model: {result.stdout}") | |
| dump_ollama_log("VLM load error") | |
| return False | |
| # Set keep_alive to -1 to keep model in VRAM | |
| subprocess.run( | |
| [ | |
| "curl", "-s", f"{OLLAMA_URL}/api/generate", | |
| "-d", f'{{"model": "{model_name}", "keep_alive": -1, "options": {{"num_ctx": 8192}}}}' | |
| ], | |
| capture_output=True, | |
| timeout=10 | |
| ) | |
| logger.info(f"Model {model_name} loaded into VRAM successfully in {time.time()-t0:.1f}s.") | |
| return True | |
| except subprocess.TimeoutExpired: | |
| logger.error(f"Timeout while loading model into VRAM after {time.time()-t0:.1f}s.") | |
| dump_ollama_log("VLM load timeout") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error loading model into VRAM: {e}") | |
| dump_ollama_log("VLM load exception") | |
| return False | |
| def log_active_models(): | |
| """Log which models are currently loaded in VRAM (not just on disk).""" | |
| try: | |
| result = subprocess.run( | |
| ["curl", "-s", f"{OLLAMA_URL}/api/ps"], | |
| capture_output=True, | |
| text=True, | |
| timeout=5 | |
| ) | |
| logger.info(f"Active models in VRAM: {result.stdout}") | |
| except Exception as e: | |
| logger.warning(f"Could not get active models: {e}") | |
| def dump_ollama_log(tag: str, n: int = 60): | |
| """Log the tail of the Ollama server log (shows GPU detection / offload).""" | |
| try: | |
| with open(OLLAMA_LOG_PATH) as f: | |
| tail = f.readlines()[-n:] | |
| logger.warning(f"[{tag}] ollama server log (last {len(tail)} lines):\n" + "".join(tail)) | |
| except Exception as e: | |
| logger.warning(f"[DEBUG {tag}] could not read ollama log: {e}") | |
| def wait_for_startup_flag(path: str, label: str, timeout: int = 600): | |
| """Wait for a startup-ready sentinel file. Files survive the ZeroGPU fork | |
| (threads do NOT), so this replaces the ineffective t_*.join(). No-op locally | |
| (PREFETCH_MODELS False), where models load on demand.""" | |
| if not PREFETCH_MODELS: | |
| return | |
| t0 = time.time() | |
| while not os.path.exists(path) and time.time() - t0 < timeout: | |
| logger.info(f"Waiting for {label} startup to finish...") | |
| time.sleep(3) | |
| def ensure_model_pulled(model_name: str) -> bool: | |
| """Make sure the model is on disk before loading. The startup pull runs in a | |
| main-process thread, but ZeroGPU forks the worker (threads don't survive the | |
| fork) so t_ollama.join() can't guarantee it. /api/pull is idempotent: fast if | |
| the model is already present, waits/dedupes if a startup pull is in flight.""" | |
| try: | |
| r = subprocess.run( | |
| ["curl", "-s", f"{OLLAMA_URL}/api/pull", | |
| "-d", f'{{"model": "{model_name}", "stream": false}}'], | |
| capture_output=True, text=True, timeout=600, | |
| ) | |
| if '"error"' in r.stdout.lower(): | |
| logger.error(f"Pull failed for {model_name}: {r.stdout[:300]}") | |
| return False | |
| logger.info(f"Model {model_name} present/pulled.") | |
| return True | |
| except Exception as e: | |
| logger.warning(f"Pull check failed for {model_name}: {e}") | |
| return False | |
| def ensure_ollama_ready_gpu(model_name: str) -> bool: | |
| """ | |
| Main function to ensure Ollama is fully ready with GPU support. | |
| MUST be called inside the GPU worker. | |
| This function: | |
| 1. Starts Ollama server (which will detect GPU) | |
| 2. Loads the specified model into VRAM | |
| 3. Logs which model is active | |
| Args: | |
| model_name: Name of the Ollama model to use | |
| Returns: | |
| bool: True if ready | |
| Raises: | |
| RuntimeError: If setup fails | |
| """ | |
| logger.info(f"Ensuring Ollama is ready with GPU for model: {model_name}") | |
| # Step 1: Start server (will detect GPU since we're inside the GPU worker) | |
| if not start_ollama_server_gpu(): | |
| raise RuntimeError("Failed to start Ollama server with GPU") | |
| # Step 1b: ensure the model is on disk (cold-start race: the startup pull | |
| # thread does NOT survive the ZeroGPU fork, so don't rely on t_ollama.join()). | |
| if not ensure_model_pulled(model_name): | |
| raise RuntimeError(f"Failed to pull model {model_name}") | |
| # Step 2: Load model into VRAM | |
| if not load_model_into_vram(model_name): | |
| raise RuntimeError(f"Failed to load model {model_name} into VRAM") | |
| # Step 3: Log which model is actually active in VRAM | |
| log_active_models() | |
| logger.info("Ollama is ready with GPU support!") | |
| return True | |
| # =========================================== | |
| # STARTUP: DOWNLOAD BINARY AND MODELS (CPU) | |
| # =========================================== | |
| OLLAMA_VERSION = "0.30.2" # pinned for reproducible Space builds; no effect locally (only downloaded on the Space, with ENEAS_PREFETCH_MODELS=1) | |
| def download_ollama_binary(): | |
| """Download Ollama binary if not present.""" | |
| if os.path.exists(OLLAMA_BIN): | |
| logger.info("Ollama binary already exists.") | |
| return True | |
| logger.info("Downloading Ollama binary (ZST)...") | |
| try: | |
| subprocess.run( | |
| ["curl", "-L", f"https://github.com/ollama/ollama/releases/download/v{OLLAMA_VERSION}/ollama-linux-amd64.tar.zst", "-o", "ollama.tar.zst"], | |
| check=True, | |
| timeout=300 | |
| ) | |
| subprocess.run(["tar", "--zstd", "-xf", "ollama.tar.zst"], check=True) | |
| subprocess.run(["chmod", "+x", OLLAMA_BIN], check=True) | |
| os.remove("ollama.tar.zst") # Cleanup | |
| logger.info("Ollama binary downloaded and extracted successfully.") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to download Ollama binary: {e}") | |
| return False | |
| def pull_ollama_models(): | |
| """ | |
| Pull Ollama models at startup (runs on CPU). | |
| This pre-downloads the models so they're ready when GPU is available. | |
| """ | |
| logger.info("Pre-downloading Ollama models...") | |
| # Need to temporarily start server to pull models | |
| env = get_ollama_env() | |
| # Start server temporarily | |
| server_process = subprocess.Popen( | |
| [OLLAMA_BIN, "serve"], | |
| env=env, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE | |
| ) | |
| # Wait for server | |
| time.sleep(5) | |
| for _ in range(20): | |
| if is_ollama_server_running(): | |
| break | |
| time.sleep(1) | |
| # Pull each model | |
| for model in VLM_MODELS: | |
| logger.info(f"Pulling model: {model}") | |
| try: | |
| subprocess.run( | |
| [OLLAMA_BIN, "pull", model], | |
| env=env, | |
| timeout=600, | |
| capture_output=True | |
| ) | |
| logger.info(f"Model {model} pulled successfully.") | |
| except Exception as e: | |
| logger.warning(f"Failed to pull model {model}: {e}") | |
| # Stop server (we'll restart it inside GPU context later) | |
| server_process.terminate() | |
| try: | |
| server_process.wait(timeout=5) | |
| except subprocess.TimeoutExpired: | |
| server_process.kill() | |
| logger.info("Ollama models pre-download complete.") | |
| def setup_ollama_startup(): | |
| """Setup Ollama at startup: download binary and pull models.""" | |
| try: | |
| download_ollama_binary() | |
| pull_ollama_models() | |
| finally: | |
| open(OLLAMA_READY_FLAG, "w").close() | |
| def setup_hf_models(): | |
| """ | |
| Downloads heavy HuggingFace models to disk at startup. | |
| This prevents ZeroGPU timeouts during the first inference. | |
| """ | |
| logger.info("Starting HuggingFace models download (Warm-up)...") | |
| try: | |
| manager = ModelManager() | |
| # 1. SeC-4B (Heavy, ~15GB) | |
| logger.info("Downloading SeC-4B...") | |
| manager.download("OpenIXCLab/SeC-4B") | |
| # 2. Florence-2 (Grounding) | |
| logger.info("Downloading Florence-2...") | |
| manager.download("microsoft/Florence-2-large") | |
| # 3. SigLIP (For Generic Category) | |
| logger.info("Downloading SigLIP...") | |
| manager.download("google/siglip2-base-patch16-naflex") | |
| # 4. SAM2 Checkpoint (Direct URL) | |
| logger.info("Downloading SAM2 checkpoint...") | |
| manager.download_url( | |
| "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", | |
| "sam2.1_hiera_large.pt" | |
| ) | |
| logger.info("All HuggingFace models downloaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Error during HF model download: {e}") | |
| finally: | |
| open(HF_READY_FLAG, "w").close() | |
| # =========================================== | |
| # STARTUP: PARALLEL MODEL DOWNLOADS | |
| # =========================================== | |
| if not PREFETCH_MODELS: | |
| logger.info("Skipping background model downloads. Set ENEAS_PREFETCH_MODELS=1 to prefetch at startup.") | |
| t_hf = CompletedStartupTask() | |
| t_ollama = CompletedStartupTask() | |
| else: | |
| logger.info("Starting parallel model downloads at startup...") | |
| for _flag in (HF_READY_FLAG, OLLAMA_READY_FLAG): | |
| try: | |
| os.remove(_flag) | |
| except FileNotFoundError: | |
| pass | |
| t_hf = threading.Thread(target=setup_hf_models, daemon=True) | |
| t_ollama = threading.Thread(target=setup_ollama_startup, daemon=True) | |
| t_hf.start() | |
| t_ollama.start() | |
| # NOTE: never touch torch.cuda at module level. On ZeroGPU the GPU is only | |
| # attached inside @spaces.GPU workers; initializing CUDA in the main process | |
| # corrupts NVML and crashes later .to(cuda) calls. Use pick_device() at call time. | |
| def pick_device() -> str: | |
| """Resolve the device at call time. On HF ZeroGPU the GPU is only attached | |
| inside the worker, so module-level detection would wrongly read 'cpu' — | |
| always detect when segmentation actually runs.""" | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| # =========================================== | |
| # FRAME PROCESSING | |
| # =========================================== | |
| def target_fps_from_sampling(sampling_mode: str) -> float | None: | |
| if not sampling_mode or sampling_mode == SPACE_SAMPLING: | |
| return None | |
| if sampling_mode.endswith(" FPS"): | |
| try: | |
| return float(sampling_mode.split(" ", 1)[0]) | |
| except ValueError: | |
| return None | |
| return None | |
| def resolve_target_fps(sampling_mode: str, detected_fps: float) -> float | None: | |
| """None → keep every frame (Native). A number → resample to that FPS. | |
| Unknown / Space limit → 1 FPS.""" | |
| if sampling_mode and sampling_mode.lower().startswith("native"): | |
| return None | |
| fps = target_fps_from_sampling(sampling_mode) | |
| return 1.0 if fps is None else fps | |
| def sampling_widget_html(prefix: str) -> str: | |
| """Per-environment frame-sampling control injected into index.html.""" | |
| if SAMPLING_LOCKED: | |
| return '<div class="sampling-note">1 FPS · up to 150 frames.</div>' | |
| return ( | |
| f'<select id="{prefix}-sampling" class="sampling-select" aria-label="Frame sampling">' | |
| '<option value="Native FPS" selected>Native FPS</option>' | |
| '<option value="10 FPS">10 FPS</option>' | |
| '<option value="5 FPS">5 FPS</option>' | |
| '<option value="2 FPS">2 FPS</option>' | |
| '</select>' | |
| ) | |
| def sampling_metadata( | |
| sampling_mode: str, | |
| source_fps: float, | |
| output_fps: float, | |
| source_frames: int, | |
| saved_frames: int, | |
| is_video: bool, | |
| ) -> dict: | |
| return { | |
| "sampling_mode": sampling_mode or SPACE_SAMPLING, | |
| "source_fps": source_fps, | |
| "output_fps": output_fps, | |
| "source_frames": source_frames, | |
| "saved_frames": saved_frames, | |
| "is_video": is_video, | |
| "sampled": saved_frames < source_frames, | |
| } | |
| def frame_extraction_message(meta: dict) -> str: | |
| saved = meta["saved_frames"] | |
| output_fps = meta["output_fps"] | |
| if meta["sampled"]: | |
| source = meta["source_frames"] | |
| if meta["is_video"]: | |
| return f"Using first {saved:,} sampled frames from {source:,} source frames · {output_fps:.2f} FPS." | |
| return f"Using first {saved:,} images from {source:,} uploaded files." | |
| if meta["is_video"]: | |
| return f"{saved:,} frames extracted at {output_fps:.2f} FPS." | |
| return f"{saved:,} images loaded." | |
| def uploaded_file_path(file_obj: Any) -> str: | |
| """Resolve Gradio Blocks uploads, @gradio/client FileData, and raw paths.""" | |
| if file_obj is None: | |
| return "" | |
| if isinstance(file_obj, str): | |
| return file_obj | |
| if isinstance(file_obj, dict): | |
| return file_obj.get("path") or file_obj.get("name") or "" | |
| return getattr(file_obj, "path", None) or getattr(file_obj, "name", None) or str(file_obj) | |
| def uploaded_file_name(file_obj: Any) -> str: | |
| """Return the user-facing/original filename when Gradio keeps it separate.""" | |
| if file_obj is None: | |
| return "" | |
| if isinstance(file_obj, str): | |
| return os.path.basename(file_obj) | |
| if isinstance(file_obj, dict): | |
| return ( | |
| file_obj.get("orig_name") | |
| or os.path.basename(file_obj.get("name") or "") | |
| or os.path.basename(file_obj.get("path") or "") | |
| ) | |
| return ( | |
| getattr(file_obj, "orig_name", None) | |
| or os.path.basename(getattr(file_obj, "name", None) or "") | |
| or os.path.basename(getattr(file_obj, "path", None) or "") | |
| or str(file_obj) | |
| ) | |
| def uploaded_file_mime(file_obj: Any) -> str: | |
| if file_obj is None: | |
| return "" | |
| if isinstance(file_obj, dict): | |
| return file_obj.get("mime_type") or "" | |
| return getattr(file_obj, "mime_type", None) or "" | |
| def uploaded_file_is_video(file_obj: Any, video_extensions: set[str]) -> bool: | |
| mime = uploaded_file_mime(file_obj).lower() | |
| if mime.startswith("video/"): | |
| return True | |
| candidates = (uploaded_file_name(file_obj), uploaded_file_path(file_obj)) | |
| return any(os.path.splitext(candidate)[1].lower() in video_extensions for candidate in candidates if candidate) | |
| def file_content_looks_like_video(path: str) -> bool: | |
| """Fallback for browser uploads that arrive without filename or mime.""" | |
| if not path or not os.path.exists(path): | |
| return False | |
| cap = cv2.VideoCapture(path) | |
| try: | |
| if not cap.isOpened(): | |
| return False | |
| ret_first, _ = cap.read() | |
| if not ret_first: | |
| return False | |
| frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| if frame_count and not np.isnan(frame_count) and frame_count > 1: | |
| return True | |
| ret_second, _ = cap.read() | |
| return bool(ret_second) | |
| finally: | |
| cap.release() | |
| def uploaded_file_sort_key(file_obj: Any) -> str: | |
| return uploaded_file_name(file_obj) or uploaded_file_path(file_obj) | |
| def normalize_uploaded_files(input_data: Any) -> list: | |
| if input_data is None: | |
| return [] | |
| if isinstance(input_data, (list, tuple)): | |
| return [f for f in input_data if f is not None] | |
| return [input_data] | |
| def generic_completion_message(detections: int, frames: int, meta: dict) -> str: | |
| if meta["sampled"]: | |
| if meta["is_video"]: | |
| return ( | |
| f"Completed · {detections:,} detections across first {frames:,} sampled frames " | |
| f"from {meta['source_frames']:,} source frames." | |
| ) | |
| return ( | |
| f"Completed · {detections:,} detections across first {frames:,} images " | |
| f"from {meta['source_frames']:,} uploaded files." | |
| ) | |
| return f"Completed · {detections:,} detections across {frames:,} frames." | |
| def process_inputs_to_frames(input_data, output_folder: str, sampling_mode: str = SPACE_SAMPLING) -> tuple: | |
| """ | |
| Extract frames for the Hugging Face Space. | |
| Videos are sampled at 1 FPS and limited to MAX_FRAMES sampled frames | |
| to keep demo runtime predictable. Image uploads keep their sequence | |
| order and use the same maximum. | |
| """ | |
| if os.path.exists(output_folder): | |
| shutil.rmtree(output_folder) | |
| os.makedirs(output_folder) | |
| if SAMPLING_LOCKED: | |
| sampling_mode = SPACE_SAMPLING | |
| frame_paths = [] | |
| detected_fps = 30.0 | |
| output_fps = 1.0 | |
| source_frames = 0 | |
| video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'} | |
| input_list = normalize_uploaded_files(input_data) | |
| if not input_list: | |
| meta = sampling_metadata(SPACE_SAMPLING, detected_fps, output_fps, 0, 0, False) | |
| return output_folder, [], output_fps, meta | |
| first_input = input_list[0] | |
| first_file = uploaded_file_path(first_input) | |
| is_video = uploaded_file_is_video(first_input, video_extensions) or file_content_looks_like_video(first_file) | |
| logger.info( | |
| "Resolved input: path=%s name=%s mime=%s is_video=%s", | |
| first_file, | |
| uploaded_file_name(first_input), | |
| uploaded_file_mime(first_input), | |
| is_video, | |
| ) | |
| if is_video: | |
| cap = cv2.VideoCapture(first_file) | |
| video_fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames_original = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| if video_fps > 0 and not np.isnan(video_fps): | |
| detected_fps = video_fps | |
| if total_frames_original > 0 and not np.isnan(total_frames_original): | |
| source_frames = int(total_frames_original) | |
| target_fps = resolve_target_fps(sampling_mode, detected_fps) | |
| if target_fps is None: | |
| frame_interval = 1 | |
| output_fps = detected_fps | |
| else: | |
| frame_interval = max(1, int(round(detected_fps / target_fps))) | |
| output_fps = float(target_fps) | |
| logger.info(f"Sampling '{sampling_mode}': {detected_fps:.2f} FPS source -> interval {frame_interval}, output {output_fps:.2f} FPS") | |
| source_count = 0 | |
| saved_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if source_count % frame_interval == 0: | |
| filename = f"frame_{saved_count:05d}.jpg" | |
| filepath = os.path.join(output_folder, filename) | |
| cv2.imwrite(filepath, frame) | |
| frame_paths.append(filepath) | |
| saved_count += 1 | |
| if saved_count >= MAX_FRAMES: | |
| logger.warning(f"Max frame limit ({MAX_FRAMES}) reached.") | |
| break | |
| source_count += 1 | |
| cap.release() | |
| source_frames = source_frames or source_count | |
| else: | |
| detected_fps = 5.0 | |
| output_fps = detected_fps | |
| input_list.sort(key=uploaded_file_sort_key) | |
| source_frames = len(input_list) | |
| for i, f in enumerate(input_list[:MAX_FRAMES]): | |
| path = uploaded_file_path(f) | |
| try: | |
| img = Image.open(path).convert("RGB") | |
| filename = f"frame_{i:05d}.jpg" | |
| filepath = os.path.join(output_folder, filename) | |
| img.save(filepath) | |
| frame_paths.append(filepath) | |
| except Exception as e: | |
| logger.warning(f"Skipping file {path}: {e}") | |
| meta = sampling_metadata( | |
| sampling_mode, | |
| detected_fps, | |
| output_fps, | |
| source_frames, | |
| len(frame_paths), | |
| is_video, | |
| ) | |
| return output_folder, frame_paths, output_fps, meta | |
| def transcode_to_h264(src: str, dst: str) -> bool: | |
| """Re-encode a video to browser-playable H.264 (yuv420p) via ffmpeg.""" | |
| try: | |
| subprocess.run( | |
| [ | |
| "ffmpeg", "-y", "-loglevel", "error", | |
| "-i", src, | |
| "-c:v", "libx264", "-pix_fmt", "yuv420p", | |
| "-movflags", "+faststart", | |
| dst, | |
| ], | |
| check=True, | |
| ) | |
| return os.path.exists(dst) and os.path.getsize(dst) > 0 | |
| except (subprocess.CalledProcessError, FileNotFoundError) as e: | |
| logger.warning(f"H.264 transcode failed ({e}); falling back to raw output.") | |
| return False | |
| def create_video_overlay(frames_folder: str, masks_dict: dict, output_path: str, fps: float) -> str: | |
| frame_files = sorted([f for f in os.listdir(frames_folder) if f.endswith(".jpg")]) | |
| if not frame_files: | |
| return None | |
| first = cv2.imread(os.path.join(frames_folder, frame_files[0])) | |
| h, w, _ = first.shape | |
| logger.info(f"Constructing output video at {fps} FPS...") | |
| # OpenCV can only encode MPEG-4 Part 2 here, which browsers won't play — | |
| # write a raw file then transcode to H.264 below. | |
| raw_path = f"{output_path}.raw.mp4" | |
| out = cv2.VideoWriter(raw_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) | |
| mask_color = np.array([255, 100, 0], dtype=np.uint8) | |
| for i, fname in enumerate(frame_files): | |
| frame = cv2.imread(os.path.join(frames_folder, fname)) | |
| mask_overlay = np.zeros_like(frame) | |
| if i in masks_dict: | |
| masks = masks_dict[i] | |
| if isinstance(masks, np.ndarray): | |
| mask_overlay[masks > 0] = mask_color | |
| elif isinstance(masks, list): | |
| for m in masks: | |
| mask_overlay[m > 0] = mask_color | |
| if np.any(mask_overlay): | |
| frame = cv2.addWeighted(frame, 1.0, mask_overlay, 0.5, 0) | |
| out.write(frame) | |
| out.release() | |
| if transcode_to_h264(raw_path, output_path): | |
| os.remove(raw_path) | |
| else: | |
| shutil.move(raw_path, output_path) | |
| return output_path | |
| def run_unique_segmentation(input_files, points, prompt, encoder_size, offload, frame_idx, frames_dir_cache, original_fps): | |
| if not frames_dir_cache or not os.path.exists(frames_dir_cache): | |
| if not input_files: | |
| return None, status_html("idle", "Process input first."), unique_stats_html() | |
| temp_dir = tempfile.mkdtemp() | |
| frames_dir_cache, _, original_fps, _ = process_inputs_to_frames(input_files, temp_dir, SPACE_SAMPLING) | |
| wait_for_startup_flag(HF_READY_FLAG, "Hugging Face models") | |
| sec_config_module = "eneas.vendor.SeC.inference.sam2.configs" | |
| try: | |
| with hydra_environment(sec_config_module): | |
| device = pick_device() | |
| logger.info(f"Starting Unique Segmentation on {device}...") | |
| segmenter = UniqueInstanceSegmenter( | |
| sam_encoder=encoder_size, | |
| device=device | |
| ) | |
| annotation_frame = f"frame_{int(frame_idx):05d}.jpg" | |
| t_start = time.time() | |
| if points: | |
| result = segmenter.segment( | |
| frames_path=frames_dir_cache, | |
| points=points, | |
| annotation_frame=annotation_frame, | |
| offload_frames_to_gpu=offload | |
| ) | |
| elif prompt and prompt.strip(): | |
| result = segmenter.segment( | |
| frames_path=frames_dir_cache, | |
| text=prompt, | |
| annotation_frame=annotation_frame, | |
| offload_frames_to_gpu=offload | |
| ) | |
| else: | |
| return None, status_html("idle", "Provide points (click image) or text."), unique_stats_html() | |
| elapsed = time.time() - t_start | |
| out_vid = os.path.join(OUTPUT_BASE_DIR, f"unique_result_{int(time.time() * 1000)}.mp4") | |
| final_path = create_video_overlay(frames_dir_cache, result.masks, out_vid, fps=original_fps) | |
| mask_count = len(result.masks) | |
| return ( | |
| final_path, | |
| status_html("ok", f"Completed · {result.num_frames} frames segmented."), | |
| unique_stats_html(result.num_frames, mask_count, elapsed), | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in unique seg: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, status_html("err", str(e)), unique_stats_html() | |
| def run_generic_segmentation(input_files, category, accept, reject, vlm_model, sampling_mode): | |
| if not input_files: | |
| return None, status_html("idle", "No input."), stats_html() | |
| if not category or not category.strip(): | |
| return None, status_html("idle", "Specify a category."), stats_html() | |
| wait_for_startup_flag(HF_READY_FLAG, "Hugging Face models") | |
| wait_for_startup_flag(OLLAMA_READY_FLAG, "Ollama setup") | |
| sam2_config_module = "eneas.vendor.sam2" | |
| try: | |
| ensure_ollama_ready_gpu(vlm_model) | |
| temp_dir = tempfile.mkdtemp() | |
| frames_dir, frame_paths, output_fps, meta = process_inputs_to_frames(input_files, temp_dir, SPACE_SAMPLING) | |
| n_frames = len(frame_paths) | |
| with hydra_environment(sam2_config_module): | |
| logger.info(f"Starting Generic Segmentation with {vlm_model}...") | |
| segmenter = GenericCategorySegmenter( | |
| device=pick_device(), | |
| vlm_model=vlm_model | |
| ) | |
| t_start = time.time() | |
| result = segmenter.segment( | |
| frames_path=frames_dir, | |
| category=category, | |
| accept_threshold=accept, | |
| reject_threshold=reject | |
| ) | |
| elapsed = time.time() - t_start | |
| out_vid = os.path.join(OUTPUT_BASE_DIR, f"generic_result_{int(time.time() * 1000)}.mp4") | |
| final_path = create_video_overlay(frames_dir, result.masks, out_vid, fps=output_fps) | |
| count = sum(len(v) for v in result.metadata['detections'].values()) | |
| return ( | |
| final_path, | |
| status_html("ok", generic_completion_message(count, n_frames, meta)), | |
| stats_html(n_frames, count, elapsed), | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generic seg: {e}") | |
| dump_ollama_log("generic error") | |
| import traceback | |
| traceback.print_exc() | |
| return None, status_html("err", str(e)), stats_html() | |
| # =========================================== | |
| # HTML SNIPPETS (status/stats) | |
| # =========================================== | |
| def status_html(state: str, msg: str) -> str: | |
| """state: 'idle' | 'live' | 'ok' | 'err'.""" | |
| dot_cls = state if state in ("idle", "live", "ok") else "idle" | |
| color_override = ' style="background:var(--sl-signal)"' if state == "err" else "" | |
| safe = (msg or "").replace("<", "<").replace(">", ">") | |
| return ( | |
| f'<div class="status-row">' | |
| f'<span class="status-dot {dot_cls}"{color_override}></span>' | |
| f'<span class="status-msg">{safe}</span>' | |
| f'<span class="status-time">{time.strftime("%H:%M:%S")}</span>' | |
| f'</div>' | |
| ) | |
| def _latency_cell(frames: int, secs: float) -> str: | |
| lat = (secs / frames) if frames > 0 else 0.0 | |
| return f'<div class="cell"><div class="k">Latency / frame</div><div class="v">{lat:.3f} s</div></div>' | |
| def stats_html(frames: int = 0, dets: int = 0, secs: float = 0.0) -> str: | |
| return ( | |
| f'<div class="stats">' | |
| f'<div class="cell"><div class="k">Frames</div><div class="v">{frames}</div></div>' | |
| f'<div class="cell"><div class="k">Detections</div><div class="v">{dets}</div></div>' | |
| f'{_latency_cell(frames, secs)}' | |
| f'<div class="cell"><div class="k">Inference</div><div class="v">{secs:.1f} s</div></div>' | |
| f'</div>' | |
| ) | |
| def unique_stats_html(frames: int = 0, masks: int = 0, secs: float = 0.0) -> str: | |
| return ( | |
| f'<div class="stats">' | |
| f'<div class="cell"><div class="k">Frames</div><div class="v">{frames}</div></div>' | |
| f'<div class="cell"><div class="k">Masks</div><div class="v">{masks}</div></div>' | |
| f'{_latency_cell(frames, secs)}' | |
| f'<div class="cell"><div class="k">Inference</div><div class="v">{secs:.1f} s</div></div>' | |
| f'</div>' | |
| ) | |
| # =========================================== | |
| # CUSTOM GRADIO SERVER APP | |
| # =========================================== | |
| def file_data(path: str | None) -> FileData | None: | |
| if not path or not os.path.exists(path): | |
| return None | |
| return FileData(path=path) | |
| def normalize_points(points: Any) -> list[tuple[int, int]]: | |
| if not points: | |
| return [] | |
| cleaned = [] | |
| for point in points: | |
| if not point or len(point) < 2: | |
| continue | |
| try: | |
| cleaned.append((int(round(float(point[0]))), int(round(float(point[1]))))) | |
| except (TypeError, ValueError): | |
| continue | |
| return cleaned | |
| app = Server() | |
| assets_dir = Path(current_dir) / "assets" | |
| if assets_dir.exists(): | |
| app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets") | |
| outputs_dir = Path(current_dir) / OUTPUT_BASE_DIR | |
| outputs_dir.mkdir(exist_ok=True) | |
| app.mount("/outputs", StaticFiles(directory=str(outputs_dir)), name="outputs") | |
| async def homepage(): | |
| html_path = Path(current_dir) / "index.html" | |
| if not html_path.exists(): | |
| return HTMLResponse("<h1>index.html is missing</h1>", status_code=500) | |
| html = html_path.read_text(encoding="utf-8") | |
| # Cache-bust static assets by file mtime so browsers always pick up the | |
| # latest CSS/JS after a deploy (no stale-cache surprises). | |
| for rel in ("eneas.css", "eneas.js"): | |
| asset = assets_dir / rel | |
| if asset.exists(): | |
| html = html.replace(f"/assets/{rel}", f"/assets/{rel}?v={int(asset.stat().st_mtime)}") | |
| html = html.replace("<!--SAMPLING:unique-->", sampling_widget_html("unique")) | |
| html = html.replace("<!--SAMPLING:generic-->", sampling_widget_html("generic")) | |
| return HTMLResponse(html, headers={"Cache-Control": "no-store"}) | |
| def extract_unique_api(input_files: Any = None, sampling: str = SPACE_SAMPLING) -> tuple[FileData | None, dict]: | |
| if not input_files: | |
| return None, { | |
| "success": False, | |
| "status_html": status_html("idle", "No files uploaded yet."), | |
| "stats_html": unique_stats_html(), | |
| } | |
| temp_dir = tempfile.mkdtemp() | |
| frames_dir, frame_paths, output_fps, meta = process_inputs_to_frames(input_files, temp_dir, sampling) | |
| if not frame_paths: | |
| return None, { | |
| "success": False, | |
| "frames_dir": frames_dir, | |
| "frame_count": 0, | |
| "fps": output_fps, | |
| "status_html": status_html("idle", "No frames extracted."), | |
| "stats_html": unique_stats_html(), | |
| } | |
| return file_data(frame_paths[0]), { | |
| "success": True, | |
| "frames_dir": frames_dir, | |
| "frame_count": len(frame_paths), | |
| "fps": output_fps, | |
| "status_html": status_html("ok", frame_extraction_message(meta)), | |
| "stats_html": unique_stats_html(), | |
| } | |
| def reference_frame_api(frames_dir: str, frame_idx: int = 0) -> tuple[FileData | None, dict]: | |
| if not frames_dir: | |
| return None, {"success": False, "error": "Extract frames first."} | |
| frame_idx = max(0, int(frame_idx or 0)) | |
| frame_path = os.path.join(frames_dir, f"frame_{frame_idx:05d}.jpg") | |
| if not os.path.exists(frame_path): | |
| return None, {"success": False, "error": "Reference frame not found."} | |
| return file_data(frame_path), {"success": True, "frame_idx": frame_idx} | |
| def preview_generic_api(input_files: Any = None) -> tuple[FileData | None, dict]: | |
| """First-frame thumbnail for the Generic input, extracted server-side with | |
| OpenCV so it works for any codec the browser can't decode (mpeg4, HEVC...).""" | |
| input_list = normalize_uploaded_files(input_files) | |
| if not input_list: | |
| return None, {"success": False} | |
| first = input_list[0] | |
| path = uploaded_file_path(first) | |
| video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'} | |
| is_video = uploaded_file_is_video(first, video_extensions) or file_content_looks_like_video(path) | |
| if not is_video: | |
| return file_data(path), {"success": True} | |
| try: | |
| cap = cv2.VideoCapture(path) | |
| ok, frame = cap.read() | |
| cap.release() | |
| if not ok or frame is None: | |
| return None, {"success": False, "error": "Could not read first frame."} | |
| out_dir = tempfile.mkdtemp() | |
| out_path = os.path.join(out_dir, "generic_preview.jpg") | |
| cv2.imwrite(out_path, frame) | |
| return file_data(out_path), {"success": True} | |
| except Exception as e: | |
| logger.error(f"Generic preview extraction failed: {e}") | |
| return None, {"success": False, "error": str(e)} | |
| def run_unique_api( | |
| input_files: Any = None, | |
| points: Any = None, | |
| prompt: str = "", | |
| encoder_size: str = "long-large", | |
| offload: bool = False, | |
| frame_idx: int = 0, | |
| frames_dir_cache: str | None = None, | |
| original_fps: float = 1.0, | |
| ) -> tuple[FileData | None, dict]: | |
| video_path, status, stats = run_unique_segmentation( | |
| input_files, | |
| normalize_points(points), | |
| prompt, | |
| encoder_size or "long-large", | |
| bool(offload), | |
| int(frame_idx or 0), | |
| frames_dir_cache, | |
| float(original_fps or 1.0), | |
| ) | |
| return file_data(video_path), { | |
| "success": bool(video_path), | |
| "status_html": status, | |
| "stats_html": stats, | |
| } | |
| def run_generic_api( | |
| input_files: Any = None, | |
| category: str = "", | |
| accept: float = 0.30, | |
| reject: float = 0.10, | |
| vlm_model: str = VLM_MODELS[0], | |
| sampling: str = SPACE_SAMPLING, | |
| ) -> tuple[FileData | None, dict]: | |
| video_path, status, stats = run_generic_segmentation( | |
| input_files, | |
| category, | |
| float(accept), | |
| float(reject), | |
| vlm_model or VLM_MODELS[0], | |
| sampling, | |
| ) | |
| return file_data(video_path), { | |
| "success": bool(video_path), | |
| "status_html": status, | |
| "stats_html": stats, | |
| } | |
| if __name__ == "__main__": | |
| app.launch(show_error=True) | |