--- license: mit tags: - retinal-disease - medical-imaging - vision-transformer - efficientnet - hybrid-model - ophthalmology datasets: - aptos2019 - odir5k metrics: - accuracy - f1 - roc_auc pipeline_tag: image-classification library_name: pytorch model-index: - name: RetinaSense-ViT-v4 results: - task: type: image-classification name: Retinal Disease Classification dataset: name: APTOS-2019 + ODIR-5K (10,000 balanced images) type: custom metrics: - name: Test Accuracy type: accuracy value: 0.820 - name: Macro AUC type: roc_auc value: 0.969 - name: Macro F1 type: f1 value: 0.822 - name: 5-Fold CV Accuracy type: accuracy value: 0.9113 - name: 5-Fold CV Macro F1 type: f1 value: 0.910 --- # RetinaSense-ViT v4: Hybrid Vision Transformer for Retinal Disease Classification **RetinaSense-ViT v4** is a hybrid deep learning model that fuses **ViT-Base/16** and **EfficientNet-B3** feature representations for automated retinal disease screening from color fundus photographs. It classifies images into 5 disease categories with calibrated confidence estimates, uncertainty quantification, and explainable attention maps. **HuggingFace Repository:** [tanishq74/retinasense-vit](https://huggingface.co/tanishq74/retinasense-vit) --- ## Model Details | Property | Value | |----------|-------| | **Architecture** | HybridRetinaModel (ViT-Base/16 + EfficientNet-B3 fusion) | | **Parameters** | 97.8M | | **Input Size** | 224 x 224 x 3 (RGB fundus photograph) | | **Output** | 5-class probability distribution | | **Framework** | PyTorch | | **License** | MIT | ### Architecture Overview The HybridRetinaModel extracts complementary features from two pretrained backbones and fuses them through concatenation followed by a classification MLP: ``` Input Image (224x224x3) | +---> EfficientNet-B3 ---> 1536-dim features | | +---> ViT-Base/16 ---------> 768-dim features | Concatenation (2304-dim) | MLP Classifier | 5-class predictions ``` - **EfficientNet-B3 branch:** Extracts local texture and fine-grained pathological features (1536-dim pooled features) - **ViT-Base/16 branch:** Captures global spatial relationships and long-range dependencies (768-dim CLS token) - **Fusion layer:** Concatenated 2304-dim representation passed through a multi-layer perceptron with batch normalization and dropout - **Transfer learning:** Initialized from v3 pretrained ViT (`best_model.pth`, 82.59% val acc) and EfficientNet-B3 (`efficientnet_b3.pth`, 71.1% val acc) ### Disease Classes | Label | Class | Description | |-------|-------|-------------| | 0 | Normal | No retinal disease detected | | 1 | Diabetes/DR | Diabetic retinopathy (microaneurysms, hemorrhages, exudates) | | 2 | Glaucoma | Optic disc cupping and glaucomatous changes | | 3 | Cataract | Lens opacity affecting fundus image clarity | | 4 | AMD | Age-related macular degeneration (drusen, atrophy) | --- ## Training ### Dataset The training dataset was constructed by merging two public retinal imaging datasets: | Source | Original Size | After Cleaning | |--------|--------------|----------------| | APTOS-2019 | 3,662 images | Deduplicated and quality-filtered | | ODIR-5K | 6,392 images | Deduplicated and quality-filtered | | **Merged (raw)** | **8,905 images** | -- | | **After deduplication** | -- | **5,270 images** | | **After balancing** | -- | **10,000 images (2,000/class)** | Balancing was achieved through a combination of undersampling over-represented classes and augmentation-based oversampling of under-represented classes to reach exactly 2,000 images per class. ### Data Splits | Split | Samples | Purpose | |-------|---------|---------| | Train | 7,038 | Model training | | Validation | 1,476 | Hyperparameter tuning and early stopping | | Test | 1,486 | Final held-out evaluation | All splits are **patient-aware** to prevent data leakage -- no patient appears in more than one split. ### Preprocessing Pipeline The preprocessing pipeline must be replicated exactly for correct inference: ``` Raw Image -> Crop Black Borders -> Resize 224x224 -> CLAHE (L-channel) -> Circular Mask -> Normalize ``` | Step | Details | |------|---------| | Border Crop | Remove dark padding (pixels with brightness < 7) | | Resize | 224 x 224 with `cv2.INTER_AREA` interpolation | | CLAHE | Applied to L-channel in LAB color space (clipLimit=2.0, tileGridSize=8x8) | | Circular Mask | Zero out pixels outside a centered circle (radius = 0.48 x min dimension) | | Normalize | mean=[0.4298, 0.2784, 0.1559], std=[0.2857, 0.2065, 0.1465] | ### Training Configuration | Parameter | Value | |-----------|-------| | Optimizer | AdamW with Layer-wise Learning Rate Decay (LLRD) | | Scheduler | OneCycleLR with cosine annealing | | Loss Function | Focal Loss (gamma=2.0) with label smoothing (0.1) | | Batch Size | 32 | | Learning Rate | 1e-4 | | Epochs | 40 | | Augmentation | MixUp, CutMix, albumentations (flips, rotation, color jitter, elastic transforms) | | Regularization | Dropout, Stochastic Weight Averaging (SWA) | | Fine-tuning | Lesion attention training with GradCAM-guided attention loss | --- ## Results ### 5-Fold Stratified Cross-Validation (10,000 images) | Metric | Mean | Std | |--------|------|-----| | Accuracy | 91.13% | +/- 0.55% | | Macro F1 | 0.910 | +/- 0.006 | | Macro AUC | 0.986 | +/- 0.001 | ### Held-Out Test Set (1,486 samples) | Metric | Value | |--------|-------| | Accuracy | 80.9% (82.0% with optimized thresholds) | | Macro F1 | 0.813 (0.822 with optimized thresholds) | | Macro AUC | 0.969 | | Cohen's Kappa | 0.761 | | MC Dropout Accuracy @ 90% retention | 86.0% | ### Per-Class Test Set Performance | Class | F1 | AUC | Precision | Recall | |-------|----|-----|-----------|--------| | Normal | 0.69 | 0.926 | 0.573 | 0.857 | | Diabetes/DR | 0.78 | 0.965 | 0.844 | 0.726 | | Glaucoma | 0.78 | 0.981 | 0.925 | 0.670 | | Cataract | 0.95 | 0.997 | 0.940 | 0.966 | | AMD | 0.87 | 0.977 | 0.917 | 0.827 | ### Improvement over v3 | Metric | v3 (Ensemble) | v4 (Hybrid Fusion) | Delta | |--------|--------------|---------------------|-------| | Accuracy | 74.7% | 82.0% | **+7.3%** | | Macro F1 | 0.712 | 0.822 | **+0.110** | | Macro AUC | 0.951 | 0.969 | **+0.018** | --- ## Usage ### Installation ```bash pip install torch torchvision timm opencv-python-headless numpy huggingface_hub ``` ### Download Model Weights ```python from huggingface_hub import hf_hub_download import shutil, os os.makedirs("weights", exist_ok=True) for fname in ["best_model.pth", "temperature.json", "thresholds.json"]: path = hf_hub_download(repo_id="tanishq74/retinasense-vit", filename=fname) shutil.copy(path, f"weights/{fname}") ``` ### Inference Example ```python import torch import torch.nn as nn import timm import cv2 import json import numpy as np from torchvision import transforms # --- Model Definition --- class HybridRetinaModel(nn.Module): def __init__(self, num_classes=5, drop_rate=0.3): super().__init__() # EfficientNet-B3 branch (1536-dim) self.efficientnet = timm.create_model( 'efficientnet_b3', pretrained=False, num_classes=0 ) # ViT-Base/16 branch (768-dim) self.vit = timm.create_model( 'vit_base_patch16_224', pretrained=False, num_classes=0 ) # Fusion MLP: 1536 + 768 = 2304 -> num_classes fusion_dim = 2304 self.classifier = nn.Sequential( nn.Linear(fusion_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(drop_rate), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(drop_rate * 0.67), nn.Linear(256, num_classes), ) def forward(self, x): eff_features = self.efficientnet(x) # (B, 1536) vit_features = self.vit(x) # (B, 768) fused = torch.cat([eff_features, vit_features], dim=1) # (B, 2304) return self.classifier(fused) # --- Preprocessing --- MEAN = [0.4298, 0.2784, 0.1559] STD = [0.2857, 0.2065, 0.1465] CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD'] normalize = transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize(MEAN, STD), ]) def preprocess(img_path): img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) # Crop black borders gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) mask = gray > 7 rows, cols = np.any(mask, axis=1), np.any(mask, axis=0) if rows.any() and cols.any(): r0, r1 = np.where(rows)[0][[0, -1]] c0, c1 = np.where(cols)[0][[0, -1]] img = img[r0:r1+1, c0:c1+1] # Resize img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA) # CLAHE on L-channel lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) lab[:, :, 0] = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(lab[:, :, 0]) img = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) # Circular mask h, w = img.shape[:2] cmask = np.zeros((h, w), dtype=np.uint8) cv2.circle(cmask, (w // 2, h // 2), int(min(h, w) * 0.48), 255, -1) img = cv2.bitwise_and(img, img, mask=cmask) return normalize(np.clip(img, 0, 255).astype(np.uint8)).unsqueeze(0) # --- Load Model --- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = HybridRetinaModel(num_classes=5).to(device) checkpoint = torch.load('weights/best_model.pth', map_location=device, weights_only=False) model.load_state_dict(checkpoint['model_state_dict']) model.eval() # Load temperature for calibrated probabilities with open('weights/temperature.json') as f: temperature = json.load(f)['temperature'] # Load per-class optimized thresholds (optional) with open('weights/thresholds.json') as f: thresholds = json.load(f) # --- Run Inference --- x = preprocess('your_fundus_image.jpg').to(device) with torch.no_grad(): logits = model(x) probs = torch.softmax(logits / temperature, dim=1).cpu().numpy()[0] pred_class = CLASS_NAMES[probs.argmax()] confidence = probs.max() print(f"Prediction: {pred_class} ({confidence * 100:.1f}%)") print("\nClass Probabilities:") for i, name in enumerate(CLASS_NAMES): print(f" {name}: {probs[i] * 100:.1f}%") ``` ### MC Dropout Uncertainty Estimation ```python def mc_dropout_predict(model, x, n_passes=30): """Enable dropout at test time for uncertainty estimation.""" model.train() # Enable dropout predictions = [] with torch.no_grad(): for _ in range(n_passes): logits = model(x) probs = torch.softmax(logits / temperature, dim=1) predictions.append(probs.cpu().numpy()) model.eval() predictions = np.array(predictions) # (n_passes, batch, classes) mean_probs = predictions.mean(axis=0) epistemic_uncertainty = predictions.var(axis=0).sum(axis=-1) # Inter-pass variance return mean_probs, epistemic_uncertainty mean_probs, uncertainty = mc_dropout_predict(model, x) print(f"Prediction: {CLASS_NAMES[mean_probs[0].argmax()]}") print(f"Epistemic Uncertainty: {uncertainty[0]:.4f}") ``` --- ## Features ### GradCAM Attention Visualization Generate class-discriminative heatmaps highlighting the retinal regions most relevant to the model's prediction. Useful for clinical interpretability and verifying that the model attends to pathologically relevant structures. ### MC Dropout Uncertainty Estimation Monte Carlo Dropout with 30 stochastic forward passes provides calibrated epistemic uncertainty estimates. At 90% retention (discarding the most uncertain 10% of predictions), accuracy improves from 80.9% to 86.0%. ### Temperature-Calibrated Probabilities Post-hoc temperature scaling calibrates the model's softmax outputs so that predicted confidence values more closely match empirical accuracy. ### Per-Class Optimized Decision Thresholds Class-specific decision thresholds optimized on the validation set improve macro F1 from 0.813 to 0.822 on the held-out test set. ### FAISS-Based Similar Case Retrieval A FAISS index built from the model's 2304-dimensional fusion embeddings enables fast nearest-neighbor retrieval of visually and diagnostically similar cases from the training set, supporting clinical decision-making through case-based reasoning. --- ## Files in Repository | File / Directory | Description | |------------------|-------------| | `best_model.pth` | Trained HybridRetinaModel checkpoint (v4) | | `temperature.json` | Calibrated temperature scaling parameter | | `thresholds.json` | Per-class optimized decision thresholds | | `evaluation/` | Evaluation outputs: confusion matrices, ROC curves, per-class metrics, calibration plots | | `retrieval/` | FAISS index and retrieval scripts for similar case lookup | | `kfold/` | 5-fold cross-validation results and per-fold metrics | | `models/hybrid_retina_model.py` | Model architecture definition | | `training/retinasense_v4.py` | Main training script | | `training/kfold_cv.py` | 5-fold cross-validation script | | `training/lesion_attention_training.py` | GradCAM-guided attention fine-tuning | | `evaluation/eval_dashboard.py` | Comprehensive evaluation dashboard | | `retrieval/build_index.py` | FAISS index construction | | `retrieval/query_index.py` | Similar case retrieval interface | --- ## Limitations and Ethical Considerations ### Technical Limitations - **Normal class has lower F1 (0.69):** Early-stage disease presentations overlap visually with normal fundus images, leading to lower precision for the Normal class. Clinically, this means the model errs on the side of flagging images for review rather than missing disease. - **Dataset scope:** Trained exclusively on APTOS-2019 and ODIR-5K datasets. Performance may degrade on fundus images from different camera systems, patient demographics, or imaging protocols not represented in the training data. - **Single-label classification:** Each image receives one predicted label. Co-morbid conditions (e.g., concurrent DR and glaucoma) are not modeled. - **Cross-validation vs. test gap:** The 5-fold CV accuracy (91.1%) is higher than the held-out test accuracy (82.0%), which may reflect distribution differences between augmented training data and real test images. ### Intended Use This model is intended for **research and educational purposes only**. It may serve as a screening aid or decision-support tool in research settings, but it is **not a medical device** and has **not been validated for clinical deployment**. ### Out-of-Scope Uses - Clinical diagnosis without ophthalmologist verification - Deployment as a standalone screening tool in any healthcare setting - Use on non-fundus images or imaging modalities other than color fundus photography - Medico-legal decision-making --- ## Citation ```bibtex @misc{retinasense_v4_2026, title={RetinaSense-ViT v4: Hybrid Vision Transformer for Retinal Disease Classification}, author={Tanishq Tamarkar}, year={2026}, url={https://huggingface.co/tanishq74/retinasense-vit}, note={ViT-Base/16 + EfficientNet-B3 hybrid fusion model, 5-class retinal disease classification} } ``` --- ## License MIT License --- ## Disclaimer This is a research prototype for AI-assisted retinal screening. It is **NOT** a certified medical device and should **NOT** be used for clinical decision-making without independent verification by a qualified ophthalmologist. All predictions are probabilistic estimates and may be incorrect. The authors assume no liability for any clinical decisions made based on this model's outputs.