retinasense-vit / run_error_analysis.py
tanishq74's picture
Add run_error_analysis.py
8552f4b verified
raw
history blame
36 kB
#!/usr/bin/env python3
"""
RetinaSense ViT v2 - Comprehensive Error Analysis & Baseline Report
===================================================================
Runs full evaluation on the validation split, computes ECE,
confusion analysis, confidence distributions, and source-level
performance. Saves all plots and metrics to outputs_analysis/v2_baseline/.
"""
import os, sys, json, warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
warnings.filterwarnings('ignore')
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import timm
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
confusion_matrix, classification_report,
f1_score, precision_score, recall_score
)
# ================================================================
# CONFIG
# ================================================================
BASE_DIR = '/teamspace/studios/this_studio'
MODEL_PATH = f'{BASE_DIR}/outputs_vit/best_model.pth'
META_CSV = f'{BASE_DIR}/final_unified_metadata.csv'
THRESH_JSON = f'{BASE_DIR}/outputs_vit/threshold_optimization_results.json'
CACHE_DIR = f'{BASE_DIR}/preprocessed_cache_vit'
OUT_DIR = f'{BASE_DIR}/outputs_analysis/v2_baseline'
IMG_SIZE = 224
BATCH_SIZE = 64
NUM_WORKERS = 8
NUM_CLASSES = 5
CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
print(f'GPU: {torch.cuda.get_device_name(0)}')
# ================================================================
# MODEL DEFINITION (mirrors retinasense_vit.py)
# ================================================================
class MultiTaskViT(nn.Module):
def __init__(self, n_disease=5, n_severity=5, drop=0.4):
super().__init__()
self.backbone = timm.create_model(
'vit_base_patch16_224', pretrained=False, num_classes=0)
feat = 768
self.drop = nn.Dropout(drop)
self.disease_head = nn.Sequential(
nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(256, n_disease))
self.severity_head = nn.Sequential(
nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(256, n_severity))
def forward(self, x):
f = self.backbone(x)
f = self.drop(f)
return self.disease_head(f), self.severity_head(f)
# ================================================================
# IMAGE PREPROCESSING (Ben Graham method, matches training)
# ================================================================
def ben_graham(path, sz=IMG_SIZE, sigma=10):
img = cv2.imread(str(path))
if img is None:
img = np.array(Image.open(str(path)).convert('RGB'))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (sz, sz))
img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0,0), sigma), -4, 128)
mask = np.zeros(img.shape[:2], dtype=np.uint8)
cv2.circle(mask, (sz//2, sz//2), int(sz * 0.48), 255, -1)
return cv2.bitwise_and(img, img, mask=mask)
def resolve_image_path(raw_path):
"""
Resolve image path from CSV entry (which has leading .// prefix).
Tries multiple known root locations.
APTOS images live in:
aptos/gaussian_filtered_images/gaussian_filtered_images/{Severity}/{stem}.png
ODIR images live in:
odir/preprocessed_images/{filename}
"""
# Strip leading .// or ./
clean = raw_path.lstrip('.').lstrip('/').lstrip('/')
clean = clean.replace('//', '/')
stem = Path(raw_path).stem
candidates = [
f'{BASE_DIR}/{clean}',
]
# APTOS: search all severity subfolders
if 'aptos' in raw_path.lower():
aptos_base = f'{BASE_DIR}/aptos/gaussian_filtered_images/gaussian_filtered_images'
for severity in ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferate_DR']:
for ext in ['.png', '.jpg', '.jpeg']:
candidates.append(f'{aptos_base}/{severity}/{stem}{ext}')
# Also try train_images (original path)
for ext in ['.png', '.jpg', '.jpeg']:
candidates.append(f'{BASE_DIR}/aptos/train_images/{stem}{ext}')
# ODIR: preprocessed_images
if 'odir' in raw_path.lower():
fname = Path(raw_path).name
candidates.append(f'{BASE_DIR}/odir/preprocessed_images/{fname}')
candidates.append(f'{BASE_DIR}/ocular-disease-recognition-odir5k/preprocessed_images/{fname}')
for c in candidates:
if os.path.exists(c):
return c
return None
def load_or_cache(row):
"""
Load preprocessed image from cache (.npy) or process from disk.
Returns uint8 HxWx3 numpy array.
"""
stem = Path(row['image_path_clean']).stem
cache_fp = f'{CACHE_DIR}/{stem}_224.npy'
if os.path.exists(cache_fp):
try:
return np.load(cache_fp)
except Exception:
pass
img_path = row.get('image_path_resolved')
if img_path and os.path.exists(img_path):
try:
arr = ben_graham(img_path)
np.save(cache_fp, arr)
return arr
except Exception as e:
pass
# Fallback: zero image
return np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
# ================================================================
# DATASET
# ================================================================
val_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
class RetDS(Dataset):
def __init__(self, df):
self.df = df.reset_index(drop=True)
def __len__(self):
return len(self.df)
def __getitem__(self, i):
r = self.df.iloc[i]
img = load_or_cache(r)
return (
val_transform(img),
torch.tensor(int(r['disease_label']), dtype=torch.long),
torch.tensor(int(r['severity_label']), dtype=torch.long),
i # return index so we can track per-sample metadata
)
# ================================================================
# STEP 1 — LOAD METADATA & BUILD VAL SPLIT
# ================================================================
print('\n[1/6] Loading metadata and building val split...')
meta = pd.read_csv(META_CSV)
print(f' Raw rows: {len(meta)}')
# Fix image paths
meta['image_path_clean'] = meta['image_path'].str.lstrip('.').str.lstrip('/').str.replace('//', '/', regex=False)
meta['image_path_resolved'] = meta['image_path_clean'].apply(
lambda p: resolve_image_path(p)
)
n_resolved = meta['image_path_resolved'].notna().sum()
print(f' Images resolved on disk: {n_resolved} / {len(meta)}')
# Build the same stratified split used in training (random_state=42, test_size=0.2)
train_df, val_df = train_test_split(
meta,
test_size=0.2,
stratify=meta['disease_label'],
random_state=42
)
val_df = val_df.reset_index(drop=True)
print(f' Val split: {len(val_df)} samples')
print(f' Val class distribution:')
for lbl, cnt in val_df['disease_label'].value_counts().sort_index().items():
print(f' {CLASS_NAMES[int(lbl)]:<15s}: {cnt:4d}')
# ================================================================
# STEP 2 — LOAD MODEL
# ================================================================
print('\n[2/6] Loading model...')
model = MultiTaskViT().to(device)
ckpt = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print(f' Loaded checkpoint: epoch={ckpt.get("epoch","?")}, '
f'macro_f1={ckpt.get("macro_f1", 0):.4f}')
# Load thresholds
with open(THRESH_JSON) as f:
thresh_data = json.load(f)
thresholds = {int(k): float(v) for k, v in thresh_data['optimal_thresholds'].items()}
print(f' Optimal thresholds: {thresholds}')
# ================================================================
# STEP 3 — RUN INFERENCE
# ================================================================
print('\n[3/6] Running inference on val set...')
val_ds = RetDS(val_df)
val_loader = DataLoader(
val_ds, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True
)
all_probs = [] # (N, 5) softmax probabilities
all_preds = [] # (N,) argmax predictions
all_labels = [] # (N,) true labels
all_idxs = [] # (N,) val_df indices
with torch.no_grad():
for imgs, d_lbl, s_lbl, idx in tqdm(val_loader, desc='Inference'):
imgs = imgs.to(device, non_blocking=True)
with torch.amp.autocast('cuda'):
d_out, _ = model(imgs)
probs = torch.softmax(d_out.float(), dim=1).cpu().numpy()
preds = d_out.argmax(1).cpu().numpy()
all_probs.append(probs)
all_preds.append(preds)
all_labels.append(d_lbl.numpy())
all_idxs.append(idx.numpy())
all_probs = np.vstack(all_probs) # (N, 5)
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)
all_idxs = np.concatenate(all_idxs)
# Also compute threshold-adjusted predictions
thresh_preds = np.zeros_like(all_preds)
for i in range(len(all_probs)):
adjusted = all_probs[i].copy()
for c, t in thresholds.items():
adjusted[c] = all_probs[i][c] / t # scale by threshold
thresh_preds[i] = adjusted.argmax()
raw_acc = (all_preds == all_labels).mean() * 100
thresh_acc = (thresh_preds == all_labels).mean() * 100
print(f' Raw accuracy : {raw_acc:.2f}%')
print(f' Threshold accuracy: {thresh_acc:.2f}%')
# Use threshold-adjusted for main analysis (matches published 84.48%)
preds = thresh_preds
# ================================================================
# STEP 4 — CONFIDENCE CALIBRATION (ECE)
# ================================================================
print('\n[4/6] Computing ECE and reliability diagram...')
def compute_ece(probs, labels, n_bins=10):
"""Expected Calibration Error with equal-width bins."""
confidences = probs.max(axis=1) # max probability = confidence
predicted = probs.argmax(axis=1)
correct = (predicted == labels).astype(float)
bins = np.linspace(0, 1, n_bins + 1)
ece = 0.0
bin_acc = []
bin_conf = []
bin_count = []
for lo, hi in zip(bins[:-1], bins[1:]):
mask = (confidences >= lo) & (confidences < hi)
if mask.sum() == 0:
bin_acc.append(0.0)
bin_conf.append((lo + hi) / 2)
bin_count.append(0)
continue
acc = correct[mask].mean()
conf = confidences[mask].mean()
n = mask.sum()
ece += (n / len(labels)) * abs(acc - conf)
bin_acc.append(acc)
bin_conf.append(conf)
bin_count.append(int(n))
return ece, bin_acc, bin_conf, bin_count, bins
ece, bin_acc, bin_conf, bin_count, bins = compute_ece(all_probs, all_labels)
print(f' ECE (10 bins): {ece:.4f}')
# Per-class calibration
per_class_ece = {}
for c in range(NUM_CLASSES):
mask = (all_labels == c)
if mask.sum() == 0:
per_class_ece[CLASS_NAMES[c]] = 0.0
continue
ece_c, _, _, _, _ = compute_ece(all_probs[mask], all_labels[mask])
per_class_ece[CLASS_NAMES[c]] = float(ece_c)
print(f' ECE {CLASS_NAMES[c]:<15s}: {ece_c:.4f}')
# -- Reliability diagram --
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
bin_centers = (bins[:-1] + bins[1:]) / 2
bars = axes[0].bar(
bin_centers, bin_acc,
width=(bins[1] - bins[0]) * 0.9,
alpha=0.7, color='steelblue', label='Accuracy per bin'
)
axes[0].plot([0, 1], [0, 1], 'r--', lw=2, label='Perfect calibration')
axes[0].set_xlabel('Confidence', fontsize=12)
axes[0].set_ylabel('Accuracy', fontsize=12)
axes[0].set_title(f'Reliability Diagram\nECE = {ece:.4f}', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(alpha=0.3)
axes[0].set_xlim(0, 1); axes[0].set_ylim(0, 1)
# Annotate with bin counts
for bar, cnt in zip(bars, bin_count):
if cnt > 0:
axes[0].text(
bar.get_x() + bar.get_width()/2, min(bar.get_height() + 0.02, 0.97),
str(cnt), ha='center', va='bottom', fontsize=7, color='black'
)
# Gap diagram (overconfidence = positive gap)
gap = np.array(bin_conf) - np.array(bin_acc)
color_gap = ['#e74c3c' if g > 0 else '#2ecc71' for g in gap]
axes[1].bar(bin_centers, gap, width=(bins[1]-bins[0])*0.9, color=color_gap, alpha=0.8)
axes[1].axhline(0, color='black', lw=1)
axes[1].set_xlabel('Confidence', fontsize=12)
axes[1].set_ylabel('Confidence - Accuracy (Gap)', fontsize=12)
axes[1].set_title('Calibration Gap\n(Red=overconfident, Green=underconfident)',
fontsize=13, fontweight='bold')
axes[1].grid(alpha=0.3)
axes[1].set_xlim(0, 1)
plt.tight_layout()
plt.savefig(f'{OUT_DIR}/reliability_diagram.png', dpi=150, bbox_inches='tight')
plt.close()
print(f' Saved reliability_diagram.png')
# ================================================================
# STEP 5 — CONFUSION MATRIX
# ================================================================
print('\n[5/6] Generating confusion matrices...')
cm_raw = confusion_matrix(all_labels, preds)
cm_norm = cm_raw.astype(float) / cm_raw.sum(axis=1, keepdims=True)
# -- Raw counts confusion matrix --
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(
cm_raw, annot=True, fmt='d', cmap='Blues', ax=ax,
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
linewidths=0.5, linecolor='gray'
)
ax.set_title('Confusion Matrix (Raw Counts)', fontsize=14, fontweight='bold')
ax.set_ylabel('True Label', fontsize=12)
ax.set_xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=30, ha='right')
plt.tight_layout()
plt.savefig(f'{OUT_DIR}/confusion_matrix_raw.png', dpi=150, bbox_inches='tight')
plt.close()
# -- Normalized confusion matrix --
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(
cm_norm, annot=True, fmt='.3f', cmap='Blues', ax=ax,
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
linewidths=0.5, linecolor='gray', vmin=0, vmax=1
)
ax.set_title('Confusion Matrix (Normalized by True Class)', fontsize=14, fontweight='bold')
ax.set_ylabel('True Label', fontsize=12)
ax.set_xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=30, ha='right')
plt.tight_layout()
plt.savefig(f'{OUT_DIR}/confusion_matrix_normalized.png', dpi=150, bbox_inches='tight')
plt.close()
print(' Saved confusion_matrix_raw.png and confusion_matrix_normalized.png')
# -- Top confused pairs --
confused_pairs = []
for true_c in range(NUM_CLASSES):
for pred_c in range(NUM_CLASSES):
if true_c == pred_c:
continue
count = cm_raw[true_c, pred_c]
rate = cm_norm[true_c, pred_c]
confused_pairs.append({
'true_class': CLASS_NAMES[true_c],
'pred_class': CLASS_NAMES[pred_c],
'count': int(count),
'rate': float(rate),
'description': f'{CLASS_NAMES[true_c]} misclassified AS {CLASS_NAMES[pred_c]}'
})
confused_pairs.sort(key=lambda x: x['count'], reverse=True)
top5_pairs = confused_pairs[:5]
print('\n Top 5 confused class pairs (by raw count):')
for p in top5_pairs:
print(f' {p["description"]}: {p["count"]} ({p["rate"]*100:.1f}%)')
# ================================================================
# STEP 6 — PER-CLASS METRICS
# ================================================================
print('\n[6/6] Computing per-class metrics...')
report_dict = classification_report(
all_labels, preds, target_names=CLASS_NAMES, output_dict=True, zero_division=0
)
print(classification_report(all_labels, preds, target_names=CLASS_NAMES, digits=4, zero_division=0))
per_class_precision = {}
per_class_recall = {}
per_class_f1 = {}
per_class_support = {}
for cn in CLASS_NAMES:
per_class_precision[cn] = report_dict[cn]['precision']
per_class_recall[cn] = report_dict[cn]['recall']
per_class_f1[cn] = report_dict[cn]['f1-score']
per_class_support[cn] = int(report_dict[cn]['support'])
overall_accuracy = report_dict['accuracy'] * 100
macro_f1 = report_dict['macro avg']['f1-score']
weighted_f1 = report_dict['weighted avg']['f1-score']
print(f'\n Overall accuracy : {overall_accuracy:.2f}%')
print(f' Macro F1 : {macro_f1:.4f}')
print(f' Weighted F1 : {weighted_f1:.4f}')
# ================================================================
# CONFIDENCE DISTRIBUTION ANALYSIS
# ================================================================
print('\nAnalyzing confidence distributions...')
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()
all_max_conf = all_probs.max(axis=1)
all_correct = (preds == all_labels)
for ci, cn in enumerate(CLASS_NAMES):
ax = axes[ci]
mask_class = (all_labels == ci)
correct_conf = all_max_conf[mask_class & all_correct]
wrong_conf = all_max_conf[mask_class & ~all_correct]
n_correct = len(correct_conf)
n_wrong = len(wrong_conf)
if n_correct > 0:
ax.hist(correct_conf, bins=20, alpha=0.6, color='#2ecc71',
label=f'Correct (n={n_correct})', density=True)
if n_wrong > 0:
ax.hist(wrong_conf, bins=20, alpha=0.6, color='#e74c3c',
label=f'Wrong (n={n_wrong})', density=True)
# Mark high-confidence wrong predictions
if n_wrong > 0:
high_conf_wrong = (wrong_conf > 0.8).sum()
ax.axvline(0.8, color='darkred', linestyle='--', alpha=0.7, lw=1.5,
label=f'Conf>0.8 wrong: {high_conf_wrong}')
ax.set_title(f'{cn}\nPrec={per_class_precision[cn]:.3f} Rec={per_class_recall[cn]:.3f} F1={per_class_f1[cn]:.3f}',
fontsize=10, fontweight='bold')
ax.set_xlabel('Max Confidence', fontsize=9)
ax.set_ylabel('Density', fontsize=9)
ax.legend(fontsize=7)
ax.grid(alpha=0.3)
ax.set_xlim(0, 1)
# Summary panel
ax = axes[5]
mean_correct = [all_max_conf[all_labels==c][preds[all_labels==c]==c].mean()
if (all_labels==c).sum() > 0 else 0 for c in range(NUM_CLASSES)]
mean_wrong = [all_max_conf[all_labels==c][preds[all_labels==c]!=c].mean()
if ((all_labels==c) & (preds!=c)).sum() > 0 else 0 for c in range(NUM_CLASSES)]
x = np.arange(NUM_CLASSES)
width = 0.35
ax.bar(x - width/2, mean_correct, width, label='Mean conf (correct)', color='#2ecc71', alpha=0.8)
ax.bar(x + width/2, mean_wrong, width, label='Mean conf (wrong)', color='#e74c3c', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels([c[:6] for c in CLASS_NAMES], rotation=20)
ax.set_ylabel('Mean Confidence')
ax.set_title('Mean Confidence: Correct vs Wrong', fontweight='bold')
ax.legend(fontsize=8)
ax.grid(alpha=0.3, axis='y')
ax.set_ylim(0, 1)
plt.suptitle('Confidence Distribution Analysis per Class', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(f'{OUT_DIR}/confidence_distributions.png', dpi=150, bbox_inches='tight')
plt.close()
print(' Saved confidence_distributions.png')
# ================================================================
# PER-SOURCE ANALYSIS
# ================================================================
print('\nRunning per-source analysis...')
# Attach dataset source to val_df indices
source_col = val_df['dataset'].values
results_df = pd.DataFrame({
'true_label': all_labels,
'pred_label': preds,
'max_conf': all_max_conf,
'dataset': source_col[all_idxs],
'correct': (preds == all_labels).astype(int),
})
per_source = {}
for src in ['ODIR', 'APTOS']:
mask = results_df['dataset'] == src
if mask.sum() == 0:
continue
src_true = results_df['true_label'][mask].values
src_pred = results_df['pred_label'][mask].values
src_acc = (src_true == src_pred).mean() * 100
src_f1 = f1_score(src_true, src_pred, average='macro', zero_division=0)
per_class_acc_src = {}
for c in range(NUM_CLASSES):
cmask = (src_true == c)
if cmask.sum() == 0:
per_class_acc_src[CLASS_NAMES[c]] = None
else:
per_class_acc_src[CLASS_NAMES[c]] = float((src_pred[cmask] == c).mean() * 100)
per_source[src] = {
'n_samples': int(mask.sum()),
'accuracy': float(src_acc),
'macro_f1': float(src_f1),
'per_class_acc': per_class_acc_src
}
print(f'\n {src} (n={mask.sum()}):')
print(f' Accuracy : {src_acc:.2f}%')
print(f' Macro F1 : {src_f1:.4f}')
for cn, acc in per_class_acc_src.items():
if acc is not None:
print(f' {cn:<15s}: {acc:.1f}%')
# -- Per-source performance plot --
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Overall bar
sources = list(per_source.keys())
accs = [per_source[s]['accuracy'] for s in sources]
f1s = [per_source[s]['macro_f1'] for s in sources]
x = np.arange(len(sources))
w = 0.35
axes[0].bar(x - w/2, accs, w, label='Accuracy (%)', color=['#3498db', '#e67e22'], alpha=0.85)
axes[0].bar(x + w/2, [f*100 for f in f1s], w, label='Macro F1 ×100',
color=['#2ecc71', '#e74c3c'], alpha=0.85)
axes[0].set_xticks(x); axes[0].set_xticklabels(sources)
axes[0].set_ylim(50, 100)
axes[0].set_ylabel('Score')
axes[0].set_title('Overall Performance by Source', fontweight='bold')
axes[0].legend(); axes[0].grid(alpha=0.3, axis='y')
for xi, (acc, f1) in enumerate(zip(accs, f1s)):
axes[0].text(xi - w/2, acc + 0.5, f'{acc:.1f}', ha='center', fontsize=9)
axes[0].text(xi + w/2, f1*100 + 0.5, f'{f1*100:.1f}', ha='center', fontsize=9)
# Per-class accuracy by source
class_data = {cn: [] for cn in CLASS_NAMES}
valid_sources = []
for src in sources:
valid_sources.append(src)
for cn in CLASS_NAMES:
acc = per_source[src]['per_class_acc'].get(cn)
class_data[cn].append(acc if acc is not None else 0.0)
x = np.arange(len(CLASS_NAMES))
n_src = len(valid_sources)
width = 0.8 / n_src
colors_src = ['#3498db', '#e67e22', '#2ecc71']
for si, src in enumerate(valid_sources):
vals = [class_data[cn][si] for cn in CLASS_NAMES]
offset = (si - n_src/2 + 0.5) * width
axes[1].bar(x + offset, vals, width, label=src, alpha=0.85, color=colors_src[si])
axes[1].set_xticks(x); axes[1].set_xticklabels(CLASS_NAMES, rotation=20, ha='right')
axes[1].set_ylim(0, 105)
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Per-Class Accuracy by Source', fontweight='bold')
axes[1].legend(); axes[1].grid(alpha=0.3, axis='y')
plt.suptitle('Dataset Source Performance Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(f'{OUT_DIR}/per_source_performance.png', dpi=150, bbox_inches='tight')
plt.close()
print(' Saved per_source_performance.png')
# ================================================================
# SAVE METRICS JSON
# ================================================================
print('\nSaving metrics JSON...')
baseline_metrics = {
'overall_accuracy': float(overall_accuracy),
'raw_accuracy': float(raw_acc),
'threshold_accuracy': float(thresh_acc),
'macro_f1': float(macro_f1),
'weighted_f1': float(weighted_f1),
'ece': float(ece),
'per_class_ece': per_class_ece,
'per_class_f1': per_class_f1,
'per_class_precision': per_class_precision,
'per_class_recall': per_class_recall,
'per_class_support': per_class_support,
'per_source_accuracy': {
src: {
'accuracy': per_source[src]['accuracy'],
'macro_f1': per_source[src]['macro_f1'],
'n_samples': per_source[src]['n_samples'],
'per_class_acc': per_source[src]['per_class_acc']
}
for src in per_source
},
'top_confusion_pairs': top5_pairs,
'confusion_matrix_raw': cm_raw.tolist(),
'val_split_size': len(val_df),
'thresholds_used': thresholds,
'calibration': {
'ece': float(ece),
'bin_acc': [float(x) for x in bin_acc],
'bin_conf': [float(x) for x in bin_conf],
'bin_count': bin_count,
}
}
with open(f'{OUT_DIR}/baseline_metrics.json', 'w') as f:
json.dump(baseline_metrics, f, indent=2)
print(f' Saved baseline_metrics.json')
# ================================================================
# ANALYSIS REPORT
# ================================================================
print('\nGenerating analysis report...')
# Identify key findings
worst_recall_class = min(per_class_recall, key=per_class_recall.get)
worst_f1_class = min(per_class_f1, key=per_class_f1.get)
best_f1_class = max(per_class_f1, key=per_class_f1.get)
# High-confidence wrong predictions per class
hcw_analysis = {}
for ci, cn in enumerate(CLASS_NAMES):
mask_class = (all_labels == ci)
wrong_mask = mask_class & ~all_correct
if wrong_mask.sum() > 0:
high_conf_wrong = ((all_max_conf > 0.8) & wrong_mask).sum()
hcw_analysis[cn] = {
'total_wrong': int(wrong_mask.sum()),
'high_conf_wrong_count': int(high_conf_wrong),
'high_conf_wrong_pct': float(high_conf_wrong / wrong_mask.sum() * 100) if wrong_mask.sum() > 0 else 0,
'mean_wrong_conf': float(all_max_conf[wrong_mask].mean()) if wrong_mask.sum() > 0 else 0,
}
else:
hcw_analysis[cn] = {'total_wrong': 0, 'high_conf_wrong_count': 0,
'high_conf_wrong_pct': 0, 'mean_wrong_conf': 0}
# Domain gap
domain_gap = None
if 'ODIR' in per_source and 'APTOS' in per_source:
odir_acc = per_source['ODIR']['accuracy']
aptos_acc = per_source['APTOS']['accuracy']
domain_gap = abs(odir_acc - aptos_acc)
# DR-specific domain gap
odir_dr = per_source['ODIR']['per_class_acc'].get('Diabetes/DR', 0) or 0
aptos_dr = per_source['APTOS']['per_class_acc'].get('Diabetes/DR', 0) or 0
dr_gap = abs(odir_dr - aptos_dr)
else:
domain_gap = 0; odir_acc = 0; aptos_acc = 0; odir_dr = 0; aptos_dr = 0; dr_gap = 0
calibration_verdict = 'overconfident' if sum(
b_conf - b_acc for b_conf, b_acc in zip(bin_conf, bin_acc) if bin_count[bin_acc.index(b_acc)] > 0
) > 0 else 'underconfident'
report = f"""# RetinaSense ViT v2 — Baseline Error Analysis Report
**Generated**: 2026-03-06
**Model**: ViT-Base-Patch16-224 (MultiTaskViT)
**Checkpoint**: outputs_vit/best_model.pth
**Val Split**: {len(val_df)} samples (20% stratified, random_state=42)
---
## 1. Overall Performance
| Metric | Value |
|--------|-------|
| Accuracy (raw argmax) | {raw_acc:.2f}% |
| Accuracy (with thresholds) | {thresh_acc:.2f}% |
| Macro F1 | {macro_f1:.4f} |
| Weighted F1 | {weighted_f1:.4f} |
| ECE (10 bins) | {ece:.4f} |
---
## 2. Per-Class Metrics
| Class | Precision | Recall | F1 | Support |
|-------|-----------|--------|----|---------|
"""
for cn in CLASS_NAMES:
report += (f"| {cn:<15s} | {per_class_precision[cn]:.4f} | "
f"{per_class_recall[cn]:.4f} | {per_class_f1[cn]:.4f} | "
f"{per_class_support[cn]:4d} |\n")
report += f"""
---
## 3. Confusion Analysis — Top 5 Confused Pairs
| Rank | True Class | Predicted As | Count | Rate |
|------|-----------|-------------|-------|------|
"""
for rank, pair in enumerate(top5_pairs, 1):
report += (f"| {rank} | {pair['true_class']} | {pair['pred_class']} | "
f"{pair['count']} | {pair['rate']*100:.1f}% |\n")
report += f"""
### Full Confusion Matrix (normalized by true class)
```
{(' '.join(f'{cn[:6]:>7s}' for cn in CLASS_NAMES))}
"""
for ri, rn in enumerate(CLASS_NAMES):
row_str = ' '.join(f'{cm_norm[ri, ci]:.3f}' for ci in range(NUM_CLASSES))
report += f"{rn[:8]:>8s} {row_str}\n"
report += f"""```
---
## 4. Confidence Calibration Analysis
- **ECE (overall)**: {ece:.4f}
- **Calibration pattern**: The model is predominantly **{calibration_verdict}**
(mean confidence exceeds accuracy in most bins).
### Per-Class ECE
| Class | ECE |
|-------|-----|
"""
for cn, ece_c in per_class_ece.items():
report += f"| {cn} | {ece_c:.4f} |\n"
report += f"""
### High-Confidence Wrong Predictions (confidence > 0.8)
| Class | Total Wrong | High-Conf Wrong | % of Errors | Mean Wrong Conf |
|-------|------------|----------------|-------------|----------------|
"""
for cn, hcw in hcw_analysis.items():
report += (f"| {cn} | {hcw['total_wrong']} | {hcw['high_conf_wrong_count']} | "
f"{hcw['high_conf_wrong_pct']:.1f}% | {hcw['mean_wrong_conf']:.3f} |\n")
report += f"""
---
## 5. Dataset Source Analysis (ODIR vs APTOS)
| Source | N Samples | Accuracy | Macro F1 |
|--------|-----------|----------|----------|
"""
for src, data in per_source.items():
report += f"| {src} | {data['n_samples']} | {data['accuracy']:.2f}% | {data['macro_f1']:.4f} |\n"
report += f"""
### Per-Class Accuracy by Source
| Class |"""
for src in per_source:
report += f" {src} |"
report += "\n|-------|"
for _ in per_source:
report += "--------|"
report += "\n"
for cn in CLASS_NAMES:
report += f"| {cn} |"
for src in per_source:
acc = per_source[src]['per_class_acc'].get(cn)
if acc is None:
report += " N/A |"
else:
report += f" {acc:.1f}% |"
report += "\n"
report += f"""
**Domain gap (overall accuracy)**: {domain_gap:.2f}pp between ODIR and APTOS
"""
if 'ODIR' in per_source and 'APTOS' in per_source:
report += f"""**DR class gap (ODIR vs APTOS)**: ODIR={odir_dr:.1f}% vs APTOS={aptos_dr:.1f}% (gap={dr_gap:.1f}pp)
"""
report += f"""
---
## 6. Error Pattern Summary
### Q1: What is the model's biggest weakness?
The model's biggest weakness is classifying **{worst_f1_class}** (F1={per_class_f1[worst_f1_class]:.4f},
recall={per_class_recall[worst_f1_class]:.4f}). This class has the worst F1 score, indicating the
model struggles to both detect and correctly distinguish it from other pathologies.
The confusion matrix shows that the primary confusion pathway is:
- **{top5_pairs[0]['description']}**: {top5_pairs[0]['count']} cases ({top5_pairs[0]['rate']*100:.1f}% error rate)
- **{top5_pairs[1]['description']}**: {top5_pairs[1]['count']} cases ({top5_pairs[1]['rate']*100:.1f}% error rate)
### Q2: Which class has the worst recall? Why?
**{worst_recall_class}** has the worst recall at {per_class_recall[worst_recall_class]:.4f}.
"""
# Detailed reason based on support
worst_support = per_class_support[worst_recall_class]
all_support = sum(per_class_support.values())
worst_pct = worst_support / all_support * 100
report += f"""This class represents only {worst_support} samples ({worst_pct:.1f}% of the val set).
The low recall is likely caused by:
1. **Class imbalance** — the model sees fewer examples during training and defaults to predicting
more common classes when uncertain.
2. **Visual similarity** with other conditions (especially {top5_pairs[0]['pred_class'] if top5_pairs[0]['true_class']==worst_recall_class else 'Normal'})
at the fundus level.
3. **Threshold sensitivity** — the optimized threshold ({thresholds.get(CLASS_NAMES.index(worst_recall_class), 0.5):.2f})
may overcorrect or undercorrect depending on the calibration.
### Q3: Evidence of domain shift (ODIR vs APTOS)?
"""
if domain_gap is not None and domain_gap > 2.0:
report += f"""YES — there is a **{domain_gap:.1f}pp accuracy gap** between ODIR ({odir_acc:.1f}%) and APTOS
({aptos_acc:.1f}%). This is significant and consistent with domain shift between the two data sources.
For the DR/Diabetes class specifically, the gap is **{dr_gap:.1f}pp** (ODIR={odir_dr:.1f}% vs APTOS={aptos_dr:.1f}%).
APTOS images are specifically DR-graded fundus photographs from India (Aravind Eye Hospital),
while ODIR covers multiple disease classes with more varied image quality and capture conditions.
The Ben Graham preprocessing helps but does not fully bridge the domain gap.
**Implication for v3**: Domain-specific augmentation or source-aware training (e.g., source
as auxiliary input, separate batch norms, or domain adaptation) may improve generalization.
"""
elif domain_gap is not None and domain_gap > 0:
report += f"""MINOR gap observed — {domain_gap:.1f}pp difference between ODIR ({odir_acc:.1f}%) and
APTOS ({aptos_acc:.1f}%). The gap is small, suggesting the Ben Graham preprocessing and ViT
architecture generalize reasonably across sources. DR-specific gap: {dr_gap:.1f}pp.
"""
else:
report += "Insufficient cross-source data to conclude domain shift.\n"
report += f"""
### Q4: Calibration assessment
ECE = **{ece:.4f}** (scale: 0=perfect, 0.1=poor).
"""
if ece < 0.03:
report += "The model is **well-calibrated** (ECE < 0.03). Confidence scores are reliable."
elif ece < 0.07:
report += f"""The model shows **moderate miscalibration** (ECE={ece:.4f}). The reliability diagram
shows the model is {calibration_verdict} in the high-confidence range, meaning predicted
confidence scores are not fully reliable. Temperature scaling in v3 is recommended."""
else:
report += f"""The model is **poorly calibrated** (ECE={ece:.4f}). The {calibration_verdict}
pattern is severe. Temperature scaling or label smoothing in v3 training is strongly recommended."""
report += f"""
---
## 7. Recommendations for v3 Training
Based on this baseline analysis:
1. **Address {worst_recall_class} recall** — increase class weight, targeted augmentation,
or focal loss gamma tuning for this class.
2. **Calibration** — add temperature scaling post-training or increase label smoothing
(current ECE={ece:.4f}).
3. **Domain shift mitigation** — consider source-conditioned augmentation or adversarial
domain adaptation if ODIR/APTOS gap persists.
4. **High-confidence errors** — the model makes confidently wrong predictions on certain
classes; mixup or CutMix augmentation may improve uncertainty estimation.
5. **Top confusion pairs** to specifically target:
"""
for pair in top5_pairs[:3]:
report += f" - {pair['description']} ({pair['count']} errors)\n"
report += """
---
## 8. Output Files
| File | Description |
|------|-------------|
| confusion_matrix_raw.png | Raw count confusion matrix |
| confusion_matrix_normalized.png | Recall-normalized confusion matrix |
| reliability_diagram.png | ECE calibration plot |
| confidence_distributions.png | Per-class confidence histograms |
| per_source_performance.png | ODIR vs APTOS breakdown |
| baseline_metrics.json | All metrics in structured JSON |
---
*Report generated by RetinaSense ViT v2 error analysis pipeline.*
"""
with open(f'{OUT_DIR}/BASELINE_ANALYSIS.md', 'w') as f:
f.write(report)
print(f' Saved BASELINE_ANALYSIS.md')
# ================================================================
# FINAL SUMMARY
# ================================================================
print('\n' + '='*65)
print(' BASELINE ANALYSIS COMPLETE')
print('='*65)
print(f' Val accuracy (thresh) : {thresh_acc:.2f}%')
print(f' Macro F1 : {macro_f1:.4f}')
print(f' ECE : {ece:.4f}')
print(f' Worst class (F1) : {worst_f1_class} ({per_class_f1[worst_f1_class]:.4f})')
print(f' Worst class (recall) : {worst_recall_class} ({per_class_recall[worst_recall_class]:.4f})')
print(f' Top confusion : {top5_pairs[0]["description"]}')
if domain_gap is not None:
print(f' Domain gap (ODIR-APTOS): {domain_gap:.2f}pp')
print(f'\n All outputs in: {OUT_DIR}/')
print('='*65)