nzs234 commited on
Commit
73f654d
·
verified ·
1 Parent(s): eac58a9

Make model loading lazy at first inference to avoid startup crashes/timeouts

Browse files
Files changed (1) hide show
  1. app.py +52 -26
app.py CHANGED
@@ -1,5 +1,6 @@
1
- import json
2
- from pathlib import Path
 
3
 
4
  import gradio as gr
5
  import torch
@@ -9,8 +10,15 @@ from safetensors.torch import load_file
9
  from transformers import AutoImageProcessor, AutoModel
10
  from huggingface_hub import snapshot_download
11
 
12
- MODEL_REPO = "nzs234/siglip2-so400m-aesthetic-scorer-v1"
13
- CACHE_DIR = Path("./model_cache")
 
 
 
 
 
 
 
14
 
15
 
16
  def infer_feature_dim(vision):
@@ -75,33 +83,51 @@ class Regressor(nn.Module):
75
  return torch.sigmoid(x)
76
 
77
 
78
- print("Downloading model repo snapshot...")
79
- local_repo = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=str(CACHE_DIR))
80
- local_repo = Path(local_repo)
81
- meta = json.loads((local_repo / "metadata.json").read_text(encoding="utf-8"))
82
- model_cfg = meta.get("model", {})
83
- data_cfg = meta.get("data", {})
84
-
85
- processor = AutoImageProcessor.from_pretrained(str(local_repo / "backbone"), local_files_only=True, use_fast=False)
86
- model = Regressor(
87
- backbone_dir=str(local_repo / "backbone"),
88
- hidden_dim=int(model_cfg.get("hidden_dim", 2048)),
89
- dropout=float(model_cfg.get("dropout", 0.2)),
90
- )
91
- head_state = load_file(str(local_repo / "head.safetensors"), device="cpu")
92
- model.head.load_state_dict(head_state, strict=False)
93
- model.eval()
94
-
95
- score_min = float(data_cfg.get("score_min", 1.0))
96
- score_max = float(data_cfg.get("score_max", 9.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  def predict(img: Image.Image):
100
  if img is None:
101
  return "error: no image"
102
- if img.mode != "RGB":
103
- img = img.convert("RGB")
104
- proc = processor(images=img, return_tensors="pt")
 
 
 
 
105
  with torch.inference_mode():
106
  pred_01 = model(proc["pixel_values"]).item()
107
  pred_01 = max(0.0, min(1.0, float(pred_01)))
 
1
+ import json
2
+ import threading
3
+ from pathlib import Path
4
 
5
  import gradio as gr
6
  import torch
 
10
  from transformers import AutoImageProcessor, AutoModel
11
  from huggingface_hub import snapshot_download
12
 
13
+ MODEL_REPO = "nzs234/siglip2-so400m-aesthetic-scorer-v1"
14
+ CACHE_DIR = Path("./model_cache")
15
+ _STATE_LOCK = threading.Lock()
16
+ _MODEL_READY = False
17
+ _MODEL_ERR = ""
18
+ processor = None
19
+ model = None
20
+ score_min = 1.0
21
+ score_max = 9.0
22
 
23
 
24
  def infer_feature_dim(vision):
 
83
  return torch.sigmoid(x)
84
 
85
 
86
+ def _ensure_loaded():
87
+ global _MODEL_READY, _MODEL_ERR, processor, model, score_min, score_max
88
+ if _MODEL_READY:
89
+ return
90
+ with _STATE_LOCK:
91
+ if _MODEL_READY:
92
+ return
93
+ try:
94
+ print("Downloading model repo snapshot...")
95
+ local_repo = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir=str(CACHE_DIR))
96
+ local_repo = Path(local_repo)
97
+ meta = json.loads((local_repo / "metadata.json").read_text(encoding="utf-8"))
98
+ model_cfg = meta.get("model", {})
99
+ data_cfg = meta.get("data", {})
100
+
101
+ processor = AutoImageProcessor.from_pretrained(str(local_repo / "backbone"), local_files_only=True, use_fast=False)
102
+ model = Regressor(
103
+ backbone_dir=str(local_repo / "backbone"),
104
+ hidden_dim=int(model_cfg.get("hidden_dim", 2048)),
105
+ dropout=float(model_cfg.get("dropout", 0.2)),
106
+ )
107
+ head_state = load_file(str(local_repo / "head.safetensors"), device="cpu")
108
+ model.head.load_state_dict(head_state, strict=False)
109
+ model.eval()
110
+
111
+ score_min = float(data_cfg.get("score_min", 1.0))
112
+ score_max = float(data_cfg.get("score_max", 9.0))
113
+ _MODEL_READY = True
114
+ _MODEL_ERR = ""
115
+ print("Model loaded.")
116
+ except Exception as e:
117
+ _MODEL_ERR = str(e)
118
+ raise
119
 
120
 
121
  def predict(img: Image.Image):
122
  if img is None:
123
  return "error: no image"
124
+ try:
125
+ _ensure_loaded()
126
+ except Exception:
127
+ return f"error: model load failed: {_MODEL_ERR}"
128
+ if img.mode != "RGB":
129
+ img = img.convert("RGB")
130
+ proc = processor(images=img, return_tensors="pt")
131
  with torch.inference_mode():
132
  pred_01 = model(proc["pixel_values"]).item()
133
  pred_01 = max(0.0, min(1.0, float(pred_01)))