import json import threading from pathlib import Path import gradio as gr import torch import torch.nn as nn from PIL import Image from safetensors.torch import load_file from transformers import AutoImageProcessor, AutoModel from huggingface_hub import snapshot_download MODEL_REPO = "nzs234/siglip2-so400m-aesthetic-scorer-v1" CACHE_DIR = Path("./model_cache") _STATE_LOCK = threading.Lock() _MODEL_READY = False _MODEL_ERR = "" processor = None model = None score_min = 1.0 score_max = 9.0 def infer_feature_dim(vision): cfg = getattr(vision, "config", None) for obj in [cfg, getattr(cfg, "vision_config", None) if cfg is not None else None]: if obj is None: continue for k in ("projection_dim", "hidden_size"): v = getattr(obj, k, None) if isinstance(v, int) and v > 0: return v proj = getattr(vision, "visual_projection", None) if isinstance(proj, nn.Linear): return int(proj.out_features) raise ValueError("cannot infer feature dim") class Regressor(nn.Module): def __init__(self, backbone_dir: str, hidden_dim: int = 2048, dropout: float = 0.2): super().__init__() self.vision = AutoModel.from_pretrained(backbone_dir, local_files_only=True) feat_dim = infer_feature_dim(self.vision) h1 = int(hidden_dim) h2, h3, h4, h5 = 512, 256, 128, 32 d1 = float(max(0.0, min(0.8, dropout if dropout > 0 else 0.3))) d2 = d1 d3 = float(max(0.0, min(0.8, d1 * 0.67))) d4 = float(max(0.0, min(0.8, d1 * 0.33))) self.head = nn.Sequential( nn.LayerNorm(feat_dim), nn.Linear(feat_dim, h1), nn.ReLU(), nn.BatchNorm1d(h1), nn.Dropout(d1), nn.Linear(h1, h2), nn.ReLU(), nn.BatchNorm1d(h2), nn.Dropout(d2), nn.Linear(h2, h3), nn.ReLU(), nn.BatchNorm1d(h3), nn.Dropout(d3), nn.Linear(h3, h4), nn.ReLU(), nn.BatchNorm1d(h4), nn.Dropout(d4), nn.Linear(h4, h5), nn.ReLU(), nn.Linear(h5, 1), ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: if hasattr(self.vision, "get_image_features"): feats = self.vision.get_image_features(pixel_values=pixel_values) if not isinstance(feats, torch.Tensor): feats = feats.image_embeds if hasattr(feats, "image_embeds") else feats.pooler_output else: out = self.vision(pixel_values=pixel_values) feats = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:, 0, :] feats = feats / (feats.norm(dim=1, keepdim=True) + 1e-8) x = self.head(feats).squeeze(-1) return torch.sigmoid(x) def _ensure_loaded(): global _MODEL_READY, _MODEL_ERR, processor, model, score_min, score_max if _MODEL_READY: return with _STATE_LOCK: if _MODEL_READY: return try: print("Downloading model repo snapshot...") local_repo = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=str(CACHE_DIR)) local_repo = Path(local_repo) meta = json.loads((local_repo / "metadata.json").read_text(encoding="utf-8")) model_cfg = meta.get("model", {}) data_cfg = meta.get("data", {}) processor = AutoImageProcessor.from_pretrained(str(local_repo / "backbone"), local_files_only=True, use_fast=False) model = Regressor( backbone_dir=str(local_repo / "backbone"), hidden_dim=int(model_cfg.get("hidden_dim", 2048)), dropout=float(model_cfg.get("dropout", 0.2)), ) head_state = load_file(str(local_repo / "head.safetensors"), device="cpu") model.head.load_state_dict(head_state, strict=False) model.eval() score_min = float(data_cfg.get("score_min", 1.0)) score_max = float(data_cfg.get("score_max", 9.0)) _MODEL_READY = True _MODEL_ERR = "" print("Model loaded.") except Exception as e: _MODEL_ERR = str(e) raise def predict(img: Image.Image): if img is None: yield "error: no image", "status: please upload image first" return yield "", "status: starting" try: if not _MODEL_READY: yield "", "status: loading model (first run takes longer)" _ensure_loaded() except Exception: yield f"error: model load failed: {_MODEL_ERR}", "status: failed" return yield "", "status: model ready, running inference" if img.mode != "RGB": img = img.convert("RGB") proc = processor(images=img, return_tensors="pt") with torch.inference_mode(): pred_01 = model(proc["pixel_values"]).item() pred_01 = max(0.0, min(1.0, float(pred_01))) pred_score = pred_01 * (score_max - score_min) + score_min score_int = int(round(pred_score)) score_int = max(int(score_min), min(int(score_max), score_int)) yield f"score_{score_int} (raw={pred_score:.4f})", "status: done" with gr.Blocks() as demo: gr.Markdown("# SigLIP2 Aesthetic Scorer Demo") inp = gr.Image(type="pil", label="Image") out = gr.Textbox(label="Result") status = gr.Textbox(label="Status", value="status: idle") btn = gr.Button("Predict") btn.click(fn=predict, inputs=[inp], outputs=[out, status]) demo.queue(default_concurrency_limit=1) demo.launch(server_name="0.0.0.0", server_port=7860)