Spaces:
Running
Running
SathyaSantosh77 commited on
Commit Β·
a043ce4
1
Parent(s): 66e1339
fix zerogpu device handling
Browse files
app.py
CHANGED
|
@@ -182,30 +182,23 @@ def ssim(pred, target, window_size=11):
|
|
| 182 |
|
| 183 |
|
| 184 |
# ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
model = ImageSRTransformer().to(DEVICE)
|
| 188 |
checkpoint = torch.load(
|
| 189 |
"sr_best_v4_resumed.pt",
|
| 190 |
-
map_location=
|
| 191 |
weights_only=False
|
| 192 |
)
|
| 193 |
model.load_state_dict(checkpoint["model_state_dict"])
|
| 194 |
model.eval()
|
| 195 |
-
print(f"Model loaded β
|
| 196 |
-
print(f"Best val PSNR: {checkpoint['val_psnr']:.2f} dB")
|
| 197 |
|
| 198 |
|
| 199 |
# ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 200 |
@spaces.GPU
|
| 201 |
def run_sr(img_pil):
|
| 202 |
-
"""
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
2. Bicubic downscale to 64Γ64 β LR
|
| 206 |
-
3. Runs Dense-Iso-ViT SR
|
| 207 |
-
4. Returns (lr_display, sr_output, ground_truth, metrics_str)
|
| 208 |
-
"""
|
| 209 |
w, h = img_pil.size
|
| 210 |
if w < 256 or h < 256:
|
| 211 |
scale = max(256 / w, 256 / h)
|
|
@@ -213,42 +206,33 @@ def run_sr(img_pil):
|
|
| 213 |
(int(w * scale), int(h * scale)), Image.BICUBIC)
|
| 214 |
w, h = img_pil.size
|
| 215 |
|
| 216 |
-
# centre crop 256Γ256
|
| 217 |
left = (w - 256) // 2
|
| 218 |
top = (h - 256) // 2
|
| 219 |
gt = img_pil.crop((left, top, left + 256, top + 256))
|
| 220 |
-
|
| 221 |
-
# LR β bicubic 64Γ64
|
| 222 |
lr = gt.resize((64, 64), Image.BICUBIC)
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
gt_t = TF.to_tensor(gt).unsqueeze(0).to(DEVICE)
|
| 227 |
|
| 228 |
with torch.no_grad():
|
| 229 |
-
with torch.autocast(device_type=
|
|
|
|
|
|
|
| 230 |
sr_t = model(lr_t)
|
| 231 |
sr_t = sr_t.float().clamp(0, 1)
|
| 232 |
|
| 233 |
-
# LR display β bilinear upscale to 256 for side-by-side
|
| 234 |
lr_display_t = F.interpolate(
|
| 235 |
lr_t.float(), size=(256, 256),
|
| 236 |
mode="bilinear", align_corners=False)
|
| 237 |
|
| 238 |
-
# metrics β LR baseline vs SR
|
| 239 |
psnr_lr = psnr(lr_display_t, gt_t).item()
|
| 240 |
ssim_lr = ssim(lr_display_t, gt_t)
|
| 241 |
psnr_sr = psnr(sr_t, gt_t).item()
|
| 242 |
ssim_sr = ssim(sr_t, gt_t)
|
| 243 |
|
| 244 |
-
# to PIL
|
| 245 |
def to_pil(t):
|
| 246 |
return TF.to_pil_image(t.squeeze(0).cpu())
|
| 247 |
|
| 248 |
-
lr_img = to_pil(lr_display_t)
|
| 249 |
-
sr_img = to_pil(sr_t)
|
| 250 |
-
gt_img = gt
|
| 251 |
-
|
| 252 |
metrics = (
|
| 253 |
f"**LR baseline** β PSNR: {psnr_lr:.2f} dB | SSIM: {ssim_lr:.4f}\n\n"
|
| 254 |
f"**SR output** β PSNR: {psnr_sr:.2f} dB | SSIM: {ssim_sr:.4f}\n\n"
|
|
@@ -256,7 +240,7 @@ def run_sr(img_pil):
|
|
| 256 |
f"ΞSSIM: +{ssim_sr - ssim_lr:.4f}"
|
| 257 |
)
|
| 258 |
|
| 259 |
-
return
|
| 260 |
|
| 261 |
|
| 262 |
# ββ Example images ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -383,7 +367,6 @@ body, .gradio-container {
|
|
| 383 |
|
| 384 |
# ββ Architecture info (collapsible) ββββββββββββββββββββββββββββββββββββββββββ
|
| 385 |
ARCH_INFO = """
|
| 386 |
-
## Dense-Iso-ViT core claim
|
| 387 |
|
| 388 |
> "Isotropic constant-resolution hierarchical ViT with inter-stage dense feature aggregation β eliminating spatial bottlenecks while preserving coordinate integrity throughout all processing stages."
|
| 389 |
|
|
@@ -405,7 +388,6 @@ DRCT uses spatial downsampling and local block-level residuals. Dense-Iso-ViT ma
|
|
| 405 |
### Results
|
| 406 |
| Benchmark | PSNR | SSIM |
|
| 407 |
|-----------|------|------|
|
| 408 |
-
| LSDIR test | 24.11 dB | β |
|
| 409 |
| DIV2K validation | 25.20 dB | 0.8298 |
|
| 410 |
|
| 411 |
"""
|
|
|
|
| 182 |
|
| 183 |
|
| 184 |
# ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 185 |
+
model = ImageSRTransformer()
|
|
|
|
|
|
|
| 186 |
checkpoint = torch.load(
|
| 187 |
"sr_best_v4_resumed.pt",
|
| 188 |
+
map_location="cpu",
|
| 189 |
weights_only=False
|
| 190 |
)
|
| 191 |
model.load_state_dict(checkpoint["model_state_dict"])
|
| 192 |
model.eval()
|
| 193 |
+
print(f"Model loaded β val PSNR: {checkpoint['val_psnr']:.2f} dB")
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
# ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 197 |
@spaces.GPU
|
| 198 |
def run_sr(img_pil):
|
| 199 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 200 |
+
model.to(device)
|
| 201 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
w, h = img_pil.size
|
| 203 |
if w < 256 or h < 256:
|
| 204 |
scale = max(256 / w, 256 / h)
|
|
|
|
| 206 |
(int(w * scale), int(h * scale)), Image.BICUBIC)
|
| 207 |
w, h = img_pil.size
|
| 208 |
|
|
|
|
| 209 |
left = (w - 256) // 2
|
| 210 |
top = (h - 256) // 2
|
| 211 |
gt = img_pil.crop((left, top, left + 256, top + 256))
|
|
|
|
|
|
|
| 212 |
lr = gt.resize((64, 64), Image.BICUBIC)
|
| 213 |
|
| 214 |
+
lr_t = TF.to_tensor(lr).unsqueeze(0).to(device)
|
| 215 |
+
gt_t = TF.to_tensor(gt).unsqueeze(0).to(device)
|
|
|
|
| 216 |
|
| 217 |
with torch.no_grad():
|
| 218 |
+
with torch.autocast(device_type="cuda",
|
| 219 |
+
dtype=torch.bfloat16,
|
| 220 |
+
enabled=(device == "cuda")):
|
| 221 |
sr_t = model(lr_t)
|
| 222 |
sr_t = sr_t.float().clamp(0, 1)
|
| 223 |
|
|
|
|
| 224 |
lr_display_t = F.interpolate(
|
| 225 |
lr_t.float(), size=(256, 256),
|
| 226 |
mode="bilinear", align_corners=False)
|
| 227 |
|
|
|
|
| 228 |
psnr_lr = psnr(lr_display_t, gt_t).item()
|
| 229 |
ssim_lr = ssim(lr_display_t, gt_t)
|
| 230 |
psnr_sr = psnr(sr_t, gt_t).item()
|
| 231 |
ssim_sr = ssim(sr_t, gt_t)
|
| 232 |
|
|
|
|
| 233 |
def to_pil(t):
|
| 234 |
return TF.to_pil_image(t.squeeze(0).cpu())
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
metrics = (
|
| 237 |
f"**LR baseline** β PSNR: {psnr_lr:.2f} dB | SSIM: {ssim_lr:.4f}\n\n"
|
| 238 |
f"**SR output** β PSNR: {psnr_sr:.2f} dB | SSIM: {ssim_sr:.4f}\n\n"
|
|
|
|
| 240 |
f"ΞSSIM: +{ssim_sr - ssim_lr:.4f}"
|
| 241 |
)
|
| 242 |
|
| 243 |
+
return to_pil(lr_display_t), to_pil(sr_t), gt, metrics
|
| 244 |
|
| 245 |
|
| 246 |
# ββ Example images ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 367 |
|
| 368 |
# ββ Architecture info (collapsible) ββββββββββββββββββββββββββββββββββββββββββ
|
| 369 |
ARCH_INFO = """
|
|
|
|
| 370 |
|
| 371 |
> "Isotropic constant-resolution hierarchical ViT with inter-stage dense feature aggregation β eliminating spatial bottlenecks while preserving coordinate integrity throughout all processing stages."
|
| 372 |
|
|
|
|
| 388 |
### Results
|
| 389 |
| Benchmark | PSNR | SSIM |
|
| 390 |
|-----------|------|------|
|
|
|
|
| 391 |
| DIV2K validation | 25.20 dB | 0.8298 |
|
| 392 |
|
| 393 |
"""
|