SathyaSantosh77 commited on
Commit
a043ce4
Β·
1 Parent(s): 66e1339

fix zerogpu device handling

Browse files
Files changed (1) hide show
  1. app.py +12 -30
app.py CHANGED
@@ -182,30 +182,23 @@ def ssim(pred, target, window_size=11):
182
 
183
 
184
  # ── Load model ────────────────────────────────────────────────────────────────
185
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
186
-
187
- model = ImageSRTransformer().to(DEVICE)
188
  checkpoint = torch.load(
189
  "sr_best_v4_resumed.pt",
190
- map_location=DEVICE,
191
  weights_only=False
192
  )
193
  model.load_state_dict(checkpoint["model_state_dict"])
194
  model.eval()
195
- print(f"Model loaded β€” device: {DEVICE}")
196
- print(f"Best val PSNR: {checkpoint['val_psnr']:.2f} dB")
197
 
198
 
199
  # ── Inference ─────────────────────────────────────────────────────────────────
200
  @spaces.GPU
201
  def run_sr(img_pil):
202
- """
203
- Takes any PIL image.
204
- 1. Crops centre 256Γ—256 β†’ ground truth
205
- 2. Bicubic downscale to 64Γ—64 β†’ LR
206
- 3. Runs Dense-Iso-ViT SR
207
- 4. Returns (lr_display, sr_output, ground_truth, metrics_str)
208
- """
209
  w, h = img_pil.size
210
  if w < 256 or h < 256:
211
  scale = max(256 / w, 256 / h)
@@ -213,42 +206,33 @@ def run_sr(img_pil):
213
  (int(w * scale), int(h * scale)), Image.BICUBIC)
214
  w, h = img_pil.size
215
 
216
- # centre crop 256Γ—256
217
  left = (w - 256) // 2
218
  top = (h - 256) // 2
219
  gt = img_pil.crop((left, top, left + 256, top + 256))
220
-
221
- # LR β€” bicubic 64Γ—64
222
  lr = gt.resize((64, 64), Image.BICUBIC)
223
 
224
- # tensors
225
- lr_t = TF.to_tensor(lr).unsqueeze(0).to(DEVICE)
226
- gt_t = TF.to_tensor(gt).unsqueeze(0).to(DEVICE)
227
 
228
  with torch.no_grad():
229
- with torch.autocast(device_type=DEVICE, dtype=torch.bfloat16):
 
 
230
  sr_t = model(lr_t)
231
  sr_t = sr_t.float().clamp(0, 1)
232
 
233
- # LR display β€” bilinear upscale to 256 for side-by-side
234
  lr_display_t = F.interpolate(
235
  lr_t.float(), size=(256, 256),
236
  mode="bilinear", align_corners=False)
237
 
238
- # metrics β€” LR baseline vs SR
239
  psnr_lr = psnr(lr_display_t, gt_t).item()
240
  ssim_lr = ssim(lr_display_t, gt_t)
241
  psnr_sr = psnr(sr_t, gt_t).item()
242
  ssim_sr = ssim(sr_t, gt_t)
243
 
244
- # to PIL
245
  def to_pil(t):
246
  return TF.to_pil_image(t.squeeze(0).cpu())
247
 
248
- lr_img = to_pil(lr_display_t)
249
- sr_img = to_pil(sr_t)
250
- gt_img = gt
251
-
252
  metrics = (
253
  f"**LR baseline** β€” PSNR: {psnr_lr:.2f} dB | SSIM: {ssim_lr:.4f}\n\n"
254
  f"**SR output** β€” PSNR: {psnr_sr:.2f} dB | SSIM: {ssim_sr:.4f}\n\n"
@@ -256,7 +240,7 @@ def run_sr(img_pil):
256
  f"Ξ”SSIM: +{ssim_sr - ssim_lr:.4f}"
257
  )
258
 
259
- return lr_img, sr_img, gt_img, metrics
260
 
261
 
262
  # ── Example images ────────────────────────────────────────────────────────────
@@ -383,7 +367,6 @@ body, .gradio-container {
383
 
384
  # ── Architecture info (collapsible) ──────────────────────────────────────────
385
  ARCH_INFO = """
386
- ## Dense-Iso-ViT core claim
387
 
388
  > "Isotropic constant-resolution hierarchical ViT with inter-stage dense feature aggregation β€” eliminating spatial bottlenecks while preserving coordinate integrity throughout all processing stages."
389
 
@@ -405,7 +388,6 @@ DRCT uses spatial downsampling and local block-level residuals. Dense-Iso-ViT ma
405
  ### Results
406
  | Benchmark | PSNR | SSIM |
407
  |-----------|------|------|
408
- | LSDIR test | 24.11 dB | β€” |
409
  | DIV2K validation | 25.20 dB | 0.8298 |
410
 
411
  """
 
182
 
183
 
184
  # ── Load model ────────────────────────────────────────────────────────────────
185
+ model = ImageSRTransformer()
 
 
186
  checkpoint = torch.load(
187
  "sr_best_v4_resumed.pt",
188
+ map_location="cpu",
189
  weights_only=False
190
  )
191
  model.load_state_dict(checkpoint["model_state_dict"])
192
  model.eval()
193
+ print(f"Model loaded β€” val PSNR: {checkpoint['val_psnr']:.2f} dB")
 
194
 
195
 
196
  # ── Inference ─────────────────────────────────────────────────────────────────
197
  @spaces.GPU
198
  def run_sr(img_pil):
199
+ device = "cuda" if torch.cuda.is_available() else "cpu"
200
+ model.to(device)
201
+
 
 
 
 
202
  w, h = img_pil.size
203
  if w < 256 or h < 256:
204
  scale = max(256 / w, 256 / h)
 
206
  (int(w * scale), int(h * scale)), Image.BICUBIC)
207
  w, h = img_pil.size
208
 
 
209
  left = (w - 256) // 2
210
  top = (h - 256) // 2
211
  gt = img_pil.crop((left, top, left + 256, top + 256))
 
 
212
  lr = gt.resize((64, 64), Image.BICUBIC)
213
 
214
+ lr_t = TF.to_tensor(lr).unsqueeze(0).to(device)
215
+ gt_t = TF.to_tensor(gt).unsqueeze(0).to(device)
 
216
 
217
  with torch.no_grad():
218
+ with torch.autocast(device_type="cuda",
219
+ dtype=torch.bfloat16,
220
+ enabled=(device == "cuda")):
221
  sr_t = model(lr_t)
222
  sr_t = sr_t.float().clamp(0, 1)
223
 
 
224
  lr_display_t = F.interpolate(
225
  lr_t.float(), size=(256, 256),
226
  mode="bilinear", align_corners=False)
227
 
 
228
  psnr_lr = psnr(lr_display_t, gt_t).item()
229
  ssim_lr = ssim(lr_display_t, gt_t)
230
  psnr_sr = psnr(sr_t, gt_t).item()
231
  ssim_sr = ssim(sr_t, gt_t)
232
 
 
233
  def to_pil(t):
234
  return TF.to_pil_image(t.squeeze(0).cpu())
235
 
 
 
 
 
236
  metrics = (
237
  f"**LR baseline** β€” PSNR: {psnr_lr:.2f} dB | SSIM: {ssim_lr:.4f}\n\n"
238
  f"**SR output** β€” PSNR: {psnr_sr:.2f} dB | SSIM: {ssim_sr:.4f}\n\n"
 
240
  f"Ξ”SSIM: +{ssim_sr - ssim_lr:.4f}"
241
  )
242
 
243
+ return to_pil(lr_display_t), to_pil(sr_t), gt, metrics
244
 
245
 
246
  # ── Example images ────────────────────────────────────────────────────────────
 
367
 
368
  # ── Architecture info (collapsible) ──────────────────────────────────────────
369
  ARCH_INFO = """
 
370
 
371
  > "Isotropic constant-resolution hierarchical ViT with inter-stage dense feature aggregation β€” eliminating spatial bottlenecks while preserving coordinate integrity throughout all processing stages."
372
 
 
388
  ### Results
389
  | Benchmark | PSNR | SSIM |
390
  |-----------|------|------|
 
391
  | DIV2K validation | 25.20 dB | 0.8298 |
392
 
393
  """