tanishq74 commited on
Commit
9022502
·
verified ·
1 Parent(s): e5d94b0

Add gradcam_v3.py

Browse files
Files changed (1) hide show
  1. gradcam_v3.py +1179 -0
gradcam_v3.py ADDED
@@ -0,0 +1,1179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RetinaSense v3.0 — Grad-CAM Explainability Pipeline
4
+ ====================================================
5
+ Implements:
6
+ 1. ViTGradCAM : Gradient-weighted Class Activation Maps for ViT backbone
7
+ 2. OODDetector : Mahalanobis-distance out-of-distribution detection
8
+ 3. predict_with_gradcam : Full inference pipeline (preprocess → OOD → CAM → calibrate)
9
+ 4. Batch evaluation on 20 test images (4 per class)
10
+ 5. Disease-specific heatmap validation against known anatomical regions
11
+ 6. Clinical output report (GRADCAM_REPORT.md)
12
+
13
+ Usage:
14
+ python gradcam_v3.py
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import json
20
+ import warnings
21
+ import numpy as np
22
+ import cv2
23
+ import matplotlib
24
+ matplotlib.use('Agg')
25
+ import matplotlib.pyplot as plt
26
+ import matplotlib.patches as mpatches
27
+ from PIL import Image
28
+ from datetime import datetime
29
+ import time
30
+
31
+ warnings.filterwarnings('ignore')
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ from torchvision import transforms
37
+
38
+ import timm
39
+
40
+ # ================================================================
41
+ # CONFIGURATION
42
+ # ================================================================
43
+ BASE_DIR = '/teamspace/studios/this_studio'
44
+ OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3')
45
+ GRADCAM_DIR = os.path.join(OUTPUT_DIR, 'gradcam')
46
+ os.makedirs(GRADCAM_DIR, exist_ok=True)
47
+
48
+ MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth')
49
+ THRESHOLDS_PATH = os.path.join(OUTPUT_DIR, 'thresholds.json')
50
+ TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json')
51
+ TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv')
52
+ NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json')
53
+
54
+ CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
55
+ NUM_CLASSES = 5
56
+ IMG_SIZE = 224
57
+ DROPOUT = 0.3
58
+
59
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
+
61
+ # Anatomical regions expected for each disease class
62
+ EXPECTED_REGIONS = {
63
+ 0: 'low uniform activation (Normal)',
64
+ 1: 'scattered periphery and macula (DR)',
65
+ 2: 'optic disc (Glaucoma)',
66
+ 3: 'diffuse lens opacity (Cataract)',
67
+ 4: 'macula/centre-temporal (AMD)',
68
+ }
69
+
70
+ print('=' * 65)
71
+ print(' RetinaSense v3.0 — Grad-CAM Explainability Pipeline')
72
+ print('=' * 65)
73
+ print(f' Device : {DEVICE}')
74
+ if torch.cuda.is_available():
75
+ print(f' GPU : {torch.cuda.get_device_name(0)}')
76
+ print(f' Output : {GRADCAM_DIR}')
77
+ print('=' * 65)
78
+
79
+
80
+ # ================================================================
81
+ # LOAD NORMALISATION STATS
82
+ # ================================================================
83
+ if os.path.exists(NORM_STATS_PATH):
84
+ with open(NORM_STATS_PATH) as f:
85
+ norm_stats = json.load(f)
86
+ NORM_MEAN = norm_stats['mean_rgb']
87
+ NORM_STD = norm_stats['std_rgb']
88
+ print(f' Fundus norm stats: mean={[round(v,4) for v in NORM_MEAN]}, std={[round(v,4) for v in NORM_STD]}')
89
+ else:
90
+ NORM_MEAN = [0.485, 0.456, 0.406]
91
+ NORM_STD = [0.229, 0.224, 0.225]
92
+ print(' Using ImageNet normalisation fallback')
93
+
94
+
95
+ # ================================================================
96
+ # MODEL ARCHITECTURE (mirrors retinasense_v3.py exactly)
97
+ # ================================================================
98
+ class MultiTaskViT(nn.Module):
99
+ """ViT-Base-Patch16-224 with disease + severity heads."""
100
+
101
+ def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT):
102
+ super().__init__()
103
+ self.backbone = timm.create_model(
104
+ 'vit_base_patch16_224', pretrained=False, num_classes=0
105
+ )
106
+ feat = 768 # CLS token dimension
107
+
108
+ self.drop = nn.Dropout(drop)
109
+
110
+ self.disease_head = nn.Sequential(
111
+ nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
112
+ nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
113
+ nn.Linear(256, n_disease),
114
+ )
115
+ self.severity_head = nn.Sequential(
116
+ nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
117
+ nn.Linear(256, n_severity),
118
+ )
119
+
120
+ def forward(self, x):
121
+ f = self.backbone(x) # (B, 768) — CLS token features
122
+ f = self.drop(f)
123
+ return self.disease_head(f), self.severity_head(f)
124
+
125
+ def get_features(self, x):
126
+ """Return raw CLS token features (before heads and dropout)."""
127
+ return self.backbone(x) # (B, 768)
128
+
129
+ def forward_with_tokens(self, x):
130
+ """Return (disease_logits, full_token_sequence (B,197,768))."""
131
+ tokens = self.backbone.forward_features(x) # (B, 197, 768)
132
+ cls_feat = tokens[:, 0, :]
133
+ cls_feat_d = self.drop(cls_feat)
134
+ d_out = self.disease_head(cls_feat_d)
135
+ return d_out, tokens
136
+
137
+
138
+ # ================================================================
139
+ # LOAD MODEL
140
+ # ================================================================
141
+ print('\nLoading model...')
142
+ model = MultiTaskViT().to(DEVICE)
143
+ ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
144
+ model.load_state_dict(ckpt['model_state_dict'])
145
+ model.eval()
146
+ print(f' Loaded: {MODEL_PATH}')
147
+ print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1} val_acc={ckpt.get("val_acc", 0):.2f}%')
148
+
149
+ # Load thresholds and temperature
150
+ with open(THRESHOLDS_PATH) as f:
151
+ thr_data = json.load(f)
152
+ THRESHOLDS = thr_data['thresholds']
153
+
154
+ with open(TEMPERATURE_PATH) as f:
155
+ temp_data = json.load(f)
156
+ TEMPERATURE = temp_data['temperature']
157
+
158
+ print(f' Temperature T = {TEMPERATURE:.4f}')
159
+ print(f' Thresholds = {[round(t,3) for t in THRESHOLDS]}')
160
+
161
+
162
+ # ================================================================
163
+ # IMAGE PREPROCESSING
164
+ # ================================================================
165
+ def ben_graham(path, sz=IMG_SIZE, sigma=10):
166
+ """Ben Graham high-frequency fundus enhancement (APTOS-style)."""
167
+ img = cv2.imread(path)
168
+ if img is None:
169
+ img = np.array(Image.open(path).convert('RGB'))
170
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
171
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
172
+ img = cv2.resize(img, (sz, sz))
173
+ img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0, 0), sigma), -4, 128)
174
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
175
+ cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1)
176
+ return cv2.bitwise_and(img, img, mask=mask)
177
+
178
+
179
+ def clahe_preprocess(path, sz=IMG_SIZE):
180
+ """CLAHE-based contrast enhancement (ODIR-style)."""
181
+ img = cv2.imread(path)
182
+ if img is None:
183
+ img = np.array(Image.open(path).convert('RGB'))
184
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
185
+ img = cv2.resize(img, (sz, sz))
186
+ lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
187
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
188
+ lab[:, :, 0] = clahe.apply(lab[:, :, 0])
189
+ img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
190
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
191
+
192
+
193
+ def load_and_preprocess(image_path, dataset='auto'):
194
+ """
195
+ Load image and apply domain-conditional preprocessing.
196
+ Returns:
197
+ img_np : numpy (224,224,3) uint8 preprocessed
198
+ img_orig : numpy (224,224,3) uint8 original (for overlay)
199
+ """
200
+ # Normalise path: handle relative paths from CSV (e.g. "aptos/..." or "./aptos/...")
201
+ # If the path is already absolute and exists, use it directly.
202
+ # Otherwise resolve relative to BASE_DIR, stripping any leading ./ or .// first.
203
+ if not os.path.isabs(image_path):
204
+ # Strip any leading './' or '../' patterns to get a clean relative path
205
+ clean = image_path
206
+ while clean.startswith('./') or clean.startswith('.//'):
207
+ clean = clean[2:] if clean.startswith('./') else clean[3:]
208
+ image_path = os.path.join(BASE_DIR, clean)thinl
209
+ # Auto-detect domain
210
+ if dataset == 'auto':
211
+ if 'aptos' in image_path.lower() or 'gaussian' in image_path.lower():
212
+ dataset = 'APTOS'
213
+ else:
214
+ dataset = 'ODIR'
215
+
216
+ # Load original (unprocessed, for overlay)
217
+ raw = cv2.imread(image_path)
218
+ if raw is None:
219
+ raw = np.array(Image.open(image_path).convert('RGB'))
220
+ else:
221
+ raw = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)
222
+ img_orig = cv2.resize(raw, (IMG_SIZE, IMG_SIZE))
223
+
224
+ # Apply preprocessing
225
+ if dataset == 'APTOS':
226
+ img_np = ben_graham(image_path)
227
+ else:
228
+ img_np = clahe_preprocess(image_path)
229
+
230
+ return img_np, img_orig
231
+
232
+
233
+ def preprocess_to_tensor(img_np):
234
+ """Convert preprocessed numpy image to normalised tensor (1, 3, 224, 224)."""
235
+ transform = transforms.Compose([
236
+ transforms.ToPILImage(),
237
+ transforms.ToTensor(),
238
+ transforms.Normalize(NORM_MEAN, NORM_STD),
239
+ ])
240
+ return transform(img_np).unsqueeze(0)
241
+
242
+
243
+ # ================================================================
244
+ # ViT GRAD-CAM
245
+ # ================================================================
246
+ class ViTAttentionRollout:
247
+ """
248
+ Attention Rollout for Vision Transformer (Abnar & Zuidema, 2020).
249
+
250
+ WHY this works better than Grad-CAM for ViT:
251
+ - ViT uses CLS token pooling: gradients flow ONLY through CLS token (index 0)
252
+ - All 196 patch token gradients at block output = zero → Grad-CAM fails
253
+ - Attention Rollout instead traces how information flows from image patches
254
+ to the CLS token across ALL 12 transformer layers
255
+ - Accounts for residual connections by adding identity to each attention map
256
+ - Produces spatially meaningful maps that highlight actual disease regions
257
+
258
+ Algorithm:
259
+ 1. Collect attention maps A_l from all 12 blocks: shape (B, H, N, N)
260
+ 2. Average over H heads: A_l → (B, N, N)
261
+ 3. Add identity: A_l = A_l + I (accounts for residual connection)
262
+ 4. Row-normalize: A_l = A_l / row_sum
263
+ 5. Matrix-multiply all layers: Rollout = A_0 @ A_1 @ ... @ A_11
264
+ 6. Take CLS row, patch tokens only: Rollout[0, 1:] → (196,)
265
+ 7. Reshape 14×14 → bilinear upsample → 224×224
266
+ """
267
+
268
+ def __init__(self, model, discard_ratio=0.97):
269
+ self.model = model
270
+ self.discard_ratio = discard_ratio # zero out weakest attention weights
271
+ self._attention_maps = []
272
+ self._hooks = []
273
+
274
+ # Disable fused attention for explicit weight access
275
+ for block in model.backbone.blocks:
276
+ block.attn.fused_attn = False
277
+
278
+ # Register forward hooks on ALL transformer blocks
279
+ for block in model.backbone.blocks:
280
+ h = block.attn.register_forward_hook(self._attn_hook)
281
+ self._hooks.append(h)
282
+
283
+ def _attn_hook(self, module, input, output):
284
+ """Capture softmax attention weights from each block."""
285
+ x = input[0]
286
+ B, N, C = x.shape
287
+ with torch.no_grad():
288
+ qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, module.head_dim).permute(2, 0, 3, 1, 4)
289
+ q, k, _ = qkv.unbind(0)
290
+ q, k = module.q_norm(q), module.k_norm(k)
291
+ attn = (q * module.scale @ k.transpose(-2, -1)).softmax(dim=-1)
292
+ self._attention_maps.append(attn.detach().cpu()) # (B, H, N, N)
293
+
294
+ def generate(self, image_tensor, class_idx=None):
295
+ """
296
+ Generate attention rollout heatmap.
297
+
298
+ Returns:
299
+ heatmap : np.ndarray (224, 224) float32 [0, 1]
300
+ High values = regions most important for prediction
301
+ predicted_label : int
302
+ confidence : float (raw softmax)
303
+ """
304
+ self.model.eval()
305
+ self._attention_maps = []
306
+
307
+ with torch.no_grad():
308
+ image_tensor = image_tensor.to(DEVICE)
309
+ d_out, _ = self.model(image_tensor)
310
+ probs = torch.softmax(d_out, dim=1)
311
+ predicted_label = int(probs.argmax(dim=1).item())
312
+ confidence = float(probs[0, predicted_label].item())
313
+
314
+ if class_idx is None:
315
+ class_idx = predicted_label
316
+
317
+ # --- Attention Rollout computation ---
318
+ # Stack all layer attentions: list of (1, H, N, N) → (L, H, N, N)
319
+ attn_stack = torch.stack(self._attention_maps, dim=0) # (L, 1, H, N, N)
320
+ attn_stack = attn_stack[:, 0] # (L, H, N, N), batch=1
321
+
322
+ # Average over heads
323
+ attn_mean = attn_stack.mean(dim=1) # (L, N, N)
324
+
325
+ # Optional: discard weakest connections (sharpens the map)
326
+ if self.discard_ratio > 0:
327
+ flat = attn_mean.reshape(attn_mean.shape[0], -1)
328
+ thresh = torch.quantile(flat, self.discard_ratio, dim=1, keepdim=True)
329
+ thresh = thresh.unsqueeze(-1) # broadcast over N,N
330
+ attn_mean = torch.where(attn_mean >= thresh, attn_mean, torch.zeros_like(attn_mean))
331
+
332
+ # Add identity matrix for residual connection, then row-normalize
333
+ I = torch.eye(attn_mean.shape[-1]).unsqueeze(0) # (1, N, N)
334
+ attn_aug = attn_mean + I
335
+ attn_aug = attn_aug / attn_aug.sum(dim=-1, keepdim=True).clamp(min=1e-8)
336
+
337
+ # Matrix-multiply across all layers
338
+ rollout = attn_aug[0]
339
+ for l in range(1, len(attn_aug)):
340
+ rollout = rollout @ attn_aug[l]
341
+
342
+ # CLS token's attention to all patch tokens (skip CLS at index 0)
343
+ cls_attention = rollout[0, 1:] # (196,)
344
+
345
+ # Reshape and upsample
346
+ spatial = cls_attention.numpy().reshape(14, 14).astype(np.float32)
347
+ spatial = cv2.resize(spatial, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
348
+
349
+ # Normalize to [0, 1]
350
+ s_min, s_max = spatial.min(), spatial.max()
351
+ if s_max - s_min > 1e-8:
352
+ spatial = (spatial - s_min) / (s_max - s_min)
353
+ else:
354
+ spatial = np.zeros_like(spatial)
355
+
356
+ # Power-curve stretch: boosts mid-range attention values for visual clarity
357
+ # gamma < 1 brightens the map; 0.4 gives strong contrast enhancement
358
+ spatial = np.power(spatial, 0.4)
359
+
360
+ return spatial.astype(np.float32), predicted_label, confidence
361
+
362
+ def overlay(self, original_image_np, heatmap, alpha=0.5):
363
+ """
364
+ Blend attention rollout heatmap onto original fundus image.
365
+ Uses INFERNO colormap (dark=low, bright=high) — better for medical images.
366
+
367
+ Args:
368
+ original_image_np : (224, 224, 3) uint8 RGB
369
+ heatmap : (224, 224) float32 [0, 1]
370
+ alpha : heatmap opacity (0.5 gives good visibility)
371
+
372
+ Returns:
373
+ overlay : (224, 224, 3) uint8 RGB
374
+ """
375
+ # Apply JET colormap
376
+ heatmap_uint8 = (heatmap * 255).astype(np.uint8)
377
+ colormap = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
378
+ colormap_rgb = cv2.cvtColor(colormap, cv2.COLOR_BGR2RGB)
379
+
380
+ # Apply circular mask to ignore black borders (fundus images are circular)
381
+ h, w = heatmap.shape
382
+ cy, cx = h // 2, w // 2
383
+ radius = min(h, w) // 2 - 5
384
+ mask = np.zeros((h, w), dtype=np.float32)
385
+ cv2.circle(mask, (cx, cy), radius, 1.0, -1)
386
+ mask = cv2.GaussianBlur(mask, (21, 21), 0)
387
+
388
+ # Blend only inside the retinal circle
389
+ orig = original_image_np.astype(np.float32)
390
+ cmap = colormap_rgb.astype(np.float32)
391
+ blended = orig.copy()
392
+ for c in range(3):
393
+ blended[:, :, c] = (
394
+ orig[:, :, c] * (1 - alpha * mask)
395
+ + cmap[:, :, c] * (alpha * mask)
396
+ )
397
+ return np.clip(blended, 0, 255).astype(np.uint8)
398
+
399
+ def remove_hooks(self):
400
+ """Clean up all registered hooks."""
401
+ for h in self._hooks:
402
+ h.remove()
403
+ self._hooks = []
404
+
405
+ # Keep old name as alias for backward compatibility
406
+ ViTGradCAM = ViTAttentionRollout
407
+
408
+
409
+
410
+ # ================================================================
411
+ # OOD DETECTION (Mahalanobis Distance)
412
+ # ================================================================
413
+ class OODDetector:
414
+ """
415
+ Out-of-Distribution detector using class-conditional Mahalanobis distance.
416
+
417
+ Fit on training-set CLS token features; at inference, computes the
418
+ minimum Mahalanobis distance from the test feature to the nearest
419
+ class centroid. High distance = likely OOD.
420
+ """
421
+
422
+ def __init__(self, threshold_percentile=97.5):
423
+ self.class_means = None # (num_classes, feat_dim)
424
+ self.cov_inv = None # (feat_dim, feat_dim)
425
+ self.ood_threshold = None
426
+ self.threshold_percentile = threshold_percentile
427
+ self.is_fitted = False
428
+
429
+ def fit(self, model, dataloader, device, max_batches=60):
430
+ """
431
+ Extract CLS token features for all samples, compute class-conditional
432
+ means and shared inverse covariance matrix.
433
+ """
434
+ print(' OODDetector.fit: extracting features...')
435
+ all_features = []
436
+ all_labels = []
437
+
438
+ model.eval()
439
+ with torch.no_grad():
440
+ for i, batch in enumerate(dataloader):
441
+ if i >= max_batches:
442
+ break
443
+ imgs, d_lbl, _ = batch
444
+ imgs = imgs.to(device)
445
+ feats = model.get_features(imgs) # (B, 768)
446
+ all_features.append(feats.cpu().numpy())
447
+ all_labels.append(d_lbl.numpy())
448
+
449
+ features = np.concatenate(all_features, axis=0) # (N, 768)
450
+ labels = np.concatenate(all_labels, axis=0) # (N,)
451
+
452
+ num_classes = NUM_CLASSES
453
+ feat_dim = features.shape[1]
454
+
455
+ # Class-conditional means
456
+ self.class_means = np.zeros((num_classes, feat_dim), dtype=np.float64)
457
+ for c in range(num_classes):
458
+ mask = labels == c
459
+ if mask.sum() > 0:
460
+ self.class_means[c] = features[mask].mean(axis=0)
461
+
462
+ # Shared (pooled) covariance matrix
463
+ cov = np.zeros((feat_dim, feat_dim), dtype=np.float64)
464
+ total = 0
465
+ for c in range(num_classes):
466
+ mask = labels == c
467
+ if mask.sum() < 2:
468
+ continue
469
+ diff = features[mask] - self.class_means[c]
470
+ cov += diff.T @ diff
471
+ total += mask.sum()
472
+
473
+ cov /= max(total - num_classes, 1)
474
+
475
+ # Regularise for numerical stability (add small diagonal)
476
+ cov += np.eye(feat_dim) * 1e-4
477
+
478
+ # Pseudo-inverse via SVD (numerically stable for high-dim)
479
+ try:
480
+ self.cov_inv = np.linalg.pinv(cov)
481
+ except np.linalg.LinAlgError:
482
+ self.cov_inv = np.eye(feat_dim)
483
+
484
+ # Compute train-set Mahalanobis distances to set threshold
485
+ train_dists = []
486
+ for feat in features:
487
+ d = self._mahal_min_dist(feat)
488
+ train_dists.append(d)
489
+ self.ood_threshold = float(np.percentile(train_dists, self.threshold_percentile))
490
+
491
+ self.is_fitted = True
492
+ print(f' OOD threshold ({self.threshold_percentile}th pct): {self.ood_threshold:.4f}')
493
+ print(f' Features extracted: {len(features)} samples')
494
+
495
+ def _mahal_min_dist(self, feat):
496
+ """Minimum Mahalanobis distance to any class centroid."""
497
+ min_dist = float('inf')
498
+ for c in range(NUM_CLASSES):
499
+ diff = feat - self.class_means[c]
500
+ dist = float(diff @ self.cov_inv @ diff)
501
+ dist = max(dist, 0.0) # guard against floating-point negatives
502
+ if dist < min_dist:
503
+ min_dist = dist
504
+ return np.sqrt(min_dist)
505
+
506
+ def score(self, features):
507
+ """
508
+ Compute OOD score for a batch of features.
509
+
510
+ Args:
511
+ features : np.ndarray (N, 768) or (768,)
512
+
513
+ Returns:
514
+ distances : np.ndarray (N,) Mahalanobis distances
515
+ ood_flags : np.ndarray (N,) bool, True = likely OOD
516
+ """
517
+ if not self.is_fitted:
518
+ raise RuntimeError('OODDetector.fit() must be called before score()')
519
+
520
+ if features.ndim == 1:
521
+ features = features[np.newaxis, :]
522
+
523
+ distances = np.array([self._mahal_min_dist(f) for f in features])
524
+ ood_flags = distances > self.ood_threshold
525
+ return distances, ood_flags
526
+
527
+ def save(self, path):
528
+ np.savez(path,
529
+ class_means=self.class_means,
530
+ cov_inv=self.cov_inv,
531
+ ood_threshold=np.array([self.ood_threshold]),
532
+ threshold_percentile=np.array([self.threshold_percentile]))
533
+ print(f' OOD detector saved -> {path}.npz')
534
+
535
+ def load(self, path):
536
+ if not path.endswith('.npz'):
537
+ path = path + '.npz'
538
+ data = np.load(path)
539
+ self.class_means = data['class_means']
540
+ self.cov_inv = data['cov_inv']
541
+ self.ood_threshold = float(data['ood_threshold'][0])
542
+ self.threshold_percentile = float(data['threshold_percentile'][0])
543
+ self.is_fitted = True
544
+ print(f' OOD detector loaded <- {path}')
545
+
546
+
547
+ # ================================================================
548
+ # ATTENTION REGION ANALYSER
549
+ # ================================================================
550
+ def analyse_attention_region(heatmap, disease_class):
551
+ """
552
+ Check if the Grad-CAM heatmap activation pattern is consistent
553
+ with the expected anatomical region for the given disease.
554
+
555
+ Returns:
556
+ attention_region : str describing where activation is
557
+ is_consistent : bool
558
+ region_scores : dict with activation energy in each zone
559
+ """
560
+ h, w = heatmap.shape # (224, 224)
561
+ cx, cy = w // 2, h // 2
562
+
563
+ # Define anatomical zones (approximate, relative to image centre)
564
+ # Centre disc zone: circle r ~ 30px (optic disc)
565
+ r_disc = int(h * 0.13)
566
+ # Macula zone: circle r ~ 55px centred slightly temporal
567
+ r_macula = int(h * 0.25)
568
+ cx_mac = int(cx + w * 0.10) # slightly nasal offset
569
+
570
+ # Build zone masks
571
+ Y, X = np.ogrid[:h, :w]
572
+
573
+ # Optic disc (small circle, centre of image)
574
+ disc_mask = ((X - cx)**2 + (Y - cy)**2) <= r_disc**2
575
+
576
+ # Macula (larger circle, centre-temporal)
577
+ macula_mask = ((X - cx_mac)**2 + (Y - cy)**2) <= r_macula**2
578
+
579
+ # Periphery: outer 30% of image
580
+ peri_mask = (X < int(w * 0.15)) | (X > int(w * 0.85)) | \
581
+ (Y < int(h * 0.15)) | (Y > int(h * 0.85))
582
+
583
+ # Compute mean activation in each zone
584
+ disc_score = float(heatmap[disc_mask].mean()) if disc_mask.sum() > 0 else 0.0
585
+ macula_score = float(heatmap[macula_mask].mean()) if macula_mask.sum() > 0 else 0.0
586
+ peri_score = float(heatmap[peri_mask].mean()) if peri_mask.sum() > 0 else 0.0
587
+ overall_mean = float(heatmap.mean())
588
+
589
+ region_scores = {
590
+ 'optic_disc': round(disc_score, 4),
591
+ 'macula': round(macula_score, 4),
592
+ 'periphery': round(peri_score, 4),
593
+ 'overall': round(overall_mean, 4),
594
+ }
595
+
596
+ # Determine dominant region label
597
+ max_zone = max(region_scores, key=lambda k: region_scores[k] if k != 'overall' else -1)
598
+
599
+ zone_labels = {
600
+ 'optic_disc': 'optic disc (centre)',
601
+ 'macula': 'macula (centre-temporal)',
602
+ 'periphery': 'scattered periphery',
603
+ }
604
+ dominant_label = zone_labels.get(max_zone, 'diffuse')
605
+
606
+ # Assess uniformity (low std = diffuse / uniform)
607
+ if heatmap.std() < 0.10:
608
+ dominant_label = 'diffuse (low activation)'
609
+
610
+ # Check consistency with expected region
611
+ consistency_map = {
612
+ 0: lambda s: s['overall'] < 0.25, # Normal → low uniform
613
+ 1: lambda s: s['periphery'] > 0.20 or s['macula'] > 0.25, # DR → periphery/macula
614
+ 2: lambda s: s['optic_disc'] > 0.30, # Glaucoma → disc
615
+ 3: lambda s: heatmap.std() < 0.15, # Cataract → diffuse
616
+ 4: lambda s: s['macula'] > 0.25, # AMD → macula
617
+ }
618
+ check_fn = consistency_map.get(disease_class, lambda s: True)
619
+ is_consistent = check_fn(region_scores)
620
+
621
+ return dominant_label, is_consistent, region_scores
622
+
623
+
624
+ # ================================================================
625
+ # FULL INFERENCE PIPELINE
626
+ # ================================================================
627
+ def predict_with_gradcam(image_path, model, gradcam, ood_detector,
628
+ thresholds, temperature, device,
629
+ true_label=None, dataset='auto'):
630
+ """
631
+ End-to-end inference with Grad-CAM and OOD detection.
632
+
633
+ Steps:
634
+ 1. Load and preprocess image (Ben Graham for APTOS, CLAHE for ODIR)
635
+ 2. OOD check on ViT CLS token features
636
+ 3. Generate Grad-CAM heatmap
637
+ 4. Apply temperature scaling to logits
638
+ 5. Apply per-class thresholds
639
+ 6. Analyse attention region
640
+
641
+ Returns:
642
+ dict with predicted_class, confidence, gradcam_heatmap, etc.
643
+ """
644
+ # 1. Preprocess
645
+ img_np, img_orig = load_and_preprocess(image_path, dataset=dataset)
646
+ img_tensor = preprocess_to_tensor(img_np).to(device)
647
+
648
+ # 2. OOD check using raw CLS features
649
+ model.eval()
650
+ with torch.no_grad():
651
+ features = model.get_features(img_tensor).cpu().numpy() # (1, 768)
652
+
653
+ if ood_detector.is_fitted:
654
+ distances, ood_flags = ood_detector.score(features)
655
+ ood_distance = float(distances[0])
656
+ ood_flag = bool(ood_flags[0])
657
+ else:
658
+ ood_distance = 0.0
659
+ ood_flag = False
660
+
661
+ # 3. Generate Grad-CAM (also runs forward + backward pass)
662
+ heatmap, predicted_label, raw_confidence = gradcam.generate(img_tensor)
663
+
664
+ # 4. Temperature-scaled calibrated probabilities
665
+ # Run a clean no-grad forward pass to get stable logits for calibration
666
+ model.eval()
667
+ with torch.no_grad():
668
+ raw_feats = model.backbone(img_tensor) # (1, 768)
669
+ raw_feats = model.drop(raw_feats)
670
+ logits = model.disease_head(raw_feats).float().cpu() # (1, 5)
671
+
672
+ scaled_logits = logits / temperature
673
+ calibrated_probs = torch.softmax(scaled_logits, dim=1)[0].numpy() # (5,)
674
+
675
+ # 5. Apply per-class thresholds
676
+ above = [i for i, (p, t) in enumerate(zip(calibrated_probs, thresholds)) if p >= t]
677
+ if above:
678
+ final_label = int(above[np.argmax([calibrated_probs[i] for i in above])])
679
+ else:
680
+ final_label = int(np.argmax(calibrated_probs))
681
+
682
+ final_confidence = float(calibrated_probs[final_label])
683
+ predicted_class = CLASS_NAMES[final_label]
684
+
685
+ # 6. Heatmap overlay
686
+ gradcam_overlay = gradcam.overlay(img_orig, heatmap, alpha=0.7)
687
+
688
+ # 7. Attention region analysis
689
+ attention_region, region_consistent, region_scores = analyse_attention_region(
690
+ heatmap, final_label
691
+ )
692
+
693
+ # Append disease name for clarity
694
+ disease_tag = CLASS_NAMES[final_label].replace('/', '-')
695
+ attention_region_full = f'{attention_region} ({disease_tag})'
696
+
697
+ # 8. Review flag: low confidence OR OOD
698
+ review_flag = ood_flag or final_confidence < 0.50
699
+
700
+ return {
701
+ 'image_path': image_path,
702
+ 'predicted_class': predicted_class,
703
+ 'predicted_label': final_label,
704
+ 'confidence': round(final_confidence, 4),
705
+ 'raw_confidence': round(raw_confidence, 4),
706
+ 'all_probabilities': [round(float(p), 4) for p in calibrated_probs],
707
+ 'gradcam_heatmap': heatmap, # (224, 224) float32
708
+ 'gradcam_overlay': gradcam_overlay, # (224, 224, 3) uint8
709
+ 'img_orig': img_orig, # original for display
710
+ 'ood_flag': ood_flag,
711
+ 'ood_distance': round(ood_distance, 4),
712
+ 'review_flag': review_flag,
713
+ 'attention_region': attention_region_full,
714
+ 'region_scores': region_scores,
715
+ 'region_consistent': region_consistent,
716
+ 'true_label': true_label,
717
+ }
718
+
719
+
720
+ # ================================================================
721
+ # BATCH EVALUATION
722
+ # ================================================================
723
+ def run_batch_evaluation(model, gradcam, ood_detector,
724
+ thresholds, temperature, device,
725
+ n_per_class=4):
726
+ """
727
+ Run inference on n_per_class images per disease class (20 total).
728
+ Saves individual overlay images + summary grid.
729
+ """
730
+ import pandas as pd
731
+ print(f'\nRunning batch evaluation ({n_per_class} per class = {n_per_class * NUM_CLASSES} total)...')
732
+
733
+ df = pd.read_csv(TEST_CSV)
734
+
735
+ # Collect n_per_class unique samples per class
736
+ samples = []
737
+ for label in range(NUM_CLASSES):
738
+ subset = df[df['disease_label'] == label].drop_duplicates(subset='image_path')
739
+ chosen = subset.head(n_per_class)
740
+ for _, row in chosen.iterrows():
741
+ samples.append({
742
+ 'image_path': row['image_path'],
743
+ 'true_label': int(row['disease_label']),
744
+ 'dataset': str(row.get('dataset', 'auto')),
745
+ })
746
+
747
+ results = []
748
+ failed = []
749
+
750
+ for i, sample in enumerate(samples):
751
+ img_path = sample['image_path']
752
+ true_label = sample['true_label']
753
+ dataset = sample['dataset']
754
+
755
+ print(f' [{i+1:2d}/{len(samples)}] {CLASS_NAMES[true_label]:15s} | {os.path.basename(img_path)}', end=' ')
756
+
757
+ try:
758
+ result = predict_with_gradcam(
759
+ img_path, model, gradcam, ood_detector,
760
+ thresholds, temperature, device,
761
+ true_label=true_label,
762
+ dataset=dataset,
763
+ )
764
+ correct = (result['predicted_label'] == true_label)
765
+ flag_str = ' [OOD]' if result['ood_flag'] else ''
766
+ flag_str += ' [REVIEW]' if result['review_flag'] else ''
767
+ print(f'-> pred={result["predicted_class"]:15s} conf={result["confidence"]:.3f} {"OK" if correct else "WRONG"}{flag_str}')
768
+
769
+ # Save overlay image
770
+ save_name = f'gradcam_{i+1:02d}_true{true_label}_pred{result["predicted_label"]}_{os.path.splitext(os.path.basename(img_path))[0][:20]}.png'
771
+ save_path = os.path.join(GRADCAM_DIR, save_name)
772
+
773
+ fig, axes = plt.subplots(1, 3, figsize=(12, 4))
774
+ axes[0].imshow(result['img_orig'])
775
+ axes[0].set_title(f'Original\nTrue: {CLASS_NAMES[true_label]}', fontsize=9)
776
+ axes[0].axis('off')
777
+
778
+ axes[1].imshow(result['gradcam_heatmap'], cmap='jet', vmin=0, vmax=1)
779
+ axes[1].set_title('Grad-CAM Heatmap', fontsize=9)
780
+ axes[1].axis('off')
781
+
782
+ axes[2].imshow(result['gradcam_overlay'])
783
+ flag_line = ' [OOD]' if result['ood_flag'] else ''
784
+ axes[2].set_title(
785
+ f'Overlay\nPred: {result["predicted_class"]} ({result["confidence"]:.2f}){flag_line}',
786
+ fontsize=9, color='red' if not correct else 'green'
787
+ )
788
+ axes[2].axis('off')
789
+
790
+ plt.suptitle(
791
+ f'Attention: {result["attention_region"]}',
792
+ fontsize=8, color='gray'
793
+ )
794
+ plt.tight_layout()
795
+ plt.savefig(save_path, dpi=120, bbox_inches='tight')
796
+ plt.close()
797
+
798
+ result['save_path'] = save_path
799
+ results.append(result)
800
+
801
+ except Exception as e:
802
+ print(f' ERROR: {e}')
803
+ failed.append({'image_path': img_path, 'error': str(e)})
804
+
805
+ return results, failed
806
+
807
+
808
+ # ================================================================
809
+ # SUMMARY GRID (4 rows = classes 0-4, 4 cols = samples)
810
+ # ================================================================
811
+ def save_summary_grid(results):
812
+ """Save a 5×4 summary grid (rows=classes, cols=samples)."""
813
+ n_rows = NUM_CLASSES
814
+ n_cols = 4
815
+
816
+ # Group results by true label
817
+ by_class = {i: [] for i in range(NUM_CLASSES)}
818
+ for r in results:
819
+ tl = r.get('true_label', r['predicted_label'])
820
+ by_class[tl].append(r)
821
+
822
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 20))
823
+ fig.patch.set_facecolor('#1a1a2e')
824
+
825
+ for row_idx in range(n_rows):
826
+ class_results = by_class[row_idx]
827
+ for col_idx in range(n_cols):
828
+ ax = axes[row_idx, col_idx]
829
+ if col_idx < len(class_results):
830
+ r = class_results[col_idx]
831
+ ax.imshow(r['gradcam_overlay'])
832
+ correct = (r['predicted_label'] == r.get('true_label', r['predicted_label']))
833
+ border_color = '#2ecc71' if correct else '#e74c3c'
834
+ for spine in ax.spines.values():
835
+ spine.set_edgecolor(border_color)
836
+ spine.set_linewidth(3)
837
+ label_str = f'{r["predicted_class"]}\n{r["confidence"]:.2f}'
838
+ if r['ood_flag']:
839
+ label_str += '\n[OOD]'
840
+ ax.set_title(label_str, fontsize=7, color='white', pad=2)
841
+ else:
842
+ ax.set_facecolor('#1a1a2e')
843
+
844
+ ax.axis('off')
845
+ if col_idx == 0:
846
+ ax.set_ylabel(CLASS_NAMES[row_idx], rotation=0, labelpad=50,
847
+ fontsize=10, color='white', fontweight='bold',
848
+ va='center')
849
+
850
+ plt.suptitle(
851
+ 'RetinaSense v3.0 — Grad-CAM Summary Grid\n'
852
+ 'Rows = True Class | Green border = Correct | Red border = Wrong',
853
+ fontsize=12, color='white', y=1.01
854
+ )
855
+ plt.tight_layout()
856
+ grid_path = os.path.join(GRADCAM_DIR, 'gradcam_summary_grid.png')
857
+ plt.savefig(grid_path, dpi=130, bbox_inches='tight',
858
+ facecolor=fig.get_facecolor())
859
+ plt.close()
860
+ print(f' Summary grid saved -> {grid_path}')
861
+ return grid_path
862
+
863
+
864
+ # ================================================================
865
+ # DISEASE-SPECIFIC HEATMAP VALIDATION
866
+ # ================================================================
867
+ def validate_heatmaps(results):
868
+ """
869
+ Check per-disease whether Grad-CAM activates the expected anatomical region.
870
+ Returns a validation summary dict, saves to heatmap_validation.json.
871
+ """
872
+ print('\nRunning disease-specific heatmap validation...')
873
+
874
+ validation = {}
875
+ for cls_idx, cls_name in enumerate(CLASS_NAMES):
876
+ cls_results = [r for r in results if r.get('true_label') == cls_idx]
877
+ if not cls_results:
878
+ validation[cls_name] = {'n_samples': 0}
879
+ continue
880
+
881
+ consistent_count = sum(1 for r in cls_results if r.get('region_consistent', False))
882
+ avg_scores = {k: 0.0 for k in ['optic_disc', 'macula', 'periphery', 'overall']}
883
+ for r in cls_results:
884
+ for k in avg_scores:
885
+ avg_scores[k] += r['region_scores'].get(k, 0.0)
886
+ for k in avg_scores:
887
+ avg_scores[k] = round(avg_scores[k] / len(cls_results), 4)
888
+
889
+ dominant_zone = max(
890
+ ['optic_disc', 'macula', 'periphery'],
891
+ key=lambda k: avg_scores[k]
892
+ )
893
+
894
+ validation[cls_name] = {
895
+ 'n_samples': len(cls_results),
896
+ 'expected_region': EXPECTED_REGIONS[cls_idx],
897
+ 'dominant_zone': dominant_zone,
898
+ 'consistent_samples': consistent_count,
899
+ 'consistency_pct': round(100 * consistent_count / len(cls_results), 1),
900
+ 'avg_region_scores': avg_scores,
901
+ }
902
+
903
+ print(f' {cls_name:15s}: {consistent_count}/{len(cls_results)} consistent '
904
+ f'({validation[cls_name]["consistency_pct"]:.0f}%) '
905
+ f'dominant={dominant_zone}')
906
+
907
+ # Save
908
+ val_path = os.path.join(GRADCAM_DIR, 'heatmap_validation.json')
909
+ with open(val_path, 'w') as f:
910
+ json.dump(validation, f, indent=2)
911
+ print(f' Validation saved -> {val_path}')
912
+
913
+ return validation
914
+
915
+
916
+ # ================================================================
917
+ # CLINICAL REPORT
918
+ # ================================================================
919
+ def generate_clinical_report(results, validation, ood_stats, failed):
920
+ """Generate GRADCAM_REPORT.md with clinical analysis."""
921
+ now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
922
+ n_total = len(results)
923
+ n_correct = sum(1 for r in results if r.get('predicted_label') == r.get('true_label'))
924
+ n_ood = sum(1 for r in results if r.get('ood_flag'))
925
+ n_review = sum(1 for r in results if r.get('review_flag'))
926
+ avg_conf = np.mean([r['confidence'] for r in results]) if results else 0.0
927
+
928
+ lines = [
929
+ '# RetinaSense v3.0 — Grad-CAM Clinical Report',
930
+ f'',
931
+ f'**Generated**: {now} ',
932
+ f'**Model**: ViT-Base-Patch16-224 (81.19% test accuracy) ',
933
+ f'**Pipeline**: Grad-CAM + Mahalanobis OOD + Temperature Scaling + Per-Class Thresholds',
934
+ '',
935
+ '---',
936
+ '',
937
+ '## Executive Summary',
938
+ '',
939
+ f'| Metric | Value |',
940
+ f'|--------|-------|',
941
+ f'| Images processed | {n_total} |',
942
+ f'| Correct predictions | {n_correct}/{n_total} ({100*n_correct/max(n_total,1):.1f}%) |',
943
+ f'| Avg calibrated confidence | {avg_conf:.3f} |',
944
+ f'| OOD flags raised | {n_ood} |',
945
+ f'| Human review flags | {n_review} |',
946
+ f'| Failed images | {len(failed)} |',
947
+ f'| Temperature T | {TEMPERATURE:.4f} |',
948
+ '',
949
+ '---',
950
+ '',
951
+ '## Per-Sample Predictions',
952
+ '',
953
+ '| # | Image | True | Predicted | Confidence | OOD | Review | Attention Region |',
954
+ '|---|-------|------|-----------|-----------|-----|--------|-----------------|',
955
+ ]
956
+
957
+ for i, r in enumerate(results):
958
+ true_name = CLASS_NAMES[r['true_label']] if r.get('true_label') is not None else 'Unknown'
959
+ correct_marker = 'OK' if r['predicted_label'] == r.get('true_label') else '**WRONG**'
960
+ lines.append(
961
+ f'| {i+1} | {os.path.basename(r["image_path"])[:25]} '
962
+ f'| {true_name} '
963
+ f'| {r["predicted_class"]} ({correct_marker}) '
964
+ f'| {r["confidence"]:.3f} '
965
+ f'| {"YES" if r["ood_flag"] else "no"} '
966
+ f'| {"YES" if r["review_flag"] else "no"} '
967
+ f'| {r["attention_region"]} |'
968
+ )
969
+
970
+ lines += [
971
+ '',
972
+ '---',
973
+ '',
974
+ '## Per-Class Attention Pattern Analysis',
975
+ '',
976
+ '| Disease | Expected Region | Dominant Zone | Consistency |',
977
+ '|---------|----------------|---------------|-------------|',
978
+ ]
979
+ for cls_name, v in validation.items():
980
+ if v.get('n_samples', 0) == 0:
981
+ lines.append(f'| {cls_name} | N/A | N/A | N/A (no samples) |')
982
+ else:
983
+ lines.append(
984
+ f'| {cls_name} | {v["expected_region"]} '
985
+ f'| {v["dominant_zone"]} '
986
+ f'| {v["consistency_pct"]:.0f}% ({v["consistent_samples"]}/{v["n_samples"]}) |'
987
+ )
988
+
989
+ lines += [
990
+ '',
991
+ '---',
992
+ '',
993
+ '## OOD Detection Statistics',
994
+ '',
995
+ f'- **Method**: Mahalanobis distance to nearest class centroid (CLS token features)',
996
+ f'- **Threshold percentile**: 97.5th percentile of training-set distances',
997
+ f'- **OOD threshold**: {ood_stats.get("threshold", "N/A")}',
998
+ f'- **Images flagged OOD**: {n_ood}/{n_total}',
999
+ '',
1000
+ '### Interpretation',
1001
+ '',
1002
+ '- Mahalanobis distance measures how far a feature embedding lies from known class distributions',
1003
+ '- Low-quality images, extreme artefacts, or off-distribution fundus cameras may trigger OOD flags',
1004
+ '- All OOD-flagged images are automatically sent for human review',
1005
+ '',
1006
+ '---',
1007
+ '',
1008
+ '## Grad-CAM Heatmap Descriptions',
1009
+ '',
1010
+ '| Disease | Expected activation | Clinical significance |',
1011
+ '|---------|--------------------|-----------------------|',
1012
+ '| Normal | Low, uniform | No focal pathology — model attention diffuse |',
1013
+ '| Diabetes/DR | Scattered periphery + macula | Microaneurysms, exudates, NV |',
1014
+ '| Glaucoma | Optic disc (centre) | Structural disc changes, CDR |',
1015
+ '| Cataract | Diffuse lens opacity | Posterior/anterior capsule opacification |',
1016
+ '| AMD | Macula / centre-temporal | Drusen, RPE atrophy, CNV |',
1017
+ '',
1018
+ '---',
1019
+ '',
1020
+ '## Thresholds Applied',
1021
+ '',
1022
+ '| Class | Threshold |',
1023
+ '|-------|-----------|',
1024
+ ]
1025
+ for cls_name, thr in zip(CLASS_NAMES, THRESHOLDS):
1026
+ lines.append(f'| {cls_name} | {thr:.4f} |')
1027
+
1028
+ lines += [
1029
+ '',
1030
+ '---',
1031
+ '',
1032
+ '## Deployment Recommendations',
1033
+ '',
1034
+ '1. **Confidence gate**: Flag predictions below 0.50 for mandatory ophthalmologist review.',
1035
+ '2. **OOD gate**: Any Mahalanobis distance above threshold should trigger QC check on image quality before clinical use.',
1036
+ '3. **Grad-CAM review**: Clinicians should inspect heatmaps for cases where model attention does not align with expected anatomy.',
1037
+ '4. **Glaucoma caution**: Current dataset imbalance (46 test samples) — consider supplementing ODIR with additional glaucoma images.',
1038
+ '5. **Continuous monitoring**: Re-calibrate temperature and thresholds quarterly on production data.',
1039
+ '6. **Not for standalone diagnosis**: Grad-CAM is an explainability aid; all predictions require clinical validation.',
1040
+ '',
1041
+ '---',
1042
+ '',
1043
+ f'*Report auto-generated by RetinaSense v3.0 Grad-CAM Pipeline | {now}*',
1044
+ ]
1045
+
1046
+ report_path = os.path.join(GRADCAM_DIR, 'GRADCAM_REPORT.md')
1047
+ with open(report_path, 'w') as f:
1048
+ f.write('\n'.join(lines))
1049
+ print(f' Clinical report saved -> {report_path}')
1050
+ return report_path
1051
+
1052
+
1053
+ # ================================================================
1054
+ # MAIN
1055
+ # ================================================================
1056
+ def main():
1057
+ t_start = time.time()
1058
+
1059
+ # ---- 1. Build Grad-CAM ---
1060
+ print('\n[1/6] Initialising ViTGradCAM...')
1061
+ gradcam = ViTGradCAM(model)
1062
+ print(f' Method : Attention Rollout (all 12 transformer blocks)')
1063
+ print(f' Hooks : {len(gradcam._hooks)} attention hooks registered')
1064
+ print(f' fused_attn disabled for attention weight access')
1065
+
1066
+ # ---- 2. Fit OOD Detector ---
1067
+ print('\n[2/6] Fitting OOD detector...')
1068
+ ood_path = os.path.join(OUTPUT_DIR, 'ood_detector')
1069
+ ood_detector = OODDetector(threshold_percentile=97.5)
1070
+
1071
+ if os.path.exists(ood_path + '.npz'):
1072
+ ood_detector.load(ood_path)
1073
+ else:
1074
+ # Build a small DataLoader from training data to fit OOD detector
1075
+ import pandas as pd
1076
+ from torch.utils.data import Dataset, DataLoader
1077
+ from torchvision import transforms as T
1078
+
1079
+ train_df = pd.read_csv(os.path.join(BASE_DIR, 'data', 'train_split.csv'))
1080
+
1081
+ class SimpleDataset(Dataset):
1082
+ def __init__(self, df):
1083
+ self.df = df.reset_index(drop=True)
1084
+ self.transform = transforms.Compose([
1085
+ transforms.ToPILImage(),
1086
+ transforms.ToTensor(),
1087
+ transforms.Normalize(NORM_MEAN, NORM_STD),
1088
+ ])
1089
+
1090
+ def __len__(self):
1091
+ return len(self.df)
1092
+
1093
+ def __getitem__(self, idx):
1094
+ row = self.df.iloc[idx]
1095
+ img_path = str(row['image_path'])
1096
+ if not os.path.isabs(img_path):
1097
+ clean = img_path
1098
+ while clean.startswith('./') or clean.startswith('.//'):
1099
+ clean = clean[2:] if clean.startswith('./') else clean[3:]
1100
+ img_path = os.path.join(BASE_DIR, clean)
1101
+ dataset = str(row.get('dataset', 'auto'))
1102
+
1103
+ try:
1104
+ img_np, _ = load_and_preprocess(img_path, dataset=dataset)
1105
+ img_tensor = self.transform(img_np)
1106
+ except Exception:
1107
+ img_tensor = torch.zeros(3, IMG_SIZE, IMG_SIZE)
1108
+
1109
+ lbl = int(row['disease_label'])
1110
+ return img_tensor, torch.tensor(lbl, dtype=torch.long), torch.tensor(0, dtype=torch.long)
1111
+
1112
+ ood_ds = SimpleDataset(train_df)
1113
+ ood_loader = DataLoader(ood_ds, batch_size=32, shuffle=False, num_workers=4)
1114
+ ood_detector.fit(model, ood_loader, DEVICE, max_batches=80)
1115
+ ood_detector.save(ood_path)
1116
+
1117
+ # ---- 3. Batch Evaluation ---
1118
+ print('\n[3/6] Batch evaluation on 20 test images...')
1119
+ results, failed = run_batch_evaluation(
1120
+ model, gradcam, ood_detector,
1121
+ THRESHOLDS, TEMPERATURE, DEVICE,
1122
+ n_per_class=4
1123
+ )
1124
+
1125
+ # ---- 4. Summary Grid ---
1126
+ print('\n[4/6] Generating summary grid...')
1127
+ grid_path = save_summary_grid(results)
1128
+
1129
+ # ---- 5. Heatmap Validation ---
1130
+ print('\n[5/6] Heatmap validation...')
1131
+ validation = validate_heatmaps(results)
1132
+
1133
+ # ---- 6. Clinical Report ---
1134
+ print('\n[6/6] Generating clinical report...')
1135
+ ood_stats = {'threshold': round(ood_detector.ood_threshold, 4) if ood_detector.is_fitted else 'N/A'}
1136
+ report_path = generate_clinical_report(results, validation, ood_stats, failed)
1137
+
1138
+ # ---- Cleanup ---
1139
+ gradcam.remove_hooks()
1140
+
1141
+ # ================================================================
1142
+ # FINAL SUMMARY
1143
+ # ================================================================
1144
+ elapsed = time.time() - t_start
1145
+ n_total = len(results)
1146
+ n_correct = sum(1 for r in results if r.get('predicted_label') == r.get('true_label'))
1147
+ avg_conf = np.mean([r['confidence'] for r in results]) if results else 0.0
1148
+ n_ood = sum(1 for r in results if r['ood_flag'])
1149
+ n_review = sum(1 for r in results if r['review_flag'])
1150
+
1151
+ print('\n' + '=' * 65)
1152
+ print(' RetinaSense v3.0 — GRAD-CAM PIPELINE COMPLETE')
1153
+ print('=' * 65)
1154
+ print(f' Images processed : {n_total}')
1155
+ print(f' Correct predictions : {n_correct}/{n_total} ({100*n_correct/max(n_total,1):.1f}%)')
1156
+ print(f' Avg calibrated conf : {avg_conf:.3f}')
1157
+ print(f' OOD flags : {n_ood}')
1158
+ print(f' Review flags : {n_review}')
1159
+ print(f' Failed images : {len(failed)}')
1160
+ print(f' Elapsed time : {elapsed:.1f}s')
1161
+ print()
1162
+ print(f' Output directory : {GRADCAM_DIR}')
1163
+ output_files = [
1164
+ 'gradcam_summary_grid.png',
1165
+ 'heatmap_validation.json',
1166
+ 'GRADCAM_REPORT.md',
1167
+ ] + [os.path.basename(r.get('save_path', '')) for r in results if r.get('save_path')]
1168
+ for fname in output_files:
1169
+ if fname:
1170
+ full = os.path.join(GRADCAM_DIR, fname)
1171
+ exists = os.path.exists(full)
1172
+ print(f' {"[OK]" if exists else "[!!]"} {fname}')
1173
+ print('=' * 65)
1174
+
1175
+ return results, validation
1176
+
1177
+
1178
+ if __name__ == '__main__':
1179
+ main()