File size: 6,135 Bytes
73f654d
 
 
482b4b9
 
7f8b3d8
 
 
a5a9cc1
 
 
7f8b3d8
 
 
 
 
eb58ffc
7f8b3d8
 
 
 
 
 
 
 
 
eb58ffc
 
 
 
 
 
 
 
 
a5a9cc1
73f654d
 
 
 
 
 
 
 
 
a5a9cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73f654d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5a9cc1
 
796e086
d943770
482b4b9
73f654d
 
 
482b4b9
73f654d
 
 
f63e025
 
 
 
 
 
482b4b9
 
d943770
482b4b9
 
 
 
 
 
 
 
d943770
482b4b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import json
import threading
from pathlib import Path

import gradio as gr
import gradio_client.utils as gc_utils
import torch
import torch.nn as nn
from PIL import Image
from safetensors.torch import load_file
from transformers import AutoImageProcessor, AutoModel
from huggingface_hub import snapshot_download

# Workaround for gradio/gradio_client schema bug on some Spaces runtimes:
# json_schema_to_python_type may pass boolean schema nodes into get_type().
_orig_get_type = gc_utils.get_type
_orig_json_schema_to_python_type = gc_utils._json_schema_to_python_type


def _safe_get_type(schema):
    if isinstance(schema, bool):
        return "Any"
    return _orig_get_type(schema)


gc_utils.get_type = _safe_get_type


def _safe_json_schema_to_python_type(schema, defs=None):
    if isinstance(schema, bool):
        return "Any"
    return _orig_json_schema_to_python_type(schema, defs)


gc_utils._json_schema_to_python_type = _safe_json_schema_to_python_type

MODEL_REPO = "nzs234/siglip2-so400m-aesthetic-scorer-v1"
CACHE_DIR = Path("./model_cache")
_STATE_LOCK = threading.Lock()
_MODEL_READY = False
_MODEL_ERR = ""
processor = None
model = None
score_min = 1.0
score_max = 9.0


def infer_feature_dim(vision):
    cfg = getattr(vision, "config", None)
    for obj in [cfg, getattr(cfg, "vision_config", None) if cfg is not None else None]:
        if obj is None:
            continue
        for k in ("projection_dim", "hidden_size"):
            v = getattr(obj, k, None)
            if isinstance(v, int) and v > 0:
                return v
    proj = getattr(vision, "visual_projection", None)
    if isinstance(proj, nn.Linear):
        return int(proj.out_features)
    raise ValueError("cannot infer feature dim")


class Regressor(nn.Module):
    def __init__(self, backbone_dir: str, hidden_dim: int = 2048, dropout: float = 0.2):
        super().__init__()
        self.vision = AutoModel.from_pretrained(backbone_dir, local_files_only=True)
        feat_dim = infer_feature_dim(self.vision)
        h1 = int(hidden_dim)
        h2, h3, h4, h5 = 512, 256, 128, 32
        d1 = float(max(0.0, min(0.8, dropout if dropout > 0 else 0.3)))
        d2 = d1
        d3 = float(max(0.0, min(0.8, d1 * 0.67)))
        d4 = float(max(0.0, min(0.8, d1 * 0.33)))
        self.head = nn.Sequential(
            nn.LayerNorm(feat_dim),
            nn.Linear(feat_dim, h1),
            nn.ReLU(),
            nn.BatchNorm1d(h1),
            nn.Dropout(d1),
            nn.Linear(h1, h2),
            nn.ReLU(),
            nn.BatchNorm1d(h2),
            nn.Dropout(d2),
            nn.Linear(h2, h3),
            nn.ReLU(),
            nn.BatchNorm1d(h3),
            nn.Dropout(d3),
            nn.Linear(h3, h4),
            nn.ReLU(),
            nn.BatchNorm1d(h4),
            nn.Dropout(d4),
            nn.Linear(h4, h5),
            nn.ReLU(),
            nn.Linear(h5, 1),
        )

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        if hasattr(self.vision, "get_image_features"):
            feats = self.vision.get_image_features(pixel_values=pixel_values)
            if not isinstance(feats, torch.Tensor):
                feats = feats.image_embeds if hasattr(feats, "image_embeds") else feats.pooler_output
        else:
            out = self.vision(pixel_values=pixel_values)
            feats = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:, 0, :]
        feats = feats / (feats.norm(dim=1, keepdim=True) + 1e-8)
        x = self.head(feats).squeeze(-1)
        return torch.sigmoid(x)


def _ensure_loaded():
    global _MODEL_READY, _MODEL_ERR, processor, model, score_min, score_max
    if _MODEL_READY:
        return
    with _STATE_LOCK:
        if _MODEL_READY:
            return
        try:
            print("Downloading model repo snapshot...")
            local_repo = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=str(CACHE_DIR))
            local_repo = Path(local_repo)
            meta = json.loads((local_repo / "metadata.json").read_text(encoding="utf-8"))
            model_cfg = meta.get("model", {})
            data_cfg = meta.get("data", {})

            processor = AutoImageProcessor.from_pretrained(str(local_repo / "backbone"), local_files_only=True, use_fast=False)
            model = Regressor(
                backbone_dir=str(local_repo / "backbone"),
                hidden_dim=int(model_cfg.get("hidden_dim", 2048)),
                dropout=float(model_cfg.get("dropout", 0.2)),
            )
            head_state = load_file(str(local_repo / "head.safetensors"), device="cpu")
            model.head.load_state_dict(head_state, strict=False)
            model.eval()

            score_min = float(data_cfg.get("score_min", 1.0))
            score_max = float(data_cfg.get("score_max", 9.0))
            _MODEL_READY = True
            _MODEL_ERR = ""
            print("Model loaded.")
        except Exception as e:
            _MODEL_ERR = str(e)
            raise


def predict(img):
    if img is None:
        return "error: no image"
    try:
        _ensure_loaded()
    except Exception:
        return f"error: model load failed: {_MODEL_ERR}"
    if img.mode != "RGB":
        img = img.convert("RGB")
    proc = processor(images=img, return_tensors="pt")
    with torch.inference_mode():
        pred_01 = model(proc["pixel_values"]).item()
    pred_01 = max(0.0, min(1.0, float(pred_01)))
    pred_score = pred_01 * (score_max - score_min) + score_min
    score_int = int(round(pred_score))
    score_int = max(int(score_min), min(int(score_max), score_int))
    return f"score_{score_int} (raw={pred_score:.4f})"


demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Image"),
    outputs=gr.Textbox(label="Result"),
    title="SigLIP2 Aesthetic Scorer Demo",
    description="Upload image and get score_1..score_9",
    allow_flagging="never",
)

demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, ssr_mode=False, share=True)