retinasense-vit / fairness_analysis.py
tanishq74's picture
Add fairness_analysis.py
e5d94b0 verified
Raw
History Blame Contribute Delete
36.3 kB
#!/usr/bin/env python3
"""
RetinaSense v3.0 — Phase 1D: Fairness & Domain Robustness Analysis
===================================================================
Evaluates model performance across data sources (APTOS, ODIR, REFUGE2)
to quantify domain gap and identify fairness concerns.
Outputs (saved to outputs_v3/fairness/):
- performance_by_source.png : grouped bar chart of metrics per class per source
- calibration_by_source.png : reliability diagrams by source
- confusion_matrix_aptos.png : confusion matrix for APTOS subset
- confusion_matrix_odir.png : confusion matrix for ODIR subset
- confidence_by_source.png : violin plots of prediction confidence
- error_patterns.png : most common misclassification pairs by source
- domain_gap_report.json : full quantitative report
"""
import os
import sys
import json
import warnings
import numpy as np
import pandas as pd
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from PIL import Image
from tqdm import tqdm
from collections import Counter, defaultdict
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import timm
from sklearn.metrics import (
accuracy_score, f1_score, precision_score, recall_score,
confusion_matrix, classification_report
)
from scipy import stats
# ================================================================
# CONFIGURATION
# ================================================================
BASE_DIR = '/teamspace/studios/this_studio'
OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3')
FAIRNESS_DIR = os.path.join(OUTPUT_DIR, 'fairness')
os.makedirs(FAIRNESS_DIR, exist_ok=True)
MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth')
TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv')
NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json')
TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json')
NUM_CLASSES = 5
IMG_SIZE = 224
DROPOUT = 0.3
BATCH_SIZE = 32
CLASS_NAMES = ['Normal', 'DR', 'Glaucoma', 'Cataract', 'AMD']
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('=' * 65)
print(' RetinaSense v3.0 -- Phase 1D: Fairness & Domain Robustness')
print('=' * 65)
print(f' Device : {DEVICE}')
if torch.cuda.is_available():
print(f' GPU : {torch.cuda.get_device_name(0)}')
print(f' Output : {FAIRNESS_DIR}')
print('=' * 65)
# ================================================================
# LOAD NORMALISATION STATS
# ================================================================
if os.path.exists(NORM_STATS_PATH):
with open(NORM_STATS_PATH) as f:
norm_stats = json.load(f)
NORM_MEAN = norm_stats['mean_rgb']
NORM_STD = norm_stats['std_rgb']
print(f' Fundus norm: mean={[round(v, 4) for v in NORM_MEAN]}, '
f'std={[round(v, 4) for v in NORM_STD]}')
else:
NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD = [0.229, 0.224, 0.225]
print(' Using ImageNet normalisation fallback')
# Load temperature
with open(TEMPERATURE_PATH) as f:
temp_data = json.load(f)
TEMPERATURE = temp_data['temperature']
print(f' Temperature T = {TEMPERATURE:.4f}')
# ================================================================
# MODEL ARCHITECTURE (mirrors retinasense_v3.py / gradcam_v3.py)
# ================================================================
class MultiTaskViT(nn.Module):
"""ViT-Base-Patch16-224 with disease + severity heads."""
def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT):
super().__init__()
self.backbone = timm.create_model(
'vit_base_patch16_224', pretrained=False, num_classes=0
)
feat = 768
self.drop = nn.Dropout(drop)
self.disease_head = nn.Sequential(
nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(256, n_disease),
)
self.severity_head = nn.Sequential(
nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(256, n_severity),
)
def forward(self, x):
f = self.backbone(x)
f = self.drop(f)
return self.disease_head(f), self.severity_head(f)
# ================================================================
# LOAD MODEL
# ================================================================
print('\nLoading model...')
model = MultiTaskViT().to(DEVICE)
ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print(f' Loaded: {MODEL_PATH}')
print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1} '
f'val_acc={ckpt.get("val_acc", 0):.2f}%')
# ================================================================
# DATASET
# ================================================================
class FairnessDataset(Dataset):
"""Loads preprocessed cached images for inference."""
def __init__(self, csv_path, base_dir):
self.df = pd.read_csv(csv_path)
self.base_dir = base_dir
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize(NORM_MEAN, NORM_STD),
])
print(f' Test set: {len(self.df)} images')
print(f' Sources : {dict(self.df["source"].value_counts())}')
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
label = int(row['disease_label'])
source = row['source']
# Try cached .npy first
cache_path = row.get('cache_path', '')
if isinstance(cache_path, str) and cache_path:
# Resolve relative path
if not os.path.isabs(cache_path):
clean = cache_path
while clean.startswith('./') or clean.startswith('.//'):
clean = clean[2:] if clean.startswith('./') else clean[3:]
cache_path = os.path.join(self.base_dir, clean)
if os.path.exists(cache_path):
img = np.load(cache_path) # (224,224,3) uint8
tensor = self.transform(img)
return tensor, label, source, idx
# Fallback: load and preprocess from source image
image_path = row['image_path']
if not os.path.isabs(image_path):
clean = image_path
while clean.startswith('./') or clean.startswith('.//'):
clean = clean[2:] if clean.startswith('./') else clean[3:]
image_path = os.path.join(self.base_dir, clean)
if source == 'APTOS':
img = self._ben_graham(image_path)
else:
img = self._clahe_preprocess(image_path)
tensor = self.transform(img)
return tensor, label, source, idx
@staticmethod
def _ben_graham(path, sz=IMG_SIZE, sigma=10):
img = cv2.imread(path)
if img is None:
img = np.array(Image.open(path).convert('RGB'))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (sz, sz))
img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0, 0), sigma), -4, 128)
mask = np.zeros(img.shape[:2], dtype=np.uint8)
cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1)
return cv2.bitwise_and(img, img, mask=mask)
@staticmethod
def _clahe_preprocess(path, sz=IMG_SIZE):
img = cv2.imread(path)
if img is None:
img = np.array(Image.open(path).convert('RGB'))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = cv2.resize(img, (sz, sz))
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
lab[:, :, 0] = clahe.apply(lab[:, :, 0])
img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ================================================================
# RUN INFERENCE
# ================================================================
print('\nRunning inference on test set...')
dataset = FairnessDataset(TEST_CSV, BASE_DIR)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=4, pin_memory=True)
all_labels = []
all_preds = []
all_probs = [] # temperature-scaled probabilities
all_logits = []
all_sources = []
all_indices = []
with torch.no_grad():
for batch_imgs, batch_labels, batch_sources, batch_idx in tqdm(loader, desc='Inference'):
batch_imgs = batch_imgs.to(DEVICE)
disease_logits, _ = model(batch_imgs)
# Temperature scaling
scaled_logits = disease_logits / TEMPERATURE
probs = F.softmax(scaled_logits, dim=1)
preds = probs.argmax(dim=1)
all_labels.extend(batch_labels.numpy().tolist())
all_preds.extend(preds.cpu().numpy().tolist())
all_probs.extend(probs.cpu().numpy().tolist())
all_logits.extend(disease_logits.cpu().numpy().tolist())
all_sources.extend(list(batch_sources))
all_indices.extend(batch_idx.numpy().tolist())
all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_probs = np.array(all_probs)
all_logits = np.array(all_logits)
all_sources = np.array(all_sources)
all_correct = (all_labels == all_preds).astype(int)
# Max confidence (probability of predicted class)
all_confidence = np.max(all_probs, axis=1)
print(f'\n Total images: {len(all_labels)}')
print(f' Overall accuracy: {accuracy_score(all_labels, all_preds):.4f}')
print(f' Sources: {Counter(all_sources)}')
# ================================================================
# HELPER: compute metrics for a subset
# ================================================================
def compute_metrics(labels, preds, probs=None, class_names=CLASS_NAMES):
"""Compute per-class and overall metrics."""
present_classes = sorted(set(labels))
results = {}
for c in range(len(class_names)):
mask = labels == c
n_c = mask.sum()
if n_c == 0:
results[class_names[c]] = {
'n': 0, 'accuracy': None, 'precision': None,
'recall': None, 'f1': None
}
continue
# Binary: class c vs rest
y_true_bin = (labels == c).astype(int)
y_pred_bin = (preds == c).astype(int)
tp = ((y_true_bin == 1) & (y_pred_bin == 1)).sum()
fn = ((y_true_bin == 1) & (y_pred_bin == 0)).sum()
fp = ((y_true_bin == 0) & (y_pred_bin == 1)).sum()
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
acc = (labels[mask] == preds[mask]).mean()
results[class_names[c]] = {
'n': int(n_c),
'accuracy': float(round(acc, 4)),
'precision': float(round(prec, 4)),
'recall': float(round(rec, 4)),
'f1': float(round(f1, 4)),
}
# Overall
overall_acc = accuracy_score(labels, preds)
present_labels = sorted(set(labels) | set(preds))
overall_f1 = f1_score(labels, preds, labels=present_labels, average='macro', zero_division=0)
overall_prec = precision_score(labels, preds, labels=present_labels, average='macro', zero_division=0)
overall_rec = recall_score(labels, preds, labels=present_labels, average='macro', zero_division=0)
results['Overall'] = {
'n': int(len(labels)),
'accuracy': float(round(overall_acc, 4)),
'precision': float(round(overall_prec, 4)),
'recall': float(round(overall_rec, 4)),
'f1': float(round(overall_f1, 4)),
}
return results
# ================================================================
# SPLIT BY SOURCE
# ================================================================
sources_unique = sorted(set(all_sources))
print(f'\n Unique sources: {sources_unique}')
source_masks = {}
source_metrics = {}
for src in sources_unique:
mask = all_sources == src
source_masks[src] = mask
labels_s = all_labels[mask]
preds_s = all_preds[mask]
probs_s = all_probs[mask]
metrics = compute_metrics(labels_s, preds_s, probs_s)
source_metrics[src] = metrics
acc = metrics['Overall']['accuracy']
f1 = metrics['Overall']['f1']
print(f' {src:8s}: n={mask.sum():4d} acc={acc:.4f} macro-F1={f1:.4f}')
# ================================================================
# PLOT 1: Performance by Source (grouped bar chart)
# ================================================================
print('\nGenerating performance_by_source.png...')
# Focus on DR comparison (APTOS vs ODIR) + ODIR per-class
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
plt.subplots_adjust(wspace=0.35)
# Left panel: DR comparison across sources
metric_names = ['accuracy', 'f1', 'precision', 'recall']
metric_labels = ['Accuracy', 'F1-Score', 'Precision', 'Recall']
bar_colors = {'APTOS': '#2196F3', 'ODIR': '#FF9800', 'REFUGE2': '#4CAF50'}
# DR-only comparison
ax = axes[0]
x = np.arange(len(metric_names))
bar_width = 0.25
offsets = np.arange(len(sources_unique)) - (len(sources_unique) - 1) / 2
for i, src in enumerate(sources_unique):
dr_metrics = source_metrics[src].get('DR', {})
vals = []
for m in metric_names:
v = dr_metrics.get(m, None)
vals.append(v if v is not None else 0)
bars = ax.bar(x + offsets[i] * bar_width, vals, bar_width,
label=f'{src} (n={dr_metrics.get("n", 0)})',
color=bar_colors.get(src, '#999999'), edgecolor='white', linewidth=0.5)
# Value labels
for bar, val in zip(bars, vals):
if val > 0:
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
f'{val:.2f}', ha='center', va='bottom', fontsize=8, fontweight='bold')
ax.set_xlabel('Metric', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('DR Classification: APTOS vs ODIR vs REFUGE2', fontsize=13, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(metric_labels)
ax.set_ylim(0, 1.15)
ax.legend(fontsize=9)
ax.grid(axis='y', alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Right panel: ODIR per-class performance
ax = axes[1]
odir_metrics = source_metrics.get('ODIR', {})
classes_present = [c for c in CLASS_NAMES if odir_metrics.get(c, {}).get('n', 0) > 0]
x2 = np.arange(len(classes_present))
bar_width2 = 0.18
colors_metric = ['#1976D2', '#388E3C', '#F57C00', '#D32F2F']
for i, (m, ml) in enumerate(zip(metric_names, metric_labels)):
vals = [odir_metrics[c][m] if odir_metrics[c][m] is not None else 0 for c in classes_present]
bars = ax.bar(x2 + (i - 1.5) * bar_width2, vals, bar_width2,
label=ml, color=colors_metric[i], edgecolor='white', linewidth=0.5)
for bar, val in zip(bars, vals):
if val > 0:
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
f'{val:.2f}', ha='center', va='bottom', fontsize=7, fontweight='bold')
ax.set_xlabel('Disease Class', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('ODIR Per-Class Performance', fontsize=13, fontweight='bold')
class_labels = [f'{c}\n(n={odir_metrics[c]["n"]})' for c in classes_present]
ax.set_xticks(x2)
ax.set_xticklabels(class_labels)
ax.set_ylim(0, 1.15)
ax.legend(fontsize=9, ncol=2)
ax.grid(axis='y', alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
fig.suptitle('Fairness & Domain Robustness: Performance by Source',
fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
fig.savefig(os.path.join(FAIRNESS_DIR, 'performance_by_source.png'),
dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print(' Saved performance_by_source.png')
# ================================================================
# PLOT 2: Calibration by Source (reliability diagrams)
# ================================================================
print('Generating calibration_by_source.png...')
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# Reliability diagram helper
def reliability_diagram(labels, probs, preds, n_bins=10):
"""Compute reliability diagram data (mean predicted confidence vs actual accuracy per bin)."""
confidences = np.max(probs, axis=1)
correct = (labels == preds).astype(float)
bin_edges = np.linspace(0, 1, n_bins + 1)
bin_centers = []
bin_accs = []
bin_confs = []
bin_counts = []
for b in range(n_bins):
lo, hi = bin_edges[b], bin_edges[b + 1]
mask = (confidences >= lo) & (confidences < hi) if b < n_bins - 1 \
else (confidences >= lo) & (confidences <= hi)
if mask.sum() == 0:
continue
bin_centers.append((lo + hi) / 2)
bin_accs.append(correct[mask].mean())
bin_confs.append(confidences[mask].mean())
bin_counts.append(int(mask.sum()))
return np.array(bin_centers), np.array(bin_accs), np.array(bin_confs), np.array(bin_counts)
# Left: all sources overlaid
ax = axes[0]
ax.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect calibration')
source_colors = {'APTOS': '#2196F3', 'ODIR': '#FF9800', 'REFUGE2': '#4CAF50'}
for src in sources_unique:
mask = source_masks[src]
if mask.sum() < 10:
continue
centers, accs, confs, counts = reliability_diagram(
all_labels[mask], all_probs[mask], all_preds[mask], n_bins=10
)
ax.plot(confs, accs, 'o-', color=source_colors.get(src, '#999'),
label=f'{src} (n={mask.sum()})', markersize=6, linewidth=2)
ax.set_xlabel('Mean Predicted Confidence', fontsize=12)
ax.set_ylabel('Actual Accuracy (Fraction Correct)', fontsize=12)
ax.set_title('Reliability Diagram by Source', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.grid(alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_aspect('equal')
# Right: ECE and histogram of confidence
ax = axes[1]
for src in sources_unique:
mask = source_masks[src]
if mask.sum() < 10:
continue
confs = np.max(all_probs[mask], axis=1)
ax.hist(confs, bins=20, alpha=0.5, label=f'{src} (n={mask.sum()})',
color=source_colors.get(src, '#999'), density=True, edgecolor='white')
ax.set_xlabel('Prediction Confidence', fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.set_title('Confidence Distribution by Source', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
fig.savefig(os.path.join(FAIRNESS_DIR, 'calibration_by_source.png'),
dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print(' Saved calibration_by_source.png')
# ================================================================
# PLOT 3 & 4: Confusion Matrices
# ================================================================
def plot_confusion_matrix(labels, preds, title, save_path, class_names=CLASS_NAMES):
"""Plot and save a publication-quality confusion matrix."""
present_classes = sorted(set(labels) | set(preds))
present_names = [class_names[c] for c in present_classes]
cm = confusion_matrix(labels, preds, labels=present_classes)
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
cm_norm = np.nan_to_num(cm_norm)
fig, ax = plt.subplots(figsize=(8, 7))
# Create annotation strings with both count and percentage
annot = np.empty_like(cm, dtype=object)
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
annot[i, j] = f'{cm[i, j]}\n({cm_norm[i, j]:.1%})'
sns.heatmap(cm_norm, annot=annot, fmt='', cmap='Blues',
xticklabels=present_names, yticklabels=present_names,
ax=ax, vmin=0, vmax=1, linewidths=0.5, linecolor='white',
cbar_kws={'label': 'Proportion'})
ax.set_xlabel('Predicted Label', fontsize=12, fontweight='bold')
ax.set_ylabel('True Label', fontsize=12, fontweight='bold')
ax.set_title(title, fontsize=14, fontweight='bold', pad=15)
ax.tick_params(labelsize=10)
plt.tight_layout()
fig.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
return cm
print('Generating confusion matrices...')
# APTOS confusion matrix (DR only, but predictions can span all classes)
aptos_mask = all_sources == 'APTOS'
if aptos_mask.sum() > 0:
cm_aptos = plot_confusion_matrix(
all_labels[aptos_mask], all_preds[aptos_mask],
f'Confusion Matrix: APTOS (n={aptos_mask.sum()}, DR images only)',
os.path.join(FAIRNESS_DIR, 'confusion_matrix_aptos.png')
)
print(' Saved confusion_matrix_aptos.png')
# ODIR confusion matrix (all 5 classes)
odir_mask = all_sources == 'ODIR'
if odir_mask.sum() > 0:
cm_odir = plot_confusion_matrix(
all_labels[odir_mask], all_preds[odir_mask],
f'Confusion Matrix: ODIR (n={odir_mask.sum()}, all 5 classes)',
os.path.join(FAIRNESS_DIR, 'confusion_matrix_odir.png')
)
print(' Saved confusion_matrix_odir.png')
# ================================================================
# PLOT 5: Confidence by Source (violin/box plots)
# ================================================================
print('Generating confidence_by_source.png...')
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
# Build dataframe for plotting
conf_df = pd.DataFrame({
'Source': all_sources,
'Confidence': all_confidence,
'Correct': ['Correct' if c else 'Incorrect' for c in all_correct],
'Predicted': [CLASS_NAMES[p] for p in all_preds],
'True': [CLASS_NAMES[l] for l in all_labels],
})
# Left: violin by source and correctness
ax = axes[0]
plot_sources = [s for s in sources_unique if (all_sources == s).sum() >= 10]
conf_df_filtered = conf_df[conf_df['Source'].isin(plot_sources)]
sns.violinplot(data=conf_df_filtered, x='Source', y='Confidence', hue='Correct',
split=True, inner='quartile', ax=ax,
palette={'Correct': '#4CAF50', 'Incorrect': '#F44336'})
# Add means as scatter points
for i, src in enumerate(plot_sources):
for j, corr_label in enumerate(['Correct', 'Incorrect']):
mask_sc = (conf_df_filtered['Source'] == src) & (conf_df_filtered['Correct'] == corr_label)
if mask_sc.sum() > 0:
mean_val = conf_df_filtered.loc[mask_sc, 'Confidence'].mean()
offset = -0.05 if j == 0 else 0.05
ax.scatter(i + offset, mean_val, color='black', s=30, zorder=5, marker='D')
ax.set_title('Prediction Confidence by Source & Correctness', fontsize=13, fontweight='bold')
ax.set_ylabel('Confidence (max probability)', fontsize=12)
ax.set_xlabel('Data Source', fontsize=12)
ax.grid(axis='y', alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Right: confidence by source for DR images only
ax = axes[1]
dr_df = conf_df[conf_df['True'] == 'DR']
dr_plot_sources = [s for s in sources_unique if ((all_sources == s) & (all_labels == 1)).sum() >= 5]
dr_df_filtered = dr_df[dr_df['Source'].isin(dr_plot_sources)]
if len(dr_df_filtered) > 0:
sns.violinplot(data=dr_df_filtered, x='Source', y='Confidence', hue='Correct',
split=True, inner='quartile', ax=ax,
palette={'Correct': '#4CAF50', 'Incorrect': '#F44336'})
ax.set_title('DR Confidence: APTOS vs ODIR', fontsize=13, fontweight='bold')
ax.set_ylabel('Confidence (max probability)', fontsize=12)
ax.set_xlabel('Data Source', fontsize=12)
ax.grid(axis='y', alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
fig.savefig(os.path.join(FAIRNESS_DIR, 'confidence_by_source.png'),
dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print(' Saved confidence_by_source.png')
# ================================================================
# PLOT 6: Error Patterns
# ================================================================
print('Generating error_patterns.png...')
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
def get_error_pairs(labels, preds, class_names, top_k=10):
"""Get most common misclassification pairs."""
errors = []
for true_l, pred_l in zip(labels, preds):
if true_l != pred_l:
errors.append(f'{class_names[true_l]} -> {class_names[pred_l]}')
counter = Counter(errors)
return counter.most_common(top_k)
# APTOS errors
ax = axes[0]
if aptos_mask.sum() > 0:
aptos_errors = get_error_pairs(all_labels[aptos_mask], all_preds[aptos_mask], CLASS_NAMES, top_k=8)
if aptos_errors:
pairs, counts = zip(*aptos_errors)
y_pos = np.arange(len(pairs))
bars = ax.barh(y_pos, counts, color='#2196F3', edgecolor='white', linewidth=0.5)
ax.set_yticks(y_pos)
ax.set_yticklabels(pairs, fontsize=10)
ax.invert_yaxis()
for bar, count in zip(bars, counts):
ax.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height() / 2,
str(count), va='center', fontsize=10, fontweight='bold')
else:
ax.text(0.5, 0.5, 'No errors!', ha='center', va='center', fontsize=14,
transform=ax.transAxes)
n_aptos_errors = (all_labels[aptos_mask] != all_preds[aptos_mask]).sum() if aptos_mask.sum() > 0 else 0
ax.set_title(f'APTOS Error Patterns\n({n_aptos_errors} errors / {aptos_mask.sum()} images)',
fontsize=13, fontweight='bold')
ax.set_xlabel('Count', fontsize=12)
ax.set_ylabel('True -> Predicted', fontsize=12)
ax.grid(axis='x', alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# ODIR errors
ax = axes[1]
if odir_mask.sum() > 0:
odir_errors = get_error_pairs(all_labels[odir_mask], all_preds[odir_mask], CLASS_NAMES, top_k=10)
if odir_errors:
pairs, counts = zip(*odir_errors)
y_pos = np.arange(len(pairs))
# Color by severity
pair_colors = []
for p in pairs:
if 'Normal' in p:
pair_colors.append('#FF9800')
elif 'DR' in p:
pair_colors.append('#F44336')
else:
pair_colors.append('#9C27B0')
bars = ax.barh(y_pos, counts, color=pair_colors, edgecolor='white', linewidth=0.5)
ax.set_yticks(y_pos)
ax.set_yticklabels(pairs, fontsize=10)
ax.invert_yaxis()
for bar, count in zip(bars, counts):
ax.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height() / 2,
str(count), va='center', fontsize=10, fontweight='bold')
else:
ax.text(0.5, 0.5, 'No errors!', ha='center', va='center', fontsize=14,
transform=ax.transAxes)
n_odir_errors = (all_labels[odir_mask] != all_preds[odir_mask]).sum() if odir_mask.sum() > 0 else 0
ax.set_title(f'ODIR Error Patterns\n({n_odir_errors} errors / {odir_mask.sum()} images)',
fontsize=13, fontweight='bold')
ax.set_xlabel('Count', fontsize=12)
ax.set_ylabel('True -> Predicted', fontsize=12)
ax.grid(axis='x', alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
fig.savefig(os.path.join(FAIRNESS_DIR, 'error_patterns.png'),
dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print(' Saved error_patterns.png')
# ================================================================
# COMPUTE ECE (Expected Calibration Error) per source
# ================================================================
def compute_ece(labels, probs, preds, n_bins=15):
"""Compute Expected Calibration Error."""
confidences = np.max(probs, axis=1)
correct = (labels == preds).astype(float)
bin_edges = np.linspace(0, 1, n_bins + 1)
ece = 0.0
for b in range(n_bins):
lo, hi = bin_edges[b], bin_edges[b + 1]
mask = (confidences >= lo) & (confidences < hi) if b < n_bins - 1 \
else (confidences >= lo) & (confidences <= hi)
if mask.sum() == 0:
continue
bin_acc = correct[mask].mean()
bin_conf = confidences[mask].mean()
ece += (mask.sum() / len(labels)) * abs(bin_acc - bin_conf)
return float(ece)
# ================================================================
# STATISTICAL SIGNIFICANCE TEST
# ================================================================
print('\nRunning statistical tests...')
# Chi-squared test: is there a significant difference in accuracy between
# APTOS-DR and ODIR-DR?
aptos_dr_mask = (all_sources == 'APTOS') & (all_labels == 1)
odir_dr_mask = (all_sources == 'ODIR') & (all_labels == 1)
stat_results = {}
if aptos_dr_mask.sum() > 0 and odir_dr_mask.sum() > 0:
aptos_dr_correct = all_correct[aptos_dr_mask].sum()
aptos_dr_total = aptos_dr_mask.sum()
odir_dr_correct = all_correct[odir_dr_mask].sum()
odir_dr_total = odir_dr_mask.sum()
# Contingency table: [[aptos_correct, aptos_wrong], [odir_correct, odir_wrong]]
contingency = np.array([
[aptos_dr_correct, aptos_dr_total - aptos_dr_correct],
[odir_dr_correct, odir_dr_total - odir_dr_correct],
])
chi2, p_value, dof, expected = stats.chi2_contingency(contingency)
stat_results = {
'test': 'chi-squared',
'chi2_statistic': float(round(chi2, 4)),
'p_value': float(p_value),
'degrees_of_freedom': int(dof),
'significant_at_005': bool(p_value < 0.05),
'aptos_dr_accuracy': float(round(aptos_dr_correct / aptos_dr_total, 4)),
'odir_dr_accuracy': float(round(odir_dr_correct / odir_dr_total, 4)),
'aptos_dr_n': int(aptos_dr_total),
'odir_dr_n': int(odir_dr_total),
}
print(f' Chi-squared test (DR: APTOS vs ODIR):')
print(f' APTOS-DR accuracy: {aptos_dr_correct}/{aptos_dr_total} = '
f'{aptos_dr_correct / aptos_dr_total:.4f}')
print(f' ODIR-DR accuracy: {odir_dr_correct}/{odir_dr_total} = '
f'{odir_dr_correct / odir_dr_total:.4f}')
print(f' chi2 = {chi2:.4f}, p = {p_value:.6f}, significant = {p_value < 0.05}')
# ================================================================
# BUILD DOMAIN GAP REPORT
# ================================================================
print('\nBuilding domain_gap_report.json...')
# Per-source ECE
source_ece = {}
for src in sources_unique:
mask = source_masks[src]
if mask.sum() >= 10:
ece = compute_ece(all_labels[mask], all_probs[mask], all_preds[mask])
source_ece[src] = round(ece, 4)
# Compute domain gap (difference in accuracy)
overall_accs = {}
for src in sources_unique:
overall_accs[src] = source_metrics[src]['Overall']['accuracy']
domain_gap = {}
if 'APTOS' in overall_accs and 'ODIR' in overall_accs:
# For DR specifically
aptos_dr_acc = stat_results.get('aptos_dr_accuracy', None)
odir_dr_acc = stat_results.get('odir_dr_accuracy', None)
if aptos_dr_acc is not None and odir_dr_acc is not None:
domain_gap['dr_accuracy_gap'] = round(abs(aptos_dr_acc - odir_dr_acc), 4)
domain_gap['dr_gap_direction'] = 'APTOS > ODIR' if aptos_dr_acc > odir_dr_acc else 'ODIR > APTOS'
domain_gap['overall_accuracy_gap'] = round(abs(overall_accs['APTOS'] - overall_accs['ODIR']), 4)
domain_gap['overall_gap_direction'] = (
'APTOS > ODIR' if overall_accs['APTOS'] > overall_accs['ODIR'] else 'ODIR > APTOS'
)
# Mean confidence by source and correctness
confidence_stats = {}
for src in sources_unique:
mask = source_masks[src]
correct_mask = mask & (all_correct == 1)
incorrect_mask = mask & (all_correct == 0)
confidence_stats[src] = {
'mean_confidence': round(float(all_confidence[mask].mean()), 4),
'mean_confidence_correct': round(float(all_confidence[correct_mask].mean()), 4) if correct_mask.sum() > 0 else None,
'mean_confidence_incorrect': round(float(all_confidence[incorrect_mask].mean()), 4) if incorrect_mask.sum() > 0 else None,
'n_correct': int(correct_mask.sum()),
'n_incorrect': int(incorrect_mask.sum()),
}
# Key findings
findings = []
# 1. Domain gap in DR
if stat_results:
gap = domain_gap.get('dr_accuracy_gap', 0)
direction = domain_gap.get('dr_gap_direction', '')
sig = stat_results.get('significant_at_005', False)
findings.append(
f"DR accuracy gap: {gap:.1%} ({direction}). "
f"Chi-squared p={stat_results['p_value']:.4f} "
f"({'statistically significant' if sig else 'not statistically significant'} at alpha=0.05)."
)
# 2. Calibration
for src in sources_unique:
if src in source_ece:
findings.append(f"{src} ECE (Expected Calibration Error) = {source_ece[src]:.4f}.")
# 3. Confidence analysis
for src in sources_unique:
cs = confidence_stats.get(src, {})
mc_corr = cs.get('mean_confidence_correct')
mc_incorr = cs.get('mean_confidence_incorrect')
if mc_corr is not None and mc_incorr is not None:
findings.append(
f"{src}: mean confidence on correct={mc_corr:.3f}, incorrect={mc_incorr:.3f} "
f"(separation={mc_corr - mc_incorr:.3f})."
)
# 4. Error pattern analysis
if aptos_mask.sum() > 0:
aptos_errors_list = get_error_pairs(all_labels[aptos_mask], all_preds[aptos_mask], CLASS_NAMES, top_k=3)
if aptos_errors_list:
top_err = aptos_errors_list[0]
findings.append(
f"APTOS top error: {top_err[0]} ({top_err[1]} instances). "
f"Total APTOS errors: {n_aptos_errors}/{aptos_mask.sum()} "
f"({n_aptos_errors / aptos_mask.sum():.1%})."
)
if odir_mask.sum() > 0:
odir_errors_list = get_error_pairs(all_labels[odir_mask], all_preds[odir_mask], CLASS_NAMES, top_k=3)
if odir_errors_list:
top_err = odir_errors_list[0]
findings.append(
f"ODIR top error: {top_err[0]} ({top_err[1]} instances). "
f"Total ODIR errors: {n_odir_errors}/{odir_mask.sum()} "
f"({n_odir_errors / odir_mask.sum():.1%})."
)
# Assemble report
report = {
'phase': '1D - Fairness & Domain Robustness Analysis',
'model': 'RetinaSense-ViT v3 (vit_base_patch16_224)',
'test_set_size': int(len(all_labels)),
'temperature': TEMPERATURE,
'source_distribution': {src: int((all_sources == src).sum()) for src in sources_unique},
'per_source_metrics': {},
'domain_gap': domain_gap,
'statistical_test': stat_results,
'calibration': {
'ece_by_source': source_ece,
},
'confidence_analysis': confidence_stats,
'key_findings': findings,
}
# Add per-source metrics (convert to serialisable format)
for src in sources_unique:
report['per_source_metrics'][src] = source_metrics[src]
report_path = os.path.join(FAIRNESS_DIR, 'domain_gap_report.json')
with open(report_path, 'w') as f:
json.dump(report, f, indent=2)
print(f' Saved domain_gap_report.json')
# ================================================================
# SUMMARY
# ================================================================
print('\n' + '=' * 65)
print(' FAIRNESS ANALYSIS COMPLETE')
print('=' * 65)
print(f'\n Output directory: {FAIRNESS_DIR}')
print(f' Files generated:')
for fname in ['performance_by_source.png', 'calibration_by_source.png',
'confusion_matrix_aptos.png', 'confusion_matrix_odir.png',
'confidence_by_source.png', 'error_patterns.png',
'domain_gap_report.json']:
fpath = os.path.join(FAIRNESS_DIR, fname)
exists = os.path.exists(fpath)
size = os.path.getsize(fpath) if exists else 0
status = f'{size / 1024:.1f} KB' if exists else 'MISSING'
print(f' {fname:35s} {status}')
print(f'\n KEY FINDINGS:')
for i, finding in enumerate(findings, 1):
print(f' {i}. {finding}')
print('=' * 65)