#!/usr/bin/env python3 """ RetinaSense v3.0 -- MC Dropout Uncertainty Quantification (Phase 1B) ==================================================================== Performs Monte Carlo Dropout inference on the test set to decompose predictive uncertainty into aleatoric and epistemic components. Strategy for efficiency: - Run the ViT backbone ONCE per image (deterministic, no dropout in backbone) - Cache the 768-dim CLS features - Run T=30 stochastic forward passes through the classification heads only (where the dropout layers live: self.drop + head dropouts) This is 30x faster than running the full model T times. For each test image, computes: - Predictive entropy (total uncertainty) - Expected entropy (aleatoric uncertainty) - Mutual information (epistemic uncertainty) - Per-class prediction variance Generates: - uncertainty_vs_accuracy.png - rejection_curve.png - epistemic_vs_aleatoric.png - uncertainty_by_class.png - confidence_vs_uncertainty.png - mc_dropout_results.json Usage: python mc_dropout_uncertainty.py """ import os import sys import json import time import warnings import numpy as np import pandas as pd import cv2 import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.patches as mpatches from PIL import Image from tqdm import tqdm warnings.filterwarnings('ignore') import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from torch.utils.data import Dataset, DataLoader import timm # Maximize CPU throughput torch.set_num_threads(4) # ================================================================ # CONFIGURATION # ================================================================ BASE_DIR = '/teamspace/studios/this_studio' OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3') UNCERT_DIR = os.path.join(OUTPUT_DIR, 'uncertainty') os.makedirs(UNCERT_DIR, exist_ok=True) MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth') TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json') NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json') TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv') CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD'] NUM_CLASSES = 5 IMG_SIZE = 224 DROPOUT = 0.3 T_FORWARD_PASSES = 30 # number of MC stochastic forward passes BATCH_SIZE = 32 # batch size for feature extraction HEAD_BATCH = 512 # batch size for head-only MC passes (very lightweight) DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('=' * 65) print(' RetinaSense v3.0 -- MC Dropout Uncertainty Quantification') print('=' * 65) print(f' Device : {DEVICE}') if torch.cuda.is_available(): print(f' GPU : {torch.cuda.get_device_name(0)}') print(f' MC passes (T) : {T_FORWARD_PASSES}') print(f' Output dir : {UNCERT_DIR}') # ================================================================ # LOAD NORMALISATION STATS # ================================================================ if os.path.exists(NORM_STATS_PATH): with open(NORM_STATS_PATH) as f: norm_stats = json.load(f) NORM_MEAN = norm_stats['mean_rgb'] NORM_STD = norm_stats['std_rgb'] print(f' Fundus norm : mean={[round(v,4) for v in NORM_MEAN]}, ' f'std={[round(v,4) for v in NORM_STD]}') else: NORM_MEAN = [0.485, 0.456, 0.406] NORM_STD = [0.229, 0.224, 0.225] print(' Using ImageNet normalisation fallback') # Load temperature with open(TEMPERATURE_PATH) as f: temp_data = json.load(f) TEMPERATURE = temp_data['temperature'] print(f' Temperature : {TEMPERATURE:.4f}') # ================================================================ # MODEL ARCHITECTURE (mirrors retinasense_v3.py / gradcam_v3.py) # ================================================================ class MultiTaskViT(nn.Module): """ViT-Base-Patch16-224 with disease + severity heads.""" def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT): super().__init__() self.backbone = timm.create_model( 'vit_base_patch16_224', pretrained=False, num_classes=0 ) feat = 768 # CLS token dimension self.drop = nn.Dropout(drop) self.disease_head = nn.Sequential( nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, n_disease), ) self.severity_head = nn.Sequential( nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, n_severity), ) def forward(self, x): f = self.backbone(x) # (B, 768) f = self.drop(f) return self.disease_head(f), self.severity_head(f) def extract_features(self, x): """Run backbone only (deterministic) to get CLS features.""" return self.backbone(x) # (B, 768) def forward_heads(self, features): """Run dropout + disease head on pre-extracted features.""" f = self.drop(features) return self.disease_head(f) # ================================================================ # LOAD MODEL # ================================================================ print('\nLoading model...') model = MultiTaskViT().to(DEVICE) ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) model.load_state_dict(ckpt['model_state_dict']) print(f' Loaded: {MODEL_PATH}') print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1} ' f'val_acc={ckpt.get("val_acc", 0):.2f}%') # ================================================================ # MC DROPOUT SETUP # ================================================================ def enable_head_dropout(model): """ Set model to eval mode, then enable dropout ONLY in the classification heads (self.drop, disease_head dropouts). The backbone stays fully deterministic (eval mode) so we only need one backbone pass per image. BatchNorm layers remain in eval mode (use running stats). """ model.eval() # everything to eval (including backbone) # Enable dropout in the drop layer and disease_head model.drop.train() for m in model.disease_head.modules(): if isinstance(m, (nn.Dropout, nn.Dropout2d)): m.train() enable_head_dropout(model) # Count active dropout layers n_dropout_active = 0 for name, m in model.named_modules(): if isinstance(m, (nn.Dropout, nn.Dropout2d)) and m.training: n_dropout_active += 1 n_dropout_total = sum(1 for m in model.modules() if isinstance(m, (nn.Dropout, nn.Dropout2d))) print(f'\n MC Dropout enabled in heads: {n_dropout_active} active / {n_dropout_total} total dropout layers') print(f' Backbone: deterministic (eval mode) -- single pass per image') print(f' Heads: stochastic (train mode dropout) -- {T_FORWARD_PASSES} passes per image') # ================================================================ # PREPROCESSING (matches gradcam_v3.py pipeline) # ================================================================ def ben_graham(path, sz=IMG_SIZE, sigma=10): """Ben Graham high-frequency fundus enhancement (APTOS-style).""" img = cv2.imread(path) if img is None: img = np.array(Image.open(path).convert('RGB')) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (sz, sz)) img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0, 0), sigma), -4, 128) mask = np.zeros(img.shape[:2], dtype=np.uint8) cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1) return cv2.bitwise_and(img, img, mask=mask) def clahe_preprocess(path, sz=IMG_SIZE): """CLAHE-based contrast enhancement (ODIR-style).""" img = cv2.imread(path) if img is None: img = np.array(Image.open(path).convert('RGB')) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) img = cv2.resize(img, (sz, sz)) lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) lab[:, :, 0] = clahe.apply(lab[:, :, 0]) img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) def resolve_path(image_path): """Resolve image path relative to BASE_DIR.""" if os.path.isabs(image_path) and os.path.exists(image_path): return image_path clean = image_path while clean.startswith('./'): clean = clean[2:] return os.path.join(BASE_DIR, clean) # ================================================================ # DATASET # ================================================================ class TestDataset(Dataset): """Test dataset loading preprocessed images from cache or live.""" def __init__(self, csv_path): self.df = pd.read_csv(csv_path).reset_index(drop=True) self.transform = transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize(NORM_MEAN, NORM_STD), ]) def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] img_path = str(row['image_path']) dataset = str(row.get('source', 'auto')) label = int(row['disease_label']) # Try loading from cache first cache_path = str(row.get('cache_path', '')) if cache_path and cache_path != 'nan': cache_abs = resolve_path(cache_path) if os.path.exists(cache_abs): try: img_np = np.load(cache_abs) img_tensor = self.transform(img_np) return img_tensor, label, img_path except Exception: pass # Live preprocessing abs_path = resolve_path(img_path) try: if dataset == 'APTOS': img_np = ben_graham(abs_path) else: img_np = clahe_preprocess(abs_path) img_tensor = self.transform(img_np) except Exception: img_tensor = torch.zeros(3, IMG_SIZE, IMG_SIZE) return img_tensor, label, img_path # ================================================================ # TWO-STAGE MC DROPOUT INFERENCE # ================================================================ def extract_all_features(model, dataloader): """ Stage 1: Run backbone once per image to get CLS features (deterministic). Returns features (N, 768), labels (N,), paths list. """ all_features = [] all_labels = [] all_paths = [] print(f'\n Stage 1: Extracting backbone features (deterministic)...') with torch.no_grad(): for images, labels, paths in tqdm(dataloader, desc=' Features', ncols=80): images = images.to(DEVICE) feats = model.extract_features(images) # (B, 768) all_features.append(feats.cpu()) all_labels.extend(labels.numpy().tolist()) all_paths.extend(paths) all_features = torch.cat(all_features, dim=0) # (N, 768) all_labels = np.array(all_labels) return all_features, all_labels, all_paths def mc_dropout_on_heads(model, features, T=T_FORWARD_PASSES, temperature=TEMPERATURE): """ Stage 2: Run T stochastic forward passes through heads only. features: (N, 768) tensor Returns: (N, T, C) numpy array of probability vectors. """ N = features.size(0) all_probs = np.zeros((N, T, NUM_CLASSES), dtype=np.float32) print(f'\n Stage 2: MC Dropout through heads ({T} passes, {N} samples)...') with torch.no_grad(): for t in tqdm(range(T), desc=' MC Passes', ncols=80): # Process in chunks to avoid memory issues for start in range(0, N, HEAD_BATCH): end = min(start + HEAD_BATCH, N) feat_batch = features[start:end].to(DEVICE) logits = model.forward_heads(feat_batch) scaled = logits / temperature probs = F.softmax(scaled, dim=1) all_probs[start:end, t, :] = probs.cpu().numpy() return all_probs # ================================================================ # UNCERTAINTY METRICS # ================================================================ def compute_uncertainty_metrics(mc_probs): """ Compute uncertainty metrics from MC dropout probability samples. Args: mc_probs: (N, T, C) array of MC sampled probability vectors Returns dict with: - p_mean, predicted_class, max_confidence - predictive_entropy (total), expected_entropy (aleatoric), mutual_info (epistemic), class_variance """ N, T, C = mc_probs.shape eps = 1e-10 # Predictive mean: average over T passes p_mean = mc_probs.mean(axis=1) # (N, C) predicted_class = p_mean.argmax(axis=1) # (N,) max_confidence = p_mean.max(axis=1) # (N,) # Predictive entropy: H[p_bar] = -sum(p_bar * log(p_bar)) -- TOTAL uncertainty predictive_entropy = -np.sum(p_mean * np.log(p_mean + eps), axis=1) # (N,) # Per-pass entropies per_pass_entropy = -np.sum(mc_probs * np.log(mc_probs + eps), axis=2) # (N, T) # Expected entropy: E_t[H[p_t]] -- ALEATORIC uncertainty expected_entropy = per_pass_entropy.mean(axis=1) # (N,) # Mutual information: H - E[H] -- EPISTEMIC uncertainty mutual_info = predictive_entropy - expected_entropy mutual_info = np.maximum(mutual_info, 0.0) # Prediction variance per class class_variance = mc_probs.var(axis=1) # (N, C) return { 'p_mean': p_mean, 'predicted_class': predicted_class, 'max_confidence': max_confidence, 'predictive_entropy': predictive_entropy, 'expected_entropy': expected_entropy, 'mutual_info': mutual_info, 'class_variance': class_variance, } # ================================================================ # PLOTTING FUNCTIONS # ================================================================ def plot_uncertainty_vs_accuracy(metrics, labels, save_path): """Scatter: total uncertainty vs correctness, colored by class.""" correct = (metrics['predicted_class'] == labels).astype(int) entropy = metrics['predictive_entropy'] fig, ax = plt.subplots(figsize=(10, 7)) colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES)) for cls_idx in range(NUM_CLASSES): mask = labels == cls_idx ax.scatter( entropy[mask], correct[mask] + np.random.uniform(-0.08, 0.08, mask.sum()), c=[colors[cls_idx]], alpha=0.5, s=20, label=CLASS_NAMES[cls_idx], edgecolors='none' ) ax.set_xlabel('Predictive Entropy (Total Uncertainty)', fontsize=12) ax.set_ylabel('Correctness (1=correct, 0=wrong)', fontsize=12) ax.set_title('MC Dropout: Uncertainty vs Prediction Correctness', fontsize=14) ax.set_yticks([0, 1]) ax.set_yticklabels(['Incorrect', 'Correct']) ax.legend(title='True Class', fontsize=9, title_fontsize=10) ax.grid(True, alpha=0.3) # Add vertical line at median uncertainty med = np.median(entropy) ax.axvline(med, color='red', linestyle='--', alpha=0.5, label=f'Median H={med:.3f}') # Summary stats correct_ent = entropy[correct == 1] wrong_ent = entropy[correct == 0] textstr = (f'Correct: mean H={correct_ent.mean():.3f}\n' f'Wrong: mean H={wrong_ent.mean():.3f}' if len(wrong_ent) > 0 else f'Correct: mean H={correct_ent.mean():.3f}') ax.text(0.98, 0.5, textstr, transform=ax.transAxes, fontsize=9, verticalalignment='center', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8)) plt.tight_layout() fig.savefig(save_path, dpi=300, bbox_inches='tight') plt.close(fig) print(f' Saved: {save_path}') def plot_rejection_curve(metrics, labels, save_path): """Accuracy as a function of rejection threshold on uncertainty.""" entropy = metrics['predictive_entropy'] correct = (metrics['predicted_class'] == labels).astype(int) # Sort by decreasing uncertainty sorted_idx = np.argsort(entropy)[::-1] sorted_correct = correct[sorted_idx] N = len(labels) rejection_fracs = np.linspace(0.0, 0.95, 200) accuracies = [] n_remaining = [] for frac in rejection_fracs: n_reject = int(frac * N) kept = sorted_correct[n_reject:] if len(kept) == 0: accuracies.append(np.nan) n_remaining.append(0) else: accuracies.append(kept.mean() * 100) n_remaining.append(len(kept)) accuracies = np.array(accuracies) n_remaining = np.array(n_remaining) fig, ax1 = plt.subplots(figsize=(10, 7)) color1 = '#2196F3' ax1.plot(rejection_fracs * 100, accuracies, color=color1, linewidth=2.0, label='Accuracy') ax1.set_xlabel('Rejection Rate (%)', fontsize=12) ax1.set_ylabel('Accuracy (%)', fontsize=12, color=color1) ax1.tick_params(axis='y', labelcolor=color1) ax1.set_ylim([max(50, np.nanmin(accuracies) - 5), 101]) # Secondary axis: number of remaining samples ax2 = ax1.twinx() color2 = '#FF9800' ax2.plot(rejection_fracs * 100, n_remaining, color=color2, linewidth=1.5, linestyle='--', alpha=0.7, label='Remaining') ax2.set_ylabel('Samples Remaining', fontsize=12, color=color2) ax2.tick_params(axis='y', labelcolor=color2) # Baseline accuracy (no rejection) base_acc = correct.mean() * 100 ax1.axhline(base_acc, color='gray', linestyle=':', alpha=0.5) ax1.text(2, base_acc + 0.5, f'Baseline: {base_acc:.1f}%', fontsize=9, color='gray') ax1.set_title('Rejection Curve: Accuracy vs Uncertainty-Based Rejection', fontsize=14) ax1.grid(True, alpha=0.3) # Combined legend lines1, labels1 = ax1.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax1.legend(lines1 + lines2, labels1 + labels2, loc='lower left', fontsize=10) plt.tight_layout() fig.savefig(save_path, dpi=300, bbox_inches='tight') plt.close(fig) print(f' Saved: {save_path}') def plot_epistemic_vs_aleatoric(metrics, labels, save_path): """Scatter separating epistemic and aleatoric uncertainty.""" aleatoric = metrics['expected_entropy'] epistemic = metrics['mutual_info'] correct = (metrics['predicted_class'] == labels).astype(int) fig, ax = plt.subplots(figsize=(10, 7)) colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES)) for cls_idx in range(NUM_CLASSES): mask = labels == cls_idx ax.scatter( aleatoric[mask], epistemic[mask], c=[colors[cls_idx]], alpha=0.45, s=20, label=CLASS_NAMES[cls_idx], edgecolors='none' ) # Mark misclassified samples wrong_mask = correct == 0 if wrong_mask.sum() > 0: ax.scatter( aleatoric[wrong_mask], epistemic[wrong_mask], facecolors='none', edgecolors='red', s=60, linewidths=1.2, label='Misclassified', zorder=5 ) ax.set_xlabel('Aleatoric Uncertainty (Expected Entropy)', fontsize=12) ax.set_ylabel('Epistemic Uncertainty (Mutual Information)', fontsize=12) ax.set_title('Decomposition of Uncertainty: Epistemic vs Aleatoric', fontsize=14) ax.legend(fontsize=9, title='Class', title_fontsize=10) ax.grid(True, alpha=0.3) # Annotate quadrants xlim = ax.get_xlim() ylim = ax.get_ylim() ax.text(xlim[0] + 0.02 * (xlim[1] - xlim[0]), ylim[1] - 0.05 * (ylim[1] - ylim[0]), 'Low aleatoric\nHigh epistemic\n(need more data)', fontsize=8, alpha=0.6, va='top') ax.text(xlim[1] - 0.02 * (xlim[1] - xlim[0]), ylim[1] - 0.05 * (ylim[1] - ylim[0]), 'High aleatoric\nHigh epistemic\n(hard + unseen)', fontsize=8, alpha=0.6, va='top', ha='right') ax.text(xlim[1] - 0.02 * (xlim[1] - xlim[0]), ylim[0] + 0.05 * (ylim[1] - ylim[0]), 'High aleatoric\nLow epistemic\n(inherently noisy)', fontsize=8, alpha=0.6, va='bottom', ha='right') plt.tight_layout() fig.savefig(save_path, dpi=300, bbox_inches='tight') plt.close(fig) print(f' Saved: {save_path}') def plot_uncertainty_by_class(metrics, labels, save_path): """Box plots of uncertainty per class.""" fig, axes = plt.subplots(1, 3, figsize=(18, 6)) data_types = [ ('predictive_entropy', 'Total Uncertainty (Predictive Entropy)'), ('expected_entropy', 'Aleatoric Uncertainty (Expected Entropy)'), ('mutual_info', 'Epistemic Uncertainty (Mutual Information)'), ] for ax, (key, title) in zip(axes, data_types): data = metrics[key] box_data = [data[labels == c] for c in range(NUM_CLASSES)] bp = ax.boxplot(box_data, labels=CLASS_NAMES, patch_artist=True, widths=0.6, showfliers=True, flierprops=dict(marker='o', markersize=3, alpha=0.3)) colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES)) for patch, color in zip(bp['boxes'], colors): patch.set_facecolor(color) patch.set_alpha(0.7) ax.set_title(title, fontsize=11) ax.set_ylabel('Uncertainty', fontsize=10) ax.grid(True, axis='y', alpha=0.3) ax.tick_params(axis='x', rotation=15) # Add sample counts for i, cls_data in enumerate(box_data): ax.text(i + 1, ax.get_ylim()[1] * 0.95, f'n={len(cls_data)}', ha='center', fontsize=8, alpha=0.6) plt.suptitle('Uncertainty Distribution by Disease Class', fontsize=14, y=1.02) plt.tight_layout() fig.savefig(save_path, dpi=300, bbox_inches='tight') plt.close(fig) print(f' Saved: {save_path}') def plot_confidence_vs_uncertainty(metrics, labels, save_path): """Scatter showing confidence vs uncertainty (should be anti-correlated).""" confidence = metrics['max_confidence'] entropy = metrics['predictive_entropy'] correct = (metrics['predicted_class'] == labels).astype(int) fig, ax = plt.subplots(figsize=(10, 7)) scatter_correct = ax.scatter( confidence[correct == 1], entropy[correct == 1], c='#4CAF50', alpha=0.4, s=15, label='Correct', edgecolors='none' ) scatter_wrong = ax.scatter( confidence[correct == 0], entropy[correct == 0], c='#F44336', alpha=0.6, s=25, label='Incorrect', edgecolors='none', marker='x', linewidths=1.0 ) # Compute correlation from scipy import stats r, p_val = stats.pearsonr(confidence, entropy) ax.set_xlabel('Maximum Confidence (max p_bar)', fontsize=12) ax.set_ylabel('Predictive Entropy (Total Uncertainty)', fontsize=12) ax.set_title(f'Confidence vs Uncertainty (Pearson r={r:.3f}, p={p_val:.2e})', fontsize=14) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) # Add trend line z = np.polyfit(confidence, entropy, 1) x_line = np.linspace(confidence.min(), confidence.max(), 100) ax.plot(x_line, np.polyval(z, x_line), 'k--', alpha=0.4, linewidth=1.5) plt.tight_layout() fig.savefig(save_path, dpi=300, bbox_inches='tight') plt.close(fig) print(f' Saved: {save_path}') # ================================================================ # MAIN # ================================================================ def main(): t_start = time.time() # ---- 1. Build DataLoader ---- print('\nLoading test set...') dataset = TestDataset(TEST_CSV) dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=False ) print(f' Test samples: {len(dataset)}') # ---- 2. Stage 1: Extract backbone features (single deterministic pass) ---- features, true_labels, image_paths = extract_all_features(model, dataloader) print(f' Features shape: {features.shape}') t_feat = time.time() - t_start print(f' Feature extraction: {t_feat:.1f}s') # ---- 3. Stage 2: MC Dropout on heads only ---- mc_probs = mc_dropout_on_heads( model, features, T=T_FORWARD_PASSES, temperature=TEMPERATURE ) print(f' MC probs shape: {mc_probs.shape} (N, T, C)') t_mc = time.time() - t_start - t_feat print(f' MC head passes: {t_mc:.1f}s') # ---- 4. Compute Uncertainty Metrics ---- print('\nComputing uncertainty metrics...') metrics = compute_uncertainty_metrics(mc_probs) # Print summary statistics correct = (metrics['predicted_class'] == true_labels).astype(int) accuracy = correct.mean() * 100 print(f'\n --- Summary ---') print(f' Accuracy (MC mean): {accuracy:.2f}%') print(f' Predictive entropy: mean={metrics["predictive_entropy"].mean():.4f}, ' f'std={metrics["predictive_entropy"].std():.4f}') print(f' Aleatoric (exp. ent.): mean={metrics["expected_entropy"].mean():.4f}, ' f'std={metrics["expected_entropy"].std():.4f}') print(f' Epistemic (MI): mean={metrics["mutual_info"].mean():.4f}, ' f'std={metrics["mutual_info"].std():.4f}') print(f' Max confidence: mean={metrics["max_confidence"].mean():.4f}, ' f'std={metrics["max_confidence"].std():.4f}') # Per-class stats print(f'\n Per-class uncertainty (predictive entropy):') for cls_idx in range(NUM_CLASSES): mask = true_labels == cls_idx n_cls = mask.sum() cls_acc = correct[mask].mean() * 100 if n_cls > 0 else 0 cls_ent = metrics['predictive_entropy'][mask].mean() if n_cls > 0 else 0 cls_mi = metrics['mutual_info'][mask].mean() if n_cls > 0 else 0 print(f' {CLASS_NAMES[cls_idx]:15s}: n={n_cls:4d}, ' f'acc={cls_acc:5.1f}%, H={cls_ent:.4f}, MI={cls_mi:.4f}') # ---- 5. Generate Plots ---- print('\nGenerating plots...') plot_uncertainty_vs_accuracy( metrics, true_labels, os.path.join(UNCERT_DIR, 'uncertainty_vs_accuracy.png') ) plot_rejection_curve( metrics, true_labels, os.path.join(UNCERT_DIR, 'rejection_curve.png') ) plot_epistemic_vs_aleatoric( metrics, true_labels, os.path.join(UNCERT_DIR, 'epistemic_vs_aleatoric.png') ) plot_uncertainty_by_class( metrics, true_labels, os.path.join(UNCERT_DIR, 'uncertainty_by_class.png') ) plot_confidence_vs_uncertainty( metrics, true_labels, os.path.join(UNCERT_DIR, 'confidence_vs_uncertainty.png') ) # ---- 6. Save JSON Results ---- print('\nSaving results JSON...') per_image = [] for i in range(len(true_labels)): per_image.append({ 'image_path': image_paths[i], 'true_label': int(true_labels[i]), 'true_class': CLASS_NAMES[int(true_labels[i])], 'predicted_label': int(metrics['predicted_class'][i]), 'predicted_class': CLASS_NAMES[int(metrics['predicted_class'][i])], 'correct': bool(correct[i]), 'max_confidence': round(float(metrics['max_confidence'][i]), 6), 'predictive_entropy': round(float(metrics['predictive_entropy'][i]), 6), 'expected_entropy': round(float(metrics['expected_entropy'][i]), 6), 'mutual_information': round(float(metrics['mutual_info'][i]), 6), 'class_variance': [round(float(v), 8) for v in metrics['class_variance'][i]], 'mean_probs': [round(float(v), 6) for v in metrics['p_mean'][i]], }) aggregate = { 'n_samples': int(len(true_labels)), 'n_classes': NUM_CLASSES, 'mc_passes': T_FORWARD_PASSES, 'temperature': TEMPERATURE, 'accuracy_pct': round(float(accuracy), 4), 'overall': { 'predictive_entropy': { 'mean': round(float(metrics['predictive_entropy'].mean()), 6), 'std': round(float(metrics['predictive_entropy'].std()), 6), 'min': round(float(metrics['predictive_entropy'].min()), 6), 'max': round(float(metrics['predictive_entropy'].max()), 6), }, 'expected_entropy': { 'mean': round(float(metrics['expected_entropy'].mean()), 6), 'std': round(float(metrics['expected_entropy'].std()), 6), 'min': round(float(metrics['expected_entropy'].min()), 6), 'max': round(float(metrics['expected_entropy'].max()), 6), }, 'mutual_information': { 'mean': round(float(metrics['mutual_info'].mean()), 6), 'std': round(float(metrics['mutual_info'].std()), 6), 'min': round(float(metrics['mutual_info'].min()), 6), 'max': round(float(metrics['mutual_info'].max()), 6), }, 'max_confidence': { 'mean': round(float(metrics['max_confidence'].mean()), 6), 'std': round(float(metrics['max_confidence'].std()), 6), }, }, 'per_class': {}, } for cls_idx in range(NUM_CLASSES): mask = true_labels == cls_idx n_cls = int(mask.sum()) if n_cls == 0: continue aggregate['per_class'][CLASS_NAMES[cls_idx]] = { 'n_samples': n_cls, 'accuracy': round(float(correct[mask].mean() * 100), 4), 'pred_entropy_mean': round(float(metrics['predictive_entropy'][mask].mean()), 6), 'pred_entropy_std': round(float(metrics['predictive_entropy'][mask].std()), 6), 'aleatoric_mean': round(float(metrics['expected_entropy'][mask].mean()), 6), 'epistemic_mean': round(float(metrics['mutual_info'][mask].mean()), 6), 'confidence_mean': round(float(metrics['max_confidence'][mask].mean()), 6), } # Rejection curve data at key thresholds entropy = metrics['predictive_entropy'] sorted_idx = np.argsort(entropy)[::-1] sorted_correct = correct[sorted_idx] rejection_checkpoints = {} for frac in [0.0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.50]: n_reject = int(frac * len(true_labels)) kept = sorted_correct[n_reject:] if len(kept) > 0: rejection_checkpoints[f'reject_{int(frac*100)}pct'] = { 'accuracy': round(float(kept.mean() * 100), 4), 'n_remaining': int(len(kept)), } aggregate['rejection_curve'] = rejection_checkpoints results = { 'aggregate': aggregate, 'per_image': per_image, } json_path = os.path.join(UNCERT_DIR, 'mc_dropout_results.json') with open(json_path, 'w') as f: json.dump(results, f, indent=2) print(f' Saved: {json_path}') elapsed = time.time() - t_start print(f'\nDone in {elapsed:.1f}s') print('=' * 65) if __name__ == '__main__': main()