SathyaSantosh77
clean UI
8860868
Raw
History Blame Contribute Delete
15.9 kB
# patch gradio_client bug β€” "const" in bool raises TypeError
import gradio_client.utils as _gcu
_original_get_type = _gcu.get_type
def _patched_get_type(schema):
if not isinstance(schema, dict):
return "Any"
return _original_get_type(schema)
_gcu.get_type = _patched_get_type
# now import gradio normally
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import math
import numpy as np
from PIL import Image
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, embed_dim=192, patch_size=4, img_size=64):
super().__init__()
self.in_channels = in_channels
self.embed_dim = embed_dim
num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim) * 0.02)
nn.init.trunc_normal_(self.proj.weight, std=0.02)
nn.init.zeros_(self.proj.bias)
def forward(self, x):
x = self.proj(x) # [B,3,64,64] β†’ [B,192,16,16]
x = x.flatten(2) # β†’ [B,192,256]
x = x.transpose(1, 2) # β†’ [B,256,192]
x = x + self.pos_embed
return x
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_dim=192, num_heads=6, attn_drop=0.1):
super().__init__()
assert embed_dim % num_heads == 0, \
f"embed_dim {embed_dim} must be divisible by num_heads {num_heads}"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.attn_drop = nn.Dropout(attn_drop)
def forward(self, x):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
B, N, D = q.shape
q = q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = k.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
v = v.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
scores = (q @ k.transpose(2, 3)) / math.sqrt(self.head_dim)
weights = self.attn_drop(F.softmax(scores, dim=-1))
output = (weights @ v).transpose(1, 2).contiguous()
output = output.reshape(B, N, self.embed_dim)
output = self.out_proj(output)
return output
class GDFN(nn.Module):
def __init__(self, embed_dim, ffn_expansion = 4, grid_size = 16):
super().__init__()
self.embed_dim = embed_dim
self.ffn_expansion = ffn_expansion
self.grid_size = grid_size
self.hidden = embed_dim * ffn_expansion
self.g_proj = nn.Linear(embed_dim, 2*self.hidden)
self.path_1 = nn.Conv2d(self.hidden, self.hidden, 3, padding=1, groups=self.hidden)
self.path_2 = nn.Conv2d(self.hidden, self.hidden, 3, padding=1, groups=self.hidden)
self.out_proj = nn.Linear(self.hidden, embed_dim)
def forward(self, x):
x = self.g_proj(x)
x1, x2 = x.chunk(2, dim=-1)
B, N, C = x1.shape
x1 = x1.transpose(1,2).reshape(B, C, self.grid_size, self.grid_size)
x2 = x2.transpose(1,2).reshape(B, C, self.grid_size, self.grid_size)
x1 = self.path_1(x1)
x2 = self.path_2(x2)
x1 = x1.flatten(2).transpose(1,2)
x2 = x2.flatten(2).transpose(1,2)
x = x1 * F.gelu(x2)
x = self.out_proj(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, embed_dim=192, num_heads=6, attn_drop=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_drop)
self.ffn = GDFN(embed_dim)
self.gamma1 = nn.Parameter(1e-4 * torch.ones(embed_dim))
self.gamma2 = nn.Parameter(1e-4 * torch.ones(embed_dim))
def forward(self, x):
x = x + self.gamma1 * self.attn(self.norm1(x))
x = x + self.gamma2 * self.ffn(self.norm2(x))
return x
class ImageSRTransformer(nn.Module):
def __init__(self,
embed_dims=[192, 256, 288, 384],
num_heads=[6, 8, 6, 8],
depths=[3, 3, 3, 3],
patch_size=4,
img_size=64):
super().__init__()
self.embed_dims = embed_dims
self.grid_size = img_size // patch_size
self.patch_embed = PatchEmbedding(embed_dim=embed_dims[0], patch_size=patch_size, img_size=img_size)
self.stages = nn.ModuleList([
nn.ModuleList([TransformerBlock(embed_dims[i], num_heads[i])
for _ in range(depths[i])]) for i in range(len(embed_dims))])
self.projections = nn.ModuleList([
nn.Linear(embed_dims[i], embed_dims[i+1])
for i in range(len(embed_dims) - 1)
])
self.head = nn.Sequential(
nn.Conv2d(1120, 192, 3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(192, 768, 3, padding=1),
nn.PixelShuffle(2),
nn.Conv2d(192, 768, 3, padding=1),
nn.PixelShuffle(2),
nn.Conv2d(192, 768, 3, padding=1),
nn.PixelShuffle(2),
nn.Conv2d(192, 768, 3, padding=1),
nn.PixelShuffle(2),
nn.Conv2d(192, 3, 3, padding=1)
)
def forward(self, x):
lr = x
x = self.patch_embed(x)
stage_outputs = []
for i, stage in enumerate(self.stages):
for block in stage:
x = block(x)
stage_outputs.append(x)
if i < len(self.stages) -1:
x = self.projections[i](x)
x = torch.cat(stage_outputs, dim=-1)
B, N, C = x.shape
x = x.transpose(1,2).reshape(B, C, self.grid_size, self.grid_size)
x = self.head(x)
base = F.interpolate(lr, size=(256, 256), mode="bicubic", align_corners=False)
return base + x
def psnr(pred, target):
mse = torch.mean((pred - target) ** 2)
return 10 * torch.log10(1.0 / (mse + 1e-8))
def ssim(pred, target, window_size=11):
C1 = 0.01 ** 2
C2 = 0.03 ** 2
mu1 = F.avg_pool2d(pred, window_size, 1, window_size // 2)
mu2 = F.avg_pool2d(target, window_size, 1, window_size // 2)
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
s1 = F.avg_pool2d(pred ** 2, window_size, 1, window_size // 2) - mu1_sq
s2 = F.avg_pool2d(target ** 2, window_size, 1, window_size // 2) - mu2_sq
s12 = F.avg_pool2d(pred * target, window_size, 1, window_size // 2) - mu1_mu2
num = (2 * mu1_mu2 + C1) * (2 * s12 + C2)
den = (mu1_sq + mu2_sq + C1) * (s1 + s2 + C2)
return (num / den).mean().item()
# ── Load model ────────────────────────────────────────────────────────────────
model = ImageSRTransformer()
checkpoint = torch.load(
"sr_best_v4_resumed.pt",
map_location="cpu",
weights_only=False
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print(f"Model loaded β€” val PSNR: {checkpoint['val_psnr']:.2f} dB")
# ── Inference ─────────────────────────────────────────────────────────────────
def run_sr(img_pil):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
img_pil = img_pil.convert("RGB")
w, h = img_pil.size
# ensure minimum 256Γ—256 for centre crop
if w < 256 or h < 256:
scale_up = max(256 / w, 256 / h)
img_pil = img_pil.resize(
(int(w * scale_up), int(h * scale_up)), Image.BICUBIC)
w, h = img_pil.size
# centre crop 256Γ—256 β†’ ground truth
left = (w - 256) // 2
top = (h - 256) // 2
gt = img_pil.crop((left, top, left + 256, top + 256))
# bicubic downscale β†’ 64Γ—64 LR
lr = gt.resize((64, 64), Image.BICUBIC)
lr_t = TF.to_tensor(lr).unsqueeze(0).to(device)
gt_t = TF.to_tensor(gt).unsqueeze(0).to(device)
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
enabled=(device == "cuda")):
sr_t = model(lr_t)
sr_t = sr_t.float().clamp(0, 1)
lr_display_t = F.interpolate(
lr_t.float(), size=(256, 256), mode="bilinear", align_corners=False)
psnr_lr = psnr(lr_display_t, gt_t).item()
ssim_lr = ssim(lr_display_t, gt_t)
psnr_sr = psnr(sr_t, gt_t).item()
ssim_sr = ssim(sr_t, gt_t)
def to_pil(t):
return TF.to_pil_image(t.squeeze(0).cpu())
metrics = (
f"**LR baseline** β€” PSNR: {psnr_lr:.2f} dB | SSIM: {ssim_lr:.4f}\n\n"
f"**SR output** β€” PSNR: {psnr_sr:.2f} dB | SSIM: {ssim_sr:.4f}\n\n"
f"**Improvement** β€” Ξ”PSNR: +{psnr_sr - psnr_lr:.2f} dB | "
f"Ξ”SSIM: +{ssim_sr - ssim_lr:.4f}"
)
return to_pil(lr_display_t), to_pil(sr_t), gt, metrics
# ── Example images ────────────────────────────────────────────────────────────
# 6 examples chosen to showcase V4 strengths:
EXAMPLES = [
["examples/urban.png"],
["examples/aerial.png"],
["examples/architecture.png"],
["examples/nature.png"],
["examples/portrait.png"],
["examples/texture.png"]
]
# ── CSS ───────────────────────────────────────────────────────────────────────
CSS = """
body, .gradio-container {
background: #0a0c10 !important;
color: #c9d1d9 !important;
font-family: 'Inter', system-ui, sans-serif !important;
}
.header-block {
text-align: center;
padding: 2rem 1rem 1rem;
border-bottom: 1px solid #1e2128;
margin-bottom: 1.5rem;
}
.header-title {
font-size: 1.8rem;
font-weight: 700;
color: #2dd4bf;
letter-spacing: -0.02em;
margin-bottom: 0.25rem;
}
.header-sub {
font-size: 0.85rem;
color: #6e7681;
}
.hint-text {
font-size: 0.78rem;
color: #6e7681;
margin-top: 8px;
line-height: 1.5;
}
.metric-box textarea, .metric-box .prose {
background: #111318 !important;
border: 1px solid #1e2128 !important;
color: #c9d1d9 !important;
border-radius: 8px !important;
font-family: 'JetBrains Mono', monospace !important;
font-size: 0.82rem !important;
padding: 12px !important;
}
.run-btn {
background: #2dd4bf !important;
color: #0a0c10 !important;
border: none !important;
font-weight: 700 !important;
font-size: 0.9rem !important;
border-radius: 8px !important;
padding: 0.6rem 2rem !important;
}
.run-btn:hover { background: #5eead4 !important; }
.accordion {
background: #111318 !important;
border: 1px solid #1e2128 !important;
border-radius: 8px !important;
margin-top: 1.5rem !important;
}
.accordion .label-wrap {
color: #6e7681 !important;
font-size: 0.8rem !important;
letter-spacing: 0.08em !important;
text-transform: uppercase !important;
}
.image-frame img {
border-radius: 8px !important;
border: 1px solid #1e2128 !important;
}
"""
# ── Architecture info (collapsible) ──────────────────────────────────────────
ARCH_INFO = """
> "Isotropic constant-resolution hierarchical ViT with inter-stage dense feature aggregation β€” eliminating spatial bottlenecks while preserving coordinate integrity throughout all processing stages."
### Key architectural decisions
**Isotropic token grid** β€” constant 16Γ—16 spatial resolution across all 4 transformer stages. Zero patch merging, zero token downsampling. Every token maps to the same 4Γ—4 pixel region from input to output.
**Hierarchical embed dims [192 β†’ 256 β†’ 288 β†’ 384]** β€” representational capacity scales with feature complexity. Early stages learn local edges and textures (192-dim is sufficient). Deep stages reason about global scene semantics (384-dim is necessary).
**Inter-stage macro concatenation** β€” outputs from all 4 stages concatenated directly to the reconstruction head: `cat([h1, h2, h3, h4]) β†’ [B, 256, 1120]`. The head receives low-level edge maps (h1) and high-level semantic context (h4) simultaneously.
**GDFN feed-forward** β€” replaces standard MLPs with Gated Depthwise Feed-Forward Networks. Each token sees its 3Γ—3 spatial neighborhood during the MLP step. Local spatial context injected at every attention layer.
**Bilinear skip connection** β€” `output = F.interpolate(lr, 256Γ—256) + vit_residual`. Model learns residual correction only, not full reconstruction from scratch.
### Results
| Benchmark | Avg PSNR | Avg SSIM |
|-----------|------|------|
| DIV2K validation | 25.20 dB | 0.8298 |
"""
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(css=CSS, title="Dense-Iso-ViT SR") as demo:
gr.HTML("""
<div class="header-block">
<div class="header-title">Dense-Iso-ViT</div>
<div class="header-sub">
Constant-Resolution Hierarchical Vision Transformer for Γ—4 Image Super-Resolution
</div>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(
type="pil",
label="Upload any image",
elem_classes=["image-frame"],
)
run_btn = gr.Button(
"Run Γ—4 Super-Resolution",
variant="primary",
elem_classes=["run-btn"],
)
gr.Markdown(
"πŸ’‘ For best results, upload a low-resolution image "
"(256Γ—256 or smaller). The model upscales it Γ—4.",
elem_classes=["hint-text"],
)
with gr.Column(scale=3):
with gr.Row():
lr_out = gr.Image(
label="LR Input (bilinear upscaled for display)",
elem_classes=["image-frame"],
)
sr_out = gr.Image(
label="SR Output β€” Dense-Iso-ViT",
elem_classes=["image-frame"],
)
gt_out = gr.Image(
label="Ground Truth (original crop)",
elem_classes=["image-frame"],
)
metrics_out = gr.Markdown(elem_classes=["metric-box"])
gr.Examples(
examples=EXAMPLES,
inputs=[input_img],
label="Examples β€” showing V4 strengths",
)
with gr.Accordion(
"Architecture details β€” Dense-Iso-ViT",
open=False,
elem_classes=["accordion"],
):
gr.Markdown(ARCH_INFO)
run_btn.click(
fn=run_sr,
inputs=[input_img],
outputs=[lr_out, sr_out, gt_out, metrics_out],
)
input_img.change(
fn=run_sr,
inputs=[input_img],
outputs=[lr_out, sr_out, gt_out, metrics_out],
)
if __name__ == "__main__":
demo.launch(show_api=False)