#!/usr/bin/env python3 """ RetinaSense v3.0 -- Phase 1A: Rich Evaluation Dashboard ======================================================== Standalone script that loads the trained ViT model, runs inference on the full test set (1,281 images), and produces publication-quality evaluation plots plus a structured metrics JSON report. Outputs (all written to outputs_v3/evaluation/): - confusion_matrix.png - roc_curves_per_class.png - precision_recall_curves.png - calibration_reliability.png - confidence_histograms.png - error_analysis_by_source.png - metrics_report.json Usage: python eval_dashboard.py """ import os import sys import json import warnings import numpy as np import pandas as pd import cv2 import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.ticker as mticker import seaborn as sns from PIL import Image from collections import OrderedDict warnings.filterwarnings('ignore') import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms import timm from sklearn.metrics import ( confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve, average_precision_score, f1_score, accuracy_score, cohen_kappa_score, matthews_corrcoef, balanced_accuracy_score, log_loss, ) # ================================================================ # CONFIGURATION # ================================================================ BASE_DIR = '/teamspace/studios/this_studio' OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3') EVAL_DIR = os.path.join(OUTPUT_DIR, 'evaluation') os.makedirs(EVAL_DIR, exist_ok=True) MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth') THRESHOLDS_PATH = os.path.join(OUTPUT_DIR, 'thresholds.json') TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json') TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv') NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json') NUM_CLASSES = 5 IMG_SIZE = 224 DROPOUT = 0.3 BATCH_SIZE = 32 CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD'] DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Publication style defaults plt.rcParams.update({ 'font.size': 11, 'axes.titlesize': 13, 'axes.labelsize': 12, 'xtick.labelsize': 10, 'ytick.labelsize': 10, 'legend.fontsize': 10, 'figure.dpi': 300, 'savefig.dpi': 300, 'savefig.bbox': 'tight', 'savefig.pad_inches': 0.15, 'font.family': 'sans-serif', }) print('=' * 65) print(' RetinaSense v3.0 -- Phase 1A: Evaluation Dashboard') print('=' * 65) print(f' Device : {DEVICE}') if torch.cuda.is_available(): print(f' GPU : {torch.cuda.get_device_name(0)}') print(f' Output : {EVAL_DIR}') print('=' * 65) # ================================================================ # LOAD NORMALISATION STATS # ================================================================ if os.path.exists(NORM_STATS_PATH): with open(NORM_STATS_PATH) as f: norm_stats = json.load(f) NORM_MEAN = norm_stats['mean_rgb'] NORM_STD = norm_stats['std_rgb'] print(f' Fundus norm stats loaded: mean={[round(v, 4) for v in NORM_MEAN]}, ' f'std={[round(v, 4) for v in NORM_STD]}') else: NORM_MEAN = [0.485, 0.456, 0.406] NORM_STD = [0.229, 0.224, 0.225] print(' Using ImageNet normalisation fallback') # ================================================================ # MODEL ARCHITECTURE (mirrors retinasense_v3.py / gradcam_v3.py) # ================================================================ class MultiTaskViT(nn.Module): """ViT-Base-Patch16-224 with disease + severity heads.""" def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT): super().__init__() self.backbone = timm.create_model( 'vit_base_patch16_224', pretrained=False, num_classes=0 ) feat = 768 # CLS token dimension self.drop = nn.Dropout(drop) self.disease_head = nn.Sequential( nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, n_disease), ) self.severity_head = nn.Sequential( nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, n_severity), ) def forward(self, x): f = self.backbone(x) # (B, 768) CLS token features f = self.drop(f) return self.disease_head(f), self.severity_head(f) # ================================================================ # LOAD MODEL + CALIBRATION ARTIFACTS # ================================================================ print('\nLoading model...') model = MultiTaskViT().to(DEVICE) ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) model.load_state_dict(ckpt['model_state_dict']) model.eval() print(f' Loaded: {MODEL_PATH}') print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1} ' f'val_acc={ckpt.get("val_acc", 0):.2f}%') with open(THRESHOLDS_PATH) as f: thr_data = json.load(f) THRESHOLDS = thr_data['thresholds'] with open(TEMPERATURE_PATH) as f: temp_data = json.load(f) TEMPERATURE = temp_data['temperature'] print(f' Temperature T = {TEMPERATURE:.4f}') print(f' Thresholds = {[round(t, 3) for t in THRESHOLDS]}') # ================================================================ # DATASET # ================================================================ class TestDataset(Dataset): """ Test dataset that loads from preprocessed .npy cache (fast path). Falls back to on-the-fly preprocessing if cache is missing. """ def __init__(self, df, transform): self.df = df.reset_index(drop=True) self.transform = transform def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] # Try cache path first cache_fp = row.get('cache_path', '') img = None if cache_fp and os.path.exists(cache_fp): try: img = np.load(cache_fp) except Exception: img = None # Fallback: on-the-fly preprocessing if img is None: image_path = row['image_path'] if not os.path.isabs(image_path): clean = image_path while clean.startswith('./') or clean.startswith('.//'): clean = clean[2:] if clean.startswith('./') else clean[3:] image_path = os.path.join(BASE_DIR, clean) source = row.get('source', 'ODIR') try: if source == 'APTOS': img = self._ben_graham(image_path) else: img = self._clahe_preprocess(image_path) except Exception: img = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8) img_tensor = self.transform(img) disease_lbl = int(row['disease_label']) source = row.get('source', 'unknown') return img_tensor, disease_lbl, source @staticmethod def _ben_graham(path, sz=IMG_SIZE, sigma=10): raw = cv2.imread(path) if raw is None: raw = np.array(Image.open(path).convert('RGB')) raw = cv2.cvtColor(raw, cv2.COLOR_RGB2BGR) raw = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB) raw = cv2.resize(raw, (sz, sz)) raw = cv2.addWeighted(raw, 4, cv2.GaussianBlur(raw, (0, 0), sigma), -4, 128) mask = np.zeros(raw.shape[:2], dtype=np.uint8) cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1) return cv2.bitwise_and(raw, raw, mask=mask) @staticmethod def _clahe_preprocess(path, sz=IMG_SIZE): raw = cv2.imread(path) if raw is None: raw = np.array(Image.open(path).convert('RGB')) raw = cv2.cvtColor(raw, cv2.COLOR_RGB2BGR) raw = cv2.resize(raw, (sz, sz)) lab = cv2.cvtColor(raw, cv2.COLOR_BGR2LAB) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) lab[:, :, 0] = clahe.apply(lab[:, :, 0]) raw = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) return cv2.cvtColor(raw, cv2.COLOR_BGR2RGB) val_transform = transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize(NORM_MEAN, NORM_STD), ]) print('\nLoading test set...') test_df = pd.read_csv(TEST_CSV) print(f' Test samples: {len(test_df)}') print(f' Sources : {sorted(test_df["source"].unique())}') print(f' Class dist : {test_df["disease_label"].value_counts().sort_index().to_dict()}') test_ds = TestDataset(test_df, val_transform) test_loader = DataLoader( test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True, ) # ================================================================ # INFERENCE # ================================================================ print('\nRunning inference on full test set...') all_logits = [] all_labels = [] all_sources = [] with torch.no_grad(): for imgs, labels, sources in test_loader: imgs = imgs.to(DEVICE) disease_logits, _ = model(imgs) all_logits.append(disease_logits.cpu()) all_labels.extend(labels.numpy().tolist()) all_sources.extend(sources) all_logits = torch.cat(all_logits, dim=0) # (N, 5) all_labels = np.array(all_labels) all_sources = np.array(all_sources) N = len(all_labels) print(f' Inference complete: {N} samples') # Temperature-scaled probabilities probs_calibrated = F.softmax(all_logits / TEMPERATURE, dim=1).numpy() # (N, 5) probs_uncalibrated = F.softmax(all_logits, dim=1).numpy() # Predictions: argmax of calibrated probabilities preds = np.argmax(probs_calibrated, axis=1) confidences = np.max(probs_calibrated, axis=1) correct_mask = (preds == all_labels) acc = accuracy_score(all_labels, preds) print(f' Overall accuracy: {acc:.4f} ({int(acc * N)}/{N})') # ================================================================ # 1. CONFUSION MATRIX # ================================================================ print('\n[1/7] Confusion matrix...') cm = confusion_matrix(all_labels, preds, labels=list(range(NUM_CLASSES))) cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True) fig, ax = plt.subplots(figsize=(7, 6)) sns.heatmap( cm_norm, annot=True, fmt='.2f', cmap='Blues', xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, linewidths=0.5, linecolor='white', cbar_kws={'label': 'Proportion', 'shrink': 0.8}, ax=ax, vmin=0, vmax=1, ) # Overlay raw counts in smaller font for i in range(NUM_CLASSES): for j in range(NUM_CLASSES): ax.text(j + 0.5, i + 0.72, f'(n={cm[i, j]})', ha='center', va='center', fontsize=7, color='gray') ax.set_xlabel('Predicted Class') ax.set_ylabel('True Class') ax.set_title('Normalized Confusion Matrix (Test Set)') fig.tight_layout() fig.savefig(os.path.join(EVAL_DIR, 'confusion_matrix.png')) plt.close(fig) print(' Saved confusion_matrix.png') # ================================================================ # 2. ROC CURVES PER CLASS # ================================================================ print('[2/7] ROC curves...') fig, ax = plt.subplots(figsize=(7, 6)) colors = sns.color_palette('tab10', NUM_CLASSES) all_fpr_tpr = {} macro_auc_list = [] for i in range(NUM_CLASSES): y_true_bin = (all_labels == i).astype(int) y_score = probs_calibrated[:, i] fpr, tpr, _ = roc_curve(y_true_bin, y_score) roc_auc = auc(fpr, tpr) macro_auc_list.append(roc_auc) all_fpr_tpr[i] = (fpr, tpr) ax.plot(fpr, tpr, color=colors[i], lw=2, label=f'{CLASS_NAMES[i]} (AUC={roc_auc:.3f})') # Macro average ROC mean_fpr = np.linspace(0, 1, 200) mean_tpr = np.zeros_like(mean_fpr) for i in range(NUM_CLASSES): mean_tpr += np.interp(mean_fpr, all_fpr_tpr[i][0], all_fpr_tpr[i][1]) mean_tpr /= NUM_CLASSES macro_auc = auc(mean_fpr, mean_tpr) ax.plot(mean_fpr, mean_tpr, 'k--', lw=2.5, label=f'Macro-average (AUC={macro_auc:.3f})') ax.plot([0, 1], [0, 1], 'k:', lw=1, alpha=0.4) ax.set_xlim([-0.02, 1.02]) ax.set_ylim([-0.02, 1.05]) ax.set_xlabel('False Positive Rate') ax.set_ylabel('True Positive Rate') ax.set_title('One-vs-Rest ROC Curves (Calibrated)') ax.legend(loc='lower right', framealpha=0.9) ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(EVAL_DIR, 'roc_curves_per_class.png')) plt.close(fig) print(' Saved roc_curves_per_class.png') # ================================================================ # 3. PRECISION-RECALL CURVES # ================================================================ print('[3/7] Precision-recall curves...') fig, ax = plt.subplots(figsize=(7, 6)) for i in range(NUM_CLASSES): y_true_bin = (all_labels == i).astype(int) y_score = probs_calibrated[:, i] prec, rec, _ = precision_recall_curve(y_true_bin, y_score) ap = average_precision_score(y_true_bin, y_score) ax.plot(rec, prec, color=colors[i], lw=2, label=f'{CLASS_NAMES[i]} (AP={ap:.3f})') # Add prevalence baselines prevalences = np.bincount(all_labels, minlength=NUM_CLASSES) / N for i in range(NUM_CLASSES): ax.axhline(y=prevalences[i], color=colors[i], ls=':', alpha=0.3) ax.set_xlim([-0.02, 1.02]) ax.set_ylim([-0.02, 1.05]) ax.set_xlabel('Recall') ax.set_ylabel('Precision') ax.set_title('Precision-Recall Curves (Calibrated)') ax.legend(loc='upper right', framealpha=0.9) ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(EVAL_DIR, 'precision_recall_curves.png')) plt.close(fig) print(' Saved precision_recall_curves.png') # ================================================================ # 4. CALIBRATION RELIABILITY DIAGRAM # ================================================================ print('[4/7] Calibration reliability diagram...') n_bins = 10 bin_edges = np.linspace(0, 1, n_bins + 1) bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 # Compute calibration for both calibrated and uncalibrated probabilities def compute_calibration(confidences_arr, correct_arr, bin_edges): """Compute per-bin accuracy and average confidence.""" bin_accs = [] bin_confs = [] bin_counts = [] for lo, hi in zip(bin_edges[:-1], bin_edges[1:]): mask = (confidences_arr > lo) & (confidences_arr <= hi) if mask.sum() == 0: bin_accs.append(np.nan) bin_confs.append(np.nan) bin_counts.append(0) else: bin_accs.append(correct_arr[mask].mean()) bin_confs.append(confidences_arr[mask].mean()) bin_counts.append(int(mask.sum())) return np.array(bin_accs), np.array(bin_confs), np.array(bin_counts) conf_calib = np.max(probs_calibrated, axis=1) conf_uncalib = np.max(probs_uncalibrated, axis=1) bin_accs_cal, bin_confs_cal, bin_counts_cal = compute_calibration( conf_calib, correct_mask.astype(float), bin_edges) bin_accs_uncal, bin_confs_uncal, bin_counts_uncal = compute_calibration( conf_uncalib, correct_mask.astype(float), bin_edges) # ECE ece_cal = np.nansum( np.abs(bin_accs_cal - bin_confs_cal) * bin_counts_cal) / N ece_uncal = np.nansum( np.abs(bin_accs_uncal - bin_confs_uncal) * bin_counts_uncal) / N fig, axes = plt.subplots(1, 2, figsize=(12, 5)) for ax_idx, (b_accs, b_confs, b_counts, ece_val, title_suffix) in enumerate([ (bin_accs_cal, bin_confs_cal, bin_counts_cal, ece_cal, 'Calibrated'), (bin_accs_uncal, bin_confs_uncal, bin_counts_uncal, ece_uncal, 'Uncalibrated'), ]): ax = axes[ax_idx] # Perfect calibration line ax.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.5, label='Perfectly calibrated') # Bar chart of bin accuracy valid = ~np.isnan(b_accs) bar_color = '#4C72B0' if ax_idx == 0 else '#DD8452' ax.bar(bin_centers[valid], b_accs[valid], width=0.08, alpha=0.7, color=bar_color, edgecolor='black', linewidth=0.5, label=f'Model (ECE={ece_val:.4f})') # Gap shading for j in range(n_bins): if valid[j]: lo_val = min(b_accs[j], b_confs[j]) hi_val = max(b_accs[j], b_confs[j]) ax.fill_between( [bin_centers[j] - 0.04, bin_centers[j] + 0.04], lo_val, hi_val, alpha=0.15, color='red') # Sample counts on top for j in range(n_bins): if valid[j] and b_counts[j] > 0: ax.text(bin_centers[j], b_accs[j] + 0.03, str(b_counts[j]), ha='center', va='bottom', fontsize=7) ax.set_xlim([0, 1]) ax.set_ylim([0, 1.1]) ax.set_xlabel('Mean Predicted Confidence') ax.set_ylabel('Fraction of Correct Predictions') ax.set_title(f'Reliability Diagram ({title_suffix})') ax.legend(loc='upper left', framealpha=0.9) ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(EVAL_DIR, 'calibration_reliability.png')) plt.close(fig) print(f' Saved calibration_reliability.png (ECE_cal={ece_cal:.4f}, ECE_uncal={ece_uncal:.4f})') # ================================================================ # 5. CONFIDENCE HISTOGRAMS # ================================================================ print('[5/7] Confidence histograms...') fig, axes = plt.subplots(1, 2, figsize=(12, 5)) # Correct vs Incorrect for ax_idx, (mask, label, color) in enumerate([ (correct_mask, 'Correct', '#2ca02c'), (~correct_mask, 'Incorrect', '#d62728'), ]): axes[0].hist(confidences[mask], bins=30, alpha=0.65, color=color, label=f'{label} (n={mask.sum()})', edgecolor='black', linewidth=0.3) axes[0].set_xlabel('Prediction Confidence') axes[0].set_ylabel('Count') axes[0].set_title('Confidence Distribution: Correct vs Incorrect') axes[0].legend(loc='upper left', framealpha=0.9) axes[0].axvline(x=np.median(confidences[correct_mask]), color='#2ca02c', ls='--', alpha=0.6, label='_nolegend_') axes[0].axvline(x=np.median(confidences[~correct_mask]), color='#d62728', ls='--', alpha=0.6, label='_nolegend_') axes[0].grid(True, alpha=0.3, axis='y') # Per-class confidence for i in range(NUM_CLASSES): cls_mask = (all_labels == i) axes[1].hist(confidences[cls_mask], bins=20, alpha=0.5, color=colors[i], label=f'{CLASS_NAMES[i]} (n={cls_mask.sum()})', edgecolor='black', linewidth=0.3) axes[1].set_xlabel('Prediction Confidence') axes[1].set_ylabel('Count') axes[1].set_title('Confidence Distribution by True Class') axes[1].legend(loc='upper left', framealpha=0.9, fontsize=9) axes[1].grid(True, alpha=0.3, axis='y') fig.tight_layout() fig.savefig(os.path.join(EVAL_DIR, 'confidence_histograms.png')) plt.close(fig) print(' Saved confidence_histograms.png') # ================================================================ # 6. ERROR ANALYSIS BY SOURCE # ================================================================ print('[6/7] Error analysis by source...') sources_unique = sorted(np.unique(all_sources)) n_sources = len(sources_unique) # Build accuracy per (source, class) pair source_class_acc = {} source_class_n = {} for src in sources_unique: for cls_idx in range(NUM_CLASSES): mask = (all_sources == src) & (all_labels == cls_idx) n_cls = mask.sum() if n_cls > 0: acc_sc = (preds[mask] == all_labels[mask]).mean() else: acc_sc = np.nan source_class_acc[(src, cls_idx)] = acc_sc source_class_n[(src, cls_idx)] = int(n_cls) # Also overall accuracy per source source_overall_acc = {} for src in sources_unique: mask = (all_sources == src) source_overall_acc[src] = accuracy_score(all_labels[mask], preds[mask]) fig, axes = plt.subplots(1, 2, figsize=(14, 6)) # Left panel: grouped bar chart of per-class accuracy by source x = np.arange(NUM_CLASSES) bar_width = 0.8 / max(n_sources, 1) source_colors = sns.color_palette('Set2', n_sources) for s_idx, src in enumerate(sources_unique): accs = [source_class_acc[(src, c)] for c in range(NUM_CLASSES)] counts = [source_class_n[(src, c)] for c in range(NUM_CLASSES)] offset = (s_idx - n_sources / 2 + 0.5) * bar_width bars = axes[0].bar(x + offset, accs, bar_width * 0.9, label=f'{src} (n={sum(counts)})', color=source_colors[s_idx], edgecolor='black', linewidth=0.5) # Annotate sample counts for j, (b, n_val) in enumerate(zip(bars, counts)): if n_val > 0 and not np.isnan(accs[j]): axes[0].text(b.get_x() + b.get_width() / 2, b.get_height() + 0.02, str(n_val), ha='center', va='bottom', fontsize=7) axes[0].set_xticks(x) axes[0].set_xticklabels(CLASS_NAMES, rotation=15, ha='right') axes[0].set_ylabel('Accuracy') axes[0].set_title('Per-Class Accuracy by Data Source') axes[0].set_ylim([0, 1.15]) axes[0].legend(loc='upper right', framealpha=0.9) axes[0].grid(True, alpha=0.3, axis='y') axes[0].axhline(y=acc, color='black', ls='--', alpha=0.4, lw=1) axes[0].text(NUM_CLASSES - 0.5, acc + 0.02, f'Overall: {acc:.3f}', ha='right', fontsize=9, alpha=0.6) # Right panel: confusion breakdown -- most common misclassifications per source error_data = [] for src in sources_unique: src_mask = (all_sources == src) & (~correct_mask) if src_mask.sum() == 0: continue for true_cls in range(NUM_CLASSES): for pred_cls in range(NUM_CLASSES): if true_cls == pred_cls: continue pair_mask = src_mask & (all_labels == true_cls) & (preds == pred_cls) cnt = pair_mask.sum() if cnt > 0: error_data.append({ 'Source': src, 'Error': f'{CLASS_NAMES[true_cls][:3]}>{CLASS_NAMES[pred_cls][:3]}', 'Count': int(cnt), }) if error_data: err_df = pd.DataFrame(error_data) # Top 10 error types top_errors = (err_df.groupby('Error')['Count'].sum() .sort_values(ascending=False).head(10).index.tolist()) err_df_top = err_df[err_df['Error'].isin(top_errors)] pivot = err_df_top.pivot_table(index='Error', columns='Source', values='Count', aggfunc='sum', fill_value=0) # Reorder by total count pivot = pivot.loc[pivot.sum(axis=1).sort_values(ascending=True).index] pivot.plot(kind='barh', stacked=True, ax=axes[1], color=source_colors[:n_sources], edgecolor='black', linewidth=0.5) axes[1].set_xlabel('Error Count') axes[1].set_title('Top Misclassification Patterns by Source') axes[1].legend(loc='lower right', framealpha=0.9) axes[1].grid(True, alpha=0.3, axis='x') else: axes[1].text(0.5, 0.5, 'No errors to display', ha='center', va='center', transform=axes[1].transAxes, fontsize=14) axes[1].set_title('Top Misclassification Patterns by Source') fig.tight_layout() fig.savefig(os.path.join(EVAL_DIR, 'error_analysis_by_source.png')) plt.close(fig) print(' Saved error_analysis_by_source.png') # ================================================================ # 7. METRICS REPORT (JSON) # ================================================================ print('[7/7] Metrics report...') # Classification report as dict cls_report = classification_report( all_labels, preds, target_names=CLASS_NAMES, output_dict=True, zero_division=0) # Per-class AUC and AP per_class_auc = {} per_class_ap = {} for i in range(NUM_CLASSES): y_bin = (all_labels == i).astype(int) y_score = probs_calibrated[:, i] fpr_i, tpr_i, _ = roc_curve(y_bin, y_score) per_class_auc[CLASS_NAMES[i]] = float(auc(fpr_i, tpr_i)) per_class_ap[CLASS_NAMES[i]] = float(average_precision_score(y_bin, y_score)) # Build the full report try: ll = float(log_loss(all_labels, probs_calibrated)) except Exception: ll = None metrics_report = OrderedDict([ ('n_test_samples', int(N)), ('overall_accuracy', float(acc)), ('balanced_accuracy', float(balanced_accuracy_score(all_labels, preds))), ('macro_f1', float(f1_score(all_labels, preds, average='macro', zero_division=0))), ('weighted_f1', float(f1_score(all_labels, preds, average='weighted', zero_division=0))), ('cohen_kappa', float(cohen_kappa_score(all_labels, preds))), ('matthews_corrcoef', float(matthews_corrcoef(all_labels, preds))), ('log_loss', ll), ('macro_auc', float(np.mean(list(per_class_auc.values())))), ('ece_calibrated', float(ece_cal)), ('ece_uncalibrated', float(ece_uncal)), ('temperature', float(TEMPERATURE)), ('thresholds', THRESHOLDS), ('per_class_metrics', {}), ('per_class_auc', per_class_auc), ('per_class_ap', per_class_ap), ('confusion_matrix_raw', cm.tolist()), ('confusion_matrix_normalized', np.round(cm_norm, 4).tolist()), ('source_accuracy', {src: float(v) for src, v in source_overall_acc.items()}), ('source_class_counts', { src: {CLASS_NAMES[c]: source_class_n[(src, c)] for c in range(NUM_CLASSES)} for src in sources_unique }), ('class_names', CLASS_NAMES), ]) # Per-class from classification_report for i, name in enumerate(CLASS_NAMES): metrics_report['per_class_metrics'][name] = { 'precision': float(cls_report[name]['precision']), 'recall': float(cls_report[name]['recall']), 'f1-score': float(cls_report[name]['f1-score']), 'support': int(cls_report[name]['support']), 'auc': per_class_auc[name], 'average_precision': per_class_ap[name], } report_path = os.path.join(EVAL_DIR, 'metrics_report.json') with open(report_path, 'w') as f: json.dump(metrics_report, f, indent=2) print(f' Saved metrics_report.json') # ================================================================ # SUMMARY # ================================================================ print('\n' + '=' * 65) print(' EVALUATION DASHBOARD COMPLETE') print('=' * 65) print(f' Overall Accuracy : {acc:.4f}') print(f' Balanced Accuracy : {metrics_report["balanced_accuracy"]:.4f}') print(f' Macro F1 : {metrics_report["macro_f1"]:.4f}') print(f' Cohen Kappa : {metrics_report["cohen_kappa"]:.4f}') print(f' Macro AUC : {metrics_report["macro_auc"]:.4f}') print(f' ECE (calibrated) : {ece_cal:.4f}') print(f' ECE (uncalibrated) : {ece_uncal:.4f}') print(f'\n Per-class AUC:') for name, val in per_class_auc.items(): print(f' {name:15s} : {val:.4f}') print(f'\n Source accuracy:') for src, val in source_overall_acc.items(): print(f' {src:10s} : {val:.4f}') print(f'\n All outputs in: {EVAL_DIR}/') output_files = [ 'confusion_matrix.png', 'roc_curves_per_class.png', 'precision_recall_curves.png', 'calibration_reliability.png', 'confidence_histograms.png', 'error_analysis_by_source.png', 'metrics_report.json', ] for fname in output_files: fpath = os.path.join(EVAL_DIR, fname) exists = os.path.exists(fpath) size_kb = os.path.getsize(fpath) / 1024 if exists else 0 status = f'{size_kb:.0f} KB' if exists else 'MISSING' print(f' [{status:>8s}] {fname}') print('=' * 65)