Add eval_dashboard.py
Browse files- eval_dashboard.py +752 -0
eval_dashboard.py
ADDED
|
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
RetinaSense v3.0 -- Phase 1A: Rich Evaluation Dashboard
|
| 4 |
+
========================================================
|
| 5 |
+
Standalone script that loads the trained ViT model, runs inference on the
|
| 6 |
+
full test set (1,281 images), and produces publication-quality evaluation
|
| 7 |
+
plots plus a structured metrics JSON report.
|
| 8 |
+
|
| 9 |
+
Outputs (all written to outputs_v3/evaluation/):
|
| 10 |
+
- confusion_matrix.png
|
| 11 |
+
- roc_curves_per_class.png
|
| 12 |
+
- precision_recall_curves.png
|
| 13 |
+
- calibration_reliability.png
|
| 14 |
+
- confidence_histograms.png
|
| 15 |
+
- error_analysis_by_source.png
|
| 16 |
+
- metrics_report.json
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
python eval_dashboard.py
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
import json
|
| 25 |
+
import warnings
|
| 26 |
+
import numpy as np
|
| 27 |
+
import pandas as pd
|
| 28 |
+
import cv2
|
| 29 |
+
import matplotlib
|
| 30 |
+
matplotlib.use('Agg')
|
| 31 |
+
import matplotlib.pyplot as plt
|
| 32 |
+
import matplotlib.ticker as mticker
|
| 33 |
+
import seaborn as sns
|
| 34 |
+
from PIL import Image
|
| 35 |
+
from collections import OrderedDict
|
| 36 |
+
|
| 37 |
+
warnings.filterwarnings('ignore')
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
import torch.nn as nn
|
| 41 |
+
import torch.nn.functional as F
|
| 42 |
+
from torch.utils.data import Dataset, DataLoader
|
| 43 |
+
from torchvision import transforms
|
| 44 |
+
import timm
|
| 45 |
+
|
| 46 |
+
from sklearn.metrics import (
|
| 47 |
+
confusion_matrix,
|
| 48 |
+
classification_report,
|
| 49 |
+
roc_curve,
|
| 50 |
+
auc,
|
| 51 |
+
precision_recall_curve,
|
| 52 |
+
average_precision_score,
|
| 53 |
+
f1_score,
|
| 54 |
+
accuracy_score,
|
| 55 |
+
cohen_kappa_score,
|
| 56 |
+
matthews_corrcoef,
|
| 57 |
+
balanced_accuracy_score,
|
| 58 |
+
log_loss,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# ================================================================
|
| 62 |
+
# CONFIGURATION
|
| 63 |
+
# ================================================================
|
| 64 |
+
BASE_DIR = '/teamspace/studios/this_studio'
|
| 65 |
+
OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3')
|
| 66 |
+
EVAL_DIR = os.path.join(OUTPUT_DIR, 'evaluation')
|
| 67 |
+
os.makedirs(EVAL_DIR, exist_ok=True)
|
| 68 |
+
|
| 69 |
+
MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth')
|
| 70 |
+
THRESHOLDS_PATH = os.path.join(OUTPUT_DIR, 'thresholds.json')
|
| 71 |
+
TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json')
|
| 72 |
+
TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv')
|
| 73 |
+
NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json')
|
| 74 |
+
|
| 75 |
+
NUM_CLASSES = 5
|
| 76 |
+
IMG_SIZE = 224
|
| 77 |
+
DROPOUT = 0.3
|
| 78 |
+
BATCH_SIZE = 32
|
| 79 |
+
|
| 80 |
+
CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
|
| 81 |
+
|
| 82 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 83 |
+
|
| 84 |
+
# Publication style defaults
|
| 85 |
+
plt.rcParams.update({
|
| 86 |
+
'font.size': 11,
|
| 87 |
+
'axes.titlesize': 13,
|
| 88 |
+
'axes.labelsize': 12,
|
| 89 |
+
'xtick.labelsize': 10,
|
| 90 |
+
'ytick.labelsize': 10,
|
| 91 |
+
'legend.fontsize': 10,
|
| 92 |
+
'figure.dpi': 300,
|
| 93 |
+
'savefig.dpi': 300,
|
| 94 |
+
'savefig.bbox': 'tight',
|
| 95 |
+
'savefig.pad_inches': 0.15,
|
| 96 |
+
'font.family': 'sans-serif',
|
| 97 |
+
})
|
| 98 |
+
|
| 99 |
+
print('=' * 65)
|
| 100 |
+
print(' RetinaSense v3.0 -- Phase 1A: Evaluation Dashboard')
|
| 101 |
+
print('=' * 65)
|
| 102 |
+
print(f' Device : {DEVICE}')
|
| 103 |
+
if torch.cuda.is_available():
|
| 104 |
+
print(f' GPU : {torch.cuda.get_device_name(0)}')
|
| 105 |
+
print(f' Output : {EVAL_DIR}')
|
| 106 |
+
print('=' * 65)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ================================================================
|
| 110 |
+
# LOAD NORMALISATION STATS
|
| 111 |
+
# ================================================================
|
| 112 |
+
if os.path.exists(NORM_STATS_PATH):
|
| 113 |
+
with open(NORM_STATS_PATH) as f:
|
| 114 |
+
norm_stats = json.load(f)
|
| 115 |
+
NORM_MEAN = norm_stats['mean_rgb']
|
| 116 |
+
NORM_STD = norm_stats['std_rgb']
|
| 117 |
+
print(f' Fundus norm stats loaded: mean={[round(v, 4) for v in NORM_MEAN]}, '
|
| 118 |
+
f'std={[round(v, 4) for v in NORM_STD]}')
|
| 119 |
+
else:
|
| 120 |
+
NORM_MEAN = [0.485, 0.456, 0.406]
|
| 121 |
+
NORM_STD = [0.229, 0.224, 0.225]
|
| 122 |
+
print(' Using ImageNet normalisation fallback')
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ================================================================
|
| 126 |
+
# MODEL ARCHITECTURE (mirrors retinasense_v3.py / gradcam_v3.py)
|
| 127 |
+
# ================================================================
|
| 128 |
+
class MultiTaskViT(nn.Module):
|
| 129 |
+
"""ViT-Base-Patch16-224 with disease + severity heads."""
|
| 130 |
+
|
| 131 |
+
def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.backbone = timm.create_model(
|
| 134 |
+
'vit_base_patch16_224', pretrained=False, num_classes=0
|
| 135 |
+
)
|
| 136 |
+
feat = 768 # CLS token dimension
|
| 137 |
+
self.drop = nn.Dropout(drop)
|
| 138 |
+
self.disease_head = nn.Sequential(
|
| 139 |
+
nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
|
| 140 |
+
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
|
| 141 |
+
nn.Linear(256, n_disease),
|
| 142 |
+
)
|
| 143 |
+
self.severity_head = nn.Sequential(
|
| 144 |
+
nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
|
| 145 |
+
nn.Linear(256, n_severity),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
f = self.backbone(x) # (B, 768) CLS token features
|
| 150 |
+
f = self.drop(f)
|
| 151 |
+
return self.disease_head(f), self.severity_head(f)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ================================================================
|
| 155 |
+
# LOAD MODEL + CALIBRATION ARTIFACTS
|
| 156 |
+
# ================================================================
|
| 157 |
+
print('\nLoading model...')
|
| 158 |
+
model = MultiTaskViT().to(DEVICE)
|
| 159 |
+
ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
|
| 160 |
+
model.load_state_dict(ckpt['model_state_dict'])
|
| 161 |
+
model.eval()
|
| 162 |
+
print(f' Loaded: {MODEL_PATH}')
|
| 163 |
+
print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1} '
|
| 164 |
+
f'val_acc={ckpt.get("val_acc", 0):.2f}%')
|
| 165 |
+
|
| 166 |
+
with open(THRESHOLDS_PATH) as f:
|
| 167 |
+
thr_data = json.load(f)
|
| 168 |
+
THRESHOLDS = thr_data['thresholds']
|
| 169 |
+
|
| 170 |
+
with open(TEMPERATURE_PATH) as f:
|
| 171 |
+
temp_data = json.load(f)
|
| 172 |
+
TEMPERATURE = temp_data['temperature']
|
| 173 |
+
|
| 174 |
+
print(f' Temperature T = {TEMPERATURE:.4f}')
|
| 175 |
+
print(f' Thresholds = {[round(t, 3) for t in THRESHOLDS]}')
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ================================================================
|
| 179 |
+
# DATASET
|
| 180 |
+
# ================================================================
|
| 181 |
+
class TestDataset(Dataset):
|
| 182 |
+
"""
|
| 183 |
+
Test dataset that loads from preprocessed .npy cache (fast path).
|
| 184 |
+
Falls back to on-the-fly preprocessing if cache is missing.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(self, df, transform):
|
| 188 |
+
self.df = df.reset_index(drop=True)
|
| 189 |
+
self.transform = transform
|
| 190 |
+
|
| 191 |
+
def __len__(self):
|
| 192 |
+
return len(self.df)
|
| 193 |
+
|
| 194 |
+
def __getitem__(self, idx):
|
| 195 |
+
row = self.df.iloc[idx]
|
| 196 |
+
|
| 197 |
+
# Try cache path first
|
| 198 |
+
cache_fp = row.get('cache_path', '')
|
| 199 |
+
img = None
|
| 200 |
+
|
| 201 |
+
if cache_fp and os.path.exists(cache_fp):
|
| 202 |
+
try:
|
| 203 |
+
img = np.load(cache_fp)
|
| 204 |
+
except Exception:
|
| 205 |
+
img = None
|
| 206 |
+
|
| 207 |
+
# Fallback: on-the-fly preprocessing
|
| 208 |
+
if img is None:
|
| 209 |
+
image_path = row['image_path']
|
| 210 |
+
if not os.path.isabs(image_path):
|
| 211 |
+
clean = image_path
|
| 212 |
+
while clean.startswith('./') or clean.startswith('.//'):
|
| 213 |
+
clean = clean[2:] if clean.startswith('./') else clean[3:]
|
| 214 |
+
image_path = os.path.join(BASE_DIR, clean)
|
| 215 |
+
|
| 216 |
+
source = row.get('source', 'ODIR')
|
| 217 |
+
try:
|
| 218 |
+
if source == 'APTOS':
|
| 219 |
+
img = self._ben_graham(image_path)
|
| 220 |
+
else:
|
| 221 |
+
img = self._clahe_preprocess(image_path)
|
| 222 |
+
except Exception:
|
| 223 |
+
img = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
|
| 224 |
+
|
| 225 |
+
img_tensor = self.transform(img)
|
| 226 |
+
disease_lbl = int(row['disease_label'])
|
| 227 |
+
source = row.get('source', 'unknown')
|
| 228 |
+
return img_tensor, disease_lbl, source
|
| 229 |
+
|
| 230 |
+
@staticmethod
|
| 231 |
+
def _ben_graham(path, sz=IMG_SIZE, sigma=10):
|
| 232 |
+
raw = cv2.imread(path)
|
| 233 |
+
if raw is None:
|
| 234 |
+
raw = np.array(Image.open(path).convert('RGB'))
|
| 235 |
+
raw = cv2.cvtColor(raw, cv2.COLOR_RGB2BGR)
|
| 236 |
+
raw = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)
|
| 237 |
+
raw = cv2.resize(raw, (sz, sz))
|
| 238 |
+
raw = cv2.addWeighted(raw, 4, cv2.GaussianBlur(raw, (0, 0), sigma), -4, 128)
|
| 239 |
+
mask = np.zeros(raw.shape[:2], dtype=np.uint8)
|
| 240 |
+
cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1)
|
| 241 |
+
return cv2.bitwise_and(raw, raw, mask=mask)
|
| 242 |
+
|
| 243 |
+
@staticmethod
|
| 244 |
+
def _clahe_preprocess(path, sz=IMG_SIZE):
|
| 245 |
+
raw = cv2.imread(path)
|
| 246 |
+
if raw is None:
|
| 247 |
+
raw = np.array(Image.open(path).convert('RGB'))
|
| 248 |
+
raw = cv2.cvtColor(raw, cv2.COLOR_RGB2BGR)
|
| 249 |
+
raw = cv2.resize(raw, (sz, sz))
|
| 250 |
+
lab = cv2.cvtColor(raw, cv2.COLOR_BGR2LAB)
|
| 251 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 252 |
+
lab[:, :, 0] = clahe.apply(lab[:, :, 0])
|
| 253 |
+
raw = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
|
| 254 |
+
return cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
val_transform = transforms.Compose([
|
| 258 |
+
transforms.ToPILImage(),
|
| 259 |
+
transforms.ToTensor(),
|
| 260 |
+
transforms.Normalize(NORM_MEAN, NORM_STD),
|
| 261 |
+
])
|
| 262 |
+
|
| 263 |
+
print('\nLoading test set...')
|
| 264 |
+
test_df = pd.read_csv(TEST_CSV)
|
| 265 |
+
print(f' Test samples: {len(test_df)}')
|
| 266 |
+
print(f' Sources : {sorted(test_df["source"].unique())}')
|
| 267 |
+
print(f' Class dist : {test_df["disease_label"].value_counts().sort_index().to_dict()}')
|
| 268 |
+
|
| 269 |
+
test_ds = TestDataset(test_df, val_transform)
|
| 270 |
+
test_loader = DataLoader(
|
| 271 |
+
test_ds, batch_size=BATCH_SIZE, shuffle=False,
|
| 272 |
+
num_workers=4, pin_memory=True,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# ================================================================
|
| 277 |
+
# INFERENCE
|
| 278 |
+
# ================================================================
|
| 279 |
+
print('\nRunning inference on full test set...')
|
| 280 |
+
all_logits = []
|
| 281 |
+
all_labels = []
|
| 282 |
+
all_sources = []
|
| 283 |
+
|
| 284 |
+
with torch.no_grad():
|
| 285 |
+
for imgs, labels, sources in test_loader:
|
| 286 |
+
imgs = imgs.to(DEVICE)
|
| 287 |
+
disease_logits, _ = model(imgs)
|
| 288 |
+
all_logits.append(disease_logits.cpu())
|
| 289 |
+
all_labels.extend(labels.numpy().tolist())
|
| 290 |
+
all_sources.extend(sources)
|
| 291 |
+
|
| 292 |
+
all_logits = torch.cat(all_logits, dim=0) # (N, 5)
|
| 293 |
+
all_labels = np.array(all_labels)
|
| 294 |
+
all_sources = np.array(all_sources)
|
| 295 |
+
N = len(all_labels)
|
| 296 |
+
print(f' Inference complete: {N} samples')
|
| 297 |
+
|
| 298 |
+
# Temperature-scaled probabilities
|
| 299 |
+
probs_calibrated = F.softmax(all_logits / TEMPERATURE, dim=1).numpy() # (N, 5)
|
| 300 |
+
probs_uncalibrated = F.softmax(all_logits, dim=1).numpy()
|
| 301 |
+
|
| 302 |
+
# Predictions: argmax of calibrated probabilities
|
| 303 |
+
preds = np.argmax(probs_calibrated, axis=1)
|
| 304 |
+
confidences = np.max(probs_calibrated, axis=1)
|
| 305 |
+
|
| 306 |
+
correct_mask = (preds == all_labels)
|
| 307 |
+
acc = accuracy_score(all_labels, preds)
|
| 308 |
+
print(f' Overall accuracy: {acc:.4f} ({int(acc * N)}/{N})')
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# ================================================================
|
| 312 |
+
# 1. CONFUSION MATRIX
|
| 313 |
+
# ================================================================
|
| 314 |
+
print('\n[1/7] Confusion matrix...')
|
| 315 |
+
cm = confusion_matrix(all_labels, preds, labels=list(range(NUM_CLASSES)))
|
| 316 |
+
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
|
| 317 |
+
|
| 318 |
+
fig, ax = plt.subplots(figsize=(7, 6))
|
| 319 |
+
sns.heatmap(
|
| 320 |
+
cm_norm, annot=True, fmt='.2f', cmap='Blues',
|
| 321 |
+
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
|
| 322 |
+
linewidths=0.5, linecolor='white',
|
| 323 |
+
cbar_kws={'label': 'Proportion', 'shrink': 0.8},
|
| 324 |
+
ax=ax, vmin=0, vmax=1,
|
| 325 |
+
)
|
| 326 |
+
# Overlay raw counts in smaller font
|
| 327 |
+
for i in range(NUM_CLASSES):
|
| 328 |
+
for j in range(NUM_CLASSES):
|
| 329 |
+
ax.text(j + 0.5, i + 0.72, f'(n={cm[i, j]})',
|
| 330 |
+
ha='center', va='center', fontsize=7, color='gray')
|
| 331 |
+
|
| 332 |
+
ax.set_xlabel('Predicted Class')
|
| 333 |
+
ax.set_ylabel('True Class')
|
| 334 |
+
ax.set_title('Normalized Confusion Matrix (Test Set)')
|
| 335 |
+
fig.tight_layout()
|
| 336 |
+
fig.savefig(os.path.join(EVAL_DIR, 'confusion_matrix.png'))
|
| 337 |
+
plt.close(fig)
|
| 338 |
+
print(' Saved confusion_matrix.png')
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# ================================================================
|
| 342 |
+
# 2. ROC CURVES PER CLASS
|
| 343 |
+
# ================================================================
|
| 344 |
+
print('[2/7] ROC curves...')
|
| 345 |
+
fig, ax = plt.subplots(figsize=(7, 6))
|
| 346 |
+
colors = sns.color_palette('tab10', NUM_CLASSES)
|
| 347 |
+
all_fpr_tpr = {}
|
| 348 |
+
macro_auc_list = []
|
| 349 |
+
|
| 350 |
+
for i in range(NUM_CLASSES):
|
| 351 |
+
y_true_bin = (all_labels == i).astype(int)
|
| 352 |
+
y_score = probs_calibrated[:, i]
|
| 353 |
+
fpr, tpr, _ = roc_curve(y_true_bin, y_score)
|
| 354 |
+
roc_auc = auc(fpr, tpr)
|
| 355 |
+
macro_auc_list.append(roc_auc)
|
| 356 |
+
all_fpr_tpr[i] = (fpr, tpr)
|
| 357 |
+
ax.plot(fpr, tpr, color=colors[i], lw=2,
|
| 358 |
+
label=f'{CLASS_NAMES[i]} (AUC={roc_auc:.3f})')
|
| 359 |
+
|
| 360 |
+
# Macro average ROC
|
| 361 |
+
mean_fpr = np.linspace(0, 1, 200)
|
| 362 |
+
mean_tpr = np.zeros_like(mean_fpr)
|
| 363 |
+
for i in range(NUM_CLASSES):
|
| 364 |
+
mean_tpr += np.interp(mean_fpr, all_fpr_tpr[i][0], all_fpr_tpr[i][1])
|
| 365 |
+
mean_tpr /= NUM_CLASSES
|
| 366 |
+
macro_auc = auc(mean_fpr, mean_tpr)
|
| 367 |
+
ax.plot(mean_fpr, mean_tpr, 'k--', lw=2.5,
|
| 368 |
+
label=f'Macro-average (AUC={macro_auc:.3f})')
|
| 369 |
+
ax.plot([0, 1], [0, 1], 'k:', lw=1, alpha=0.4)
|
| 370 |
+
|
| 371 |
+
ax.set_xlim([-0.02, 1.02])
|
| 372 |
+
ax.set_ylim([-0.02, 1.05])
|
| 373 |
+
ax.set_xlabel('False Positive Rate')
|
| 374 |
+
ax.set_ylabel('True Positive Rate')
|
| 375 |
+
ax.set_title('One-vs-Rest ROC Curves (Calibrated)')
|
| 376 |
+
ax.legend(loc='lower right', framealpha=0.9)
|
| 377 |
+
ax.grid(True, alpha=0.3)
|
| 378 |
+
fig.tight_layout()
|
| 379 |
+
fig.savefig(os.path.join(EVAL_DIR, 'roc_curves_per_class.png'))
|
| 380 |
+
plt.close(fig)
|
| 381 |
+
print(' Saved roc_curves_per_class.png')
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
# ================================================================
|
| 385 |
+
# 3. PRECISION-RECALL CURVES
|
| 386 |
+
# ================================================================
|
| 387 |
+
print('[3/7] Precision-recall curves...')
|
| 388 |
+
fig, ax = plt.subplots(figsize=(7, 6))
|
| 389 |
+
|
| 390 |
+
for i in range(NUM_CLASSES):
|
| 391 |
+
y_true_bin = (all_labels == i).astype(int)
|
| 392 |
+
y_score = probs_calibrated[:, i]
|
| 393 |
+
prec, rec, _ = precision_recall_curve(y_true_bin, y_score)
|
| 394 |
+
ap = average_precision_score(y_true_bin, y_score)
|
| 395 |
+
ax.plot(rec, prec, color=colors[i], lw=2,
|
| 396 |
+
label=f'{CLASS_NAMES[i]} (AP={ap:.3f})')
|
| 397 |
+
|
| 398 |
+
# Add prevalence baselines
|
| 399 |
+
prevalences = np.bincount(all_labels, minlength=NUM_CLASSES) / N
|
| 400 |
+
for i in range(NUM_CLASSES):
|
| 401 |
+
ax.axhline(y=prevalences[i], color=colors[i], ls=':', alpha=0.3)
|
| 402 |
+
|
| 403 |
+
ax.set_xlim([-0.02, 1.02])
|
| 404 |
+
ax.set_ylim([-0.02, 1.05])
|
| 405 |
+
ax.set_xlabel('Recall')
|
| 406 |
+
ax.set_ylabel('Precision')
|
| 407 |
+
ax.set_title('Precision-Recall Curves (Calibrated)')
|
| 408 |
+
ax.legend(loc='upper right', framealpha=0.9)
|
| 409 |
+
ax.grid(True, alpha=0.3)
|
| 410 |
+
fig.tight_layout()
|
| 411 |
+
fig.savefig(os.path.join(EVAL_DIR, 'precision_recall_curves.png'))
|
| 412 |
+
plt.close(fig)
|
| 413 |
+
print(' Saved precision_recall_curves.png')
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# ================================================================
|
| 417 |
+
# 4. CALIBRATION RELIABILITY DIAGRAM
|
| 418 |
+
# ================================================================
|
| 419 |
+
print('[4/7] Calibration reliability diagram...')
|
| 420 |
+
n_bins = 10
|
| 421 |
+
bin_edges = np.linspace(0, 1, n_bins + 1)
|
| 422 |
+
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
|
| 423 |
+
|
| 424 |
+
# Compute calibration for both calibrated and uncalibrated probabilities
|
| 425 |
+
def compute_calibration(confidences_arr, correct_arr, bin_edges):
|
| 426 |
+
"""Compute per-bin accuracy and average confidence."""
|
| 427 |
+
bin_accs = []
|
| 428 |
+
bin_confs = []
|
| 429 |
+
bin_counts = []
|
| 430 |
+
for lo, hi in zip(bin_edges[:-1], bin_edges[1:]):
|
| 431 |
+
mask = (confidences_arr > lo) & (confidences_arr <= hi)
|
| 432 |
+
if mask.sum() == 0:
|
| 433 |
+
bin_accs.append(np.nan)
|
| 434 |
+
bin_confs.append(np.nan)
|
| 435 |
+
bin_counts.append(0)
|
| 436 |
+
else:
|
| 437 |
+
bin_accs.append(correct_arr[mask].mean())
|
| 438 |
+
bin_confs.append(confidences_arr[mask].mean())
|
| 439 |
+
bin_counts.append(int(mask.sum()))
|
| 440 |
+
return np.array(bin_accs), np.array(bin_confs), np.array(bin_counts)
|
| 441 |
+
|
| 442 |
+
conf_calib = np.max(probs_calibrated, axis=1)
|
| 443 |
+
conf_uncalib = np.max(probs_uncalibrated, axis=1)
|
| 444 |
+
|
| 445 |
+
bin_accs_cal, bin_confs_cal, bin_counts_cal = compute_calibration(
|
| 446 |
+
conf_calib, correct_mask.astype(float), bin_edges)
|
| 447 |
+
bin_accs_uncal, bin_confs_uncal, bin_counts_uncal = compute_calibration(
|
| 448 |
+
conf_uncalib, correct_mask.astype(float), bin_edges)
|
| 449 |
+
|
| 450 |
+
# ECE
|
| 451 |
+
ece_cal = np.nansum(
|
| 452 |
+
np.abs(bin_accs_cal - bin_confs_cal) * bin_counts_cal) / N
|
| 453 |
+
ece_uncal = np.nansum(
|
| 454 |
+
np.abs(bin_accs_uncal - bin_confs_uncal) * bin_counts_uncal) / N
|
| 455 |
+
|
| 456 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
| 457 |
+
|
| 458 |
+
for ax_idx, (b_accs, b_confs, b_counts, ece_val, title_suffix) in enumerate([
|
| 459 |
+
(bin_accs_cal, bin_confs_cal, bin_counts_cal, ece_cal, 'Calibrated'),
|
| 460 |
+
(bin_accs_uncal, bin_confs_uncal, bin_counts_uncal, ece_uncal, 'Uncalibrated'),
|
| 461 |
+
]):
|
| 462 |
+
ax = axes[ax_idx]
|
| 463 |
+
# Perfect calibration line
|
| 464 |
+
ax.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.5, label='Perfectly calibrated')
|
| 465 |
+
# Bar chart of bin accuracy
|
| 466 |
+
valid = ~np.isnan(b_accs)
|
| 467 |
+
bar_color = '#4C72B0' if ax_idx == 0 else '#DD8452'
|
| 468 |
+
ax.bar(bin_centers[valid], b_accs[valid], width=0.08,
|
| 469 |
+
alpha=0.7, color=bar_color, edgecolor='black', linewidth=0.5,
|
| 470 |
+
label=f'Model (ECE={ece_val:.4f})')
|
| 471 |
+
# Gap shading
|
| 472 |
+
for j in range(n_bins):
|
| 473 |
+
if valid[j]:
|
| 474 |
+
lo_val = min(b_accs[j], b_confs[j])
|
| 475 |
+
hi_val = max(b_accs[j], b_confs[j])
|
| 476 |
+
ax.fill_between(
|
| 477 |
+
[bin_centers[j] - 0.04, bin_centers[j] + 0.04],
|
| 478 |
+
lo_val, hi_val, alpha=0.15, color='red')
|
| 479 |
+
# Sample counts on top
|
| 480 |
+
for j in range(n_bins):
|
| 481 |
+
if valid[j] and b_counts[j] > 0:
|
| 482 |
+
ax.text(bin_centers[j], b_accs[j] + 0.03,
|
| 483 |
+
str(b_counts[j]), ha='center', va='bottom', fontsize=7)
|
| 484 |
+
|
| 485 |
+
ax.set_xlim([0, 1])
|
| 486 |
+
ax.set_ylim([0, 1.1])
|
| 487 |
+
ax.set_xlabel('Mean Predicted Confidence')
|
| 488 |
+
ax.set_ylabel('Fraction of Correct Predictions')
|
| 489 |
+
ax.set_title(f'Reliability Diagram ({title_suffix})')
|
| 490 |
+
ax.legend(loc='upper left', framealpha=0.9)
|
| 491 |
+
ax.grid(True, alpha=0.3)
|
| 492 |
+
|
| 493 |
+
fig.tight_layout()
|
| 494 |
+
fig.savefig(os.path.join(EVAL_DIR, 'calibration_reliability.png'))
|
| 495 |
+
plt.close(fig)
|
| 496 |
+
print(f' Saved calibration_reliability.png (ECE_cal={ece_cal:.4f}, ECE_uncal={ece_uncal:.4f})')
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
# ================================================================
|
| 500 |
+
# 5. CONFIDENCE HISTOGRAMS
|
| 501 |
+
# ================================================================
|
| 502 |
+
print('[5/7] Confidence histograms...')
|
| 503 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
| 504 |
+
|
| 505 |
+
# Correct vs Incorrect
|
| 506 |
+
for ax_idx, (mask, label, color) in enumerate([
|
| 507 |
+
(correct_mask, 'Correct', '#2ca02c'),
|
| 508 |
+
(~correct_mask, 'Incorrect', '#d62728'),
|
| 509 |
+
]):
|
| 510 |
+
axes[0].hist(confidences[mask], bins=30, alpha=0.65, color=color,
|
| 511 |
+
label=f'{label} (n={mask.sum()})', edgecolor='black', linewidth=0.3)
|
| 512 |
+
|
| 513 |
+
axes[0].set_xlabel('Prediction Confidence')
|
| 514 |
+
axes[0].set_ylabel('Count')
|
| 515 |
+
axes[0].set_title('Confidence Distribution: Correct vs Incorrect')
|
| 516 |
+
axes[0].legend(loc='upper left', framealpha=0.9)
|
| 517 |
+
axes[0].axvline(x=np.median(confidences[correct_mask]), color='#2ca02c',
|
| 518 |
+
ls='--', alpha=0.6, label='_nolegend_')
|
| 519 |
+
axes[0].axvline(x=np.median(confidences[~correct_mask]), color='#d62728',
|
| 520 |
+
ls='--', alpha=0.6, label='_nolegend_')
|
| 521 |
+
axes[0].grid(True, alpha=0.3, axis='y')
|
| 522 |
+
|
| 523 |
+
# Per-class confidence
|
| 524 |
+
for i in range(NUM_CLASSES):
|
| 525 |
+
cls_mask = (all_labels == i)
|
| 526 |
+
axes[1].hist(confidences[cls_mask], bins=20, alpha=0.5, color=colors[i],
|
| 527 |
+
label=f'{CLASS_NAMES[i]} (n={cls_mask.sum()})',
|
| 528 |
+
edgecolor='black', linewidth=0.3)
|
| 529 |
+
|
| 530 |
+
axes[1].set_xlabel('Prediction Confidence')
|
| 531 |
+
axes[1].set_ylabel('Count')
|
| 532 |
+
axes[1].set_title('Confidence Distribution by True Class')
|
| 533 |
+
axes[1].legend(loc='upper left', framealpha=0.9, fontsize=9)
|
| 534 |
+
axes[1].grid(True, alpha=0.3, axis='y')
|
| 535 |
+
|
| 536 |
+
fig.tight_layout()
|
| 537 |
+
fig.savefig(os.path.join(EVAL_DIR, 'confidence_histograms.png'))
|
| 538 |
+
plt.close(fig)
|
| 539 |
+
print(' Saved confidence_histograms.png')
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
# ================================================================
|
| 543 |
+
# 6. ERROR ANALYSIS BY SOURCE
|
| 544 |
+
# ================================================================
|
| 545 |
+
print('[6/7] Error analysis by source...')
|
| 546 |
+
sources_unique = sorted(np.unique(all_sources))
|
| 547 |
+
n_sources = len(sources_unique)
|
| 548 |
+
|
| 549 |
+
# Build accuracy per (source, class) pair
|
| 550 |
+
source_class_acc = {}
|
| 551 |
+
source_class_n = {}
|
| 552 |
+
for src in sources_unique:
|
| 553 |
+
for cls_idx in range(NUM_CLASSES):
|
| 554 |
+
mask = (all_sources == src) & (all_labels == cls_idx)
|
| 555 |
+
n_cls = mask.sum()
|
| 556 |
+
if n_cls > 0:
|
| 557 |
+
acc_sc = (preds[mask] == all_labels[mask]).mean()
|
| 558 |
+
else:
|
| 559 |
+
acc_sc = np.nan
|
| 560 |
+
source_class_acc[(src, cls_idx)] = acc_sc
|
| 561 |
+
source_class_n[(src, cls_idx)] = int(n_cls)
|
| 562 |
+
|
| 563 |
+
# Also overall accuracy per source
|
| 564 |
+
source_overall_acc = {}
|
| 565 |
+
for src in sources_unique:
|
| 566 |
+
mask = (all_sources == src)
|
| 567 |
+
source_overall_acc[src] = accuracy_score(all_labels[mask], preds[mask])
|
| 568 |
+
|
| 569 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
| 570 |
+
|
| 571 |
+
# Left panel: grouped bar chart of per-class accuracy by source
|
| 572 |
+
x = np.arange(NUM_CLASSES)
|
| 573 |
+
bar_width = 0.8 / max(n_sources, 1)
|
| 574 |
+
source_colors = sns.color_palette('Set2', n_sources)
|
| 575 |
+
|
| 576 |
+
for s_idx, src in enumerate(sources_unique):
|
| 577 |
+
accs = [source_class_acc[(src, c)] for c in range(NUM_CLASSES)]
|
| 578 |
+
counts = [source_class_n[(src, c)] for c in range(NUM_CLASSES)]
|
| 579 |
+
offset = (s_idx - n_sources / 2 + 0.5) * bar_width
|
| 580 |
+
bars = axes[0].bar(x + offset, accs, bar_width * 0.9,
|
| 581 |
+
label=f'{src} (n={sum(counts)})',
|
| 582 |
+
color=source_colors[s_idx], edgecolor='black', linewidth=0.5)
|
| 583 |
+
# Annotate sample counts
|
| 584 |
+
for j, (b, n_val) in enumerate(zip(bars, counts)):
|
| 585 |
+
if n_val > 0 and not np.isnan(accs[j]):
|
| 586 |
+
axes[0].text(b.get_x() + b.get_width() / 2, b.get_height() + 0.02,
|
| 587 |
+
str(n_val), ha='center', va='bottom', fontsize=7)
|
| 588 |
+
|
| 589 |
+
axes[0].set_xticks(x)
|
| 590 |
+
axes[0].set_xticklabels(CLASS_NAMES, rotation=15, ha='right')
|
| 591 |
+
axes[0].set_ylabel('Accuracy')
|
| 592 |
+
axes[0].set_title('Per-Class Accuracy by Data Source')
|
| 593 |
+
axes[0].set_ylim([0, 1.15])
|
| 594 |
+
axes[0].legend(loc='upper right', framealpha=0.9)
|
| 595 |
+
axes[0].grid(True, alpha=0.3, axis='y')
|
| 596 |
+
axes[0].axhline(y=acc, color='black', ls='--', alpha=0.4, lw=1)
|
| 597 |
+
axes[0].text(NUM_CLASSES - 0.5, acc + 0.02, f'Overall: {acc:.3f}',
|
| 598 |
+
ha='right', fontsize=9, alpha=0.6)
|
| 599 |
+
|
| 600 |
+
# Right panel: confusion breakdown -- most common misclassifications per source
|
| 601 |
+
error_data = []
|
| 602 |
+
for src in sources_unique:
|
| 603 |
+
src_mask = (all_sources == src) & (~correct_mask)
|
| 604 |
+
if src_mask.sum() == 0:
|
| 605 |
+
continue
|
| 606 |
+
for true_cls in range(NUM_CLASSES):
|
| 607 |
+
for pred_cls in range(NUM_CLASSES):
|
| 608 |
+
if true_cls == pred_cls:
|
| 609 |
+
continue
|
| 610 |
+
pair_mask = src_mask & (all_labels == true_cls) & (preds == pred_cls)
|
| 611 |
+
cnt = pair_mask.sum()
|
| 612 |
+
if cnt > 0:
|
| 613 |
+
error_data.append({
|
| 614 |
+
'Source': src,
|
| 615 |
+
'Error': f'{CLASS_NAMES[true_cls][:3]}>{CLASS_NAMES[pred_cls][:3]}',
|
| 616 |
+
'Count': int(cnt),
|
| 617 |
+
})
|
| 618 |
+
|
| 619 |
+
if error_data:
|
| 620 |
+
err_df = pd.DataFrame(error_data)
|
| 621 |
+
# Top 10 error types
|
| 622 |
+
top_errors = (err_df.groupby('Error')['Count'].sum()
|
| 623 |
+
.sort_values(ascending=False).head(10).index.tolist())
|
| 624 |
+
err_df_top = err_df[err_df['Error'].isin(top_errors)]
|
| 625 |
+
pivot = err_df_top.pivot_table(index='Error', columns='Source',
|
| 626 |
+
values='Count', aggfunc='sum', fill_value=0)
|
| 627 |
+
# Reorder by total count
|
| 628 |
+
pivot = pivot.loc[pivot.sum(axis=1).sort_values(ascending=True).index]
|
| 629 |
+
pivot.plot(kind='barh', stacked=True, ax=axes[1],
|
| 630 |
+
color=source_colors[:n_sources], edgecolor='black', linewidth=0.5)
|
| 631 |
+
axes[1].set_xlabel('Error Count')
|
| 632 |
+
axes[1].set_title('Top Misclassification Patterns by Source')
|
| 633 |
+
axes[1].legend(loc='lower right', framealpha=0.9)
|
| 634 |
+
axes[1].grid(True, alpha=0.3, axis='x')
|
| 635 |
+
else:
|
| 636 |
+
axes[1].text(0.5, 0.5, 'No errors to display', ha='center', va='center',
|
| 637 |
+
transform=axes[1].transAxes, fontsize=14)
|
| 638 |
+
axes[1].set_title('Top Misclassification Patterns by Source')
|
| 639 |
+
|
| 640 |
+
fig.tight_layout()
|
| 641 |
+
fig.savefig(os.path.join(EVAL_DIR, 'error_analysis_by_source.png'))
|
| 642 |
+
plt.close(fig)
|
| 643 |
+
print(' Saved error_analysis_by_source.png')
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
# ================================================================
|
| 647 |
+
# 7. METRICS REPORT (JSON)
|
| 648 |
+
# ================================================================
|
| 649 |
+
print('[7/7] Metrics report...')
|
| 650 |
+
|
| 651 |
+
# Classification report as dict
|
| 652 |
+
cls_report = classification_report(
|
| 653 |
+
all_labels, preds, target_names=CLASS_NAMES,
|
| 654 |
+
output_dict=True, zero_division=0)
|
| 655 |
+
|
| 656 |
+
# Per-class AUC and AP
|
| 657 |
+
per_class_auc = {}
|
| 658 |
+
per_class_ap = {}
|
| 659 |
+
for i in range(NUM_CLASSES):
|
| 660 |
+
y_bin = (all_labels == i).astype(int)
|
| 661 |
+
y_score = probs_calibrated[:, i]
|
| 662 |
+
fpr_i, tpr_i, _ = roc_curve(y_bin, y_score)
|
| 663 |
+
per_class_auc[CLASS_NAMES[i]] = float(auc(fpr_i, tpr_i))
|
| 664 |
+
per_class_ap[CLASS_NAMES[i]] = float(average_precision_score(y_bin, y_score))
|
| 665 |
+
|
| 666 |
+
# Build the full report
|
| 667 |
+
try:
|
| 668 |
+
ll = float(log_loss(all_labels, probs_calibrated))
|
| 669 |
+
except Exception:
|
| 670 |
+
ll = None
|
| 671 |
+
|
| 672 |
+
metrics_report = OrderedDict([
|
| 673 |
+
('n_test_samples', int(N)),
|
| 674 |
+
('overall_accuracy', float(acc)),
|
| 675 |
+
('balanced_accuracy', float(balanced_accuracy_score(all_labels, preds))),
|
| 676 |
+
('macro_f1', float(f1_score(all_labels, preds, average='macro', zero_division=0))),
|
| 677 |
+
('weighted_f1', float(f1_score(all_labels, preds, average='weighted', zero_division=0))),
|
| 678 |
+
('cohen_kappa', float(cohen_kappa_score(all_labels, preds))),
|
| 679 |
+
('matthews_corrcoef', float(matthews_corrcoef(all_labels, preds))),
|
| 680 |
+
('log_loss', ll),
|
| 681 |
+
('macro_auc', float(np.mean(list(per_class_auc.values())))),
|
| 682 |
+
('ece_calibrated', float(ece_cal)),
|
| 683 |
+
('ece_uncalibrated', float(ece_uncal)),
|
| 684 |
+
('temperature', float(TEMPERATURE)),
|
| 685 |
+
('thresholds', THRESHOLDS),
|
| 686 |
+
('per_class_metrics', {}),
|
| 687 |
+
('per_class_auc', per_class_auc),
|
| 688 |
+
('per_class_ap', per_class_ap),
|
| 689 |
+
('confusion_matrix_raw', cm.tolist()),
|
| 690 |
+
('confusion_matrix_normalized', np.round(cm_norm, 4).tolist()),
|
| 691 |
+
('source_accuracy', {src: float(v) for src, v in source_overall_acc.items()}),
|
| 692 |
+
('source_class_counts', {
|
| 693 |
+
src: {CLASS_NAMES[c]: source_class_n[(src, c)]
|
| 694 |
+
for c in range(NUM_CLASSES)}
|
| 695 |
+
for src in sources_unique
|
| 696 |
+
}),
|
| 697 |
+
('class_names', CLASS_NAMES),
|
| 698 |
+
])
|
| 699 |
+
|
| 700 |
+
# Per-class from classification_report
|
| 701 |
+
for i, name in enumerate(CLASS_NAMES):
|
| 702 |
+
metrics_report['per_class_metrics'][name] = {
|
| 703 |
+
'precision': float(cls_report[name]['precision']),
|
| 704 |
+
'recall': float(cls_report[name]['recall']),
|
| 705 |
+
'f1-score': float(cls_report[name]['f1-score']),
|
| 706 |
+
'support': int(cls_report[name]['support']),
|
| 707 |
+
'auc': per_class_auc[name],
|
| 708 |
+
'average_precision': per_class_ap[name],
|
| 709 |
+
}
|
| 710 |
+
|
| 711 |
+
report_path = os.path.join(EVAL_DIR, 'metrics_report.json')
|
| 712 |
+
with open(report_path, 'w') as f:
|
| 713 |
+
json.dump(metrics_report, f, indent=2)
|
| 714 |
+
print(f' Saved metrics_report.json')
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
# ================================================================
|
| 718 |
+
# SUMMARY
|
| 719 |
+
# ================================================================
|
| 720 |
+
print('\n' + '=' * 65)
|
| 721 |
+
print(' EVALUATION DASHBOARD COMPLETE')
|
| 722 |
+
print('=' * 65)
|
| 723 |
+
print(f' Overall Accuracy : {acc:.4f}')
|
| 724 |
+
print(f' Balanced Accuracy : {metrics_report["balanced_accuracy"]:.4f}')
|
| 725 |
+
print(f' Macro F1 : {metrics_report["macro_f1"]:.4f}')
|
| 726 |
+
print(f' Cohen Kappa : {metrics_report["cohen_kappa"]:.4f}')
|
| 727 |
+
print(f' Macro AUC : {metrics_report["macro_auc"]:.4f}')
|
| 728 |
+
print(f' ECE (calibrated) : {ece_cal:.4f}')
|
| 729 |
+
print(f' ECE (uncalibrated) : {ece_uncal:.4f}')
|
| 730 |
+
print(f'\n Per-class AUC:')
|
| 731 |
+
for name, val in per_class_auc.items():
|
| 732 |
+
print(f' {name:15s} : {val:.4f}')
|
| 733 |
+
print(f'\n Source accuracy:')
|
| 734 |
+
for src, val in source_overall_acc.items():
|
| 735 |
+
print(f' {src:10s} : {val:.4f}')
|
| 736 |
+
print(f'\n All outputs in: {EVAL_DIR}/')
|
| 737 |
+
output_files = [
|
| 738 |
+
'confusion_matrix.png',
|
| 739 |
+
'roc_curves_per_class.png',
|
| 740 |
+
'precision_recall_curves.png',
|
| 741 |
+
'calibration_reliability.png',
|
| 742 |
+
'confidence_histograms.png',
|
| 743 |
+
'error_analysis_by_source.png',
|
| 744 |
+
'metrics_report.json',
|
| 745 |
+
]
|
| 746 |
+
for fname in output_files:
|
| 747 |
+
fpath = os.path.join(EVAL_DIR, fname)
|
| 748 |
+
exists = os.path.exists(fpath)
|
| 749 |
+
size_kb = os.path.getsize(fpath) / 1024 if exists else 0
|
| 750 |
+
status = f'{size_kb:.0f} KB' if exists else 'MISSING'
|
| 751 |
+
print(f' [{status:>8s}] {fname}')
|
| 752 |
+
print('=' * 65)
|