File size: 4,594 Bytes
a5a9cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d943770
 
a5a9cc1
 
 
 
 
d943770
a5a9cc1
 
 
 
 
 
 
 
 
 
 
 
 
d943770
 
 
a5a9cc1
 
 
 
 
 
 
 
 
d943770
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import json
from pathlib import Path

import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
from safetensors.torch import load_file
from transformers import AutoImageProcessor, AutoModel
from huggingface_hub import snapshot_download

MODEL_REPO = "nzs234/siglip2-so400m-aesthetic-scorer-v1"
CACHE_DIR = Path("./model_cache")


def infer_feature_dim(vision):
    cfg = getattr(vision, "config", None)
    for obj in [cfg, getattr(cfg, "vision_config", None) if cfg is not None else None]:
        if obj is None:
            continue
        for k in ("projection_dim", "hidden_size"):
            v = getattr(obj, k, None)
            if isinstance(v, int) and v > 0:
                return v
    proj = getattr(vision, "visual_projection", None)
    if isinstance(proj, nn.Linear):
        return int(proj.out_features)
    raise ValueError("cannot infer feature dim")


class Regressor(nn.Module):
    def __init__(self, backbone_dir: str, hidden_dim: int = 2048, dropout: float = 0.2):
        super().__init__()
        self.vision = AutoModel.from_pretrained(backbone_dir, local_files_only=True)
        feat_dim = infer_feature_dim(self.vision)
        h1 = int(hidden_dim)
        h2, h3, h4, h5 = 512, 256, 128, 32
        d1 = float(max(0.0, min(0.8, dropout if dropout > 0 else 0.3)))
        d2 = d1
        d3 = float(max(0.0, min(0.8, d1 * 0.67)))
        d4 = float(max(0.0, min(0.8, d1 * 0.33)))
        self.head = nn.Sequential(
            nn.LayerNorm(feat_dim),
            nn.Linear(feat_dim, h1),
            nn.ReLU(),
            nn.BatchNorm1d(h1),
            nn.Dropout(d1),
            nn.Linear(h1, h2),
            nn.ReLU(),
            nn.BatchNorm1d(h2),
            nn.Dropout(d2),
            nn.Linear(h2, h3),
            nn.ReLU(),
            nn.BatchNorm1d(h3),
            nn.Dropout(d3),
            nn.Linear(h3, h4),
            nn.ReLU(),
            nn.BatchNorm1d(h4),
            nn.Dropout(d4),
            nn.Linear(h4, h5),
            nn.ReLU(),
            nn.Linear(h5, 1),
        )

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        if hasattr(self.vision, "get_image_features"):
            feats = self.vision.get_image_features(pixel_values=pixel_values)
            if not isinstance(feats, torch.Tensor):
                feats = feats.image_embeds if hasattr(feats, "image_embeds") else feats.pooler_output
        else:
            out = self.vision(pixel_values=pixel_values)
            feats = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:, 0, :]
        feats = feats / (feats.norm(dim=1, keepdim=True) + 1e-8)
        x = self.head(feats).squeeze(-1)
        return torch.sigmoid(x)


print("Downloading model repo snapshot...")
local_repo = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=str(CACHE_DIR))
local_repo = Path(local_repo)
meta = json.loads((local_repo / "metadata.json").read_text(encoding="utf-8"))
model_cfg = meta.get("model", {})
data_cfg = meta.get("data", {})

processor = AutoImageProcessor.from_pretrained(str(local_repo / "backbone"), local_files_only=True, use_fast=False)
model = Regressor(
    backbone_dir=str(local_repo / "backbone"),
    hidden_dim=int(model_cfg.get("hidden_dim", 2048)),
    dropout=float(model_cfg.get("dropout", 0.2)),
)
head_state = load_file(str(local_repo / "head.safetensors"), device="cpu")
model.head.load_state_dict(head_state, strict=False)
model.eval()

score_min = float(data_cfg.get("score_min", 1.0))
score_max = float(data_cfg.get("score_max", 9.0))


def predict(img: Image.Image):
    if img is None:
        return "error: no image"
    if img.mode != "RGB":
        img = img.convert("RGB")
    proc = processor(images=img, return_tensors="pt")
    with torch.inference_mode():
        pred_01 = model(proc["pixel_values"]).item()
    pred_01 = max(0.0, min(1.0, float(pred_01)))
    pred_score = pred_01 * (score_max - score_min) + score_min
    score_int = int(round(pred_score))
    score_int = max(int(score_min), min(int(score_max), score_int))
    return f"score_{score_int} (raw={pred_score:.4f})"


with gr.Blocks() as demo:
    gr.Markdown("# SigLIP2 Aesthetic Scorer Demo")
    inp = gr.Image(type="pil", label="Image")
    out = gr.Textbox(label="Result")
    btn = gr.Button("Predict")
    btn.click(fn=predict, inputs=[inp], outputs=[out])


demo.launch(server_name="0.0.0.0", server_port=7860, share=True)