Add mc_dropout_uncertainty.py
Browse files- mc_dropout_uncertainty.py +817 -0
mc_dropout_uncertainty.py
ADDED
|
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
RetinaSense v3.0 -- MC Dropout Uncertainty Quantification (Phase 1B)
|
| 4 |
+
====================================================================
|
| 5 |
+
Performs Monte Carlo Dropout inference on the test set to decompose
|
| 6 |
+
predictive uncertainty into aleatoric and epistemic components.
|
| 7 |
+
|
| 8 |
+
Strategy for efficiency:
|
| 9 |
+
- Run the ViT backbone ONCE per image (deterministic, no dropout in backbone)
|
| 10 |
+
- Cache the 768-dim CLS features
|
| 11 |
+
- Run T=30 stochastic forward passes through the classification heads only
|
| 12 |
+
(where the dropout layers live: self.drop + head dropouts)
|
| 13 |
+
This is 30x faster than running the full model T times.
|
| 14 |
+
|
| 15 |
+
For each test image, computes:
|
| 16 |
+
- Predictive entropy (total uncertainty)
|
| 17 |
+
- Expected entropy (aleatoric uncertainty)
|
| 18 |
+
- Mutual information (epistemic uncertainty)
|
| 19 |
+
- Per-class prediction variance
|
| 20 |
+
|
| 21 |
+
Generates:
|
| 22 |
+
- uncertainty_vs_accuracy.png
|
| 23 |
+
- rejection_curve.png
|
| 24 |
+
- epistemic_vs_aleatoric.png
|
| 25 |
+
- uncertainty_by_class.png
|
| 26 |
+
- confidence_vs_uncertainty.png
|
| 27 |
+
- mc_dropout_results.json
|
| 28 |
+
|
| 29 |
+
Usage:
|
| 30 |
+
python mc_dropout_uncertainty.py
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import os
|
| 34 |
+
import sys
|
| 35 |
+
import json
|
| 36 |
+
import time
|
| 37 |
+
import warnings
|
| 38 |
+
import numpy as np
|
| 39 |
+
import pandas as pd
|
| 40 |
+
import cv2
|
| 41 |
+
import matplotlib
|
| 42 |
+
matplotlib.use('Agg')
|
| 43 |
+
import matplotlib.pyplot as plt
|
| 44 |
+
import matplotlib.patches as mpatches
|
| 45 |
+
from PIL import Image
|
| 46 |
+
from tqdm import tqdm
|
| 47 |
+
|
| 48 |
+
warnings.filterwarnings('ignore')
|
| 49 |
+
|
| 50 |
+
import torch
|
| 51 |
+
import torch.nn as nn
|
| 52 |
+
import torch.nn.functional as F
|
| 53 |
+
from torchvision import transforms
|
| 54 |
+
from torch.utils.data import Dataset, DataLoader
|
| 55 |
+
|
| 56 |
+
import timm
|
| 57 |
+
|
| 58 |
+
# Maximize CPU throughput
|
| 59 |
+
torch.set_num_threads(4)
|
| 60 |
+
|
| 61 |
+
# ================================================================
|
| 62 |
+
# CONFIGURATION
|
| 63 |
+
# ================================================================
|
| 64 |
+
BASE_DIR = '/teamspace/studios/this_studio'
|
| 65 |
+
OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3')
|
| 66 |
+
UNCERT_DIR = os.path.join(OUTPUT_DIR, 'uncertainty')
|
| 67 |
+
os.makedirs(UNCERT_DIR, exist_ok=True)
|
| 68 |
+
|
| 69 |
+
MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth')
|
| 70 |
+
TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json')
|
| 71 |
+
NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json')
|
| 72 |
+
TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv')
|
| 73 |
+
|
| 74 |
+
CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
|
| 75 |
+
NUM_CLASSES = 5
|
| 76 |
+
IMG_SIZE = 224
|
| 77 |
+
DROPOUT = 0.3
|
| 78 |
+
|
| 79 |
+
T_FORWARD_PASSES = 30 # number of MC stochastic forward passes
|
| 80 |
+
BATCH_SIZE = 32 # batch size for feature extraction
|
| 81 |
+
HEAD_BATCH = 512 # batch size for head-only MC passes (very lightweight)
|
| 82 |
+
|
| 83 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 84 |
+
|
| 85 |
+
print('=' * 65)
|
| 86 |
+
print(' RetinaSense v3.0 -- MC Dropout Uncertainty Quantification')
|
| 87 |
+
print('=' * 65)
|
| 88 |
+
print(f' Device : {DEVICE}')
|
| 89 |
+
if torch.cuda.is_available():
|
| 90 |
+
print(f' GPU : {torch.cuda.get_device_name(0)}')
|
| 91 |
+
print(f' MC passes (T) : {T_FORWARD_PASSES}')
|
| 92 |
+
print(f' Output dir : {UNCERT_DIR}')
|
| 93 |
+
|
| 94 |
+
# ================================================================
|
| 95 |
+
# LOAD NORMALISATION STATS
|
| 96 |
+
# ================================================================
|
| 97 |
+
if os.path.exists(NORM_STATS_PATH):
|
| 98 |
+
with open(NORM_STATS_PATH) as f:
|
| 99 |
+
norm_stats = json.load(f)
|
| 100 |
+
NORM_MEAN = norm_stats['mean_rgb']
|
| 101 |
+
NORM_STD = norm_stats['std_rgb']
|
| 102 |
+
print(f' Fundus norm : mean={[round(v,4) for v in NORM_MEAN]}, '
|
| 103 |
+
f'std={[round(v,4) for v in NORM_STD]}')
|
| 104 |
+
else:
|
| 105 |
+
NORM_MEAN = [0.485, 0.456, 0.406]
|
| 106 |
+
NORM_STD = [0.229, 0.224, 0.225]
|
| 107 |
+
print(' Using ImageNet normalisation fallback')
|
| 108 |
+
|
| 109 |
+
# Load temperature
|
| 110 |
+
with open(TEMPERATURE_PATH) as f:
|
| 111 |
+
temp_data = json.load(f)
|
| 112 |
+
TEMPERATURE = temp_data['temperature']
|
| 113 |
+
print(f' Temperature : {TEMPERATURE:.4f}')
|
| 114 |
+
|
| 115 |
+
# ================================================================
|
| 116 |
+
# MODEL ARCHITECTURE (mirrors retinasense_v3.py / gradcam_v3.py)
|
| 117 |
+
# ================================================================
|
| 118 |
+
class MultiTaskViT(nn.Module):
|
| 119 |
+
"""ViT-Base-Patch16-224 with disease + severity heads."""
|
| 120 |
+
|
| 121 |
+
def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.backbone = timm.create_model(
|
| 124 |
+
'vit_base_patch16_224', pretrained=False, num_classes=0
|
| 125 |
+
)
|
| 126 |
+
feat = 768 # CLS token dimension
|
| 127 |
+
|
| 128 |
+
self.drop = nn.Dropout(drop)
|
| 129 |
+
|
| 130 |
+
self.disease_head = nn.Sequential(
|
| 131 |
+
nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
|
| 132 |
+
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
|
| 133 |
+
nn.Linear(256, n_disease),
|
| 134 |
+
)
|
| 135 |
+
self.severity_head = nn.Sequential(
|
| 136 |
+
nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
|
| 137 |
+
nn.Linear(256, n_severity),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
f = self.backbone(x) # (B, 768)
|
| 142 |
+
f = self.drop(f)
|
| 143 |
+
return self.disease_head(f), self.severity_head(f)
|
| 144 |
+
|
| 145 |
+
def extract_features(self, x):
|
| 146 |
+
"""Run backbone only (deterministic) to get CLS features."""
|
| 147 |
+
return self.backbone(x) # (B, 768)
|
| 148 |
+
|
| 149 |
+
def forward_heads(self, features):
|
| 150 |
+
"""Run dropout + disease head on pre-extracted features."""
|
| 151 |
+
f = self.drop(features)
|
| 152 |
+
return self.disease_head(f)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ================================================================
|
| 156 |
+
# LOAD MODEL
|
| 157 |
+
# ================================================================
|
| 158 |
+
print('\nLoading model...')
|
| 159 |
+
model = MultiTaskViT().to(DEVICE)
|
| 160 |
+
ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
|
| 161 |
+
model.load_state_dict(ckpt['model_state_dict'])
|
| 162 |
+
print(f' Loaded: {MODEL_PATH}')
|
| 163 |
+
print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1} '
|
| 164 |
+
f'val_acc={ckpt.get("val_acc", 0):.2f}%')
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ================================================================
|
| 168 |
+
# MC DROPOUT SETUP
|
| 169 |
+
# ================================================================
|
| 170 |
+
def enable_head_dropout(model):
|
| 171 |
+
"""
|
| 172 |
+
Set model to eval mode, then enable dropout ONLY in the classification
|
| 173 |
+
heads (self.drop, disease_head dropouts). The backbone stays fully
|
| 174 |
+
deterministic (eval mode) so we only need one backbone pass per image.
|
| 175 |
+
BatchNorm layers remain in eval mode (use running stats).
|
| 176 |
+
"""
|
| 177 |
+
model.eval() # everything to eval (including backbone)
|
| 178 |
+
|
| 179 |
+
# Enable dropout in the drop layer and disease_head
|
| 180 |
+
model.drop.train()
|
| 181 |
+
for m in model.disease_head.modules():
|
| 182 |
+
if isinstance(m, (nn.Dropout, nn.Dropout2d)):
|
| 183 |
+
m.train()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
enable_head_dropout(model)
|
| 187 |
+
|
| 188 |
+
# Count active dropout layers
|
| 189 |
+
n_dropout_active = 0
|
| 190 |
+
for name, m in model.named_modules():
|
| 191 |
+
if isinstance(m, (nn.Dropout, nn.Dropout2d)) and m.training:
|
| 192 |
+
n_dropout_active += 1
|
| 193 |
+
n_dropout_total = sum(1 for m in model.modules() if isinstance(m, (nn.Dropout, nn.Dropout2d)))
|
| 194 |
+
print(f'\n MC Dropout enabled in heads: {n_dropout_active} active / {n_dropout_total} total dropout layers')
|
| 195 |
+
print(f' Backbone: deterministic (eval mode) -- single pass per image')
|
| 196 |
+
print(f' Heads: stochastic (train mode dropout) -- {T_FORWARD_PASSES} passes per image')
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ================================================================
|
| 200 |
+
# PREPROCESSING (matches gradcam_v3.py pipeline)
|
| 201 |
+
# ================================================================
|
| 202 |
+
def ben_graham(path, sz=IMG_SIZE, sigma=10):
|
| 203 |
+
"""Ben Graham high-frequency fundus enhancement (APTOS-style)."""
|
| 204 |
+
img = cv2.imread(path)
|
| 205 |
+
if img is None:
|
| 206 |
+
img = np.array(Image.open(path).convert('RGB'))
|
| 207 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 208 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 209 |
+
img = cv2.resize(img, (sz, sz))
|
| 210 |
+
img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0, 0), sigma), -4, 128)
|
| 211 |
+
mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
| 212 |
+
cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1)
|
| 213 |
+
return cv2.bitwise_and(img, img, mask=mask)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def clahe_preprocess(path, sz=IMG_SIZE):
|
| 217 |
+
"""CLAHE-based contrast enhancement (ODIR-style)."""
|
| 218 |
+
img = cv2.imread(path)
|
| 219 |
+
if img is None:
|
| 220 |
+
img = np.array(Image.open(path).convert('RGB'))
|
| 221 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 222 |
+
img = cv2.resize(img, (sz, sz))
|
| 223 |
+
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
|
| 224 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 225 |
+
lab[:, :, 0] = clahe.apply(lab[:, :, 0])
|
| 226 |
+
img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
|
| 227 |
+
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def resolve_path(image_path):
|
| 231 |
+
"""Resolve image path relative to BASE_DIR."""
|
| 232 |
+
if os.path.isabs(image_path) and os.path.exists(image_path):
|
| 233 |
+
return image_path
|
| 234 |
+
clean = image_path
|
| 235 |
+
while clean.startswith('./'):
|
| 236 |
+
clean = clean[2:]
|
| 237 |
+
return os.path.join(BASE_DIR, clean)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ================================================================
|
| 241 |
+
# DATASET
|
| 242 |
+
# ================================================================
|
| 243 |
+
class TestDataset(Dataset):
|
| 244 |
+
"""Test dataset loading preprocessed images from cache or live."""
|
| 245 |
+
|
| 246 |
+
def __init__(self, csv_path):
|
| 247 |
+
self.df = pd.read_csv(csv_path).reset_index(drop=True)
|
| 248 |
+
self.transform = transforms.Compose([
|
| 249 |
+
transforms.ToPILImage(),
|
| 250 |
+
transforms.ToTensor(),
|
| 251 |
+
transforms.Normalize(NORM_MEAN, NORM_STD),
|
| 252 |
+
])
|
| 253 |
+
|
| 254 |
+
def __len__(self):
|
| 255 |
+
return len(self.df)
|
| 256 |
+
|
| 257 |
+
def __getitem__(self, idx):
|
| 258 |
+
row = self.df.iloc[idx]
|
| 259 |
+
img_path = str(row['image_path'])
|
| 260 |
+
dataset = str(row.get('source', 'auto'))
|
| 261 |
+
label = int(row['disease_label'])
|
| 262 |
+
|
| 263 |
+
# Try loading from cache first
|
| 264 |
+
cache_path = str(row.get('cache_path', ''))
|
| 265 |
+
if cache_path and cache_path != 'nan':
|
| 266 |
+
cache_abs = resolve_path(cache_path)
|
| 267 |
+
if os.path.exists(cache_abs):
|
| 268 |
+
try:
|
| 269 |
+
img_np = np.load(cache_abs)
|
| 270 |
+
img_tensor = self.transform(img_np)
|
| 271 |
+
return img_tensor, label, img_path
|
| 272 |
+
except Exception:
|
| 273 |
+
pass
|
| 274 |
+
|
| 275 |
+
# Live preprocessing
|
| 276 |
+
abs_path = resolve_path(img_path)
|
| 277 |
+
try:
|
| 278 |
+
if dataset == 'APTOS':
|
| 279 |
+
img_np = ben_graham(abs_path)
|
| 280 |
+
else:
|
| 281 |
+
img_np = clahe_preprocess(abs_path)
|
| 282 |
+
img_tensor = self.transform(img_np)
|
| 283 |
+
except Exception:
|
| 284 |
+
img_tensor = torch.zeros(3, IMG_SIZE, IMG_SIZE)
|
| 285 |
+
|
| 286 |
+
return img_tensor, label, img_path
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# ================================================================
|
| 290 |
+
# TWO-STAGE MC DROPOUT INFERENCE
|
| 291 |
+
# ================================================================
|
| 292 |
+
def extract_all_features(model, dataloader):
|
| 293 |
+
"""
|
| 294 |
+
Stage 1: Run backbone once per image to get CLS features (deterministic).
|
| 295 |
+
Returns features (N, 768), labels (N,), paths list.
|
| 296 |
+
"""
|
| 297 |
+
all_features = []
|
| 298 |
+
all_labels = []
|
| 299 |
+
all_paths = []
|
| 300 |
+
|
| 301 |
+
print(f'\n Stage 1: Extracting backbone features (deterministic)...')
|
| 302 |
+
with torch.no_grad():
|
| 303 |
+
for images, labels, paths in tqdm(dataloader, desc=' Features', ncols=80):
|
| 304 |
+
images = images.to(DEVICE)
|
| 305 |
+
feats = model.extract_features(images) # (B, 768)
|
| 306 |
+
all_features.append(feats.cpu())
|
| 307 |
+
all_labels.extend(labels.numpy().tolist())
|
| 308 |
+
all_paths.extend(paths)
|
| 309 |
+
|
| 310 |
+
all_features = torch.cat(all_features, dim=0) # (N, 768)
|
| 311 |
+
all_labels = np.array(all_labels)
|
| 312 |
+
return all_features, all_labels, all_paths
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def mc_dropout_on_heads(model, features, T=T_FORWARD_PASSES, temperature=TEMPERATURE):
|
| 316 |
+
"""
|
| 317 |
+
Stage 2: Run T stochastic forward passes through heads only.
|
| 318 |
+
features: (N, 768) tensor
|
| 319 |
+
Returns: (N, T, C) numpy array of probability vectors.
|
| 320 |
+
"""
|
| 321 |
+
N = features.size(0)
|
| 322 |
+
all_probs = np.zeros((N, T, NUM_CLASSES), dtype=np.float32)
|
| 323 |
+
|
| 324 |
+
print(f'\n Stage 2: MC Dropout through heads ({T} passes, {N} samples)...')
|
| 325 |
+
|
| 326 |
+
with torch.no_grad():
|
| 327 |
+
for t in tqdm(range(T), desc=' MC Passes', ncols=80):
|
| 328 |
+
# Process in chunks to avoid memory issues
|
| 329 |
+
for start in range(0, N, HEAD_BATCH):
|
| 330 |
+
end = min(start + HEAD_BATCH, N)
|
| 331 |
+
feat_batch = features[start:end].to(DEVICE)
|
| 332 |
+
logits = model.forward_heads(feat_batch)
|
| 333 |
+
scaled = logits / temperature
|
| 334 |
+
probs = F.softmax(scaled, dim=1)
|
| 335 |
+
all_probs[start:end, t, :] = probs.cpu().numpy()
|
| 336 |
+
|
| 337 |
+
return all_probs
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# ================================================================
|
| 341 |
+
# UNCERTAINTY METRICS
|
| 342 |
+
# ================================================================
|
| 343 |
+
def compute_uncertainty_metrics(mc_probs):
|
| 344 |
+
"""
|
| 345 |
+
Compute uncertainty metrics from MC dropout probability samples.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
mc_probs: (N, T, C) array of MC sampled probability vectors
|
| 349 |
+
|
| 350 |
+
Returns dict with:
|
| 351 |
+
- p_mean, predicted_class, max_confidence
|
| 352 |
+
- predictive_entropy (total), expected_entropy (aleatoric),
|
| 353 |
+
mutual_info (epistemic), class_variance
|
| 354 |
+
"""
|
| 355 |
+
N, T, C = mc_probs.shape
|
| 356 |
+
eps = 1e-10
|
| 357 |
+
|
| 358 |
+
# Predictive mean: average over T passes
|
| 359 |
+
p_mean = mc_probs.mean(axis=1) # (N, C)
|
| 360 |
+
predicted_class = p_mean.argmax(axis=1) # (N,)
|
| 361 |
+
max_confidence = p_mean.max(axis=1) # (N,)
|
| 362 |
+
|
| 363 |
+
# Predictive entropy: H[p_bar] = -sum(p_bar * log(p_bar)) -- TOTAL uncertainty
|
| 364 |
+
predictive_entropy = -np.sum(p_mean * np.log(p_mean + eps), axis=1) # (N,)
|
| 365 |
+
|
| 366 |
+
# Per-pass entropies
|
| 367 |
+
per_pass_entropy = -np.sum(mc_probs * np.log(mc_probs + eps), axis=2) # (N, T)
|
| 368 |
+
|
| 369 |
+
# Expected entropy: E_t[H[p_t]] -- ALEATORIC uncertainty
|
| 370 |
+
expected_entropy = per_pass_entropy.mean(axis=1) # (N,)
|
| 371 |
+
|
| 372 |
+
# Mutual information: H - E[H] -- EPISTEMIC uncertainty
|
| 373 |
+
mutual_info = predictive_entropy - expected_entropy
|
| 374 |
+
mutual_info = np.maximum(mutual_info, 0.0)
|
| 375 |
+
|
| 376 |
+
# Prediction variance per class
|
| 377 |
+
class_variance = mc_probs.var(axis=1) # (N, C)
|
| 378 |
+
|
| 379 |
+
return {
|
| 380 |
+
'p_mean': p_mean,
|
| 381 |
+
'predicted_class': predicted_class,
|
| 382 |
+
'max_confidence': max_confidence,
|
| 383 |
+
'predictive_entropy': predictive_entropy,
|
| 384 |
+
'expected_entropy': expected_entropy,
|
| 385 |
+
'mutual_info': mutual_info,
|
| 386 |
+
'class_variance': class_variance,
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
# ================================================================
|
| 391 |
+
# PLOTTING FUNCTIONS
|
| 392 |
+
# ================================================================
|
| 393 |
+
def plot_uncertainty_vs_accuracy(metrics, labels, save_path):
|
| 394 |
+
"""Scatter: total uncertainty vs correctness, colored by class."""
|
| 395 |
+
correct = (metrics['predicted_class'] == labels).astype(int)
|
| 396 |
+
entropy = metrics['predictive_entropy']
|
| 397 |
+
|
| 398 |
+
fig, ax = plt.subplots(figsize=(10, 7))
|
| 399 |
+
|
| 400 |
+
colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES))
|
| 401 |
+
for cls_idx in range(NUM_CLASSES):
|
| 402 |
+
mask = labels == cls_idx
|
| 403 |
+
ax.scatter(
|
| 404 |
+
entropy[mask], correct[mask] + np.random.uniform(-0.08, 0.08, mask.sum()),
|
| 405 |
+
c=[colors[cls_idx]], alpha=0.5, s=20, label=CLASS_NAMES[cls_idx],
|
| 406 |
+
edgecolors='none'
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
ax.set_xlabel('Predictive Entropy (Total Uncertainty)', fontsize=12)
|
| 410 |
+
ax.set_ylabel('Correctness (1=correct, 0=wrong)', fontsize=12)
|
| 411 |
+
ax.set_title('MC Dropout: Uncertainty vs Prediction Correctness', fontsize=14)
|
| 412 |
+
ax.set_yticks([0, 1])
|
| 413 |
+
ax.set_yticklabels(['Incorrect', 'Correct'])
|
| 414 |
+
ax.legend(title='True Class', fontsize=9, title_fontsize=10)
|
| 415 |
+
ax.grid(True, alpha=0.3)
|
| 416 |
+
|
| 417 |
+
# Add vertical line at median uncertainty
|
| 418 |
+
med = np.median(entropy)
|
| 419 |
+
ax.axvline(med, color='red', linestyle='--', alpha=0.5, label=f'Median H={med:.3f}')
|
| 420 |
+
|
| 421 |
+
# Summary stats
|
| 422 |
+
correct_ent = entropy[correct == 1]
|
| 423 |
+
wrong_ent = entropy[correct == 0]
|
| 424 |
+
textstr = (f'Correct: mean H={correct_ent.mean():.3f}\n'
|
| 425 |
+
f'Wrong: mean H={wrong_ent.mean():.3f}' if len(wrong_ent) > 0
|
| 426 |
+
else f'Correct: mean H={correct_ent.mean():.3f}')
|
| 427 |
+
ax.text(0.98, 0.5, textstr, transform=ax.transAxes,
|
| 428 |
+
fontsize=9, verticalalignment='center', horizontalalignment='right',
|
| 429 |
+
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
|
| 430 |
+
|
| 431 |
+
plt.tight_layout()
|
| 432 |
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 433 |
+
plt.close(fig)
|
| 434 |
+
print(f' Saved: {save_path}')
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def plot_rejection_curve(metrics, labels, save_path):
|
| 438 |
+
"""Accuracy as a function of rejection threshold on uncertainty."""
|
| 439 |
+
entropy = metrics['predictive_entropy']
|
| 440 |
+
correct = (metrics['predicted_class'] == labels).astype(int)
|
| 441 |
+
|
| 442 |
+
# Sort by decreasing uncertainty
|
| 443 |
+
sorted_idx = np.argsort(entropy)[::-1]
|
| 444 |
+
sorted_correct = correct[sorted_idx]
|
| 445 |
+
|
| 446 |
+
N = len(labels)
|
| 447 |
+
rejection_fracs = np.linspace(0.0, 0.95, 200)
|
| 448 |
+
accuracies = []
|
| 449 |
+
n_remaining = []
|
| 450 |
+
|
| 451 |
+
for frac in rejection_fracs:
|
| 452 |
+
n_reject = int(frac * N)
|
| 453 |
+
kept = sorted_correct[n_reject:]
|
| 454 |
+
if len(kept) == 0:
|
| 455 |
+
accuracies.append(np.nan)
|
| 456 |
+
n_remaining.append(0)
|
| 457 |
+
else:
|
| 458 |
+
accuracies.append(kept.mean() * 100)
|
| 459 |
+
n_remaining.append(len(kept))
|
| 460 |
+
|
| 461 |
+
accuracies = np.array(accuracies)
|
| 462 |
+
n_remaining = np.array(n_remaining)
|
| 463 |
+
|
| 464 |
+
fig, ax1 = plt.subplots(figsize=(10, 7))
|
| 465 |
+
|
| 466 |
+
color1 = '#2196F3'
|
| 467 |
+
ax1.plot(rejection_fracs * 100, accuracies, color=color1, linewidth=2.0,
|
| 468 |
+
label='Accuracy')
|
| 469 |
+
ax1.set_xlabel('Rejection Rate (%)', fontsize=12)
|
| 470 |
+
ax1.set_ylabel('Accuracy (%)', fontsize=12, color=color1)
|
| 471 |
+
ax1.tick_params(axis='y', labelcolor=color1)
|
| 472 |
+
ax1.set_ylim([max(50, np.nanmin(accuracies) - 5), 101])
|
| 473 |
+
|
| 474 |
+
# Secondary axis: number of remaining samples
|
| 475 |
+
ax2 = ax1.twinx()
|
| 476 |
+
color2 = '#FF9800'
|
| 477 |
+
ax2.plot(rejection_fracs * 100, n_remaining, color=color2, linewidth=1.5,
|
| 478 |
+
linestyle='--', alpha=0.7, label='Remaining')
|
| 479 |
+
ax2.set_ylabel('Samples Remaining', fontsize=12, color=color2)
|
| 480 |
+
ax2.tick_params(axis='y', labelcolor=color2)
|
| 481 |
+
|
| 482 |
+
# Baseline accuracy (no rejection)
|
| 483 |
+
base_acc = correct.mean() * 100
|
| 484 |
+
ax1.axhline(base_acc, color='gray', linestyle=':', alpha=0.5)
|
| 485 |
+
ax1.text(2, base_acc + 0.5, f'Baseline: {base_acc:.1f}%', fontsize=9, color='gray')
|
| 486 |
+
|
| 487 |
+
ax1.set_title('Rejection Curve: Accuracy vs Uncertainty-Based Rejection', fontsize=14)
|
| 488 |
+
ax1.grid(True, alpha=0.3)
|
| 489 |
+
|
| 490 |
+
# Combined legend
|
| 491 |
+
lines1, labels1 = ax1.get_legend_handles_labels()
|
| 492 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
| 493 |
+
ax1.legend(lines1 + lines2, labels1 + labels2, loc='lower left', fontsize=10)
|
| 494 |
+
|
| 495 |
+
plt.tight_layout()
|
| 496 |
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 497 |
+
plt.close(fig)
|
| 498 |
+
print(f' Saved: {save_path}')
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def plot_epistemic_vs_aleatoric(metrics, labels, save_path):
|
| 502 |
+
"""Scatter separating epistemic and aleatoric uncertainty."""
|
| 503 |
+
aleatoric = metrics['expected_entropy']
|
| 504 |
+
epistemic = metrics['mutual_info']
|
| 505 |
+
correct = (metrics['predicted_class'] == labels).astype(int)
|
| 506 |
+
|
| 507 |
+
fig, ax = plt.subplots(figsize=(10, 7))
|
| 508 |
+
|
| 509 |
+
colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES))
|
| 510 |
+
for cls_idx in range(NUM_CLASSES):
|
| 511 |
+
mask = labels == cls_idx
|
| 512 |
+
ax.scatter(
|
| 513 |
+
aleatoric[mask], epistemic[mask],
|
| 514 |
+
c=[colors[cls_idx]], alpha=0.45, s=20, label=CLASS_NAMES[cls_idx],
|
| 515 |
+
edgecolors='none'
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Mark misclassified samples
|
| 519 |
+
wrong_mask = correct == 0
|
| 520 |
+
if wrong_mask.sum() > 0:
|
| 521 |
+
ax.scatter(
|
| 522 |
+
aleatoric[wrong_mask], epistemic[wrong_mask],
|
| 523 |
+
facecolors='none', edgecolors='red', s=60, linewidths=1.2,
|
| 524 |
+
label='Misclassified', zorder=5
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
ax.set_xlabel('Aleatoric Uncertainty (Expected Entropy)', fontsize=12)
|
| 528 |
+
ax.set_ylabel('Epistemic Uncertainty (Mutual Information)', fontsize=12)
|
| 529 |
+
ax.set_title('Decomposition of Uncertainty: Epistemic vs Aleatoric', fontsize=14)
|
| 530 |
+
ax.legend(fontsize=9, title='Class', title_fontsize=10)
|
| 531 |
+
ax.grid(True, alpha=0.3)
|
| 532 |
+
|
| 533 |
+
# Annotate quadrants
|
| 534 |
+
xlim = ax.get_xlim()
|
| 535 |
+
ylim = ax.get_ylim()
|
| 536 |
+
ax.text(xlim[0] + 0.02 * (xlim[1] - xlim[0]),
|
| 537 |
+
ylim[1] - 0.05 * (ylim[1] - ylim[0]),
|
| 538 |
+
'Low aleatoric\nHigh epistemic\n(need more data)',
|
| 539 |
+
fontsize=8, alpha=0.6, va='top')
|
| 540 |
+
ax.text(xlim[1] - 0.02 * (xlim[1] - xlim[0]),
|
| 541 |
+
ylim[1] - 0.05 * (ylim[1] - ylim[0]),
|
| 542 |
+
'High aleatoric\nHigh epistemic\n(hard + unseen)',
|
| 543 |
+
fontsize=8, alpha=0.6, va='top', ha='right')
|
| 544 |
+
ax.text(xlim[1] - 0.02 * (xlim[1] - xlim[0]),
|
| 545 |
+
ylim[0] + 0.05 * (ylim[1] - ylim[0]),
|
| 546 |
+
'High aleatoric\nLow epistemic\n(inherently noisy)',
|
| 547 |
+
fontsize=8, alpha=0.6, va='bottom', ha='right')
|
| 548 |
+
|
| 549 |
+
plt.tight_layout()
|
| 550 |
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 551 |
+
plt.close(fig)
|
| 552 |
+
print(f' Saved: {save_path}')
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def plot_uncertainty_by_class(metrics, labels, save_path):
|
| 556 |
+
"""Box plots of uncertainty per class."""
|
| 557 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
| 558 |
+
|
| 559 |
+
data_types = [
|
| 560 |
+
('predictive_entropy', 'Total Uncertainty (Predictive Entropy)'),
|
| 561 |
+
('expected_entropy', 'Aleatoric Uncertainty (Expected Entropy)'),
|
| 562 |
+
('mutual_info', 'Epistemic Uncertainty (Mutual Information)'),
|
| 563 |
+
]
|
| 564 |
+
|
| 565 |
+
for ax, (key, title) in zip(axes, data_types):
|
| 566 |
+
data = metrics[key]
|
| 567 |
+
box_data = [data[labels == c] for c in range(NUM_CLASSES)]
|
| 568 |
+
|
| 569 |
+
bp = ax.boxplot(box_data, labels=CLASS_NAMES, patch_artist=True,
|
| 570 |
+
widths=0.6, showfliers=True,
|
| 571 |
+
flierprops=dict(marker='o', markersize=3, alpha=0.3))
|
| 572 |
+
|
| 573 |
+
colors = plt.cm.Set2(np.linspace(0, 1, NUM_CLASSES))
|
| 574 |
+
for patch, color in zip(bp['boxes'], colors):
|
| 575 |
+
patch.set_facecolor(color)
|
| 576 |
+
patch.set_alpha(0.7)
|
| 577 |
+
|
| 578 |
+
ax.set_title(title, fontsize=11)
|
| 579 |
+
ax.set_ylabel('Uncertainty', fontsize=10)
|
| 580 |
+
ax.grid(True, axis='y', alpha=0.3)
|
| 581 |
+
ax.tick_params(axis='x', rotation=15)
|
| 582 |
+
|
| 583 |
+
# Add sample counts
|
| 584 |
+
for i, cls_data in enumerate(box_data):
|
| 585 |
+
ax.text(i + 1, ax.get_ylim()[1] * 0.95,
|
| 586 |
+
f'n={len(cls_data)}', ha='center', fontsize=8, alpha=0.6)
|
| 587 |
+
|
| 588 |
+
plt.suptitle('Uncertainty Distribution by Disease Class', fontsize=14, y=1.02)
|
| 589 |
+
plt.tight_layout()
|
| 590 |
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 591 |
+
plt.close(fig)
|
| 592 |
+
print(f' Saved: {save_path}')
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def plot_confidence_vs_uncertainty(metrics, labels, save_path):
|
| 596 |
+
"""Scatter showing confidence vs uncertainty (should be anti-correlated)."""
|
| 597 |
+
confidence = metrics['max_confidence']
|
| 598 |
+
entropy = metrics['predictive_entropy']
|
| 599 |
+
correct = (metrics['predicted_class'] == labels).astype(int)
|
| 600 |
+
|
| 601 |
+
fig, ax = plt.subplots(figsize=(10, 7))
|
| 602 |
+
|
| 603 |
+
scatter_correct = ax.scatter(
|
| 604 |
+
confidence[correct == 1], entropy[correct == 1],
|
| 605 |
+
c='#4CAF50', alpha=0.4, s=15, label='Correct', edgecolors='none'
|
| 606 |
+
)
|
| 607 |
+
scatter_wrong = ax.scatter(
|
| 608 |
+
confidence[correct == 0], entropy[correct == 0],
|
| 609 |
+
c='#F44336', alpha=0.6, s=25, label='Incorrect', edgecolors='none',
|
| 610 |
+
marker='x', linewidths=1.0
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# Compute correlation
|
| 614 |
+
from scipy import stats
|
| 615 |
+
r, p_val = stats.pearsonr(confidence, entropy)
|
| 616 |
+
|
| 617 |
+
ax.set_xlabel('Maximum Confidence (max p_bar)', fontsize=12)
|
| 618 |
+
ax.set_ylabel('Predictive Entropy (Total Uncertainty)', fontsize=12)
|
| 619 |
+
ax.set_title(f'Confidence vs Uncertainty (Pearson r={r:.3f}, p={p_val:.2e})', fontsize=14)
|
| 620 |
+
ax.legend(fontsize=10)
|
| 621 |
+
ax.grid(True, alpha=0.3)
|
| 622 |
+
|
| 623 |
+
# Add trend line
|
| 624 |
+
z = np.polyfit(confidence, entropy, 1)
|
| 625 |
+
x_line = np.linspace(confidence.min(), confidence.max(), 100)
|
| 626 |
+
ax.plot(x_line, np.polyval(z, x_line), 'k--', alpha=0.4, linewidth=1.5)
|
| 627 |
+
|
| 628 |
+
plt.tight_layout()
|
| 629 |
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 630 |
+
plt.close(fig)
|
| 631 |
+
print(f' Saved: {save_path}')
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
# ================================================================
|
| 635 |
+
# MAIN
|
| 636 |
+
# ================================================================
|
| 637 |
+
def main():
|
| 638 |
+
t_start = time.time()
|
| 639 |
+
|
| 640 |
+
# ---- 1. Build DataLoader ----
|
| 641 |
+
print('\nLoading test set...')
|
| 642 |
+
dataset = TestDataset(TEST_CSV)
|
| 643 |
+
dataloader = DataLoader(
|
| 644 |
+
dataset, batch_size=BATCH_SIZE, shuffle=False,
|
| 645 |
+
num_workers=2, pin_memory=False
|
| 646 |
+
)
|
| 647 |
+
print(f' Test samples: {len(dataset)}')
|
| 648 |
+
|
| 649 |
+
# ---- 2. Stage 1: Extract backbone features (single deterministic pass) ----
|
| 650 |
+
features, true_labels, image_paths = extract_all_features(model, dataloader)
|
| 651 |
+
print(f' Features shape: {features.shape}')
|
| 652 |
+
|
| 653 |
+
t_feat = time.time() - t_start
|
| 654 |
+
print(f' Feature extraction: {t_feat:.1f}s')
|
| 655 |
+
|
| 656 |
+
# ---- 3. Stage 2: MC Dropout on heads only ----
|
| 657 |
+
mc_probs = mc_dropout_on_heads(
|
| 658 |
+
model, features, T=T_FORWARD_PASSES, temperature=TEMPERATURE
|
| 659 |
+
)
|
| 660 |
+
print(f' MC probs shape: {mc_probs.shape} (N, T, C)')
|
| 661 |
+
|
| 662 |
+
t_mc = time.time() - t_start - t_feat
|
| 663 |
+
print(f' MC head passes: {t_mc:.1f}s')
|
| 664 |
+
|
| 665 |
+
# ---- 4. Compute Uncertainty Metrics ----
|
| 666 |
+
print('\nComputing uncertainty metrics...')
|
| 667 |
+
metrics = compute_uncertainty_metrics(mc_probs)
|
| 668 |
+
|
| 669 |
+
# Print summary statistics
|
| 670 |
+
correct = (metrics['predicted_class'] == true_labels).astype(int)
|
| 671 |
+
accuracy = correct.mean() * 100
|
| 672 |
+
print(f'\n --- Summary ---')
|
| 673 |
+
print(f' Accuracy (MC mean): {accuracy:.2f}%')
|
| 674 |
+
print(f' Predictive entropy: mean={metrics["predictive_entropy"].mean():.4f}, '
|
| 675 |
+
f'std={metrics["predictive_entropy"].std():.4f}')
|
| 676 |
+
print(f' Aleatoric (exp. ent.): mean={metrics["expected_entropy"].mean():.4f}, '
|
| 677 |
+
f'std={metrics["expected_entropy"].std():.4f}')
|
| 678 |
+
print(f' Epistemic (MI): mean={metrics["mutual_info"].mean():.4f}, '
|
| 679 |
+
f'std={metrics["mutual_info"].std():.4f}')
|
| 680 |
+
print(f' Max confidence: mean={metrics["max_confidence"].mean():.4f}, '
|
| 681 |
+
f'std={metrics["max_confidence"].std():.4f}')
|
| 682 |
+
|
| 683 |
+
# Per-class stats
|
| 684 |
+
print(f'\n Per-class uncertainty (predictive entropy):')
|
| 685 |
+
for cls_idx in range(NUM_CLASSES):
|
| 686 |
+
mask = true_labels == cls_idx
|
| 687 |
+
n_cls = mask.sum()
|
| 688 |
+
cls_acc = correct[mask].mean() * 100 if n_cls > 0 else 0
|
| 689 |
+
cls_ent = metrics['predictive_entropy'][mask].mean() if n_cls > 0 else 0
|
| 690 |
+
cls_mi = metrics['mutual_info'][mask].mean() if n_cls > 0 else 0
|
| 691 |
+
print(f' {CLASS_NAMES[cls_idx]:15s}: n={n_cls:4d}, '
|
| 692 |
+
f'acc={cls_acc:5.1f}%, H={cls_ent:.4f}, MI={cls_mi:.4f}')
|
| 693 |
+
|
| 694 |
+
# ---- 5. Generate Plots ----
|
| 695 |
+
print('\nGenerating plots...')
|
| 696 |
+
|
| 697 |
+
plot_uncertainty_vs_accuracy(
|
| 698 |
+
metrics, true_labels,
|
| 699 |
+
os.path.join(UNCERT_DIR, 'uncertainty_vs_accuracy.png')
|
| 700 |
+
)
|
| 701 |
+
plot_rejection_curve(
|
| 702 |
+
metrics, true_labels,
|
| 703 |
+
os.path.join(UNCERT_DIR, 'rejection_curve.png')
|
| 704 |
+
)
|
| 705 |
+
plot_epistemic_vs_aleatoric(
|
| 706 |
+
metrics, true_labels,
|
| 707 |
+
os.path.join(UNCERT_DIR, 'epistemic_vs_aleatoric.png')
|
| 708 |
+
)
|
| 709 |
+
plot_uncertainty_by_class(
|
| 710 |
+
metrics, true_labels,
|
| 711 |
+
os.path.join(UNCERT_DIR, 'uncertainty_by_class.png')
|
| 712 |
+
)
|
| 713 |
+
plot_confidence_vs_uncertainty(
|
| 714 |
+
metrics, true_labels,
|
| 715 |
+
os.path.join(UNCERT_DIR, 'confidence_vs_uncertainty.png')
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# ---- 6. Save JSON Results ----
|
| 719 |
+
print('\nSaving results JSON...')
|
| 720 |
+
|
| 721 |
+
per_image = []
|
| 722 |
+
for i in range(len(true_labels)):
|
| 723 |
+
per_image.append({
|
| 724 |
+
'image_path': image_paths[i],
|
| 725 |
+
'true_label': int(true_labels[i]),
|
| 726 |
+
'true_class': CLASS_NAMES[int(true_labels[i])],
|
| 727 |
+
'predicted_label': int(metrics['predicted_class'][i]),
|
| 728 |
+
'predicted_class': CLASS_NAMES[int(metrics['predicted_class'][i])],
|
| 729 |
+
'correct': bool(correct[i]),
|
| 730 |
+
'max_confidence': round(float(metrics['max_confidence'][i]), 6),
|
| 731 |
+
'predictive_entropy': round(float(metrics['predictive_entropy'][i]), 6),
|
| 732 |
+
'expected_entropy': round(float(metrics['expected_entropy'][i]), 6),
|
| 733 |
+
'mutual_information': round(float(metrics['mutual_info'][i]), 6),
|
| 734 |
+
'class_variance': [round(float(v), 8) for v in metrics['class_variance'][i]],
|
| 735 |
+
'mean_probs': [round(float(v), 6) for v in metrics['p_mean'][i]],
|
| 736 |
+
})
|
| 737 |
+
|
| 738 |
+
aggregate = {
|
| 739 |
+
'n_samples': int(len(true_labels)),
|
| 740 |
+
'n_classes': NUM_CLASSES,
|
| 741 |
+
'mc_passes': T_FORWARD_PASSES,
|
| 742 |
+
'temperature': TEMPERATURE,
|
| 743 |
+
'accuracy_pct': round(float(accuracy), 4),
|
| 744 |
+
'overall': {
|
| 745 |
+
'predictive_entropy': {
|
| 746 |
+
'mean': round(float(metrics['predictive_entropy'].mean()), 6),
|
| 747 |
+
'std': round(float(metrics['predictive_entropy'].std()), 6),
|
| 748 |
+
'min': round(float(metrics['predictive_entropy'].min()), 6),
|
| 749 |
+
'max': round(float(metrics['predictive_entropy'].max()), 6),
|
| 750 |
+
},
|
| 751 |
+
'expected_entropy': {
|
| 752 |
+
'mean': round(float(metrics['expected_entropy'].mean()), 6),
|
| 753 |
+
'std': round(float(metrics['expected_entropy'].std()), 6),
|
| 754 |
+
'min': round(float(metrics['expected_entropy'].min()), 6),
|
| 755 |
+
'max': round(float(metrics['expected_entropy'].max()), 6),
|
| 756 |
+
},
|
| 757 |
+
'mutual_information': {
|
| 758 |
+
'mean': round(float(metrics['mutual_info'].mean()), 6),
|
| 759 |
+
'std': round(float(metrics['mutual_info'].std()), 6),
|
| 760 |
+
'min': round(float(metrics['mutual_info'].min()), 6),
|
| 761 |
+
'max': round(float(metrics['mutual_info'].max()), 6),
|
| 762 |
+
},
|
| 763 |
+
'max_confidence': {
|
| 764 |
+
'mean': round(float(metrics['max_confidence'].mean()), 6),
|
| 765 |
+
'std': round(float(metrics['max_confidence'].std()), 6),
|
| 766 |
+
},
|
| 767 |
+
},
|
| 768 |
+
'per_class': {},
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
for cls_idx in range(NUM_CLASSES):
|
| 772 |
+
mask = true_labels == cls_idx
|
| 773 |
+
n_cls = int(mask.sum())
|
| 774 |
+
if n_cls == 0:
|
| 775 |
+
continue
|
| 776 |
+
aggregate['per_class'][CLASS_NAMES[cls_idx]] = {
|
| 777 |
+
'n_samples': n_cls,
|
| 778 |
+
'accuracy': round(float(correct[mask].mean() * 100), 4),
|
| 779 |
+
'pred_entropy_mean': round(float(metrics['predictive_entropy'][mask].mean()), 6),
|
| 780 |
+
'pred_entropy_std': round(float(metrics['predictive_entropy'][mask].std()), 6),
|
| 781 |
+
'aleatoric_mean': round(float(metrics['expected_entropy'][mask].mean()), 6),
|
| 782 |
+
'epistemic_mean': round(float(metrics['mutual_info'][mask].mean()), 6),
|
| 783 |
+
'confidence_mean': round(float(metrics['max_confidence'][mask].mean()), 6),
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
# Rejection curve data at key thresholds
|
| 787 |
+
entropy = metrics['predictive_entropy']
|
| 788 |
+
sorted_idx = np.argsort(entropy)[::-1]
|
| 789 |
+
sorted_correct = correct[sorted_idx]
|
| 790 |
+
rejection_checkpoints = {}
|
| 791 |
+
for frac in [0.0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.50]:
|
| 792 |
+
n_reject = int(frac * len(true_labels))
|
| 793 |
+
kept = sorted_correct[n_reject:]
|
| 794 |
+
if len(kept) > 0:
|
| 795 |
+
rejection_checkpoints[f'reject_{int(frac*100)}pct'] = {
|
| 796 |
+
'accuracy': round(float(kept.mean() * 100), 4),
|
| 797 |
+
'n_remaining': int(len(kept)),
|
| 798 |
+
}
|
| 799 |
+
aggregate['rejection_curve'] = rejection_checkpoints
|
| 800 |
+
|
| 801 |
+
results = {
|
| 802 |
+
'aggregate': aggregate,
|
| 803 |
+
'per_image': per_image,
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
json_path = os.path.join(UNCERT_DIR, 'mc_dropout_results.json')
|
| 807 |
+
with open(json_path, 'w') as f:
|
| 808 |
+
json.dump(results, f, indent=2)
|
| 809 |
+
print(f' Saved: {json_path}')
|
| 810 |
+
|
| 811 |
+
elapsed = time.time() - t_start
|
| 812 |
+
print(f'\nDone in {elapsed:.1f}s')
|
| 813 |
+
print('=' * 65)
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
if __name__ == '__main__':
|
| 817 |
+
main()
|