tanishq74 commited on
Commit
c4d737f
·
verified ·
1 Parent(s): dbadc23

Add mc_dropout_uncertainty.py

Browse files
Files changed (1) hide show
  1. mc_dropout_uncertainty.py +817 -0
mc_dropout_uncertainty.py ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RetinaSense v3.0 -- MC Dropout Uncertainty Quantification (Phase 1B)
4
+ ====================================================================
5
+ Performs Monte Carlo Dropout inference on the test set to decompose
6
+ predictive uncertainty into aleatoric and epistemic components.
7
+
8
+ Strategy for efficiency:
9
+ - Run the ViT backbone ONCE per image (deterministic, no dropout in backbone)
10
+ - Cache the 768-dim CLS features
11
+ - Run T=30 stochastic forward passes through the classification heads only
12
+ (where the dropout layers live: self.drop + head dropouts)
13
+ This is 30x faster than running the full model T times.
14
+
15
+ For each test image, computes:
16
+ - Predictive entropy (total uncertainty)
17
+ - Expected entropy (aleatoric uncertainty)
18
+ - Mutual information (epistemic uncertainty)
19
+ - Per-class prediction variance
20
+
21
+ Generates:
22
+ - uncertainty_vs_accuracy.png
23
+ - rejection_curve.png
24
+ - epistemic_vs_aleatoric.png
25
+ - uncertainty_by_class.png
26
+ - confidence_vs_uncertainty.png
27
+ - mc_dropout_results.json
28
+
29
+ Usage:
30
+ python mc_dropout_uncertainty.py
31
+ """
32
+
33
+ import os
34
+ import sys
35
+ import json
36
+ import time
37
+ import warnings
38
+ import numpy as np
39
+ import pandas as pd
40
+ import cv2
41
+ import matplotlib
42
+ matplotlib.use('Agg')
43
+ import matplotlib.pyplot as plt
44
+ import matplotlib.patches as mpatches
45
+ from PIL import Image
46
+ from tqdm import tqdm
47
+
48
+ warnings.filterwarnings('ignore')
49
+
50
+ import torch
51
+ import torch.nn as nn
52
+ import torch.nn.functional as F
53
+ from torchvision import transforms
54
+ from torch.utils.data import Dataset, DataLoader
55
+
56
+ import timm
57
+
58
+ # Maximize CPU throughput
59
+ torch.set_num_threads(4)
60
+
61
+ # ================================================================
62
+ # CONFIGURATION
63
+ # ================================================================
64
+ BASE_DIR = '/teamspace/studios/this_studio'
65
+ OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3')
66
+ UNCERT_DIR = os.path.join(OUTPUT_DIR, 'uncertainty')
67
+ os.makedirs(UNCERT_DIR, exist_ok=True)
68
+
69
+ MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth')
70
+ TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json')
71
+ NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json')
72
+ TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv')
73
+
74
+ CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
75
+ NUM_CLASSES = 5
76
+ IMG_SIZE = 224
77
+ DROPOUT = 0.3
78
+
79
+ T_FORWARD_PASSES = 30 # number of MC stochastic forward passes
80
+ BATCH_SIZE = 32 # batch size for feature extraction
81
+ HEAD_BATCH = 512 # batch size for head-only MC passes (very lightweight)
82
+
83
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
+
85
+ print('=' * 65)
86
+ print(' RetinaSense v3.0 -- MC Dropout Uncertainty Quantification')
87
+ print('=' * 65)
88
+ print(f' Device : {DEVICE}')
89
+ if torch.cuda.is_available():
90
+ print(f' GPU : {torch.cuda.get_device_name(0)}')
91
+ print(f' MC passes (T) : {T_FORWARD_PASSES}')
92
+ print(f' Output dir : {UNCERT_DIR}')
93
+
94
+ # ================================================================
95
+ # LOAD NORMALISATION STATS
96
+ # ================================================================
97
+ if os.path.exists(NORM_STATS_PATH):
98
+ with open(NORM_STATS_PATH) as f:
99
+ norm_stats = json.load(f)
100
+ NORM_MEAN = norm_stats['mean_rgb']
101
+ NORM_STD = norm_stats['std_rgb']
102
+ print(f' Fundus norm : mean={[round(v,4) for v in NORM_MEAN]}, '
103
+ f'std={[round(v,4) for v in NORM_STD]}')
104
+ else:
105
+ NORM_MEAN = [0.485, 0.456, 0.406]
106
+ NORM_STD = [0.229, 0.224, 0.225]
107
+ print(' Using ImageNet normalisation fallback')
108
+
109
+ # Load temperature
110
+ with open(TEMPERATURE_PATH) as f:
111
+ temp_data = json.load(f)
112
+ TEMPERATURE = temp_data['temperature']
113
+ print(f' Temperature : {TEMPERATURE:.4f}')
114
+
115
+ # ================================================================
116
+ # MODEL ARCHITECTURE (mirrors retinasense_v3.py / gradcam_v3.py)
117
+ # ================================================================
118
+ class MultiTaskViT(nn.Module):
119
+ """ViT-Base-Patch16-224 with disease + severity heads."""
120
+
121
+ def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT):
122
+ super().__init__()
123
+ self.backbone = timm.create_model(
124
+ 'vit_base_patch16_224', pretrained=False, num_classes=0
125
+ )
126
+ feat = 768 # CLS token dimension
127
+
128
+ self.drop = nn.Dropout(drop)
129
+
130
+ self.disease_head = nn.Sequential(
131
+ nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
132
+ nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
133
+ nn.Linear(256, n_disease),
134
+ )
135
+ self.severity_head = nn.Sequential(
136
+ nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
137
+ nn.Linear(256, n_severity),
138
+ )
139
+
140
+ def forward(self, x):
141
+ f = self.backbone(x) # (B, 768)
142
+ f = self.drop(f)
143
+ return self.disease_head(f), self.severity_head(f)
144
+
145
+ def extract_features(self, x):
146
+ """Run backbone only (deterministic) to get CLS features."""
147
+ return self.backbone(x) # (B, 768)
148
+
149
+ def forward_heads(self, features):
150
+ """Run dropout + disease head on pre-extracted features."""
151
+ f = self.drop(features)
152
+ return self.disease_head(f)
153
+
154
+
155
+ # ================================================================
156
+ # LOAD MODEL
157
+ # ================================================================
158
+ print('\nLoading model...')
159
+ model = MultiTaskViT().to(DEVICE)
160
+ ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
161
+ model.load_state_dict(ckpt['model_state_dict'])
162
+ print(f' Loaded: {MODEL_PATH}')
163
+ print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1} '
164
+ f'val_acc={ckpt.get("val_acc", 0):.2f}%')
165
+
166
+
167
+ # ================================================================
168
+ # MC DROPOUT SETUP
169
+ # ================================================================
170
+ def enable_head_dropout(model):
171
+ """
172
+ Set model to eval mode, then enable dropout ONLY in the classification
173
+ heads (self.drop, disease_head dropouts). The backbone stays fully
174
+ deterministic (eval mode) so we only need one backbone pass per image.
175
+ BatchNorm layers remain in eval mode (use running stats).
176
+ """
177
+ model.eval() # everything to eval (including backbone)
178
+
179
+ # Enable dropout in the drop layer and disease_head
180
+ model.drop.train()
181
+ for m in model.disease_head.modules():
182
+ if isinstance(m, (nn.Dropout, nn.Dropout2d)):
183
+ m.train()
184
+
185
+
186
+ enable_head_dropout(model)
187
+
188
+ # Count active dropout layers
189
+ n_dropout_active = 0
190
+ for name, m in model.named_modules():
191
+ if isinstance(m, (nn.Dropout, nn.Dropout2d)) and m.training:
192
+ n_dropout_active += 1
193
+ n_dropout_total = sum(1 for m in model.modules() if isinstance(m, (nn.Dropout, nn.Dropout2d)))
194
+ print(f'\n MC Dropout enabled in heads: {n_dropout_active} active / {n_dropout_total} total dropout layers')
195
+ print(f' Backbone: deterministic (eval mode) -- single pass per image')
196
+ print(f' Heads: stochastic (train mode dropout) -- {T_FORWARD_PASSES} passes per image')
197
+
198
+
199
+ # ================================================================
200
+ # PREPROCESSING (matches gradcam_v3.py pipeline)
201
+ # ================================================================
202
+ def ben_graham(path, sz=IMG_SIZE, sigma=10):
203
+ """Ben Graham high-frequency fundus enhancement (APTOS-style)."""
204
+ img = cv2.imread(path)
205
+ if img is None:
206
+ img = np.array(Image.open(path).convert('RGB'))
207
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
208
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
209
+ img = cv2.resize(img, (sz, sz))
210
+ img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0, 0), sigma), -4, 128)
211
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
212
+ cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1)
213
+ return cv2.bitwise_and(img, img, mask=mask)
214
+
215
+
216
+ def clahe_preprocess(path, sz=IMG_SIZE):
217
+ """CLAHE-based contrast enhancement (ODIR-style)."""
218
+ img = cv2.imread(path)
219
+ if img is None:
220
+ img = np.array(Image.open(path).convert('RGB'))
221
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
222
+ img = cv2.resize(img, (sz, sz))
223
+ lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
224
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
225
+ lab[:, :, 0] = clahe.apply(lab[:, :, 0])
226
+ img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
227
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
228
+
229
+
230
+ def resolve_path(image_path):
231
+ """Resolve image path relative to BASE_DIR."""
232
+ if os.path.isabs(image_path) and os.path.exists(image_path):
233
+ return image_path
234
+ clean = image_path
235
+ while clean.startswith('./'):
236
+ clean = clean[2:]
237
+ return os.path.join(BASE_DIR, clean)
238
+
239
+
240
+ # ================================================================
241
+ # DATASET
242
+ # ================================================================
243
+ class TestDataset(Dataset):
244
+ """Test dataset loading preprocessed images from cache or live."""
245
+
246
+ def __init__(self, csv_path):
247
+ self.df = pd.read_csv(csv_path).reset_index(drop=True)
248
+ self.transform = transforms.Compose([
249
+ transforms.ToPILImage(),
250
+ transforms.ToTensor(),
251
+ transforms.Normalize(NORM_MEAN, NORM_STD),
252
+ ])
253
+
254
+ def __len__(self):
255
+ return len(self.df)
256
+
257
+ def __getitem__(self, idx):
258
+ row = self.df.iloc[idx]
259
+ img_path = str(row['image_path'])
260
+ dataset = str(row.get('source', 'auto'))
261
+ label = int(row['disease_label'])
262
+
263
+ # Try loading from cache first
264
+ cache_path = str(row.get('cache_path', ''))
265
+ if cache_path and cache_path != 'nan':
266
+ cache_abs = resolve_path(cache_path)
267
+ if os.path.exists(cache_abs):
268
+ try:
269
+ img_np = np.load(cache_abs)
270
+ img_tensor = self.transform(img_np)
271
+ return img_tensor, label, img_path
272
+ except Exception:
273
+ pass
274
+
275
+ # Live preprocessing
276
+ abs_path = resolve_path(img_path)
277
+ try:
278
+ if dataset == 'APTOS':
279
+ img_np = ben_graham(abs_path)
280
+ else:
281
+ img_np = clahe_preprocess(abs_path)
282
+ img_tensor = self.transform(img_np)
283
+ except Exception:
284
+ img_tensor = torch.zeros(3, IMG_SIZE, IMG_SIZE)
285
+
286
+ return img_tensor, label, img_path
287
+
288
+
289
+ # ================================================================
290
+ # TWO-STAGE MC DROPOUT INFERENCE
291
+ # ================================================================
292
+ def extract_all_features(model, dataloader):
293
+ """
294
+ Stage 1: Run backbone once per image to get CLS features (deterministic).
295
+ Returns features (N, 768), labels (N,), paths list.
296
+ """
297
+ all_features = []
298
+ all_labels = []
299
+ all_paths = []
300
+
301
+ print(f'\n Stage 1: Extracting backbone features (deterministic)...')
302
+ with torch.no_grad():
303
+ for images, labels, paths in tqdm(dataloader, desc=' Features', ncols=80):
304
+ images = images.to(DEVICE)
305
+ feats = model.extract_features(images) # (B, 768)
306
+ all_features.append(feats.cpu())
307
+ all_labels.extend(labels.numpy().tolist())
308
+ all_paths.extend(paths)
309
+
310
+ all_features = torch.cat(all_features, dim=0) # (N, 768)
311
+ all_labels = np.array(all_labels)
312
+ return all_features, all_labels, all_paths
313
+
314
+
315
+ def mc_dropout_on_heads(model, features, T=T_FORWARD_PASSES, temperature=TEMPERATURE):
316
+ """
317
+ Stage 2: Run T stochastic forward passes through heads only.
318
+ features: (N, 768) tensor
319
+ Returns: (N, T, C) numpy array of probability vectors.
320
+ """
321
+ N = features.size(0)
322
+ all_probs = np.zeros((N, T, NUM_CLASSES), dtype=np.float32)
323
+
324
+ print(f'\n Stage 2: MC Dropout through heads ({T} passes, {N} samples)...')
325
+
326
+ with torch.no_grad():
327
+ for t in tqdm(range(T), desc=' MC Passes', ncols=80):
328
+ # Process in chunks to avoid memory issues
329
+ for start in range(0, N, HEAD_BATCH):
330
+ end = min(start + HEAD_BATCH, N)
331
+ feat_batch = features[start:end].to(DEVICE)
332
+ logits = model.forward_heads(feat_batch)
333
+ scaled = logits / temperature
334
+ probs = F.softmax(scaled, dim=1)
335
+ all_probs[start:end, t, :] = probs.cpu().numpy()
336
+
337
+ return all_probs
338
+
339
+
340
+ # ================================================================
341
+ # UNCERTAINTY METRICS
342
+ # ================================================================
343
+ def compute_uncertainty_metrics(mc_probs):
344
+ """
345
+ Compute uncertainty metrics from MC dropout probability samples.
346
+
347
+ Args:
348
+ mc_probs: (N, T, C) array of MC sampled probability vectors
349
+
350
+ Returns dict with:
351
+ - p_mean, predicted_class, max_confidence
352
+ - predictive_entropy (total), expected_entropy (aleatoric),
353
+ mutual_info (epistemic), class_variance
354
+ """
355
+ N, T, C = mc_probs.shape
356
+ eps = 1e-10
357
+
358
+ # Predictive mean: average over T passes
359
+ p_mean = mc_probs.mean(axis=1) # (N, C)
360
+ predicted_class = p_mean.argmax(axis=1) # (N,)
361
+ max_confidence = p_mean.max(axis=1) # (N,)
362
+
363
+ # Predictive entropy: H[p_bar] = -sum(p_bar * log(p_bar)) -- TOTAL uncertainty
364
+ predictive_entropy = -np.sum(p_mean * np.log(p_mean + eps), axis=1) # (N,)
365
+
366
+ # Per-pass entropies
367
+ per_pass_entropy = -np.sum(mc_probs * np.log(mc_probs + eps), axis=2) # (N, T)
368
+
369
+ # Expected entropy: E_t[H[p_t]] -- ALEATORIC uncertainty
370
+ expected_entropy = per_pass_entropy.mean(axis=1) # (N,)
371
+
372
+ # Mutual information: H - E[H] -- EPISTEMIC uncertainty
373
+ mutual_info = predictive_entropy - expected_entropy
374
+ mutual_info = np.maximum(mutual_info, 0.0)
375
+
376
+ # Prediction variance per class
377
+ class_variance = mc_probs.var(axis=1) # (N, C)
378
+
379
+ return {
380
+ 'p_mean': p_mean,
381
+ 'predicted_class': predicted_class,
382
+ 'max_confidence': max_confidence,
383
+ 'predictive_entropy': predictive_entropy,
384
+ 'expected_entropy': expected_entropy,
385
+ 'mutual_info': mutual_info,
386
+ 'class_variance': class_variance,
387
+ }
388
+
389
+
390
+ # ================================================================
391
+ # PLOTTING FUNCTIONS
392
+ # ================================================================
393
+ def plot_uncertainty_vs_accuracy(metrics, labels, save_path):
394
+ """Scatter: total uncertainty vs correctness, colored by class."""
395
+ correct = (metrics['predicted_class'] == labels).astype(int)
396
+ entropy = metrics['predictive_entropy']
397
+
398
+ fig, ax = plt.subplots(figsize=(10, 7))
399
+
400
+ colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES))
401
+ for cls_idx in range(NUM_CLASSES):
402
+ mask = labels == cls_idx
403
+ ax.scatter(
404
+ entropy[mask], correct[mask] + np.random.uniform(-0.08, 0.08, mask.sum()),
405
+ c=[colors[cls_idx]], alpha=0.5, s=20, label=CLASS_NAMES[cls_idx],
406
+ edgecolors='none'
407
+ )
408
+
409
+ ax.set_xlabel('Predictive Entropy (Total Uncertainty)', fontsize=12)
410
+ ax.set_ylabel('Correctness (1=correct, 0=wrong)', fontsize=12)
411
+ ax.set_title('MC Dropout: Uncertainty vs Prediction Correctness', fontsize=14)
412
+ ax.set_yticks([0, 1])
413
+ ax.set_yticklabels(['Incorrect', 'Correct'])
414
+ ax.legend(title='True Class', fontsize=9, title_fontsize=10)
415
+ ax.grid(True, alpha=0.3)
416
+
417
+ # Add vertical line at median uncertainty
418
+ med = np.median(entropy)
419
+ ax.axvline(med, color='red', linestyle='--', alpha=0.5, label=f'Median H={med:.3f}')
420
+
421
+ # Summary stats
422
+ correct_ent = entropy[correct == 1]
423
+ wrong_ent = entropy[correct == 0]
424
+ textstr = (f'Correct: mean H={correct_ent.mean():.3f}\n'
425
+ f'Wrong: mean H={wrong_ent.mean():.3f}' if len(wrong_ent) > 0
426
+ else f'Correct: mean H={correct_ent.mean():.3f}')
427
+ ax.text(0.98, 0.5, textstr, transform=ax.transAxes,
428
+ fontsize=9, verticalalignment='center', horizontalalignment='right',
429
+ bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
430
+
431
+ plt.tight_layout()
432
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
433
+ plt.close(fig)
434
+ print(f' Saved: {save_path}')
435
+
436
+
437
+ def plot_rejection_curve(metrics, labels, save_path):
438
+ """Accuracy as a function of rejection threshold on uncertainty."""
439
+ entropy = metrics['predictive_entropy']
440
+ correct = (metrics['predicted_class'] == labels).astype(int)
441
+
442
+ # Sort by decreasing uncertainty
443
+ sorted_idx = np.argsort(entropy)[::-1]
444
+ sorted_correct = correct[sorted_idx]
445
+
446
+ N = len(labels)
447
+ rejection_fracs = np.linspace(0.0, 0.95, 200)
448
+ accuracies = []
449
+ n_remaining = []
450
+
451
+ for frac in rejection_fracs:
452
+ n_reject = int(frac * N)
453
+ kept = sorted_correct[n_reject:]
454
+ if len(kept) == 0:
455
+ accuracies.append(np.nan)
456
+ n_remaining.append(0)
457
+ else:
458
+ accuracies.append(kept.mean() * 100)
459
+ n_remaining.append(len(kept))
460
+
461
+ accuracies = np.array(accuracies)
462
+ n_remaining = np.array(n_remaining)
463
+
464
+ fig, ax1 = plt.subplots(figsize=(10, 7))
465
+
466
+ color1 = '#2196F3'
467
+ ax1.plot(rejection_fracs * 100, accuracies, color=color1, linewidth=2.0,
468
+ label='Accuracy')
469
+ ax1.set_xlabel('Rejection Rate (%)', fontsize=12)
470
+ ax1.set_ylabel('Accuracy (%)', fontsize=12, color=color1)
471
+ ax1.tick_params(axis='y', labelcolor=color1)
472
+ ax1.set_ylim([max(50, np.nanmin(accuracies) - 5), 101])
473
+
474
+ # Secondary axis: number of remaining samples
475
+ ax2 = ax1.twinx()
476
+ color2 = '#FF9800'
477
+ ax2.plot(rejection_fracs * 100, n_remaining, color=color2, linewidth=1.5,
478
+ linestyle='--', alpha=0.7, label='Remaining')
479
+ ax2.set_ylabel('Samples Remaining', fontsize=12, color=color2)
480
+ ax2.tick_params(axis='y', labelcolor=color2)
481
+
482
+ # Baseline accuracy (no rejection)
483
+ base_acc = correct.mean() * 100
484
+ ax1.axhline(base_acc, color='gray', linestyle=':', alpha=0.5)
485
+ ax1.text(2, base_acc + 0.5, f'Baseline: {base_acc:.1f}%', fontsize=9, color='gray')
486
+
487
+ ax1.set_title('Rejection Curve: Accuracy vs Uncertainty-Based Rejection', fontsize=14)
488
+ ax1.grid(True, alpha=0.3)
489
+
490
+ # Combined legend
491
+ lines1, labels1 = ax1.get_legend_handles_labels()
492
+ lines2, labels2 = ax2.get_legend_handles_labels()
493
+ ax1.legend(lines1 + lines2, labels1 + labels2, loc='lower left', fontsize=10)
494
+
495
+ plt.tight_layout()
496
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
497
+ plt.close(fig)
498
+ print(f' Saved: {save_path}')
499
+
500
+
501
+ def plot_epistemic_vs_aleatoric(metrics, labels, save_path):
502
+ """Scatter separating epistemic and aleatoric uncertainty."""
503
+ aleatoric = metrics['expected_entropy']
504
+ epistemic = metrics['mutual_info']
505
+ correct = (metrics['predicted_class'] == labels).astype(int)
506
+
507
+ fig, ax = plt.subplots(figsize=(10, 7))
508
+
509
+ colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES))
510
+ for cls_idx in range(NUM_CLASSES):
511
+ mask = labels == cls_idx
512
+ ax.scatter(
513
+ aleatoric[mask], epistemic[mask],
514
+ c=[colors[cls_idx]], alpha=0.45, s=20, label=CLASS_NAMES[cls_idx],
515
+ edgecolors='none'
516
+ )
517
+
518
+ # Mark misclassified samples
519
+ wrong_mask = correct == 0
520
+ if wrong_mask.sum() > 0:
521
+ ax.scatter(
522
+ aleatoric[wrong_mask], epistemic[wrong_mask],
523
+ facecolors='none', edgecolors='red', s=60, linewidths=1.2,
524
+ label='Misclassified', zorder=5
525
+ )
526
+
527
+ ax.set_xlabel('Aleatoric Uncertainty (Expected Entropy)', fontsize=12)
528
+ ax.set_ylabel('Epistemic Uncertainty (Mutual Information)', fontsize=12)
529
+ ax.set_title('Decomposition of Uncertainty: Epistemic vs Aleatoric', fontsize=14)
530
+ ax.legend(fontsize=9, title='Class', title_fontsize=10)
531
+ ax.grid(True, alpha=0.3)
532
+
533
+ # Annotate quadrants
534
+ xlim = ax.get_xlim()
535
+ ylim = ax.get_ylim()
536
+ ax.text(xlim[0] + 0.02 * (xlim[1] - xlim[0]),
537
+ ylim[1] - 0.05 * (ylim[1] - ylim[0]),
538
+ 'Low aleatoric\nHigh epistemic\n(need more data)',
539
+ fontsize=8, alpha=0.6, va='top')
540
+ ax.text(xlim[1] - 0.02 * (xlim[1] - xlim[0]),
541
+ ylim[1] - 0.05 * (ylim[1] - ylim[0]),
542
+ 'High aleatoric\nHigh epistemic\n(hard + unseen)',
543
+ fontsize=8, alpha=0.6, va='top', ha='right')
544
+ ax.text(xlim[1] - 0.02 * (xlim[1] - xlim[0]),
545
+ ylim[0] + 0.05 * (ylim[1] - ylim[0]),
546
+ 'High aleatoric\nLow epistemic\n(inherently noisy)',
547
+ fontsize=8, alpha=0.6, va='bottom', ha='right')
548
+
549
+ plt.tight_layout()
550
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
551
+ plt.close(fig)
552
+ print(f' Saved: {save_path}')
553
+
554
+
555
+ def plot_uncertainty_by_class(metrics, labels, save_path):
556
+ """Box plots of uncertainty per class."""
557
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
558
+
559
+ data_types = [
560
+ ('predictive_entropy', 'Total Uncertainty (Predictive Entropy)'),
561
+ ('expected_entropy', 'Aleatoric Uncertainty (Expected Entropy)'),
562
+ ('mutual_info', 'Epistemic Uncertainty (Mutual Information)'),
563
+ ]
564
+
565
+ for ax, (key, title) in zip(axes, data_types):
566
+ data = metrics[key]
567
+ box_data = [data[labels == c] for c in range(NUM_CLASSES)]
568
+
569
+ bp = ax.boxplot(box_data, labels=CLASS_NAMES, patch_artist=True,
570
+ widths=0.6, showfliers=True,
571
+ flierprops=dict(marker='o', markersize=3, alpha=0.3))
572
+
573
+ colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES))
574
+ for patch, color in zip(bp['boxes'], colors):
575
+ patch.set_facecolor(color)
576
+ patch.set_alpha(0.7)
577
+
578
+ ax.set_title(title, fontsize=11)
579
+ ax.set_ylabel('Uncertainty', fontsize=10)
580
+ ax.grid(True, axis='y', alpha=0.3)
581
+ ax.tick_params(axis='x', rotation=15)
582
+
583
+ # Add sample counts
584
+ for i, cls_data in enumerate(box_data):
585
+ ax.text(i + 1, ax.get_ylim()[1] * 0.95,
586
+ f'n={len(cls_data)}', ha='center', fontsize=8, alpha=0.6)
587
+
588
+ plt.suptitle('Uncertainty Distribution by Disease Class', fontsize=14, y=1.02)
589
+ plt.tight_layout()
590
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
591
+ plt.close(fig)
592
+ print(f' Saved: {save_path}')
593
+
594
+
595
+ def plot_confidence_vs_uncertainty(metrics, labels, save_path):
596
+ """Scatter showing confidence vs uncertainty (should be anti-correlated)."""
597
+ confidence = metrics['max_confidence']
598
+ entropy = metrics['predictive_entropy']
599
+ correct = (metrics['predicted_class'] == labels).astype(int)
600
+
601
+ fig, ax = plt.subplots(figsize=(10, 7))
602
+
603
+ scatter_correct = ax.scatter(
604
+ confidence[correct == 1], entropy[correct == 1],
605
+ c='#4CAF50', alpha=0.4, s=15, label='Correct', edgecolors='none'
606
+ )
607
+ scatter_wrong = ax.scatter(
608
+ confidence[correct == 0], entropy[correct == 0],
609
+ c='#F44336', alpha=0.6, s=25, label='Incorrect', edgecolors='none',
610
+ marker='x', linewidths=1.0
611
+ )
612
+
613
+ # Compute correlation
614
+ from scipy import stats
615
+ r, p_val = stats.pearsonr(confidence, entropy)
616
+
617
+ ax.set_xlabel('Maximum Confidence (max p_bar)', fontsize=12)
618
+ ax.set_ylabel('Predictive Entropy (Total Uncertainty)', fontsize=12)
619
+ ax.set_title(f'Confidence vs Uncertainty (Pearson r={r:.3f}, p={p_val:.2e})', fontsize=14)
620
+ ax.legend(fontsize=10)
621
+ ax.grid(True, alpha=0.3)
622
+
623
+ # Add trend line
624
+ z = np.polyfit(confidence, entropy, 1)
625
+ x_line = np.linspace(confidence.min(), confidence.max(), 100)
626
+ ax.plot(x_line, np.polyval(z, x_line), 'k--', alpha=0.4, linewidth=1.5)
627
+
628
+ plt.tight_layout()
629
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
630
+ plt.close(fig)
631
+ print(f' Saved: {save_path}')
632
+
633
+
634
+ # ================================================================
635
+ # MAIN
636
+ # ================================================================
637
+ def main():
638
+ t_start = time.time()
639
+
640
+ # ---- 1. Build DataLoader ----
641
+ print('\nLoading test set...')
642
+ dataset = TestDataset(TEST_CSV)
643
+ dataloader = DataLoader(
644
+ dataset, batch_size=BATCH_SIZE, shuffle=False,
645
+ num_workers=2, pin_memory=False
646
+ )
647
+ print(f' Test samples: {len(dataset)}')
648
+
649
+ # ---- 2. Stage 1: Extract backbone features (single deterministic pass) ----
650
+ features, true_labels, image_paths = extract_all_features(model, dataloader)
651
+ print(f' Features shape: {features.shape}')
652
+
653
+ t_feat = time.time() - t_start
654
+ print(f' Feature extraction: {t_feat:.1f}s')
655
+
656
+ # ---- 3. Stage 2: MC Dropout on heads only ----
657
+ mc_probs = mc_dropout_on_heads(
658
+ model, features, T=T_FORWARD_PASSES, temperature=TEMPERATURE
659
+ )
660
+ print(f' MC probs shape: {mc_probs.shape} (N, T, C)')
661
+
662
+ t_mc = time.time() - t_start - t_feat
663
+ print(f' MC head passes: {t_mc:.1f}s')
664
+
665
+ # ---- 4. Compute Uncertainty Metrics ----
666
+ print('\nComputing uncertainty metrics...')
667
+ metrics = compute_uncertainty_metrics(mc_probs)
668
+
669
+ # Print summary statistics
670
+ correct = (metrics['predicted_class'] == true_labels).astype(int)
671
+ accuracy = correct.mean() * 100
672
+ print(f'\n --- Summary ---')
673
+ print(f' Accuracy (MC mean): {accuracy:.2f}%')
674
+ print(f' Predictive entropy: mean={metrics["predictive_entropy"].mean():.4f}, '
675
+ f'std={metrics["predictive_entropy"].std():.4f}')
676
+ print(f' Aleatoric (exp. ent.): mean={metrics["expected_entropy"].mean():.4f}, '
677
+ f'std={metrics["expected_entropy"].std():.4f}')
678
+ print(f' Epistemic (MI): mean={metrics["mutual_info"].mean():.4f}, '
679
+ f'std={metrics["mutual_info"].std():.4f}')
680
+ print(f' Max confidence: mean={metrics["max_confidence"].mean():.4f}, '
681
+ f'std={metrics["max_confidence"].std():.4f}')
682
+
683
+ # Per-class stats
684
+ print(f'\n Per-class uncertainty (predictive entropy):')
685
+ for cls_idx in range(NUM_CLASSES):
686
+ mask = true_labels == cls_idx
687
+ n_cls = mask.sum()
688
+ cls_acc = correct[mask].mean() * 100 if n_cls > 0 else 0
689
+ cls_ent = metrics['predictive_entropy'][mask].mean() if n_cls > 0 else 0
690
+ cls_mi = metrics['mutual_info'][mask].mean() if n_cls > 0 else 0
691
+ print(f' {CLASS_NAMES[cls_idx]:15s}: n={n_cls:4d}, '
692
+ f'acc={cls_acc:5.1f}%, H={cls_ent:.4f}, MI={cls_mi:.4f}')
693
+
694
+ # ---- 5. Generate Plots ----
695
+ print('\nGenerating plots...')
696
+
697
+ plot_uncertainty_vs_accuracy(
698
+ metrics, true_labels,
699
+ os.path.join(UNCERT_DIR, 'uncertainty_vs_accuracy.png')
700
+ )
701
+ plot_rejection_curve(
702
+ metrics, true_labels,
703
+ os.path.join(UNCERT_DIR, 'rejection_curve.png')
704
+ )
705
+ plot_epistemic_vs_aleatoric(
706
+ metrics, true_labels,
707
+ os.path.join(UNCERT_DIR, 'epistemic_vs_aleatoric.png')
708
+ )
709
+ plot_uncertainty_by_class(
710
+ metrics, true_labels,
711
+ os.path.join(UNCERT_DIR, 'uncertainty_by_class.png')
712
+ )
713
+ plot_confidence_vs_uncertainty(
714
+ metrics, true_labels,
715
+ os.path.join(UNCERT_DIR, 'confidence_vs_uncertainty.png')
716
+ )
717
+
718
+ # ---- 6. Save JSON Results ----
719
+ print('\nSaving results JSON...')
720
+
721
+ per_image = []
722
+ for i in range(len(true_labels)):
723
+ per_image.append({
724
+ 'image_path': image_paths[i],
725
+ 'true_label': int(true_labels[i]),
726
+ 'true_class': CLASS_NAMES[int(true_labels[i])],
727
+ 'predicted_label': int(metrics['predicted_class'][i]),
728
+ 'predicted_class': CLASS_NAMES[int(metrics['predicted_class'][i])],
729
+ 'correct': bool(correct[i]),
730
+ 'max_confidence': round(float(metrics['max_confidence'][i]), 6),
731
+ 'predictive_entropy': round(float(metrics['predictive_entropy'][i]), 6),
732
+ 'expected_entropy': round(float(metrics['expected_entropy'][i]), 6),
733
+ 'mutual_information': round(float(metrics['mutual_info'][i]), 6),
734
+ 'class_variance': [round(float(v), 8) for v in metrics['class_variance'][i]],
735
+ 'mean_probs': [round(float(v), 6) for v in metrics['p_mean'][i]],
736
+ })
737
+
738
+ aggregate = {
739
+ 'n_samples': int(len(true_labels)),
740
+ 'n_classes': NUM_CLASSES,
741
+ 'mc_passes': T_FORWARD_PASSES,
742
+ 'temperature': TEMPERATURE,
743
+ 'accuracy_pct': round(float(accuracy), 4),
744
+ 'overall': {
745
+ 'predictive_entropy': {
746
+ 'mean': round(float(metrics['predictive_entropy'].mean()), 6),
747
+ 'std': round(float(metrics['predictive_entropy'].std()), 6),
748
+ 'min': round(float(metrics['predictive_entropy'].min()), 6),
749
+ 'max': round(float(metrics['predictive_entropy'].max()), 6),
750
+ },
751
+ 'expected_entropy': {
752
+ 'mean': round(float(metrics['expected_entropy'].mean()), 6),
753
+ 'std': round(float(metrics['expected_entropy'].std()), 6),
754
+ 'min': round(float(metrics['expected_entropy'].min()), 6),
755
+ 'max': round(float(metrics['expected_entropy'].max()), 6),
756
+ },
757
+ 'mutual_information': {
758
+ 'mean': round(float(metrics['mutual_info'].mean()), 6),
759
+ 'std': round(float(metrics['mutual_info'].std()), 6),
760
+ 'min': round(float(metrics['mutual_info'].min()), 6),
761
+ 'max': round(float(metrics['mutual_info'].max()), 6),
762
+ },
763
+ 'max_confidence': {
764
+ 'mean': round(float(metrics['max_confidence'].mean()), 6),
765
+ 'std': round(float(metrics['max_confidence'].std()), 6),
766
+ },
767
+ },
768
+ 'per_class': {},
769
+ }
770
+
771
+ for cls_idx in range(NUM_CLASSES):
772
+ mask = true_labels == cls_idx
773
+ n_cls = int(mask.sum())
774
+ if n_cls == 0:
775
+ continue
776
+ aggregate['per_class'][CLASS_NAMES[cls_idx]] = {
777
+ 'n_samples': n_cls,
778
+ 'accuracy': round(float(correct[mask].mean() * 100), 4),
779
+ 'pred_entropy_mean': round(float(metrics['predictive_entropy'][mask].mean()), 6),
780
+ 'pred_entropy_std': round(float(metrics['predictive_entropy'][mask].std()), 6),
781
+ 'aleatoric_mean': round(float(metrics['expected_entropy'][mask].mean()), 6),
782
+ 'epistemic_mean': round(float(metrics['mutual_info'][mask].mean()), 6),
783
+ 'confidence_mean': round(float(metrics['max_confidence'][mask].mean()), 6),
784
+ }
785
+
786
+ # Rejection curve data at key thresholds
787
+ entropy = metrics['predictive_entropy']
788
+ sorted_idx = np.argsort(entropy)[::-1]
789
+ sorted_correct = correct[sorted_idx]
790
+ rejection_checkpoints = {}
791
+ for frac in [0.0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.50]:
792
+ n_reject = int(frac * len(true_labels))
793
+ kept = sorted_correct[n_reject:]
794
+ if len(kept) > 0:
795
+ rejection_checkpoints[f'reject_{int(frac*100)}pct'] = {
796
+ 'accuracy': round(float(kept.mean() * 100), 4),
797
+ 'n_remaining': int(len(kept)),
798
+ }
799
+ aggregate['rejection_curve'] = rejection_checkpoints
800
+
801
+ results = {
802
+ 'aggregate': aggregate,
803
+ 'per_image': per_image,
804
+ }
805
+
806
+ json_path = os.path.join(UNCERT_DIR, 'mc_dropout_results.json')
807
+ with open(json_path, 'w') as f:
808
+ json.dump(results, f, indent=2)
809
+ print(f' Saved: {json_path}')
810
+
811
+ elapsed = time.time() - t_start
812
+ print(f'\nDone in {elapsed:.1f}s')
813
+ print('=' * 65)
814
+
815
+
816
+ if __name__ == '__main__':
817
+ main()