| |
| """ |
| RetinaSense v3.0 -- MC Dropout Uncertainty Quantification (Phase 1B) |
| ==================================================================== |
| Performs Monte Carlo Dropout inference on the test set to decompose |
| predictive uncertainty into aleatoric and epistemic components. |
| |
| Strategy for efficiency: |
| - Run the ViT backbone ONCE per image (deterministic, no dropout in backbone) |
| - Cache the 768-dim CLS features |
| - Run T=30 stochastic forward passes through the classification heads only |
| (where the dropout layers live: self.drop + head dropouts) |
| This is 30x faster than running the full model T times. |
| |
| For each test image, computes: |
| - Predictive entropy (total uncertainty) |
| - Expected entropy (aleatoric uncertainty) |
| - Mutual information (epistemic uncertainty) |
| - Per-class prediction variance |
| |
| Generates: |
| - uncertainty_vs_accuracy.png |
| - rejection_curve.png |
| - epistemic_vs_aleatoric.png |
| - uncertainty_by_class.png |
| - confidence_vs_uncertainty.png |
| - mc_dropout_results.json |
| |
| Usage: |
| python mc_dropout_uncertainty.py |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import warnings |
| import numpy as np |
| import pandas as pd |
| import cv2 |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| warnings.filterwarnings('ignore') |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from torch.utils.data import Dataset, DataLoader |
|
|
| import timm |
|
|
| |
| torch.set_num_threads(4) |
|
|
| |
| |
| |
| BASE_DIR = '/teamspace/studios/this_studio' |
| OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3') |
| UNCERT_DIR = os.path.join(OUTPUT_DIR, 'uncertainty') |
| os.makedirs(UNCERT_DIR, exist_ok=True) |
|
|
| MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth') |
| TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json') |
| NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json') |
| TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv') |
|
|
| CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD'] |
| NUM_CLASSES = 5 |
| IMG_SIZE = 224 |
| DROPOUT = 0.3 |
|
|
| T_FORWARD_PASSES = 30 |
| BATCH_SIZE = 32 |
| HEAD_BATCH = 512 |
|
|
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| print('=' * 65) |
| print(' RetinaSense v3.0 -- MC Dropout Uncertainty Quantification') |
| print('=' * 65) |
| print(f' Device : {DEVICE}') |
| if torch.cuda.is_available(): |
| print(f' GPU : {torch.cuda.get_device_name(0)}') |
| print(f' MC passes (T) : {T_FORWARD_PASSES}') |
| print(f' Output dir : {UNCERT_DIR}') |
|
|
| |
| |
| |
| if os.path.exists(NORM_STATS_PATH): |
| with open(NORM_STATS_PATH) as f: |
| norm_stats = json.load(f) |
| NORM_MEAN = norm_stats['mean_rgb'] |
| NORM_STD = norm_stats['std_rgb'] |
| print(f' Fundus norm : mean={[round(v,4) for v in NORM_MEAN]}, ' |
| f'std={[round(v,4) for v in NORM_STD]}') |
| else: |
| NORM_MEAN = [0.485, 0.456, 0.406] |
| NORM_STD = [0.229, 0.224, 0.225] |
| print(' Using ImageNet normalisation fallback') |
|
|
| |
| with open(TEMPERATURE_PATH) as f: |
| temp_data = json.load(f) |
| TEMPERATURE = temp_data['temperature'] |
| print(f' Temperature : {TEMPERATURE:.4f}') |
|
|
| |
| |
| |
| class MultiTaskViT(nn.Module): |
| """ViT-Base-Patch16-224 with disease + severity heads.""" |
|
|
| def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT): |
| super().__init__() |
| self.backbone = timm.create_model( |
| 'vit_base_patch16_224', pretrained=False, num_classes=0 |
| ) |
| feat = 768 |
|
|
| self.drop = nn.Dropout(drop) |
|
|
| self.disease_head = nn.Sequential( |
| nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), |
| nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), |
| nn.Linear(256, n_disease), |
| ) |
| self.severity_head = nn.Sequential( |
| nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), |
| nn.Linear(256, n_severity), |
| ) |
|
|
| def forward(self, x): |
| f = self.backbone(x) |
| f = self.drop(f) |
| return self.disease_head(f), self.severity_head(f) |
|
|
| def extract_features(self, x): |
| """Run backbone only (deterministic) to get CLS features.""" |
| return self.backbone(x) |
|
|
| def forward_heads(self, features): |
| """Run dropout + disease head on pre-extracted features.""" |
| f = self.drop(features) |
| return self.disease_head(f) |
|
|
|
|
| |
| |
| |
| print('\nLoading model...') |
| model = MultiTaskViT().to(DEVICE) |
| ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) |
| model.load_state_dict(ckpt['model_state_dict']) |
| print(f' Loaded: {MODEL_PATH}') |
| print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1} ' |
| f'val_acc={ckpt.get("val_acc", 0):.2f}%') |
|
|
|
|
| |
| |
| |
| def enable_head_dropout(model): |
| """ |
| Set model to eval mode, then enable dropout ONLY in the classification |
| heads (self.drop, disease_head dropouts). The backbone stays fully |
| deterministic (eval mode) so we only need one backbone pass per image. |
| BatchNorm layers remain in eval mode (use running stats). |
| """ |
| model.eval() |
|
|
| |
| model.drop.train() |
| for m in model.disease_head.modules(): |
| if isinstance(m, (nn.Dropout, nn.Dropout2d)): |
| m.train() |
|
|
|
|
| enable_head_dropout(model) |
|
|
| |
| n_dropout_active = 0 |
| for name, m in model.named_modules(): |
| if isinstance(m, (nn.Dropout, nn.Dropout2d)) and m.training: |
| n_dropout_active += 1 |
| n_dropout_total = sum(1 for m in model.modules() if isinstance(m, (nn.Dropout, nn.Dropout2d))) |
| print(f'\n MC Dropout enabled in heads: {n_dropout_active} active / {n_dropout_total} total dropout layers') |
| print(f' Backbone: deterministic (eval mode) -- single pass per image') |
| print(f' Heads: stochastic (train mode dropout) -- {T_FORWARD_PASSES} passes per image') |
|
|
|
|
| |
| |
| |
| def ben_graham(path, sz=IMG_SIZE, sigma=10): |
| """Ben Graham high-frequency fundus enhancement (APTOS-style).""" |
| img = cv2.imread(path) |
| if img is None: |
| img = np.array(Image.open(path).convert('RGB')) |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| img = cv2.resize(img, (sz, sz)) |
| img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0, 0), sigma), -4, 128) |
| mask = np.zeros(img.shape[:2], dtype=np.uint8) |
| cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1) |
| return cv2.bitwise_and(img, img, mask=mask) |
|
|
|
|
| def clahe_preprocess(path, sz=IMG_SIZE): |
| """CLAHE-based contrast enhancement (ODIR-style).""" |
| img = cv2.imread(path) |
| if img is None: |
| img = np.array(Image.open(path).convert('RGB')) |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| img = cv2.resize(img, (sz, sz)) |
| lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
| lab[:, :, 0] = clahe.apply(lab[:, :, 0]) |
| img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) |
| return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
| def resolve_path(image_path): |
| """Resolve image path relative to BASE_DIR.""" |
| if os.path.isabs(image_path) and os.path.exists(image_path): |
| return image_path |
| clean = image_path |
| while clean.startswith('./'): |
| clean = clean[2:] |
| return os.path.join(BASE_DIR, clean) |
|
|
|
|
| |
| |
| |
| class TestDataset(Dataset): |
| """Test dataset loading preprocessed images from cache or live.""" |
|
|
| def __init__(self, csv_path): |
| self.df = pd.read_csv(csv_path).reset_index(drop=True) |
| self.transform = transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.ToTensor(), |
| transforms.Normalize(NORM_MEAN, NORM_STD), |
| ]) |
|
|
| def __len__(self): |
| return len(self.df) |
|
|
| def __getitem__(self, idx): |
| row = self.df.iloc[idx] |
| img_path = str(row['image_path']) |
| dataset = str(row.get('source', 'auto')) |
| label = int(row['disease_label']) |
|
|
| |
| cache_path = str(row.get('cache_path', '')) |
| if cache_path and cache_path != 'nan': |
| cache_abs = resolve_path(cache_path) |
| if os.path.exists(cache_abs): |
| try: |
| img_np = np.load(cache_abs) |
| img_tensor = self.transform(img_np) |
| return img_tensor, label, img_path |
| except Exception: |
| pass |
|
|
| |
| abs_path = resolve_path(img_path) |
| try: |
| if dataset == 'APTOS': |
| img_np = ben_graham(abs_path) |
| else: |
| img_np = clahe_preprocess(abs_path) |
| img_tensor = self.transform(img_np) |
| except Exception: |
| img_tensor = torch.zeros(3, IMG_SIZE, IMG_SIZE) |
|
|
| return img_tensor, label, img_path |
|
|
|
|
| |
| |
| |
| def extract_all_features(model, dataloader): |
| """ |
| Stage 1: Run backbone once per image to get CLS features (deterministic). |
| Returns features (N, 768), labels (N,), paths list. |
| """ |
| all_features = [] |
| all_labels = [] |
| all_paths = [] |
|
|
| print(f'\n Stage 1: Extracting backbone features (deterministic)...') |
| with torch.no_grad(): |
| for images, labels, paths in tqdm(dataloader, desc=' Features', ncols=80): |
| images = images.to(DEVICE) |
| feats = model.extract_features(images) |
| all_features.append(feats.cpu()) |
| all_labels.extend(labels.numpy().tolist()) |
| all_paths.extend(paths) |
|
|
| all_features = torch.cat(all_features, dim=0) |
| all_labels = np.array(all_labels) |
| return all_features, all_labels, all_paths |
|
|
|
|
| def mc_dropout_on_heads(model, features, T=T_FORWARD_PASSES, temperature=TEMPERATURE): |
| """ |
| Stage 2: Run T stochastic forward passes through heads only. |
| features: (N, 768) tensor |
| Returns: (N, T, C) numpy array of probability vectors. |
| """ |
| N = features.size(0) |
| all_probs = np.zeros((N, T, NUM_CLASSES), dtype=np.float32) |
|
|
| print(f'\n Stage 2: MC Dropout through heads ({T} passes, {N} samples)...') |
|
|
| with torch.no_grad(): |
| for t in tqdm(range(T), desc=' MC Passes', ncols=80): |
| |
| for start in range(0, N, HEAD_BATCH): |
| end = min(start + HEAD_BATCH, N) |
| feat_batch = features[start:end].to(DEVICE) |
| logits = model.forward_heads(feat_batch) |
| scaled = logits / temperature |
| probs = F.softmax(scaled, dim=1) |
| all_probs[start:end, t, :] = probs.cpu().numpy() |
|
|
| return all_probs |
|
|
|
|
| |
| |
| |
| def compute_uncertainty_metrics(mc_probs): |
| """ |
| Compute uncertainty metrics from MC dropout probability samples. |
| |
| Args: |
| mc_probs: (N, T, C) array of MC sampled probability vectors |
| |
| Returns dict with: |
| - p_mean, predicted_class, max_confidence |
| - predictive_entropy (total), expected_entropy (aleatoric), |
| mutual_info (epistemic), class_variance |
| """ |
| N, T, C = mc_probs.shape |
| eps = 1e-10 |
|
|
| |
| p_mean = mc_probs.mean(axis=1) |
| predicted_class = p_mean.argmax(axis=1) |
| max_confidence = p_mean.max(axis=1) |
|
|
| |
| predictive_entropy = -np.sum(p_mean * np.log(p_mean + eps), axis=1) |
|
|
| |
| per_pass_entropy = -np.sum(mc_probs * np.log(mc_probs + eps), axis=2) |
|
|
| |
| expected_entropy = per_pass_entropy.mean(axis=1) |
|
|
| |
| mutual_info = predictive_entropy - expected_entropy |
| mutual_info = np.maximum(mutual_info, 0.0) |
|
|
| |
| class_variance = mc_probs.var(axis=1) |
|
|
| return { |
| 'p_mean': p_mean, |
| 'predicted_class': predicted_class, |
| 'max_confidence': max_confidence, |
| 'predictive_entropy': predictive_entropy, |
| 'expected_entropy': expected_entropy, |
| 'mutual_info': mutual_info, |
| 'class_variance': class_variance, |
| } |
|
|
|
|
| |
| |
| |
| def plot_uncertainty_vs_accuracy(metrics, labels, save_path): |
| """Scatter: total uncertainty vs correctness, colored by class.""" |
| correct = (metrics['predicted_class'] == labels).astype(int) |
| entropy = metrics['predictive_entropy'] |
|
|
| fig, ax = plt.subplots(figsize=(10, 7)) |
|
|
| colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES)) |
| for cls_idx in range(NUM_CLASSES): |
| mask = labels == cls_idx |
| ax.scatter( |
| entropy[mask], correct[mask] + np.random.uniform(-0.08, 0.08, mask.sum()), |
| c=[colors[cls_idx]], alpha=0.5, s=20, label=CLASS_NAMES[cls_idx], |
| edgecolors='none' |
| ) |
|
|
| ax.set_xlabel('Predictive Entropy (Total Uncertainty)', fontsize=12) |
| ax.set_ylabel('Correctness (1=correct, 0=wrong)', fontsize=12) |
| ax.set_title('MC Dropout: Uncertainty vs Prediction Correctness', fontsize=14) |
| ax.set_yticks([0, 1]) |
| ax.set_yticklabels(['Incorrect', 'Correct']) |
| ax.legend(title='True Class', fontsize=9, title_fontsize=10) |
| ax.grid(True, alpha=0.3) |
|
|
| |
| med = np.median(entropy) |
| ax.axvline(med, color='red', linestyle='--', alpha=0.5, label=f'Median H={med:.3f}') |
|
|
| |
| correct_ent = entropy[correct == 1] |
| wrong_ent = entropy[correct == 0] |
| textstr = (f'Correct: mean H={correct_ent.mean():.3f}\n' |
| f'Wrong: mean H={wrong_ent.mean():.3f}' if len(wrong_ent) > 0 |
| else f'Correct: mean H={correct_ent.mean():.3f}') |
| ax.text(0.98, 0.5, textstr, transform=ax.transAxes, |
| fontsize=9, verticalalignment='center', horizontalalignment='right', |
| bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8)) |
|
|
| plt.tight_layout() |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.close(fig) |
| print(f' Saved: {save_path}') |
|
|
|
|
| def plot_rejection_curve(metrics, labels, save_path): |
| """Accuracy as a function of rejection threshold on uncertainty.""" |
| entropy = metrics['predictive_entropy'] |
| correct = (metrics['predicted_class'] == labels).astype(int) |
|
|
| |
| sorted_idx = np.argsort(entropy)[::-1] |
| sorted_correct = correct[sorted_idx] |
|
|
| N = len(labels) |
| rejection_fracs = np.linspace(0.0, 0.95, 200) |
| accuracies = [] |
| n_remaining = [] |
|
|
| for frac in rejection_fracs: |
| n_reject = int(frac * N) |
| kept = sorted_correct[n_reject:] |
| if len(kept) == 0: |
| accuracies.append(np.nan) |
| n_remaining.append(0) |
| else: |
| accuracies.append(kept.mean() * 100) |
| n_remaining.append(len(kept)) |
|
|
| accuracies = np.array(accuracies) |
| n_remaining = np.array(n_remaining) |
|
|
| fig, ax1 = plt.subplots(figsize=(10, 7)) |
|
|
| color1 = '#2196F3' |
| ax1.plot(rejection_fracs * 100, accuracies, color=color1, linewidth=2.0, |
| label='Accuracy') |
| ax1.set_xlabel('Rejection Rate (%)', fontsize=12) |
| ax1.set_ylabel('Accuracy (%)', fontsize=12, color=color1) |
| ax1.tick_params(axis='y', labelcolor=color1) |
| ax1.set_ylim([max(50, np.nanmin(accuracies) - 5), 101]) |
|
|
| |
| ax2 = ax1.twinx() |
| color2 = '#FF9800' |
| ax2.plot(rejection_fracs * 100, n_remaining, color=color2, linewidth=1.5, |
| linestyle='--', alpha=0.7, label='Remaining') |
| ax2.set_ylabel('Samples Remaining', fontsize=12, color=color2) |
| ax2.tick_params(axis='y', labelcolor=color2) |
|
|
| |
| base_acc = correct.mean() * 100 |
| ax1.axhline(base_acc, color='gray', linestyle=':', alpha=0.5) |
| ax1.text(2, base_acc + 0.5, f'Baseline: {base_acc:.1f}%', fontsize=9, color='gray') |
|
|
| ax1.set_title('Rejection Curve: Accuracy vs Uncertainty-Based Rejection', fontsize=14) |
| ax1.grid(True, alpha=0.3) |
|
|
| |
| lines1, labels1 = ax1.get_legend_handles_labels() |
| lines2, labels2 = ax2.get_legend_handles_labels() |
| ax1.legend(lines1 + lines2, labels1 + labels2, loc='lower left', fontsize=10) |
|
|
| plt.tight_layout() |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.close(fig) |
| print(f' Saved: {save_path}') |
|
|
|
|
| def plot_epistemic_vs_aleatoric(metrics, labels, save_path): |
| """Scatter separating epistemic and aleatoric uncertainty.""" |
| aleatoric = metrics['expected_entropy'] |
| epistemic = metrics['mutual_info'] |
| correct = (metrics['predicted_class'] == labels).astype(int) |
|
|
| fig, ax = plt.subplots(figsize=(10, 7)) |
|
|
| colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES)) |
| for cls_idx in range(NUM_CLASSES): |
| mask = labels == cls_idx |
| ax.scatter( |
| aleatoric[mask], epistemic[mask], |
| c=[colors[cls_idx]], alpha=0.45, s=20, label=CLASS_NAMES[cls_idx], |
| edgecolors='none' |
| ) |
|
|
| |
| wrong_mask = correct == 0 |
| if wrong_mask.sum() > 0: |
| ax.scatter( |
| aleatoric[wrong_mask], epistemic[wrong_mask], |
| facecolors='none', edgecolors='red', s=60, linewidths=1.2, |
| label='Misclassified', zorder=5 |
| ) |
|
|
| ax.set_xlabel('Aleatoric Uncertainty (Expected Entropy)', fontsize=12) |
| ax.set_ylabel('Epistemic Uncertainty (Mutual Information)', fontsize=12) |
| ax.set_title('Decomposition of Uncertainty: Epistemic vs Aleatoric', fontsize=14) |
| ax.legend(fontsize=9, title='Class', title_fontsize=10) |
| ax.grid(True, alpha=0.3) |
|
|
| |
| xlim = ax.get_xlim() |
| ylim = ax.get_ylim() |
| ax.text(xlim[0] + 0.02 * (xlim[1] - xlim[0]), |
| ylim[1] - 0.05 * (ylim[1] - ylim[0]), |
| 'Low aleatoric\nHigh epistemic\n(need more data)', |
| fontsize=8, alpha=0.6, va='top') |
| ax.text(xlim[1] - 0.02 * (xlim[1] - xlim[0]), |
| ylim[1] - 0.05 * (ylim[1] - ylim[0]), |
| 'High aleatoric\nHigh epistemic\n(hard + unseen)', |
| fontsize=8, alpha=0.6, va='top', ha='right') |
| ax.text(xlim[1] - 0.02 * (xlim[1] - xlim[0]), |
| ylim[0] + 0.05 * (ylim[1] - ylim[0]), |
| 'High aleatoric\nLow epistemic\n(inherently noisy)', |
| fontsize=8, alpha=0.6, va='bottom', ha='right') |
|
|
| plt.tight_layout() |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.close(fig) |
| print(f' Saved: {save_path}') |
|
|
|
|
| def plot_uncertainty_by_class(metrics, labels, save_path): |
| """Box plots of uncertainty per class.""" |
| fig, axes = plt.subplots(1, 3, figsize=(18, 6)) |
|
|
| data_types = [ |
| ('predictive_entropy', 'Total Uncertainty (Predictive Entropy)'), |
| ('expected_entropy', 'Aleatoric Uncertainty (Expected Entropy)'), |
| ('mutual_info', 'Epistemic Uncertainty (Mutual Information)'), |
| ] |
|
|
| for ax, (key, title) in zip(axes, data_types): |
| data = metrics[key] |
| box_data = [data[labels == c] for c in range(NUM_CLASSES)] |
|
|
| bp = ax.boxplot(box_data, labels=CLASS_NAMES, patch_artist=True, |
| widths=0.6, showfliers=True, |
| flierprops=dict(marker='o', markersize=3, alpha=0.3)) |
|
|
| colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES)) |
| for patch, color in zip(bp['boxes'], colors): |
| patch.set_facecolor(color) |
| patch.set_alpha(0.7) |
|
|
| ax.set_title(title, fontsize=11) |
| ax.set_ylabel('Uncertainty', fontsize=10) |
| ax.grid(True, axis='y', alpha=0.3) |
| ax.tick_params(axis='x', rotation=15) |
|
|
| |
| for i, cls_data in enumerate(box_data): |
| ax.text(i + 1, ax.get_ylim()[1] * 0.95, |
| f'n={len(cls_data)}', ha='center', fontsize=8, alpha=0.6) |
|
|
| plt.suptitle('Uncertainty Distribution by Disease Class', fontsize=14, y=1.02) |
| plt.tight_layout() |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.close(fig) |
| print(f' Saved: {save_path}') |
|
|
|
|
| def plot_confidence_vs_uncertainty(metrics, labels, save_path): |
| """Scatter showing confidence vs uncertainty (should be anti-correlated).""" |
| confidence = metrics['max_confidence'] |
| entropy = metrics['predictive_entropy'] |
| correct = (metrics['predicted_class'] == labels).astype(int) |
|
|
| fig, ax = plt.subplots(figsize=(10, 7)) |
|
|
| scatter_correct = ax.scatter( |
| confidence[correct == 1], entropy[correct == 1], |
| c='#4CAF50', alpha=0.4, s=15, label='Correct', edgecolors='none' |
| ) |
| scatter_wrong = ax.scatter( |
| confidence[correct == 0], entropy[correct == 0], |
| c='#F44336', alpha=0.6, s=25, label='Incorrect', edgecolors='none', |
| marker='x', linewidths=1.0 |
| ) |
|
|
| |
| from scipy import stats |
| r, p_val = stats.pearsonr(confidence, entropy) |
|
|
| ax.set_xlabel('Maximum Confidence (max p_bar)', fontsize=12) |
| ax.set_ylabel('Predictive Entropy (Total Uncertainty)', fontsize=12) |
| ax.set_title(f'Confidence vs Uncertainty (Pearson r={r:.3f}, p={p_val:.2e})', fontsize=14) |
| ax.legend(fontsize=10) |
| ax.grid(True, alpha=0.3) |
|
|
| |
| z = np.polyfit(confidence, entropy, 1) |
| x_line = np.linspace(confidence.min(), confidence.max(), 100) |
| ax.plot(x_line, np.polyval(z, x_line), 'k--', alpha=0.4, linewidth=1.5) |
|
|
| plt.tight_layout() |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.close(fig) |
| print(f' Saved: {save_path}') |
|
|
|
|
| |
| |
| |
| def main(): |
| t_start = time.time() |
|
|
| |
| print('\nLoading test set...') |
| dataset = TestDataset(TEST_CSV) |
| dataloader = DataLoader( |
| dataset, batch_size=BATCH_SIZE, shuffle=False, |
| num_workers=2, pin_memory=False |
| ) |
| print(f' Test samples: {len(dataset)}') |
|
|
| |
| features, true_labels, image_paths = extract_all_features(model, dataloader) |
| print(f' Features shape: {features.shape}') |
|
|
| t_feat = time.time() - t_start |
| print(f' Feature extraction: {t_feat:.1f}s') |
|
|
| |
| mc_probs = mc_dropout_on_heads( |
| model, features, T=T_FORWARD_PASSES, temperature=TEMPERATURE |
| ) |
| print(f' MC probs shape: {mc_probs.shape} (N, T, C)') |
|
|
| t_mc = time.time() - t_start - t_feat |
| print(f' MC head passes: {t_mc:.1f}s') |
|
|
| |
| print('\nComputing uncertainty metrics...') |
| metrics = compute_uncertainty_metrics(mc_probs) |
|
|
| |
| correct = (metrics['predicted_class'] == true_labels).astype(int) |
| accuracy = correct.mean() * 100 |
| print(f'\n --- Summary ---') |
| print(f' Accuracy (MC mean): {accuracy:.2f}%') |
| print(f' Predictive entropy: mean={metrics["predictive_entropy"].mean():.4f}, ' |
| f'std={metrics["predictive_entropy"].std():.4f}') |
| print(f' Aleatoric (exp. ent.): mean={metrics["expected_entropy"].mean():.4f}, ' |
| f'std={metrics["expected_entropy"].std():.4f}') |
| print(f' Epistemic (MI): mean={metrics["mutual_info"].mean():.4f}, ' |
| f'std={metrics["mutual_info"].std():.4f}') |
| print(f' Max confidence: mean={metrics["max_confidence"].mean():.4f}, ' |
| f'std={metrics["max_confidence"].std():.4f}') |
|
|
| |
| print(f'\n Per-class uncertainty (predictive entropy):') |
| for cls_idx in range(NUM_CLASSES): |
| mask = true_labels == cls_idx |
| n_cls = mask.sum() |
| cls_acc = correct[mask].mean() * 100 if n_cls > 0 else 0 |
| cls_ent = metrics['predictive_entropy'][mask].mean() if n_cls > 0 else 0 |
| cls_mi = metrics['mutual_info'][mask].mean() if n_cls > 0 else 0 |
| print(f' {CLASS_NAMES[cls_idx]:15s}: n={n_cls:4d}, ' |
| f'acc={cls_acc:5.1f}%, H={cls_ent:.4f}, MI={cls_mi:.4f}') |
|
|
| |
| print('\nGenerating plots...') |
|
|
| plot_uncertainty_vs_accuracy( |
| metrics, true_labels, |
| os.path.join(UNCERT_DIR, 'uncertainty_vs_accuracy.png') |
| ) |
| plot_rejection_curve( |
| metrics, true_labels, |
| os.path.join(UNCERT_DIR, 'rejection_curve.png') |
| ) |
| plot_epistemic_vs_aleatoric( |
| metrics, true_labels, |
| os.path.join(UNCERT_DIR, 'epistemic_vs_aleatoric.png') |
| ) |
| plot_uncertainty_by_class( |
| metrics, true_labels, |
| os.path.join(UNCERT_DIR, 'uncertainty_by_class.png') |
| ) |
| plot_confidence_vs_uncertainty( |
| metrics, true_labels, |
| os.path.join(UNCERT_DIR, 'confidence_vs_uncertainty.png') |
| ) |
|
|
| |
| print('\nSaving results JSON...') |
|
|
| per_image = [] |
| for i in range(len(true_labels)): |
| per_image.append({ |
| 'image_path': image_paths[i], |
| 'true_label': int(true_labels[i]), |
| 'true_class': CLASS_NAMES[int(true_labels[i])], |
| 'predicted_label': int(metrics['predicted_class'][i]), |
| 'predicted_class': CLASS_NAMES[int(metrics['predicted_class'][i])], |
| 'correct': bool(correct[i]), |
| 'max_confidence': round(float(metrics['max_confidence'][i]), 6), |
| 'predictive_entropy': round(float(metrics['predictive_entropy'][i]), 6), |
| 'expected_entropy': round(float(metrics['expected_entropy'][i]), 6), |
| 'mutual_information': round(float(metrics['mutual_info'][i]), 6), |
| 'class_variance': [round(float(v), 8) for v in metrics['class_variance'][i]], |
| 'mean_probs': [round(float(v), 6) for v in metrics['p_mean'][i]], |
| }) |
|
|
| aggregate = { |
| 'n_samples': int(len(true_labels)), |
| 'n_classes': NUM_CLASSES, |
| 'mc_passes': T_FORWARD_PASSES, |
| 'temperature': TEMPERATURE, |
| 'accuracy_pct': round(float(accuracy), 4), |
| 'overall': { |
| 'predictive_entropy': { |
| 'mean': round(float(metrics['predictive_entropy'].mean()), 6), |
| 'std': round(float(metrics['predictive_entropy'].std()), 6), |
| 'min': round(float(metrics['predictive_entropy'].min()), 6), |
| 'max': round(float(metrics['predictive_entropy'].max()), 6), |
| }, |
| 'expected_entropy': { |
| 'mean': round(float(metrics['expected_entropy'].mean()), 6), |
| 'std': round(float(metrics['expected_entropy'].std()), 6), |
| 'min': round(float(metrics['expected_entropy'].min()), 6), |
| 'max': round(float(metrics['expected_entropy'].max()), 6), |
| }, |
| 'mutual_information': { |
| 'mean': round(float(metrics['mutual_info'].mean()), 6), |
| 'std': round(float(metrics['mutual_info'].std()), 6), |
| 'min': round(float(metrics['mutual_info'].min()), 6), |
| 'max': round(float(metrics['mutual_info'].max()), 6), |
| }, |
| 'max_confidence': { |
| 'mean': round(float(metrics['max_confidence'].mean()), 6), |
| 'std': round(float(metrics['max_confidence'].std()), 6), |
| }, |
| }, |
| 'per_class': {}, |
| } |
|
|
| for cls_idx in range(NUM_CLASSES): |
| mask = true_labels == cls_idx |
| n_cls = int(mask.sum()) |
| if n_cls == 0: |
| continue |
| aggregate['per_class'][CLASS_NAMES[cls_idx]] = { |
| 'n_samples': n_cls, |
| 'accuracy': round(float(correct[mask].mean() * 100), 4), |
| 'pred_entropy_mean': round(float(metrics['predictive_entropy'][mask].mean()), 6), |
| 'pred_entropy_std': round(float(metrics['predictive_entropy'][mask].std()), 6), |
| 'aleatoric_mean': round(float(metrics['expected_entropy'][mask].mean()), 6), |
| 'epistemic_mean': round(float(metrics['mutual_info'][mask].mean()), 6), |
| 'confidence_mean': round(float(metrics['max_confidence'][mask].mean()), 6), |
| } |
|
|
| |
| entropy = metrics['predictive_entropy'] |
| sorted_idx = np.argsort(entropy)[::-1] |
| sorted_correct = correct[sorted_idx] |
| rejection_checkpoints = {} |
| for frac in [0.0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.50]: |
| n_reject = int(frac * len(true_labels)) |
| kept = sorted_correct[n_reject:] |
| if len(kept) > 0: |
| rejection_checkpoints[f'reject_{int(frac*100)}pct'] = { |
| 'accuracy': round(float(kept.mean() * 100), 4), |
| 'n_remaining': int(len(kept)), |
| } |
| aggregate['rejection_curve'] = rejection_checkpoints |
|
|
| results = { |
| 'aggregate': aggregate, |
| 'per_image': per_image, |
| } |
|
|
| json_path = os.path.join(UNCERT_DIR, 'mc_dropout_results.json') |
| with open(json_path, 'w') as f: |
| json.dump(results, f, indent=2) |
| print(f' Saved: {json_path}') |
|
|
| elapsed = time.time() - t_start |
| print(f'\nDone in {elapsed:.1f}s') |
| print('=' * 65) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|