retinasense-vit / RUN.md
tanishq74's picture
Add RUN.md
9f9de0c verified

RetinaSense-ViT β€” Complete Run Guide

Read this file at the start of any new session. Everything needed to understand, run, test, and extend this project is documented here.


0. Local Setup (Start Here If Cloning for the First Time)

The project was originally developed on a remote GPU server. All paths are now relative β€” the code works on any machine. Run this once after cloning:

git clone https://github.com/Tanishq74/retina-sense
cd retina-sense
bash setup.sh

setup.sh does everything automatically:

  1. Installs all Python dependencies (torch CPU build + timm, gradio, fastapi, etc.)
  2. Creates outputs_v3/ and data/ directories
  3. Downloads best_model.pth (331MB) from Hugging Face Hub β†’ outputs_v3/
  4. Downloads efficientnet_b3.pth (45MB) from Hugging Face Hub β†’ outputs_v3/ensemble/
  5. Verifies configs/ JSON files are present (they are committed to git)

After setup completes:

python app.py    # Gradio web demo β†’ http://localhost:7860 (also generates public URL)

What is and isn't in git:

File In git? How to get it
All .py scripts Yes git clone
configs/*.json Yes git clone (temperature, thresholds, norm stats)
outputs_v3/best_model.pth No (331MB) bash setup.sh β†’ auto-downloads from HF
outputs_v3/ensemble/efficientnet_b3.pth No (45MB) bash setup.sh β†’ auto-downloads from HF
outputs_v3/ood_detector.npz No Only on GPU server β€” app runs fine without it (OOD check skipped)
data/*.csv No (large) Only on GPU server β€” only needed for retraining
preprocessed_cache_v3/ No (multi-GB) Only on GPU server β€” only needed for retraining

HF model repo: https://huggingface.co/tanishq74/retinasense-vit


1. Project Overview

RetinaSense-ViT is a deep learning system for retinal disease classification from fundus photographs, featuring an ensemble architecture, uncertainty-guided clinical triage, and domain-adversarial training.

Property Value
Task Multi-class classification, 5 diseases
Classes Normal (0), Diabetes/DR (1), Glaucoma (2), Cataract (3), AMD (4)
Inference Model ViT-Base/16 + EfficientNet-B3 Ensemble + TTA
Ensemble Accuracy 74.7% (vs 49.1% ViT-only) / Macro AUC 0.951
Dataset 8,540 images β€” APTOS (3,662) + ODIR (4,878)
Split 70/15/15 β€” train/calib/test (stratified)
Best checkpoint outputs_v3/best_model.pth (epoch 24)
GitHub https://github.com/Tanishq74/retina-sense
App Features Ensemble + TTA + Attention Rollout + MC Dropout + Clinical Triage

2. Directory Structure

<repo-root>/
β”‚
β”œβ”€β”€ app.py                       # Gradio web demo (Ensemble + TTA + Triage)
β”œβ”€β”€ api/
β”‚   └── main.py                  # FastAPI REST server
β”‚
β”œβ”€β”€ ─── TRAINING SCRIPTS ───
β”œβ”€β”€ retinasense_v3.py            # Main ViT training script (1220 lines)
β”œβ”€β”€ train_ensemble.py            # EfficientNet-B3 ensemble training
β”œβ”€β”€ train_dann.py                # NEW: Domain-Adversarial Neural Network (GPU)
β”œβ”€β”€ kfold_cv.py                  # 5-fold cross-validation (GPU)
β”œβ”€β”€ knowledge_distillation.py    # KD + ONNX export (GPU)
β”‚
β”œβ”€β”€ ─── IMPROVEMENT MODULES ───
β”œβ”€β”€ unified_preprocessing.py     # NEW: Unified CLAHE pipeline (replaces domain-conditional)
β”œβ”€β”€ retfound_backbone.py         # NEW: RETFound foundation model backbone
β”œβ”€β”€ enhanced_augmentation.py     # NEW: CutMix, elastic deform, class-aware augmentation
β”œβ”€β”€ prepare_datasets.py          # NEW: Download/prep 5 additional public datasets
β”‚
β”œβ”€β”€ ─── ANALYSIS & XAI ───
β”œβ”€β”€ gradcam_v3.py                # Attention Rollout XAI
β”œβ”€β”€ eval_dashboard.py            # Full evaluation suite
β”œβ”€β”€ mc_dropout_uncertainty.py    # MC Dropout uncertainty
β”œβ”€β”€ integrated_gradients_xai.py  # Integrated Gradients XAI
β”œβ”€β”€ fairness_analysis.py         # Domain fairness analysis
β”‚
β”œβ”€β”€ ─── DATA & CONFIG ───
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ fundus_norm_stats.json   # mean=[0.4298,0.2784,0.1559] std=[0.2857,0.2065,0.1465]
β”‚   β”œβ”€β”€ temperature.json         # T=0.6438
β”‚   └── thresholds.json          # Per-class thresholds
β”œβ”€β”€ data/                        # CSVs (on GPU server only)
β”œβ”€β”€ preprocessed_cache_v3/       # .npy image cache (on GPU server only)
β”‚
β”œβ”€β”€ ─── MODEL WEIGHTS (not in git) ───
β”œβ”€β”€ outputs_v3/
β”‚   β”œβ”€β”€ best_model.pth           # ViT-Base/16 checkpoint (331MB, from HF)
β”‚   β”œβ”€β”€ ensemble/
β”‚   β”‚   └── efficientnet_b3.pth  # EfficientNet-B3 checkpoint (45MB, from HF)
β”‚   β”œβ”€β”€ evaluation/              # Phase 1A outputs (7 files)
β”‚   β”œβ”€β”€ uncertainty/             # Phase 1B outputs (6 files)
β”‚   β”œβ”€β”€ xai/                     # Phase 1C outputs (23 files)
β”‚   β”œβ”€β”€ fairness/                # Phase 1D outputs (7 files)
β”‚   β”œβ”€β”€ gradcam/                 # Attention Rollout heatmaps (22 files)
β”‚   └── dann/                    # DANN outputs (after training)
β”‚
β”œβ”€β”€ ─── DOCUMENTATION ───
β”œβ”€β”€ RUN.md                       # This file
β”œβ”€β”€ ARCHITECTURE_DOCUMENT.md     # System architecture
β”œβ”€β”€ FUNCTIONAL_DOCUMENT.md       # Functional specification
β”œβ”€β”€ FUNCTIONAL_TEST_CASE_DOCUMENT.md
β”œβ”€β”€ IEEE_RESEARCH_PAPER.md       # Research paper draft
β”œβ”€β”€ FINAL_COMPREHENSIVE_REPORT.md
β”œβ”€β”€ Dockerfile                   # Docker deployment
└── requirements_deploy.txt      # Deployment dependencies

3. Model Architecture

Defined as MultiTaskViT in retinasense_v3.py:

class MultiTaskViT(nn.Module):
    # Backbone: ViT-Base/16, pretrained ImageNet-21k, output_dim=768
    backbone = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
    drop = nn.Dropout(0.3)

    # Disease classification head
    disease_head = Sequential(
        Linear(768, 512), BatchNorm1d(512), ReLU, Dropout(0.3),
        Linear(512, 256), BatchNorm1d(256), ReLU, Dropout(0.2),
        Linear(256, 5)
    )

    # Severity sub-classification head (shared backbone)
    severity_head = Sequential(
        Linear(768, 256), BatchNorm1d(256), ReLU, Dropout(0.3),
        Linear(256, 5)
    )

    def forward(self, x):
        feat = self.drop(self.backbone(x))
        return self.disease_head(feat), self.severity_head(feat)
        # Returns: (disease_logits [B,5], severity_logits [B,5])
        # ALWAYS unpack both, even if only using disease_logits

EfficientNet-B3 (Ensemble Member)

class EfficientNetB3(nn.Module):
    backbone = timm.create_model('efficientnet_b3', num_classes=0)  # 1536-dim
    drop = nn.Dropout(0.3)
    head = Sequential(
        Linear(1536, 512), BatchNorm1d(512), ReLU, Dropout(0.3),
        Linear(512, 256), BatchNorm1d(256), ReLU, Dropout(0.2),
        Linear(256, 5)
    )
    def forward(self, x):
        return self.head(self.drop(self.backbone(x)))
        # Returns: logits [B,5] β€” single tensor, NOT tuple

Ensemble Inference (how app.py works)

# Both models loaded, both in eval mode
T = 0.6438
vit_probs  = softmax(vit_disease_logits / T, dim=1)    # from MultiTaskViT
eff_probs  = softmax(eff_logits / T, dim=1)             # from EfficientNetB3

# Weighted ensemble
ensemble_probs = 0.35 * vit_probs + 0.65 * eff_probs
pred = argmax(ensemble_probs)

# TTA: average over 4 augmented versions (original, h-flip, v-flip, Β±10Β° rotation)
# MC Dropout: 15 stochastic passes for uncertainty estimation

DANNMultiTaskViT (Domain-Adversarial, in train_dann.py)

class DANNMultiTaskViT(nn.Module):
    backbone = ViT-Base/16 (768-dim)
    disease_head  = 768 β†’ 512 β†’ 256 β†’ 5    (same as MultiTaskViT)
    severity_head = 768 β†’ 256 β†’ 5           (same as MultiTaskViT)
    domain_head   = 768 β†’ GRL β†’ 256 β†’ 128 β†’ 2   (NEW: APTOS vs ODIR)
    # GRL = Gradient Reversal Layer (negates gradients to force domain invariance)

4. Installed Packages

All packages are already installed in the environment:

torch==2.8.0+cu128        torchvision==0.23.0+cu128
timm==1.0.25              pytorch-lightning==2.6.0
gradio==6.9.0             captum==0.8.0
onnx==1.20.1              onnxruntime==1.24.3
fastapi==0.133.1          uvicorn==0.41.0
fpdf2==2.8.7              opencv-python==4.11.0.86
torchmetrics==1.7.4

If anything is missing:

pip install -r requirements_deploy.txt
pip install captum gradio fpdf2

5. What Has Already Been Run (Do NOT Re-Run)

These scripts completed successfully on a previous GPU session and their outputs are saved:

Script Status Output Location Key Result
retinasense_v3.py DONE outputs_v3/best_model.pth Epoch 24, F1=0.854
gradcam_v3.py DONE outputs_v3/gradcam/ (22 files) Attention Rollout heatmaps
eval_dashboard.py DONE outputs_v3/evaluation/ (7 files) Acc=49.1%, AUC=0.893
mc_dropout_uncertainty.py DONE outputs_v3/uncertainty/ (6 files) T=30 MC passes
integrated_gradients_xai.py DONE outputs_v3/xai/ (23 files) Pearson r=0.196
fairness_analysis.py DONE outputs_v3/fairness/ (7 files) APTOS ECE=0.51
train_ensemble.py DONE outputs_v3/ensemble/ (2 files) Acc=74.7%, AUC=0.951

6. What Still Needs to Run (Requires GPU)

Recommended execution order:

  1. python unified_preprocessing.py β€” rebuild cache (fixes domain shift)
  2. python prepare_datasets.py --all β€” expand dataset (optional but recommended)
  3. python train_dann.py β€” domain-adversarial training (main accuracy improvement)
  4. python kfold_cv.py β€” cross-validation for paper
  5. python knowledge_distillation.py β€” model compression

6A-NEW. Unified Preprocessing + Cache Rebuild (~30 min)

python unified_preprocessing.py

What it does: Rebuilds the entire image cache using a single CLAHE pipeline for ALL images (APTOS, ODIR, REFUGE2). This eliminates the domain-conditional preprocessing (Ben Graham for APTOS, CLAHE for ODIR) that caused the domain shift.

Outputs:

  • preprocessed_cache_unified/ β€” new .npy cache with consistent preprocessing
  • configs/fundus_norm_stats_unified.json β€” recomputed normalization stats
  • data/*_unified.csv β€” updated CSVs with new cache paths

After running: Update training scripts to use the unified cache and new norm stats.


6B-NEW. Dataset Expansion (optional, ~1-2 hours download + preprocess)

python prepare_datasets.py --list              # show available datasets
python prepare_datasets.py --instructions      # download instructions
python prepare_datasets.py --dataset eyepacs --raw-dir ./data/eyepacs
python prepare_datasets.py --dataset refuge --raw-dir ./data/refuge
python prepare_datasets.py --dataset adam --raw-dir ./data/adam
python prepare_datasets.py --merge             # combine all into unified splits

Available datasets:

Dataset Images Classes Added Impact
EyePACS ~35,000 DR + Normal Massive DR/Normal boost
MESSIDOR-2 1,748 DR grades More DR diversity
REFUGE ~1,200 Glaucoma + Normal 20Γ— more Glaucoma samples
ADAM (iChallenge-AMD) ~1,200 AMD + Normal 30Γ— more AMD samples
ORIGA ~650 Glaucoma + Normal More Glaucoma diversity

6C-NEW. Domain-Adversarial Training (~2-3 hours on H100)

python train_dann.py

What it does: Trains a Domain-Adversarial Neural Network (DANN) with gradient reversal that forces the ViT backbone to learn features that are predictive of disease but NOT predictive of which dataset (APTOS vs ODIR) the image came from.

Key components:

  • GradientReversalLayer: reverses gradients with Ganin schedule (lambda ramps 0β†’1)
  • DANNMultiTaskViT: disease head + severity head + domain discriminator
  • Loss: disease_loss + 0.2*severity_loss + alpha*lambda*domain_loss
  • Warm-starts from outputs_v3/best_model.pth
  • Same training recipe as v3: AdamW, LLRD, OneCycleLR, Focal Loss, Mixup

Outputs β†’ outputs_v3/dann/:

  • best_model.pth β€” DANN-trained checkpoint
  • history.json β€” per-epoch metrics
  • dashboard.png β€” training curves (disease + domain accuracy)

Expected results: Domain accuracy should converge toward ~50% (random = domain-invariant). Disease accuracy should improve, especially on APTOS DR images.


6D-NEW. RETFound Backbone (alternative to 6C)

# In your training script, replace:
from retinasense_v3 import MultiTaskViT
# With:
from retfound_backbone import MultiTaskRetFound, setup_retfound

# Download RETFound weights (once)
setup_retfound()

# Create model with retinal-pretrained backbone
model = MultiTaskRetFound(retfound_path='./weights/RETFound_cfp_weights.pth')

What it does: Swaps ImageNet-pretrained ViT for RETFound β€” a ViT-Base/16 pretrained on 1.6 million retinal images using masked autoencoding. Same architecture, much better features for retinal pathology.


6E. K-Fold Cross-Validation (~2 hours on H100)

python kfold_cv.py

What it does: 5-fold stratified CV on the full train+calib pool (7,265 images). Produces mean Β± std confidence intervals for the paper.

Config (inside kfold_cv.py):

N_FOLDS=5, N_EPOCHS=30, PATIENCE=8, BATCH_SIZE=32
BASE_LR=3e-4, LLRD_DECAY=0.75, MIXUP_ALPHA=0.4, FOCAL_GAMMA=1.0

Outputs β†’ outputs_v3/kfold/:

  • fold_1_best.pth ... fold_5_best.pth
  • kfold_results.json β€” mean Β± std per metric
  • fold_comparison.png β€” bar chart per fold
  • perclass_f1_boxplot.png β€” per-class F1 boxplot

Expected results:

Accuracy:          60–75% Β± ~5%
Balanced Accuracy: 75–85% Β± ~3%
Macro F1:          0.65–0.75 Β± ~0.03
Macro AUC:         0.90–0.96 Β± ~0.02

Troubleshooting:

Problem Fix
CUDA OOM Reduce BATCH_SIZE from 32 to 16 in kfold_cv.py
Cache miss Check preprocessed_cache_v3/ has .npy files
KeyError: mean_rgb data/fundus_norm_stats.json must exist
Arch mismatch MultiTaskViT in kfold_cv.py must match retinasense_v3.py

After completion: update SESSION_CONTEXT.md with meanΒ±std numbers.


6B. Knowledge Distillation + ONNX Export (~30 min on H100)

cd /teamspace/studios/this_studio
python knowledge_distillation.py

What it does: Compresses ViT-Base (86M params) β†’ ViT-Tiny (5.7M params) using knowledge distillation. Then exports to ONNX and quantizes to INT8 for CPU deployment.

Config (inside knowledge_distillation.py):

KD_ALPHA=0.3 (30% CE loss + 70% KL distillation loss)
KD_TEMP=4.0  (softens teacher logits for distillation)
Student: vit_tiny_patch16_224 (192-dim, ~5.7M params)

Outputs β†’ outputs_v3/distillation/:

  • student_best.pth β€” distilled ViT-Tiny checkpoint
  • retinasense_student.onnx β€” ONNX model (opset 17, dynamic batch)
  • retinasense_student_int8.onnx β€” INT8 quantized (~6MB)
  • distillation_results.json β€” accuracy comparison + CPU benchmark
  • distillation_curves.png β€” training curves

Expected results:

Teacher (ViT-Base):   Acc~74%, AUC~0.95, Size=331MB
Student (ViT-Tiny):   Acc~68-72%, AUC~0.91-0.94, Size~23MB
INT8 quantized:       Acc~67-71%, Size~6MB, CPU inference ~80ms

7. Running the Web Applications

7A. Gradio Web Demo

python app.py
  • Opens on port 7860 (also generates a public shareable URL)
  • Features:
    • Ensemble prediction: ViT-Base/16 (35%) + EfficientNet-B3 (65%)
    • Test-Time Augmentation: 4 augmented versions averaged (h-flip, v-flip, Β±10Β° rotation)
    • Attention Rollout heatmap: ViT attention visualization
    • MC Dropout uncertainty: 15 stochastic passes (epistemic + aleatoric split)
    • Clinical triage: AUTO-SCREEN / PRIORITY REVIEW / URGENT / RESCAN
    • Model disagreement detection: shows when ViT and EfficientNet disagree
    • OOD detection: Mahalanobis distance (gracefully skipped if npz missing)
    • Downloadable clinical report: .txt file with full analysis

What app.py does internally:

  1. Loads outputs_v3/best_model.pth (ViT-Base/16)
  2. Loads outputs_v3/ensemble/efficientnet_b3.pth (EfficientNet-B3)
  3. Loads configs: temperature.json, thresholds.json, fundus_norm_stats.json
  4. Preprocessing: crop borders β†’ resize 224 (INTER_AREA) β†’ CLAHE β†’ circular mask β†’ normalize
  5. Runs TTA (4 augmentations) through both models
  6. Computes ensemble probabilities (35/65 weighted average)
  7. Runs MC Dropout (15 passes) for uncertainty
  8. Computes triage level based on confidence + uncertainty + model agreement
  9. Generates Attention Rollout heatmap
  10. Generates clinical recommendation + downloadable report

Preprocessing pipeline (CRITICAL β€” must match training):

def preprocess_image(img_pil):
    img = crop_black_borders(img)        # remove dark padding
    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
    img = apply_clahe(img)               # CLAHE on L-channel in LAB
    img = apply_circular_mask(img)       # zero pixels outside fundus circle (r=0.48)
    tensor = ToTensor + Normalize(mean=[0.4298,0.2784,0.1559], std=[0.2857,0.2065,0.1465])

Clinical Triage System:

Level Criteria Action
AUTO-SCREEN Confidence > 70%, low uncertainty, models agree Routine re-screening
PRIORITY REVIEW Confidence 40-70%, or elevated uncertainty Specialist within 2 weeks
URGENT SPECIALIST Confidence < 40%, or high uncertainty, or models disagree Specialist within 48 hours
RESCAN NEEDED OOD detected Image quality issue, rescan

7B. FastAPI REST Server

cd /teamspace/studios/this_studio
python -m uvicorn api.main:app --host 0.0.0.0 --port 8000

Or with auto-reload during development:

python -m uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload

Endpoints:

Method Endpoint Description
GET /health Returns {"status": "healthy", "model": "loaded"}
POST /predict Single image inference
POST /predict/batch Batch inference (multiple images)
GET /docs Interactive Swagger UI

Test the API:

# Health check
curl http://localhost:8000/health

# Single prediction (replace with a real fundus image path)
curl -X POST http://localhost:8000/predict \
  -F "file=@/path/to/fundus_image.jpg"

# Response format:
{
  "predicted_class": "Diabetes/DR",
  "predicted_label": 1,
  "confidence": 0.819,
  "probabilities": {"Normal": 0.08, "Diabetes/DR": 0.82, "Glaucoma": 0.03, "Cataract": 0.04, "AMD": 0.03},
  "severity": "Moderate",
  "ood_score": 12.4,
  "is_ood": false,
  "inference_time_ms": 370
}

7C. Docker Deployment

cd /teamspace/studios/this_studio

# Build the image
docker build -t retinasense .

# Run the container
docker run -p 8000:8000 retinasense

# With GPU support
docker run --gpus all -p 8000:8000 retinasense

8. Running XAI / Attention Rollout on New Images

cd /teamspace/studios/this_studio
python gradcam_v3.py

What it does: Runs Attention Rollout (Abnar & Zuidema 2020) on 20 sample images (4 per class), saves heatmap overlays to outputs_v3/gradcam/.

Key settings in gradcam_v3.py:

discard_ratio = 0.97      # discard bottom 97% of attention (sharp focus)
alpha = 0.7               # heatmap overlay strength
power_stretch = 0.4       # contrast enhancement (np.power(spatial, 0.4))

Note: Standard Grad-CAM does NOT work on ViT (CLS token problem β€” zero patch gradients). Always use ViTAttentionRollout, never GradCAM for this model.


9. Key Numbers (Current Performance)

Single-Run Test Set (1,287 images)

Metric Value
Overall Accuracy 49.1%
Balanced Accuracy 76.7%
Macro F1 0.626
Macro AUC 0.893
Cohen Kappa 0.293

Per-class breakdown:

Class Precision Recall F1 AUC Support
Normal 0.336 0.968 0.499 0.770 310
Diabetes/DR 0.995 0.253 0.404 0.830 837
Glaucoma 0.717 0.635 0.673 0.877 52
Cataract 0.681 0.979 0.803 0.988 48
AMD 0.597 1.000 0.748 0.998 40

Ensemble (ViT 35% + EfficientNet-B3 65%):

Metric Value
Accuracy 74.7%
Macro F1 0.712
Macro AUC 0.951
Disagreement rate 44.2% (flags for human review)

Domain Gap

Domain N Accuracy ECE
ODIR 709 67.7% 0.157
APTOS 570 26.5% 0.510
REFUGE2 8 12.5% β€”

Critical finding: DR recall = 25.3%. 573 of 837 DR images misclassified as Normal. DR precision = 99.5% β€” when the model says DR, it's correct. But it rarely fires. Root cause: APTOS Ben Graham preprocessing creates a domain shift from ODIR CLAHE images.


10. Configuration Files

Temperature & Thresholds

outputs_v3/temperature.json:

{"temperature": 0.6438, "ece_before": 0.1618, "ece_after": 0.1014}

outputs_v3/thresholds.json:

{
  "thresholds": [0.638, 0.068, 0.840, 0.564, 0.289],
  "class_names": ["Normal", "Diabetes/DR", "Glaucoma", "Cataract", "AMD"]
}

Note: DR threshold = 0.068 (very low). Despite this, APTOS DR images still get classified as Normal because their DR probability stays below even this threshold.

Normalization Stats

data/fundus_norm_stats.json keys: mean_rgb and std_rgb (not mean/std).

import json
stats = json.load(open('data/fundus_norm_stats.json'))
mean = stats['mean_rgb']   # [0.4298, 0.2784, 0.1559]
std  = stats['std_rgb']    # [0.2857, 0.2065, 0.1465]

OOD Detector

import numpy as np
data = np.load('outputs_v3/ood_detector.npz')
# Keys: class_means, precision_matrix, threshold
# threshold = 42.82 (Mahalanobis distance)
# If distance > 42.82 β†’ flag as out-of-distribution

11. Data Loading Pattern

CSV columns: image_path, label, source, cache_path

import pandas as pd, numpy as np, torch
from torchvision import transforms

df = pd.read_csv('data/test_split.csv')

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4298, 0.2784, 0.1559],
                         std=[0.2857, 0.2065, 0.1465])
])

# Fast path: load from preprocessed cache (already resized to 224x224)
def load_image(row):
    cache = row['cache_path']   # e.g. preprocessed_cache_v3/xxxx.npy
    img = np.load(cache)        # uint8 [H,W,3]
    img = torch.tensor(img).permute(2,0,1).float() / 255.0
    return transform_normalize(img)  # apply only normalize, not resize

# Sources: 'APTOS', 'ODIR', 'REFUGE2'
# Labels: 0=Normal, 1=DR, 2=Glaucoma, 3=Cataract, 4=AMD

12. Complete Task Checklist

Done (Training β€” GPU server)

  • Train ViT-Base/16 model (retinasense_v3.py) β€” epoch 24, F1=0.854
  • Temperature calibration (T=0.6438)
  • Per-class threshold optimization
  • Attention Rollout XAI (gradcam_v3.py) β€” heatmaps working
  • OOD Mahalanobis detector
  • Evaluation dashboard (eval_dashboard.py) β€” 7 output files
  • MC Dropout uncertainty (mc_dropout_uncertainty.py) β€” 6 output files
  • Integrated Gradients XAI (integrated_gradients_xai.py) β€” 23 output files
  • Fairness/domain analysis (fairness_analysis.py) β€” 7 output files
  • EfficientNet-B3 ensemble (train_ensemble.py) β€” 74.7% acc, AUC=0.951

Done (App β€” session 2026-03-10)

  • Ensemble inference in app.py (ViT 35% + EfficientNet 65%)
  • Fixed preprocessing pipeline (border crop + circular mask + INTER_AREA)
  • Test-Time Augmentation (4 augmented versions averaged)
  • Clinical triage system (AUTO-SCREEN / REVIEW / URGENT / RESCAN)
  • Model disagreement detection (ViT vs EfficientNet)
  • OOD graceful handling (no crash when npz missing)
  • Gradio web app with all features β€” port 7860
  • FastAPI REST server (api/main.py) β€” port 8000
  • Docker deployment (Dockerfile)

Done (New training code β€” ready for GPU)

  • unified_preprocessing.py β€” single CLAHE pipeline for all sources
  • train_dann.py β€” Domain-Adversarial Neural Network training
  • retfound_backbone.py β€” RETFound foundation model backbone support
  • enhanced_augmentation.py β€” CutMix, elastic deform, 5Γ— minority oversampling
  • prepare_datasets.py β€” download/prep for EyePACS, MESSIDOR-2, REFUGE, ADAM, ORIGA

Needs GPU (run in this order)

  1. Rebuild cache: python unified_preprocessing.py (~30 min)
  2. Expand dataset: python prepare_datasets.py --all (optional)
  3. DANN training: python train_dann.py (~2-3 hrs on H100)
  4. K-Fold CV: python kfold_cv.py (~2 hrs on H100)
  5. Knowledge Distillation: python knowledge_distillation.py (~30 min on H100)
  6. (Alternative) RETFound backbone: modify training to use retfound_backbone.py

Known Issues

  • DR recall = 25.3% β€” root cause: APTOS domain shift (Ben Graham vs CLAHE preprocessing). Fix: unified_preprocessing.py + train_dann.py
  • AMD under-represented (~40 samples). Fix: prepare_datasets.py to add ADAM dataset (1,200 AMD images)
  • Model overconfident on garbage input (random noise: 86%). Inherent model property.
  • Ensemble thresholds not calibrated β€” current thresholds are for ViT-only. Fix: Recalibrate after DANN retraining.
  • Temperature T=0.6438 sharpens wrong predictions. Consider recalibrating after retraining.

13. Model Weights β€” Hugging Face Hub

Model weights are NOT in git (too large). They are hosted on Hugging Face:

Repo: https://huggingface.co/tanishq74/retinasense-vit

File Size Description
best_model.pth 331MB ViT-Base/16 trained checkpoint (epoch 24)
efficientnet_b3.pth 45MB EfficientNet-B3 ensemble checkpoint

Download automatically (run once):

bash setup.sh    # handles everything including model download

Or download manually:

from huggingface_hub import hf_hub_download
import shutil, os

os.makedirs("outputs_v3/ensemble", exist_ok=True)
for fname, dest in [
    ("best_model.pth",       "outputs_v3/best_model.pth"),
    ("efficientnet_b3.pth",  "outputs_v3/ensemble/efficientnet_b3.pth"),
]:
    path = hf_hub_download(repo_id="tanishq74/retinasense-vit", filename=fname)
    shutil.copy(path, dest)

14. Quick-Start Commands

# Full local setup from scratch (cloned repo):
bash setup.sh                        # installs deps + downloads models from HF

# Check what's already been generated
ls outputs_v3/evaluation/ outputs_v3/uncertainty/ outputs_v3/xai/ outputs_v3/fairness/ outputs_v3/ensemble/

# Start Gradio app
python app.py

# Start FastAPI server
python -m uvicorn api.main:app --host 0.0.0.0 --port 8000

# Run K-Fold CV (GPU required, ~2 hrs)
python kfold_cv.py

# Run Knowledge Distillation (GPU required, ~30 min)
python knowledge_distillation.py

# Check GPU
nvidia-smi

# Check disk space (outputs can be large)
df -h .

# Check running processes
ps aux | grep python | grep -v grep

15. Paper Writing Notes

Use these numbers in the paper:

Table 1 β€” Main Results (Single Run, Test Set n=1,287):

Overall Accuracy:    49.1%  (weighted by class imbalance β€” 837/1287 are DR)
Balanced Accuracy:   76.7%
Macro F1:            0.626
Macro AUC:           0.893

Table 2 β€” Ensemble Results:

ViT-Base alone:               Acc=49.1%, F1=0.626, AUC=0.893
EfficientNet-B3 alone:        Acc=72.2%, F1=0.658, AUC=0.923
ViT+EfficientNet Ensemble:    Acc=74.7%, F1=0.712, AUC=0.951

Table 3 β€” K-Fold CV (fill in after running kfold_cv.py):

Accuracy:          XX.X% Β± X.X%
Balanced Accuracy: XX.X% Β± X.X%
Macro F1:          0.XXX Β± 0.XXX
Macro AUC:         0.XXX Β± 0.XXX

Table 4 β€” DANN Results (fill in after running train_dann.py):

Overall Accuracy:    XX.X%  (expected: 80-85%)
DR Recall:           XX.X%  (expected: 60-75%, up from 25.3%)
Macro F1:            0.XXX  (expected: 0.75-0.85)
Domain accuracy:     ~50%   (closer to 50% = more domain-invariant)

Reviewer issues addressed:

  • Missing per-class metrics β†’ eval_dashboard.py outputs metrics_report.json
  • Missing ROC/PR curves β†’ outputs_v3/evaluation/roc_curves_per_class.png, precision_recall_curves.png
  • No cross-validation β†’ kfold_cv.py (run on GPU)
  • No confidence intervals β†’ kfold_results.json (mean Β± std, after K-Fold)
  • No statistical significance β†’ fairness_analysis.py, chi-squared test p=0.296
  • No uncertainty quantification β†’ mc_dropout_uncertainty.py (epistemic/aleatoric split)

16. Research Novelty (Paper Differentiators)

Novelty 1: Domain-Adversarial Retinal Screening

  • Problem: Cross-dataset domain shift (APTOS vs ODIR) causes 25.3% DR recall
  • Solution: DANN with gradient reversal forces domain-invariant features
  • Implementation: train_dann.py
  • Paper angle: "Domain-Invariant Retinal Disease Classification Across Heterogeneous Fundus Image Sources"

Novelty 2: Uncertainty-Guided Clinical Triage

  • Problem: When should the AI auto-screen vs defer to a human?
  • Solution: Combine confidence + MC Dropout uncertainty + ensemble disagreement into triage levels
  • Implementation: app.py (live, working)
  • Paper angle: "Uncertainty-Aware Retinal Screening: When to Trust the AI and When to Defer"

Novelty 3: RETFound Foundation Model Transfer

  • Problem: ImageNet features are suboptimal for retinal pathology
  • Solution: Use RETFound (1.6M retinal images MAE-pretrained) as backbone
  • Implementation: retfound_backbone.py
  • Paper angle: "Parameter-Efficient Transfer from Retinal Foundation Models for Small-Dataset Classification"

Novelty 4: Preprocessing-Induced Domain Shift Analysis

  • Problem: Different preprocessing (Ben Graham vs CLAHE) creates artificial domain shift
  • Finding: CLAHE alone shifts Glaucoma probability by +43 percentage points
  • Solution: Unified CLAHE pipeline eliminates this
  • Implementation: unified_preprocessing.py
  • Paper angle: Novel contribution β€” documented evidence that preprocessing choices create measurable domain shift in retinal AI

17. Critical Bug Fix Log (2026-03-10)

These bugs were found during the investigation session and are now fixed:

Bug Severity Root Cause Fix
Wrong predictions on all images CRITICAL app.py missing circular mask + border crop in preprocessing Added _crop_black_borders() and _apply_circular_mask()
Wrong predictions on all images CRITICAL app.py used INTER_LINEAR resize vs training's INTER_AREA Changed to cv2.INTER_AREA
OOD report crash HIGH ood.ood_threshold is None when npz missing, format string fails Added None check with graceful fallback
CLAHE +43% Glaucoma shift MODERATE CLAHE applied blindly vs domain-conditional during training Documented; fix requires retraining with unified_preprocessing.py

Investigation methodology: 4 parallel agents analyzed architecture, preprocessing, raw model outputs, and normalization. Full findings in memory files.