tanishq74 commited on
Commit
deb35c1
·
verified ·
1 Parent(s): 127e14e

Add eval_dashboard.py

Browse files
Files changed (1) hide show
  1. eval_dashboard.py +752 -0
eval_dashboard.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RetinaSense v3.0 -- Phase 1A: Rich Evaluation Dashboard
4
+ ========================================================
5
+ Standalone script that loads the trained ViT model, runs inference on the
6
+ full test set (1,281 images), and produces publication-quality evaluation
7
+ plots plus a structured metrics JSON report.
8
+
9
+ Outputs (all written to outputs_v3/evaluation/):
10
+ - confusion_matrix.png
11
+ - roc_curves_per_class.png
12
+ - precision_recall_curves.png
13
+ - calibration_reliability.png
14
+ - confidence_histograms.png
15
+ - error_analysis_by_source.png
16
+ - metrics_report.json
17
+
18
+ Usage:
19
+ python eval_dashboard.py
20
+ """
21
+
22
+ import os
23
+ import sys
24
+ import json
25
+ import warnings
26
+ import numpy as np
27
+ import pandas as pd
28
+ import cv2
29
+ import matplotlib
30
+ matplotlib.use('Agg')
31
+ import matplotlib.pyplot as plt
32
+ import matplotlib.ticker as mticker
33
+ import seaborn as sns
34
+ from PIL import Image
35
+ from collections import OrderedDict
36
+
37
+ warnings.filterwarnings('ignore')
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+ from torch.utils.data import Dataset, DataLoader
43
+ from torchvision import transforms
44
+ import timm
45
+
46
+ from sklearn.metrics import (
47
+ confusion_matrix,
48
+ classification_report,
49
+ roc_curve,
50
+ auc,
51
+ precision_recall_curve,
52
+ average_precision_score,
53
+ f1_score,
54
+ accuracy_score,
55
+ cohen_kappa_score,
56
+ matthews_corrcoef,
57
+ balanced_accuracy_score,
58
+ log_loss,
59
+ )
60
+
61
+ # ================================================================
62
+ # CONFIGURATION
63
+ # ================================================================
64
+ BASE_DIR = '/teamspace/studios/this_studio'
65
+ OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3')
66
+ EVAL_DIR = os.path.join(OUTPUT_DIR, 'evaluation')
67
+ os.makedirs(EVAL_DIR, exist_ok=True)
68
+
69
+ MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth')
70
+ THRESHOLDS_PATH = os.path.join(OUTPUT_DIR, 'thresholds.json')
71
+ TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json')
72
+ TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv')
73
+ NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json')
74
+
75
+ NUM_CLASSES = 5
76
+ IMG_SIZE = 224
77
+ DROPOUT = 0.3
78
+ BATCH_SIZE = 32
79
+
80
+ CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
81
+
82
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83
+
84
+ # Publication style defaults
85
+ plt.rcParams.update({
86
+ 'font.size': 11,
87
+ 'axes.titlesize': 13,
88
+ 'axes.labelsize': 12,
89
+ 'xtick.labelsize': 10,
90
+ 'ytick.labelsize': 10,
91
+ 'legend.fontsize': 10,
92
+ 'figure.dpi': 300,
93
+ 'savefig.dpi': 300,
94
+ 'savefig.bbox': 'tight',
95
+ 'savefig.pad_inches': 0.15,
96
+ 'font.family': 'sans-serif',
97
+ })
98
+
99
+ print('=' * 65)
100
+ print(' RetinaSense v3.0 -- Phase 1A: Evaluation Dashboard')
101
+ print('=' * 65)
102
+ print(f' Device : {DEVICE}')
103
+ if torch.cuda.is_available():
104
+ print(f' GPU : {torch.cuda.get_device_name(0)}')
105
+ print(f' Output : {EVAL_DIR}')
106
+ print('=' * 65)
107
+
108
+
109
+ # ================================================================
110
+ # LOAD NORMALISATION STATS
111
+ # ================================================================
112
+ if os.path.exists(NORM_STATS_PATH):
113
+ with open(NORM_STATS_PATH) as f:
114
+ norm_stats = json.load(f)
115
+ NORM_MEAN = norm_stats['mean_rgb']
116
+ NORM_STD = norm_stats['std_rgb']
117
+ print(f' Fundus norm stats loaded: mean={[round(v, 4) for v in NORM_MEAN]}, '
118
+ f'std={[round(v, 4) for v in NORM_STD]}')
119
+ else:
120
+ NORM_MEAN = [0.485, 0.456, 0.406]
121
+ NORM_STD = [0.229, 0.224, 0.225]
122
+ print(' Using ImageNet normalisation fallback')
123
+
124
+
125
+ # ================================================================
126
+ # MODEL ARCHITECTURE (mirrors retinasense_v3.py / gradcam_v3.py)
127
+ # ================================================================
128
+ class MultiTaskViT(nn.Module):
129
+ """ViT-Base-Patch16-224 with disease + severity heads."""
130
+
131
+ def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT):
132
+ super().__init__()
133
+ self.backbone = timm.create_model(
134
+ 'vit_base_patch16_224', pretrained=False, num_classes=0
135
+ )
136
+ feat = 768 # CLS token dimension
137
+ self.drop = nn.Dropout(drop)
138
+ self.disease_head = nn.Sequential(
139
+ nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
140
+ nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
141
+ nn.Linear(256, n_disease),
142
+ )
143
+ self.severity_head = nn.Sequential(
144
+ nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
145
+ nn.Linear(256, n_severity),
146
+ )
147
+
148
+ def forward(self, x):
149
+ f = self.backbone(x) # (B, 768) CLS token features
150
+ f = self.drop(f)
151
+ return self.disease_head(f), self.severity_head(f)
152
+
153
+
154
+ # ================================================================
155
+ # LOAD MODEL + CALIBRATION ARTIFACTS
156
+ # ================================================================
157
+ print('\nLoading model...')
158
+ model = MultiTaskViT().to(DEVICE)
159
+ ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
160
+ model.load_state_dict(ckpt['model_state_dict'])
161
+ model.eval()
162
+ print(f' Loaded: {MODEL_PATH}')
163
+ print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1} '
164
+ f'val_acc={ckpt.get("val_acc", 0):.2f}%')
165
+
166
+ with open(THRESHOLDS_PATH) as f:
167
+ thr_data = json.load(f)
168
+ THRESHOLDS = thr_data['thresholds']
169
+
170
+ with open(TEMPERATURE_PATH) as f:
171
+ temp_data = json.load(f)
172
+ TEMPERATURE = temp_data['temperature']
173
+
174
+ print(f' Temperature T = {TEMPERATURE:.4f}')
175
+ print(f' Thresholds = {[round(t, 3) for t in THRESHOLDS]}')
176
+
177
+
178
+ # ================================================================
179
+ # DATASET
180
+ # ================================================================
181
+ class TestDataset(Dataset):
182
+ """
183
+ Test dataset that loads from preprocessed .npy cache (fast path).
184
+ Falls back to on-the-fly preprocessing if cache is missing.
185
+ """
186
+
187
+ def __init__(self, df, transform):
188
+ self.df = df.reset_index(drop=True)
189
+ self.transform = transform
190
+
191
+ def __len__(self):
192
+ return len(self.df)
193
+
194
+ def __getitem__(self, idx):
195
+ row = self.df.iloc[idx]
196
+
197
+ # Try cache path first
198
+ cache_fp = row.get('cache_path', '')
199
+ img = None
200
+
201
+ if cache_fp and os.path.exists(cache_fp):
202
+ try:
203
+ img = np.load(cache_fp)
204
+ except Exception:
205
+ img = None
206
+
207
+ # Fallback: on-the-fly preprocessing
208
+ if img is None:
209
+ image_path = row['image_path']
210
+ if not os.path.isabs(image_path):
211
+ clean = image_path
212
+ while clean.startswith('./') or clean.startswith('.//'):
213
+ clean = clean[2:] if clean.startswith('./') else clean[3:]
214
+ image_path = os.path.join(BASE_DIR, clean)
215
+
216
+ source = row.get('source', 'ODIR')
217
+ try:
218
+ if source == 'APTOS':
219
+ img = self._ben_graham(image_path)
220
+ else:
221
+ img = self._clahe_preprocess(image_path)
222
+ except Exception:
223
+ img = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
224
+
225
+ img_tensor = self.transform(img)
226
+ disease_lbl = int(row['disease_label'])
227
+ source = row.get('source', 'unknown')
228
+ return img_tensor, disease_lbl, source
229
+
230
+ @staticmethod
231
+ def _ben_graham(path, sz=IMG_SIZE, sigma=10):
232
+ raw = cv2.imread(path)
233
+ if raw is None:
234
+ raw = np.array(Image.open(path).convert('RGB'))
235
+ raw = cv2.cvtColor(raw, cv2.COLOR_RGB2BGR)
236
+ raw = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)
237
+ raw = cv2.resize(raw, (sz, sz))
238
+ raw = cv2.addWeighted(raw, 4, cv2.GaussianBlur(raw, (0, 0), sigma), -4, 128)
239
+ mask = np.zeros(raw.shape[:2], dtype=np.uint8)
240
+ cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1)
241
+ return cv2.bitwise_and(raw, raw, mask=mask)
242
+
243
+ @staticmethod
244
+ def _clahe_preprocess(path, sz=IMG_SIZE):
245
+ raw = cv2.imread(path)
246
+ if raw is None:
247
+ raw = np.array(Image.open(path).convert('RGB'))
248
+ raw = cv2.cvtColor(raw, cv2.COLOR_RGB2BGR)
249
+ raw = cv2.resize(raw, (sz, sz))
250
+ lab = cv2.cvtColor(raw, cv2.COLOR_BGR2LAB)
251
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
252
+ lab[:, :, 0] = clahe.apply(lab[:, :, 0])
253
+ raw = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
254
+ return cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)
255
+
256
+
257
+ val_transform = transforms.Compose([
258
+ transforms.ToPILImage(),
259
+ transforms.ToTensor(),
260
+ transforms.Normalize(NORM_MEAN, NORM_STD),
261
+ ])
262
+
263
+ print('\nLoading test set...')
264
+ test_df = pd.read_csv(TEST_CSV)
265
+ print(f' Test samples: {len(test_df)}')
266
+ print(f' Sources : {sorted(test_df["source"].unique())}')
267
+ print(f' Class dist : {test_df["disease_label"].value_counts().sort_index().to_dict()}')
268
+
269
+ test_ds = TestDataset(test_df, val_transform)
270
+ test_loader = DataLoader(
271
+ test_ds, batch_size=BATCH_SIZE, shuffle=False,
272
+ num_workers=4, pin_memory=True,
273
+ )
274
+
275
+
276
+ # ================================================================
277
+ # INFERENCE
278
+ # ================================================================
279
+ print('\nRunning inference on full test set...')
280
+ all_logits = []
281
+ all_labels = []
282
+ all_sources = []
283
+
284
+ with torch.no_grad():
285
+ for imgs, labels, sources in test_loader:
286
+ imgs = imgs.to(DEVICE)
287
+ disease_logits, _ = model(imgs)
288
+ all_logits.append(disease_logits.cpu())
289
+ all_labels.extend(labels.numpy().tolist())
290
+ all_sources.extend(sources)
291
+
292
+ all_logits = torch.cat(all_logits, dim=0) # (N, 5)
293
+ all_labels = np.array(all_labels)
294
+ all_sources = np.array(all_sources)
295
+ N = len(all_labels)
296
+ print(f' Inference complete: {N} samples')
297
+
298
+ # Temperature-scaled probabilities
299
+ probs_calibrated = F.softmax(all_logits / TEMPERATURE, dim=1).numpy() # (N, 5)
300
+ probs_uncalibrated = F.softmax(all_logits, dim=1).numpy()
301
+
302
+ # Predictions: argmax of calibrated probabilities
303
+ preds = np.argmax(probs_calibrated, axis=1)
304
+ confidences = np.max(probs_calibrated, axis=1)
305
+
306
+ correct_mask = (preds == all_labels)
307
+ acc = accuracy_score(all_labels, preds)
308
+ print(f' Overall accuracy: {acc:.4f} ({int(acc * N)}/{N})')
309
+
310
+
311
+ # ================================================================
312
+ # 1. CONFUSION MATRIX
313
+ # ================================================================
314
+ print('\n[1/7] Confusion matrix...')
315
+ cm = confusion_matrix(all_labels, preds, labels=list(range(NUM_CLASSES)))
316
+ cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
317
+
318
+ fig, ax = plt.subplots(figsize=(7, 6))
319
+ sns.heatmap(
320
+ cm_norm, annot=True, fmt='.2f', cmap='Blues',
321
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
322
+ linewidths=0.5, linecolor='white',
323
+ cbar_kws={'label': 'Proportion', 'shrink': 0.8},
324
+ ax=ax, vmin=0, vmax=1,
325
+ )
326
+ # Overlay raw counts in smaller font
327
+ for i in range(NUM_CLASSES):
328
+ for j in range(NUM_CLASSES):
329
+ ax.text(j + 0.5, i + 0.72, f'(n={cm[i, j]})',
330
+ ha='center', va='center', fontsize=7, color='gray')
331
+
332
+ ax.set_xlabel('Predicted Class')
333
+ ax.set_ylabel('True Class')
334
+ ax.set_title('Normalized Confusion Matrix (Test Set)')
335
+ fig.tight_layout()
336
+ fig.savefig(os.path.join(EVAL_DIR, 'confusion_matrix.png'))
337
+ plt.close(fig)
338
+ print(' Saved confusion_matrix.png')
339
+
340
+
341
+ # ================================================================
342
+ # 2. ROC CURVES PER CLASS
343
+ # ================================================================
344
+ print('[2/7] ROC curves...')
345
+ fig, ax = plt.subplots(figsize=(7, 6))
346
+ colors = sns.color_palette('tab10', NUM_CLASSES)
347
+ all_fpr_tpr = {}
348
+ macro_auc_list = []
349
+
350
+ for i in range(NUM_CLASSES):
351
+ y_true_bin = (all_labels == i).astype(int)
352
+ y_score = probs_calibrated[:, i]
353
+ fpr, tpr, _ = roc_curve(y_true_bin, y_score)
354
+ roc_auc = auc(fpr, tpr)
355
+ macro_auc_list.append(roc_auc)
356
+ all_fpr_tpr[i] = (fpr, tpr)
357
+ ax.plot(fpr, tpr, color=colors[i], lw=2,
358
+ label=f'{CLASS_NAMES[i]} (AUC={roc_auc:.3f})')
359
+
360
+ # Macro average ROC
361
+ mean_fpr = np.linspace(0, 1, 200)
362
+ mean_tpr = np.zeros_like(mean_fpr)
363
+ for i in range(NUM_CLASSES):
364
+ mean_tpr += np.interp(mean_fpr, all_fpr_tpr[i][0], all_fpr_tpr[i][1])
365
+ mean_tpr /= NUM_CLASSES
366
+ macro_auc = auc(mean_fpr, mean_tpr)
367
+ ax.plot(mean_fpr, mean_tpr, 'k--', lw=2.5,
368
+ label=f'Macro-average (AUC={macro_auc:.3f})')
369
+ ax.plot([0, 1], [0, 1], 'k:', lw=1, alpha=0.4)
370
+
371
+ ax.set_xlim([-0.02, 1.02])
372
+ ax.set_ylim([-0.02, 1.05])
373
+ ax.set_xlabel('False Positive Rate')
374
+ ax.set_ylabel('True Positive Rate')
375
+ ax.set_title('One-vs-Rest ROC Curves (Calibrated)')
376
+ ax.legend(loc='lower right', framealpha=0.9)
377
+ ax.grid(True, alpha=0.3)
378
+ fig.tight_layout()
379
+ fig.savefig(os.path.join(EVAL_DIR, 'roc_curves_per_class.png'))
380
+ plt.close(fig)
381
+ print(' Saved roc_curves_per_class.png')
382
+
383
+
384
+ # ================================================================
385
+ # 3. PRECISION-RECALL CURVES
386
+ # ================================================================
387
+ print('[3/7] Precision-recall curves...')
388
+ fig, ax = plt.subplots(figsize=(7, 6))
389
+
390
+ for i in range(NUM_CLASSES):
391
+ y_true_bin = (all_labels == i).astype(int)
392
+ y_score = probs_calibrated[:, i]
393
+ prec, rec, _ = precision_recall_curve(y_true_bin, y_score)
394
+ ap = average_precision_score(y_true_bin, y_score)
395
+ ax.plot(rec, prec, color=colors[i], lw=2,
396
+ label=f'{CLASS_NAMES[i]} (AP={ap:.3f})')
397
+
398
+ # Add prevalence baselines
399
+ prevalences = np.bincount(all_labels, minlength=NUM_CLASSES) / N
400
+ for i in range(NUM_CLASSES):
401
+ ax.axhline(y=prevalences[i], color=colors[i], ls=':', alpha=0.3)
402
+
403
+ ax.set_xlim([-0.02, 1.02])
404
+ ax.set_ylim([-0.02, 1.05])
405
+ ax.set_xlabel('Recall')
406
+ ax.set_ylabel('Precision')
407
+ ax.set_title('Precision-Recall Curves (Calibrated)')
408
+ ax.legend(loc='upper right', framealpha=0.9)
409
+ ax.grid(True, alpha=0.3)
410
+ fig.tight_layout()
411
+ fig.savefig(os.path.join(EVAL_DIR, 'precision_recall_curves.png'))
412
+ plt.close(fig)
413
+ print(' Saved precision_recall_curves.png')
414
+
415
+
416
+ # ================================================================
417
+ # 4. CALIBRATION RELIABILITY DIAGRAM
418
+ # ================================================================
419
+ print('[4/7] Calibration reliability diagram...')
420
+ n_bins = 10
421
+ bin_edges = np.linspace(0, 1, n_bins + 1)
422
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
423
+
424
+ # Compute calibration for both calibrated and uncalibrated probabilities
425
+ def compute_calibration(confidences_arr, correct_arr, bin_edges):
426
+ """Compute per-bin accuracy and average confidence."""
427
+ bin_accs = []
428
+ bin_confs = []
429
+ bin_counts = []
430
+ for lo, hi in zip(bin_edges[:-1], bin_edges[1:]):
431
+ mask = (confidences_arr > lo) & (confidences_arr <= hi)
432
+ if mask.sum() == 0:
433
+ bin_accs.append(np.nan)
434
+ bin_confs.append(np.nan)
435
+ bin_counts.append(0)
436
+ else:
437
+ bin_accs.append(correct_arr[mask].mean())
438
+ bin_confs.append(confidences_arr[mask].mean())
439
+ bin_counts.append(int(mask.sum()))
440
+ return np.array(bin_accs), np.array(bin_confs), np.array(bin_counts)
441
+
442
+ conf_calib = np.max(probs_calibrated, axis=1)
443
+ conf_uncalib = np.max(probs_uncalibrated, axis=1)
444
+
445
+ bin_accs_cal, bin_confs_cal, bin_counts_cal = compute_calibration(
446
+ conf_calib, correct_mask.astype(float), bin_edges)
447
+ bin_accs_uncal, bin_confs_uncal, bin_counts_uncal = compute_calibration(
448
+ conf_uncalib, correct_mask.astype(float), bin_edges)
449
+
450
+ # ECE
451
+ ece_cal = np.nansum(
452
+ np.abs(bin_accs_cal - bin_confs_cal) * bin_counts_cal) / N
453
+ ece_uncal = np.nansum(
454
+ np.abs(bin_accs_uncal - bin_confs_uncal) * bin_counts_uncal) / N
455
+
456
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
457
+
458
+ for ax_idx, (b_accs, b_confs, b_counts, ece_val, title_suffix) in enumerate([
459
+ (bin_accs_cal, bin_confs_cal, bin_counts_cal, ece_cal, 'Calibrated'),
460
+ (bin_accs_uncal, bin_confs_uncal, bin_counts_uncal, ece_uncal, 'Uncalibrated'),
461
+ ]):
462
+ ax = axes[ax_idx]
463
+ # Perfect calibration line
464
+ ax.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.5, label='Perfectly calibrated')
465
+ # Bar chart of bin accuracy
466
+ valid = ~np.isnan(b_accs)
467
+ bar_color = '#4C72B0' if ax_idx == 0 else '#DD8452'
468
+ ax.bar(bin_centers[valid], b_accs[valid], width=0.08,
469
+ alpha=0.7, color=bar_color, edgecolor='black', linewidth=0.5,
470
+ label=f'Model (ECE={ece_val:.4f})')
471
+ # Gap shading
472
+ for j in range(n_bins):
473
+ if valid[j]:
474
+ lo_val = min(b_accs[j], b_confs[j])
475
+ hi_val = max(b_accs[j], b_confs[j])
476
+ ax.fill_between(
477
+ [bin_centers[j] - 0.04, bin_centers[j] + 0.04],
478
+ lo_val, hi_val, alpha=0.15, color='red')
479
+ # Sample counts on top
480
+ for j in range(n_bins):
481
+ if valid[j] and b_counts[j] > 0:
482
+ ax.text(bin_centers[j], b_accs[j] + 0.03,
483
+ str(b_counts[j]), ha='center', va='bottom', fontsize=7)
484
+
485
+ ax.set_xlim([0, 1])
486
+ ax.set_ylim([0, 1.1])
487
+ ax.set_xlabel('Mean Predicted Confidence')
488
+ ax.set_ylabel('Fraction of Correct Predictions')
489
+ ax.set_title(f'Reliability Diagram ({title_suffix})')
490
+ ax.legend(loc='upper left', framealpha=0.9)
491
+ ax.grid(True, alpha=0.3)
492
+
493
+ fig.tight_layout()
494
+ fig.savefig(os.path.join(EVAL_DIR, 'calibration_reliability.png'))
495
+ plt.close(fig)
496
+ print(f' Saved calibration_reliability.png (ECE_cal={ece_cal:.4f}, ECE_uncal={ece_uncal:.4f})')
497
+
498
+
499
+ # ================================================================
500
+ # 5. CONFIDENCE HISTOGRAMS
501
+ # ================================================================
502
+ print('[5/7] Confidence histograms...')
503
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
504
+
505
+ # Correct vs Incorrect
506
+ for ax_idx, (mask, label, color) in enumerate([
507
+ (correct_mask, 'Correct', '#2ca02c'),
508
+ (~correct_mask, 'Incorrect', '#d62728'),
509
+ ]):
510
+ axes[0].hist(confidences[mask], bins=30, alpha=0.65, color=color,
511
+ label=f'{label} (n={mask.sum()})', edgecolor='black', linewidth=0.3)
512
+
513
+ axes[0].set_xlabel('Prediction Confidence')
514
+ axes[0].set_ylabel('Count')
515
+ axes[0].set_title('Confidence Distribution: Correct vs Incorrect')
516
+ axes[0].legend(loc='upper left', framealpha=0.9)
517
+ axes[0].axvline(x=np.median(confidences[correct_mask]), color='#2ca02c',
518
+ ls='--', alpha=0.6, label='_nolegend_')
519
+ axes[0].axvline(x=np.median(confidences[~correct_mask]), color='#d62728',
520
+ ls='--', alpha=0.6, label='_nolegend_')
521
+ axes[0].grid(True, alpha=0.3, axis='y')
522
+
523
+ # Per-class confidence
524
+ for i in range(NUM_CLASSES):
525
+ cls_mask = (all_labels == i)
526
+ axes[1].hist(confidences[cls_mask], bins=20, alpha=0.5, color=colors[i],
527
+ label=f'{CLASS_NAMES[i]} (n={cls_mask.sum()})',
528
+ edgecolor='black', linewidth=0.3)
529
+
530
+ axes[1].set_xlabel('Prediction Confidence')
531
+ axes[1].set_ylabel('Count')
532
+ axes[1].set_title('Confidence Distribution by True Class')
533
+ axes[1].legend(loc='upper left', framealpha=0.9, fontsize=9)
534
+ axes[1].grid(True, alpha=0.3, axis='y')
535
+
536
+ fig.tight_layout()
537
+ fig.savefig(os.path.join(EVAL_DIR, 'confidence_histograms.png'))
538
+ plt.close(fig)
539
+ print(' Saved confidence_histograms.png')
540
+
541
+
542
+ # ================================================================
543
+ # 6. ERROR ANALYSIS BY SOURCE
544
+ # ================================================================
545
+ print('[6/7] Error analysis by source...')
546
+ sources_unique = sorted(np.unique(all_sources))
547
+ n_sources = len(sources_unique)
548
+
549
+ # Build accuracy per (source, class) pair
550
+ source_class_acc = {}
551
+ source_class_n = {}
552
+ for src in sources_unique:
553
+ for cls_idx in range(NUM_CLASSES):
554
+ mask = (all_sources == src) & (all_labels == cls_idx)
555
+ n_cls = mask.sum()
556
+ if n_cls > 0:
557
+ acc_sc = (preds[mask] == all_labels[mask]).mean()
558
+ else:
559
+ acc_sc = np.nan
560
+ source_class_acc[(src, cls_idx)] = acc_sc
561
+ source_class_n[(src, cls_idx)] = int(n_cls)
562
+
563
+ # Also overall accuracy per source
564
+ source_overall_acc = {}
565
+ for src in sources_unique:
566
+ mask = (all_sources == src)
567
+ source_overall_acc[src] = accuracy_score(all_labels[mask], preds[mask])
568
+
569
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
570
+
571
+ # Left panel: grouped bar chart of per-class accuracy by source
572
+ x = np.arange(NUM_CLASSES)
573
+ bar_width = 0.8 / max(n_sources, 1)
574
+ source_colors = sns.color_palette('Set2', n_sources)
575
+
576
+ for s_idx, src in enumerate(sources_unique):
577
+ accs = [source_class_acc[(src, c)] for c in range(NUM_CLASSES)]
578
+ counts = [source_class_n[(src, c)] for c in range(NUM_CLASSES)]
579
+ offset = (s_idx - n_sources / 2 + 0.5) * bar_width
580
+ bars = axes[0].bar(x + offset, accs, bar_width * 0.9,
581
+ label=f'{src} (n={sum(counts)})',
582
+ color=source_colors[s_idx], edgecolor='black', linewidth=0.5)
583
+ # Annotate sample counts
584
+ for j, (b, n_val) in enumerate(zip(bars, counts)):
585
+ if n_val > 0 and not np.isnan(accs[j]):
586
+ axes[0].text(b.get_x() + b.get_width() / 2, b.get_height() + 0.02,
587
+ str(n_val), ha='center', va='bottom', fontsize=7)
588
+
589
+ axes[0].set_xticks(x)
590
+ axes[0].set_xticklabels(CLASS_NAMES, rotation=15, ha='right')
591
+ axes[0].set_ylabel('Accuracy')
592
+ axes[0].set_title('Per-Class Accuracy by Data Source')
593
+ axes[0].set_ylim([0, 1.15])
594
+ axes[0].legend(loc='upper right', framealpha=0.9)
595
+ axes[0].grid(True, alpha=0.3, axis='y')
596
+ axes[0].axhline(y=acc, color='black', ls='--', alpha=0.4, lw=1)
597
+ axes[0].text(NUM_CLASSES - 0.5, acc + 0.02, f'Overall: {acc:.3f}',
598
+ ha='right', fontsize=9, alpha=0.6)
599
+
600
+ # Right panel: confusion breakdown -- most common misclassifications per source
601
+ error_data = []
602
+ for src in sources_unique:
603
+ src_mask = (all_sources == src) & (~correct_mask)
604
+ if src_mask.sum() == 0:
605
+ continue
606
+ for true_cls in range(NUM_CLASSES):
607
+ for pred_cls in range(NUM_CLASSES):
608
+ if true_cls == pred_cls:
609
+ continue
610
+ pair_mask = src_mask & (all_labels == true_cls) & (preds == pred_cls)
611
+ cnt = pair_mask.sum()
612
+ if cnt > 0:
613
+ error_data.append({
614
+ 'Source': src,
615
+ 'Error': f'{CLASS_NAMES[true_cls][:3]}>{CLASS_NAMES[pred_cls][:3]}',
616
+ 'Count': int(cnt),
617
+ })
618
+
619
+ if error_data:
620
+ err_df = pd.DataFrame(error_data)
621
+ # Top 10 error types
622
+ top_errors = (err_df.groupby('Error')['Count'].sum()
623
+ .sort_values(ascending=False).head(10).index.tolist())
624
+ err_df_top = err_df[err_df['Error'].isin(top_errors)]
625
+ pivot = err_df_top.pivot_table(index='Error', columns='Source',
626
+ values='Count', aggfunc='sum', fill_value=0)
627
+ # Reorder by total count
628
+ pivot = pivot.loc[pivot.sum(axis=1).sort_values(ascending=True).index]
629
+ pivot.plot(kind='barh', stacked=True, ax=axes[1],
630
+ color=source_colors[:n_sources], edgecolor='black', linewidth=0.5)
631
+ axes[1].set_xlabel('Error Count')
632
+ axes[1].set_title('Top Misclassification Patterns by Source')
633
+ axes[1].legend(loc='lower right', framealpha=0.9)
634
+ axes[1].grid(True, alpha=0.3, axis='x')
635
+ else:
636
+ axes[1].text(0.5, 0.5, 'No errors to display', ha='center', va='center',
637
+ transform=axes[1].transAxes, fontsize=14)
638
+ axes[1].set_title('Top Misclassification Patterns by Source')
639
+
640
+ fig.tight_layout()
641
+ fig.savefig(os.path.join(EVAL_DIR, 'error_analysis_by_source.png'))
642
+ plt.close(fig)
643
+ print(' Saved error_analysis_by_source.png')
644
+
645
+
646
+ # ================================================================
647
+ # 7. METRICS REPORT (JSON)
648
+ # ================================================================
649
+ print('[7/7] Metrics report...')
650
+
651
+ # Classification report as dict
652
+ cls_report = classification_report(
653
+ all_labels, preds, target_names=CLASS_NAMES,
654
+ output_dict=True, zero_division=0)
655
+
656
+ # Per-class AUC and AP
657
+ per_class_auc = {}
658
+ per_class_ap = {}
659
+ for i in range(NUM_CLASSES):
660
+ y_bin = (all_labels == i).astype(int)
661
+ y_score = probs_calibrated[:, i]
662
+ fpr_i, tpr_i, _ = roc_curve(y_bin, y_score)
663
+ per_class_auc[CLASS_NAMES[i]] = float(auc(fpr_i, tpr_i))
664
+ per_class_ap[CLASS_NAMES[i]] = float(average_precision_score(y_bin, y_score))
665
+
666
+ # Build the full report
667
+ try:
668
+ ll = float(log_loss(all_labels, probs_calibrated))
669
+ except Exception:
670
+ ll = None
671
+
672
+ metrics_report = OrderedDict([
673
+ ('n_test_samples', int(N)),
674
+ ('overall_accuracy', float(acc)),
675
+ ('balanced_accuracy', float(balanced_accuracy_score(all_labels, preds))),
676
+ ('macro_f1', float(f1_score(all_labels, preds, average='macro', zero_division=0))),
677
+ ('weighted_f1', float(f1_score(all_labels, preds, average='weighted', zero_division=0))),
678
+ ('cohen_kappa', float(cohen_kappa_score(all_labels, preds))),
679
+ ('matthews_corrcoef', float(matthews_corrcoef(all_labels, preds))),
680
+ ('log_loss', ll),
681
+ ('macro_auc', float(np.mean(list(per_class_auc.values())))),
682
+ ('ece_calibrated', float(ece_cal)),
683
+ ('ece_uncalibrated', float(ece_uncal)),
684
+ ('temperature', float(TEMPERATURE)),
685
+ ('thresholds', THRESHOLDS),
686
+ ('per_class_metrics', {}),
687
+ ('per_class_auc', per_class_auc),
688
+ ('per_class_ap', per_class_ap),
689
+ ('confusion_matrix_raw', cm.tolist()),
690
+ ('confusion_matrix_normalized', np.round(cm_norm, 4).tolist()),
691
+ ('source_accuracy', {src: float(v) for src, v in source_overall_acc.items()}),
692
+ ('source_class_counts', {
693
+ src: {CLASS_NAMES[c]: source_class_n[(src, c)]
694
+ for c in range(NUM_CLASSES)}
695
+ for src in sources_unique
696
+ }),
697
+ ('class_names', CLASS_NAMES),
698
+ ])
699
+
700
+ # Per-class from classification_report
701
+ for i, name in enumerate(CLASS_NAMES):
702
+ metrics_report['per_class_metrics'][name] = {
703
+ 'precision': float(cls_report[name]['precision']),
704
+ 'recall': float(cls_report[name]['recall']),
705
+ 'f1-score': float(cls_report[name]['f1-score']),
706
+ 'support': int(cls_report[name]['support']),
707
+ 'auc': per_class_auc[name],
708
+ 'average_precision': per_class_ap[name],
709
+ }
710
+
711
+ report_path = os.path.join(EVAL_DIR, 'metrics_report.json')
712
+ with open(report_path, 'w') as f:
713
+ json.dump(metrics_report, f, indent=2)
714
+ print(f' Saved metrics_report.json')
715
+
716
+
717
+ # ================================================================
718
+ # SUMMARY
719
+ # ================================================================
720
+ print('\n' + '=' * 65)
721
+ print(' EVALUATION DASHBOARD COMPLETE')
722
+ print('=' * 65)
723
+ print(f' Overall Accuracy : {acc:.4f}')
724
+ print(f' Balanced Accuracy : {metrics_report["balanced_accuracy"]:.4f}')
725
+ print(f' Macro F1 : {metrics_report["macro_f1"]:.4f}')
726
+ print(f' Cohen Kappa : {metrics_report["cohen_kappa"]:.4f}')
727
+ print(f' Macro AUC : {metrics_report["macro_auc"]:.4f}')
728
+ print(f' ECE (calibrated) : {ece_cal:.4f}')
729
+ print(f' ECE (uncalibrated) : {ece_uncal:.4f}')
730
+ print(f'\n Per-class AUC:')
731
+ for name, val in per_class_auc.items():
732
+ print(f' {name:15s} : {val:.4f}')
733
+ print(f'\n Source accuracy:')
734
+ for src, val in source_overall_acc.items():
735
+ print(f' {src:10s} : {val:.4f}')
736
+ print(f'\n All outputs in: {EVAL_DIR}/')
737
+ output_files = [
738
+ 'confusion_matrix.png',
739
+ 'roc_curves_per_class.png',
740
+ 'precision_recall_curves.png',
741
+ 'calibration_reliability.png',
742
+ 'confidence_histograms.png',
743
+ 'error_analysis_by_source.png',
744
+ 'metrics_report.json',
745
+ ]
746
+ for fname in output_files:
747
+ fpath = os.path.join(EVAL_DIR, fname)
748
+ exists = os.path.exists(fpath)
749
+ size_kb = os.path.getsize(fpath) / 1024 if exists else 0
750
+ status = f'{size_kb:.0f} KB' if exists else 'MISSING'
751
+ print(f' [{status:>8s}] {fname}')
752
+ print('=' * 65)