tanishq74 commited on
Commit
8552f4b
·
verified ·
1 Parent(s): fa9938d

Add run_error_analysis.py

Browse files
Files changed (1) hide show
  1. run_error_analysis.py +977 -0
run_error_analysis.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RetinaSense ViT v2 - Comprehensive Error Analysis & Baseline Report
4
+ ===================================================================
5
+ Runs full evaluation on the validation split, computes ECE,
6
+ confusion analysis, confidence distributions, and source-level
7
+ performance. Saves all plots and metrics to outputs_analysis/v2_baseline/.
8
+ """
9
+
10
+ import os, sys, json, warnings
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib
14
+ matplotlib.use('Agg')
15
+ import matplotlib.pyplot as plt
16
+ import matplotlib.gridspec as gridspec
17
+ import seaborn as sns
18
+ from pathlib import Path
19
+ from tqdm import tqdm
20
+ warnings.filterwarnings('ignore')
21
+
22
+ import cv2
23
+ from PIL import Image
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torchvision import transforms
29
+ from torch.utils.data import Dataset, DataLoader
30
+
31
+ import timm
32
+ from sklearn.model_selection import train_test_split
33
+ from sklearn.metrics import (
34
+ confusion_matrix, classification_report,
35
+ f1_score, precision_score, recall_score
36
+ )
37
+
38
+ # ================================================================
39
+ # CONFIG
40
+ # ================================================================
41
+ BASE_DIR = '/teamspace/studios/this_studio'
42
+ MODEL_PATH = f'{BASE_DIR}/outputs_vit/best_model.pth'
43
+ META_CSV = f'{BASE_DIR}/final_unified_metadata.csv'
44
+ THRESH_JSON = f'{BASE_DIR}/outputs_vit/threshold_optimization_results.json'
45
+ CACHE_DIR = f'{BASE_DIR}/preprocessed_cache_vit'
46
+ OUT_DIR = f'{BASE_DIR}/outputs_analysis/v2_baseline'
47
+
48
+ IMG_SIZE = 224
49
+ BATCH_SIZE = 64
50
+ NUM_WORKERS = 8
51
+ NUM_CLASSES = 5
52
+ CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
53
+
54
+ os.makedirs(OUT_DIR, exist_ok=True)
55
+ os.makedirs(CACHE_DIR, exist_ok=True)
56
+
57
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58
+ print(f'Device: {device}')
59
+ if torch.cuda.is_available():
60
+ print(f'GPU: {torch.cuda.get_device_name(0)}')
61
+
62
+ # ================================================================
63
+ # MODEL DEFINITION (mirrors retinasense_vit.py)
64
+ # ================================================================
65
+ class MultiTaskViT(nn.Module):
66
+ def __init__(self, n_disease=5, n_severity=5, drop=0.4):
67
+ super().__init__()
68
+ self.backbone = timm.create_model(
69
+ 'vit_base_patch16_224', pretrained=False, num_classes=0)
70
+ feat = 768
71
+ self.drop = nn.Dropout(drop)
72
+ self.disease_head = nn.Sequential(
73
+ nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
74
+ nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
75
+ nn.Linear(256, n_disease))
76
+ self.severity_head = nn.Sequential(
77
+ nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
78
+ nn.Linear(256, n_severity))
79
+
80
+ def forward(self, x):
81
+ f = self.backbone(x)
82
+ f = self.drop(f)
83
+ return self.disease_head(f), self.severity_head(f)
84
+
85
+
86
+ # ================================================================
87
+ # IMAGE PREPROCESSING (Ben Graham method, matches training)
88
+ # ================================================================
89
+ def ben_graham(path, sz=IMG_SIZE, sigma=10):
90
+ img = cv2.imread(str(path))
91
+ if img is None:
92
+ img = np.array(Image.open(str(path)).convert('RGB'))
93
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
94
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
95
+ img = cv2.resize(img, (sz, sz))
96
+ img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0,0), sigma), -4, 128)
97
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
98
+ cv2.circle(mask, (sz//2, sz//2), int(sz * 0.48), 255, -1)
99
+ return cv2.bitwise_and(img, img, mask=mask)
100
+
101
+
102
+ def resolve_image_path(raw_path):
103
+ """
104
+ Resolve image path from CSV entry (which has leading .// prefix).
105
+ Tries multiple known root locations.
106
+ APTOS images live in:
107
+ aptos/gaussian_filtered_images/gaussian_filtered_images/{Severity}/{stem}.png
108
+ ODIR images live in:
109
+ odir/preprocessed_images/{filename}
110
+ """
111
+ # Strip leading .// or ./
112
+ clean = raw_path.lstrip('.').lstrip('/').lstrip('/')
113
+ clean = clean.replace('//', '/')
114
+
115
+ stem = Path(raw_path).stem
116
+
117
+ candidates = [
118
+ f'{BASE_DIR}/{clean}',
119
+ ]
120
+
121
+ # APTOS: search all severity subfolders
122
+ if 'aptos' in raw_path.lower():
123
+ aptos_base = f'{BASE_DIR}/aptos/gaussian_filtered_images/gaussian_filtered_images'
124
+ for severity in ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferate_DR']:
125
+ for ext in ['.png', '.jpg', '.jpeg']:
126
+ candidates.append(f'{aptos_base}/{severity}/{stem}{ext}')
127
+ # Also try train_images (original path)
128
+ for ext in ['.png', '.jpg', '.jpeg']:
129
+ candidates.append(f'{BASE_DIR}/aptos/train_images/{stem}{ext}')
130
+
131
+ # ODIR: preprocessed_images
132
+ if 'odir' in raw_path.lower():
133
+ fname = Path(raw_path).name
134
+ candidates.append(f'{BASE_DIR}/odir/preprocessed_images/{fname}')
135
+ candidates.append(f'{BASE_DIR}/ocular-disease-recognition-odir5k/preprocessed_images/{fname}')
136
+
137
+ for c in candidates:
138
+ if os.path.exists(c):
139
+ return c
140
+ return None
141
+
142
+
143
+ def load_or_cache(row):
144
+ """
145
+ Load preprocessed image from cache (.npy) or process from disk.
146
+ Returns uint8 HxWx3 numpy array.
147
+ """
148
+ stem = Path(row['image_path_clean']).stem
149
+ cache_fp = f'{CACHE_DIR}/{stem}_224.npy'
150
+
151
+ if os.path.exists(cache_fp):
152
+ try:
153
+ return np.load(cache_fp)
154
+ except Exception:
155
+ pass
156
+
157
+ img_path = row.get('image_path_resolved')
158
+ if img_path and os.path.exists(img_path):
159
+ try:
160
+ arr = ben_graham(img_path)
161
+ np.save(cache_fp, arr)
162
+ return arr
163
+ except Exception as e:
164
+ pass
165
+
166
+ # Fallback: zero image
167
+ return np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
168
+
169
+
170
+ # ================================================================
171
+ # DATASET
172
+ # ================================================================
173
+ val_transform = transforms.Compose([
174
+ transforms.ToPILImage(),
175
+ transforms.ToTensor(),
176
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
177
+ ])
178
+
179
+
180
+ class RetDS(Dataset):
181
+ def __init__(self, df):
182
+ self.df = df.reset_index(drop=True)
183
+
184
+ def __len__(self):
185
+ return len(self.df)
186
+
187
+ def __getitem__(self, i):
188
+ r = self.df.iloc[i]
189
+ img = load_or_cache(r)
190
+ return (
191
+ val_transform(img),
192
+ torch.tensor(int(r['disease_label']), dtype=torch.long),
193
+ torch.tensor(int(r['severity_label']), dtype=torch.long),
194
+ i # return index so we can track per-sample metadata
195
+ )
196
+
197
+
198
+ # ================================================================
199
+ # STEP 1 — LOAD METADATA & BUILD VAL SPLIT
200
+ # ================================================================
201
+ print('\n[1/6] Loading metadata and building val split...')
202
+
203
+ meta = pd.read_csv(META_CSV)
204
+ print(f' Raw rows: {len(meta)}')
205
+
206
+ # Fix image paths
207
+ meta['image_path_clean'] = meta['image_path'].str.lstrip('.').str.lstrip('/').str.replace('//', '/', regex=False)
208
+ meta['image_path_resolved'] = meta['image_path_clean'].apply(
209
+ lambda p: resolve_image_path(p)
210
+ )
211
+
212
+ n_resolved = meta['image_path_resolved'].notna().sum()
213
+ print(f' Images resolved on disk: {n_resolved} / {len(meta)}')
214
+
215
+ # Build the same stratified split used in training (random_state=42, test_size=0.2)
216
+ train_df, val_df = train_test_split(
217
+ meta,
218
+ test_size=0.2,
219
+ stratify=meta['disease_label'],
220
+ random_state=42
221
+ )
222
+ val_df = val_df.reset_index(drop=True)
223
+ print(f' Val split: {len(val_df)} samples')
224
+ print(f' Val class distribution:')
225
+ for lbl, cnt in val_df['disease_label'].value_counts().sort_index().items():
226
+ print(f' {CLASS_NAMES[int(lbl)]:<15s}: {cnt:4d}')
227
+
228
+ # ================================================================
229
+ # STEP 2 — LOAD MODEL
230
+ # ================================================================
231
+ print('\n[2/6] Loading model...')
232
+
233
+ model = MultiTaskViT().to(device)
234
+ ckpt = torch.load(MODEL_PATH, map_location=device, weights_only=False)
235
+ model.load_state_dict(ckpt['model_state_dict'])
236
+ model.eval()
237
+ print(f' Loaded checkpoint: epoch={ckpt.get("epoch","?")}, '
238
+ f'macro_f1={ckpt.get("macro_f1", 0):.4f}')
239
+
240
+ # Load thresholds
241
+ with open(THRESH_JSON) as f:
242
+ thresh_data = json.load(f)
243
+ thresholds = {int(k): float(v) for k, v in thresh_data['optimal_thresholds'].items()}
244
+ print(f' Optimal thresholds: {thresholds}')
245
+
246
+ # ================================================================
247
+ # STEP 3 — RUN INFERENCE
248
+ # ================================================================
249
+ print('\n[3/6] Running inference on val set...')
250
+
251
+ val_ds = RetDS(val_df)
252
+ val_loader = DataLoader(
253
+ val_ds, batch_size=BATCH_SIZE, shuffle=False,
254
+ num_workers=NUM_WORKERS, pin_memory=True
255
+ )
256
+
257
+ all_probs = [] # (N, 5) softmax probabilities
258
+ all_preds = [] # (N,) argmax predictions
259
+ all_labels = [] # (N,) true labels
260
+ all_idxs = [] # (N,) val_df indices
261
+
262
+ with torch.no_grad():
263
+ for imgs, d_lbl, s_lbl, idx in tqdm(val_loader, desc='Inference'):
264
+ imgs = imgs.to(device, non_blocking=True)
265
+ with torch.amp.autocast('cuda'):
266
+ d_out, _ = model(imgs)
267
+ probs = torch.softmax(d_out.float(), dim=1).cpu().numpy()
268
+ preds = d_out.argmax(1).cpu().numpy()
269
+ all_probs.append(probs)
270
+ all_preds.append(preds)
271
+ all_labels.append(d_lbl.numpy())
272
+ all_idxs.append(idx.numpy())
273
+
274
+ all_probs = np.vstack(all_probs) # (N, 5)
275
+ all_preds = np.concatenate(all_preds)
276
+ all_labels = np.concatenate(all_labels)
277
+ all_idxs = np.concatenate(all_idxs)
278
+
279
+ # Also compute threshold-adjusted predictions
280
+ thresh_preds = np.zeros_like(all_preds)
281
+ for i in range(len(all_probs)):
282
+ adjusted = all_probs[i].copy()
283
+ for c, t in thresholds.items():
284
+ adjusted[c] = all_probs[i][c] / t # scale by threshold
285
+ thresh_preds[i] = adjusted.argmax()
286
+
287
+ raw_acc = (all_preds == all_labels).mean() * 100
288
+ thresh_acc = (thresh_preds == all_labels).mean() * 100
289
+ print(f' Raw accuracy : {raw_acc:.2f}%')
290
+ print(f' Threshold accuracy: {thresh_acc:.2f}%')
291
+
292
+ # Use threshold-adjusted for main analysis (matches published 84.48%)
293
+ preds = thresh_preds
294
+
295
+ # ================================================================
296
+ # STEP 4 — CONFIDENCE CALIBRATION (ECE)
297
+ # ================================================================
298
+ print('\n[4/6] Computing ECE and reliability diagram...')
299
+
300
+ def compute_ece(probs, labels, n_bins=10):
301
+ """Expected Calibration Error with equal-width bins."""
302
+ confidences = probs.max(axis=1) # max probability = confidence
303
+ predicted = probs.argmax(axis=1)
304
+ correct = (predicted == labels).astype(float)
305
+
306
+ bins = np.linspace(0, 1, n_bins + 1)
307
+ ece = 0.0
308
+ bin_acc = []
309
+ bin_conf = []
310
+ bin_count = []
311
+
312
+ for lo, hi in zip(bins[:-1], bins[1:]):
313
+ mask = (confidences >= lo) & (confidences < hi)
314
+ if mask.sum() == 0:
315
+ bin_acc.append(0.0)
316
+ bin_conf.append((lo + hi) / 2)
317
+ bin_count.append(0)
318
+ continue
319
+ acc = correct[mask].mean()
320
+ conf = confidences[mask].mean()
321
+ n = mask.sum()
322
+ ece += (n / len(labels)) * abs(acc - conf)
323
+ bin_acc.append(acc)
324
+ bin_conf.append(conf)
325
+ bin_count.append(int(n))
326
+
327
+ return ece, bin_acc, bin_conf, bin_count, bins
328
+
329
+ ece, bin_acc, bin_conf, bin_count, bins = compute_ece(all_probs, all_labels)
330
+ print(f' ECE (10 bins): {ece:.4f}')
331
+
332
+ # Per-class calibration
333
+ per_class_ece = {}
334
+ for c in range(NUM_CLASSES):
335
+ mask = (all_labels == c)
336
+ if mask.sum() == 0:
337
+ per_class_ece[CLASS_NAMES[c]] = 0.0
338
+ continue
339
+ ece_c, _, _, _, _ = compute_ece(all_probs[mask], all_labels[mask])
340
+ per_class_ece[CLASS_NAMES[c]] = float(ece_c)
341
+ print(f' ECE {CLASS_NAMES[c]:<15s}: {ece_c:.4f}')
342
+
343
+ # -- Reliability diagram --
344
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
345
+
346
+ bin_centers = (bins[:-1] + bins[1:]) / 2
347
+ bars = axes[0].bar(
348
+ bin_centers, bin_acc,
349
+ width=(bins[1] - bins[0]) * 0.9,
350
+ alpha=0.7, color='steelblue', label='Accuracy per bin'
351
+ )
352
+ axes[0].plot([0, 1], [0, 1], 'r--', lw=2, label='Perfect calibration')
353
+ axes[0].set_xlabel('Confidence', fontsize=12)
354
+ axes[0].set_ylabel('Accuracy', fontsize=12)
355
+ axes[0].set_title(f'Reliability Diagram\nECE = {ece:.4f}', fontsize=13, fontweight='bold')
356
+ axes[0].legend(fontsize=10)
357
+ axes[0].grid(alpha=0.3)
358
+ axes[0].set_xlim(0, 1); axes[0].set_ylim(0, 1)
359
+
360
+ # Annotate with bin counts
361
+ for bar, cnt in zip(bars, bin_count):
362
+ if cnt > 0:
363
+ axes[0].text(
364
+ bar.get_x() + bar.get_width()/2, min(bar.get_height() + 0.02, 0.97),
365
+ str(cnt), ha='center', va='bottom', fontsize=7, color='black'
366
+ )
367
+
368
+ # Gap diagram (overconfidence = positive gap)
369
+ gap = np.array(bin_conf) - np.array(bin_acc)
370
+ color_gap = ['#e74c3c' if g > 0 else '#2ecc71' for g in gap]
371
+ axes[1].bar(bin_centers, gap, width=(bins[1]-bins[0])*0.9, color=color_gap, alpha=0.8)
372
+ axes[1].axhline(0, color='black', lw=1)
373
+ axes[1].set_xlabel('Confidence', fontsize=12)
374
+ axes[1].set_ylabel('Confidence - Accuracy (Gap)', fontsize=12)
375
+ axes[1].set_title('Calibration Gap\n(Red=overconfident, Green=underconfident)',
376
+ fontsize=13, fontweight='bold')
377
+ axes[1].grid(alpha=0.3)
378
+ axes[1].set_xlim(0, 1)
379
+
380
+ plt.tight_layout()
381
+ plt.savefig(f'{OUT_DIR}/reliability_diagram.png', dpi=150, bbox_inches='tight')
382
+ plt.close()
383
+ print(f' Saved reliability_diagram.png')
384
+
385
+ # ================================================================
386
+ # STEP 5 — CONFUSION MATRIX
387
+ # ================================================================
388
+ print('\n[5/6] Generating confusion matrices...')
389
+
390
+ cm_raw = confusion_matrix(all_labels, preds)
391
+ cm_norm = cm_raw.astype(float) / cm_raw.sum(axis=1, keepdims=True)
392
+
393
+ # -- Raw counts confusion matrix --
394
+ fig, ax = plt.subplots(figsize=(8, 6))
395
+ sns.heatmap(
396
+ cm_raw, annot=True, fmt='d', cmap='Blues', ax=ax,
397
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
398
+ linewidths=0.5, linecolor='gray'
399
+ )
400
+ ax.set_title('Confusion Matrix (Raw Counts)', fontsize=14, fontweight='bold')
401
+ ax.set_ylabel('True Label', fontsize=12)
402
+ ax.set_xlabel('Predicted Label', fontsize=12)
403
+ plt.xticks(rotation=30, ha='right')
404
+ plt.tight_layout()
405
+ plt.savefig(f'{OUT_DIR}/confusion_matrix_raw.png', dpi=150, bbox_inches='tight')
406
+ plt.close()
407
+
408
+ # -- Normalized confusion matrix --
409
+ fig, ax = plt.subplots(figsize=(8, 6))
410
+ sns.heatmap(
411
+ cm_norm, annot=True, fmt='.3f', cmap='Blues', ax=ax,
412
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
413
+ linewidths=0.5, linecolor='gray', vmin=0, vmax=1
414
+ )
415
+ ax.set_title('Confusion Matrix (Normalized by True Class)', fontsize=14, fontweight='bold')
416
+ ax.set_ylabel('True Label', fontsize=12)
417
+ ax.set_xlabel('Predicted Label', fontsize=12)
418
+ plt.xticks(rotation=30, ha='right')
419
+ plt.tight_layout()
420
+ plt.savefig(f'{OUT_DIR}/confusion_matrix_normalized.png', dpi=150, bbox_inches='tight')
421
+ plt.close()
422
+ print(' Saved confusion_matrix_raw.png and confusion_matrix_normalized.png')
423
+
424
+ # -- Top confused pairs --
425
+ confused_pairs = []
426
+ for true_c in range(NUM_CLASSES):
427
+ for pred_c in range(NUM_CLASSES):
428
+ if true_c == pred_c:
429
+ continue
430
+ count = cm_raw[true_c, pred_c]
431
+ rate = cm_norm[true_c, pred_c]
432
+ confused_pairs.append({
433
+ 'true_class': CLASS_NAMES[true_c],
434
+ 'pred_class': CLASS_NAMES[pred_c],
435
+ 'count': int(count),
436
+ 'rate': float(rate),
437
+ 'description': f'{CLASS_NAMES[true_c]} misclassified AS {CLASS_NAMES[pred_c]}'
438
+ })
439
+ confused_pairs.sort(key=lambda x: x['count'], reverse=True)
440
+ top5_pairs = confused_pairs[:5]
441
+
442
+ print('\n Top 5 confused class pairs (by raw count):')
443
+ for p in top5_pairs:
444
+ print(f' {p["description"]}: {p["count"]} ({p["rate"]*100:.1f}%)')
445
+
446
+ # ================================================================
447
+ # STEP 6 — PER-CLASS METRICS
448
+ # ================================================================
449
+ print('\n[6/6] Computing per-class metrics...')
450
+
451
+ report_dict = classification_report(
452
+ all_labels, preds, target_names=CLASS_NAMES, output_dict=True, zero_division=0
453
+ )
454
+ print(classification_report(all_labels, preds, target_names=CLASS_NAMES, digits=4, zero_division=0))
455
+
456
+ per_class_precision = {}
457
+ per_class_recall = {}
458
+ per_class_f1 = {}
459
+ per_class_support = {}
460
+
461
+ for cn in CLASS_NAMES:
462
+ per_class_precision[cn] = report_dict[cn]['precision']
463
+ per_class_recall[cn] = report_dict[cn]['recall']
464
+ per_class_f1[cn] = report_dict[cn]['f1-score']
465
+ per_class_support[cn] = int(report_dict[cn]['support'])
466
+
467
+ overall_accuracy = report_dict['accuracy'] * 100
468
+ macro_f1 = report_dict['macro avg']['f1-score']
469
+ weighted_f1 = report_dict['weighted avg']['f1-score']
470
+
471
+ print(f'\n Overall accuracy : {overall_accuracy:.2f}%')
472
+ print(f' Macro F1 : {macro_f1:.4f}')
473
+ print(f' Weighted F1 : {weighted_f1:.4f}')
474
+
475
+ # ================================================================
476
+ # CONFIDENCE DISTRIBUTION ANALYSIS
477
+ # ================================================================
478
+ print('\nAnalyzing confidence distributions...')
479
+
480
+ fig, axes = plt.subplots(2, 3, figsize=(18, 10))
481
+ axes = axes.flatten()
482
+
483
+ all_max_conf = all_probs.max(axis=1)
484
+ all_correct = (preds == all_labels)
485
+
486
+ for ci, cn in enumerate(CLASS_NAMES):
487
+ ax = axes[ci]
488
+ mask_class = (all_labels == ci)
489
+
490
+ correct_conf = all_max_conf[mask_class & all_correct]
491
+ wrong_conf = all_max_conf[mask_class & ~all_correct]
492
+
493
+ n_correct = len(correct_conf)
494
+ n_wrong = len(wrong_conf)
495
+
496
+ if n_correct > 0:
497
+ ax.hist(correct_conf, bins=20, alpha=0.6, color='#2ecc71',
498
+ label=f'Correct (n={n_correct})', density=True)
499
+ if n_wrong > 0:
500
+ ax.hist(wrong_conf, bins=20, alpha=0.6, color='#e74c3c',
501
+ label=f'Wrong (n={n_wrong})', density=True)
502
+
503
+ # Mark high-confidence wrong predictions
504
+ if n_wrong > 0:
505
+ high_conf_wrong = (wrong_conf > 0.8).sum()
506
+ ax.axvline(0.8, color='darkred', linestyle='--', alpha=0.7, lw=1.5,
507
+ label=f'Conf>0.8 wrong: {high_conf_wrong}')
508
+
509
+ ax.set_title(f'{cn}\nPrec={per_class_precision[cn]:.3f} Rec={per_class_recall[cn]:.3f} F1={per_class_f1[cn]:.3f}',
510
+ fontsize=10, fontweight='bold')
511
+ ax.set_xlabel('Max Confidence', fontsize=9)
512
+ ax.set_ylabel('Density', fontsize=9)
513
+ ax.legend(fontsize=7)
514
+ ax.grid(alpha=0.3)
515
+ ax.set_xlim(0, 1)
516
+
517
+ # Summary panel
518
+ ax = axes[5]
519
+ mean_correct = [all_max_conf[all_labels==c][preds[all_labels==c]==c].mean()
520
+ if (all_labels==c).sum() > 0 else 0 for c in range(NUM_CLASSES)]
521
+ mean_wrong = [all_max_conf[all_labels==c][preds[all_labels==c]!=c].mean()
522
+ if ((all_labels==c) & (preds!=c)).sum() > 0 else 0 for c in range(NUM_CLASSES)]
523
+
524
+ x = np.arange(NUM_CLASSES)
525
+ width = 0.35
526
+ ax.bar(x - width/2, mean_correct, width, label='Mean conf (correct)', color='#2ecc71', alpha=0.8)
527
+ ax.bar(x + width/2, mean_wrong, width, label='Mean conf (wrong)', color='#e74c3c', alpha=0.8)
528
+ ax.set_xticks(x)
529
+ ax.set_xticklabels([c[:6] for c in CLASS_NAMES], rotation=20)
530
+ ax.set_ylabel('Mean Confidence')
531
+ ax.set_title('Mean Confidence: Correct vs Wrong', fontweight='bold')
532
+ ax.legend(fontsize=8)
533
+ ax.grid(alpha=0.3, axis='y')
534
+ ax.set_ylim(0, 1)
535
+
536
+ plt.suptitle('Confidence Distribution Analysis per Class', fontsize=14, fontweight='bold')
537
+ plt.tight_layout()
538
+ plt.savefig(f'{OUT_DIR}/confidence_distributions.png', dpi=150, bbox_inches='tight')
539
+ plt.close()
540
+ print(' Saved confidence_distributions.png')
541
+
542
+ # ================================================================
543
+ # PER-SOURCE ANALYSIS
544
+ # ================================================================
545
+ print('\nRunning per-source analysis...')
546
+
547
+ # Attach dataset source to val_df indices
548
+ source_col = val_df['dataset'].values
549
+
550
+ results_df = pd.DataFrame({
551
+ 'true_label': all_labels,
552
+ 'pred_label': preds,
553
+ 'max_conf': all_max_conf,
554
+ 'dataset': source_col[all_idxs],
555
+ 'correct': (preds == all_labels).astype(int),
556
+ })
557
+
558
+ per_source = {}
559
+ for src in ['ODIR', 'APTOS']:
560
+ mask = results_df['dataset'] == src
561
+ if mask.sum() == 0:
562
+ continue
563
+ src_true = results_df['true_label'][mask].values
564
+ src_pred = results_df['pred_label'][mask].values
565
+ src_acc = (src_true == src_pred).mean() * 100
566
+ src_f1 = f1_score(src_true, src_pred, average='macro', zero_division=0)
567
+
568
+ per_class_acc_src = {}
569
+ for c in range(NUM_CLASSES):
570
+ cmask = (src_true == c)
571
+ if cmask.sum() == 0:
572
+ per_class_acc_src[CLASS_NAMES[c]] = None
573
+ else:
574
+ per_class_acc_src[CLASS_NAMES[c]] = float((src_pred[cmask] == c).mean() * 100)
575
+
576
+ per_source[src] = {
577
+ 'n_samples': int(mask.sum()),
578
+ 'accuracy': float(src_acc),
579
+ 'macro_f1': float(src_f1),
580
+ 'per_class_acc': per_class_acc_src
581
+ }
582
+ print(f'\n {src} (n={mask.sum()}):')
583
+ print(f' Accuracy : {src_acc:.2f}%')
584
+ print(f' Macro F1 : {src_f1:.4f}')
585
+ for cn, acc in per_class_acc_src.items():
586
+ if acc is not None:
587
+ print(f' {cn:<15s}: {acc:.1f}%')
588
+
589
+ # -- Per-source performance plot --
590
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
591
+
592
+ # Overall bar
593
+ sources = list(per_source.keys())
594
+ accs = [per_source[s]['accuracy'] for s in sources]
595
+ f1s = [per_source[s]['macro_f1'] for s in sources]
596
+
597
+ x = np.arange(len(sources))
598
+ w = 0.35
599
+ axes[0].bar(x - w/2, accs, w, label='Accuracy (%)', color=['#3498db', '#e67e22'], alpha=0.85)
600
+ axes[0].bar(x + w/2, [f*100 for f in f1s], w, label='Macro F1 ×100',
601
+ color=['#2ecc71', '#e74c3c'], alpha=0.85)
602
+ axes[0].set_xticks(x); axes[0].set_xticklabels(sources)
603
+ axes[0].set_ylim(50, 100)
604
+ axes[0].set_ylabel('Score')
605
+ axes[0].set_title('Overall Performance by Source', fontweight='bold')
606
+ axes[0].legend(); axes[0].grid(alpha=0.3, axis='y')
607
+ for xi, (acc, f1) in enumerate(zip(accs, f1s)):
608
+ axes[0].text(xi - w/2, acc + 0.5, f'{acc:.1f}', ha='center', fontsize=9)
609
+ axes[0].text(xi + w/2, f1*100 + 0.5, f'{f1*100:.1f}', ha='center', fontsize=9)
610
+
611
+ # Per-class accuracy by source
612
+ class_data = {cn: [] for cn in CLASS_NAMES}
613
+ valid_sources = []
614
+ for src in sources:
615
+ valid_sources.append(src)
616
+ for cn in CLASS_NAMES:
617
+ acc = per_source[src]['per_class_acc'].get(cn)
618
+ class_data[cn].append(acc if acc is not None else 0.0)
619
+
620
+ x = np.arange(len(CLASS_NAMES))
621
+ n_src = len(valid_sources)
622
+ width = 0.8 / n_src
623
+ colors_src = ['#3498db', '#e67e22', '#2ecc71']
624
+
625
+ for si, src in enumerate(valid_sources):
626
+ vals = [class_data[cn][si] for cn in CLASS_NAMES]
627
+ offset = (si - n_src/2 + 0.5) * width
628
+ axes[1].bar(x + offset, vals, width, label=src, alpha=0.85, color=colors_src[si])
629
+
630
+ axes[1].set_xticks(x); axes[1].set_xticklabels(CLASS_NAMES, rotation=20, ha='right')
631
+ axes[1].set_ylim(0, 105)
632
+ axes[1].set_ylabel('Accuracy (%)')
633
+ axes[1].set_title('Per-Class Accuracy by Source', fontweight='bold')
634
+ axes[1].legend(); axes[1].grid(alpha=0.3, axis='y')
635
+
636
+ plt.suptitle('Dataset Source Performance Analysis', fontsize=14, fontweight='bold')
637
+ plt.tight_layout()
638
+ plt.savefig(f'{OUT_DIR}/per_source_performance.png', dpi=150, bbox_inches='tight')
639
+ plt.close()
640
+ print(' Saved per_source_performance.png')
641
+
642
+ # ================================================================
643
+ # SAVE METRICS JSON
644
+ # ================================================================
645
+ print('\nSaving metrics JSON...')
646
+
647
+ baseline_metrics = {
648
+ 'overall_accuracy': float(overall_accuracy),
649
+ 'raw_accuracy': float(raw_acc),
650
+ 'threshold_accuracy': float(thresh_acc),
651
+ 'macro_f1': float(macro_f1),
652
+ 'weighted_f1': float(weighted_f1),
653
+ 'ece': float(ece),
654
+ 'per_class_ece': per_class_ece,
655
+ 'per_class_f1': per_class_f1,
656
+ 'per_class_precision': per_class_precision,
657
+ 'per_class_recall': per_class_recall,
658
+ 'per_class_support': per_class_support,
659
+ 'per_source_accuracy': {
660
+ src: {
661
+ 'accuracy': per_source[src]['accuracy'],
662
+ 'macro_f1': per_source[src]['macro_f1'],
663
+ 'n_samples': per_source[src]['n_samples'],
664
+ 'per_class_acc': per_source[src]['per_class_acc']
665
+ }
666
+ for src in per_source
667
+ },
668
+ 'top_confusion_pairs': top5_pairs,
669
+ 'confusion_matrix_raw': cm_raw.tolist(),
670
+ 'val_split_size': len(val_df),
671
+ 'thresholds_used': thresholds,
672
+ 'calibration': {
673
+ 'ece': float(ece),
674
+ 'bin_acc': [float(x) for x in bin_acc],
675
+ 'bin_conf': [float(x) for x in bin_conf],
676
+ 'bin_count': bin_count,
677
+ }
678
+ }
679
+
680
+ with open(f'{OUT_DIR}/baseline_metrics.json', 'w') as f:
681
+ json.dump(baseline_metrics, f, indent=2)
682
+ print(f' Saved baseline_metrics.json')
683
+
684
+ # ================================================================
685
+ # ANALYSIS REPORT
686
+ # ================================================================
687
+ print('\nGenerating analysis report...')
688
+
689
+ # Identify key findings
690
+ worst_recall_class = min(per_class_recall, key=per_class_recall.get)
691
+ worst_f1_class = min(per_class_f1, key=per_class_f1.get)
692
+ best_f1_class = max(per_class_f1, key=per_class_f1.get)
693
+
694
+ # High-confidence wrong predictions per class
695
+ hcw_analysis = {}
696
+ for ci, cn in enumerate(CLASS_NAMES):
697
+ mask_class = (all_labels == ci)
698
+ wrong_mask = mask_class & ~all_correct
699
+ if wrong_mask.sum() > 0:
700
+ high_conf_wrong = ((all_max_conf > 0.8) & wrong_mask).sum()
701
+ hcw_analysis[cn] = {
702
+ 'total_wrong': int(wrong_mask.sum()),
703
+ 'high_conf_wrong_count': int(high_conf_wrong),
704
+ 'high_conf_wrong_pct': float(high_conf_wrong / wrong_mask.sum() * 100) if wrong_mask.sum() > 0 else 0,
705
+ 'mean_wrong_conf': float(all_max_conf[wrong_mask].mean()) if wrong_mask.sum() > 0 else 0,
706
+ }
707
+ else:
708
+ hcw_analysis[cn] = {'total_wrong': 0, 'high_conf_wrong_count': 0,
709
+ 'high_conf_wrong_pct': 0, 'mean_wrong_conf': 0}
710
+
711
+ # Domain gap
712
+ domain_gap = None
713
+ if 'ODIR' in per_source and 'APTOS' in per_source:
714
+ odir_acc = per_source['ODIR']['accuracy']
715
+ aptos_acc = per_source['APTOS']['accuracy']
716
+ domain_gap = abs(odir_acc - aptos_acc)
717
+
718
+ # DR-specific domain gap
719
+ odir_dr = per_source['ODIR']['per_class_acc'].get('Diabetes/DR', 0) or 0
720
+ aptos_dr = per_source['APTOS']['per_class_acc'].get('Diabetes/DR', 0) or 0
721
+ dr_gap = abs(odir_dr - aptos_dr)
722
+ else:
723
+ domain_gap = 0; odir_acc = 0; aptos_acc = 0; odir_dr = 0; aptos_dr = 0; dr_gap = 0
724
+
725
+ calibration_verdict = 'overconfident' if sum(
726
+ b_conf - b_acc for b_conf, b_acc in zip(bin_conf, bin_acc) if bin_count[bin_acc.index(b_acc)] > 0
727
+ ) > 0 else 'underconfident'
728
+
729
+ report = f"""# RetinaSense ViT v2 — Baseline Error Analysis Report
730
+ **Generated**: 2026-03-06
731
+ **Model**: ViT-Base-Patch16-224 (MultiTaskViT)
732
+ **Checkpoint**: outputs_vit/best_model.pth
733
+ **Val Split**: {len(val_df)} samples (20% stratified, random_state=42)
734
+
735
+ ---
736
+
737
+ ## 1. Overall Performance
738
+
739
+ | Metric | Value |
740
+ |--------|-------|
741
+ | Accuracy (raw argmax) | {raw_acc:.2f}% |
742
+ | Accuracy (with thresholds) | {thresh_acc:.2f}% |
743
+ | Macro F1 | {macro_f1:.4f} |
744
+ | Weighted F1 | {weighted_f1:.4f} |
745
+ | ECE (10 bins) | {ece:.4f} |
746
+
747
+ ---
748
+
749
+ ## 2. Per-Class Metrics
750
+
751
+ | Class | Precision | Recall | F1 | Support |
752
+ |-------|-----------|--------|----|---------|
753
+ """
754
+ for cn in CLASS_NAMES:
755
+ report += (f"| {cn:<15s} | {per_class_precision[cn]:.4f} | "
756
+ f"{per_class_recall[cn]:.4f} | {per_class_f1[cn]:.4f} | "
757
+ f"{per_class_support[cn]:4d} |\n")
758
+
759
+ report += f"""
760
+ ---
761
+
762
+ ## 3. Confusion Analysis — Top 5 Confused Pairs
763
+
764
+ | Rank | True Class | Predicted As | Count | Rate |
765
+ |------|-----------|-------------|-------|------|
766
+ """
767
+ for rank, pair in enumerate(top5_pairs, 1):
768
+ report += (f"| {rank} | {pair['true_class']} | {pair['pred_class']} | "
769
+ f"{pair['count']} | {pair['rate']*100:.1f}% |\n")
770
+
771
+ report += f"""
772
+ ### Full Confusion Matrix (normalized by true class)
773
+
774
+ ```
775
+ {(' '.join(f'{cn[:6]:>7s}' for cn in CLASS_NAMES))}
776
+ """
777
+ for ri, rn in enumerate(CLASS_NAMES):
778
+ row_str = ' '.join(f'{cm_norm[ri, ci]:.3f}' for ci in range(NUM_CLASSES))
779
+ report += f"{rn[:8]:>8s} {row_str}\n"
780
+
781
+ report += f"""```
782
+
783
+ ---
784
+
785
+ ## 4. Confidence Calibration Analysis
786
+
787
+ - **ECE (overall)**: {ece:.4f}
788
+ - **Calibration pattern**: The model is predominantly **{calibration_verdict}**
789
+ (mean confidence exceeds accuracy in most bins).
790
+
791
+ ### Per-Class ECE
792
+
793
+ | Class | ECE |
794
+ |-------|-----|
795
+ """
796
+ for cn, ece_c in per_class_ece.items():
797
+ report += f"| {cn} | {ece_c:.4f} |\n"
798
+
799
+ report += f"""
800
+ ### High-Confidence Wrong Predictions (confidence > 0.8)
801
+
802
+ | Class | Total Wrong | High-Conf Wrong | % of Errors | Mean Wrong Conf |
803
+ |-------|------------|----------------|-------------|----------------|
804
+ """
805
+ for cn, hcw in hcw_analysis.items():
806
+ report += (f"| {cn} | {hcw['total_wrong']} | {hcw['high_conf_wrong_count']} | "
807
+ f"{hcw['high_conf_wrong_pct']:.1f}% | {hcw['mean_wrong_conf']:.3f} |\n")
808
+
809
+ report += f"""
810
+ ---
811
+
812
+ ## 5. Dataset Source Analysis (ODIR vs APTOS)
813
+
814
+ | Source | N Samples | Accuracy | Macro F1 |
815
+ |--------|-----------|----------|----------|
816
+ """
817
+ for src, data in per_source.items():
818
+ report += f"| {src} | {data['n_samples']} | {data['accuracy']:.2f}% | {data['macro_f1']:.4f} |\n"
819
+
820
+ report += f"""
821
+ ### Per-Class Accuracy by Source
822
+
823
+ | Class |"""
824
+ for src in per_source:
825
+ report += f" {src} |"
826
+ report += "\n|-------|"
827
+ for _ in per_source:
828
+ report += "--------|"
829
+ report += "\n"
830
+ for cn in CLASS_NAMES:
831
+ report += f"| {cn} |"
832
+ for src in per_source:
833
+ acc = per_source[src]['per_class_acc'].get(cn)
834
+ if acc is None:
835
+ report += " N/A |"
836
+ else:
837
+ report += f" {acc:.1f}% |"
838
+ report += "\n"
839
+
840
+ report += f"""
841
+ **Domain gap (overall accuracy)**: {domain_gap:.2f}pp between ODIR and APTOS
842
+ """
843
+ if 'ODIR' in per_source and 'APTOS' in per_source:
844
+ report += f"""**DR class gap (ODIR vs APTOS)**: ODIR={odir_dr:.1f}% vs APTOS={aptos_dr:.1f}% (gap={dr_gap:.1f}pp)
845
+ """
846
+
847
+ report += f"""
848
+ ---
849
+
850
+ ## 6. Error Pattern Summary
851
+
852
+ ### Q1: What is the model's biggest weakness?
853
+
854
+ The model's biggest weakness is classifying **{worst_f1_class}** (F1={per_class_f1[worst_f1_class]:.4f},
855
+ recall={per_class_recall[worst_f1_class]:.4f}). This class has the worst F1 score, indicating the
856
+ model struggles to both detect and correctly distinguish it from other pathologies.
857
+
858
+ The confusion matrix shows that the primary confusion pathway is:
859
+ - **{top5_pairs[0]['description']}**: {top5_pairs[0]['count']} cases ({top5_pairs[0]['rate']*100:.1f}% error rate)
860
+ - **{top5_pairs[1]['description']}**: {top5_pairs[1]['count']} cases ({top5_pairs[1]['rate']*100:.1f}% error rate)
861
+
862
+ ### Q2: Which class has the worst recall? Why?
863
+
864
+ **{worst_recall_class}** has the worst recall at {per_class_recall[worst_recall_class]:.4f}.
865
+ """
866
+
867
+ # Detailed reason based on support
868
+ worst_support = per_class_support[worst_recall_class]
869
+ all_support = sum(per_class_support.values())
870
+ worst_pct = worst_support / all_support * 100
871
+ report += f"""This class represents only {worst_support} samples ({worst_pct:.1f}% of the val set).
872
+ The low recall is likely caused by:
873
+ 1. **Class imbalance** — the model sees fewer examples during training and defaults to predicting
874
+ more common classes when uncertain.
875
+ 2. **Visual similarity** with other conditions (especially {top5_pairs[0]['pred_class'] if top5_pairs[0]['true_class']==worst_recall_class else 'Normal'})
876
+ at the fundus level.
877
+ 3. **Threshold sensitivity** — the optimized threshold ({thresholds.get(CLASS_NAMES.index(worst_recall_class), 0.5):.2f})
878
+ may overcorrect or undercorrect depending on the calibration.
879
+
880
+ ### Q3: Evidence of domain shift (ODIR vs APTOS)?
881
+
882
+ """
883
+ if domain_gap is not None and domain_gap > 2.0:
884
+ report += f"""YES — there is a **{domain_gap:.1f}pp accuracy gap** between ODIR ({odir_acc:.1f}%) and APTOS
885
+ ({aptos_acc:.1f}%). This is significant and consistent with domain shift between the two data sources.
886
+
887
+ For the DR/Diabetes class specifically, the gap is **{dr_gap:.1f}pp** (ODIR={odir_dr:.1f}% vs APTOS={aptos_dr:.1f}%).
888
+ APTOS images are specifically DR-graded fundus photographs from India (Aravind Eye Hospital),
889
+ while ODIR covers multiple disease classes with more varied image quality and capture conditions.
890
+ The Ben Graham preprocessing helps but does not fully bridge the domain gap.
891
+
892
+ **Implication for v3**: Domain-specific augmentation or source-aware training (e.g., source
893
+ as auxiliary input, separate batch norms, or domain adaptation) may improve generalization.
894
+ """
895
+ elif domain_gap is not None and domain_gap > 0:
896
+ report += f"""MINOR gap observed — {domain_gap:.1f}pp difference between ODIR ({odir_acc:.1f}%) and
897
+ APTOS ({aptos_acc:.1f}%). The gap is small, suggesting the Ben Graham preprocessing and ViT
898
+ architecture generalize reasonably across sources. DR-specific gap: {dr_gap:.1f}pp.
899
+ """
900
+ else:
901
+ report += "Insufficient cross-source data to conclude domain shift.\n"
902
+
903
+ report += f"""
904
+ ### Q4: Calibration assessment
905
+
906
+ ECE = **{ece:.4f}** (scale: 0=perfect, 0.1=poor).
907
+
908
+ """
909
+ if ece < 0.03:
910
+ report += "The model is **well-calibrated** (ECE < 0.03). Confidence scores are reliable."
911
+ elif ece < 0.07:
912
+ report += f"""The model shows **moderate miscalibration** (ECE={ece:.4f}). The reliability diagram
913
+ shows the model is {calibration_verdict} in the high-confidence range, meaning predicted
914
+ confidence scores are not fully reliable. Temperature scaling in v3 is recommended."""
915
+ else:
916
+ report += f"""The model is **poorly calibrated** (ECE={ece:.4f}). The {calibration_verdict}
917
+ pattern is severe. Temperature scaling or label smoothing in v3 training is strongly recommended."""
918
+
919
+ report += f"""
920
+
921
+ ---
922
+
923
+ ## 7. Recommendations for v3 Training
924
+
925
+ Based on this baseline analysis:
926
+
927
+ 1. **Address {worst_recall_class} recall** — increase class weight, targeted augmentation,
928
+ or focal loss gamma tuning for this class.
929
+ 2. **Calibration** — add temperature scaling post-training or increase label smoothing
930
+ (current ECE={ece:.4f}).
931
+ 3. **Domain shift mitigation** — consider source-conditioned augmentation or adversarial
932
+ domain adaptation if ODIR/APTOS gap persists.
933
+ 4. **High-confidence errors** — the model makes confidently wrong predictions on certain
934
+ classes; mixup or CutMix augmentation may improve uncertainty estimation.
935
+ 5. **Top confusion pairs** to specifically target:
936
+ """
937
+ for pair in top5_pairs[:3]:
938
+ report += f" - {pair['description']} ({pair['count']} errors)\n"
939
+
940
+ report += """
941
+ ---
942
+
943
+ ## 8. Output Files
944
+
945
+ | File | Description |
946
+ |------|-------------|
947
+ | confusion_matrix_raw.png | Raw count confusion matrix |
948
+ | confusion_matrix_normalized.png | Recall-normalized confusion matrix |
949
+ | reliability_diagram.png | ECE calibration plot |
950
+ | confidence_distributions.png | Per-class confidence histograms |
951
+ | per_source_performance.png | ODIR vs APTOS breakdown |
952
+ | baseline_metrics.json | All metrics in structured JSON |
953
+
954
+ ---
955
+ *Report generated by RetinaSense ViT v2 error analysis pipeline.*
956
+ """
957
+
958
+ with open(f'{OUT_DIR}/BASELINE_ANALYSIS.md', 'w') as f:
959
+ f.write(report)
960
+ print(f' Saved BASELINE_ANALYSIS.md')
961
+
962
+ # ================================================================
963
+ # FINAL SUMMARY
964
+ # ================================================================
965
+ print('\n' + '='*65)
966
+ print(' BASELINE ANALYSIS COMPLETE')
967
+ print('='*65)
968
+ print(f' Val accuracy (thresh) : {thresh_acc:.2f}%')
969
+ print(f' Macro F1 : {macro_f1:.4f}')
970
+ print(f' ECE : {ece:.4f}')
971
+ print(f' Worst class (F1) : {worst_f1_class} ({per_class_f1[worst_f1_class]:.4f})')
972
+ print(f' Worst class (recall) : {worst_recall_class} ({per_class_recall[worst_recall_class]:.4f})')
973
+ print(f' Top confusion : {top5_pairs[0]["description"]}')
974
+ if domain_gap is not None:
975
+ print(f' Domain gap (ODIR-APTOS): {domain_gap:.2f}pp')
976
+ print(f'\n All outputs in: {OUT_DIR}/')
977
+ print('='*65)