Dense-Iso-ViT-SR / README.md
SathyaSantosh77
update content
d073163
|
Raw
History Blame Contribute Delete
5.35 kB

A newer version of the Gradio SDK is available: 6.19.0

Upgrade
metadata
title: Dense-Iso-ViT SR
emoji: πŸ”­
colorFrom: blue
colorTo: green
sdk: gradio
sdk_version: 5.16.0
python_version: '3.10'
app_file: app.py
pinned: true
license: mit
tags:
  - image-super-resolution
  - vision-transformer
  - super-resolution
  - isotropic-vit
  - dense-connections
  - GDFN
  - pytorch
  - image-restoration
  - computer-vision
  - encoder-only-vision-transformer-super-resolution
  - constant-spatial-resolution
  - DenseNet-style-feature-propagation
  - isotropic-token-grid
  - gated-depthwise-feed-forward-network

Dense-Iso-ViT

Constant-Resolution Hierarchical Vision Transformer for Γ—4 Image Super-Resolution


What is Dense-Iso-ViT?

Dense-Iso-ViT is a pure Vision Transformer built from scratch for Γ—4 image super-resolution. The core design idea: keep the spatial token grid constant throughout all transformer stages β€” no patch merging, no spatial compression β€” and connect all stage outputs directly to the reconstruction head using DenseNet-style dense concatenation.

The name captures the two central ideas:

  • Dense β€” dense inter-stage feature aggregation (DenseNet principle applied to transformers)
  • Iso β€” isotropic spatial resolution, constant 16Γ—16 token grid across all 4 stages

Architecture

How it works

Input [B, 3, 64, 64]
  β†’ PatchEmbedding (patch=4) β†’ [B, 256, 192]   (16Γ—16 grid, fixed throughout)

  β†’ Stage 1: 3Γ— TransformerBlock(embed=192, heads=6, GDFN) β†’ h1 [B, 256, 192]
  β†’ Linear(192β†’256)

  β†’ Stage 2: 3Γ— TransformerBlock(embed=256, heads=8, GDFN) β†’ h2 [B, 256, 256]
  β†’ Linear(256β†’288)

  β†’ Stage 3: 3Γ— TransformerBlock(embed=288, heads=6, GDFN) β†’ h3 [B, 256, 288]
  β†’ Linear(288β†’384)

  β†’ Stage 4: 3Γ— TransformerBlock(embed=384, heads=8, GDFN) β†’ h4 [B, 256, 384]

  β†’ Dense concat: cat([h1, h2, h3, h4]) β†’ [B, 256, 1120] β†’ [B, 1120, 16, 16]
  β†’ SR Head: fusion conv β†’ 4Γ— PixelShuffle β†’ [B, 3, 256, 256]
  β†’ + F.interpolate(lr_img, 256Γ—256) bilinear skip
  β†’ Output [B, 3, 256, 256]

Design decisions and reasoning

Isotropic token grid β€” constant 16Γ—16 spatial resolution

All 4 transformer stages operate on the same 256 tokens (16Γ—16 grid). No patch merging, no token downsampling at any point. Every token maps to the same 4Γ—4 pixel region from input through to reconstruction β€” spatial coordinates are preserved exactly throughout the network.

Hierarchical embed dims [192, 256, 288, 384]

Representational capacity increases with depth. Early stages learn local edges and textures β€” 192 dimensions is sufficient. Deep stages reason about global scene structure and semantics β€” 384 dimensions gives more capacity where the task is genuinely harder. Linear projections between stages change the channel dimension without affecting spatial resolution.

Dense inter-stage concatenation

Outputs from all 4 stages are concatenated before the reconstruction head: cat([h1, h2, h3, h4]) β†’ [B, 256, 1120]. The head receives low-level edge information (stage 1) and high-level semantic context (stage 4) simultaneously, without early-stage features being filtered through subsequent blocks. Inspired by DenseNet's feature reuse principle, applied across transformer stages.

Gated Depthwise Feed-Forward Network (GDFN)

Standard transformer MLPs process each token independently β€” no spatial awareness in the feed-forward step. GDFN replaces this with gated depthwise 3Γ—3 convolutions, giving each token access to its 8 spatial neighbors during the feed-forward computation. Local spatial context injected at every attention layer, at almost no extra parameter cost.

# GDFN β€” replaces standard Linear β†’ SiLU β†’ Linear
g_proj   = Linear(embed_dim, 2 * hidden)
path_1   = Conv2d(hidden, hidden, 3, padding=1, groups=hidden)
path_2   = Conv2d(hidden, hidden, 3, padding=1, groups=hidden)
out_proj = Linear(hidden, embed_dim)

# gate: path1 Γ— GELU(path2)
x = out_proj(path_1(x) * F.gelu(path_2(x)))

Bilinear skip connection

output = F.interpolate(lr, 256Γ—256) + vit_residual

The model learns a residual correction on top of a bilinear upscale of the input β€” not full reconstruction from scratch. Faster convergence, more stable training.

Summary

Component Choice
Attention Global self-attention, O(NΒ²), N=256
Spatial resolution Constant 16Γ—16 throughout
Embed dims [192, 256, 288, 384]
Heads [6, 8, 6, 8]
Depths [3, 3, 3, 3] β€” 12 blocks total
Feed-forward GDFN (gated depthwise conv)
Feature routing Dense concat all 4 stages
Skip Bilinear upscale + residual
Parameters 23.8M
Model size 90.99 MB (fp32)

Results

Benchmark PSNR SSIM
DIV2K validation 25.20 dB 0.8298

Keywords

encoder-only vision transformer super-resolution Β· DenseNet-style skip connections and feature propagation Β· constant spatial resolution across 4 hierarchical stages Β· isotropic token grid no patch merging Β· gated depthwise feed-forward network GDFN Β· Γ—4 sub-pixel convolution upscale shuffling Β· dense inter-stage feature aggregation Β· pure ViT image restoration Β· Dense-Iso-ViT


License

MIT β€” free to use, modify, and distribute with attribution.