--- license: mit language: - en tags: - medical-image-segmentation - retinal-oct - transunet - efficientnet - pytorch - onnx - ophthalmology - attention-mechanism - uncertainty-quantification - mc-dropout - miccai - domain-adaptation datasets: - DUKE-DME - AROI - UMN-AMD - UMN-DME metrics: - dice pipeline_tag: image-segmentation --- # Attention-Guided TransUNet for Multi-Class Retinal Fluid Segmentation **Author:** Animesh Kumar | Newcastle University MSc Advanced Computer Science 2025–26 **Target Venue:** OMIA 2026 Workshop at MICCAI + medRxiv preprint **Framework:** PyTorch | **Compute:** Google Colab H100 **Live Demo:** [HuggingFace Space](https://huggingface.co/spaces/animeshakr/oct-fluid-segmentation) | **Code:** [GitHub](https://github.com/Animesh-Kr/oct-fluid-segmentation) | **DOI:** [Zenodo](https://doi.org/10.5281/zenodo.19808008) --- ## Clinical Motivation Retinal fluid accumulation is the primary biomarker for three major vision-threatening diseases: | Fluid Type | Full Name | Clinical Significance | |-----------|-----------|----------------------| | **IRF** | Intraretinal Fluid | Active inflammation in diabetic macular oedema (DME) and wet AMD. Requires anti-VEGF injection within days of detection. | | **SRF** | Subretinal Fluid | Associated with neovascular AMD and central serous chorioretinopathy. Volume determines treatment frequency. | | **PED** | Pigment Epithelial Detachment | Elevation of the retinal pigment epithelium — a key AMD progression marker. | Manual segmentation by expert graders takes **20–40 minutes per volume** with up to **15% inter-grader variability** on small fluid pockets. This model provides calibrated, uncertainty-aware segmentation masks to reduce clinician workload and improve treatment monitoring consistency. --- ## Architecture ### Two Models — Dual Ensemble | Component | V2S (Small) | V2L (Large) | |-----------|-------------|-------------| | Encoder | EfficientNetV2S | EfficientNetV2L | | Encoder channels | s1=24, s2=48, s3=64, s4=160, bot=256 | s1=32, s2=64, s3=96, s4=192, bot=640 | | Transformer d_model | 256 | 512 | | Attention heads | 16 | 16 | | Transformer layers | 2 | 2 | | Total parameters | ~22M | ~127M | | Phase B val Dice | 0.7443 (seed=42) | 0.7913 (seed=123) | ### Novel Contributions 1. **UCUS — Uncertainty-Weighted Clinical Urgency Score** Combines volume score, foveal multiplier, boundary uncertainty and uncertainty discount into a single triage band: Monitor / Review / Urgent. 2. **Dual Uncertainty Estimation** MC Dropout variance (20 forward passes) combined with inter-model disagreement between V2S and V2L predictions. 3. **Source-Adaptive BatchNorm (SA-BN)** Separate batch norm statistics per scanner source (DUKE, AROI, UMN-AMD, UMN-DME, OPTIMA). Enables cross-scanner domain adaptation without retraining. 4. **Multi-Source Four-Dataset Evaluation** Simultaneous evaluation across 4 independent acquisition sources with per-source Dice breakdown. ### Architecture Data Flow ``` Input OCT B-scan (1×512×512) │ ▼ EfficientNetV2L Encoder (5 stages, ImageNet pretrained) │ skip connections s1-s4 ▼ Transformer Bottleneck (2× MHA, d_model=512, 16 heads, learnable pos encoding) │ ▼ Attention-Gate Decoder (4 levels, SA-BatchNorm per decoder block) │ ▼ MC Dropout (p=0.3) + Output Head (4 classes) │ ├── Segmentation Mask (BG/IRF/SRF/PED) ├── Uncertainty Heatmap (pixel-wise std) └── UCUS Score (clinical triage) ``` --- ## Datasets | Dataset | Volumes | Annotated Classes | Disease | Scanner | |---------|---------|------------------|---------|---------| | DUKE DME | 10 subjects, 110 B-scans | IRF only | DME | Spectralis | | AROI | 24 patients | IRF + SRF + PED | AMD | Zeiss Cirrus | | UMN AMD | 24 subjects | SRF (binary) | AMD | Spectralis | | UMN DME | 29 subjects | IRF (binary) | DME | Spectralis | **Unified label space:** 0=Background, 1=IRF, 2=SRF, 3=PED **Split after preprocessing:** - Train: 4983 fluid-only slices - Validation: 552 slices - Test: 503 slices --- ## Training Protocol ### Phase A — Decoder Only (5 epochs) - Encoder frozen at ImageNet weights - LR = 1e-3, Adam, batch size 8 - Target: val_dice > 0.50 ### Phase B — Full Fine-tuning (25 epochs) - Encoder blocks 3-5 unfrozen - LR = 1e-4, WarmupCosineDecay (5 epoch warmup) - Batch size 4, early stopping patience=7 - Loss: Dice + 0.5 × CrossEntropy ### Seeds Trained V2S: 42, 123, 2024 | V2L: 42, 123, 2024 --- ## Results ### Multi-Seed Validation Dice (mean ± std across seeds 42/123/2024) | Model | IRF | SRF | PED | Mean Fluid | |-------|-----|-----|-----|-----------| | V2S | 0.8658 ± 0.0067 | 0.8272 ± 0.0046 | 0.5184 ± 0.0093 | 0.7371 ± 0.0052 | | **V2L** | **0.9158 ± 0.0034** | **0.8560 ± 0.0034** | **0.5811 ± 0.0175** | **0.7843 ± 0.0058** | ### Test Set Results (503 slices, 4 sources) | Metric | Mean | Std | |--------|------|-----| | dice_IRF | 0.2043 | ±0.3482 | | dice_SRF | 0.1712 | ±0.2359 | | dice_PED | 0.4463 | ±0.4611 | | dice_mean_fluid | 0.2739 | ±0.2161 | > Note: Low test Dice is driven by domain shift across 4 independent sources. > V2L alone achieves 0.4511 on the ablation — the multi-source test set is a hard benchmark. ### Per-Source Breakdown | Source | IRF | SRF | PED | Mean Fluid | |--------|-----|-----|-----|-----------| | AROI | 0.054 | 0.299 | 0.144 | 0.166 | | DUKE | 0.071 | 0.000 | **0.902** | 0.324 | | UMN | 0.381 | 0.176 | 0.409 | 0.322 | > DUKE SRF = 0.000 is expected — DUKE dataset contains only IRF annotations. > DUKE PED = 0.902 shows the model correctly detects PED on DUKE scans. ### Clinical Safety Metrics | Metric | Value | Significance | |--------|-------|-------------| | Inter-grader human ceiling | 0.9030 | Upper bound for automated systems | | Model match threshold (95%) | 0.8579 | Target for clinical deployment | | Uncertainty ratio | **1.34×** | p=3.77e-05 ✅ | | SRF volume correlation | **r=0.778** | p=6.33e-04 ✅ | | PED volume correlation | **r=0.841** | p=8.64e-05 ✅ | | Total fluid correlation | r=0.562 | p=2.93e-02 ✅ | The uncertainty ratio finding means the model is **1.34× more uncertain at pixels where human experts disagree** — a statistically significant result confirming that model uncertainty correlates with genuine clinical ambiguity. ### Ablation Study | Variant | Mean Dice | Std | |---------|-----------|-----| | V2S only (no MC, no TTA) | 0.338 | ±0.332 | | V2S + MC Dropout | 0.141 | ±0.144 | | V2S + TTA | 0.121 | ±0.115 | | **V2L only (no MC, no TTA)** | **0.449** | ±0.304 | | V2S + V2L ensemble | 0.415 | ±0.318 | | V2S + V2L + MC Dropout | 0.279 | ±0.219 | | Full (V2S+V2L+MC+TTA) | 0.293 | ±0.218 | ### INT8 Quantisation (Phase 5B) | Model | FP32 | INT8 | Compression | Method | |-------|------|------|-------------|--------| | V2L | 510 MB | 132 MB | **3.9×** | Per-tensor symmetric int8 | | V2S | 91 MB | 24 MB | **3.8×** | Per-tensor symmetric int8 | --- ## Files in This Repository | File | Size | Description | |------|------|-------------| | `ckpt_phaseB_V2L_s123.pth` | 1526 MB | Best V2L checkpoint — val_dice=0.7913, epoch=25 | | `ckpt_phaseB_V2L_s2024.pth` | 1526 MB | Second V2L checkpoint — val_dice=0.7841, epoch=24 | | `ckpt_phaseB_V2S_s42.pth` | 271 MB | Best V2S checkpoint — val_dice=0.7443, epoch=34 | | `ckpt_phaseB_V2L_s123_int8.pth` | 132 MB | INT8 quantised V2L (3.9× compression) | | `ckpt_phaseB_V2S_s42_int8.pth` | 24 MB | INT8 quantised V2S (3.8× compression) | | `deployment/slot1_v2l_seed2024.onnx` | — | ONNX export — ready for TensorRT/OpenVINO | | `deployment/slot2_v2l_seed123.onnx` | — | ONNX export — ready for TensorRT/OpenVINO | | `demo_results.json` | 87 MB | 20 precomputed demo samples (5 per source) | --- ## Usage ### Load Best Model (V2L seed=123) ```python from huggingface_hub import hf_hub_download import torch path = hf_hub_download( repo_id="animeshakr/oct-fluid-segmentation", filename="ckpt_phaseB_V2L_s123.pth" ) ck = torch.load(path, map_location="cpu") print(f"val_dice: {ck['val_dice']:.4f}") # 0.7913 print(f"epoch: {ck['epoch']}") # 25 ``` ### Load INT8 Quantised Model (Edge Deployment) ```python from huggingface_hub import hf_hub_download import torch path = hf_hub_download( repo_id="animeshakr/oct-fluid-segmentation", filename="ckpt_phaseB_V2L_s123_int8.pth" ) qsd = torch.load(path, map_location="cpu") # Dequantise at inference: weight = weight_int8 * scale # 132MB vs 510MB original — 3.9x smaller ``` ### ONNX Inference ```python from huggingface_hub import hf_hub_download import onnxruntime as ort import numpy as np # Download both files — .onnx.data must be downloaded first hf_hub_download(repo_id="animeshakr/oct-fluid-segmentation", filename="deployment/slot2_v2l_seed123.onnx.data") path = hf_hub_download(repo_id="animeshakr/oct-fluid-segmentation", filename="deployment/slot2_v2l_seed123.onnx") sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) x = np.random.randn(1, 1, 512, 512).astype(np.float32) out = sess.run(None, {sess.get_inputs()[0].name: x})[0] pred_mask = out.argmax(axis=1)[0] # (512, 512) with values 0-3 ``` --- ## Citation ```bibtex @misc{kumar2026octseg, title={Attention-Guided TransUNet for Multi-Class Retinal Fluid Segmentation in OCT with MC Dropout Uncertainty Quantification}, author={Kumar, Animesh A.}, institution={Newcastle University, UK}, year={2026}, note={MSc Advanced Computer Science dissertation. Targeting OMIA 2026 Workshop at MICCAI.} } ``` --- ## Related Papers - Bogunovic et al. (2019) — RETOUCH Challenge, IEEE TMI — benchmark definition - Ronneberger et al. (2015) — U-Net — segmentation baseline - Schlemper et al. (2019) — Attention U-Net — attention gate mechanism - Chen et al. (2021) — TransUNet — transformer bottleneck design - Rasti et al. (2022) — RetiFluidNet — current SOTA on RETOUCH --- ## License MIT License — see [GitHub repository](https://github.com/Animesh-Kr/oct-fluid-segmentation/blob/main/LICENSE)