Add run_error_analysis.py
Browse files- run_error_analysis.py +977 -0
run_error_analysis.py
ADDED
|
@@ -0,0 +1,977 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
RetinaSense ViT v2 - Comprehensive Error Analysis & Baseline Report
|
| 4 |
+
===================================================================
|
| 5 |
+
Runs full evaluation on the validation split, computes ECE,
|
| 6 |
+
confusion analysis, confidence distributions, and source-level
|
| 7 |
+
performance. Saves all plots and metrics to outputs_analysis/v2_baseline/.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os, sys, json, warnings
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import matplotlib
|
| 14 |
+
matplotlib.use('Agg')
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import matplotlib.gridspec as gridspec
|
| 17 |
+
import seaborn as sns
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
warnings.filterwarnings('ignore')
|
| 21 |
+
|
| 22 |
+
import cv2
|
| 23 |
+
from PIL import Image
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
from torchvision import transforms
|
| 29 |
+
from torch.utils.data import Dataset, DataLoader
|
| 30 |
+
|
| 31 |
+
import timm
|
| 32 |
+
from sklearn.model_selection import train_test_split
|
| 33 |
+
from sklearn.metrics import (
|
| 34 |
+
confusion_matrix, classification_report,
|
| 35 |
+
f1_score, precision_score, recall_score
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# ================================================================
|
| 39 |
+
# CONFIG
|
| 40 |
+
# ================================================================
|
| 41 |
+
BASE_DIR = '/teamspace/studios/this_studio'
|
| 42 |
+
MODEL_PATH = f'{BASE_DIR}/outputs_vit/best_model.pth'
|
| 43 |
+
META_CSV = f'{BASE_DIR}/final_unified_metadata.csv'
|
| 44 |
+
THRESH_JSON = f'{BASE_DIR}/outputs_vit/threshold_optimization_results.json'
|
| 45 |
+
CACHE_DIR = f'{BASE_DIR}/preprocessed_cache_vit'
|
| 46 |
+
OUT_DIR = f'{BASE_DIR}/outputs_analysis/v2_baseline'
|
| 47 |
+
|
| 48 |
+
IMG_SIZE = 224
|
| 49 |
+
BATCH_SIZE = 64
|
| 50 |
+
NUM_WORKERS = 8
|
| 51 |
+
NUM_CLASSES = 5
|
| 52 |
+
CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
|
| 53 |
+
|
| 54 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
| 55 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 58 |
+
print(f'Device: {device}')
|
| 59 |
+
if torch.cuda.is_available():
|
| 60 |
+
print(f'GPU: {torch.cuda.get_device_name(0)}')
|
| 61 |
+
|
| 62 |
+
# ================================================================
|
| 63 |
+
# MODEL DEFINITION (mirrors retinasense_vit.py)
|
| 64 |
+
# ================================================================
|
| 65 |
+
class MultiTaskViT(nn.Module):
|
| 66 |
+
def __init__(self, n_disease=5, n_severity=5, drop=0.4):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.backbone = timm.create_model(
|
| 69 |
+
'vit_base_patch16_224', pretrained=False, num_classes=0)
|
| 70 |
+
feat = 768
|
| 71 |
+
self.drop = nn.Dropout(drop)
|
| 72 |
+
self.disease_head = nn.Sequential(
|
| 73 |
+
nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
|
| 74 |
+
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
|
| 75 |
+
nn.Linear(256, n_disease))
|
| 76 |
+
self.severity_head = nn.Sequential(
|
| 77 |
+
nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
|
| 78 |
+
nn.Linear(256, n_severity))
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
f = self.backbone(x)
|
| 82 |
+
f = self.drop(f)
|
| 83 |
+
return self.disease_head(f), self.severity_head(f)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ================================================================
|
| 87 |
+
# IMAGE PREPROCESSING (Ben Graham method, matches training)
|
| 88 |
+
# ================================================================
|
| 89 |
+
def ben_graham(path, sz=IMG_SIZE, sigma=10):
|
| 90 |
+
img = cv2.imread(str(path))
|
| 91 |
+
if img is None:
|
| 92 |
+
img = np.array(Image.open(str(path)).convert('RGB'))
|
| 93 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 94 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 95 |
+
img = cv2.resize(img, (sz, sz))
|
| 96 |
+
img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0,0), sigma), -4, 128)
|
| 97 |
+
mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
| 98 |
+
cv2.circle(mask, (sz//2, sz//2), int(sz * 0.48), 255, -1)
|
| 99 |
+
return cv2.bitwise_and(img, img, mask=mask)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def resolve_image_path(raw_path):
|
| 103 |
+
"""
|
| 104 |
+
Resolve image path from CSV entry (which has leading .// prefix).
|
| 105 |
+
Tries multiple known root locations.
|
| 106 |
+
APTOS images live in:
|
| 107 |
+
aptos/gaussian_filtered_images/gaussian_filtered_images/{Severity}/{stem}.png
|
| 108 |
+
ODIR images live in:
|
| 109 |
+
odir/preprocessed_images/{filename}
|
| 110 |
+
"""
|
| 111 |
+
# Strip leading .// or ./
|
| 112 |
+
clean = raw_path.lstrip('.').lstrip('/').lstrip('/')
|
| 113 |
+
clean = clean.replace('//', '/')
|
| 114 |
+
|
| 115 |
+
stem = Path(raw_path).stem
|
| 116 |
+
|
| 117 |
+
candidates = [
|
| 118 |
+
f'{BASE_DIR}/{clean}',
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
# APTOS: search all severity subfolders
|
| 122 |
+
if 'aptos' in raw_path.lower():
|
| 123 |
+
aptos_base = f'{BASE_DIR}/aptos/gaussian_filtered_images/gaussian_filtered_images'
|
| 124 |
+
for severity in ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferate_DR']:
|
| 125 |
+
for ext in ['.png', '.jpg', '.jpeg']:
|
| 126 |
+
candidates.append(f'{aptos_base}/{severity}/{stem}{ext}')
|
| 127 |
+
# Also try train_images (original path)
|
| 128 |
+
for ext in ['.png', '.jpg', '.jpeg']:
|
| 129 |
+
candidates.append(f'{BASE_DIR}/aptos/train_images/{stem}{ext}')
|
| 130 |
+
|
| 131 |
+
# ODIR: preprocessed_images
|
| 132 |
+
if 'odir' in raw_path.lower():
|
| 133 |
+
fname = Path(raw_path).name
|
| 134 |
+
candidates.append(f'{BASE_DIR}/odir/preprocessed_images/{fname}')
|
| 135 |
+
candidates.append(f'{BASE_DIR}/ocular-disease-recognition-odir5k/preprocessed_images/{fname}')
|
| 136 |
+
|
| 137 |
+
for c in candidates:
|
| 138 |
+
if os.path.exists(c):
|
| 139 |
+
return c
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def load_or_cache(row):
|
| 144 |
+
"""
|
| 145 |
+
Load preprocessed image from cache (.npy) or process from disk.
|
| 146 |
+
Returns uint8 HxWx3 numpy array.
|
| 147 |
+
"""
|
| 148 |
+
stem = Path(row['image_path_clean']).stem
|
| 149 |
+
cache_fp = f'{CACHE_DIR}/{stem}_224.npy'
|
| 150 |
+
|
| 151 |
+
if os.path.exists(cache_fp):
|
| 152 |
+
try:
|
| 153 |
+
return np.load(cache_fp)
|
| 154 |
+
except Exception:
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
img_path = row.get('image_path_resolved')
|
| 158 |
+
if img_path and os.path.exists(img_path):
|
| 159 |
+
try:
|
| 160 |
+
arr = ben_graham(img_path)
|
| 161 |
+
np.save(cache_fp, arr)
|
| 162 |
+
return arr
|
| 163 |
+
except Exception as e:
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
# Fallback: zero image
|
| 167 |
+
return np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ================================================================
|
| 171 |
+
# DATASET
|
| 172 |
+
# ================================================================
|
| 173 |
+
val_transform = transforms.Compose([
|
| 174 |
+
transforms.ToPILImage(),
|
| 175 |
+
transforms.ToTensor(),
|
| 176 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 177 |
+
])
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class RetDS(Dataset):
|
| 181 |
+
def __init__(self, df):
|
| 182 |
+
self.df = df.reset_index(drop=True)
|
| 183 |
+
|
| 184 |
+
def __len__(self):
|
| 185 |
+
return len(self.df)
|
| 186 |
+
|
| 187 |
+
def __getitem__(self, i):
|
| 188 |
+
r = self.df.iloc[i]
|
| 189 |
+
img = load_or_cache(r)
|
| 190 |
+
return (
|
| 191 |
+
val_transform(img),
|
| 192 |
+
torch.tensor(int(r['disease_label']), dtype=torch.long),
|
| 193 |
+
torch.tensor(int(r['severity_label']), dtype=torch.long),
|
| 194 |
+
i # return index so we can track per-sample metadata
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ================================================================
|
| 199 |
+
# STEP 1 — LOAD METADATA & BUILD VAL SPLIT
|
| 200 |
+
# ================================================================
|
| 201 |
+
print('\n[1/6] Loading metadata and building val split...')
|
| 202 |
+
|
| 203 |
+
meta = pd.read_csv(META_CSV)
|
| 204 |
+
print(f' Raw rows: {len(meta)}')
|
| 205 |
+
|
| 206 |
+
# Fix image paths
|
| 207 |
+
meta['image_path_clean'] = meta['image_path'].str.lstrip('.').str.lstrip('/').str.replace('//', '/', regex=False)
|
| 208 |
+
meta['image_path_resolved'] = meta['image_path_clean'].apply(
|
| 209 |
+
lambda p: resolve_image_path(p)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
n_resolved = meta['image_path_resolved'].notna().sum()
|
| 213 |
+
print(f' Images resolved on disk: {n_resolved} / {len(meta)}')
|
| 214 |
+
|
| 215 |
+
# Build the same stratified split used in training (random_state=42, test_size=0.2)
|
| 216 |
+
train_df, val_df = train_test_split(
|
| 217 |
+
meta,
|
| 218 |
+
test_size=0.2,
|
| 219 |
+
stratify=meta['disease_label'],
|
| 220 |
+
random_state=42
|
| 221 |
+
)
|
| 222 |
+
val_df = val_df.reset_index(drop=True)
|
| 223 |
+
print(f' Val split: {len(val_df)} samples')
|
| 224 |
+
print(f' Val class distribution:')
|
| 225 |
+
for lbl, cnt in val_df['disease_label'].value_counts().sort_index().items():
|
| 226 |
+
print(f' {CLASS_NAMES[int(lbl)]:<15s}: {cnt:4d}')
|
| 227 |
+
|
| 228 |
+
# ================================================================
|
| 229 |
+
# STEP 2 — LOAD MODEL
|
| 230 |
+
# ================================================================
|
| 231 |
+
print('\n[2/6] Loading model...')
|
| 232 |
+
|
| 233 |
+
model = MultiTaskViT().to(device)
|
| 234 |
+
ckpt = torch.load(MODEL_PATH, map_location=device, weights_only=False)
|
| 235 |
+
model.load_state_dict(ckpt['model_state_dict'])
|
| 236 |
+
model.eval()
|
| 237 |
+
print(f' Loaded checkpoint: epoch={ckpt.get("epoch","?")}, '
|
| 238 |
+
f'macro_f1={ckpt.get("macro_f1", 0):.4f}')
|
| 239 |
+
|
| 240 |
+
# Load thresholds
|
| 241 |
+
with open(THRESH_JSON) as f:
|
| 242 |
+
thresh_data = json.load(f)
|
| 243 |
+
thresholds = {int(k): float(v) for k, v in thresh_data['optimal_thresholds'].items()}
|
| 244 |
+
print(f' Optimal thresholds: {thresholds}')
|
| 245 |
+
|
| 246 |
+
# ================================================================
|
| 247 |
+
# STEP 3 — RUN INFERENCE
|
| 248 |
+
# ================================================================
|
| 249 |
+
print('\n[3/6] Running inference on val set...')
|
| 250 |
+
|
| 251 |
+
val_ds = RetDS(val_df)
|
| 252 |
+
val_loader = DataLoader(
|
| 253 |
+
val_ds, batch_size=BATCH_SIZE, shuffle=False,
|
| 254 |
+
num_workers=NUM_WORKERS, pin_memory=True
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
all_probs = [] # (N, 5) softmax probabilities
|
| 258 |
+
all_preds = [] # (N,) argmax predictions
|
| 259 |
+
all_labels = [] # (N,) true labels
|
| 260 |
+
all_idxs = [] # (N,) val_df indices
|
| 261 |
+
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
for imgs, d_lbl, s_lbl, idx in tqdm(val_loader, desc='Inference'):
|
| 264 |
+
imgs = imgs.to(device, non_blocking=True)
|
| 265 |
+
with torch.amp.autocast('cuda'):
|
| 266 |
+
d_out, _ = model(imgs)
|
| 267 |
+
probs = torch.softmax(d_out.float(), dim=1).cpu().numpy()
|
| 268 |
+
preds = d_out.argmax(1).cpu().numpy()
|
| 269 |
+
all_probs.append(probs)
|
| 270 |
+
all_preds.append(preds)
|
| 271 |
+
all_labels.append(d_lbl.numpy())
|
| 272 |
+
all_idxs.append(idx.numpy())
|
| 273 |
+
|
| 274 |
+
all_probs = np.vstack(all_probs) # (N, 5)
|
| 275 |
+
all_preds = np.concatenate(all_preds)
|
| 276 |
+
all_labels = np.concatenate(all_labels)
|
| 277 |
+
all_idxs = np.concatenate(all_idxs)
|
| 278 |
+
|
| 279 |
+
# Also compute threshold-adjusted predictions
|
| 280 |
+
thresh_preds = np.zeros_like(all_preds)
|
| 281 |
+
for i in range(len(all_probs)):
|
| 282 |
+
adjusted = all_probs[i].copy()
|
| 283 |
+
for c, t in thresholds.items():
|
| 284 |
+
adjusted[c] = all_probs[i][c] / t # scale by threshold
|
| 285 |
+
thresh_preds[i] = adjusted.argmax()
|
| 286 |
+
|
| 287 |
+
raw_acc = (all_preds == all_labels).mean() * 100
|
| 288 |
+
thresh_acc = (thresh_preds == all_labels).mean() * 100
|
| 289 |
+
print(f' Raw accuracy : {raw_acc:.2f}%')
|
| 290 |
+
print(f' Threshold accuracy: {thresh_acc:.2f}%')
|
| 291 |
+
|
| 292 |
+
# Use threshold-adjusted for main analysis (matches published 84.48%)
|
| 293 |
+
preds = thresh_preds
|
| 294 |
+
|
| 295 |
+
# ================================================================
|
| 296 |
+
# STEP 4 — CONFIDENCE CALIBRATION (ECE)
|
| 297 |
+
# ================================================================
|
| 298 |
+
print('\n[4/6] Computing ECE and reliability diagram...')
|
| 299 |
+
|
| 300 |
+
def compute_ece(probs, labels, n_bins=10):
|
| 301 |
+
"""Expected Calibration Error with equal-width bins."""
|
| 302 |
+
confidences = probs.max(axis=1) # max probability = confidence
|
| 303 |
+
predicted = probs.argmax(axis=1)
|
| 304 |
+
correct = (predicted == labels).astype(float)
|
| 305 |
+
|
| 306 |
+
bins = np.linspace(0, 1, n_bins + 1)
|
| 307 |
+
ece = 0.0
|
| 308 |
+
bin_acc = []
|
| 309 |
+
bin_conf = []
|
| 310 |
+
bin_count = []
|
| 311 |
+
|
| 312 |
+
for lo, hi in zip(bins[:-1], bins[1:]):
|
| 313 |
+
mask = (confidences >= lo) & (confidences < hi)
|
| 314 |
+
if mask.sum() == 0:
|
| 315 |
+
bin_acc.append(0.0)
|
| 316 |
+
bin_conf.append((lo + hi) / 2)
|
| 317 |
+
bin_count.append(0)
|
| 318 |
+
continue
|
| 319 |
+
acc = correct[mask].mean()
|
| 320 |
+
conf = confidences[mask].mean()
|
| 321 |
+
n = mask.sum()
|
| 322 |
+
ece += (n / len(labels)) * abs(acc - conf)
|
| 323 |
+
bin_acc.append(acc)
|
| 324 |
+
bin_conf.append(conf)
|
| 325 |
+
bin_count.append(int(n))
|
| 326 |
+
|
| 327 |
+
return ece, bin_acc, bin_conf, bin_count, bins
|
| 328 |
+
|
| 329 |
+
ece, bin_acc, bin_conf, bin_count, bins = compute_ece(all_probs, all_labels)
|
| 330 |
+
print(f' ECE (10 bins): {ece:.4f}')
|
| 331 |
+
|
| 332 |
+
# Per-class calibration
|
| 333 |
+
per_class_ece = {}
|
| 334 |
+
for c in range(NUM_CLASSES):
|
| 335 |
+
mask = (all_labels == c)
|
| 336 |
+
if mask.sum() == 0:
|
| 337 |
+
per_class_ece[CLASS_NAMES[c]] = 0.0
|
| 338 |
+
continue
|
| 339 |
+
ece_c, _, _, _, _ = compute_ece(all_probs[mask], all_labels[mask])
|
| 340 |
+
per_class_ece[CLASS_NAMES[c]] = float(ece_c)
|
| 341 |
+
print(f' ECE {CLASS_NAMES[c]:<15s}: {ece_c:.4f}')
|
| 342 |
+
|
| 343 |
+
# -- Reliability diagram --
|
| 344 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 345 |
+
|
| 346 |
+
bin_centers = (bins[:-1] + bins[1:]) / 2
|
| 347 |
+
bars = axes[0].bar(
|
| 348 |
+
bin_centers, bin_acc,
|
| 349 |
+
width=(bins[1] - bins[0]) * 0.9,
|
| 350 |
+
alpha=0.7, color='steelblue', label='Accuracy per bin'
|
| 351 |
+
)
|
| 352 |
+
axes[0].plot([0, 1], [0, 1], 'r--', lw=2, label='Perfect calibration')
|
| 353 |
+
axes[0].set_xlabel('Confidence', fontsize=12)
|
| 354 |
+
axes[0].set_ylabel('Accuracy', fontsize=12)
|
| 355 |
+
axes[0].set_title(f'Reliability Diagram\nECE = {ece:.4f}', fontsize=13, fontweight='bold')
|
| 356 |
+
axes[0].legend(fontsize=10)
|
| 357 |
+
axes[0].grid(alpha=0.3)
|
| 358 |
+
axes[0].set_xlim(0, 1); axes[0].set_ylim(0, 1)
|
| 359 |
+
|
| 360 |
+
# Annotate with bin counts
|
| 361 |
+
for bar, cnt in zip(bars, bin_count):
|
| 362 |
+
if cnt > 0:
|
| 363 |
+
axes[0].text(
|
| 364 |
+
bar.get_x() + bar.get_width()/2, min(bar.get_height() + 0.02, 0.97),
|
| 365 |
+
str(cnt), ha='center', va='bottom', fontsize=7, color='black'
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Gap diagram (overconfidence = positive gap)
|
| 369 |
+
gap = np.array(bin_conf) - np.array(bin_acc)
|
| 370 |
+
color_gap = ['#e74c3c' if g > 0 else '#2ecc71' for g in gap]
|
| 371 |
+
axes[1].bar(bin_centers, gap, width=(bins[1]-bins[0])*0.9, color=color_gap, alpha=0.8)
|
| 372 |
+
axes[1].axhline(0, color='black', lw=1)
|
| 373 |
+
axes[1].set_xlabel('Confidence', fontsize=12)
|
| 374 |
+
axes[1].set_ylabel('Confidence - Accuracy (Gap)', fontsize=12)
|
| 375 |
+
axes[1].set_title('Calibration Gap\n(Red=overconfident, Green=underconfident)',
|
| 376 |
+
fontsize=13, fontweight='bold')
|
| 377 |
+
axes[1].grid(alpha=0.3)
|
| 378 |
+
axes[1].set_xlim(0, 1)
|
| 379 |
+
|
| 380 |
+
plt.tight_layout()
|
| 381 |
+
plt.savefig(f'{OUT_DIR}/reliability_diagram.png', dpi=150, bbox_inches='tight')
|
| 382 |
+
plt.close()
|
| 383 |
+
print(f' Saved reliability_diagram.png')
|
| 384 |
+
|
| 385 |
+
# ================================================================
|
| 386 |
+
# STEP 5 — CONFUSION MATRIX
|
| 387 |
+
# ================================================================
|
| 388 |
+
print('\n[5/6] Generating confusion matrices...')
|
| 389 |
+
|
| 390 |
+
cm_raw = confusion_matrix(all_labels, preds)
|
| 391 |
+
cm_norm = cm_raw.astype(float) / cm_raw.sum(axis=1, keepdims=True)
|
| 392 |
+
|
| 393 |
+
# -- Raw counts confusion matrix --
|
| 394 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 395 |
+
sns.heatmap(
|
| 396 |
+
cm_raw, annot=True, fmt='d', cmap='Blues', ax=ax,
|
| 397 |
+
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
|
| 398 |
+
linewidths=0.5, linecolor='gray'
|
| 399 |
+
)
|
| 400 |
+
ax.set_title('Confusion Matrix (Raw Counts)', fontsize=14, fontweight='bold')
|
| 401 |
+
ax.set_ylabel('True Label', fontsize=12)
|
| 402 |
+
ax.set_xlabel('Predicted Label', fontsize=12)
|
| 403 |
+
plt.xticks(rotation=30, ha='right')
|
| 404 |
+
plt.tight_layout()
|
| 405 |
+
plt.savefig(f'{OUT_DIR}/confusion_matrix_raw.png', dpi=150, bbox_inches='tight')
|
| 406 |
+
plt.close()
|
| 407 |
+
|
| 408 |
+
# -- Normalized confusion matrix --
|
| 409 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 410 |
+
sns.heatmap(
|
| 411 |
+
cm_norm, annot=True, fmt='.3f', cmap='Blues', ax=ax,
|
| 412 |
+
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
|
| 413 |
+
linewidths=0.5, linecolor='gray', vmin=0, vmax=1
|
| 414 |
+
)
|
| 415 |
+
ax.set_title('Confusion Matrix (Normalized by True Class)', fontsize=14, fontweight='bold')
|
| 416 |
+
ax.set_ylabel('True Label', fontsize=12)
|
| 417 |
+
ax.set_xlabel('Predicted Label', fontsize=12)
|
| 418 |
+
plt.xticks(rotation=30, ha='right')
|
| 419 |
+
plt.tight_layout()
|
| 420 |
+
plt.savefig(f'{OUT_DIR}/confusion_matrix_normalized.png', dpi=150, bbox_inches='tight')
|
| 421 |
+
plt.close()
|
| 422 |
+
print(' Saved confusion_matrix_raw.png and confusion_matrix_normalized.png')
|
| 423 |
+
|
| 424 |
+
# -- Top confused pairs --
|
| 425 |
+
confused_pairs = []
|
| 426 |
+
for true_c in range(NUM_CLASSES):
|
| 427 |
+
for pred_c in range(NUM_CLASSES):
|
| 428 |
+
if true_c == pred_c:
|
| 429 |
+
continue
|
| 430 |
+
count = cm_raw[true_c, pred_c]
|
| 431 |
+
rate = cm_norm[true_c, pred_c]
|
| 432 |
+
confused_pairs.append({
|
| 433 |
+
'true_class': CLASS_NAMES[true_c],
|
| 434 |
+
'pred_class': CLASS_NAMES[pred_c],
|
| 435 |
+
'count': int(count),
|
| 436 |
+
'rate': float(rate),
|
| 437 |
+
'description': f'{CLASS_NAMES[true_c]} misclassified AS {CLASS_NAMES[pred_c]}'
|
| 438 |
+
})
|
| 439 |
+
confused_pairs.sort(key=lambda x: x['count'], reverse=True)
|
| 440 |
+
top5_pairs = confused_pairs[:5]
|
| 441 |
+
|
| 442 |
+
print('\n Top 5 confused class pairs (by raw count):')
|
| 443 |
+
for p in top5_pairs:
|
| 444 |
+
print(f' {p["description"]}: {p["count"]} ({p["rate"]*100:.1f}%)')
|
| 445 |
+
|
| 446 |
+
# ================================================================
|
| 447 |
+
# STEP 6 — PER-CLASS METRICS
|
| 448 |
+
# ================================================================
|
| 449 |
+
print('\n[6/6] Computing per-class metrics...')
|
| 450 |
+
|
| 451 |
+
report_dict = classification_report(
|
| 452 |
+
all_labels, preds, target_names=CLASS_NAMES, output_dict=True, zero_division=0
|
| 453 |
+
)
|
| 454 |
+
print(classification_report(all_labels, preds, target_names=CLASS_NAMES, digits=4, zero_division=0))
|
| 455 |
+
|
| 456 |
+
per_class_precision = {}
|
| 457 |
+
per_class_recall = {}
|
| 458 |
+
per_class_f1 = {}
|
| 459 |
+
per_class_support = {}
|
| 460 |
+
|
| 461 |
+
for cn in CLASS_NAMES:
|
| 462 |
+
per_class_precision[cn] = report_dict[cn]['precision']
|
| 463 |
+
per_class_recall[cn] = report_dict[cn]['recall']
|
| 464 |
+
per_class_f1[cn] = report_dict[cn]['f1-score']
|
| 465 |
+
per_class_support[cn] = int(report_dict[cn]['support'])
|
| 466 |
+
|
| 467 |
+
overall_accuracy = report_dict['accuracy'] * 100
|
| 468 |
+
macro_f1 = report_dict['macro avg']['f1-score']
|
| 469 |
+
weighted_f1 = report_dict['weighted avg']['f1-score']
|
| 470 |
+
|
| 471 |
+
print(f'\n Overall accuracy : {overall_accuracy:.2f}%')
|
| 472 |
+
print(f' Macro F1 : {macro_f1:.4f}')
|
| 473 |
+
print(f' Weighted F1 : {weighted_f1:.4f}')
|
| 474 |
+
|
| 475 |
+
# ================================================================
|
| 476 |
+
# CONFIDENCE DISTRIBUTION ANALYSIS
|
| 477 |
+
# ================================================================
|
| 478 |
+
print('\nAnalyzing confidence distributions...')
|
| 479 |
+
|
| 480 |
+
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
| 481 |
+
axes = axes.flatten()
|
| 482 |
+
|
| 483 |
+
all_max_conf = all_probs.max(axis=1)
|
| 484 |
+
all_correct = (preds == all_labels)
|
| 485 |
+
|
| 486 |
+
for ci, cn in enumerate(CLASS_NAMES):
|
| 487 |
+
ax = axes[ci]
|
| 488 |
+
mask_class = (all_labels == ci)
|
| 489 |
+
|
| 490 |
+
correct_conf = all_max_conf[mask_class & all_correct]
|
| 491 |
+
wrong_conf = all_max_conf[mask_class & ~all_correct]
|
| 492 |
+
|
| 493 |
+
n_correct = len(correct_conf)
|
| 494 |
+
n_wrong = len(wrong_conf)
|
| 495 |
+
|
| 496 |
+
if n_correct > 0:
|
| 497 |
+
ax.hist(correct_conf, bins=20, alpha=0.6, color='#2ecc71',
|
| 498 |
+
label=f'Correct (n={n_correct})', density=True)
|
| 499 |
+
if n_wrong > 0:
|
| 500 |
+
ax.hist(wrong_conf, bins=20, alpha=0.6, color='#e74c3c',
|
| 501 |
+
label=f'Wrong (n={n_wrong})', density=True)
|
| 502 |
+
|
| 503 |
+
# Mark high-confidence wrong predictions
|
| 504 |
+
if n_wrong > 0:
|
| 505 |
+
high_conf_wrong = (wrong_conf > 0.8).sum()
|
| 506 |
+
ax.axvline(0.8, color='darkred', linestyle='--', alpha=0.7, lw=1.5,
|
| 507 |
+
label=f'Conf>0.8 wrong: {high_conf_wrong}')
|
| 508 |
+
|
| 509 |
+
ax.set_title(f'{cn}\nPrec={per_class_precision[cn]:.3f} Rec={per_class_recall[cn]:.3f} F1={per_class_f1[cn]:.3f}',
|
| 510 |
+
fontsize=10, fontweight='bold')
|
| 511 |
+
ax.set_xlabel('Max Confidence', fontsize=9)
|
| 512 |
+
ax.set_ylabel('Density', fontsize=9)
|
| 513 |
+
ax.legend(fontsize=7)
|
| 514 |
+
ax.grid(alpha=0.3)
|
| 515 |
+
ax.set_xlim(0, 1)
|
| 516 |
+
|
| 517 |
+
# Summary panel
|
| 518 |
+
ax = axes[5]
|
| 519 |
+
mean_correct = [all_max_conf[all_labels==c][preds[all_labels==c]==c].mean()
|
| 520 |
+
if (all_labels==c).sum() > 0 else 0 for c in range(NUM_CLASSES)]
|
| 521 |
+
mean_wrong = [all_max_conf[all_labels==c][preds[all_labels==c]!=c].mean()
|
| 522 |
+
if ((all_labels==c) & (preds!=c)).sum() > 0 else 0 for c in range(NUM_CLASSES)]
|
| 523 |
+
|
| 524 |
+
x = np.arange(NUM_CLASSES)
|
| 525 |
+
width = 0.35
|
| 526 |
+
ax.bar(x - width/2, mean_correct, width, label='Mean conf (correct)', color='#2ecc71', alpha=0.8)
|
| 527 |
+
ax.bar(x + width/2, mean_wrong, width, label='Mean conf (wrong)', color='#e74c3c', alpha=0.8)
|
| 528 |
+
ax.set_xticks(x)
|
| 529 |
+
ax.set_xticklabels([c[:6] for c in CLASS_NAMES], rotation=20)
|
| 530 |
+
ax.set_ylabel('Mean Confidence')
|
| 531 |
+
ax.set_title('Mean Confidence: Correct vs Wrong', fontweight='bold')
|
| 532 |
+
ax.legend(fontsize=8)
|
| 533 |
+
ax.grid(alpha=0.3, axis='y')
|
| 534 |
+
ax.set_ylim(0, 1)
|
| 535 |
+
|
| 536 |
+
plt.suptitle('Confidence Distribution Analysis per Class', fontsize=14, fontweight='bold')
|
| 537 |
+
plt.tight_layout()
|
| 538 |
+
plt.savefig(f'{OUT_DIR}/confidence_distributions.png', dpi=150, bbox_inches='tight')
|
| 539 |
+
plt.close()
|
| 540 |
+
print(' Saved confidence_distributions.png')
|
| 541 |
+
|
| 542 |
+
# ================================================================
|
| 543 |
+
# PER-SOURCE ANALYSIS
|
| 544 |
+
# ================================================================
|
| 545 |
+
print('\nRunning per-source analysis...')
|
| 546 |
+
|
| 547 |
+
# Attach dataset source to val_df indices
|
| 548 |
+
source_col = val_df['dataset'].values
|
| 549 |
+
|
| 550 |
+
results_df = pd.DataFrame({
|
| 551 |
+
'true_label': all_labels,
|
| 552 |
+
'pred_label': preds,
|
| 553 |
+
'max_conf': all_max_conf,
|
| 554 |
+
'dataset': source_col[all_idxs],
|
| 555 |
+
'correct': (preds == all_labels).astype(int),
|
| 556 |
+
})
|
| 557 |
+
|
| 558 |
+
per_source = {}
|
| 559 |
+
for src in ['ODIR', 'APTOS']:
|
| 560 |
+
mask = results_df['dataset'] == src
|
| 561 |
+
if mask.sum() == 0:
|
| 562 |
+
continue
|
| 563 |
+
src_true = results_df['true_label'][mask].values
|
| 564 |
+
src_pred = results_df['pred_label'][mask].values
|
| 565 |
+
src_acc = (src_true == src_pred).mean() * 100
|
| 566 |
+
src_f1 = f1_score(src_true, src_pred, average='macro', zero_division=0)
|
| 567 |
+
|
| 568 |
+
per_class_acc_src = {}
|
| 569 |
+
for c in range(NUM_CLASSES):
|
| 570 |
+
cmask = (src_true == c)
|
| 571 |
+
if cmask.sum() == 0:
|
| 572 |
+
per_class_acc_src[CLASS_NAMES[c]] = None
|
| 573 |
+
else:
|
| 574 |
+
per_class_acc_src[CLASS_NAMES[c]] = float((src_pred[cmask] == c).mean() * 100)
|
| 575 |
+
|
| 576 |
+
per_source[src] = {
|
| 577 |
+
'n_samples': int(mask.sum()),
|
| 578 |
+
'accuracy': float(src_acc),
|
| 579 |
+
'macro_f1': float(src_f1),
|
| 580 |
+
'per_class_acc': per_class_acc_src
|
| 581 |
+
}
|
| 582 |
+
print(f'\n {src} (n={mask.sum()}):')
|
| 583 |
+
print(f' Accuracy : {src_acc:.2f}%')
|
| 584 |
+
print(f' Macro F1 : {src_f1:.4f}')
|
| 585 |
+
for cn, acc in per_class_acc_src.items():
|
| 586 |
+
if acc is not None:
|
| 587 |
+
print(f' {cn:<15s}: {acc:.1f}%')
|
| 588 |
+
|
| 589 |
+
# -- Per-source performance plot --
|
| 590 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 591 |
+
|
| 592 |
+
# Overall bar
|
| 593 |
+
sources = list(per_source.keys())
|
| 594 |
+
accs = [per_source[s]['accuracy'] for s in sources]
|
| 595 |
+
f1s = [per_source[s]['macro_f1'] for s in sources]
|
| 596 |
+
|
| 597 |
+
x = np.arange(len(sources))
|
| 598 |
+
w = 0.35
|
| 599 |
+
axes[0].bar(x - w/2, accs, w, label='Accuracy (%)', color=['#3498db', '#e67e22'], alpha=0.85)
|
| 600 |
+
axes[0].bar(x + w/2, [f*100 for f in f1s], w, label='Macro F1 ×100',
|
| 601 |
+
color=['#2ecc71', '#e74c3c'], alpha=0.85)
|
| 602 |
+
axes[0].set_xticks(x); axes[0].set_xticklabels(sources)
|
| 603 |
+
axes[0].set_ylim(50, 100)
|
| 604 |
+
axes[0].set_ylabel('Score')
|
| 605 |
+
axes[0].set_title('Overall Performance by Source', fontweight='bold')
|
| 606 |
+
axes[0].legend(); axes[0].grid(alpha=0.3, axis='y')
|
| 607 |
+
for xi, (acc, f1) in enumerate(zip(accs, f1s)):
|
| 608 |
+
axes[0].text(xi - w/2, acc + 0.5, f'{acc:.1f}', ha='center', fontsize=9)
|
| 609 |
+
axes[0].text(xi + w/2, f1*100 + 0.5, f'{f1*100:.1f}', ha='center', fontsize=9)
|
| 610 |
+
|
| 611 |
+
# Per-class accuracy by source
|
| 612 |
+
class_data = {cn: [] for cn in CLASS_NAMES}
|
| 613 |
+
valid_sources = []
|
| 614 |
+
for src in sources:
|
| 615 |
+
valid_sources.append(src)
|
| 616 |
+
for cn in CLASS_NAMES:
|
| 617 |
+
acc = per_source[src]['per_class_acc'].get(cn)
|
| 618 |
+
class_data[cn].append(acc if acc is not None else 0.0)
|
| 619 |
+
|
| 620 |
+
x = np.arange(len(CLASS_NAMES))
|
| 621 |
+
n_src = len(valid_sources)
|
| 622 |
+
width = 0.8 / n_src
|
| 623 |
+
colors_src = ['#3498db', '#e67e22', '#2ecc71']
|
| 624 |
+
|
| 625 |
+
for si, src in enumerate(valid_sources):
|
| 626 |
+
vals = [class_data[cn][si] for cn in CLASS_NAMES]
|
| 627 |
+
offset = (si - n_src/2 + 0.5) * width
|
| 628 |
+
axes[1].bar(x + offset, vals, width, label=src, alpha=0.85, color=colors_src[si])
|
| 629 |
+
|
| 630 |
+
axes[1].set_xticks(x); axes[1].set_xticklabels(CLASS_NAMES, rotation=20, ha='right')
|
| 631 |
+
axes[1].set_ylim(0, 105)
|
| 632 |
+
axes[1].set_ylabel('Accuracy (%)')
|
| 633 |
+
axes[1].set_title('Per-Class Accuracy by Source', fontweight='bold')
|
| 634 |
+
axes[1].legend(); axes[1].grid(alpha=0.3, axis='y')
|
| 635 |
+
|
| 636 |
+
plt.suptitle('Dataset Source Performance Analysis', fontsize=14, fontweight='bold')
|
| 637 |
+
plt.tight_layout()
|
| 638 |
+
plt.savefig(f'{OUT_DIR}/per_source_performance.png', dpi=150, bbox_inches='tight')
|
| 639 |
+
plt.close()
|
| 640 |
+
print(' Saved per_source_performance.png')
|
| 641 |
+
|
| 642 |
+
# ================================================================
|
| 643 |
+
# SAVE METRICS JSON
|
| 644 |
+
# ================================================================
|
| 645 |
+
print('\nSaving metrics JSON...')
|
| 646 |
+
|
| 647 |
+
baseline_metrics = {
|
| 648 |
+
'overall_accuracy': float(overall_accuracy),
|
| 649 |
+
'raw_accuracy': float(raw_acc),
|
| 650 |
+
'threshold_accuracy': float(thresh_acc),
|
| 651 |
+
'macro_f1': float(macro_f1),
|
| 652 |
+
'weighted_f1': float(weighted_f1),
|
| 653 |
+
'ece': float(ece),
|
| 654 |
+
'per_class_ece': per_class_ece,
|
| 655 |
+
'per_class_f1': per_class_f1,
|
| 656 |
+
'per_class_precision': per_class_precision,
|
| 657 |
+
'per_class_recall': per_class_recall,
|
| 658 |
+
'per_class_support': per_class_support,
|
| 659 |
+
'per_source_accuracy': {
|
| 660 |
+
src: {
|
| 661 |
+
'accuracy': per_source[src]['accuracy'],
|
| 662 |
+
'macro_f1': per_source[src]['macro_f1'],
|
| 663 |
+
'n_samples': per_source[src]['n_samples'],
|
| 664 |
+
'per_class_acc': per_source[src]['per_class_acc']
|
| 665 |
+
}
|
| 666 |
+
for src in per_source
|
| 667 |
+
},
|
| 668 |
+
'top_confusion_pairs': top5_pairs,
|
| 669 |
+
'confusion_matrix_raw': cm_raw.tolist(),
|
| 670 |
+
'val_split_size': len(val_df),
|
| 671 |
+
'thresholds_used': thresholds,
|
| 672 |
+
'calibration': {
|
| 673 |
+
'ece': float(ece),
|
| 674 |
+
'bin_acc': [float(x) for x in bin_acc],
|
| 675 |
+
'bin_conf': [float(x) for x in bin_conf],
|
| 676 |
+
'bin_count': bin_count,
|
| 677 |
+
}
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
with open(f'{OUT_DIR}/baseline_metrics.json', 'w') as f:
|
| 681 |
+
json.dump(baseline_metrics, f, indent=2)
|
| 682 |
+
print(f' Saved baseline_metrics.json')
|
| 683 |
+
|
| 684 |
+
# ================================================================
|
| 685 |
+
# ANALYSIS REPORT
|
| 686 |
+
# ================================================================
|
| 687 |
+
print('\nGenerating analysis report...')
|
| 688 |
+
|
| 689 |
+
# Identify key findings
|
| 690 |
+
worst_recall_class = min(per_class_recall, key=per_class_recall.get)
|
| 691 |
+
worst_f1_class = min(per_class_f1, key=per_class_f1.get)
|
| 692 |
+
best_f1_class = max(per_class_f1, key=per_class_f1.get)
|
| 693 |
+
|
| 694 |
+
# High-confidence wrong predictions per class
|
| 695 |
+
hcw_analysis = {}
|
| 696 |
+
for ci, cn in enumerate(CLASS_NAMES):
|
| 697 |
+
mask_class = (all_labels == ci)
|
| 698 |
+
wrong_mask = mask_class & ~all_correct
|
| 699 |
+
if wrong_mask.sum() > 0:
|
| 700 |
+
high_conf_wrong = ((all_max_conf > 0.8) & wrong_mask).sum()
|
| 701 |
+
hcw_analysis[cn] = {
|
| 702 |
+
'total_wrong': int(wrong_mask.sum()),
|
| 703 |
+
'high_conf_wrong_count': int(high_conf_wrong),
|
| 704 |
+
'high_conf_wrong_pct': float(high_conf_wrong / wrong_mask.sum() * 100) if wrong_mask.sum() > 0 else 0,
|
| 705 |
+
'mean_wrong_conf': float(all_max_conf[wrong_mask].mean()) if wrong_mask.sum() > 0 else 0,
|
| 706 |
+
}
|
| 707 |
+
else:
|
| 708 |
+
hcw_analysis[cn] = {'total_wrong': 0, 'high_conf_wrong_count': 0,
|
| 709 |
+
'high_conf_wrong_pct': 0, 'mean_wrong_conf': 0}
|
| 710 |
+
|
| 711 |
+
# Domain gap
|
| 712 |
+
domain_gap = None
|
| 713 |
+
if 'ODIR' in per_source and 'APTOS' in per_source:
|
| 714 |
+
odir_acc = per_source['ODIR']['accuracy']
|
| 715 |
+
aptos_acc = per_source['APTOS']['accuracy']
|
| 716 |
+
domain_gap = abs(odir_acc - aptos_acc)
|
| 717 |
+
|
| 718 |
+
# DR-specific domain gap
|
| 719 |
+
odir_dr = per_source['ODIR']['per_class_acc'].get('Diabetes/DR', 0) or 0
|
| 720 |
+
aptos_dr = per_source['APTOS']['per_class_acc'].get('Diabetes/DR', 0) or 0
|
| 721 |
+
dr_gap = abs(odir_dr - aptos_dr)
|
| 722 |
+
else:
|
| 723 |
+
domain_gap = 0; odir_acc = 0; aptos_acc = 0; odir_dr = 0; aptos_dr = 0; dr_gap = 0
|
| 724 |
+
|
| 725 |
+
calibration_verdict = 'overconfident' if sum(
|
| 726 |
+
b_conf - b_acc for b_conf, b_acc in zip(bin_conf, bin_acc) if bin_count[bin_acc.index(b_acc)] > 0
|
| 727 |
+
) > 0 else 'underconfident'
|
| 728 |
+
|
| 729 |
+
report = f"""# RetinaSense ViT v2 — Baseline Error Analysis Report
|
| 730 |
+
**Generated**: 2026-03-06
|
| 731 |
+
**Model**: ViT-Base-Patch16-224 (MultiTaskViT)
|
| 732 |
+
**Checkpoint**: outputs_vit/best_model.pth
|
| 733 |
+
**Val Split**: {len(val_df)} samples (20% stratified, random_state=42)
|
| 734 |
+
|
| 735 |
+
---
|
| 736 |
+
|
| 737 |
+
## 1. Overall Performance
|
| 738 |
+
|
| 739 |
+
| Metric | Value |
|
| 740 |
+
|--------|-------|
|
| 741 |
+
| Accuracy (raw argmax) | {raw_acc:.2f}% |
|
| 742 |
+
| Accuracy (with thresholds) | {thresh_acc:.2f}% |
|
| 743 |
+
| Macro F1 | {macro_f1:.4f} |
|
| 744 |
+
| Weighted F1 | {weighted_f1:.4f} |
|
| 745 |
+
| ECE (10 bins) | {ece:.4f} |
|
| 746 |
+
|
| 747 |
+
---
|
| 748 |
+
|
| 749 |
+
## 2. Per-Class Metrics
|
| 750 |
+
|
| 751 |
+
| Class | Precision | Recall | F1 | Support |
|
| 752 |
+
|-------|-----------|--------|----|---------|
|
| 753 |
+
"""
|
| 754 |
+
for cn in CLASS_NAMES:
|
| 755 |
+
report += (f"| {cn:<15s} | {per_class_precision[cn]:.4f} | "
|
| 756 |
+
f"{per_class_recall[cn]:.4f} | {per_class_f1[cn]:.4f} | "
|
| 757 |
+
f"{per_class_support[cn]:4d} |\n")
|
| 758 |
+
|
| 759 |
+
report += f"""
|
| 760 |
+
---
|
| 761 |
+
|
| 762 |
+
## 3. Confusion Analysis — Top 5 Confused Pairs
|
| 763 |
+
|
| 764 |
+
| Rank | True Class | Predicted As | Count | Rate |
|
| 765 |
+
|------|-----------|-------------|-------|------|
|
| 766 |
+
"""
|
| 767 |
+
for rank, pair in enumerate(top5_pairs, 1):
|
| 768 |
+
report += (f"| {rank} | {pair['true_class']} | {pair['pred_class']} | "
|
| 769 |
+
f"{pair['count']} | {pair['rate']*100:.1f}% |\n")
|
| 770 |
+
|
| 771 |
+
report += f"""
|
| 772 |
+
### Full Confusion Matrix (normalized by true class)
|
| 773 |
+
|
| 774 |
+
```
|
| 775 |
+
{(' '.join(f'{cn[:6]:>7s}' for cn in CLASS_NAMES))}
|
| 776 |
+
"""
|
| 777 |
+
for ri, rn in enumerate(CLASS_NAMES):
|
| 778 |
+
row_str = ' '.join(f'{cm_norm[ri, ci]:.3f}' for ci in range(NUM_CLASSES))
|
| 779 |
+
report += f"{rn[:8]:>8s} {row_str}\n"
|
| 780 |
+
|
| 781 |
+
report += f"""```
|
| 782 |
+
|
| 783 |
+
---
|
| 784 |
+
|
| 785 |
+
## 4. Confidence Calibration Analysis
|
| 786 |
+
|
| 787 |
+
- **ECE (overall)**: {ece:.4f}
|
| 788 |
+
- **Calibration pattern**: The model is predominantly **{calibration_verdict}**
|
| 789 |
+
(mean confidence exceeds accuracy in most bins).
|
| 790 |
+
|
| 791 |
+
### Per-Class ECE
|
| 792 |
+
|
| 793 |
+
| Class | ECE |
|
| 794 |
+
|-------|-----|
|
| 795 |
+
"""
|
| 796 |
+
for cn, ece_c in per_class_ece.items():
|
| 797 |
+
report += f"| {cn} | {ece_c:.4f} |\n"
|
| 798 |
+
|
| 799 |
+
report += f"""
|
| 800 |
+
### High-Confidence Wrong Predictions (confidence > 0.8)
|
| 801 |
+
|
| 802 |
+
| Class | Total Wrong | High-Conf Wrong | % of Errors | Mean Wrong Conf |
|
| 803 |
+
|-------|------------|----------------|-------------|----------------|
|
| 804 |
+
"""
|
| 805 |
+
for cn, hcw in hcw_analysis.items():
|
| 806 |
+
report += (f"| {cn} | {hcw['total_wrong']} | {hcw['high_conf_wrong_count']} | "
|
| 807 |
+
f"{hcw['high_conf_wrong_pct']:.1f}% | {hcw['mean_wrong_conf']:.3f} |\n")
|
| 808 |
+
|
| 809 |
+
report += f"""
|
| 810 |
+
---
|
| 811 |
+
|
| 812 |
+
## 5. Dataset Source Analysis (ODIR vs APTOS)
|
| 813 |
+
|
| 814 |
+
| Source | N Samples | Accuracy | Macro F1 |
|
| 815 |
+
|--------|-----------|----------|----------|
|
| 816 |
+
"""
|
| 817 |
+
for src, data in per_source.items():
|
| 818 |
+
report += f"| {src} | {data['n_samples']} | {data['accuracy']:.2f}% | {data['macro_f1']:.4f} |\n"
|
| 819 |
+
|
| 820 |
+
report += f"""
|
| 821 |
+
### Per-Class Accuracy by Source
|
| 822 |
+
|
| 823 |
+
| Class |"""
|
| 824 |
+
for src in per_source:
|
| 825 |
+
report += f" {src} |"
|
| 826 |
+
report += "\n|-------|"
|
| 827 |
+
for _ in per_source:
|
| 828 |
+
report += "--------|"
|
| 829 |
+
report += "\n"
|
| 830 |
+
for cn in CLASS_NAMES:
|
| 831 |
+
report += f"| {cn} |"
|
| 832 |
+
for src in per_source:
|
| 833 |
+
acc = per_source[src]['per_class_acc'].get(cn)
|
| 834 |
+
if acc is None:
|
| 835 |
+
report += " N/A |"
|
| 836 |
+
else:
|
| 837 |
+
report += f" {acc:.1f}% |"
|
| 838 |
+
report += "\n"
|
| 839 |
+
|
| 840 |
+
report += f"""
|
| 841 |
+
**Domain gap (overall accuracy)**: {domain_gap:.2f}pp between ODIR and APTOS
|
| 842 |
+
"""
|
| 843 |
+
if 'ODIR' in per_source and 'APTOS' in per_source:
|
| 844 |
+
report += f"""**DR class gap (ODIR vs APTOS)**: ODIR={odir_dr:.1f}% vs APTOS={aptos_dr:.1f}% (gap={dr_gap:.1f}pp)
|
| 845 |
+
"""
|
| 846 |
+
|
| 847 |
+
report += f"""
|
| 848 |
+
---
|
| 849 |
+
|
| 850 |
+
## 6. Error Pattern Summary
|
| 851 |
+
|
| 852 |
+
### Q1: What is the model's biggest weakness?
|
| 853 |
+
|
| 854 |
+
The model's biggest weakness is classifying **{worst_f1_class}** (F1={per_class_f1[worst_f1_class]:.4f},
|
| 855 |
+
recall={per_class_recall[worst_f1_class]:.4f}). This class has the worst F1 score, indicating the
|
| 856 |
+
model struggles to both detect and correctly distinguish it from other pathologies.
|
| 857 |
+
|
| 858 |
+
The confusion matrix shows that the primary confusion pathway is:
|
| 859 |
+
- **{top5_pairs[0]['description']}**: {top5_pairs[0]['count']} cases ({top5_pairs[0]['rate']*100:.1f}% error rate)
|
| 860 |
+
- **{top5_pairs[1]['description']}**: {top5_pairs[1]['count']} cases ({top5_pairs[1]['rate']*100:.1f}% error rate)
|
| 861 |
+
|
| 862 |
+
### Q2: Which class has the worst recall? Why?
|
| 863 |
+
|
| 864 |
+
**{worst_recall_class}** has the worst recall at {per_class_recall[worst_recall_class]:.4f}.
|
| 865 |
+
"""
|
| 866 |
+
|
| 867 |
+
# Detailed reason based on support
|
| 868 |
+
worst_support = per_class_support[worst_recall_class]
|
| 869 |
+
all_support = sum(per_class_support.values())
|
| 870 |
+
worst_pct = worst_support / all_support * 100
|
| 871 |
+
report += f"""This class represents only {worst_support} samples ({worst_pct:.1f}% of the val set).
|
| 872 |
+
The low recall is likely caused by:
|
| 873 |
+
1. **Class imbalance** — the model sees fewer examples during training and defaults to predicting
|
| 874 |
+
more common classes when uncertain.
|
| 875 |
+
2. **Visual similarity** with other conditions (especially {top5_pairs[0]['pred_class'] if top5_pairs[0]['true_class']==worst_recall_class else 'Normal'})
|
| 876 |
+
at the fundus level.
|
| 877 |
+
3. **Threshold sensitivity** — the optimized threshold ({thresholds.get(CLASS_NAMES.index(worst_recall_class), 0.5):.2f})
|
| 878 |
+
may overcorrect or undercorrect depending on the calibration.
|
| 879 |
+
|
| 880 |
+
### Q3: Evidence of domain shift (ODIR vs APTOS)?
|
| 881 |
+
|
| 882 |
+
"""
|
| 883 |
+
if domain_gap is not None and domain_gap > 2.0:
|
| 884 |
+
report += f"""YES — there is a **{domain_gap:.1f}pp accuracy gap** between ODIR ({odir_acc:.1f}%) and APTOS
|
| 885 |
+
({aptos_acc:.1f}%). This is significant and consistent with domain shift between the two data sources.
|
| 886 |
+
|
| 887 |
+
For the DR/Diabetes class specifically, the gap is **{dr_gap:.1f}pp** (ODIR={odir_dr:.1f}% vs APTOS={aptos_dr:.1f}%).
|
| 888 |
+
APTOS images are specifically DR-graded fundus photographs from India (Aravind Eye Hospital),
|
| 889 |
+
while ODIR covers multiple disease classes with more varied image quality and capture conditions.
|
| 890 |
+
The Ben Graham preprocessing helps but does not fully bridge the domain gap.
|
| 891 |
+
|
| 892 |
+
**Implication for v3**: Domain-specific augmentation or source-aware training (e.g., source
|
| 893 |
+
as auxiliary input, separate batch norms, or domain adaptation) may improve generalization.
|
| 894 |
+
"""
|
| 895 |
+
elif domain_gap is not None and domain_gap > 0:
|
| 896 |
+
report += f"""MINOR gap observed — {domain_gap:.1f}pp difference between ODIR ({odir_acc:.1f}%) and
|
| 897 |
+
APTOS ({aptos_acc:.1f}%). The gap is small, suggesting the Ben Graham preprocessing and ViT
|
| 898 |
+
architecture generalize reasonably across sources. DR-specific gap: {dr_gap:.1f}pp.
|
| 899 |
+
"""
|
| 900 |
+
else:
|
| 901 |
+
report += "Insufficient cross-source data to conclude domain shift.\n"
|
| 902 |
+
|
| 903 |
+
report += f"""
|
| 904 |
+
### Q4: Calibration assessment
|
| 905 |
+
|
| 906 |
+
ECE = **{ece:.4f}** (scale: 0=perfect, 0.1=poor).
|
| 907 |
+
|
| 908 |
+
"""
|
| 909 |
+
if ece < 0.03:
|
| 910 |
+
report += "The model is **well-calibrated** (ECE < 0.03). Confidence scores are reliable."
|
| 911 |
+
elif ece < 0.07:
|
| 912 |
+
report += f"""The model shows **moderate miscalibration** (ECE={ece:.4f}). The reliability diagram
|
| 913 |
+
shows the model is {calibration_verdict} in the high-confidence range, meaning predicted
|
| 914 |
+
confidence scores are not fully reliable. Temperature scaling in v3 is recommended."""
|
| 915 |
+
else:
|
| 916 |
+
report += f"""The model is **poorly calibrated** (ECE={ece:.4f}). The {calibration_verdict}
|
| 917 |
+
pattern is severe. Temperature scaling or label smoothing in v3 training is strongly recommended."""
|
| 918 |
+
|
| 919 |
+
report += f"""
|
| 920 |
+
|
| 921 |
+
---
|
| 922 |
+
|
| 923 |
+
## 7. Recommendations for v3 Training
|
| 924 |
+
|
| 925 |
+
Based on this baseline analysis:
|
| 926 |
+
|
| 927 |
+
1. **Address {worst_recall_class} recall** — increase class weight, targeted augmentation,
|
| 928 |
+
or focal loss gamma tuning for this class.
|
| 929 |
+
2. **Calibration** — add temperature scaling post-training or increase label smoothing
|
| 930 |
+
(current ECE={ece:.4f}).
|
| 931 |
+
3. **Domain shift mitigation** — consider source-conditioned augmentation or adversarial
|
| 932 |
+
domain adaptation if ODIR/APTOS gap persists.
|
| 933 |
+
4. **High-confidence errors** — the model makes confidently wrong predictions on certain
|
| 934 |
+
classes; mixup or CutMix augmentation may improve uncertainty estimation.
|
| 935 |
+
5. **Top confusion pairs** to specifically target:
|
| 936 |
+
"""
|
| 937 |
+
for pair in top5_pairs[:3]:
|
| 938 |
+
report += f" - {pair['description']} ({pair['count']} errors)\n"
|
| 939 |
+
|
| 940 |
+
report += """
|
| 941 |
+
---
|
| 942 |
+
|
| 943 |
+
## 8. Output Files
|
| 944 |
+
|
| 945 |
+
| File | Description |
|
| 946 |
+
|------|-------------|
|
| 947 |
+
| confusion_matrix_raw.png | Raw count confusion matrix |
|
| 948 |
+
| confusion_matrix_normalized.png | Recall-normalized confusion matrix |
|
| 949 |
+
| reliability_diagram.png | ECE calibration plot |
|
| 950 |
+
| confidence_distributions.png | Per-class confidence histograms |
|
| 951 |
+
| per_source_performance.png | ODIR vs APTOS breakdown |
|
| 952 |
+
| baseline_metrics.json | All metrics in structured JSON |
|
| 953 |
+
|
| 954 |
+
---
|
| 955 |
+
*Report generated by RetinaSense ViT v2 error analysis pipeline.*
|
| 956 |
+
"""
|
| 957 |
+
|
| 958 |
+
with open(f'{OUT_DIR}/BASELINE_ANALYSIS.md', 'w') as f:
|
| 959 |
+
f.write(report)
|
| 960 |
+
print(f' Saved BASELINE_ANALYSIS.md')
|
| 961 |
+
|
| 962 |
+
# ================================================================
|
| 963 |
+
# FINAL SUMMARY
|
| 964 |
+
# ================================================================
|
| 965 |
+
print('\n' + '='*65)
|
| 966 |
+
print(' BASELINE ANALYSIS COMPLETE')
|
| 967 |
+
print('='*65)
|
| 968 |
+
print(f' Val accuracy (thresh) : {thresh_acc:.2f}%')
|
| 969 |
+
print(f' Macro F1 : {macro_f1:.4f}')
|
| 970 |
+
print(f' ECE : {ece:.4f}')
|
| 971 |
+
print(f' Worst class (F1) : {worst_f1_class} ({per_class_f1[worst_f1_class]:.4f})')
|
| 972 |
+
print(f' Worst class (recall) : {worst_recall_class} ({per_class_recall[worst_recall_class]:.4f})')
|
| 973 |
+
print(f' Top confusion : {top5_pairs[0]["description"]}')
|
| 974 |
+
if domain_gap is not None:
|
| 975 |
+
print(f' Domain gap (ODIR-APTOS): {domain_gap:.2f}pp')
|
| 976 |
+
print(f'\n All outputs in: {OUT_DIR}/')
|
| 977 |
+
print('='*65)
|