#!/usr/bin/env python3 """ RetinaSense v3.0 — Production Training Script ============================================== Vision Transformer (ViT-Base-Patch16-224) with multi-task heads for retinal disease classification and diabetic retinopathy severity grading. v3 Enhancements over ViT baseline: 1. Layer-wise Learning Rate Decay (LLRD, decay=0.75) 2. WeightedRandomSampler for class imbalance 3. MixUp augmentation (alpha=0.4) with Focal Loss mixing 4. CosineAnnealingWarmRestarts (T_0=25, T_mult=2) 5. Extended training: 100 epochs, patience=20 on macro-F1 6. Fundus-specific normalisation (loads from data/fundus_norm_stats.json) 7. 3-way train/calib/test split (CSV-based or auto 70/15/15) 8. Temperature scaling (post-training calibration on calib set) 9. Per-class threshold optimisation on calib set, final eval on test set Usage: python retinasense_v3.py """ import os import sys import time import warnings import json import numpy as np import pandas as pd import cv2 import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import seaborn as sns from PIL import Image from tqdm import tqdm from collections import Counter warnings.filterwarnings('ignore') import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import autocast, GradScaler from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler from torchvision import transforms import timm from scipy.optimize import minimize_scalar from sklearn.model_selection import train_test_split from sklearn.utils.class_weight import compute_class_weight from sklearn.metrics import ( classification_report, confusion_matrix, roc_auc_score, f1_score, roc_curve, auc ) from sklearn.preprocessing import label_binarize # ================================================================ # CONFIG # ================================================================ class Config: DATA_DIR = './data' CACHE_DIR = './preprocessed_cache_v3' OUTPUT_DIR = './outputs_v3' MODEL_NAME = 'vit_base_patch16_224' IMG_SIZE = 224 NUM_DISEASE_CLASSES = 5 NUM_SEVERITY_CLASSES = 5 DROPOUT = 0.3 # reduced from 0.4 in v2 BATCH_SIZE = 32 NUM_EPOCHS = 3 NUM_WORKERS = 8 BASE_LR = 3e-4 LLRD_DECAY = 0.75 WEIGHT_DECAY = 1e-4 GRADIENT_ACCUMULATION = 2 # effective batch = 64 FOCAL_GAMMA = 1.0 MIXUP_ALPHA = 0.4 PATIENCE = 3 MIN_DELTA = 0.001 CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD'] DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Paths for 3-way splits TRAIN_CSV = './data/train_split.csv' CALIB_CSV = './data/calib_split.csv' TEST_CSV = './data/test_split.csv' # ImageNet fallback normalisation IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] cfg = Config() os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) os.makedirs(cfg.CACHE_DIR, exist_ok=True) os.makedirs(cfg.DATA_DIR, exist_ok=True) print('=' * 65) print(' RetinaSense v3.0 — Production Training Pipeline') print('=' * 65) if torch.cuda.is_available(): print(f' GPU : {torch.cuda.get_device_name(0)}') print(f' VRAM : {round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1)} GB') print(f' Backbone : {cfg.MODEL_NAME} (timm)') print(f' Epochs : {cfg.NUM_EPOCHS} (patience={cfg.PATIENCE})') print(f' Batch : {cfg.BATCH_SIZE} (eff. {cfg.BATCH_SIZE * cfg.GRADIENT_ACCUMULATION} via grad accum)') print(f' LLRD decay : {cfg.LLRD_DECAY}') print(f' MixUp alpha : {cfg.MIXUP_ALPHA}') print(f' Focal gamma : {cfg.FOCAL_GAMMA}') print('=' * 65) # ================================================================ # STEP 1 — NORMALISATION STATS # ================================================================ print('\n[1/9] Loading normalisation stats...') norm_stats_path = os.path.join(cfg.DATA_DIR, 'fundus_norm_stats.json') if os.path.exists(norm_stats_path): with open(norm_stats_path) as f: norm_stats = json.load(f) NORM_MEAN = norm_stats['mean_rgb'] NORM_STD = norm_stats['std_rgb'] print(f' Fundus-specific stats loaded: mean={NORM_MEAN}, std={NORM_STD}') else: NORM_MEAN = cfg.IMAGENET_MEAN NORM_STD = cfg.IMAGENET_STD print(f' fundus_norm_stats.json not found — using ImageNet defaults') print(f' mean={NORM_MEAN}, std={NORM_STD}') # ================================================================ # STEP 2 — METADATA # ================================================================ print('\n[2/9] Building metadata...') BASE = './' disease_cols = ['N', 'D', 'G', 'C', 'A'] label_map = {'N': 0, 'D': 1, 'G': 2, 'C': 3, 'A': 4} def _load_odir(base): """Load and filter ODIR metadata to single-label samples.""" odir_csv = os.path.join(base, 'odir', 'full_df.csv') if not os.path.exists(odir_csv): print(' WARNING: ODIR CSV not found, skipping ODIR samples') return pd.DataFrame() df = pd.read_csv(odir_csv) df['disease_count'] = df[disease_cols].sum(axis=1) df = df[df['disease_count'] == 1].copy() def get_label(row): for d in disease_cols: if row[d] == 1: return label_map[d] df['disease_label'] = df.apply(get_label, axis=1) img_col = next( c for c in df.columns if any(k in c.lower() for k in ['filename', 'fundus', 'image']) ) out = pd.DataFrame({ 'image_path': os.path.join(base, 'odir', 'preprocessed_images') + '/' + df[img_col].astype(str), 'source': 'ODIR', 'disease_label': df['disease_label'], 'severity_label': -1, }) return out def _load_aptos(base): """Load APTOS metadata.""" aptos_csv = os.path.join(base, 'aptos', 'train.csv') if not os.path.exists(aptos_csv): print(' WARNING: APTOS CSV not found, skipping APTOS samples') return pd.DataFrame() df = pd.read_csv(aptos_csv) out = pd.DataFrame({ 'image_path': os.path.join(base, 'aptos', 'train_images') + '/' + df['id_code'] + '.png', 'source': 'APTOS', 'disease_label': 1, 'severity_label': df['diagnosis'], }) return out def _load_refuge2(base): """Load REFUGE2 Glaucoma-only subset (~400 images). Only the Glaucoma class is used — targeted fix for the weakest class (308 samples). Images are Zeiss Visucam 500 quality — no Ben Graham needed.""" glaucoma_dir = os.path.join(base, 'refuge2', 'Training400', 'Glaucoma') if not os.path.exists(glaucoma_dir): print(' WARNING: REFUGE2 not found, skipping (expected: refuge2/Training400/Glaucoma/)') return pd.DataFrame() imgs = [os.path.join(glaucoma_dir, f) for f in os.listdir(glaucoma_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] if not imgs: return pd.DataFrame() out = pd.DataFrame({ 'image_path': imgs, 'source': 'REFUGE2', 'disease_label': 2, # Glaucoma = class 2 'severity_label': -1, }) print(f' REFUGE2 Glaucoma: {len(out)} images loaded') return out odir_meta = _load_odir(BASE) aptos_meta = _load_aptos(BASE) refuge2_meta = _load_refuge2(BASE) parts = [df for df in [odir_meta, aptos_meta, refuge2_meta] if len(df) > 0] if len(parts) == 0: raise RuntimeError('No dataset found. Place ODIR/APTOS data under ./odir and ./aptos.') meta = pd.concat(parts, ignore_index=True) meta = meta[meta['image_path'].apply(os.path.exists)].reset_index(drop=True) # severity -1 (unknown) → 0 meta['severity_label'] = meta['severity_label'].clip(lower=0).fillna(0).astype(int) print(f' Total valid samples: {len(meta)}') dist = meta['disease_label'].value_counts().sort_index() for i, cnt in dist.items(): print(f' {cfg.CLASS_NAMES[i]:15s}: {cnt:4d} ({100 * cnt / len(meta):.1f}%)') # ================================================================ # STEP 3 — PRE-CACHE # ================================================================ print(f'\n[3/9] Pre-caching images @ {cfg.IMG_SIZE}x{cfg.IMG_SIZE}...') def _read_rgb(path): """Read image from disk as RGB numpy array.""" img = cv2.imread(path) if img is None: img = np.array(Image.open(path).convert('RGB')) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) def _circular_mask(img, sz): mask = np.zeros(img.shape[:2], dtype=np.uint8) cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1) return cv2.bitwise_and(img, img, mask=mask) def ben_graham(path, sz=cfg.IMG_SIZE, sigma=10): """Ben Graham enhancement for APTOS field-camera images. Removes low-frequency illumination gradients, amplifies vessel/lesion detail.""" img = cv2.resize(_read_rgb(path), (sz, sz)) img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0, 0), sigma), -4, 128) return _circular_mask(img, sz) def clahe_preprocess(path, sz=cfg.IMG_SIZE): """CLAHE preprocessing for ODIR multi-source clinical images. Normalises local contrast without destroying fine vessel/drusen detail.""" img = cv2.resize(_read_rgb(path), (sz, sz)) lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) lab[:, :, 0] = clahe.apply(lab[:, :, 0]) img = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) return _circular_mask(img, sz) def resize_only(path, sz=cfg.IMG_SIZE): """Minimal preprocessing for already-clinical-grade images (REFUGE2). Zeiss Visucam 500 images are standardised high quality — no enhancement needed.""" img = cv2.resize(_read_rgb(path), (sz, sz)) return _circular_mask(img, sz) def preprocess_image(path, source, sz=cfg.IMG_SIZE): """Source-conditional preprocessing dispatcher. APTOS -> Ben Graham (field camera, vignetting correction) ODIR -> CLAHE (multi-source clinical, contrast normalisation) REFUGE2 -> Resize only (Zeiss Visucam 500, already high quality) """ src = str(source).upper() if src == 'APTOS': return ben_graham(path, sz) if src == 'REFUGE2': return resize_only(path, sz) return clahe_preprocess(path, sz) def _cache_key(image_path): """Filename-based cache key (basename without extension).""" stem = os.path.splitext(os.path.basename(image_path))[0] return os.path.join(cfg.CACHE_DIR, f'{stem}_{cfg.IMG_SIZE}.npy') cache_paths = [] cached = 0 for _, row in tqdm(meta.iterrows(), total=len(meta), desc='Caching'): fp = _cache_key(row['image_path']) if not os.path.exists(fp): try: np.save(fp, preprocess_image(row['image_path'], row['source'])) except Exception: np.save(fp, np.zeros((cfg.IMG_SIZE, cfg.IMG_SIZE, 3), dtype=np.uint8)) cached += 1 cache_paths.append(fp) meta['cache_path'] = cache_paths print(f' Newly cached: {cached} | Already cached: {len(meta) - cached}') # ================================================================ # STEP 4 — 3-WAY SPLIT # ================================================================ print('\n[4/9] Preparing train / calib / test splits...') def _load_or_create_splits(meta_df): """ Load splits from CSV files if they exist (train/calib/test). Otherwise perform a stratified 70/15/15 auto-split and persist the CSVs so future runs are reproducible. Returns (train_df, calib_df, test_df). """ splits_exist = (os.path.exists(cfg.TRAIN_CSV) and os.path.exists(cfg.CALIB_CSV) and os.path.exists(cfg.TEST_CSV)) if splits_exist: train_df = pd.read_csv(cfg.TRAIN_CSV) calib_df = pd.read_csv(cfg.CALIB_CSV) test_df = pd.read_csv(cfg.TEST_CSV) # Regenerate if any source is in current metadata but absent from saved splits stale = False for src in ['APTOS', 'REFUGE2']: if (src in meta_df['source'].values and ('source' not in train_df.columns or src not in train_df['source'].values)): print(f' Stale splits detected ({src} missing) — regenerating...') stale = True break if stale: splits_exist = False # fall through to recreate else: print(f' Loaded existing splits: train={len(train_df)}, ' f'calib={len(calib_df)}, test={len(test_df)}') if not splits_exist: print(' Split files not found — creating 70/15/15 stratified split...') train_df, temp_df = train_test_split( meta_df, test_size=0.30, stratify=meta_df['disease_label'], random_state=42 ) calib_df, test_df = train_test_split( temp_df, test_size=0.50, stratify=temp_df['disease_label'], random_state=42 ) train_df.to_csv(cfg.TRAIN_CSV, index=False) calib_df.to_csv(cfg.CALIB_CSV, index=False) test_df.to_csv(cfg.TEST_CSV, index=False) print(f' Auto-split saved: train={len(train_df)}, ' f'calib={len(calib_df)}, test={len(test_df)}') return train_df, calib_df, test_df train_df, calib_df, test_df = _load_or_create_splits(meta) # ================================================================ # STEP 5 — DATASET + TRANSFORMS # ================================================================ print('\n[5/9] Building dataset and loaders...') def make_transforms(phase): """ Return torchvision transform pipeline. Train: spatial augmentation + color jitter + random erasing. Val / calib / test: deterministic normalisation only. """ normalize = transforms.Normalize(NORM_MEAN, NORM_STD) if phase == 'train': return transforms.Compose([ transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(p=0.3), transforms.RandomRotation(20), transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.02), transforms.ToTensor(), normalize, transforms.RandomErasing(p=0.2), ]) return transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), normalize, ]) class RetinalDataset(Dataset): """ Retinal fundus image dataset. Loads from preprocessed_cache_v3/ using a filename-based key. Falls back to on-the-fly ben_graham preprocessing if cache is missing (rare; cache is built in step 3). severity_label -1 is mapped to 0 (unknown severity). """ def __init__(self, df, transform): self.df = df.reset_index(drop=True) self.transform = transform def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] # Load from cache (fast path) cache_fp = row.get('cache_path', _cache_key(row['image_path'])) try: img = np.load(cache_fp) except Exception: # Fallback: source-conditional preprocess on the fly try: img = preprocess_image(row['image_path'], row.get('source', 'ODIR')) except Exception: img = np.zeros((cfg.IMG_SIZE, cfg.IMG_SIZE, 3), dtype=np.uint8) img_tensor = self.transform(img) disease_lbl = int(row['disease_label']) severity_lbl = int(row['severity_label']) # Map -1 (unknown) → 0 if severity_lbl < 0: severity_lbl = 0 return ( img_tensor, torch.tensor(disease_lbl, dtype=torch.long), torch.tensor(severity_lbl, dtype=torch.long), ) # --- WeightedRandomSampler for training --- def _make_weighted_sampler(df): """ Compute per-sample weights inversely proportional to class frequency. Every batch will see all 5 classes roughly equally. """ labels = df['disease_label'].values class_cnt = np.bincount(labels, minlength=cfg.NUM_DISEASE_CLASSES).astype(float) class_cnt = np.where(class_cnt == 0, 1.0, class_cnt) weights = 1.0 / class_cnt[labels] return WeightedRandomSampler( weights=torch.DoubleTensor(weights), num_samples=len(weights), replacement=True, ) train_ds = RetinalDataset(train_df, make_transforms('train')) calib_ds = RetinalDataset(calib_df, make_transforms('val')) test_ds = RetinalDataset(test_df, make_transforms('val')) sampler = _make_weighted_sampler(train_df) train_loader = DataLoader( train_ds, batch_size=cfg.BATCH_SIZE, sampler=sampler, # WeightedRandomSampler replaces shuffle=True num_workers=cfg.NUM_WORKERS, pin_memory=True, persistent_workers=True, prefetch_factor=2, ) calib_loader = DataLoader( calib_ds, batch_size=cfg.BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS, pin_memory=True, persistent_workers=True, ) test_loader = DataLoader( test_ds, batch_size=cfg.BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS, pin_memory=True, persistent_workers=True, ) print(f' Train : {len(train_ds):5d} ({len(train_loader):3d} batches) — WeightedRandomSampler') print(f' Calib : {len(calib_ds):5d} ({len(calib_loader):3d} batches)') print(f' Test : {len(test_ds):5d} ({len(test_loader):3d} batches) [SEALED until final eval]') # ================================================================ # STEP 6 — MODEL, LOSS, LLRD OPTIMIZER # ================================================================ print('\n[6/9] Building model and optimizer...') # --- Focal Loss --- class FocalLoss(nn.Module): """ Focal Loss — down-weights easy examples, focuses on hard ones. alpha: per-class weight tensor; gamma: focusing parameter. """ def __init__(self, alpha=None, gamma=2.0): super().__init__() self.gamma = gamma if alpha is not None: self.register_buffer('alpha', alpha) else: self.alpha = None def forward(self, logits, targets): ce = F.cross_entropy(logits, targets, reduction='none') pt = torch.exp(-ce) focal = ((1 - pt) ** self.gamma) * ce if self.alpha is not None: at = self.alpha.gather(0, targets) focal = at * focal return focal.mean() # --- Multi-task ViT --- class MultiTaskViT(nn.Module): """ ViT-Base-Patch16-224 backbone with two classification heads: - disease_head : 5-class fundus disease classification - severity_head : 5-class DR severity grading (APTOS only) Dropout reduced to 0.3 (vs 0.4 in v2) since LLRD + MixUp already provide strong regularisation. """ def __init__(self, n_disease=cfg.NUM_DISEASE_CLASSES, n_severity=cfg.NUM_SEVERITY_CLASSES, drop=cfg.DROPOUT): super().__init__() self.backbone = timm.create_model( cfg.MODEL_NAME, pretrained=True, num_classes=0 ) feat = 768 # ViT-Base CLS token dimension self.drop = nn.Dropout(drop) self.disease_head = nn.Sequential( nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, n_disease), ) self.severity_head = nn.Sequential( nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, n_severity), ) def forward(self, x): f = self.backbone(x) # (B, 768) f = self.drop(f) return self.disease_head(f), self.severity_head(f) # --- Layer-wise Learning Rate Decay (LLRD) --- def get_optimizer_with_llrd(model, base_lr=cfg.BASE_LR, decay_factor=cfg.LLRD_DECAY): """ Build AdamW with per-parameter-group learning rates following LLRD. Strategy (head → patch_embed, each step multiplies by decay_factor): - disease_head / severity_head / drop : base_lr (full rate = 3e-4) - blocks[11] : base_lr * decay^1 - blocks[10] : base_lr * decay^2 ... - blocks[0] : base_lr * decay^12 - patch_embed + cls_token + pos_embed : base_lr * decay^13 (~1e-6) - norm : same as last block Returns: AdamW optimizer with separate param groups. """ param_groups = [] # 1. Classification heads (full LR) head_params = ( list(model.disease_head.parameters()) + list(model.severity_head.parameters()) + list(model.drop.parameters()) ) param_groups.append({'params': head_params, 'lr': base_lr}) # 2. Transformer blocks (12 blocks, indexed 11 → 0) blocks = model.backbone.blocks # nn.Sequential of 12 blocks num_blocks = len(blocks) for block_idx in range(num_blocks - 1, -1, -1): distance_from_head = num_blocks - block_idx # 1 for block[11], 12 for block[0] lr_i = base_lr * (decay_factor ** distance_from_head) param_groups.append({ 'params': list(blocks[block_idx].parameters()), 'lr': lr_i, }) # 3. Patch embedding + positional embedding + CLS token + norm embed_lr = base_lr * (decay_factor ** (num_blocks + 1)) embed_params = ( list(model.backbone.patch_embed.parameters()) + [model.backbone.cls_token, model.backbone.pos_embed] + list(model.backbone.norm.parameters()) ) param_groups.append({'params': embed_params, 'lr': embed_lr}) optimizer = torch.optim.AdamW( param_groups, weight_decay=cfg.WEIGHT_DECAY, ) # Log LR distribution lrs = [g['lr'] for g in param_groups] print(f' LLRD optimizer: {len(param_groups)} param groups') print(f' Head LR : {lrs[0]:.2e}') print(f' Block[11] : {lrs[1]:.2e}') print(f' Block[0] : {lrs[-2]:.2e}') print(f' Embed LR : {lrs[-1]:.2e}') return optimizer # --- Instantiate model --- model = MultiTaskViT().to(cfg.DEVICE) # --- Focal loss class weights (computed on train set) --- cw = compute_class_weight('balanced', classes=np.arange(cfg.NUM_DISEASE_CLASSES), y=train_df['disease_label'].values) alpha = torch.tensor(cw, dtype=torch.float32).to(cfg.DEVICE) alpha = alpha / alpha.sum() * cfg.NUM_DISEASE_CLASSES # normalise print(f' Focal alpha: {[f"{a:.2f}" for a in alpha.tolist()]}') criterion_d = FocalLoss(alpha=alpha, gamma=cfg.FOCAL_GAMMA) criterion_s = nn.CrossEntropyLoss(ignore_index=-1) total_params = sum(p.numel() for p in model.parameters()) print(f' Total params: {total_params:,}') # --- Optimizer (LLRD) --- optimizer = get_optimizer_with_llrd(model) # --- Scheduler: OneCycleLR --- # 10% warmup then cosine decay — avoids the epoch-3 LR collapse from # CosineAnnealingWarmRestarts. Stepped once per optimizer update (per batch). scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=[pg['lr'] for pg in optimizer.param_groups], steps_per_epoch=len(train_loader), epochs=cfg.NUM_EPOCHS, pct_start=0.1, anneal_strategy='cos', div_factor=10.0, final_div_factor=100.0, ) scaler = GradScaler() # ================================================================ # STEP 7 — MIXUP + TRAINING LOOP # ================================================================ def mixup_data(x, y, alpha=cfg.MIXUP_ALPHA): """ MixUp augmentation. Returns mixed inputs, the two label tensors, and the mixing coefficient. Loss is mixed externally: lam * L(pred, y_a) + (1-lam) * L(pred, y_b). """ lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0 batch_size = x.size(0) index = torch.randperm(batch_size, device=x.device) mixed_x = lam * x + (1 - lam) * x[index] return mixed_x, y, y[index], lam def evaluate(loader, model, criterion_d, criterion_s, device, desc='Eval'): """ Run inference on a DataLoader. Returns: loss : average total loss preds : numpy array of argmax predictions targets : numpy array of ground-truth labels probs : numpy array of softmax probabilities (N, C) """ model.eval() total_loss = 0.0 all_preds, all_targets, all_probs = [], [], [] with torch.no_grad(): for imgs, d_lbl, s_lbl in tqdm(loader, desc=desc, leave=False): imgs = imgs.to(device, non_blocking=True) d_lbl = d_lbl.to(device, non_blocking=True) s_lbl = s_lbl.to(device, non_blocking=True) with autocast('cuda'): d_out, s_out = model(imgs) ld = criterion_d(d_out, d_lbl) ls = criterion_s(s_out, s_lbl) loss = ld + 0.2 * ls if not (torch.isnan(loss) or torch.isinf(loss)): total_loss += loss.item() probs = torch.softmax(d_out.float(), dim=1) all_preds.extend(d_out.argmax(1).cpu().numpy()) all_targets.extend(d_lbl.cpu().numpy()) all_probs.extend(probs.cpu().numpy()) avg_loss = total_loss / len(loader) return (avg_loss, np.array(all_preds), np.array(all_targets), np.array(all_probs)) print('\n[7/9] Training...') CHECKPOINT = os.path.join(cfg.OUTPUT_DIR, 'best_model.pth') history = {k: [] for k in [ 'train_loss', 'val_loss', 'train_acc', 'val_acc', 'macro_f1', 'weighted_f1', 'lr', *(f'f1_{c}' for c in cfg.CLASS_NAMES) ]} best_f1 = 0.0 patience_ctr = 0 t_start = time.time() print('=' * 65) for epoch in range(cfg.NUM_EPOCHS): t0 = time.time() # ---- TRAIN ---- model.train() run_loss = 0.0 correct = 0 total = 0 optimizer.zero_grad(set_to_none=True) pbar = tqdm(train_loader, desc=f'E{epoch+1:03d}/{cfg.NUM_EPOCHS} train', leave=False) for step, (imgs, d_lbl, s_lbl) in enumerate(pbar): imgs = imgs.to(cfg.DEVICE, non_blocking=True) d_lbl = d_lbl.to(cfg.DEVICE, non_blocking=True) s_lbl = s_lbl.to(cfg.DEVICE, non_blocking=True) # MixUp augmentation (train only) mixed_imgs, y_a, y_b, lam = mixup_data(imgs, d_lbl, alpha=cfg.MIXUP_ALPHA) with autocast('cuda'): d_out, s_out = model(mixed_imgs) # Mixed Focal Loss: lam * L(y_a) + (1-lam) * L(y_b) loss_d = lam * criterion_d(d_out, y_a) + (1 - lam) * criterion_d(d_out, y_b) loss_s = criterion_s(s_out, s_lbl) loss = (loss_d + 0.2 * loss_s) / cfg.GRADIENT_ACCUMULATION if torch.isnan(loss) or torch.isinf(loss): optimizer.zero_grad(set_to_none=True) continue scaler.scale(loss).backward() if (step + 1) % cfg.GRADIENT_ACCUMULATION == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() scheduler.step() optimizer.zero_grad(set_to_none=True) run_loss += loss.item() * cfg.GRADIENT_ACCUMULATION # Use un-mixed predictions for accuracy tracking with torch.no_grad(): preds = d_out.argmax(1) correct += (preds == y_a).sum().item() total += d_lbl.size(0) pbar.set_postfix( loss=f'{loss.item() * cfg.GRADIENT_ACCUMULATION:.3f}', acc=f'{100 * correct / total:.1f}%' ) # Flush remaining gradients for incomplete accumulation window if (len(train_loader)) % cfg.GRADIENT_ACCUMULATION != 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() scheduler.step() optimizer.zero_grad(set_to_none=True) train_loss = run_loss / len(train_loader) train_acc = 100 * correct / total # ---- VALIDATE on calibration set ---- val_loss, val_preds, val_targets, val_probs = evaluate( calib_loader, model, criterion_d, criterion_s, cfg.DEVICE, desc=f'E{epoch+1:03d}/{cfg.NUM_EPOCHS} calib' ) val_acc = 100 * (val_preds == val_targets).mean() mf1 = f1_score(val_targets, val_preds, average='macro') wf1 = f1_score(val_targets, val_preds, average='weighted') per_f1 = f1_score(val_targets, val_preds, average=None, labels=range(cfg.NUM_DISEASE_CLASSES), zero_division=0) lr_now = optimizer.param_groups[0]['lr'] history['train_loss'].append(train_loss) history['val_loss'].append(val_loss) history['train_acc'].append(train_acc) history['val_acc'].append(val_acc) history['macro_f1'].append(mf1) history['weighted_f1'].append(wf1) history['lr'].append(lr_now) for ci, cn in enumerate(cfg.CLASS_NAMES): history[f'f1_{cn}'].append(float(per_f1[ci])) elapsed = time.time() - t0 # ---- Early stopping on macro-F1 (with min_delta) ---- tag = '' if mf1 > best_f1 + cfg.MIN_DELTA: best_f1 = mf1 patience_ctr = 0 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'val_acc': val_acc, 'macro_f1': mf1, 'history': history, }, CHECKPOINT) tag = f' * NEW BEST (macro-F1={mf1:.4f})' else: patience_ctr += 1 cls_str = ' | '.join( f'{cn[:3]}:{per_f1[ci]:.2f}' for ci, cn in enumerate(cfg.CLASS_NAMES) ) print( f'E{epoch+1:03d} | {elapsed:.0f}s | LR {lr_now:.2e} | ' f'TrL {train_loss:.3f} TrA {train_acc:.1f}% | ' f'VL {val_loss:.3f} VA {val_acc:.1f}% | ' f'mF1 {mf1:.4f} wF1 {wf1:.4f}{tag}' ) print(f' {cls_str}') if patience_ctr >= cfg.PATIENCE: print(f'\n Early stopping — no improvement for {cfg.PATIENCE} epochs') break total_train_time = time.time() - t_start print(f'\nTraining complete. Best macro-F1: {best_f1:.4f}') print(f'Total training time: {total_train_time / 60:.1f} minutes') # Save training history with open(os.path.join(cfg.OUTPUT_DIR, 'history.json'), 'w') as f: json.dump({k: [float(v) for v in vs] for k, vs in history.items()}, f, indent=2) # ================================================================ # STEP 8 — TEMPERATURE SCALING (post-training calibration) # ================================================================ print('\n[8/9] Temperature scaling on calibration set...') # Reload best model ckpt = torch.load(CHECKPOINT, map_location=cfg.DEVICE, weights_only=False) model.load_state_dict(ckpt['model_state_dict']) model.eval() print(f' Loaded best checkpoint (epoch {ckpt["epoch"]+1}, ' f'macro-F1={ckpt["macro_f1"]:.4f})') def _collect_logits_labels(loader, model, device): """Collect raw logits and labels (no softmax) from a DataLoader.""" all_logits, all_labels = [], [] with torch.no_grad(): for imgs, d_lbl, _ in tqdm(loader, desc='Collecting logits', leave=False): imgs = imgs.to(device, non_blocking=True) d_out, _ = model(imgs) all_logits.append(d_out.float().cpu()) all_labels.append(d_lbl.cpu()) return torch.cat(all_logits, dim=0), torch.cat(all_labels, dim=0) def _ece(probs, labels, n_bins=15): """ Expected Calibration Error. probs : numpy (N, C) softmax probabilities labels : numpy (N,) ground truth class indices """ confidences = probs.max(axis=1) predictions = probs.argmax(axis=1) accuracies = predictions == labels bin_edges = np.linspace(0, 1, n_bins + 1) ece_val = 0.0 for lo, hi in zip(bin_edges[:-1], bin_edges[1:]): mask = (confidences >= lo) & (confidences < hi) if mask.sum() == 0: continue acc_bin = accuracies[mask].mean() conf_bin = confidences[mask].mean() ece_val += mask.sum() * abs(acc_bin - conf_bin) return float(ece_val / len(labels)) calib_logits, calib_labels = _collect_logits_labels(calib_loader, model, cfg.DEVICE) # ECE before calibration probs_before = torch.softmax(calib_logits, dim=1).numpy() ece_before = _ece(probs_before, calib_labels.numpy()) print(f' ECE before temperature scaling: {ece_before:.4f}') def _nll_with_temperature(T, logits, labels): """Negative log-likelihood at temperature T (for scipy minimiser).""" scaled_logits = logits / T log_probs = F.log_softmax(scaled_logits, dim=1) nll = F.nll_loss(log_probs, labels).item() return nll result = minimize_scalar( fun=_nll_with_temperature, args=(calib_logits, calib_labels), bounds=(0.01, 10.0), method='bounded', ) T_opt = float(result.x) print(f' Optimal temperature T = {T_opt:.4f}') probs_after = torch.softmax(calib_logits / T_opt, dim=1).numpy() ece_after = _ece(probs_after, calib_labels.numpy()) print(f' ECE after temperature scaling: {ece_after:.4f}') # Save temperature temp_path = os.path.join(cfg.OUTPUT_DIR, 'temperature.json') with open(temp_path, 'w') as f: json.dump({'temperature': T_opt, 'ece_before': ece_before, 'ece_after': ece_after}, f, indent=2) print(f' Saved -> {temp_path}') # ================================================================ # STEP 9 — THRESHOLD OPTIMISATION ON CALIB SET # ================================================================ print('\n[9/9] Per-class threshold optimisation on calibration set...') def optimise_thresholds(probs, labels, n_classes, n_grid=50): """ Grid-search per-class decision thresholds on the calibration set. For each class c, sweep threshold in [0.05, 0.95] and pick the value maximising F1 for class c (one-vs-rest). Returns: list of per-class thresholds (length n_classes). """ thresholds = [] for c in range(n_classes): binary_labels = (labels == c).astype(int) best_t = 0.5 best_f1 = 0.0 for t in np.linspace(0.05, 0.95, n_grid): preds_c = (probs[:, c] >= t).astype(int) f = f1_score(binary_labels, preds_c, zero_division=0) if f > best_f1: best_f1 = f best_t = t thresholds.append(float(best_t)) print(f' {cfg.CLASS_NAMES[c]:15s}: threshold={best_t:.3f} (calib F1={best_f1:.3f})') return thresholds calib_thresholds = optimise_thresholds( probs_after, calib_labels.numpy(), cfg.NUM_DISEASE_CLASSES, ) thresh_path = os.path.join(cfg.OUTPUT_DIR, 'thresholds.json') with open(thresh_path, 'w') as f: json.dump({'thresholds': calib_thresholds, 'class_names': cfg.CLASS_NAMES}, f, indent=2) print(f' Saved -> {thresh_path}') def apply_thresholds(probs, thresholds): """ Apply per-class thresholds to probability matrix. Assigns each sample to the class with highest prob-above-threshold. Falls back to argmax if no class exceeds its threshold. """ preds = [] for prob_row in probs: above = [i for i, (p, t) in enumerate(zip(prob_row, thresholds)) if p >= t] preds.append(int(above[np.argmax([prob_row[i] for i in above])] if above else np.argmax(prob_row))) return np.array(preds) # ================================================================ # FINAL EVALUATION ON TEST SET (first and only time test is touched) # ================================================================ print('\n' + '=' * 65) print(' FINAL EVALUATION — TEST SET') print('=' * 65) print(' (Test set was never seen during training or threshold tuning)') test_logits, test_labels = _collect_logits_labels(test_loader, model, cfg.DEVICE) test_probs_calibrated = torch.softmax(test_logits / T_opt, dim=1).numpy() test_labels_np = test_labels.numpy() # Raw argmax predictions test_preds_raw = test_probs_calibrated.argmax(axis=1) # Threshold-adjusted predictions test_preds_thr = apply_thresholds(test_probs_calibrated, calib_thresholds) def _print_metrics(preds, targets, probs, label): acc = 100 * (preds == targets).mean() mf1 = f1_score(targets, preds, average='macro') wf1 = f1_score(targets, preds, average='weighted') try: mauc = roc_auc_score(targets, probs, multi_class='ovr', average='macro') except Exception: mauc = 0.0 per = f1_score(targets, preds, average=None, labels=range(cfg.NUM_DISEASE_CLASSES), zero_division=0) ece = _ece(probs, targets) print(f'\n [{label}]') print(f' Accuracy : {acc:.2f}%') print(f' Macro F1 : {mf1:.4f}') print(f' Weighted F1: {wf1:.4f}') print(f' Macro AUC : {mauc:.4f}') print(f' ECE : {ece:.4f}') print() print(classification_report(targets, preds, target_names=cfg.CLASS_NAMES, digits=4)) return {'accuracy': acc, 'macro_f1': mf1, 'weighted_f1': wf1, 'macro_auc': mauc, 'ece': ece, **{f'f1_{cfg.CLASS_NAMES[i]}': float(per[i]) for i in range(cfg.NUM_DISEASE_CLASSES)}} metrics_raw = _print_metrics(test_preds_raw, test_labels_np, test_probs_calibrated, 'Raw argmax (T-scaled)') metrics_thr = _print_metrics(test_preds_thr, test_labels_np, test_probs_calibrated, 'With per-class thresholds') # Save final metrics final_metrics = { 'raw': metrics_raw, 'thresholded': metrics_thr, 'temperature': T_opt, 'thresholds': calib_thresholds, } metrics_path = os.path.join(cfg.OUTPUT_DIR, 'final_metrics.json') with open(metrics_path, 'w') as f: json.dump(final_metrics, f, indent=2) # ================================================================ # PLOTS # ================================================================ print('\nGenerating plots...') ep = range(1, len(history['train_loss']) + 1) colors = ['#2ecc71', '#3498db', '#e74c3c', '#f39c12', '#9b59b6'] fig, axes = plt.subplots(2, 3, figsize=(20, 12)) # 1. Loss axes[0, 0].plot(ep, history['train_loss'], 'b-o', ms=3, label='Train') axes[0, 0].plot(ep, history['val_loss'], 'r-o', ms=3, label='Calib') axes[0, 0].set_title('Loss', fontweight='bold') axes[0, 0].legend(); axes[0, 0].grid(alpha=0.3) # 2. Accuracy axes[0, 1].plot(ep, history['train_acc'], 'b-o', ms=3, label='Train') axes[0, 1].plot(ep, history['val_acc'], 'r-o', ms=3, label='Calib') axes[0, 1].set_title('Accuracy (%)', fontweight='bold') axes[0, 1].legend(); axes[0, 1].grid(alpha=0.3) # 3. F1 axes[0, 2].plot(ep, history['macro_f1'], 'g-o', ms=3, label='Macro F1') axes[0, 2].plot(ep, history['weighted_f1'], 'm-o', ms=3, label='Weighted F1') axes[0, 2].set_title('F1 Scores (calib)', fontweight='bold') axes[0, 2].legend(); axes[0, 2].grid(alpha=0.3) # 4. Per-class F1 for ci, cn in enumerate(cfg.CLASS_NAMES): axes[1, 0].plot(ep, history[f'f1_{cn}'], '-o', ms=2, color=colors[ci], label=cn) axes[1, 0].set_title('Per-Class F1 (calib)', fontweight='bold') axes[1, 0].legend(fontsize=8); axes[1, 0].grid(alpha=0.3) # 5. Confusion matrix (test set, thresholded) cm = confusion_matrix(test_labels_np, test_preds_thr) cm_n = cm.astype(float) / cm.sum(axis=1, keepdims=True) sns.heatmap(cm_n, annot=True, fmt='.2f', cmap='Blues', ax=axes[1, 1], xticklabels=cfg.CLASS_NAMES, yticklabels=cfg.CLASS_NAMES) axes[1, 1].set_title('Confusion Matrix — Test Set (norm)', fontweight='bold') axes[1, 1].set_ylabel('True'); axes[1, 1].set_xlabel('Pred') # 6. ROC curves (test set) y_bin = label_binarize(test_labels_np, classes=list(range(cfg.NUM_DISEASE_CLASSES))) for ci, (cn, col) in enumerate(zip(cfg.CLASS_NAMES, colors)): fpr, tpr, _ = roc_curve(y_bin[:, ci], test_probs_calibrated[:, ci]) axes[1, 2].plot(fpr, tpr, color=col, lw=2, label=f'{cn} ({auc(fpr, tpr):.3f})') axes[1, 2].plot([0, 1], [0, 1], 'k--', lw=1) axes[1, 2].set_title('ROC Curves — Test Set', fontweight='bold') axes[1, 2].legend(loc='lower right', fontsize=8) axes[1, 2].grid(alpha=0.3) plt.suptitle( f'RetinaSense v3.0 — Macro F1={metrics_thr["macro_f1"]:.3f} | ' f'AUC={metrics_thr["macro_auc"]:.3f} | ' f'Test Acc={metrics_thr["accuracy"]:.1f}%', fontsize=14, fontweight='bold', y=1.01 ) plt.tight_layout() plt.savefig(os.path.join(cfg.OUTPUT_DIR, 'dashboard.png'), dpi=150, bbox_inches='tight') plt.close() # LR schedule plot fig, ax = plt.subplots(figsize=(10, 3)) ax.plot(ep, history['lr'], 'b-o', ms=2) ax.set_title('Learning Rate (head param group) — OneCycleLR', fontweight='bold') ax.set_xlabel('Epoch'); ax.set_ylabel('LR') ax.grid(alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(cfg.OUTPUT_DIR, 'lr_schedule.png'), dpi=150) plt.close() # Calibration reliability diagram fig, ax = plt.subplots(figsize=(6, 6)) n_bins = 15 confs = test_probs_calibrated.max(axis=1) acc_arr = (test_preds_thr == test_labels_np).astype(float) bin_edges = np.linspace(0, 1, n_bins + 1) bin_accs, bin_confs = [], [] for lo, hi in zip(bin_edges[:-1], bin_edges[1:]): mask = (confs >= lo) & (confs < hi) if mask.sum() > 0: bin_accs.append(acc_arr[mask].mean()) bin_confs.append(confs[mask].mean()) ax.bar(bin_confs, bin_accs, width=1.0 / n_bins, alpha=0.7, edgecolor='black', label='Model') ax.plot([0, 1], [0, 1], 'r--', lw=2, label='Perfect calibration') ax.set_xlabel('Confidence'); ax.set_ylabel('Accuracy') ax.set_title(f'Reliability Diagram (T={T_opt:.2f}, ECE={ece_after:.3f})', fontweight='bold') ax.legend(); ax.grid(alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(cfg.OUTPUT_DIR, 'calibration.png'), dpi=150) plt.close() # ================================================================ # SUMMARY # ================================================================ print('\n' + '=' * 65) print(' RETINASENSE v3.0 — FINAL SUMMARY') print('=' * 65) print(f' Training epochs : {len(history["train_loss"])}') print(f' Best calib macro-F1 : {best_f1:.4f}') print(f' Temperature T : {T_opt:.4f}') print(f' ECE before / after : {ece_before:.4f} / {ece_after:.4f}') print() print(' TEST SET RESULTS (with thresholds)') print(f' Accuracy : {metrics_thr["accuracy"]:.2f}%') print(f' Macro F1 : {metrics_thr["macro_f1"]:.4f}') print(f' Weighted F1: {metrics_thr["weighted_f1"]:.4f}') print(f' Macro AUC : {metrics_thr["macro_auc"]:.4f}') print(f' ECE : {metrics_thr["ece"]:.4f}') print() print(' Per-class F1 (test, thresholded):') for i, cn in enumerate(cfg.CLASS_NAMES): thr = calib_thresholds[i] fi = metrics_thr[f'f1_{cn}'] print(f' {cn:15s}: F1={fi:.3f} (threshold={thr:.3f})') print() print(f' Training time : {total_train_time / 60:.1f} minutes') print() print(f' Outputs saved to {cfg.OUTPUT_DIR}/') for fname in ['best_model.pth', 'history.json', 'temperature.json', 'thresholds.json', 'final_metrics.json', 'dashboard.png', 'lr_schedule.png', 'calibration.png']: print(f' -- {fname}') print('=' * 65)