| |
| """ |
| RetinaSense v3.0 — Phase 1D: Fairness & Domain Robustness Analysis |
| =================================================================== |
| Evaluates model performance across data sources (APTOS, ODIR, REFUGE2) |
| to quantify domain gap and identify fairness concerns. |
| |
| Outputs (saved to outputs_v3/fairness/): |
| - performance_by_source.png : grouped bar chart of metrics per class per source |
| - calibration_by_source.png : reliability diagrams by source |
| - confusion_matrix_aptos.png : confusion matrix for APTOS subset |
| - confusion_matrix_odir.png : confusion matrix for ODIR subset |
| - confidence_by_source.png : violin plots of prediction confidence |
| - error_patterns.png : most common misclassification pairs by source |
| - domain_gap_report.json : full quantitative report |
| """ |
|
|
| import os |
| import sys |
| import json |
| import warnings |
| import numpy as np |
| import pandas as pd |
| import cv2 |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
| import seaborn as sns |
| from PIL import Image |
| from tqdm import tqdm |
| from collections import Counter, defaultdict |
|
|
| warnings.filterwarnings('ignore') |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from torch.utils.data import Dataset, DataLoader |
|
|
| import timm |
|
|
| from sklearn.metrics import ( |
| accuracy_score, f1_score, precision_score, recall_score, |
| confusion_matrix, classification_report |
| ) |
| from scipy import stats |
|
|
| |
| |
| |
| BASE_DIR = '/teamspace/studios/this_studio' |
| OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3') |
| FAIRNESS_DIR = os.path.join(OUTPUT_DIR, 'fairness') |
| os.makedirs(FAIRNESS_DIR, exist_ok=True) |
|
|
| MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth') |
| TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv') |
| NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json') |
| TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json') |
|
|
| NUM_CLASSES = 5 |
| IMG_SIZE = 224 |
| DROPOUT = 0.3 |
| BATCH_SIZE = 32 |
|
|
| CLASS_NAMES = ['Normal', 'DR', 'Glaucoma', 'Cataract', 'AMD'] |
|
|
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| print('=' * 65) |
| print(' RetinaSense v3.0 -- Phase 1D: Fairness & Domain Robustness') |
| print('=' * 65) |
| print(f' Device : {DEVICE}') |
| if torch.cuda.is_available(): |
| print(f' GPU : {torch.cuda.get_device_name(0)}') |
| print(f' Output : {FAIRNESS_DIR}') |
| print('=' * 65) |
|
|
| |
| |
| |
| if os.path.exists(NORM_STATS_PATH): |
| with open(NORM_STATS_PATH) as f: |
| norm_stats = json.load(f) |
| NORM_MEAN = norm_stats['mean_rgb'] |
| NORM_STD = norm_stats['std_rgb'] |
| print(f' Fundus norm: mean={[round(v, 4) for v in NORM_MEAN]}, ' |
| f'std={[round(v, 4) for v in NORM_STD]}') |
| else: |
| NORM_MEAN = [0.485, 0.456, 0.406] |
| NORM_STD = [0.229, 0.224, 0.225] |
| print(' Using ImageNet normalisation fallback') |
|
|
| |
| with open(TEMPERATURE_PATH) as f: |
| temp_data = json.load(f) |
| TEMPERATURE = temp_data['temperature'] |
| print(f' Temperature T = {TEMPERATURE:.4f}') |
|
|
|
|
| |
| |
| |
| class MultiTaskViT(nn.Module): |
| """ViT-Base-Patch16-224 with disease + severity heads.""" |
|
|
| def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT): |
| super().__init__() |
| self.backbone = timm.create_model( |
| 'vit_base_patch16_224', pretrained=False, num_classes=0 |
| ) |
| feat = 768 |
|
|
| self.drop = nn.Dropout(drop) |
|
|
| self.disease_head = nn.Sequential( |
| nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), |
| nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), |
| nn.Linear(256, n_disease), |
| ) |
| self.severity_head = nn.Sequential( |
| nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), |
| nn.Linear(256, n_severity), |
| ) |
|
|
| def forward(self, x): |
| f = self.backbone(x) |
| f = self.drop(f) |
| return self.disease_head(f), self.severity_head(f) |
|
|
|
|
| |
| |
| |
| print('\nLoading model...') |
| model = MultiTaskViT().to(DEVICE) |
| ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) |
| model.load_state_dict(ckpt['model_state_dict']) |
| model.eval() |
| print(f' Loaded: {MODEL_PATH}') |
| print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1} ' |
| f'val_acc={ckpt.get("val_acc", 0):.2f}%') |
|
|
|
|
| |
| |
| |
| class FairnessDataset(Dataset): |
| """Loads preprocessed cached images for inference.""" |
|
|
| def __init__(self, csv_path, base_dir): |
| self.df = pd.read_csv(csv_path) |
| self.base_dir = base_dir |
| self.transform = transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.ToTensor(), |
| transforms.Normalize(NORM_MEAN, NORM_STD), |
| ]) |
| print(f' Test set: {len(self.df)} images') |
| print(f' Sources : {dict(self.df["source"].value_counts())}') |
|
|
| def __len__(self): |
| return len(self.df) |
|
|
| def __getitem__(self, idx): |
| row = self.df.iloc[idx] |
| label = int(row['disease_label']) |
| source = row['source'] |
|
|
| |
| cache_path = row.get('cache_path', '') |
| if isinstance(cache_path, str) and cache_path: |
| |
| if not os.path.isabs(cache_path): |
| clean = cache_path |
| while clean.startswith('./') or clean.startswith('.//'): |
| clean = clean[2:] if clean.startswith('./') else clean[3:] |
| cache_path = os.path.join(self.base_dir, clean) |
| if os.path.exists(cache_path): |
| img = np.load(cache_path) |
| tensor = self.transform(img) |
| return tensor, label, source, idx |
|
|
| |
| image_path = row['image_path'] |
| if not os.path.isabs(image_path): |
| clean = image_path |
| while clean.startswith('./') or clean.startswith('.//'): |
| clean = clean[2:] if clean.startswith('./') else clean[3:] |
| image_path = os.path.join(self.base_dir, clean) |
|
|
| if source == 'APTOS': |
| img = self._ben_graham(image_path) |
| else: |
| img = self._clahe_preprocess(image_path) |
|
|
| tensor = self.transform(img) |
| return tensor, label, source, idx |
|
|
| @staticmethod |
| def _ben_graham(path, sz=IMG_SIZE, sigma=10): |
| img = cv2.imread(path) |
| if img is None: |
| img = np.array(Image.open(path).convert('RGB')) |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| img = cv2.resize(img, (sz, sz)) |
| img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0, 0), sigma), -4, 128) |
| mask = np.zeros(img.shape[:2], dtype=np.uint8) |
| cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1) |
| return cv2.bitwise_and(img, img, mask=mask) |
|
|
| @staticmethod |
| def _clahe_preprocess(path, sz=IMG_SIZE): |
| img = cv2.imread(path) |
| if img is None: |
| img = np.array(Image.open(path).convert('RGB')) |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| img = cv2.resize(img, (sz, sz)) |
| lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
| lab[:, :, 0] = clahe.apply(lab[:, :, 0]) |
| img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) |
| return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
| |
| |
| |
| print('\nRunning inference on test set...') |
| dataset = FairnessDataset(TEST_CSV, BASE_DIR) |
| loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, |
| num_workers=4, pin_memory=True) |
|
|
| all_labels = [] |
| all_preds = [] |
| all_probs = [] |
| all_logits = [] |
| all_sources = [] |
| all_indices = [] |
|
|
| with torch.no_grad(): |
| for batch_imgs, batch_labels, batch_sources, batch_idx in tqdm(loader, desc='Inference'): |
| batch_imgs = batch_imgs.to(DEVICE) |
| disease_logits, _ = model(batch_imgs) |
|
|
| |
| scaled_logits = disease_logits / TEMPERATURE |
| probs = F.softmax(scaled_logits, dim=1) |
| preds = probs.argmax(dim=1) |
|
|
| all_labels.extend(batch_labels.numpy().tolist()) |
| all_preds.extend(preds.cpu().numpy().tolist()) |
| all_probs.extend(probs.cpu().numpy().tolist()) |
| all_logits.extend(disease_logits.cpu().numpy().tolist()) |
| all_sources.extend(list(batch_sources)) |
| all_indices.extend(batch_idx.numpy().tolist()) |
|
|
| all_labels = np.array(all_labels) |
| all_preds = np.array(all_preds) |
| all_probs = np.array(all_probs) |
| all_logits = np.array(all_logits) |
| all_sources = np.array(all_sources) |
| all_correct = (all_labels == all_preds).astype(int) |
|
|
| |
| all_confidence = np.max(all_probs, axis=1) |
|
|
| print(f'\n Total images: {len(all_labels)}') |
| print(f' Overall accuracy: {accuracy_score(all_labels, all_preds):.4f}') |
| print(f' Sources: {Counter(all_sources)}') |
|
|
|
|
| |
| |
| |
| def compute_metrics(labels, preds, probs=None, class_names=CLASS_NAMES): |
| """Compute per-class and overall metrics.""" |
| present_classes = sorted(set(labels)) |
| results = {} |
|
|
| for c in range(len(class_names)): |
| mask = labels == c |
| n_c = mask.sum() |
| if n_c == 0: |
| results[class_names[c]] = { |
| 'n': 0, 'accuracy': None, 'precision': None, |
| 'recall': None, 'f1': None |
| } |
| continue |
| |
| y_true_bin = (labels == c).astype(int) |
| y_pred_bin = (preds == c).astype(int) |
|
|
| tp = ((y_true_bin == 1) & (y_pred_bin == 1)).sum() |
| fn = ((y_true_bin == 1) & (y_pred_bin == 0)).sum() |
| fp = ((y_true_bin == 0) & (y_pred_bin == 1)).sum() |
|
|
| prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 |
| acc = (labels[mask] == preds[mask]).mean() |
|
|
| results[class_names[c]] = { |
| 'n': int(n_c), |
| 'accuracy': float(round(acc, 4)), |
| 'precision': float(round(prec, 4)), |
| 'recall': float(round(rec, 4)), |
| 'f1': float(round(f1, 4)), |
| } |
|
|
| |
| overall_acc = accuracy_score(labels, preds) |
| present_labels = sorted(set(labels) | set(preds)) |
| overall_f1 = f1_score(labels, preds, labels=present_labels, average='macro', zero_division=0) |
| overall_prec = precision_score(labels, preds, labels=present_labels, average='macro', zero_division=0) |
| overall_rec = recall_score(labels, preds, labels=present_labels, average='macro', zero_division=0) |
|
|
| results['Overall'] = { |
| 'n': int(len(labels)), |
| 'accuracy': float(round(overall_acc, 4)), |
| 'precision': float(round(overall_prec, 4)), |
| 'recall': float(round(overall_rec, 4)), |
| 'f1': float(round(overall_f1, 4)), |
| } |
|
|
| return results |
|
|
|
|
| |
| |
| |
| sources_unique = sorted(set(all_sources)) |
| print(f'\n Unique sources: {sources_unique}') |
|
|
| source_masks = {} |
| source_metrics = {} |
| for src in sources_unique: |
| mask = all_sources == src |
| source_masks[src] = mask |
| labels_s = all_labels[mask] |
| preds_s = all_preds[mask] |
| probs_s = all_probs[mask] |
| metrics = compute_metrics(labels_s, preds_s, probs_s) |
| source_metrics[src] = metrics |
| acc = metrics['Overall']['accuracy'] |
| f1 = metrics['Overall']['f1'] |
| print(f' {src:8s}: n={mask.sum():4d} acc={acc:.4f} macro-F1={f1:.4f}') |
|
|
|
|
| |
| |
| |
| print('\nGenerating performance_by_source.png...') |
|
|
| |
| fig, axes = plt.subplots(1, 2, figsize=(18, 7)) |
| plt.subplots_adjust(wspace=0.35) |
|
|
| |
| metric_names = ['accuracy', 'f1', 'precision', 'recall'] |
| metric_labels = ['Accuracy', 'F1-Score', 'Precision', 'Recall'] |
| bar_colors = {'APTOS': '#2196F3', 'ODIR': '#FF9800', 'REFUGE2': '#4CAF50'} |
|
|
| |
| ax = axes[0] |
| x = np.arange(len(metric_names)) |
| bar_width = 0.25 |
| offsets = np.arange(len(sources_unique)) - (len(sources_unique) - 1) / 2 |
|
|
| for i, src in enumerate(sources_unique): |
| dr_metrics = source_metrics[src].get('DR', {}) |
| vals = [] |
| for m in metric_names: |
| v = dr_metrics.get(m, None) |
| vals.append(v if v is not None else 0) |
| bars = ax.bar(x + offsets[i] * bar_width, vals, bar_width, |
| label=f'{src} (n={dr_metrics.get("n", 0)})', |
| color=bar_colors.get(src, '#999999'), edgecolor='white', linewidth=0.5) |
| |
| for bar, val in zip(bars, vals): |
| if val > 0: |
| ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, |
| f'{val:.2f}', ha='center', va='bottom', fontsize=8, fontweight='bold') |
|
|
| ax.set_xlabel('Metric', fontsize=12) |
| ax.set_ylabel('Score', fontsize=12) |
| ax.set_title('DR Classification: APTOS vs ODIR vs REFUGE2', fontsize=13, fontweight='bold') |
| ax.set_xticks(x) |
| ax.set_xticklabels(metric_labels) |
| ax.set_ylim(0, 1.15) |
| ax.legend(fontsize=9) |
| ax.grid(axis='y', alpha=0.3) |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
|
|
| |
| ax = axes[1] |
| odir_metrics = source_metrics.get('ODIR', {}) |
| classes_present = [c for c in CLASS_NAMES if odir_metrics.get(c, {}).get('n', 0) > 0] |
| x2 = np.arange(len(classes_present)) |
| bar_width2 = 0.18 |
| colors_metric = ['#1976D2', '#388E3C', '#F57C00', '#D32F2F'] |
|
|
| for i, (m, ml) in enumerate(zip(metric_names, metric_labels)): |
| vals = [odir_metrics[c][m] if odir_metrics[c][m] is not None else 0 for c in classes_present] |
| bars = ax.bar(x2 + (i - 1.5) * bar_width2, vals, bar_width2, |
| label=ml, color=colors_metric[i], edgecolor='white', linewidth=0.5) |
| for bar, val in zip(bars, vals): |
| if val > 0: |
| ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, |
| f'{val:.2f}', ha='center', va='bottom', fontsize=7, fontweight='bold') |
|
|
| ax.set_xlabel('Disease Class', fontsize=12) |
| ax.set_ylabel('Score', fontsize=12) |
| ax.set_title('ODIR Per-Class Performance', fontsize=13, fontweight='bold') |
| class_labels = [f'{c}\n(n={odir_metrics[c]["n"]})' for c in classes_present] |
| ax.set_xticks(x2) |
| ax.set_xticklabels(class_labels) |
| ax.set_ylim(0, 1.15) |
| ax.legend(fontsize=9, ncol=2) |
| ax.grid(axis='y', alpha=0.3) |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
|
|
| fig.suptitle('Fairness & Domain Robustness: Performance by Source', |
| fontsize=15, fontweight='bold', y=1.02) |
| plt.tight_layout() |
| fig.savefig(os.path.join(FAIRNESS_DIR, 'performance_by_source.png'), |
| dpi=300, bbox_inches='tight', facecolor='white') |
| plt.close(fig) |
| print(' Saved performance_by_source.png') |
|
|
|
|
| |
| |
| |
| print('Generating calibration_by_source.png...') |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 6)) |
|
|
| |
| def reliability_diagram(labels, probs, preds, n_bins=10): |
| """Compute reliability diagram data (mean predicted confidence vs actual accuracy per bin).""" |
| confidences = np.max(probs, axis=1) |
| correct = (labels == preds).astype(float) |
|
|
| bin_edges = np.linspace(0, 1, n_bins + 1) |
| bin_centers = [] |
| bin_accs = [] |
| bin_confs = [] |
| bin_counts = [] |
|
|
| for b in range(n_bins): |
| lo, hi = bin_edges[b], bin_edges[b + 1] |
| mask = (confidences >= lo) & (confidences < hi) if b < n_bins - 1 \ |
| else (confidences >= lo) & (confidences <= hi) |
| if mask.sum() == 0: |
| continue |
| bin_centers.append((lo + hi) / 2) |
| bin_accs.append(correct[mask].mean()) |
| bin_confs.append(confidences[mask].mean()) |
| bin_counts.append(int(mask.sum())) |
|
|
| return np.array(bin_centers), np.array(bin_accs), np.array(bin_confs), np.array(bin_counts) |
|
|
|
|
| |
| ax = axes[0] |
| ax.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect calibration') |
| source_colors = {'APTOS': '#2196F3', 'ODIR': '#FF9800', 'REFUGE2': '#4CAF50'} |
|
|
| for src in sources_unique: |
| mask = source_masks[src] |
| if mask.sum() < 10: |
| continue |
| centers, accs, confs, counts = reliability_diagram( |
| all_labels[mask], all_probs[mask], all_preds[mask], n_bins=10 |
| ) |
| ax.plot(confs, accs, 'o-', color=source_colors.get(src, '#999'), |
| label=f'{src} (n={mask.sum()})', markersize=6, linewidth=2) |
|
|
| ax.set_xlabel('Mean Predicted Confidence', fontsize=12) |
| ax.set_ylabel('Actual Accuracy (Fraction Correct)', fontsize=12) |
| ax.set_title('Reliability Diagram by Source', fontsize=13, fontweight='bold') |
| ax.legend(fontsize=10) |
| ax.set_xlim(0, 1) |
| ax.set_ylim(0, 1) |
| ax.grid(alpha=0.3) |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
| ax.set_aspect('equal') |
|
|
| |
| ax = axes[1] |
| for src in sources_unique: |
| mask = source_masks[src] |
| if mask.sum() < 10: |
| continue |
| confs = np.max(all_probs[mask], axis=1) |
| ax.hist(confs, bins=20, alpha=0.5, label=f'{src} (n={mask.sum()})', |
| color=source_colors.get(src, '#999'), density=True, edgecolor='white') |
|
|
| ax.set_xlabel('Prediction Confidence', fontsize=12) |
| ax.set_ylabel('Density', fontsize=12) |
| ax.set_title('Confidence Distribution by Source', fontsize=13, fontweight='bold') |
| ax.legend(fontsize=10) |
| ax.grid(alpha=0.3) |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
|
|
| plt.tight_layout() |
| fig.savefig(os.path.join(FAIRNESS_DIR, 'calibration_by_source.png'), |
| dpi=300, bbox_inches='tight', facecolor='white') |
| plt.close(fig) |
| print(' Saved calibration_by_source.png') |
|
|
|
|
| |
| |
| |
| def plot_confusion_matrix(labels, preds, title, save_path, class_names=CLASS_NAMES): |
| """Plot and save a publication-quality confusion matrix.""" |
| present_classes = sorted(set(labels) | set(preds)) |
| present_names = [class_names[c] for c in present_classes] |
|
|
| cm = confusion_matrix(labels, preds, labels=present_classes) |
| cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True) |
| cm_norm = np.nan_to_num(cm_norm) |
|
|
| fig, ax = plt.subplots(figsize=(8, 7)) |
|
|
| |
| annot = np.empty_like(cm, dtype=object) |
| for i in range(cm.shape[0]): |
| for j in range(cm.shape[1]): |
| annot[i, j] = f'{cm[i, j]}\n({cm_norm[i, j]:.1%})' |
|
|
| sns.heatmap(cm_norm, annot=annot, fmt='', cmap='Blues', |
| xticklabels=present_names, yticklabels=present_names, |
| ax=ax, vmin=0, vmax=1, linewidths=0.5, linecolor='white', |
| cbar_kws={'label': 'Proportion'}) |
|
|
| ax.set_xlabel('Predicted Label', fontsize=12, fontweight='bold') |
| ax.set_ylabel('True Label', fontsize=12, fontweight='bold') |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=15) |
| ax.tick_params(labelsize=10) |
|
|
| plt.tight_layout() |
| fig.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white') |
| plt.close(fig) |
| return cm |
|
|
|
|
| print('Generating confusion matrices...') |
|
|
| |
| aptos_mask = all_sources == 'APTOS' |
| if aptos_mask.sum() > 0: |
| cm_aptos = plot_confusion_matrix( |
| all_labels[aptos_mask], all_preds[aptos_mask], |
| f'Confusion Matrix: APTOS (n={aptos_mask.sum()}, DR images only)', |
| os.path.join(FAIRNESS_DIR, 'confusion_matrix_aptos.png') |
| ) |
| print(' Saved confusion_matrix_aptos.png') |
|
|
| |
| odir_mask = all_sources == 'ODIR' |
| if odir_mask.sum() > 0: |
| cm_odir = plot_confusion_matrix( |
| all_labels[odir_mask], all_preds[odir_mask], |
| f'Confusion Matrix: ODIR (n={odir_mask.sum()}, all 5 classes)', |
| os.path.join(FAIRNESS_DIR, 'confusion_matrix_odir.png') |
| ) |
| print(' Saved confusion_matrix_odir.png') |
|
|
|
|
| |
| |
| |
| print('Generating confidence_by_source.png...') |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(16, 7)) |
|
|
| |
| conf_df = pd.DataFrame({ |
| 'Source': all_sources, |
| 'Confidence': all_confidence, |
| 'Correct': ['Correct' if c else 'Incorrect' for c in all_correct], |
| 'Predicted': [CLASS_NAMES[p] for p in all_preds], |
| 'True': [CLASS_NAMES[l] for l in all_labels], |
| }) |
|
|
| |
| ax = axes[0] |
| plot_sources = [s for s in sources_unique if (all_sources == s).sum() >= 10] |
| conf_df_filtered = conf_df[conf_df['Source'].isin(plot_sources)] |
|
|
| sns.violinplot(data=conf_df_filtered, x='Source', y='Confidence', hue='Correct', |
| split=True, inner='quartile', ax=ax, |
| palette={'Correct': '#4CAF50', 'Incorrect': '#F44336'}) |
|
|
| |
| for i, src in enumerate(plot_sources): |
| for j, corr_label in enumerate(['Correct', 'Incorrect']): |
| mask_sc = (conf_df_filtered['Source'] == src) & (conf_df_filtered['Correct'] == corr_label) |
| if mask_sc.sum() > 0: |
| mean_val = conf_df_filtered.loc[mask_sc, 'Confidence'].mean() |
| offset = -0.05 if j == 0 else 0.05 |
| ax.scatter(i + offset, mean_val, color='black', s=30, zorder=5, marker='D') |
|
|
| ax.set_title('Prediction Confidence by Source & Correctness', fontsize=13, fontweight='bold') |
| ax.set_ylabel('Confidence (max probability)', fontsize=12) |
| ax.set_xlabel('Data Source', fontsize=12) |
| ax.grid(axis='y', alpha=0.3) |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
|
|
| |
| ax = axes[1] |
| dr_df = conf_df[conf_df['True'] == 'DR'] |
| dr_plot_sources = [s for s in sources_unique if ((all_sources == s) & (all_labels == 1)).sum() >= 5] |
| dr_df_filtered = dr_df[dr_df['Source'].isin(dr_plot_sources)] |
|
|
| if len(dr_df_filtered) > 0: |
| sns.violinplot(data=dr_df_filtered, x='Source', y='Confidence', hue='Correct', |
| split=True, inner='quartile', ax=ax, |
| palette={'Correct': '#4CAF50', 'Incorrect': '#F44336'}) |
|
|
| ax.set_title('DR Confidence: APTOS vs ODIR', fontsize=13, fontweight='bold') |
| ax.set_ylabel('Confidence (max probability)', fontsize=12) |
| ax.set_xlabel('Data Source', fontsize=12) |
| ax.grid(axis='y', alpha=0.3) |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
|
|
| plt.tight_layout() |
| fig.savefig(os.path.join(FAIRNESS_DIR, 'confidence_by_source.png'), |
| dpi=300, bbox_inches='tight', facecolor='white') |
| plt.close(fig) |
| print(' Saved confidence_by_source.png') |
|
|
|
|
| |
| |
| |
| print('Generating error_patterns.png...') |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(16, 7)) |
|
|
| def get_error_pairs(labels, preds, class_names, top_k=10): |
| """Get most common misclassification pairs.""" |
| errors = [] |
| for true_l, pred_l in zip(labels, preds): |
| if true_l != pred_l: |
| errors.append(f'{class_names[true_l]} -> {class_names[pred_l]}') |
| counter = Counter(errors) |
| return counter.most_common(top_k) |
|
|
| |
| ax = axes[0] |
| if aptos_mask.sum() > 0: |
| aptos_errors = get_error_pairs(all_labels[aptos_mask], all_preds[aptos_mask], CLASS_NAMES, top_k=8) |
| if aptos_errors: |
| pairs, counts = zip(*aptos_errors) |
| y_pos = np.arange(len(pairs)) |
| bars = ax.barh(y_pos, counts, color='#2196F3', edgecolor='white', linewidth=0.5) |
| ax.set_yticks(y_pos) |
| ax.set_yticklabels(pairs, fontsize=10) |
| ax.invert_yaxis() |
| for bar, count in zip(bars, counts): |
| ax.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height() / 2, |
| str(count), va='center', fontsize=10, fontweight='bold') |
| else: |
| ax.text(0.5, 0.5, 'No errors!', ha='center', va='center', fontsize=14, |
| transform=ax.transAxes) |
|
|
| n_aptos_errors = (all_labels[aptos_mask] != all_preds[aptos_mask]).sum() if aptos_mask.sum() > 0 else 0 |
| ax.set_title(f'APTOS Error Patterns\n({n_aptos_errors} errors / {aptos_mask.sum()} images)', |
| fontsize=13, fontweight='bold') |
| ax.set_xlabel('Count', fontsize=12) |
| ax.set_ylabel('True -> Predicted', fontsize=12) |
| ax.grid(axis='x', alpha=0.3) |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
|
|
| |
| ax = axes[1] |
| if odir_mask.sum() > 0: |
| odir_errors = get_error_pairs(all_labels[odir_mask], all_preds[odir_mask], CLASS_NAMES, top_k=10) |
| if odir_errors: |
| pairs, counts = zip(*odir_errors) |
| y_pos = np.arange(len(pairs)) |
| |
| pair_colors = [] |
| for p in pairs: |
| if 'Normal' in p: |
| pair_colors.append('#FF9800') |
| elif 'DR' in p: |
| pair_colors.append('#F44336') |
| else: |
| pair_colors.append('#9C27B0') |
| bars = ax.barh(y_pos, counts, color=pair_colors, edgecolor='white', linewidth=0.5) |
| ax.set_yticks(y_pos) |
| ax.set_yticklabels(pairs, fontsize=10) |
| ax.invert_yaxis() |
| for bar, count in zip(bars, counts): |
| ax.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height() / 2, |
| str(count), va='center', fontsize=10, fontweight='bold') |
| else: |
| ax.text(0.5, 0.5, 'No errors!', ha='center', va='center', fontsize=14, |
| transform=ax.transAxes) |
|
|
| n_odir_errors = (all_labels[odir_mask] != all_preds[odir_mask]).sum() if odir_mask.sum() > 0 else 0 |
| ax.set_title(f'ODIR Error Patterns\n({n_odir_errors} errors / {odir_mask.sum()} images)', |
| fontsize=13, fontweight='bold') |
| ax.set_xlabel('Count', fontsize=12) |
| ax.set_ylabel('True -> Predicted', fontsize=12) |
| ax.grid(axis='x', alpha=0.3) |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
|
|
| plt.tight_layout() |
| fig.savefig(os.path.join(FAIRNESS_DIR, 'error_patterns.png'), |
| dpi=300, bbox_inches='tight', facecolor='white') |
| plt.close(fig) |
| print(' Saved error_patterns.png') |
|
|
|
|
| |
| |
| |
| def compute_ece(labels, probs, preds, n_bins=15): |
| """Compute Expected Calibration Error.""" |
| confidences = np.max(probs, axis=1) |
| correct = (labels == preds).astype(float) |
| bin_edges = np.linspace(0, 1, n_bins + 1) |
| ece = 0.0 |
| for b in range(n_bins): |
| lo, hi = bin_edges[b], bin_edges[b + 1] |
| mask = (confidences >= lo) & (confidences < hi) if b < n_bins - 1 \ |
| else (confidences >= lo) & (confidences <= hi) |
| if mask.sum() == 0: |
| continue |
| bin_acc = correct[mask].mean() |
| bin_conf = confidences[mask].mean() |
| ece += (mask.sum() / len(labels)) * abs(bin_acc - bin_conf) |
| return float(ece) |
|
|
|
|
| |
| |
| |
| print('\nRunning statistical tests...') |
|
|
| |
| |
| aptos_dr_mask = (all_sources == 'APTOS') & (all_labels == 1) |
| odir_dr_mask = (all_sources == 'ODIR') & (all_labels == 1) |
|
|
| stat_results = {} |
|
|
| if aptos_dr_mask.sum() > 0 and odir_dr_mask.sum() > 0: |
| aptos_dr_correct = all_correct[aptos_dr_mask].sum() |
| aptos_dr_total = aptos_dr_mask.sum() |
| odir_dr_correct = all_correct[odir_dr_mask].sum() |
| odir_dr_total = odir_dr_mask.sum() |
|
|
| |
| contingency = np.array([ |
| [aptos_dr_correct, aptos_dr_total - aptos_dr_correct], |
| [odir_dr_correct, odir_dr_total - odir_dr_correct], |
| ]) |
|
|
| chi2, p_value, dof, expected = stats.chi2_contingency(contingency) |
|
|
| stat_results = { |
| 'test': 'chi-squared', |
| 'chi2_statistic': float(round(chi2, 4)), |
| 'p_value': float(p_value), |
| 'degrees_of_freedom': int(dof), |
| 'significant_at_005': bool(p_value < 0.05), |
| 'aptos_dr_accuracy': float(round(aptos_dr_correct / aptos_dr_total, 4)), |
| 'odir_dr_accuracy': float(round(odir_dr_correct / odir_dr_total, 4)), |
| 'aptos_dr_n': int(aptos_dr_total), |
| 'odir_dr_n': int(odir_dr_total), |
| } |
|
|
| print(f' Chi-squared test (DR: APTOS vs ODIR):') |
| print(f' APTOS-DR accuracy: {aptos_dr_correct}/{aptos_dr_total} = ' |
| f'{aptos_dr_correct / aptos_dr_total:.4f}') |
| print(f' ODIR-DR accuracy: {odir_dr_correct}/{odir_dr_total} = ' |
| f'{odir_dr_correct / odir_dr_total:.4f}') |
| print(f' chi2 = {chi2:.4f}, p = {p_value:.6f}, significant = {p_value < 0.05}') |
|
|
|
|
| |
| |
| |
| print('\nBuilding domain_gap_report.json...') |
|
|
| |
| source_ece = {} |
| for src in sources_unique: |
| mask = source_masks[src] |
| if mask.sum() >= 10: |
| ece = compute_ece(all_labels[mask], all_probs[mask], all_preds[mask]) |
| source_ece[src] = round(ece, 4) |
|
|
| |
| overall_accs = {} |
| for src in sources_unique: |
| overall_accs[src] = source_metrics[src]['Overall']['accuracy'] |
|
|
| domain_gap = {} |
| if 'APTOS' in overall_accs and 'ODIR' in overall_accs: |
| |
| aptos_dr_acc = stat_results.get('aptos_dr_accuracy', None) |
| odir_dr_acc = stat_results.get('odir_dr_accuracy', None) |
| if aptos_dr_acc is not None and odir_dr_acc is not None: |
| domain_gap['dr_accuracy_gap'] = round(abs(aptos_dr_acc - odir_dr_acc), 4) |
| domain_gap['dr_gap_direction'] = 'APTOS > ODIR' if aptos_dr_acc > odir_dr_acc else 'ODIR > APTOS' |
| domain_gap['overall_accuracy_gap'] = round(abs(overall_accs['APTOS'] - overall_accs['ODIR']), 4) |
| domain_gap['overall_gap_direction'] = ( |
| 'APTOS > ODIR' if overall_accs['APTOS'] > overall_accs['ODIR'] else 'ODIR > APTOS' |
| ) |
|
|
| |
| confidence_stats = {} |
| for src in sources_unique: |
| mask = source_masks[src] |
| correct_mask = mask & (all_correct == 1) |
| incorrect_mask = mask & (all_correct == 0) |
| confidence_stats[src] = { |
| 'mean_confidence': round(float(all_confidence[mask].mean()), 4), |
| 'mean_confidence_correct': round(float(all_confidence[correct_mask].mean()), 4) if correct_mask.sum() > 0 else None, |
| 'mean_confidence_incorrect': round(float(all_confidence[incorrect_mask].mean()), 4) if incorrect_mask.sum() > 0 else None, |
| 'n_correct': int(correct_mask.sum()), |
| 'n_incorrect': int(incorrect_mask.sum()), |
| } |
|
|
| |
| findings = [] |
|
|
| |
| if stat_results: |
| gap = domain_gap.get('dr_accuracy_gap', 0) |
| direction = domain_gap.get('dr_gap_direction', '') |
| sig = stat_results.get('significant_at_005', False) |
| findings.append( |
| f"DR accuracy gap: {gap:.1%} ({direction}). " |
| f"Chi-squared p={stat_results['p_value']:.4f} " |
| f"({'statistically significant' if sig else 'not statistically significant'} at alpha=0.05)." |
| ) |
|
|
| |
| for src in sources_unique: |
| if src in source_ece: |
| findings.append(f"{src} ECE (Expected Calibration Error) = {source_ece[src]:.4f}.") |
|
|
| |
| for src in sources_unique: |
| cs = confidence_stats.get(src, {}) |
| mc_corr = cs.get('mean_confidence_correct') |
| mc_incorr = cs.get('mean_confidence_incorrect') |
| if mc_corr is not None and mc_incorr is not None: |
| findings.append( |
| f"{src}: mean confidence on correct={mc_corr:.3f}, incorrect={mc_incorr:.3f} " |
| f"(separation={mc_corr - mc_incorr:.3f})." |
| ) |
|
|
| |
| if aptos_mask.sum() > 0: |
| aptos_errors_list = get_error_pairs(all_labels[aptos_mask], all_preds[aptos_mask], CLASS_NAMES, top_k=3) |
| if aptos_errors_list: |
| top_err = aptos_errors_list[0] |
| findings.append( |
| f"APTOS top error: {top_err[0]} ({top_err[1]} instances). " |
| f"Total APTOS errors: {n_aptos_errors}/{aptos_mask.sum()} " |
| f"({n_aptos_errors / aptos_mask.sum():.1%})." |
| ) |
|
|
| if odir_mask.sum() > 0: |
| odir_errors_list = get_error_pairs(all_labels[odir_mask], all_preds[odir_mask], CLASS_NAMES, top_k=3) |
| if odir_errors_list: |
| top_err = odir_errors_list[0] |
| findings.append( |
| f"ODIR top error: {top_err[0]} ({top_err[1]} instances). " |
| f"Total ODIR errors: {n_odir_errors}/{odir_mask.sum()} " |
| f"({n_odir_errors / odir_mask.sum():.1%})." |
| ) |
|
|
| |
| report = { |
| 'phase': '1D - Fairness & Domain Robustness Analysis', |
| 'model': 'RetinaSense-ViT v3 (vit_base_patch16_224)', |
| 'test_set_size': int(len(all_labels)), |
| 'temperature': TEMPERATURE, |
| 'source_distribution': {src: int((all_sources == src).sum()) for src in sources_unique}, |
| 'per_source_metrics': {}, |
| 'domain_gap': domain_gap, |
| 'statistical_test': stat_results, |
| 'calibration': { |
| 'ece_by_source': source_ece, |
| }, |
| 'confidence_analysis': confidence_stats, |
| 'key_findings': findings, |
| } |
|
|
| |
| for src in sources_unique: |
| report['per_source_metrics'][src] = source_metrics[src] |
|
|
| report_path = os.path.join(FAIRNESS_DIR, 'domain_gap_report.json') |
| with open(report_path, 'w') as f: |
| json.dump(report, f, indent=2) |
|
|
| print(f' Saved domain_gap_report.json') |
|
|
| |
| |
| |
| print('\n' + '=' * 65) |
| print(' FAIRNESS ANALYSIS COMPLETE') |
| print('=' * 65) |
| print(f'\n Output directory: {FAIRNESS_DIR}') |
| print(f' Files generated:') |
| for fname in ['performance_by_source.png', 'calibration_by_source.png', |
| 'confusion_matrix_aptos.png', 'confusion_matrix_odir.png', |
| 'confidence_by_source.png', 'error_patterns.png', |
| 'domain_gap_report.json']: |
| fpath = os.path.join(FAIRNESS_DIR, fname) |
| exists = os.path.exists(fpath) |
| size = os.path.getsize(fpath) if exists else 0 |
| status = f'{size / 1024:.1f} KB' if exists else 'MISSING' |
| print(f' {fname:35s} {status}') |
|
|
| print(f'\n KEY FINDINGS:') |
| for i, finding in enumerate(findings, 1): |
| print(f' {i}. {finding}') |
| print('=' * 65) |
|
|