# patch gradio_client bug — "const" in bool raises TypeError import gradio_client.utils as _gcu _original_get_type = _gcu.get_type def _patched_get_type(schema): if not isinstance(schema, dict): return "Any" return _original_get_type(schema) _gcu.get_type = _patched_get_type # now import gradio normally import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF import math import numpy as np from PIL import Image class PatchEmbedding(nn.Module): def __init__(self, in_channels=3, embed_dim=192, patch_size=4, img_size=64): super().__init__() self.in_channels = in_channels self.embed_dim = embed_dim num_patches = (img_size // patch_size) ** 2 self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim) * 0.02) nn.init.trunc_normal_(self.proj.weight, std=0.02) nn.init.zeros_(self.proj.bias) def forward(self, x): x = self.proj(x) # [B,3,64,64] → [B,192,16,16] x = x.flatten(2) # → [B,192,256] x = x.transpose(1, 2) # → [B,256,192] x = x + self.pos_embed return x class MultiHeadSelfAttention(nn.Module): def __init__(self, embed_dim=192, num_heads=6, attn_drop=0.1): super().__init__() assert embed_dim % num_heads == 0, \ f"embed_dim {embed_dim} must be divisible by num_heads {num_heads}" self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.attn_drop = nn.Dropout(attn_drop) def forward(self, x): q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) B, N, D = q.shape q = q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) k = k.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) v = v.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) scores = (q @ k.transpose(2, 3)) / math.sqrt(self.head_dim) weights = self.attn_drop(F.softmax(scores, dim=-1)) output = (weights @ v).transpose(1, 2).contiguous() output = output.reshape(B, N, self.embed_dim) output = self.out_proj(output) return output class GDFN(nn.Module): def __init__(self, embed_dim, ffn_expansion = 4, grid_size = 16): super().__init__() self.embed_dim = embed_dim self.ffn_expansion = ffn_expansion self.grid_size = grid_size self.hidden = embed_dim * ffn_expansion self.g_proj = nn.Linear(embed_dim, 2*self.hidden) self.path_1 = nn.Conv2d(self.hidden, self.hidden, 3, padding=1, groups=self.hidden) self.path_2 = nn.Conv2d(self.hidden, self.hidden, 3, padding=1, groups=self.hidden) self.out_proj = nn.Linear(self.hidden, embed_dim) def forward(self, x): x = self.g_proj(x) x1, x2 = x.chunk(2, dim=-1) B, N, C = x1.shape x1 = x1.transpose(1,2).reshape(B, C, self.grid_size, self.grid_size) x2 = x2.transpose(1,2).reshape(B, C, self.grid_size, self.grid_size) x1 = self.path_1(x1) x2 = self.path_2(x2) x1 = x1.flatten(2).transpose(1,2) x2 = x2.flatten(2).transpose(1,2) x = x1 * F.gelu(x2) x = self.out_proj(x) return x class TransformerBlock(nn.Module): def __init__(self, embed_dim=192, num_heads=6, attn_drop=0.1): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_drop) self.ffn = GDFN(embed_dim) self.gamma1 = nn.Parameter(1e-4 * torch.ones(embed_dim)) self.gamma2 = nn.Parameter(1e-4 * torch.ones(embed_dim)) def forward(self, x): x = x + self.gamma1 * self.attn(self.norm1(x)) x = x + self.gamma2 * self.ffn(self.norm2(x)) return x class ImageSRTransformer(nn.Module): def __init__(self, embed_dims=[192, 256, 288, 384], num_heads=[6, 8, 6, 8], depths=[3, 3, 3, 3], patch_size=4, img_size=64): super().__init__() self.embed_dims = embed_dims self.grid_size = img_size // patch_size self.patch_embed = PatchEmbedding(embed_dim=embed_dims[0], patch_size=patch_size, img_size=img_size) self.stages = nn.ModuleList([ nn.ModuleList([TransformerBlock(embed_dims[i], num_heads[i]) for _ in range(depths[i])]) for i in range(len(embed_dims))]) self.projections = nn.ModuleList([ nn.Linear(embed_dims[i], embed_dims[i+1]) for i in range(len(embed_dims) - 1) ]) self.head = nn.Sequential( nn.Conv2d(1120, 192, 3, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(192, 768, 3, padding=1), nn.PixelShuffle(2), nn.Conv2d(192, 768, 3, padding=1), nn.PixelShuffle(2), nn.Conv2d(192, 768, 3, padding=1), nn.PixelShuffle(2), nn.Conv2d(192, 768, 3, padding=1), nn.PixelShuffle(2), nn.Conv2d(192, 3, 3, padding=1) ) def forward(self, x): lr = x x = self.patch_embed(x) stage_outputs = [] for i, stage in enumerate(self.stages): for block in stage: x = block(x) stage_outputs.append(x) if i < len(self.stages) -1: x = self.projections[i](x) x = torch.cat(stage_outputs, dim=-1) B, N, C = x.shape x = x.transpose(1,2).reshape(B, C, self.grid_size, self.grid_size) x = self.head(x) base = F.interpolate(lr, size=(256, 256), mode="bicubic", align_corners=False) return base + x def psnr(pred, target): mse = torch.mean((pred - target) ** 2) return 10 * torch.log10(1.0 / (mse + 1e-8)) def ssim(pred, target, window_size=11): C1 = 0.01 ** 2 C2 = 0.03 ** 2 mu1 = F.avg_pool2d(pred, window_size, 1, window_size // 2) mu2 = F.avg_pool2d(target, window_size, 1, window_size // 2) mu1_sq = mu1 ** 2 mu2_sq = mu2 ** 2 mu1_mu2 = mu1 * mu2 s1 = F.avg_pool2d(pred ** 2, window_size, 1, window_size // 2) - mu1_sq s2 = F.avg_pool2d(target ** 2, window_size, 1, window_size // 2) - mu2_sq s12 = F.avg_pool2d(pred * target, window_size, 1, window_size // 2) - mu1_mu2 num = (2 * mu1_mu2 + C1) * (2 * s12 + C2) den = (mu1_sq + mu2_sq + C1) * (s1 + s2 + C2) return (num / den).mean().item() # ── Load model ──────────────────────────────────────────────────────────────── model = ImageSRTransformer() checkpoint = torch.load( "sr_best_v4_resumed.pt", map_location="cpu", weights_only=False ) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() print(f"Model loaded — val PSNR: {checkpoint['val_psnr']:.2f} dB") # ── Inference ───────────────────────────────────────────────────────────────── def run_sr(img_pil): device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) img_pil = img_pil.convert("RGB") w, h = img_pil.size # ensure minimum 256×256 for centre crop if w < 256 or h < 256: scale_up = max(256 / w, 256 / h) img_pil = img_pil.resize( (int(w * scale_up), int(h * scale_up)), Image.BICUBIC) w, h = img_pil.size # centre crop 256×256 → ground truth left = (w - 256) // 2 top = (h - 256) // 2 gt = img_pil.crop((left, top, left + 256, top + 256)) # bicubic downscale → 64×64 LR lr = gt.resize((64, 64), Image.BICUBIC) lr_t = TF.to_tensor(lr).unsqueeze(0).to(device) gt_t = TF.to_tensor(gt).unsqueeze(0).to(device) with torch.no_grad(): with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(device == "cuda")): sr_t = model(lr_t) sr_t = sr_t.float().clamp(0, 1) lr_display_t = F.interpolate( lr_t.float(), size=(256, 256), mode="bilinear", align_corners=False) psnr_lr = psnr(lr_display_t, gt_t).item() ssim_lr = ssim(lr_display_t, gt_t) psnr_sr = psnr(sr_t, gt_t).item() ssim_sr = ssim(sr_t, gt_t) def to_pil(t): return TF.to_pil_image(t.squeeze(0).cpu()) metrics = ( f"**LR baseline** — PSNR: {psnr_lr:.2f} dB | SSIM: {ssim_lr:.4f}\n\n" f"**SR output** — PSNR: {psnr_sr:.2f} dB | SSIM: {ssim_sr:.4f}\n\n" f"**Improvement** — ΔPSNR: +{psnr_sr - psnr_lr:.2f} dB | " f"ΔSSIM: +{ssim_sr - ssim_lr:.4f}" ) return to_pil(lr_display_t), to_pil(sr_t), gt, metrics # ── Example images ──────────────────────────────────────────────────────────── # 6 examples chosen to showcase V4 strengths: EXAMPLES = [ ["examples/urban.png"], ["examples/aerial.png"], ["examples/architecture.png"], ["examples/nature.png"], ["examples/portrait.png"], ["examples/texture.png"] ] # ── CSS ─────────────────────────────────────────────────────────────────────── CSS = """ body, .gradio-container { background: #0a0c10 !important; color: #c9d1d9 !important; font-family: 'Inter', system-ui, sans-serif !important; } .header-block { text-align: center; padding: 2rem 1rem 1rem; border-bottom: 1px solid #1e2128; margin-bottom: 1.5rem; } .header-title { font-size: 1.8rem; font-weight: 700; color: #2dd4bf; letter-spacing: -0.02em; margin-bottom: 0.25rem; } .header-sub { font-size: 0.85rem; color: #6e7681; } .hint-text { font-size: 0.78rem; color: #6e7681; margin-top: 8px; line-height: 1.5; } .metric-box textarea, .metric-box .prose { background: #111318 !important; border: 1px solid #1e2128 !important; color: #c9d1d9 !important; border-radius: 8px !important; font-family: 'JetBrains Mono', monospace !important; font-size: 0.82rem !important; padding: 12px !important; } .run-btn { background: #2dd4bf !important; color: #0a0c10 !important; border: none !important; font-weight: 700 !important; font-size: 0.9rem !important; border-radius: 8px !important; padding: 0.6rem 2rem !important; } .run-btn:hover { background: #5eead4 !important; } .accordion { background: #111318 !important; border: 1px solid #1e2128 !important; border-radius: 8px !important; margin-top: 1.5rem !important; } .accordion .label-wrap { color: #6e7681 !important; font-size: 0.8rem !important; letter-spacing: 0.08em !important; text-transform: uppercase !important; } .image-frame img { border-radius: 8px !important; border: 1px solid #1e2128 !important; } """ # ── Architecture info (collapsible) ────────────────────────────────────────── ARCH_INFO = """ > "Isotropic constant-resolution hierarchical ViT with inter-stage dense feature aggregation — eliminating spatial bottlenecks while preserving coordinate integrity throughout all processing stages." ### Key architectural decisions **Isotropic token grid** — constant 16×16 spatial resolution across all 4 transformer stages. Zero patch merging, zero token downsampling. Every token maps to the same 4×4 pixel region from input to output. **Hierarchical embed dims [192 → 256 → 288 → 384]** — representational capacity scales with feature complexity. Early stages learn local edges and textures (192-dim is sufficient). Deep stages reason about global scene semantics (384-dim is necessary). **Inter-stage macro concatenation** — outputs from all 4 stages concatenated directly to the reconstruction head: `cat([h1, h2, h3, h4]) → [B, 256, 1120]`. The head receives low-level edge maps (h1) and high-level semantic context (h4) simultaneously. **GDFN feed-forward** — replaces standard MLPs with Gated Depthwise Feed-Forward Networks. Each token sees its 3×3 spatial neighborhood during the MLP step. Local spatial context injected at every attention layer. **Bilinear skip connection** — `output = F.interpolate(lr, 256×256) + vit_residual`. Model learns residual correction only, not full reconstruction from scratch. ### Results | Benchmark | Avg PSNR | Avg SSIM | |-----------|------|------| | DIV2K validation | 25.20 dB | 0.8298 | """ # ── Gradio UI ───────────────────────────────────────────────────────────────── with gr.Blocks(css=CSS, title="Dense-Iso-ViT SR") as demo: gr.HTML("""
Dense-Iso-ViT
Constant-Resolution Hierarchical Vision Transformer for ×4 Image Super-Resolution
""") with gr.Row(): with gr.Column(scale=1): input_img = gr.Image( type="pil", label="Upload any image", elem_classes=["image-frame"], ) run_btn = gr.Button( "Run ×4 Super-Resolution", variant="primary", elem_classes=["run-btn"], ) gr.Markdown( "💡 For best results, upload a low-resolution image " "(256×256 or smaller). The model upscales it ×4.", elem_classes=["hint-text"], ) with gr.Column(scale=3): with gr.Row(): lr_out = gr.Image( label="LR Input (bilinear upscaled for display)", elem_classes=["image-frame"], ) sr_out = gr.Image( label="SR Output — Dense-Iso-ViT", elem_classes=["image-frame"], ) gt_out = gr.Image( label="Ground Truth (original crop)", elem_classes=["image-frame"], ) metrics_out = gr.Markdown(elem_classes=["metric-box"]) gr.Examples( examples=EXAMPLES, inputs=[input_img], label="Examples — showing V4 strengths", ) with gr.Accordion( "Architecture details — Dense-Iso-ViT", open=False, elem_classes=["accordion"], ): gr.Markdown(ARCH_INFO) run_btn.click( fn=run_sr, inputs=[input_img], outputs=[lr_out, sr_out, gt_out, metrics_out], ) input_img.change( fn=run_sr, inputs=[input_img], outputs=[lr_out, sr_out, gt_out, metrics_out], ) if __name__ == "__main__": demo.launch(show_api=False)