Spaces:
Runtime error
Runtime error
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from safetensors.torch import load_file | |
| from transformers import AutoImageProcessor, AutoModel | |
| from huggingface_hub import snapshot_download | |
| MODEL_REPO = "nzs234/siglip2-so400m-aesthetic-scorer-v1" | |
| CACHE_DIR = Path("./model_cache") | |
| def infer_feature_dim(vision): | |
| cfg = getattr(vision, "config", None) | |
| for obj in [cfg, getattr(cfg, "vision_config", None) if cfg is not None else None]: | |
| if obj is None: | |
| continue | |
| for k in ("projection_dim", "hidden_size"): | |
| v = getattr(obj, k, None) | |
| if isinstance(v, int) and v > 0: | |
| return v | |
| proj = getattr(vision, "visual_projection", None) | |
| if isinstance(proj, nn.Linear): | |
| return int(proj.out_features) | |
| raise ValueError("cannot infer feature dim") | |
| class Regressor(nn.Module): | |
| def __init__(self, backbone_dir: str, hidden_dim: int = 2048, dropout: float = 0.2): | |
| super().__init__() | |
| self.vision = AutoModel.from_pretrained(backbone_dir, local_files_only=True) | |
| feat_dim = infer_feature_dim(self.vision) | |
| h1 = int(hidden_dim) | |
| h2, h3, h4, h5 = 512, 256, 128, 32 | |
| d1 = float(max(0.0, min(0.8, dropout if dropout > 0 else 0.3))) | |
| d2 = d1 | |
| d3 = float(max(0.0, min(0.8, d1 * 0.67))) | |
| d4 = float(max(0.0, min(0.8, d1 * 0.33))) | |
| self.head = nn.Sequential( | |
| nn.LayerNorm(feat_dim), | |
| nn.Linear(feat_dim, h1), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(h1), | |
| nn.Dropout(d1), | |
| nn.Linear(h1, h2), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(h2), | |
| nn.Dropout(d2), | |
| nn.Linear(h2, h3), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(h3), | |
| nn.Dropout(d3), | |
| nn.Linear(h3, h4), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(h4), | |
| nn.Dropout(d4), | |
| nn.Linear(h4, h5), | |
| nn.ReLU(), | |
| nn.Linear(h5, 1), | |
| ) | |
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: | |
| if hasattr(self.vision, "get_image_features"): | |
| feats = self.vision.get_image_features(pixel_values=pixel_values) | |
| if not isinstance(feats, torch.Tensor): | |
| feats = feats.image_embeds if hasattr(feats, "image_embeds") else feats.pooler_output | |
| else: | |
| out = self.vision(pixel_values=pixel_values) | |
| feats = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:, 0, :] | |
| feats = feats / (feats.norm(dim=1, keepdim=True) + 1e-8) | |
| x = self.head(feats).squeeze(-1) | |
| return torch.sigmoid(x) | |
| print("Downloading model repo snapshot...") | |
| local_repo = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=str(CACHE_DIR)) | |
| local_repo = Path(local_repo) | |
| meta = json.loads((local_repo / "metadata.json").read_text(encoding="utf-8")) | |
| model_cfg = meta.get("model", {}) | |
| data_cfg = meta.get("data", {}) | |
| processor = AutoImageProcessor.from_pretrained(str(local_repo / "backbone"), local_files_only=True, use_fast=False) | |
| model = Regressor( | |
| backbone_dir=str(local_repo / "backbone"), | |
| hidden_dim=int(model_cfg.get("hidden_dim", 2048)), | |
| dropout=float(model_cfg.get("dropout", 0.2)), | |
| ) | |
| head_state = load_file(str(local_repo / "head.safetensors"), device="cpu") | |
| model.head.load_state_dict(head_state, strict=False) | |
| model.eval() | |
| score_min = float(data_cfg.get("score_min", 1.0)) | |
| score_max = float(data_cfg.get("score_max", 9.0)) | |
| def predict(img: Image.Image): | |
| if img is None: | |
| return "error: no image" | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| proc = processor(images=img, return_tensors="pt") | |
| with torch.inference_mode(): | |
| pred_01 = model(proc["pixel_values"]).item() | |
| pred_01 = max(0.0, min(1.0, float(pred_01))) | |
| pred_score = pred_01 * (score_max - score_min) + score_min | |
| score_int = int(round(pred_score)) | |
| score_int = max(int(score_min), min(int(score_max), score_int)) | |
| return f"score_{score_int} (raw={pred_score:.4f})" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# SigLIP2 Aesthetic Scorer Demo") | |
| inp = gr.Image(type="pil", label="Image") | |
| out = gr.Textbox(label="Result") | |
| btn = gr.Button("Predict") | |
| btn.click(fn=predict, inputs=[inp], outputs=[out]) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |