| |
| """ |
| RetinaSense ViT v2 - Comprehensive Error Analysis & Baseline Report |
| =================================================================== |
| Runs full evaluation on the validation split, computes ECE, |
| confusion analysis, confidence distributions, and source-level |
| performance. Saves all plots and metrics to outputs_analysis/v2_baseline/. |
| """ |
|
|
| import os, sys, json, warnings |
| import numpy as np |
| import pandas as pd |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import matplotlib.gridspec as gridspec |
| import seaborn as sns |
| from pathlib import Path |
| from tqdm import tqdm |
| warnings.filterwarnings('ignore') |
|
|
| import cv2 |
| from PIL import Image |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from torch.utils.data import Dataset, DataLoader |
|
|
| import timm |
| from sklearn.model_selection import train_test_split |
| from sklearn.metrics import ( |
| confusion_matrix, classification_report, |
| f1_score, precision_score, recall_score |
| ) |
|
|
| |
| |
| |
| BASE_DIR = '/teamspace/studios/this_studio' |
| MODEL_PATH = f'{BASE_DIR}/outputs_vit/best_model.pth' |
| META_CSV = f'{BASE_DIR}/final_unified_metadata.csv' |
| THRESH_JSON = f'{BASE_DIR}/outputs_vit/threshold_optimization_results.json' |
| CACHE_DIR = f'{BASE_DIR}/preprocessed_cache_vit' |
| OUT_DIR = f'{BASE_DIR}/outputs_analysis/v2_baseline' |
|
|
| IMG_SIZE = 224 |
| BATCH_SIZE = 64 |
| NUM_WORKERS = 8 |
| NUM_CLASSES = 5 |
| CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD'] |
|
|
| os.makedirs(OUT_DIR, exist_ok=True) |
| os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f'Device: {device}') |
| if torch.cuda.is_available(): |
| print(f'GPU: {torch.cuda.get_device_name(0)}') |
|
|
| |
| |
| |
| class MultiTaskViT(nn.Module): |
| def __init__(self, n_disease=5, n_severity=5, drop=0.4): |
| super().__init__() |
| self.backbone = timm.create_model( |
| 'vit_base_patch16_224', pretrained=False, num_classes=0) |
| feat = 768 |
| self.drop = nn.Dropout(drop) |
| self.disease_head = nn.Sequential( |
| nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), |
| nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), |
| nn.Linear(256, n_disease)) |
| self.severity_head = nn.Sequential( |
| nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), |
| nn.Linear(256, n_severity)) |
|
|
| def forward(self, x): |
| f = self.backbone(x) |
| f = self.drop(f) |
| return self.disease_head(f), self.severity_head(f) |
|
|
|
|
| |
| |
| |
| def ben_graham(path, sz=IMG_SIZE, sigma=10): |
| img = cv2.imread(str(path)) |
| if img is None: |
| img = np.array(Image.open(str(path)).convert('RGB')) |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| img = cv2.resize(img, (sz, sz)) |
| img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0,0), sigma), -4, 128) |
| mask = np.zeros(img.shape[:2], dtype=np.uint8) |
| cv2.circle(mask, (sz//2, sz//2), int(sz * 0.48), 255, -1) |
| return cv2.bitwise_and(img, img, mask=mask) |
|
|
|
|
| def resolve_image_path(raw_path): |
| """ |
| Resolve image path from CSV entry (which has leading .// prefix). |
| Tries multiple known root locations. |
| APTOS images live in: |
| aptos/gaussian_filtered_images/gaussian_filtered_images/{Severity}/{stem}.png |
| ODIR images live in: |
| odir/preprocessed_images/{filename} |
| """ |
| |
| clean = raw_path.lstrip('.').lstrip('/').lstrip('/') |
| clean = clean.replace('//', '/') |
|
|
| stem = Path(raw_path).stem |
|
|
| candidates = [ |
| f'{BASE_DIR}/{clean}', |
| ] |
|
|
| |
| if 'aptos' in raw_path.lower(): |
| aptos_base = f'{BASE_DIR}/aptos/gaussian_filtered_images/gaussian_filtered_images' |
| for severity in ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferate_DR']: |
| for ext in ['.png', '.jpg', '.jpeg']: |
| candidates.append(f'{aptos_base}/{severity}/{stem}{ext}') |
| |
| for ext in ['.png', '.jpg', '.jpeg']: |
| candidates.append(f'{BASE_DIR}/aptos/train_images/{stem}{ext}') |
|
|
| |
| if 'odir' in raw_path.lower(): |
| fname = Path(raw_path).name |
| candidates.append(f'{BASE_DIR}/odir/preprocessed_images/{fname}') |
| candidates.append(f'{BASE_DIR}/ocular-disease-recognition-odir5k/preprocessed_images/{fname}') |
|
|
| for c in candidates: |
| if os.path.exists(c): |
| return c |
| return None |
|
|
|
|
| def load_or_cache(row): |
| """ |
| Load preprocessed image from cache (.npy) or process from disk. |
| Returns uint8 HxWx3 numpy array. |
| """ |
| stem = Path(row['image_path_clean']).stem |
| cache_fp = f'{CACHE_DIR}/{stem}_224.npy' |
|
|
| if os.path.exists(cache_fp): |
| try: |
| return np.load(cache_fp) |
| except Exception: |
| pass |
|
|
| img_path = row.get('image_path_resolved') |
| if img_path and os.path.exists(img_path): |
| try: |
| arr = ben_graham(img_path) |
| np.save(cache_fp, arr) |
| return arr |
| except Exception as e: |
| pass |
|
|
| |
| return np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8) |
|
|
|
|
| |
| |
| |
| val_transform = transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
|
|
|
|
| class RetDS(Dataset): |
| def __init__(self, df): |
| self.df = df.reset_index(drop=True) |
|
|
| def __len__(self): |
| return len(self.df) |
|
|
| def __getitem__(self, i): |
| r = self.df.iloc[i] |
| img = load_or_cache(r) |
| return ( |
| val_transform(img), |
| torch.tensor(int(r['disease_label']), dtype=torch.long), |
| torch.tensor(int(r['severity_label']), dtype=torch.long), |
| i |
| ) |
|
|
|
|
| |
| |
| |
| print('\n[1/6] Loading metadata and building val split...') |
|
|
| meta = pd.read_csv(META_CSV) |
| print(f' Raw rows: {len(meta)}') |
|
|
| |
| meta['image_path_clean'] = meta['image_path'].str.lstrip('.').str.lstrip('/').str.replace('//', '/', regex=False) |
| meta['image_path_resolved'] = meta['image_path_clean'].apply( |
| lambda p: resolve_image_path(p) |
| ) |
|
|
| n_resolved = meta['image_path_resolved'].notna().sum() |
| print(f' Images resolved on disk: {n_resolved} / {len(meta)}') |
|
|
| |
| train_df, val_df = train_test_split( |
| meta, |
| test_size=0.2, |
| stratify=meta['disease_label'], |
| random_state=42 |
| ) |
| val_df = val_df.reset_index(drop=True) |
| print(f' Val split: {len(val_df)} samples') |
| print(f' Val class distribution:') |
| for lbl, cnt in val_df['disease_label'].value_counts().sort_index().items(): |
| print(f' {CLASS_NAMES[int(lbl)]:<15s}: {cnt:4d}') |
|
|
| |
| |
| |
| print('\n[2/6] Loading model...') |
|
|
| model = MultiTaskViT().to(device) |
| ckpt = torch.load(MODEL_PATH, map_location=device, weights_only=False) |
| model.load_state_dict(ckpt['model_state_dict']) |
| model.eval() |
| print(f' Loaded checkpoint: epoch={ckpt.get("epoch","?")}, ' |
| f'macro_f1={ckpt.get("macro_f1", 0):.4f}') |
|
|
| |
| with open(THRESH_JSON) as f: |
| thresh_data = json.load(f) |
| thresholds = {int(k): float(v) for k, v in thresh_data['optimal_thresholds'].items()} |
| print(f' Optimal thresholds: {thresholds}') |
|
|
| |
| |
| |
| print('\n[3/6] Running inference on val set...') |
|
|
| val_ds = RetDS(val_df) |
| val_loader = DataLoader( |
| val_ds, batch_size=BATCH_SIZE, shuffle=False, |
| num_workers=NUM_WORKERS, pin_memory=True |
| ) |
|
|
| all_probs = [] |
| all_preds = [] |
| all_labels = [] |
| all_idxs = [] |
|
|
| with torch.no_grad(): |
| for imgs, d_lbl, s_lbl, idx in tqdm(val_loader, desc='Inference'): |
| imgs = imgs.to(device, non_blocking=True) |
| with torch.amp.autocast('cuda'): |
| d_out, _ = model(imgs) |
| probs = torch.softmax(d_out.float(), dim=1).cpu().numpy() |
| preds = d_out.argmax(1).cpu().numpy() |
| all_probs.append(probs) |
| all_preds.append(preds) |
| all_labels.append(d_lbl.numpy()) |
| all_idxs.append(idx.numpy()) |
|
|
| all_probs = np.vstack(all_probs) |
| all_preds = np.concatenate(all_preds) |
| all_labels = np.concatenate(all_labels) |
| all_idxs = np.concatenate(all_idxs) |
|
|
| |
| thresh_preds = np.zeros_like(all_preds) |
| for i in range(len(all_probs)): |
| adjusted = all_probs[i].copy() |
| for c, t in thresholds.items(): |
| adjusted[c] = all_probs[i][c] / t |
| thresh_preds[i] = adjusted.argmax() |
|
|
| raw_acc = (all_preds == all_labels).mean() * 100 |
| thresh_acc = (thresh_preds == all_labels).mean() * 100 |
| print(f' Raw accuracy : {raw_acc:.2f}%') |
| print(f' Threshold accuracy: {thresh_acc:.2f}%') |
|
|
| |
| preds = thresh_preds |
|
|
| |
| |
| |
| print('\n[4/6] Computing ECE and reliability diagram...') |
|
|
| def compute_ece(probs, labels, n_bins=10): |
| """Expected Calibration Error with equal-width bins.""" |
| confidences = probs.max(axis=1) |
| predicted = probs.argmax(axis=1) |
| correct = (predicted == labels).astype(float) |
|
|
| bins = np.linspace(0, 1, n_bins + 1) |
| ece = 0.0 |
| bin_acc = [] |
| bin_conf = [] |
| bin_count = [] |
|
|
| for lo, hi in zip(bins[:-1], bins[1:]): |
| mask = (confidences >= lo) & (confidences < hi) |
| if mask.sum() == 0: |
| bin_acc.append(0.0) |
| bin_conf.append((lo + hi) / 2) |
| bin_count.append(0) |
| continue |
| acc = correct[mask].mean() |
| conf = confidences[mask].mean() |
| n = mask.sum() |
| ece += (n / len(labels)) * abs(acc - conf) |
| bin_acc.append(acc) |
| bin_conf.append(conf) |
| bin_count.append(int(n)) |
|
|
| return ece, bin_acc, bin_conf, bin_count, bins |
|
|
| ece, bin_acc, bin_conf, bin_count, bins = compute_ece(all_probs, all_labels) |
| print(f' ECE (10 bins): {ece:.4f}') |
|
|
| |
| per_class_ece = {} |
| for c in range(NUM_CLASSES): |
| mask = (all_labels == c) |
| if mask.sum() == 0: |
| per_class_ece[CLASS_NAMES[c]] = 0.0 |
| continue |
| ece_c, _, _, _, _ = compute_ece(all_probs[mask], all_labels[mask]) |
| per_class_ece[CLASS_NAMES[c]] = float(ece_c) |
| print(f' ECE {CLASS_NAMES[c]:<15s}: {ece_c:.4f}') |
|
|
| |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| bin_centers = (bins[:-1] + bins[1:]) / 2 |
| bars = axes[0].bar( |
| bin_centers, bin_acc, |
| width=(bins[1] - bins[0]) * 0.9, |
| alpha=0.7, color='steelblue', label='Accuracy per bin' |
| ) |
| axes[0].plot([0, 1], [0, 1], 'r--', lw=2, label='Perfect calibration') |
| axes[0].set_xlabel('Confidence', fontsize=12) |
| axes[0].set_ylabel('Accuracy', fontsize=12) |
| axes[0].set_title(f'Reliability Diagram\nECE = {ece:.4f}', fontsize=13, fontweight='bold') |
| axes[0].legend(fontsize=10) |
| axes[0].grid(alpha=0.3) |
| axes[0].set_xlim(0, 1); axes[0].set_ylim(0, 1) |
|
|
| |
| for bar, cnt in zip(bars, bin_count): |
| if cnt > 0: |
| axes[0].text( |
| bar.get_x() + bar.get_width()/2, min(bar.get_height() + 0.02, 0.97), |
| str(cnt), ha='center', va='bottom', fontsize=7, color='black' |
| ) |
|
|
| |
| gap = np.array(bin_conf) - np.array(bin_acc) |
| color_gap = ['#e74c3c' if g > 0 else '#2ecc71' for g in gap] |
| axes[1].bar(bin_centers, gap, width=(bins[1]-bins[0])*0.9, color=color_gap, alpha=0.8) |
| axes[1].axhline(0, color='black', lw=1) |
| axes[1].set_xlabel('Confidence', fontsize=12) |
| axes[1].set_ylabel('Confidence - Accuracy (Gap)', fontsize=12) |
| axes[1].set_title('Calibration Gap\n(Red=overconfident, Green=underconfident)', |
| fontsize=13, fontweight='bold') |
| axes[1].grid(alpha=0.3) |
| axes[1].set_xlim(0, 1) |
|
|
| plt.tight_layout() |
| plt.savefig(f'{OUT_DIR}/reliability_diagram.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f' Saved reliability_diagram.png') |
|
|
| |
| |
| |
| print('\n[5/6] Generating confusion matrices...') |
|
|
| cm_raw = confusion_matrix(all_labels, preds) |
| cm_norm = cm_raw.astype(float) / cm_raw.sum(axis=1, keepdims=True) |
|
|
| |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| sns.heatmap( |
| cm_raw, annot=True, fmt='d', cmap='Blues', ax=ax, |
| xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, |
| linewidths=0.5, linecolor='gray' |
| ) |
| ax.set_title('Confusion Matrix (Raw Counts)', fontsize=14, fontweight='bold') |
| ax.set_ylabel('True Label', fontsize=12) |
| ax.set_xlabel('Predicted Label', fontsize=12) |
| plt.xticks(rotation=30, ha='right') |
| plt.tight_layout() |
| plt.savefig(f'{OUT_DIR}/confusion_matrix_raw.png', dpi=150, bbox_inches='tight') |
| plt.close() |
|
|
| |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| sns.heatmap( |
| cm_norm, annot=True, fmt='.3f', cmap='Blues', ax=ax, |
| xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, |
| linewidths=0.5, linecolor='gray', vmin=0, vmax=1 |
| ) |
| ax.set_title('Confusion Matrix (Normalized by True Class)', fontsize=14, fontweight='bold') |
| ax.set_ylabel('True Label', fontsize=12) |
| ax.set_xlabel('Predicted Label', fontsize=12) |
| plt.xticks(rotation=30, ha='right') |
| plt.tight_layout() |
| plt.savefig(f'{OUT_DIR}/confusion_matrix_normalized.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print(' Saved confusion_matrix_raw.png and confusion_matrix_normalized.png') |
|
|
| |
| confused_pairs = [] |
| for true_c in range(NUM_CLASSES): |
| for pred_c in range(NUM_CLASSES): |
| if true_c == pred_c: |
| continue |
| count = cm_raw[true_c, pred_c] |
| rate = cm_norm[true_c, pred_c] |
| confused_pairs.append({ |
| 'true_class': CLASS_NAMES[true_c], |
| 'pred_class': CLASS_NAMES[pred_c], |
| 'count': int(count), |
| 'rate': float(rate), |
| 'description': f'{CLASS_NAMES[true_c]} misclassified AS {CLASS_NAMES[pred_c]}' |
| }) |
| confused_pairs.sort(key=lambda x: x['count'], reverse=True) |
| top5_pairs = confused_pairs[:5] |
|
|
| print('\n Top 5 confused class pairs (by raw count):') |
| for p in top5_pairs: |
| print(f' {p["description"]}: {p["count"]} ({p["rate"]*100:.1f}%)') |
|
|
| |
| |
| |
| print('\n[6/6] Computing per-class metrics...') |
|
|
| report_dict = classification_report( |
| all_labels, preds, target_names=CLASS_NAMES, output_dict=True, zero_division=0 |
| ) |
| print(classification_report(all_labels, preds, target_names=CLASS_NAMES, digits=4, zero_division=0)) |
|
|
| per_class_precision = {} |
| per_class_recall = {} |
| per_class_f1 = {} |
| per_class_support = {} |
|
|
| for cn in CLASS_NAMES: |
| per_class_precision[cn] = report_dict[cn]['precision'] |
| per_class_recall[cn] = report_dict[cn]['recall'] |
| per_class_f1[cn] = report_dict[cn]['f1-score'] |
| per_class_support[cn] = int(report_dict[cn]['support']) |
|
|
| overall_accuracy = report_dict['accuracy'] * 100 |
| macro_f1 = report_dict['macro avg']['f1-score'] |
| weighted_f1 = report_dict['weighted avg']['f1-score'] |
|
|
| print(f'\n Overall accuracy : {overall_accuracy:.2f}%') |
| print(f' Macro F1 : {macro_f1:.4f}') |
| print(f' Weighted F1 : {weighted_f1:.4f}') |
|
|
| |
| |
| |
| print('\nAnalyzing confidence distributions...') |
|
|
| fig, axes = plt.subplots(2, 3, figsize=(18, 10)) |
| axes = axes.flatten() |
|
|
| all_max_conf = all_probs.max(axis=1) |
| all_correct = (preds == all_labels) |
|
|
| for ci, cn in enumerate(CLASS_NAMES): |
| ax = axes[ci] |
| mask_class = (all_labels == ci) |
|
|
| correct_conf = all_max_conf[mask_class & all_correct] |
| wrong_conf = all_max_conf[mask_class & ~all_correct] |
|
|
| n_correct = len(correct_conf) |
| n_wrong = len(wrong_conf) |
|
|
| if n_correct > 0: |
| ax.hist(correct_conf, bins=20, alpha=0.6, color='#2ecc71', |
| label=f'Correct (n={n_correct})', density=True) |
| if n_wrong > 0: |
| ax.hist(wrong_conf, bins=20, alpha=0.6, color='#e74c3c', |
| label=f'Wrong (n={n_wrong})', density=True) |
|
|
| |
| if n_wrong > 0: |
| high_conf_wrong = (wrong_conf > 0.8).sum() |
| ax.axvline(0.8, color='darkred', linestyle='--', alpha=0.7, lw=1.5, |
| label=f'Conf>0.8 wrong: {high_conf_wrong}') |
|
|
| ax.set_title(f'{cn}\nPrec={per_class_precision[cn]:.3f} Rec={per_class_recall[cn]:.3f} F1={per_class_f1[cn]:.3f}', |
| fontsize=10, fontweight='bold') |
| ax.set_xlabel('Max Confidence', fontsize=9) |
| ax.set_ylabel('Density', fontsize=9) |
| ax.legend(fontsize=7) |
| ax.grid(alpha=0.3) |
| ax.set_xlim(0, 1) |
|
|
| |
| ax = axes[5] |
| mean_correct = [all_max_conf[all_labels==c][preds[all_labels==c]==c].mean() |
| if (all_labels==c).sum() > 0 else 0 for c in range(NUM_CLASSES)] |
| mean_wrong = [all_max_conf[all_labels==c][preds[all_labels==c]!=c].mean() |
| if ((all_labels==c) & (preds!=c)).sum() > 0 else 0 for c in range(NUM_CLASSES)] |
|
|
| x = np.arange(NUM_CLASSES) |
| width = 0.35 |
| ax.bar(x - width/2, mean_correct, width, label='Mean conf (correct)', color='#2ecc71', alpha=0.8) |
| ax.bar(x + width/2, mean_wrong, width, label='Mean conf (wrong)', color='#e74c3c', alpha=0.8) |
| ax.set_xticks(x) |
| ax.set_xticklabels([c[:6] for c in CLASS_NAMES], rotation=20) |
| ax.set_ylabel('Mean Confidence') |
| ax.set_title('Mean Confidence: Correct vs Wrong', fontweight='bold') |
| ax.legend(fontsize=8) |
| ax.grid(alpha=0.3, axis='y') |
| ax.set_ylim(0, 1) |
|
|
| plt.suptitle('Confidence Distribution Analysis per Class', fontsize=14, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(f'{OUT_DIR}/confidence_distributions.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print(' Saved confidence_distributions.png') |
|
|
| |
| |
| |
| print('\nRunning per-source analysis...') |
|
|
| |
| source_col = val_df['dataset'].values |
|
|
| results_df = pd.DataFrame({ |
| 'true_label': all_labels, |
| 'pred_label': preds, |
| 'max_conf': all_max_conf, |
| 'dataset': source_col[all_idxs], |
| 'correct': (preds == all_labels).astype(int), |
| }) |
|
|
| per_source = {} |
| for src in ['ODIR', 'APTOS']: |
| mask = results_df['dataset'] == src |
| if mask.sum() == 0: |
| continue |
| src_true = results_df['true_label'][mask].values |
| src_pred = results_df['pred_label'][mask].values |
| src_acc = (src_true == src_pred).mean() * 100 |
| src_f1 = f1_score(src_true, src_pred, average='macro', zero_division=0) |
|
|
| per_class_acc_src = {} |
| for c in range(NUM_CLASSES): |
| cmask = (src_true == c) |
| if cmask.sum() == 0: |
| per_class_acc_src[CLASS_NAMES[c]] = None |
| else: |
| per_class_acc_src[CLASS_NAMES[c]] = float((src_pred[cmask] == c).mean() * 100) |
|
|
| per_source[src] = { |
| 'n_samples': int(mask.sum()), |
| 'accuracy': float(src_acc), |
| 'macro_f1': float(src_f1), |
| 'per_class_acc': per_class_acc_src |
| } |
| print(f'\n {src} (n={mask.sum()}):') |
| print(f' Accuracy : {src_acc:.2f}%') |
| print(f' Macro F1 : {src_f1:.4f}') |
| for cn, acc in per_class_acc_src.items(): |
| if acc is not None: |
| print(f' {cn:<15s}: {acc:.1f}%') |
|
|
| |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| |
| sources = list(per_source.keys()) |
| accs = [per_source[s]['accuracy'] for s in sources] |
| f1s = [per_source[s]['macro_f1'] for s in sources] |
|
|
| x = np.arange(len(sources)) |
| w = 0.35 |
| axes[0].bar(x - w/2, accs, w, label='Accuracy (%)', color=['#3498db', '#e67e22'], alpha=0.85) |
| axes[0].bar(x + w/2, [f*100 for f in f1s], w, label='Macro F1 ×100', |
| color=['#2ecc71', '#e74c3c'], alpha=0.85) |
| axes[0].set_xticks(x); axes[0].set_xticklabels(sources) |
| axes[0].set_ylim(50, 100) |
| axes[0].set_ylabel('Score') |
| axes[0].set_title('Overall Performance by Source', fontweight='bold') |
| axes[0].legend(); axes[0].grid(alpha=0.3, axis='y') |
| for xi, (acc, f1) in enumerate(zip(accs, f1s)): |
| axes[0].text(xi - w/2, acc + 0.5, f'{acc:.1f}', ha='center', fontsize=9) |
| axes[0].text(xi + w/2, f1*100 + 0.5, f'{f1*100:.1f}', ha='center', fontsize=9) |
|
|
| |
| class_data = {cn: [] for cn in CLASS_NAMES} |
| valid_sources = [] |
| for src in sources: |
| valid_sources.append(src) |
| for cn in CLASS_NAMES: |
| acc = per_source[src]['per_class_acc'].get(cn) |
| class_data[cn].append(acc if acc is not None else 0.0) |
|
|
| x = np.arange(len(CLASS_NAMES)) |
| n_src = len(valid_sources) |
| width = 0.8 / n_src |
| colors_src = ['#3498db', '#e67e22', '#2ecc71'] |
|
|
| for si, src in enumerate(valid_sources): |
| vals = [class_data[cn][si] for cn in CLASS_NAMES] |
| offset = (si - n_src/2 + 0.5) * width |
| axes[1].bar(x + offset, vals, width, label=src, alpha=0.85, color=colors_src[si]) |
|
|
| axes[1].set_xticks(x); axes[1].set_xticklabels(CLASS_NAMES, rotation=20, ha='right') |
| axes[1].set_ylim(0, 105) |
| axes[1].set_ylabel('Accuracy (%)') |
| axes[1].set_title('Per-Class Accuracy by Source', fontweight='bold') |
| axes[1].legend(); axes[1].grid(alpha=0.3, axis='y') |
|
|
| plt.suptitle('Dataset Source Performance Analysis', fontsize=14, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(f'{OUT_DIR}/per_source_performance.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print(' Saved per_source_performance.png') |
|
|
| |
| |
| |
| print('\nSaving metrics JSON...') |
|
|
| baseline_metrics = { |
| 'overall_accuracy': float(overall_accuracy), |
| 'raw_accuracy': float(raw_acc), |
| 'threshold_accuracy': float(thresh_acc), |
| 'macro_f1': float(macro_f1), |
| 'weighted_f1': float(weighted_f1), |
| 'ece': float(ece), |
| 'per_class_ece': per_class_ece, |
| 'per_class_f1': per_class_f1, |
| 'per_class_precision': per_class_precision, |
| 'per_class_recall': per_class_recall, |
| 'per_class_support': per_class_support, |
| 'per_source_accuracy': { |
| src: { |
| 'accuracy': per_source[src]['accuracy'], |
| 'macro_f1': per_source[src]['macro_f1'], |
| 'n_samples': per_source[src]['n_samples'], |
| 'per_class_acc': per_source[src]['per_class_acc'] |
| } |
| for src in per_source |
| }, |
| 'top_confusion_pairs': top5_pairs, |
| 'confusion_matrix_raw': cm_raw.tolist(), |
| 'val_split_size': len(val_df), |
| 'thresholds_used': thresholds, |
| 'calibration': { |
| 'ece': float(ece), |
| 'bin_acc': [float(x) for x in bin_acc], |
| 'bin_conf': [float(x) for x in bin_conf], |
| 'bin_count': bin_count, |
| } |
| } |
|
|
| with open(f'{OUT_DIR}/baseline_metrics.json', 'w') as f: |
| json.dump(baseline_metrics, f, indent=2) |
| print(f' Saved baseline_metrics.json') |
|
|
| |
| |
| |
| print('\nGenerating analysis report...') |
|
|
| |
| worst_recall_class = min(per_class_recall, key=per_class_recall.get) |
| worst_f1_class = min(per_class_f1, key=per_class_f1.get) |
| best_f1_class = max(per_class_f1, key=per_class_f1.get) |
|
|
| |
| hcw_analysis = {} |
| for ci, cn in enumerate(CLASS_NAMES): |
| mask_class = (all_labels == ci) |
| wrong_mask = mask_class & ~all_correct |
| if wrong_mask.sum() > 0: |
| high_conf_wrong = ((all_max_conf > 0.8) & wrong_mask).sum() |
| hcw_analysis[cn] = { |
| 'total_wrong': int(wrong_mask.sum()), |
| 'high_conf_wrong_count': int(high_conf_wrong), |
| 'high_conf_wrong_pct': float(high_conf_wrong / wrong_mask.sum() * 100) if wrong_mask.sum() > 0 else 0, |
| 'mean_wrong_conf': float(all_max_conf[wrong_mask].mean()) if wrong_mask.sum() > 0 else 0, |
| } |
| else: |
| hcw_analysis[cn] = {'total_wrong': 0, 'high_conf_wrong_count': 0, |
| 'high_conf_wrong_pct': 0, 'mean_wrong_conf': 0} |
|
|
| |
| domain_gap = None |
| if 'ODIR' in per_source and 'APTOS' in per_source: |
| odir_acc = per_source['ODIR']['accuracy'] |
| aptos_acc = per_source['APTOS']['accuracy'] |
| domain_gap = abs(odir_acc - aptos_acc) |
|
|
| |
| odir_dr = per_source['ODIR']['per_class_acc'].get('Diabetes/DR', 0) or 0 |
| aptos_dr = per_source['APTOS']['per_class_acc'].get('Diabetes/DR', 0) or 0 |
| dr_gap = abs(odir_dr - aptos_dr) |
| else: |
| domain_gap = 0; odir_acc = 0; aptos_acc = 0; odir_dr = 0; aptos_dr = 0; dr_gap = 0 |
|
|
| calibration_verdict = 'overconfident' if sum( |
| b_conf - b_acc for b_conf, b_acc in zip(bin_conf, bin_acc) if bin_count[bin_acc.index(b_acc)] > 0 |
| ) > 0 else 'underconfident' |
|
|
| report = f"""# RetinaSense ViT v2 — Baseline Error Analysis Report |
| **Generated**: 2026-03-06 |
| **Model**: ViT-Base-Patch16-224 (MultiTaskViT) |
| **Checkpoint**: outputs_vit/best_model.pth |
| **Val Split**: {len(val_df)} samples (20% stratified, random_state=42) |
| |
| --- |
| |
| ## 1. Overall Performance |
| |
| | Metric | Value | |
| |--------|-------| |
| | Accuracy (raw argmax) | {raw_acc:.2f}% | |
| | Accuracy (with thresholds) | {thresh_acc:.2f}% | |
| | Macro F1 | {macro_f1:.4f} | |
| | Weighted F1 | {weighted_f1:.4f} | |
| | ECE (10 bins) | {ece:.4f} | |
| |
| --- |
| |
| ## 2. Per-Class Metrics |
| |
| | Class | Precision | Recall | F1 | Support | |
| |-------|-----------|--------|----|---------| |
| """ |
| for cn in CLASS_NAMES: |
| report += (f"| {cn:<15s} | {per_class_precision[cn]:.4f} | " |
| f"{per_class_recall[cn]:.4f} | {per_class_f1[cn]:.4f} | " |
| f"{per_class_support[cn]:4d} |\n") |
|
|
| report += f""" |
| --- |
| |
| ## 3. Confusion Analysis — Top 5 Confused Pairs |
| |
| | Rank | True Class | Predicted As | Count | Rate | |
| |------|-----------|-------------|-------|------| |
| """ |
| for rank, pair in enumerate(top5_pairs, 1): |
| report += (f"| {rank} | {pair['true_class']} | {pair['pred_class']} | " |
| f"{pair['count']} | {pair['rate']*100:.1f}% |\n") |
|
|
| report += f""" |
| ### Full Confusion Matrix (normalized by true class) |
| |
| ``` |
| {(' '.join(f'{cn[:6]:>7s}' for cn in CLASS_NAMES))} |
| """ |
| for ri, rn in enumerate(CLASS_NAMES): |
| row_str = ' '.join(f'{cm_norm[ri, ci]:.3f}' for ci in range(NUM_CLASSES)) |
| report += f"{rn[:8]:>8s} {row_str}\n" |
|
|
| report += f"""``` |
| |
| --- |
| |
| ## 4. Confidence Calibration Analysis |
| |
| - **ECE (overall)**: {ece:.4f} |
| - **Calibration pattern**: The model is predominantly **{calibration_verdict}** |
| (mean confidence exceeds accuracy in most bins). |
| |
| ### Per-Class ECE |
| |
| | Class | ECE | |
| |-------|-----| |
| """ |
| for cn, ece_c in per_class_ece.items(): |
| report += f"| {cn} | {ece_c:.4f} |\n" |
|
|
| report += f""" |
| ### High-Confidence Wrong Predictions (confidence > 0.8) |
| |
| | Class | Total Wrong | High-Conf Wrong | % of Errors | Mean Wrong Conf | |
| |-------|------------|----------------|-------------|----------------| |
| """ |
| for cn, hcw in hcw_analysis.items(): |
| report += (f"| {cn} | {hcw['total_wrong']} | {hcw['high_conf_wrong_count']} | " |
| f"{hcw['high_conf_wrong_pct']:.1f}% | {hcw['mean_wrong_conf']:.3f} |\n") |
|
|
| report += f""" |
| --- |
| |
| ## 5. Dataset Source Analysis (ODIR vs APTOS) |
| |
| | Source | N Samples | Accuracy | Macro F1 | |
| |--------|-----------|----------|----------| |
| """ |
| for src, data in per_source.items(): |
| report += f"| {src} | {data['n_samples']} | {data['accuracy']:.2f}% | {data['macro_f1']:.4f} |\n" |
|
|
| report += f""" |
| ### Per-Class Accuracy by Source |
| |
| | Class |""" |
| for src in per_source: |
| report += f" {src} |" |
| report += "\n|-------|" |
| for _ in per_source: |
| report += "--------|" |
| report += "\n" |
| for cn in CLASS_NAMES: |
| report += f"| {cn} |" |
| for src in per_source: |
| acc = per_source[src]['per_class_acc'].get(cn) |
| if acc is None: |
| report += " N/A |" |
| else: |
| report += f" {acc:.1f}% |" |
| report += "\n" |
|
|
| report += f""" |
| **Domain gap (overall accuracy)**: {domain_gap:.2f}pp between ODIR and APTOS |
| """ |
| if 'ODIR' in per_source and 'APTOS' in per_source: |
| report += f"""**DR class gap (ODIR vs APTOS)**: ODIR={odir_dr:.1f}% vs APTOS={aptos_dr:.1f}% (gap={dr_gap:.1f}pp) |
| """ |
|
|
| report += f""" |
| --- |
| |
| ## 6. Error Pattern Summary |
| |
| ### Q1: What is the model's biggest weakness? |
| |
| The model's biggest weakness is classifying **{worst_f1_class}** (F1={per_class_f1[worst_f1_class]:.4f}, |
| recall={per_class_recall[worst_f1_class]:.4f}). This class has the worst F1 score, indicating the |
| model struggles to both detect and correctly distinguish it from other pathologies. |
| |
| The confusion matrix shows that the primary confusion pathway is: |
| - **{top5_pairs[0]['description']}**: {top5_pairs[0]['count']} cases ({top5_pairs[0]['rate']*100:.1f}% error rate) |
| - **{top5_pairs[1]['description']}**: {top5_pairs[1]['count']} cases ({top5_pairs[1]['rate']*100:.1f}% error rate) |
| |
| ### Q2: Which class has the worst recall? Why? |
| |
| **{worst_recall_class}** has the worst recall at {per_class_recall[worst_recall_class]:.4f}. |
| """ |
|
|
| |
| worst_support = per_class_support[worst_recall_class] |
| all_support = sum(per_class_support.values()) |
| worst_pct = worst_support / all_support * 100 |
| report += f"""This class represents only {worst_support} samples ({worst_pct:.1f}% of the val set). |
| The low recall is likely caused by: |
| 1. **Class imbalance** — the model sees fewer examples during training and defaults to predicting |
| more common classes when uncertain. |
| 2. **Visual similarity** with other conditions (especially {top5_pairs[0]['pred_class'] if top5_pairs[0]['true_class']==worst_recall_class else 'Normal'}) |
| at the fundus level. |
| 3. **Threshold sensitivity** — the optimized threshold ({thresholds.get(CLASS_NAMES.index(worst_recall_class), 0.5):.2f}) |
| may overcorrect or undercorrect depending on the calibration. |
| |
| ### Q3: Evidence of domain shift (ODIR vs APTOS)? |
| |
| """ |
| if domain_gap is not None and domain_gap > 2.0: |
| report += f"""YES — there is a **{domain_gap:.1f}pp accuracy gap** between ODIR ({odir_acc:.1f}%) and APTOS |
| ({aptos_acc:.1f}%). This is significant and consistent with domain shift between the two data sources. |
| |
| For the DR/Diabetes class specifically, the gap is **{dr_gap:.1f}pp** (ODIR={odir_dr:.1f}% vs APTOS={aptos_dr:.1f}%). |
| APTOS images are specifically DR-graded fundus photographs from India (Aravind Eye Hospital), |
| while ODIR covers multiple disease classes with more varied image quality and capture conditions. |
| The Ben Graham preprocessing helps but does not fully bridge the domain gap. |
| |
| **Implication for v3**: Domain-specific augmentation or source-aware training (e.g., source |
| as auxiliary input, separate batch norms, or domain adaptation) may improve generalization. |
| """ |
| elif domain_gap is not None and domain_gap > 0: |
| report += f"""MINOR gap observed — {domain_gap:.1f}pp difference between ODIR ({odir_acc:.1f}%) and |
| APTOS ({aptos_acc:.1f}%). The gap is small, suggesting the Ben Graham preprocessing and ViT |
| architecture generalize reasonably across sources. DR-specific gap: {dr_gap:.1f}pp. |
| """ |
| else: |
| report += "Insufficient cross-source data to conclude domain shift.\n" |
|
|
| report += f""" |
| ### Q4: Calibration assessment |
| |
| ECE = **{ece:.4f}** (scale: 0=perfect, 0.1=poor). |
| |
| """ |
| if ece < 0.03: |
| report += "The model is **well-calibrated** (ECE < 0.03). Confidence scores are reliable." |
| elif ece < 0.07: |
| report += f"""The model shows **moderate miscalibration** (ECE={ece:.4f}). The reliability diagram |
| shows the model is {calibration_verdict} in the high-confidence range, meaning predicted |
| confidence scores are not fully reliable. Temperature scaling in v3 is recommended.""" |
| else: |
| report += f"""The model is **poorly calibrated** (ECE={ece:.4f}). The {calibration_verdict} |
| pattern is severe. Temperature scaling or label smoothing in v3 training is strongly recommended.""" |
|
|
| report += f""" |
| |
| --- |
| |
| ## 7. Recommendations for v3 Training |
| |
| Based on this baseline analysis: |
| |
| 1. **Address {worst_recall_class} recall** — increase class weight, targeted augmentation, |
| or focal loss gamma tuning for this class. |
| 2. **Calibration** — add temperature scaling post-training or increase label smoothing |
| (current ECE={ece:.4f}). |
| 3. **Domain shift mitigation** — consider source-conditioned augmentation or adversarial |
| domain adaptation if ODIR/APTOS gap persists. |
| 4. **High-confidence errors** — the model makes confidently wrong predictions on certain |
| classes; mixup or CutMix augmentation may improve uncertainty estimation. |
| 5. **Top confusion pairs** to specifically target: |
| """ |
| for pair in top5_pairs[:3]: |
| report += f" - {pair['description']} ({pair['count']} errors)\n" |
|
|
| report += """ |
| --- |
| |
| ## 8. Output Files |
| |
| | File | Description | |
| |------|-------------| |
| | confusion_matrix_raw.png | Raw count confusion matrix | |
| | confusion_matrix_normalized.png | Recall-normalized confusion matrix | |
| | reliability_diagram.png | ECE calibration plot | |
| | confidence_distributions.png | Per-class confidence histograms | |
| | per_source_performance.png | ODIR vs APTOS breakdown | |
| | baseline_metrics.json | All metrics in structured JSON | |
| |
| --- |
| *Report generated by RetinaSense ViT v2 error analysis pipeline.* |
| """ |
|
|
| with open(f'{OUT_DIR}/BASELINE_ANALYSIS.md', 'w') as f: |
| f.write(report) |
| print(f' Saved BASELINE_ANALYSIS.md') |
|
|
| |
| |
| |
| print('\n' + '='*65) |
| print(' BASELINE ANALYSIS COMPLETE') |
| print('='*65) |
| print(f' Val accuracy (thresh) : {thresh_acc:.2f}%') |
| print(f' Macro F1 : {macro_f1:.4f}') |
| print(f' ECE : {ece:.4f}') |
| print(f' Worst class (F1) : {worst_f1_class} ({per_class_f1[worst_f1_class]:.4f})') |
| print(f' Worst class (recall) : {worst_recall_class} ({per_class_recall[worst_recall_class]:.4f})') |
| print(f' Top confusion : {top5_pairs[0]["description"]}') |
| if domain_gap is not None: |
| print(f' Domain gap (ODIR-APTOS): {domain_gap:.2f}pp') |
| print(f'\n All outputs in: {OUT_DIR}/') |
| print('='*65) |
|
|