Spaces:
Running
Running
| # patch gradio_client bug β "const" in bool raises TypeError | |
| import gradio_client.utils as _gcu | |
| _original_get_type = _gcu.get_type | |
| def _patched_get_type(schema): | |
| if not isinstance(schema, dict): | |
| return "Any" | |
| return _original_get_type(schema) | |
| _gcu.get_type = _patched_get_type | |
| # now import gradio normally | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| import math | |
| import numpy as np | |
| from PIL import Image | |
| class PatchEmbedding(nn.Module): | |
| def __init__(self, in_channels=3, embed_dim=192, patch_size=4, img_size=64): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.embed_dim = embed_dim | |
| num_patches = (img_size // patch_size) ** 2 | |
| self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim) * 0.02) | |
| nn.init.trunc_normal_(self.proj.weight, std=0.02) | |
| nn.init.zeros_(self.proj.bias) | |
| def forward(self, x): | |
| x = self.proj(x) # [B,3,64,64] β [B,192,16,16] | |
| x = x.flatten(2) # β [B,192,256] | |
| x = x.transpose(1, 2) # β [B,256,192] | |
| x = x + self.pos_embed | |
| return x | |
| class MultiHeadSelfAttention(nn.Module): | |
| def __init__(self, embed_dim=192, num_heads=6, attn_drop=0.1): | |
| super().__init__() | |
| assert embed_dim % num_heads == 0, \ | |
| f"embed_dim {embed_dim} must be divisible by num_heads {num_heads}" | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| def forward(self, x): | |
| q = self.q_proj(x) | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| B, N, D = q.shape | |
| q = q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) | |
| k = k.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) | |
| v = v.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) | |
| scores = (q @ k.transpose(2, 3)) / math.sqrt(self.head_dim) | |
| weights = self.attn_drop(F.softmax(scores, dim=-1)) | |
| output = (weights @ v).transpose(1, 2).contiguous() | |
| output = output.reshape(B, N, self.embed_dim) | |
| output = self.out_proj(output) | |
| return output | |
| class GDFN(nn.Module): | |
| def __init__(self, embed_dim, ffn_expansion = 4, grid_size = 16): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.ffn_expansion = ffn_expansion | |
| self.grid_size = grid_size | |
| self.hidden = embed_dim * ffn_expansion | |
| self.g_proj = nn.Linear(embed_dim, 2*self.hidden) | |
| self.path_1 = nn.Conv2d(self.hidden, self.hidden, 3, padding=1, groups=self.hidden) | |
| self.path_2 = nn.Conv2d(self.hidden, self.hidden, 3, padding=1, groups=self.hidden) | |
| self.out_proj = nn.Linear(self.hidden, embed_dim) | |
| def forward(self, x): | |
| x = self.g_proj(x) | |
| x1, x2 = x.chunk(2, dim=-1) | |
| B, N, C = x1.shape | |
| x1 = x1.transpose(1,2).reshape(B, C, self.grid_size, self.grid_size) | |
| x2 = x2.transpose(1,2).reshape(B, C, self.grid_size, self.grid_size) | |
| x1 = self.path_1(x1) | |
| x2 = self.path_2(x2) | |
| x1 = x1.flatten(2).transpose(1,2) | |
| x2 = x2.flatten(2).transpose(1,2) | |
| x = x1 * F.gelu(x2) | |
| x = self.out_proj(x) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, embed_dim=192, num_heads=6, attn_drop=0.1): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| self.attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_drop) | |
| self.ffn = GDFN(embed_dim) | |
| self.gamma1 = nn.Parameter(1e-4 * torch.ones(embed_dim)) | |
| self.gamma2 = nn.Parameter(1e-4 * torch.ones(embed_dim)) | |
| def forward(self, x): | |
| x = x + self.gamma1 * self.attn(self.norm1(x)) | |
| x = x + self.gamma2 * self.ffn(self.norm2(x)) | |
| return x | |
| class ImageSRTransformer(nn.Module): | |
| def __init__(self, | |
| embed_dims=[192, 256, 288, 384], | |
| num_heads=[6, 8, 6, 8], | |
| depths=[3, 3, 3, 3], | |
| patch_size=4, | |
| img_size=64): | |
| super().__init__() | |
| self.embed_dims = embed_dims | |
| self.grid_size = img_size // patch_size | |
| self.patch_embed = PatchEmbedding(embed_dim=embed_dims[0], patch_size=patch_size, img_size=img_size) | |
| self.stages = nn.ModuleList([ | |
| nn.ModuleList([TransformerBlock(embed_dims[i], num_heads[i]) | |
| for _ in range(depths[i])]) for i in range(len(embed_dims))]) | |
| self.projections = nn.ModuleList([ | |
| nn.Linear(embed_dims[i], embed_dims[i+1]) | |
| for i in range(len(embed_dims) - 1) | |
| ]) | |
| self.head = nn.Sequential( | |
| nn.Conv2d(1120, 192, 3, padding=1), | |
| nn.LeakyReLU(0.2), | |
| nn.Conv2d(192, 768, 3, padding=1), | |
| nn.PixelShuffle(2), | |
| nn.Conv2d(192, 768, 3, padding=1), | |
| nn.PixelShuffle(2), | |
| nn.Conv2d(192, 768, 3, padding=1), | |
| nn.PixelShuffle(2), | |
| nn.Conv2d(192, 768, 3, padding=1), | |
| nn.PixelShuffle(2), | |
| nn.Conv2d(192, 3, 3, padding=1) | |
| ) | |
| def forward(self, x): | |
| lr = x | |
| x = self.patch_embed(x) | |
| stage_outputs = [] | |
| for i, stage in enumerate(self.stages): | |
| for block in stage: | |
| x = block(x) | |
| stage_outputs.append(x) | |
| if i < len(self.stages) -1: | |
| x = self.projections[i](x) | |
| x = torch.cat(stage_outputs, dim=-1) | |
| B, N, C = x.shape | |
| x = x.transpose(1,2).reshape(B, C, self.grid_size, self.grid_size) | |
| x = self.head(x) | |
| base = F.interpolate(lr, size=(256, 256), mode="bicubic", align_corners=False) | |
| return base + x | |
| def psnr(pred, target): | |
| mse = torch.mean((pred - target) ** 2) | |
| return 10 * torch.log10(1.0 / (mse + 1e-8)) | |
| def ssim(pred, target, window_size=11): | |
| C1 = 0.01 ** 2 | |
| C2 = 0.03 ** 2 | |
| mu1 = F.avg_pool2d(pred, window_size, 1, window_size // 2) | |
| mu2 = F.avg_pool2d(target, window_size, 1, window_size // 2) | |
| mu1_sq = mu1 ** 2 | |
| mu2_sq = mu2 ** 2 | |
| mu1_mu2 = mu1 * mu2 | |
| s1 = F.avg_pool2d(pred ** 2, window_size, 1, window_size // 2) - mu1_sq | |
| s2 = F.avg_pool2d(target ** 2, window_size, 1, window_size // 2) - mu2_sq | |
| s12 = F.avg_pool2d(pred * target, window_size, 1, window_size // 2) - mu1_mu2 | |
| num = (2 * mu1_mu2 + C1) * (2 * s12 + C2) | |
| den = (mu1_sq + mu2_sq + C1) * (s1 + s2 + C2) | |
| return (num / den).mean().item() | |
| # ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model = ImageSRTransformer() | |
| checkpoint = torch.load( | |
| "sr_best_v4_resumed.pt", | |
| map_location="cpu", | |
| weights_only=False | |
| ) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| print(f"Model loaded β val PSNR: {checkpoint['val_psnr']:.2f} dB") | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_sr(img_pil): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| img_pil = img_pil.convert("RGB") | |
| w, h = img_pil.size | |
| # ensure minimum 256Γ256 for centre crop | |
| if w < 256 or h < 256: | |
| scale_up = max(256 / w, 256 / h) | |
| img_pil = img_pil.resize( | |
| (int(w * scale_up), int(h * scale_up)), Image.BICUBIC) | |
| w, h = img_pil.size | |
| # centre crop 256Γ256 β ground truth | |
| left = (w - 256) // 2 | |
| top = (h - 256) // 2 | |
| gt = img_pil.crop((left, top, left + 256, top + 256)) | |
| # bicubic downscale β 64Γ64 LR | |
| lr = gt.resize((64, 64), Image.BICUBIC) | |
| lr_t = TF.to_tensor(lr).unsqueeze(0).to(device) | |
| gt_t = TF.to_tensor(gt).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, | |
| enabled=(device == "cuda")): | |
| sr_t = model(lr_t) | |
| sr_t = sr_t.float().clamp(0, 1) | |
| lr_display_t = F.interpolate( | |
| lr_t.float(), size=(256, 256), mode="bilinear", align_corners=False) | |
| psnr_lr = psnr(lr_display_t, gt_t).item() | |
| ssim_lr = ssim(lr_display_t, gt_t) | |
| psnr_sr = psnr(sr_t, gt_t).item() | |
| ssim_sr = ssim(sr_t, gt_t) | |
| def to_pil(t): | |
| return TF.to_pil_image(t.squeeze(0).cpu()) | |
| metrics = ( | |
| f"**LR baseline** β PSNR: {psnr_lr:.2f} dB | SSIM: {ssim_lr:.4f}\n\n" | |
| f"**SR output** β PSNR: {psnr_sr:.2f} dB | SSIM: {ssim_sr:.4f}\n\n" | |
| f"**Improvement** β ΞPSNR: +{psnr_sr - psnr_lr:.2f} dB | " | |
| f"ΞSSIM: +{ssim_sr - ssim_lr:.4f}" | |
| ) | |
| return to_pil(lr_display_t), to_pil(sr_t), gt, metrics | |
| # ββ Example images ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 6 examples chosen to showcase V4 strengths: | |
| EXAMPLES = [ | |
| ["examples/urban.png"], | |
| ["examples/aerial.png"], | |
| ["examples/architecture.png"], | |
| ["examples/nature.png"], | |
| ["examples/portrait.png"], | |
| ["examples/texture.png"] | |
| ] | |
| # ββ CSS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| body, .gradio-container { | |
| background: #0a0c10 !important; | |
| color: #c9d1d9 !important; | |
| font-family: 'Inter', system-ui, sans-serif !important; | |
| } | |
| .header-block { | |
| text-align: center; | |
| padding: 2rem 1rem 1rem; | |
| border-bottom: 1px solid #1e2128; | |
| margin-bottom: 1.5rem; | |
| } | |
| .header-title { | |
| font-size: 1.8rem; | |
| font-weight: 700; | |
| color: #2dd4bf; | |
| letter-spacing: -0.02em; | |
| margin-bottom: 0.25rem; | |
| } | |
| .header-sub { | |
| font-size: 0.85rem; | |
| color: #6e7681; | |
| } | |
| .hint-text { | |
| font-size: 0.78rem; | |
| color: #6e7681; | |
| margin-top: 8px; | |
| line-height: 1.5; | |
| } | |
| .metric-box textarea, .metric-box .prose { | |
| background: #111318 !important; | |
| border: 1px solid #1e2128 !important; | |
| color: #c9d1d9 !important; | |
| border-radius: 8px !important; | |
| font-family: 'JetBrains Mono', monospace !important; | |
| font-size: 0.82rem !important; | |
| padding: 12px !important; | |
| } | |
| .run-btn { | |
| background: #2dd4bf !important; | |
| color: #0a0c10 !important; | |
| border: none !important; | |
| font-weight: 700 !important; | |
| font-size: 0.9rem !important; | |
| border-radius: 8px !important; | |
| padding: 0.6rem 2rem !important; | |
| } | |
| .run-btn:hover { background: #5eead4 !important; } | |
| .accordion { | |
| background: #111318 !important; | |
| border: 1px solid #1e2128 !important; | |
| border-radius: 8px !important; | |
| margin-top: 1.5rem !important; | |
| } | |
| .accordion .label-wrap { | |
| color: #6e7681 !important; | |
| font-size: 0.8rem !important; | |
| letter-spacing: 0.08em !important; | |
| text-transform: uppercase !important; | |
| } | |
| .image-frame img { | |
| border-radius: 8px !important; | |
| border: 1px solid #1e2128 !important; | |
| } | |
| """ | |
| # ββ Architecture info (collapsible) ββββββββββββββββββββββββββββββββββββββββββ | |
| ARCH_INFO = """ | |
| > "Isotropic constant-resolution hierarchical ViT with inter-stage dense feature aggregation β eliminating spatial bottlenecks while preserving coordinate integrity throughout all processing stages." | |
| ### Key architectural decisions | |
| **Isotropic token grid** β constant 16Γ16 spatial resolution across all 4 transformer stages. Zero patch merging, zero token downsampling. Every token maps to the same 4Γ4 pixel region from input to output. | |
| **Hierarchical embed dims [192 β 256 β 288 β 384]** β representational capacity scales with feature complexity. Early stages learn local edges and textures (192-dim is sufficient). Deep stages reason about global scene semantics (384-dim is necessary). | |
| **Inter-stage macro concatenation** β outputs from all 4 stages concatenated directly to the reconstruction head: `cat([h1, h2, h3, h4]) β [B, 256, 1120]`. The head receives low-level edge maps (h1) and high-level semantic context (h4) simultaneously. | |
| **GDFN feed-forward** β replaces standard MLPs with Gated Depthwise Feed-Forward Networks. Each token sees its 3Γ3 spatial neighborhood during the MLP step. Local spatial context injected at every attention layer. | |
| **Bilinear skip connection** β `output = F.interpolate(lr, 256Γ256) + vit_residual`. Model learns residual correction only, not full reconstruction from scratch. | |
| ### Results | |
| | Benchmark | Avg PSNR | Avg SSIM | | |
| |-----------|------|------| | |
| | DIV2K validation | 25.20 dB | 0.8298 | | |
| """ | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(css=CSS, title="Dense-Iso-ViT SR") as demo: | |
| gr.HTML(""" | |
| <div class="header-block"> | |
| <div class="header-title">Dense-Iso-ViT</div> | |
| <div class="header-sub"> | |
| Constant-Resolution Hierarchical Vision Transformer for Γ4 Image Super-Resolution | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_img = gr.Image( | |
| type="pil", | |
| label="Upload any image", | |
| elem_classes=["image-frame"], | |
| ) | |
| run_btn = gr.Button( | |
| "Run Γ4 Super-Resolution", | |
| variant="primary", | |
| elem_classes=["run-btn"], | |
| ) | |
| gr.Markdown( | |
| "π‘ For best results, upload a low-resolution image " | |
| "(256Γ256 or smaller). The model upscales it Γ4.", | |
| elem_classes=["hint-text"], | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| lr_out = gr.Image( | |
| label="LR Input (bilinear upscaled for display)", | |
| elem_classes=["image-frame"], | |
| ) | |
| sr_out = gr.Image( | |
| label="SR Output β Dense-Iso-ViT", | |
| elem_classes=["image-frame"], | |
| ) | |
| gt_out = gr.Image( | |
| label="Ground Truth (original crop)", | |
| elem_classes=["image-frame"], | |
| ) | |
| metrics_out = gr.Markdown(elem_classes=["metric-box"]) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[input_img], | |
| label="Examples β showing V4 strengths", | |
| ) | |
| with gr.Accordion( | |
| "Architecture details β Dense-Iso-ViT", | |
| open=False, | |
| elem_classes=["accordion"], | |
| ): | |
| gr.Markdown(ARCH_INFO) | |
| run_btn.click( | |
| fn=run_sr, | |
| inputs=[input_img], | |
| outputs=[lr_out, sr_out, gt_out, metrics_out], | |
| ) | |
| input_img.change( | |
| fn=run_sr, | |
| inputs=[input_img], | |
| outputs=[lr_out, sr_out, gt_out, metrics_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_api=False) |