ViT-ISR-Tiny: Vision Transformer for Γ4 Image Super-Resolution
A Vision Transformer for Γ4 image super-resolution β built entirely from scratch in PyTorch, trained on 76,716 real-world images from the LSDIR benchmark, a tiny model with less than million parameters
π Try the Live Demo
What it does
Takes a low-resolution image and reconstructs a Γ4 higher-resolution version using global self-attention β capturing long-range spatial relationships across the entire image rather than just local neighborhoods like CNNs do.
Architecture
The pipeline: patch embed β transformer blocks β reshape β PixelShuffle upsample β RGB output.
| Component | Details |
|---|---|
| Patch size | 2Γ2 β 1,024 tokens per image |
| Embedding dim | 64 |
| Transformer blocks | 6 (pre-norm, residual) |
| Attention heads | 4 |
| MLP hidden dim | 256 |
| Upsampling | 3Γ PixelShuffle Γ2 (Γ8 total) |
| Parameters | ~786K Β· ~3 MB (fp32) |
Why pre-norm: LayerNorm before each sub-layer keeps the residual path clean, stabilizing training in deep stacks.
Why PixelShuffle: Learned upsampling β the model decides what to put in new pixels by redistributing channel information into space, rather than stretching with interpolation.
Why patch size 2: Finer tokens preserve more spatial detail for reconstruction. Larger patches are cheaper but lose the high-frequency information SR depends on.
Training
| Setting | Value |
|---|---|
| Dataset | LSDIR (76,716 train / 4,263 test) |
| Optimizer | AdamW (lr=2e-4) |
| Loss | L1 |
| Mixed precision | fp16 AMP |
| Batch size | 16 |
| Degradation | Bicubic Γ4 downscaling |
| Hardware | RTX 4060 Laptop (8GB) |
Test PSNR: 23.30 dB β evaluated on the held-out test split (never seen during training or checkpoint selection).
Usage
from huggingface_hub import hf_hub_download
import torch
weights_path = hf_hub_download("Sathya77/ViT-ISR-Tiny-LSDIR", "sr_best.pt")
checkpoint = torch.load(weights_path, map_location="cpu")
model = ImageSRTransformer() # from model_architecture.py
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
Files
| File | Description |
|---|---|
sr_best.pt |
Best checkpoint (weights + optimizer state) |
model_weights.pt |
Weights only β for inference |
config.json |
Architecture config |
training_config.json |
Training config and results |