nzs234's picture
Stabilize Space: switch to simple Interface mode, single text output, share=True
482b4b9 verified
raw
history blame
5.44 kB
import json
import threading
from pathlib import Path
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
from safetensors.torch import load_file
from transformers import AutoImageProcessor, AutoModel
from huggingface_hub import snapshot_download
MODEL_REPO = "nzs234/siglip2-so400m-aesthetic-scorer-v1"
CACHE_DIR = Path("./model_cache")
_STATE_LOCK = threading.Lock()
_MODEL_READY = False
_MODEL_ERR = ""
processor = None
model = None
score_min = 1.0
score_max = 9.0
def infer_feature_dim(vision):
cfg = getattr(vision, "config", None)
for obj in [cfg, getattr(cfg, "vision_config", None) if cfg is not None else None]:
if obj is None:
continue
for k in ("projection_dim", "hidden_size"):
v = getattr(obj, k, None)
if isinstance(v, int) and v > 0:
return v
proj = getattr(vision, "visual_projection", None)
if isinstance(proj, nn.Linear):
return int(proj.out_features)
raise ValueError("cannot infer feature dim")
class Regressor(nn.Module):
def __init__(self, backbone_dir: str, hidden_dim: int = 2048, dropout: float = 0.2):
super().__init__()
self.vision = AutoModel.from_pretrained(backbone_dir, local_files_only=True)
feat_dim = infer_feature_dim(self.vision)
h1 = int(hidden_dim)
h2, h3, h4, h5 = 512, 256, 128, 32
d1 = float(max(0.0, min(0.8, dropout if dropout > 0 else 0.3)))
d2 = d1
d3 = float(max(0.0, min(0.8, d1 * 0.67)))
d4 = float(max(0.0, min(0.8, d1 * 0.33)))
self.head = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Linear(feat_dim, h1),
nn.ReLU(),
nn.BatchNorm1d(h1),
nn.Dropout(d1),
nn.Linear(h1, h2),
nn.ReLU(),
nn.BatchNorm1d(h2),
nn.Dropout(d2),
nn.Linear(h2, h3),
nn.ReLU(),
nn.BatchNorm1d(h3),
nn.Dropout(d3),
nn.Linear(h3, h4),
nn.ReLU(),
nn.BatchNorm1d(h4),
nn.Dropout(d4),
nn.Linear(h4, h5),
nn.ReLU(),
nn.Linear(h5, 1),
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
if hasattr(self.vision, "get_image_features"):
feats = self.vision.get_image_features(pixel_values=pixel_values)
if not isinstance(feats, torch.Tensor):
feats = feats.image_embeds if hasattr(feats, "image_embeds") else feats.pooler_output
else:
out = self.vision(pixel_values=pixel_values)
feats = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:, 0, :]
feats = feats / (feats.norm(dim=1, keepdim=True) + 1e-8)
x = self.head(feats).squeeze(-1)
return torch.sigmoid(x)
def _ensure_loaded():
global _MODEL_READY, _MODEL_ERR, processor, model, score_min, score_max
if _MODEL_READY:
return
with _STATE_LOCK:
if _MODEL_READY:
return
try:
print("Downloading model repo snapshot...")
local_repo = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=str(CACHE_DIR))
local_repo = Path(local_repo)
meta = json.loads((local_repo / "metadata.json").read_text(encoding="utf-8"))
model_cfg = meta.get("model", {})
data_cfg = meta.get("data", {})
processor = AutoImageProcessor.from_pretrained(str(local_repo / "backbone"), local_files_only=True, use_fast=False)
model = Regressor(
backbone_dir=str(local_repo / "backbone"),
hidden_dim=int(model_cfg.get("hidden_dim", 2048)),
dropout=float(model_cfg.get("dropout", 0.2)),
)
head_state = load_file(str(local_repo / "head.safetensors"), device="cpu")
model.head.load_state_dict(head_state, strict=False)
model.eval()
score_min = float(data_cfg.get("score_min", 1.0))
score_max = float(data_cfg.get("score_max", 9.0))
_MODEL_READY = True
_MODEL_ERR = ""
print("Model loaded.")
except Exception as e:
_MODEL_ERR = str(e)
raise
def predict(img):
if img is None:
return "error: no image"
try:
_ensure_loaded()
except Exception:
return f"error: model load failed: {_MODEL_ERR}"
if img.mode != "RGB":
img = img.convert("RGB")
proc = processor(images=img, return_tensors="pt")
with torch.inference_mode():
pred_01 = model(proc["pixel_values"]).item()
pred_01 = max(0.0, min(1.0, float(pred_01)))
pred_score = pred_01 * (score_max - score_min) + score_min
score_int = int(round(pred_score))
score_int = max(int(score_min), min(int(score_max), score_int))
return f"score_{score_int} (raw={pred_score:.4f})"
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Image"),
outputs=gr.Textbox(label="Result"),
title="SigLIP2 Aesthetic Scorer Demo",
description="Upload image and get score_1..score_9",
allow_flagging="never",
)
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, ssr_mode=False, share=True)