RetinaSense-ViT β Complete Run Guide
Read this file at the start of any new session. Everything needed to understand, run, test, and extend this project is documented here.
0. Local Setup (Start Here If Cloning for the First Time)
The project was originally developed on a remote GPU server. All paths are now relative β the code works on any machine. Run this once after cloning:
git clone https://github.com/Tanishq74/retina-sense
cd retina-sense
bash setup.sh
setup.sh does everything automatically:
- Installs all Python dependencies (torch CPU build + timm, gradio, fastapi, etc.)
- Creates
outputs_v3/anddata/directories - Downloads
best_model.pth(331MB) from Hugging Face Hub βoutputs_v3/ - Downloads
efficientnet_b3.pth(45MB) from Hugging Face Hub βoutputs_v3/ensemble/ - Verifies
configs/JSON files are present (they are committed to git)
After setup completes:
python app.py # Gradio web demo β http://localhost:7860 (also generates public URL)
What is and isn't in git:
| File | In git? | How to get it |
|---|---|---|
All .py scripts |
Yes | git clone |
configs/*.json |
Yes | git clone (temperature, thresholds, norm stats) |
outputs_v3/best_model.pth |
No (331MB) | bash setup.sh β auto-downloads from HF |
outputs_v3/ensemble/efficientnet_b3.pth |
No (45MB) | bash setup.sh β auto-downloads from HF |
outputs_v3/ood_detector.npz |
No | Only on GPU server β app runs fine without it (OOD check skipped) |
data/*.csv |
No (large) | Only on GPU server β only needed for retraining |
preprocessed_cache_v3/ |
No (multi-GB) | Only on GPU server β only needed for retraining |
HF model repo: https://huggingface.co/tanishq74/retinasense-vit
1. Project Overview
RetinaSense-ViT is a deep learning system for retinal disease classification from fundus photographs, featuring an ensemble architecture, uncertainty-guided clinical triage, and domain-adversarial training.
| Property | Value |
|---|---|
| Task | Multi-class classification, 5 diseases |
| Classes | Normal (0), Diabetes/DR (1), Glaucoma (2), Cataract (3), AMD (4) |
| Inference Model | ViT-Base/16 + EfficientNet-B3 Ensemble + TTA |
| Ensemble Accuracy | 74.7% (vs 49.1% ViT-only) / Macro AUC 0.951 |
| Dataset | 8,540 images β APTOS (3,662) + ODIR (4,878) |
| Split | 70/15/15 β train/calib/test (stratified) |
| Best checkpoint | outputs_v3/best_model.pth (epoch 24) |
| GitHub | https://github.com/Tanishq74/retina-sense |
| App Features | Ensemble + TTA + Attention Rollout + MC Dropout + Clinical Triage |
2. Directory Structure
<repo-root>/
β
βββ app.py # Gradio web demo (Ensemble + TTA + Triage)
βββ api/
β βββ main.py # FastAPI REST server
β
βββ βββ TRAINING SCRIPTS βββ
βββ retinasense_v3.py # Main ViT training script (1220 lines)
βββ train_ensemble.py # EfficientNet-B3 ensemble training
βββ train_dann.py # NEW: Domain-Adversarial Neural Network (GPU)
βββ kfold_cv.py # 5-fold cross-validation (GPU)
βββ knowledge_distillation.py # KD + ONNX export (GPU)
β
βββ βββ IMPROVEMENT MODULES βββ
βββ unified_preprocessing.py # NEW: Unified CLAHE pipeline (replaces domain-conditional)
βββ retfound_backbone.py # NEW: RETFound foundation model backbone
βββ enhanced_augmentation.py # NEW: CutMix, elastic deform, class-aware augmentation
βββ prepare_datasets.py # NEW: Download/prep 5 additional public datasets
β
βββ βββ ANALYSIS & XAI βββ
βββ gradcam_v3.py # Attention Rollout XAI
βββ eval_dashboard.py # Full evaluation suite
βββ mc_dropout_uncertainty.py # MC Dropout uncertainty
βββ integrated_gradients_xai.py # Integrated Gradients XAI
βββ fairness_analysis.py # Domain fairness analysis
β
βββ βββ DATA & CONFIG βββ
βββ configs/
β βββ fundus_norm_stats.json # mean=[0.4298,0.2784,0.1559] std=[0.2857,0.2065,0.1465]
β βββ temperature.json # T=0.6438
β βββ thresholds.json # Per-class thresholds
βββ data/ # CSVs (on GPU server only)
βββ preprocessed_cache_v3/ # .npy image cache (on GPU server only)
β
βββ βββ MODEL WEIGHTS (not in git) βββ
βββ outputs_v3/
β βββ best_model.pth # ViT-Base/16 checkpoint (331MB, from HF)
β βββ ensemble/
β β βββ efficientnet_b3.pth # EfficientNet-B3 checkpoint (45MB, from HF)
β βββ evaluation/ # Phase 1A outputs (7 files)
β βββ uncertainty/ # Phase 1B outputs (6 files)
β βββ xai/ # Phase 1C outputs (23 files)
β βββ fairness/ # Phase 1D outputs (7 files)
β βββ gradcam/ # Attention Rollout heatmaps (22 files)
β βββ dann/ # DANN outputs (after training)
β
βββ βββ DOCUMENTATION βββ
βββ RUN.md # This file
βββ ARCHITECTURE_DOCUMENT.md # System architecture
βββ FUNCTIONAL_DOCUMENT.md # Functional specification
βββ FUNCTIONAL_TEST_CASE_DOCUMENT.md
βββ IEEE_RESEARCH_PAPER.md # Research paper draft
βββ FINAL_COMPREHENSIVE_REPORT.md
βββ Dockerfile # Docker deployment
βββ requirements_deploy.txt # Deployment dependencies
3. Model Architecture
Defined as MultiTaskViT in retinasense_v3.py:
class MultiTaskViT(nn.Module):
# Backbone: ViT-Base/16, pretrained ImageNet-21k, output_dim=768
backbone = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
drop = nn.Dropout(0.3)
# Disease classification head
disease_head = Sequential(
Linear(768, 512), BatchNorm1d(512), ReLU, Dropout(0.3),
Linear(512, 256), BatchNorm1d(256), ReLU, Dropout(0.2),
Linear(256, 5)
)
# Severity sub-classification head (shared backbone)
severity_head = Sequential(
Linear(768, 256), BatchNorm1d(256), ReLU, Dropout(0.3),
Linear(256, 5)
)
def forward(self, x):
feat = self.drop(self.backbone(x))
return self.disease_head(feat), self.severity_head(feat)
# Returns: (disease_logits [B,5], severity_logits [B,5])
# ALWAYS unpack both, even if only using disease_logits
EfficientNet-B3 (Ensemble Member)
class EfficientNetB3(nn.Module):
backbone = timm.create_model('efficientnet_b3', num_classes=0) # 1536-dim
drop = nn.Dropout(0.3)
head = Sequential(
Linear(1536, 512), BatchNorm1d(512), ReLU, Dropout(0.3),
Linear(512, 256), BatchNorm1d(256), ReLU, Dropout(0.2),
Linear(256, 5)
)
def forward(self, x):
return self.head(self.drop(self.backbone(x)))
# Returns: logits [B,5] β single tensor, NOT tuple
Ensemble Inference (how app.py works)
# Both models loaded, both in eval mode
T = 0.6438
vit_probs = softmax(vit_disease_logits / T, dim=1) # from MultiTaskViT
eff_probs = softmax(eff_logits / T, dim=1) # from EfficientNetB3
# Weighted ensemble
ensemble_probs = 0.35 * vit_probs + 0.65 * eff_probs
pred = argmax(ensemble_probs)
# TTA: average over 4 augmented versions (original, h-flip, v-flip, Β±10Β° rotation)
# MC Dropout: 15 stochastic passes for uncertainty estimation
DANNMultiTaskViT (Domain-Adversarial, in train_dann.py)
class DANNMultiTaskViT(nn.Module):
backbone = ViT-Base/16 (768-dim)
disease_head = 768 β 512 β 256 β 5 (same as MultiTaskViT)
severity_head = 768 β 256 β 5 (same as MultiTaskViT)
domain_head = 768 β GRL β 256 β 128 β 2 (NEW: APTOS vs ODIR)
# GRL = Gradient Reversal Layer (negates gradients to force domain invariance)
4. Installed Packages
All packages are already installed in the environment:
torch==2.8.0+cu128 torchvision==0.23.0+cu128
timm==1.0.25 pytorch-lightning==2.6.0
gradio==6.9.0 captum==0.8.0
onnx==1.20.1 onnxruntime==1.24.3
fastapi==0.133.1 uvicorn==0.41.0
fpdf2==2.8.7 opencv-python==4.11.0.86
torchmetrics==1.7.4
If anything is missing:
pip install -r requirements_deploy.txt
pip install captum gradio fpdf2
5. What Has Already Been Run (Do NOT Re-Run)
These scripts completed successfully on a previous GPU session and their outputs are saved:
| Script | Status | Output Location | Key Result |
|---|---|---|---|
retinasense_v3.py |
DONE | outputs_v3/best_model.pth |
Epoch 24, F1=0.854 |
gradcam_v3.py |
DONE | outputs_v3/gradcam/ (22 files) |
Attention Rollout heatmaps |
eval_dashboard.py |
DONE | outputs_v3/evaluation/ (7 files) |
Acc=49.1%, AUC=0.893 |
mc_dropout_uncertainty.py |
DONE | outputs_v3/uncertainty/ (6 files) |
T=30 MC passes |
integrated_gradients_xai.py |
DONE | outputs_v3/xai/ (23 files) |
Pearson r=0.196 |
fairness_analysis.py |
DONE | outputs_v3/fairness/ (7 files) |
APTOS ECE=0.51 |
train_ensemble.py |
DONE | outputs_v3/ensemble/ (2 files) |
Acc=74.7%, AUC=0.951 |
6. What Still Needs to Run (Requires GPU)
Recommended execution order:
python unified_preprocessing.pyβ rebuild cache (fixes domain shift)python prepare_datasets.py --allβ expand dataset (optional but recommended)python train_dann.pyβ domain-adversarial training (main accuracy improvement)python kfold_cv.pyβ cross-validation for paperpython knowledge_distillation.pyβ model compression
6A-NEW. Unified Preprocessing + Cache Rebuild (~30 min)
python unified_preprocessing.py
What it does: Rebuilds the entire image cache using a single CLAHE pipeline for ALL images (APTOS, ODIR, REFUGE2). This eliminates the domain-conditional preprocessing (Ben Graham for APTOS, CLAHE for ODIR) that caused the domain shift.
Outputs:
preprocessed_cache_unified/β new .npy cache with consistent preprocessingconfigs/fundus_norm_stats_unified.jsonβ recomputed normalization statsdata/*_unified.csvβ updated CSVs with new cache paths
After running: Update training scripts to use the unified cache and new norm stats.
6B-NEW. Dataset Expansion (optional, ~1-2 hours download + preprocess)
python prepare_datasets.py --list # show available datasets
python prepare_datasets.py --instructions # download instructions
python prepare_datasets.py --dataset eyepacs --raw-dir ./data/eyepacs
python prepare_datasets.py --dataset refuge --raw-dir ./data/refuge
python prepare_datasets.py --dataset adam --raw-dir ./data/adam
python prepare_datasets.py --merge # combine all into unified splits
Available datasets:
| Dataset | Images | Classes Added | Impact |
|---|---|---|---|
| EyePACS | ~35,000 | DR + Normal | Massive DR/Normal boost |
| MESSIDOR-2 | 1,748 | DR grades | More DR diversity |
| REFUGE | ~1,200 | Glaucoma + Normal | 20Γ more Glaucoma samples |
| ADAM (iChallenge-AMD) | ~1,200 | AMD + Normal | 30Γ more AMD samples |
| ORIGA | ~650 | Glaucoma + Normal | More Glaucoma diversity |
6C-NEW. Domain-Adversarial Training (~2-3 hours on H100)
python train_dann.py
What it does: Trains a Domain-Adversarial Neural Network (DANN) with gradient reversal that forces the ViT backbone to learn features that are predictive of disease but NOT predictive of which dataset (APTOS vs ODIR) the image came from.
Key components:
GradientReversalLayer: reverses gradients with Ganin schedule (lambda ramps 0β1)DANNMultiTaskViT: disease head + severity head + domain discriminator- Loss:
disease_loss + 0.2*severity_loss + alpha*lambda*domain_loss - Warm-starts from
outputs_v3/best_model.pth - Same training recipe as v3: AdamW, LLRD, OneCycleLR, Focal Loss, Mixup
Outputs β outputs_v3/dann/:
best_model.pthβ DANN-trained checkpointhistory.jsonβ per-epoch metricsdashboard.pngβ training curves (disease + domain accuracy)
Expected results: Domain accuracy should converge toward ~50% (random = domain-invariant). Disease accuracy should improve, especially on APTOS DR images.
6D-NEW. RETFound Backbone (alternative to 6C)
# In your training script, replace:
from retinasense_v3 import MultiTaskViT
# With:
from retfound_backbone import MultiTaskRetFound, setup_retfound
# Download RETFound weights (once)
setup_retfound()
# Create model with retinal-pretrained backbone
model = MultiTaskRetFound(retfound_path='./weights/RETFound_cfp_weights.pth')
What it does: Swaps ImageNet-pretrained ViT for RETFound β a ViT-Base/16 pretrained on 1.6 million retinal images using masked autoencoding. Same architecture, much better features for retinal pathology.
6E. K-Fold Cross-Validation (~2 hours on H100)
python kfold_cv.py
What it does: 5-fold stratified CV on the full train+calib pool (7,265 images). Produces mean Β± std confidence intervals for the paper.
Config (inside kfold_cv.py):
N_FOLDS=5, N_EPOCHS=30, PATIENCE=8, BATCH_SIZE=32
BASE_LR=3e-4, LLRD_DECAY=0.75, MIXUP_ALPHA=0.4, FOCAL_GAMMA=1.0
Outputs β outputs_v3/kfold/:
fold_1_best.pth...fold_5_best.pthkfold_results.jsonβ mean Β± std per metricfold_comparison.pngβ bar chart per foldperclass_f1_boxplot.pngβ per-class F1 boxplot
Expected results:
Accuracy: 60β75% Β± ~5%
Balanced Accuracy: 75β85% Β± ~3%
Macro F1: 0.65β0.75 Β± ~0.03
Macro AUC: 0.90β0.96 Β± ~0.02
Troubleshooting:
| Problem | Fix |
|---|---|
| CUDA OOM | Reduce BATCH_SIZE from 32 to 16 in kfold_cv.py |
| Cache miss | Check preprocessed_cache_v3/ has .npy files |
| KeyError: mean_rgb | data/fundus_norm_stats.json must exist |
| Arch mismatch | MultiTaskViT in kfold_cv.py must match retinasense_v3.py |
After completion: update SESSION_CONTEXT.md with meanΒ±std numbers.
6B. Knowledge Distillation + ONNX Export (~30 min on H100)
cd /teamspace/studios/this_studio
python knowledge_distillation.py
What it does: Compresses ViT-Base (86M params) β ViT-Tiny (5.7M params) using knowledge distillation. Then exports to ONNX and quantizes to INT8 for CPU deployment.
Config (inside knowledge_distillation.py):
KD_ALPHA=0.3 (30% CE loss + 70% KL distillation loss)
KD_TEMP=4.0 (softens teacher logits for distillation)
Student: vit_tiny_patch16_224 (192-dim, ~5.7M params)
Outputs β outputs_v3/distillation/:
student_best.pthβ distilled ViT-Tiny checkpointretinasense_student.onnxβ ONNX model (opset 17, dynamic batch)retinasense_student_int8.onnxβ INT8 quantized (~6MB)distillation_results.jsonβ accuracy comparison + CPU benchmarkdistillation_curves.pngβ training curves
Expected results:
Teacher (ViT-Base): Acc~74%, AUC~0.95, Size=331MB
Student (ViT-Tiny): Acc~68-72%, AUC~0.91-0.94, Size~23MB
INT8 quantized: Acc~67-71%, Size~6MB, CPU inference ~80ms
7. Running the Web Applications
7A. Gradio Web Demo
python app.py
- Opens on port 7860 (also generates a public shareable URL)
- Features:
- Ensemble prediction: ViT-Base/16 (35%) + EfficientNet-B3 (65%)
- Test-Time Augmentation: 4 augmented versions averaged (h-flip, v-flip, Β±10Β° rotation)
- Attention Rollout heatmap: ViT attention visualization
- MC Dropout uncertainty: 15 stochastic passes (epistemic + aleatoric split)
- Clinical triage: AUTO-SCREEN / PRIORITY REVIEW / URGENT / RESCAN
- Model disagreement detection: shows when ViT and EfficientNet disagree
- OOD detection: Mahalanobis distance (gracefully skipped if npz missing)
- Downloadable clinical report: .txt file with full analysis
What app.py does internally:
- Loads
outputs_v3/best_model.pth(ViT-Base/16) - Loads
outputs_v3/ensemble/efficientnet_b3.pth(EfficientNet-B3) - Loads configs:
temperature.json,thresholds.json,fundus_norm_stats.json - Preprocessing: crop borders β resize 224 (INTER_AREA) β CLAHE β circular mask β normalize
- Runs TTA (4 augmentations) through both models
- Computes ensemble probabilities (35/65 weighted average)
- Runs MC Dropout (15 passes) for uncertainty
- Computes triage level based on confidence + uncertainty + model agreement
- Generates Attention Rollout heatmap
- Generates clinical recommendation + downloadable report
Preprocessing pipeline (CRITICAL β must match training):
def preprocess_image(img_pil):
img = crop_black_borders(img) # remove dark padding
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
img = apply_clahe(img) # CLAHE on L-channel in LAB
img = apply_circular_mask(img) # zero pixels outside fundus circle (r=0.48)
tensor = ToTensor + Normalize(mean=[0.4298,0.2784,0.1559], std=[0.2857,0.2065,0.1465])
Clinical Triage System:
| Level | Criteria | Action |
|---|---|---|
| AUTO-SCREEN | Confidence > 70%, low uncertainty, models agree | Routine re-screening |
| PRIORITY REVIEW | Confidence 40-70%, or elevated uncertainty | Specialist within 2 weeks |
| URGENT SPECIALIST | Confidence < 40%, or high uncertainty, or models disagree | Specialist within 48 hours |
| RESCAN NEEDED | OOD detected | Image quality issue, rescan |
7B. FastAPI REST Server
cd /teamspace/studios/this_studio
python -m uvicorn api.main:app --host 0.0.0.0 --port 8000
Or with auto-reload during development:
python -m uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload
Endpoints:
| Method | Endpoint | Description |
|---|---|---|
| GET | /health |
Returns {"status": "healthy", "model": "loaded"} |
| POST | /predict |
Single image inference |
| POST | /predict/batch |
Batch inference (multiple images) |
| GET | /docs |
Interactive Swagger UI |
Test the API:
# Health check
curl http://localhost:8000/health
# Single prediction (replace with a real fundus image path)
curl -X POST http://localhost:8000/predict \
-F "file=@/path/to/fundus_image.jpg"
# Response format:
{
"predicted_class": "Diabetes/DR",
"predicted_label": 1,
"confidence": 0.819,
"probabilities": {"Normal": 0.08, "Diabetes/DR": 0.82, "Glaucoma": 0.03, "Cataract": 0.04, "AMD": 0.03},
"severity": "Moderate",
"ood_score": 12.4,
"is_ood": false,
"inference_time_ms": 370
}
7C. Docker Deployment
cd /teamspace/studios/this_studio
# Build the image
docker build -t retinasense .
# Run the container
docker run -p 8000:8000 retinasense
# With GPU support
docker run --gpus all -p 8000:8000 retinasense
8. Running XAI / Attention Rollout on New Images
cd /teamspace/studios/this_studio
python gradcam_v3.py
What it does: Runs Attention Rollout (Abnar & Zuidema 2020) on 20 sample images
(4 per class), saves heatmap overlays to outputs_v3/gradcam/.
Key settings in gradcam_v3.py:
discard_ratio = 0.97 # discard bottom 97% of attention (sharp focus)
alpha = 0.7 # heatmap overlay strength
power_stretch = 0.4 # contrast enhancement (np.power(spatial, 0.4))
Note: Standard Grad-CAM does NOT work on ViT (CLS token problem β zero patch gradients).
Always use ViTAttentionRollout, never GradCAM for this model.
9. Key Numbers (Current Performance)
Single-Run Test Set (1,287 images)
| Metric | Value |
|---|---|
| Overall Accuracy | 49.1% |
| Balanced Accuracy | 76.7% |
| Macro F1 | 0.626 |
| Macro AUC | 0.893 |
| Cohen Kappa | 0.293 |
Per-class breakdown:
| Class | Precision | Recall | F1 | AUC | Support |
|---|---|---|---|---|---|
| Normal | 0.336 | 0.968 | 0.499 | 0.770 | 310 |
| Diabetes/DR | 0.995 | 0.253 | 0.404 | 0.830 | 837 |
| Glaucoma | 0.717 | 0.635 | 0.673 | 0.877 | 52 |
| Cataract | 0.681 | 0.979 | 0.803 | 0.988 | 48 |
| AMD | 0.597 | 1.000 | 0.748 | 0.998 | 40 |
Ensemble (ViT 35% + EfficientNet-B3 65%):
| Metric | Value |
|---|---|
| Accuracy | 74.7% |
| Macro F1 | 0.712 |
| Macro AUC | 0.951 |
| Disagreement rate | 44.2% (flags for human review) |
Domain Gap
| Domain | N | Accuracy | ECE |
|---|---|---|---|
| ODIR | 709 | 67.7% | 0.157 |
| APTOS | 570 | 26.5% | 0.510 |
| REFUGE2 | 8 | 12.5% | β |
Critical finding: DR recall = 25.3%. 573 of 837 DR images misclassified as Normal. DR precision = 99.5% β when the model says DR, it's correct. But it rarely fires. Root cause: APTOS Ben Graham preprocessing creates a domain shift from ODIR CLAHE images.
10. Configuration Files
Temperature & Thresholds
outputs_v3/temperature.json:
{"temperature": 0.6438, "ece_before": 0.1618, "ece_after": 0.1014}
outputs_v3/thresholds.json:
{
"thresholds": [0.638, 0.068, 0.840, 0.564, 0.289],
"class_names": ["Normal", "Diabetes/DR", "Glaucoma", "Cataract", "AMD"]
}
Note: DR threshold = 0.068 (very low). Despite this, APTOS DR images still get classified as Normal because their DR probability stays below even this threshold.
Normalization Stats
data/fundus_norm_stats.json keys: mean_rgb and std_rgb (not mean/std).
import json
stats = json.load(open('data/fundus_norm_stats.json'))
mean = stats['mean_rgb'] # [0.4298, 0.2784, 0.1559]
std = stats['std_rgb'] # [0.2857, 0.2065, 0.1465]
OOD Detector
import numpy as np
data = np.load('outputs_v3/ood_detector.npz')
# Keys: class_means, precision_matrix, threshold
# threshold = 42.82 (Mahalanobis distance)
# If distance > 42.82 β flag as out-of-distribution
11. Data Loading Pattern
CSV columns: image_path, label, source, cache_path
import pandas as pd, numpy as np, torch
from torchvision import transforms
df = pd.read_csv('data/test_split.csv')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4298, 0.2784, 0.1559],
std=[0.2857, 0.2065, 0.1465])
])
# Fast path: load from preprocessed cache (already resized to 224x224)
def load_image(row):
cache = row['cache_path'] # e.g. preprocessed_cache_v3/xxxx.npy
img = np.load(cache) # uint8 [H,W,3]
img = torch.tensor(img).permute(2,0,1).float() / 255.0
return transform_normalize(img) # apply only normalize, not resize
# Sources: 'APTOS', 'ODIR', 'REFUGE2'
# Labels: 0=Normal, 1=DR, 2=Glaucoma, 3=Cataract, 4=AMD
12. Complete Task Checklist
Done (Training β GPU server)
- Train ViT-Base/16 model (retinasense_v3.py) β epoch 24, F1=0.854
- Temperature calibration (T=0.6438)
- Per-class threshold optimization
- Attention Rollout XAI (gradcam_v3.py) β heatmaps working
- OOD Mahalanobis detector
- Evaluation dashboard (eval_dashboard.py) β 7 output files
- MC Dropout uncertainty (mc_dropout_uncertainty.py) β 6 output files
- Integrated Gradients XAI (integrated_gradients_xai.py) β 23 output files
- Fairness/domain analysis (fairness_analysis.py) β 7 output files
- EfficientNet-B3 ensemble (train_ensemble.py) β 74.7% acc, AUC=0.951
Done (App β session 2026-03-10)
- Ensemble inference in app.py (ViT 35% + EfficientNet 65%)
- Fixed preprocessing pipeline (border crop + circular mask + INTER_AREA)
- Test-Time Augmentation (4 augmented versions averaged)
- Clinical triage system (AUTO-SCREEN / REVIEW / URGENT / RESCAN)
- Model disagreement detection (ViT vs EfficientNet)
- OOD graceful handling (no crash when npz missing)
- Gradio web app with all features β port 7860
- FastAPI REST server (api/main.py) β port 8000
- Docker deployment (Dockerfile)
Done (New training code β ready for GPU)
-
unified_preprocessing.pyβ single CLAHE pipeline for all sources -
train_dann.pyβ Domain-Adversarial Neural Network training -
retfound_backbone.pyβ RETFound foundation model backbone support -
enhanced_augmentation.pyβ CutMix, elastic deform, 5Γ minority oversampling -
prepare_datasets.pyβ download/prep for EyePACS, MESSIDOR-2, REFUGE, ADAM, ORIGA
Needs GPU (run in this order)
- Rebuild cache:
python unified_preprocessing.py(~30 min) - Expand dataset:
python prepare_datasets.py --all(optional) - DANN training:
python train_dann.py(~2-3 hrs on H100) - K-Fold CV:
python kfold_cv.py(~2 hrs on H100) - Knowledge Distillation:
python knowledge_distillation.py(~30 min on H100) - (Alternative) RETFound backbone: modify training to use
retfound_backbone.py
Known Issues
- DR recall = 25.3% β root cause: APTOS domain shift (Ben Graham vs CLAHE preprocessing).
Fix:
unified_preprocessing.py+train_dann.py - AMD under-represented (~40 samples). Fix:
prepare_datasets.pyto add ADAM dataset (1,200 AMD images) - Model overconfident on garbage input (random noise: 86%). Inherent model property.
- Ensemble thresholds not calibrated β current thresholds are for ViT-only. Fix: Recalibrate after DANN retraining.
- Temperature T=0.6438 sharpens wrong predictions. Consider recalibrating after retraining.
13. Model Weights β Hugging Face Hub
Model weights are NOT in git (too large). They are hosted on Hugging Face:
Repo: https://huggingface.co/tanishq74/retinasense-vit
| File | Size | Description |
|---|---|---|
best_model.pth |
331MB | ViT-Base/16 trained checkpoint (epoch 24) |
efficientnet_b3.pth |
45MB | EfficientNet-B3 ensemble checkpoint |
Download automatically (run once):
bash setup.sh # handles everything including model download
Or download manually:
from huggingface_hub import hf_hub_download
import shutil, os
os.makedirs("outputs_v3/ensemble", exist_ok=True)
for fname, dest in [
("best_model.pth", "outputs_v3/best_model.pth"),
("efficientnet_b3.pth", "outputs_v3/ensemble/efficientnet_b3.pth"),
]:
path = hf_hub_download(repo_id="tanishq74/retinasense-vit", filename=fname)
shutil.copy(path, dest)
14. Quick-Start Commands
# Full local setup from scratch (cloned repo):
bash setup.sh # installs deps + downloads models from HF
# Check what's already been generated
ls outputs_v3/evaluation/ outputs_v3/uncertainty/ outputs_v3/xai/ outputs_v3/fairness/ outputs_v3/ensemble/
# Start Gradio app
python app.py
# Start FastAPI server
python -m uvicorn api.main:app --host 0.0.0.0 --port 8000
# Run K-Fold CV (GPU required, ~2 hrs)
python kfold_cv.py
# Run Knowledge Distillation (GPU required, ~30 min)
python knowledge_distillation.py
# Check GPU
nvidia-smi
# Check disk space (outputs can be large)
df -h .
# Check running processes
ps aux | grep python | grep -v grep
15. Paper Writing Notes
Use these numbers in the paper:
Table 1 β Main Results (Single Run, Test Set n=1,287):
Overall Accuracy: 49.1% (weighted by class imbalance β 837/1287 are DR)
Balanced Accuracy: 76.7%
Macro F1: 0.626
Macro AUC: 0.893
Table 2 β Ensemble Results:
ViT-Base alone: Acc=49.1%, F1=0.626, AUC=0.893
EfficientNet-B3 alone: Acc=72.2%, F1=0.658, AUC=0.923
ViT+EfficientNet Ensemble: Acc=74.7%, F1=0.712, AUC=0.951
Table 3 β K-Fold CV (fill in after running kfold_cv.py):
Accuracy: XX.X% Β± X.X%
Balanced Accuracy: XX.X% Β± X.X%
Macro F1: 0.XXX Β± 0.XXX
Macro AUC: 0.XXX Β± 0.XXX
Table 4 β DANN Results (fill in after running train_dann.py):
Overall Accuracy: XX.X% (expected: 80-85%)
DR Recall: XX.X% (expected: 60-75%, up from 25.3%)
Macro F1: 0.XXX (expected: 0.75-0.85)
Domain accuracy: ~50% (closer to 50% = more domain-invariant)
Reviewer issues addressed:
- Missing per-class metrics β eval_dashboard.py outputs
metrics_report.json - Missing ROC/PR curves β
outputs_v3/evaluation/roc_curves_per_class.png,precision_recall_curves.png - No cross-validation β kfold_cv.py (run on GPU)
- No confidence intervals β kfold_results.json (mean Β± std, after K-Fold)
- No statistical significance β fairness_analysis.py, chi-squared test p=0.296
- No uncertainty quantification β mc_dropout_uncertainty.py (epistemic/aleatoric split)
16. Research Novelty (Paper Differentiators)
Novelty 1: Domain-Adversarial Retinal Screening
- Problem: Cross-dataset domain shift (APTOS vs ODIR) causes 25.3% DR recall
- Solution: DANN with gradient reversal forces domain-invariant features
- Implementation:
train_dann.py - Paper angle: "Domain-Invariant Retinal Disease Classification Across Heterogeneous Fundus Image Sources"
Novelty 2: Uncertainty-Guided Clinical Triage
- Problem: When should the AI auto-screen vs defer to a human?
- Solution: Combine confidence + MC Dropout uncertainty + ensemble disagreement into triage levels
- Implementation:
app.py(live, working) - Paper angle: "Uncertainty-Aware Retinal Screening: When to Trust the AI and When to Defer"
Novelty 3: RETFound Foundation Model Transfer
- Problem: ImageNet features are suboptimal for retinal pathology
- Solution: Use RETFound (1.6M retinal images MAE-pretrained) as backbone
- Implementation:
retfound_backbone.py - Paper angle: "Parameter-Efficient Transfer from Retinal Foundation Models for Small-Dataset Classification"
Novelty 4: Preprocessing-Induced Domain Shift Analysis
- Problem: Different preprocessing (Ben Graham vs CLAHE) creates artificial domain shift
- Finding: CLAHE alone shifts Glaucoma probability by +43 percentage points
- Solution: Unified CLAHE pipeline eliminates this
- Implementation:
unified_preprocessing.py - Paper angle: Novel contribution β documented evidence that preprocessing choices create measurable domain shift in retinal AI
17. Critical Bug Fix Log (2026-03-10)
These bugs were found during the investigation session and are now fixed:
| Bug | Severity | Root Cause | Fix |
|---|---|---|---|
| Wrong predictions on all images | CRITICAL | app.py missing circular mask + border crop in preprocessing |
Added _crop_black_borders() and _apply_circular_mask() |
| Wrong predictions on all images | CRITICAL | app.py used INTER_LINEAR resize vs training's INTER_AREA |
Changed to cv2.INTER_AREA |
| OOD report crash | HIGH | ood.ood_threshold is None when npz missing, format string fails |
Added None check with graceful fallback |
| CLAHE +43% Glaucoma shift | MODERATE | CLAHE applied blindly vs domain-conditional during training | Documented; fix requires retraining with unified_preprocessing.py |
Investigation methodology: 4 parallel agents analyzed architecture, preprocessing, raw model outputs, and normalization. Full findings in memory files.