tanishq74 commited on
Commit
5a7853c
·
verified ·
1 Parent(s): 9022502

Add integrated_gradients_xai.py

Browse files
Files changed (1) hide show
  1. integrated_gradients_xai.py +858 -0
integrated_gradients_xai.py ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RetinaSense v3.0 -- Phase 1C: Advanced XAI with Integrated Gradients
4
+ =====================================================================
5
+ Compares Attention Rollout (existing) vs Integrated Gradients (captum)
6
+ on 20 test images (4 per class).
7
+
8
+ Outputs (all saved to outputs_v3/xai/):
9
+ - comparison_grid.png : 20-row x 3-column grid [Original | Rollout | IG]
10
+ - ig_individual_01..20.png : Individual IG heatmaps
11
+ - agreement_heatmap.png : Spatial correlation matrix between methods
12
+ - agreement_score.json : Numerical agreement scores per image
13
+
14
+ Usage:
15
+ python integrated_gradients_xai.py
16
+ """
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import warnings
22
+ import numpy as np
23
+ import cv2
24
+ import matplotlib
25
+ matplotlib.use('Agg')
26
+ import matplotlib.pyplot as plt
27
+ from matplotlib.colors import Normalize
28
+ from PIL import Image
29
+ import pandas as pd
30
+ from scipy.stats import pearsonr
31
+
32
+ warnings.filterwarnings('ignore')
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from torchvision import transforms
38
+ import timm
39
+
40
+ from captum.attr import IntegratedGradients
41
+
42
+ # Maximize CPU parallelism
43
+ torch.set_num_threads(os.cpu_count() or 4)
44
+
45
+ # ================================================================
46
+ # CONFIGURATION
47
+ # ================================================================
48
+ BASE_DIR = '/teamspace/studios/this_studio'
49
+ OUTPUT_DIR = os.path.join(BASE_DIR, 'outputs_v3')
50
+ XAI_DIR = os.path.join(OUTPUT_DIR, 'xai')
51
+ os.makedirs(XAI_DIR, exist_ok=True)
52
+
53
+ MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.pth')
54
+ TEMPERATURE_PATH = os.path.join(OUTPUT_DIR, 'temperature.json')
55
+ TEST_CSV = os.path.join(BASE_DIR, 'data', 'test_split.csv')
56
+ NORM_STATS_PATH = os.path.join(BASE_DIR, 'data', 'fundus_norm_stats.json')
57
+
58
+ CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
59
+ NUM_CLASSES = 5
60
+ IMG_SIZE = 224
61
+ DROPOUT = 0.3
62
+ N_PER_CLASS = 4
63
+
64
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
+
66
+ print('=' * 65)
67
+ print(' RetinaSense v3.0 -- Phase 1C: Integrated Gradients XAI')
68
+ print('=' * 65)
69
+ print(f' Device : {DEVICE}')
70
+ if torch.cuda.is_available():
71
+ print(f' GPU : {torch.cuda.get_device_name(0)}')
72
+ print(f' Output : {XAI_DIR}')
73
+ print('=' * 65)
74
+
75
+ # ================================================================
76
+ # LOAD NORMALISATION STATS
77
+ # ================================================================
78
+ if os.path.exists(NORM_STATS_PATH):
79
+ with open(NORM_STATS_PATH) as f:
80
+ norm_stats = json.load(f)
81
+ NORM_MEAN = norm_stats['mean_rgb']
82
+ NORM_STD = norm_stats['std_rgb']
83
+ print(f' Fundus norm stats: mean={[round(v,4) for v in NORM_MEAN]}, '
84
+ f'std={[round(v,4) for v in NORM_STD]}')
85
+ else:
86
+ NORM_MEAN = [0.485, 0.456, 0.406]
87
+ NORM_STD = [0.229, 0.224, 0.225]
88
+ print(' Using ImageNet normalisation fallback')
89
+
90
+ with open(TEMPERATURE_PATH) as f:
91
+ TEMPERATURE = json.load(f)['temperature']
92
+ print(f' Temperature T = {TEMPERATURE:.4f}')
93
+
94
+
95
+ # ================================================================
96
+ # MODEL ARCHITECTURE (mirrors gradcam_v3.py / retinasense_v3.py)
97
+ # ================================================================
98
+ class MultiTaskViT(nn.Module):
99
+ """ViT-Base-Patch16-224 with disease + severity heads."""
100
+
101
+ def __init__(self, n_disease=NUM_CLASSES, n_severity=5, drop=DROPOUT):
102
+ super().__init__()
103
+ self.backbone = timm.create_model(
104
+ 'vit_base_patch16_224', pretrained=False, num_classes=0
105
+ )
106
+ feat = 768
107
+
108
+ self.drop = nn.Dropout(drop)
109
+
110
+ self.disease_head = nn.Sequential(
111
+ nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
112
+ nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
113
+ nn.Linear(256, n_disease),
114
+ )
115
+ self.severity_head = nn.Sequential(
116
+ nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
117
+ nn.Linear(256, n_severity),
118
+ )
119
+
120
+ def forward(self, x):
121
+ f = self.backbone(x)
122
+ f = self.drop(f)
123
+ return self.disease_head(f), self.severity_head(f)
124
+
125
+
126
+ # ================================================================
127
+ # DISEASE-LOGITS WRAPPER FOR CAPTUM
128
+ # ================================================================
129
+ class DiseaseLogitModel(nn.Module):
130
+ """
131
+ Wraps MultiTaskViT so that forward(x) returns only the disease logits.
132
+ Captum's IntegratedGradients requires a model whose forward output
133
+ is either a scalar or a 1-D tensor. We select the target class
134
+ logit inside the IG call via the `target` parameter, so here we
135
+ return the full (B, 5) disease logits.
136
+ """
137
+
138
+ def __init__(self, model):
139
+ super().__init__()
140
+ self.model = model
141
+
142
+ def forward(self, x):
143
+ disease_logits, _ = self.model(x)
144
+ return disease_logits
145
+
146
+
147
+ # ================================================================
148
+ # ATTENTION ROLLOUT (copied from gradcam_v3.py for self-containment)
149
+ # ================================================================
150
+ class ViTAttentionRollout:
151
+ """
152
+ Attention Rollout for Vision Transformer.
153
+ Traces information flow from patches to CLS token across all layers.
154
+ """
155
+
156
+ def __init__(self, model, discard_ratio=0.97):
157
+ self.model = model
158
+ self.discard_ratio = discard_ratio
159
+ self._attention_maps = []
160
+ self._hooks = []
161
+
162
+ # Disable fused attention for explicit weight access
163
+ for block in model.backbone.blocks:
164
+ block.attn.fused_attn = False
165
+
166
+ # Register forward hooks on all transformer blocks
167
+ for block in model.backbone.blocks:
168
+ h = block.attn.register_forward_hook(self._attn_hook)
169
+ self._hooks.append(h)
170
+
171
+ def _attn_hook(self, module, input, output):
172
+ """Capture softmax attention weights from each block."""
173
+ x = input[0]
174
+ B, N, C = x.shape
175
+ with torch.no_grad():
176
+ qkv = module.qkv(x).reshape(
177
+ B, N, 3, module.num_heads, module.head_dim
178
+ ).permute(2, 0, 3, 1, 4)
179
+ q, k, _ = qkv.unbind(0)
180
+ q, k = module.q_norm(q), module.k_norm(k)
181
+ attn = (q * module.scale @ k.transpose(-2, -1)).softmax(dim=-1)
182
+ self._attention_maps.append(attn.detach().cpu())
183
+
184
+ def generate(self, image_tensor, class_idx=None):
185
+ """
186
+ Generate attention rollout heatmap.
187
+ Returns:
188
+ heatmap (224, 224) float32 [0, 1], predicted_label, confidence
189
+ """
190
+ self.model.eval()
191
+ self._attention_maps = []
192
+
193
+ with torch.no_grad():
194
+ image_tensor = image_tensor.to(DEVICE)
195
+ d_out, _ = self.model(image_tensor)
196
+ probs = torch.softmax(d_out, dim=1)
197
+ predicted_label = int(probs.argmax(dim=1).item())
198
+ confidence = float(probs[0, predicted_label].item())
199
+
200
+ if class_idx is None:
201
+ class_idx = predicted_label
202
+
203
+ attn_stack = torch.stack(self._attention_maps, dim=0)[:, 0]
204
+ attn_mean = attn_stack.mean(dim=1)
205
+
206
+ if self.discard_ratio > 0:
207
+ flat = attn_mean.reshape(attn_mean.shape[0], -1)
208
+ thresh = torch.quantile(flat, self.discard_ratio, dim=1, keepdim=True)
209
+ thresh = thresh.unsqueeze(-1)
210
+ attn_mean = torch.where(
211
+ attn_mean >= thresh, attn_mean, torch.zeros_like(attn_mean)
212
+ )
213
+
214
+ I = torch.eye(attn_mean.shape[-1]).unsqueeze(0)
215
+ attn_aug = attn_mean + I
216
+ attn_aug = attn_aug / attn_aug.sum(dim=-1, keepdim=True).clamp(min=1e-8)
217
+
218
+ rollout = attn_aug[0]
219
+ for l in range(1, len(attn_aug)):
220
+ rollout = rollout @ attn_aug[l]
221
+
222
+ cls_attention = rollout[0, 1:]
223
+ spatial = cls_attention.numpy().reshape(14, 14).astype(np.float32)
224
+ spatial = cv2.resize(spatial, (IMG_SIZE, IMG_SIZE),
225
+ interpolation=cv2.INTER_LINEAR)
226
+
227
+ s_min, s_max = spatial.min(), spatial.max()
228
+ if s_max - s_min > 1e-8:
229
+ spatial = (spatial - s_min) / (s_max - s_min)
230
+ else:
231
+ spatial = np.zeros_like(spatial)
232
+
233
+ spatial = np.power(spatial, 0.4) # gamma stretch
234
+
235
+ return spatial.astype(np.float32), predicted_label, confidence
236
+
237
+ def remove_hooks(self):
238
+ for h in self._hooks:
239
+ h.remove()
240
+ self._hooks = []
241
+
242
+
243
+ # ================================================================
244
+ # IMAGE PREPROCESSING (mirrors gradcam_v3.py)
245
+ # ================================================================
246
+ def ben_graham(path, sz=IMG_SIZE, sigma=10):
247
+ """Ben Graham high-frequency fundus enhancement (APTOS-style)."""
248
+ img = cv2.imread(path)
249
+ if img is None:
250
+ img = np.array(Image.open(path).convert('RGB'))
251
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
252
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
253
+ img = cv2.resize(img, (sz, sz))
254
+ img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0, 0), sigma), -4, 128)
255
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
256
+ cv2.circle(mask, (sz // 2, sz // 2), int(sz * 0.48), 255, -1)
257
+ return cv2.bitwise_and(img, img, mask=mask)
258
+
259
+
260
+ def clahe_preprocess(path, sz=IMG_SIZE):
261
+ """CLAHE contrast enhancement (ODIR-style)."""
262
+ img = cv2.imread(path)
263
+ if img is None:
264
+ img = np.array(Image.open(path).convert('RGB'))
265
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
266
+ img = cv2.resize(img, (sz, sz))
267
+ lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
268
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
269
+ lab[:, :, 0] = clahe.apply(lab[:, :, 0])
270
+ img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
271
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
272
+
273
+
274
+ def load_and_preprocess(image_path, dataset='auto'):
275
+ """
276
+ Load image with domain-conditional preprocessing.
277
+ Returns:
278
+ img_np : (224, 224, 3) uint8 preprocessed
279
+ img_orig : (224, 224, 3) uint8 original
280
+ """
281
+ if not os.path.isabs(image_path):
282
+ clean = image_path
283
+ while clean.startswith('./'):
284
+ clean = clean[2:]
285
+ image_path = os.path.join(BASE_DIR, clean)
286
+
287
+ if dataset == 'auto':
288
+ if 'aptos' in image_path.lower() or 'gaussian' in image_path.lower():
289
+ dataset = 'APTOS'
290
+ else:
291
+ dataset = 'ODIR'
292
+
293
+ raw = cv2.imread(image_path)
294
+ if raw is None:
295
+ raw = np.array(Image.open(image_path).convert('RGB'))
296
+ else:
297
+ raw = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)
298
+ img_orig = cv2.resize(raw, (IMG_SIZE, IMG_SIZE))
299
+
300
+ if dataset == 'APTOS':
301
+ img_np = ben_graham(image_path, sz=IMG_SIZE)
302
+ else:
303
+ img_np = clahe_preprocess(image_path, sz=IMG_SIZE)
304
+
305
+ return img_np, img_orig
306
+
307
+
308
+ def preprocess_to_tensor(img_np):
309
+ """Convert preprocessed numpy image to normalised tensor (1, 3, 224, 224)."""
310
+ transform = transforms.Compose([
311
+ transforms.ToPILImage(),
312
+ transforms.ToTensor(),
313
+ transforms.Normalize(NORM_MEAN, NORM_STD),
314
+ ])
315
+ return transform(img_np).unsqueeze(0)
316
+
317
+
318
+ # ================================================================
319
+ # CIRCULAR FUNDUS MASK
320
+ # ================================================================
321
+ def create_fundus_mask(h=IMG_SIZE, w=IMG_SIZE):
322
+ """
323
+ Create a soft circular mask matching the fundus region.
324
+ Uses a smooth Gaussian-blurred edge to avoid hard boundaries.
325
+ Returns float32 mask [0, 1] of shape (h, w).
326
+ """
327
+ cy, cx = h // 2, w // 2
328
+ radius = min(h, w) // 2 - 5
329
+ mask = np.zeros((h, w), dtype=np.float32)
330
+ cv2.circle(mask, (cx, cy), radius, 1.0, -1)
331
+ mask = cv2.GaussianBlur(mask, (21, 21), 0)
332
+ return mask
333
+
334
+
335
+ # ================================================================
336
+ # INTEGRATED GRADIENTS COMPUTATION
337
+ # ================================================================
338
+ def compute_ig_attribution(ig_model, ig_method, img_tensor, target_class,
339
+ n_steps=50, internal_batch_size=10, sigma=10):
340
+ """
341
+ Compute Integrated Gradients attribution for a single image.
342
+
343
+ Uses a Gaussian-blurred baseline (sigma=10) which is more appropriate
344
+ for fundus images than a black baseline (since the background is already dark).
345
+
346
+ Args:
347
+ ig_model : DiseaseLogitModel wrapper
348
+ ig_method : captum IntegratedGradients instance
349
+ img_tensor : (1, 3, 224, 224) normalised tensor on DEVICE
350
+ target_class : int, disease class to explain
351
+ n_steps : number of interpolation steps
352
+ internal_batch_size : batch size for internal IG computation
353
+ sigma : Gaussian blur sigma for baseline
354
+
355
+ Returns:
356
+ attribution : (224, 224) float32 numpy array, normalised [0, 1]
357
+ """
358
+ # Create blurred baseline in pixel space, then normalise
359
+ # First undo normalisation to get pixel-space tensor
360
+ mean_t = torch.tensor(NORM_MEAN, device=DEVICE).view(1, 3, 1, 1)
361
+ std_t = torch.tensor(NORM_STD, device=DEVICE).view(1, 3, 1, 1)
362
+
363
+ # Build the blurred baseline from the input tensor
364
+ # Denormalise -> blur -> renormalise
365
+ img_denorm = img_tensor * std_t + mean_t # approx [0, 1] range
366
+ img_np_for_blur = (img_denorm[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
367
+ blurred_np = cv2.GaussianBlur(img_np_for_blur, (0, 0), sigma)
368
+ # Convert blurred back to tensor with normalisation
369
+ blurred_tensor = transforms.Compose([
370
+ transforms.ToPILImage(),
371
+ transforms.ToTensor(),
372
+ transforms.Normalize(NORM_MEAN, NORM_STD),
373
+ ])(blurred_np).unsqueeze(0).to(DEVICE)
374
+
375
+ # Compute Integrated Gradients
376
+ img_tensor.requires_grad_(True)
377
+ attributions = ig_method.attribute(
378
+ img_tensor,
379
+ baselines=blurred_tensor,
380
+ target=target_class,
381
+ n_steps=n_steps,
382
+ internal_batch_size=internal_batch_size,
383
+ )
384
+
385
+ # Aggregate across channels: take L2 norm across RGB for spatial map
386
+ # shape: (1, 3, 224, 224) -> (224, 224)
387
+ attr_np = attributions[0].detach().cpu().numpy() # (3, 224, 224)
388
+ # Use absolute values and sum over channels for a positive attribution map
389
+ attr_spatial = np.sqrt(np.sum(attr_np ** 2, axis=0)) # (224, 224)
390
+
391
+ # Normalise to [0, 1]
392
+ a_min, a_max = attr_spatial.min(), attr_spatial.max()
393
+ if a_max - a_min > 1e-8:
394
+ attr_spatial = (attr_spatial - a_min) / (a_max - a_min)
395
+ else:
396
+ attr_spatial = np.zeros_like(attr_spatial)
397
+
398
+ return attr_spatial.astype(np.float32)
399
+
400
+
401
+ # ================================================================
402
+ # OVERLAY FUNCTION
403
+ # ================================================================
404
+ def overlay_heatmap(original_np, heatmap, alpha=0.6, cmap_name='inferno'):
405
+ """
406
+ Blend heatmap onto original image with circular fundus mask.
407
+
408
+ Args:
409
+ original_np : (224, 224, 3) uint8 RGB
410
+ heatmap : (224, 224) float32 [0, 1]
411
+ alpha : heatmap blending opacity
412
+ cmap_name : matplotlib colormap name
413
+
414
+ Returns:
415
+ blended : (224, 224, 3) uint8 RGB
416
+ """
417
+ # Apply colormap
418
+ cmap = plt.get_cmap(cmap_name)
419
+ colored = cmap(heatmap)[:, :, :3] # (224, 224, 3) float [0, 1]
420
+ colored_uint8 = (colored * 255).astype(np.uint8)
421
+
422
+ # Get fundus mask
423
+ mask = create_fundus_mask(heatmap.shape[0], heatmap.shape[1])
424
+
425
+ # Blend inside the fundus region only
426
+ orig = original_np.astype(np.float32)
427
+ cmap_f = colored_uint8.astype(np.float32)
428
+ blended = orig.copy()
429
+ for c in range(3):
430
+ blended[:, :, c] = (
431
+ orig[:, :, c] * (1 - alpha * mask)
432
+ + cmap_f[:, :, c] * (alpha * mask)
433
+ )
434
+ return np.clip(blended, 0, 255).astype(np.uint8)
435
+
436
+
437
+ # ================================================================
438
+ # SELECT TEST IMAGES (same logic as gradcam_v3.py)
439
+ # ================================================================
440
+ def select_test_images(n_per_class=N_PER_CLASS):
441
+ """Select n_per_class images per disease class from test split."""
442
+ df = pd.read_csv(TEST_CSV)
443
+ samples = []
444
+ for label in range(NUM_CLASSES):
445
+ subset = df[df['disease_label'] == label].drop_duplicates(subset='image_path')
446
+ chosen = subset.head(n_per_class)
447
+ for _, row in chosen.iterrows():
448
+ samples.append({
449
+ 'image_path': row['image_path'],
450
+ 'true_label': int(row['disease_label']),
451
+ 'dataset': str(row.get('source', 'auto')),
452
+ })
453
+ print(f' Selected {len(samples)} test images '
454
+ f'({n_per_class} per class x {NUM_CLASSES} classes)')
455
+ return samples
456
+
457
+
458
+ # ================================================================
459
+ # COMPUTE AGREEMENT METRICS
460
+ # ================================================================
461
+ def compute_agreement(rollout_map, ig_map, fundus_mask):
462
+ """
463
+ Compute spatial agreement between Attention Rollout and IG heatmaps.
464
+
465
+ Metrics:
466
+ - Pearson correlation (within fundus mask)
467
+ - Intersection over Union (IoU) of top-20% activated regions
468
+ - Cosine similarity of flattened masked vectors
469
+
470
+ Returns dict of scores.
471
+ """
472
+ # Flatten inside mask
473
+ mask_bool = fundus_mask > 0.5
474
+ r_flat = rollout_map[mask_bool]
475
+ i_flat = ig_map[mask_bool]
476
+
477
+ # Pearson correlation
478
+ if len(r_flat) > 2 and r_flat.std() > 1e-8 and i_flat.std() > 1e-8:
479
+ pearson_r, pearson_p = pearsonr(r_flat, i_flat)
480
+ else:
481
+ pearson_r, pearson_p = 0.0, 1.0
482
+
483
+ # IoU of top-20% regions
484
+ r_thresh = np.percentile(r_flat, 80)
485
+ i_thresh = np.percentile(i_flat, 80)
486
+ r_top = rollout_map > r_thresh
487
+ i_top = ig_map > i_thresh
488
+ intersection = np.logical_and(r_top, i_top).sum()
489
+ union = np.logical_or(r_top, i_top).sum()
490
+ iou = float(intersection / max(union, 1))
491
+
492
+ # Cosine similarity
493
+ r_norm = np.linalg.norm(r_flat)
494
+ i_norm = np.linalg.norm(i_flat)
495
+ if r_norm > 1e-8 and i_norm > 1e-8:
496
+ cosine = float(np.dot(r_flat, i_flat) / (r_norm * i_norm))
497
+ else:
498
+ cosine = 0.0
499
+
500
+ return {
501
+ 'pearson_r': round(float(pearson_r), 4),
502
+ 'pearson_p': round(float(pearson_p), 6),
503
+ 'iou_top20': round(iou, 4),
504
+ 'cosine_sim': round(cosine, 4),
505
+ }
506
+
507
+
508
+ # ================================================================
509
+ # MAIN PIPELINE
510
+ # ================================================================
511
+ def main():
512
+ import time
513
+ t_start = time.time()
514
+
515
+ # ---- 1. Load model ----
516
+ print('\n[1/6] Loading model...')
517
+ model = MultiTaskViT().to(DEVICE)
518
+ ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
519
+ model.load_state_dict(ckpt['model_state_dict'])
520
+ model.eval()
521
+ print(f' Loaded: {MODEL_PATH}')
522
+ print(f' Checkpoint epoch: {ckpt.get("epoch", "?") + 1}')
523
+
524
+ # ---- 2. Select images ----
525
+ print('\n[2/6] Selecting test images...')
526
+ samples = select_test_images()
527
+
528
+ # ---- 3. Preprocess all images + run Attention Rollout ----
529
+ # IMPORTANT: Run rollout FIRST then remove hooks BEFORE IG.
530
+ # This prevents rollout hooks from firing during IG's many
531
+ # forward passes (50 steps), which would be very slow on CPU.
532
+ print('\n[3/6] Running Attention Rollout on all images...')
533
+ rollout = ViTAttentionRollout(model, discard_ratio=0.97)
534
+ print(f' Attention Rollout: {len(rollout._hooks)} hooks registered')
535
+
536
+ preprocessed = [] # store (img_np, img_orig, img_tensor) per sample
537
+ rollout_results = [] # store (heatmap, pred_label, confidence) per sample
538
+
539
+ for idx, sample in enumerate(samples):
540
+ img_path = sample['image_path']
541
+ true_label = sample['true_label']
542
+ dataset = sample['dataset']
543
+ basename = os.path.basename(img_path)
544
+
545
+ print(f' [{idx+1:2d}/{len(samples)}] '
546
+ f'{CLASS_NAMES[true_label]:15s} | {basename}', end=' ')
547
+
548
+ try:
549
+ img_np, img_orig = load_and_preprocess(img_path, dataset=dataset)
550
+ img_tensor = preprocess_to_tensor(img_np).to(DEVICE)
551
+ preprocessed.append((img_np, img_orig, img_tensor))
552
+
553
+ heatmap, pred_label, pred_conf = rollout.generate(img_tensor)
554
+ rollout_results.append((heatmap, pred_label, pred_conf))
555
+ print(f'-> pred={CLASS_NAMES[pred_label]:15s} conf={pred_conf:.3f}')
556
+
557
+ except Exception as e:
558
+ print(f'FAILED: {e}')
559
+ preprocessed.append(None)
560
+ rollout_results.append(None)
561
+
562
+ # Remove rollout hooks BEFORE running IG to avoid extra computation
563
+ rollout.remove_hooks()
564
+ # Re-enable fused attention for faster forward passes during IG
565
+ for block in model.backbone.blocks:
566
+ block.attn.fused_attn = True
567
+ print(' Rollout hooks removed. fused_attn re-enabled for IG speed.')
568
+
569
+ # ---- 4. Run Integrated Gradients (no rollout hooks active) ----
570
+ print('\n[4/6] Computing Integrated Gradients attributions...')
571
+ disease_model = DiseaseLogitModel(model)
572
+ disease_model.eval()
573
+ ig_method = IntegratedGradients(disease_model)
574
+ print(f' Baseline: Gaussian blur (sigma=10), n_steps=50, '
575
+ f'internal_batch_size=25')
576
+
577
+ all_results = []
578
+ fundus_mask = create_fundus_mask()
579
+
580
+ for idx, sample in enumerate(samples):
581
+ if preprocessed[idx] is None or rollout_results[idx] is None:
582
+ continue
583
+
584
+ img_path = sample['image_path']
585
+ true_label = sample['true_label']
586
+ basename = os.path.basename(img_path)
587
+ img_np, img_orig, img_tensor = preprocessed[idx]
588
+ rollout_heatmap, pred_label, pred_conf = rollout_results[idx]
589
+
590
+ print(f' [{idx+1:2d}/{len(samples)}] '
591
+ f'{CLASS_NAMES[true_label]:15s} | {basename}', end=' ')
592
+
593
+ try:
594
+ # Fresh tensor (IG needs requires_grad)
595
+ ig_input = img_tensor.clone().detach().to(DEVICE)
596
+
597
+ ig_heatmap = compute_ig_attribution(
598
+ disease_model, ig_method, ig_input,
599
+ target_class=pred_label,
600
+ n_steps=50,
601
+ internal_batch_size=25,
602
+ sigma=10,
603
+ )
604
+
605
+ # Agreement
606
+ agreement = compute_agreement(rollout_heatmap, ig_heatmap,
607
+ fundus_mask)
608
+
609
+ print(f'-> pearson={agreement["pearson_r"]:.3f} '
610
+ f'IoU={agreement["iou_top20"]:.3f}')
611
+
612
+ all_results.append({
613
+ 'idx': idx,
614
+ 'image_path': img_path,
615
+ 'basename': basename,
616
+ 'true_label': true_label,
617
+ 'pred_label': pred_label,
618
+ 'pred_class': CLASS_NAMES[pred_label],
619
+ 'confidence': round(pred_conf, 4),
620
+ 'img_orig': img_orig,
621
+ 'rollout_heatmap': rollout_heatmap,
622
+ 'ig_heatmap': ig_heatmap,
623
+ 'agreement': agreement,
624
+ })
625
+
626
+ except Exception as e:
627
+ print(f'FAILED: {e}')
628
+ import traceback
629
+ traceback.print_exc()
630
+ continue
631
+
632
+ n_success = len(all_results)
633
+ print(f'\n Completed: {n_success}/{len(samples)} images')
634
+
635
+ if n_success == 0:
636
+ print('ERROR: No images processed successfully. Exiting.')
637
+ sys.exit(1)
638
+
639
+ # ---- 5. Generate visualisations ----
640
+ print('\n[5/6] Generating visualisations...')
641
+
642
+ # 5a. Individual IG heatmaps
643
+ print(' Saving individual IG heatmaps...')
644
+ for r in all_results:
645
+ fig, axes = plt.subplots(1, 3, figsize=(13, 4.5))
646
+
647
+ # Original
648
+ axes[0].imshow(r['img_orig'])
649
+ axes[0].set_title(f'Original\nTrue: {CLASS_NAMES[r["true_label"]]}',
650
+ fontsize=10, fontweight='bold')
651
+ axes[0].axis('off')
652
+
653
+ # IG heatmap (raw)
654
+ im = axes[1].imshow(r['ig_heatmap'], cmap='inferno', vmin=0, vmax=1)
655
+ axes[1].set_title('Integrated Gradients\n(attribution magnitude)',
656
+ fontsize=10)
657
+ axes[1].axis('off')
658
+
659
+ # IG overlay on original
660
+ ig_overlay = overlay_heatmap(r['img_orig'], r['ig_heatmap'],
661
+ alpha=0.6, cmap_name='inferno')
662
+ axes[2].imshow(ig_overlay)
663
+ correct = r['pred_label'] == r['true_label']
664
+ status = 'OK' if correct else 'WRONG'
665
+ axes[2].set_title(
666
+ f'IG Overlay\nPred: {r["pred_class"]} ({r["confidence"]:.2f}) [{status}]',
667
+ fontsize=10,
668
+ color='green' if correct else 'red',
669
+ fontweight='bold')
670
+ axes[2].axis('off')
671
+
672
+ plt.tight_layout()
673
+ save_path = os.path.join(XAI_DIR,
674
+ f'ig_individual_{r["idx"]+1:02d}.png')
675
+ fig.savefig(save_path, dpi=150, bbox_inches='tight',
676
+ facecolor='white')
677
+ plt.close(fig)
678
+
679
+ print(f' Saved {n_success} individual IG heatmaps')
680
+
681
+ # 5b. Comparison grid: 20 rows x 3 columns
682
+ print(' Generating comparison grid...')
683
+ n_rows = n_success
684
+ fig, axes = plt.subplots(n_rows, 3, figsize=(14, 4.2 * n_rows))
685
+
686
+ # Handle single-row case
687
+ if n_rows == 1:
688
+ axes = axes[np.newaxis, :]
689
+
690
+ # Column headers
691
+ col_titles = ['Original Image', 'Attention Rollout', 'Integrated Gradients']
692
+
693
+ for i, r in enumerate(all_results):
694
+ true_name = CLASS_NAMES[r['true_label']]
695
+ pred_name = r['pred_class']
696
+ correct = r['pred_label'] == r['true_label']
697
+ status = 'OK' if correct else 'WRONG'
698
+
699
+ # Column 0: Original
700
+ axes[i, 0].imshow(r['img_orig'])
701
+ axes[i, 0].set_ylabel(f'#{r["idx"]+1}\nTrue: {true_name}',
702
+ fontsize=9, fontweight='bold', rotation=0,
703
+ labelpad=70, va='center')
704
+ axes[i, 0].set_xticks([])
705
+ axes[i, 0].set_yticks([])
706
+ if i == 0:
707
+ axes[i, 0].set_title(col_titles[0], fontsize=12, fontweight='bold',
708
+ pad=10)
709
+
710
+ # Column 1: Attention Rollout overlay
711
+ rollout_overlay = overlay_heatmap(r['img_orig'], r['rollout_heatmap'],
712
+ alpha=0.6, cmap_name='inferno')
713
+ axes[i, 1].imshow(rollout_overlay)
714
+ axes[i, 1].axis('off')
715
+ if i == 0:
716
+ axes[i, 1].set_title(col_titles[1], fontsize=12, fontweight='bold',
717
+ pad=10)
718
+
719
+ # Column 2: IG overlay
720
+ ig_overlay = overlay_heatmap(r['img_orig'], r['ig_heatmap'],
721
+ alpha=0.6, cmap_name='inferno')
722
+ axes[i, 2].imshow(ig_overlay)
723
+ color = 'green' if correct else 'red'
724
+ axes[i, 2].set_xlabel(
725
+ f'Pred: {pred_name} ({r["confidence"]:.2f}) [{status}] | '
726
+ f'Pearson r={r["agreement"]["pearson_r"]:.2f}',
727
+ fontsize=8, color=color, fontweight='bold')
728
+ axes[i, 2].set_xticks([])
729
+ axes[i, 2].set_yticks([])
730
+ if i == 0:
731
+ axes[i, 2].set_title(col_titles[2], fontsize=12, fontweight='bold',
732
+ pad=10)
733
+
734
+ plt.suptitle('RetinaSense v3.0 -- Attention Rollout vs Integrated Gradients',
735
+ fontsize=14, fontweight='bold', y=1.001)
736
+ plt.tight_layout()
737
+ grid_path = os.path.join(XAI_DIR, 'comparison_grid.png')
738
+ fig.savefig(grid_path, dpi=120, bbox_inches='tight', facecolor='white')
739
+ plt.close(fig)
740
+ print(f' Saved: {grid_path}')
741
+
742
+ # 5c. Agreement heatmap (matrix showing per-image spatial correlation)
743
+ print(' Generating agreement heatmap...')
744
+
745
+ # Build per-image metrics matrix
746
+ image_labels = [
747
+ f'#{r["idx"]+1} {CLASS_NAMES[r["true_label"]][:6]}'
748
+ for r in all_results
749
+ ]
750
+ metric_names = ['Pearson r', 'IoU (top 20%)', 'Cosine Sim']
751
+ agreement_matrix = np.zeros((n_success, 3))
752
+ for i, r in enumerate(all_results):
753
+ agreement_matrix[i, 0] = r['agreement']['pearson_r']
754
+ agreement_matrix[i, 1] = r['agreement']['iou_top20']
755
+ agreement_matrix[i, 2] = r['agreement']['cosine_sim']
756
+
757
+ fig, ax = plt.subplots(figsize=(7, max(8, n_success * 0.45)))
758
+ im = ax.imshow(agreement_matrix, cmap='RdYlGn', aspect='auto',
759
+ vmin=-0.2, vmax=1.0)
760
+
761
+ ax.set_xticks(range(3))
762
+ ax.set_xticklabels(metric_names, fontsize=10, fontweight='bold')
763
+ ax.set_yticks(range(n_success))
764
+ ax.set_yticklabels(image_labels, fontsize=8)
765
+
766
+ # Annotate cells
767
+ for i in range(n_success):
768
+ for j in range(3):
769
+ val = agreement_matrix[i, j]
770
+ color = 'white' if val < 0.3 else 'black'
771
+ ax.text(j, i, f'{val:.2f}', ha='center', va='center',
772
+ fontsize=8, color=color, fontweight='bold')
773
+
774
+ ax.set_title('Rollout vs IG Agreement Scores per Image',
775
+ fontsize=12, fontweight='bold', pad=12)
776
+ plt.colorbar(im, ax=ax, shrink=0.6, label='Score')
777
+ plt.tight_layout()
778
+
779
+ heatmap_path = os.path.join(XAI_DIR, 'agreement_heatmap.png')
780
+ fig.savefig(heatmap_path, dpi=150, bbox_inches='tight', facecolor='white')
781
+ plt.close(fig)
782
+ print(f' Saved: {heatmap_path}')
783
+
784
+ # ---- 6. Save agreement scores JSON ----
785
+ print('\n[6/6] Saving agreement scores...')
786
+
787
+ scores_output = {
788
+ 'summary': {
789
+ 'n_images': n_success,
790
+ 'mean_pearson_r': round(float(agreement_matrix[:, 0].mean()), 4),
791
+ 'mean_iou_top20': round(float(agreement_matrix[:, 1].mean()), 4),
792
+ 'mean_cosine_sim': round(float(agreement_matrix[:, 2].mean()), 4),
793
+ 'std_pearson_r': round(float(agreement_matrix[:, 0].std()), 4),
794
+ 'std_iou_top20': round(float(agreement_matrix[:, 1].std()), 4),
795
+ 'std_cosine_sim': round(float(agreement_matrix[:, 2].std()), 4),
796
+ },
797
+ 'per_image': [],
798
+ }
799
+ for r in all_results:
800
+ scores_output['per_image'].append({
801
+ 'image': r['basename'],
802
+ 'true_label': r['true_label'],
803
+ 'true_class': CLASS_NAMES[r['true_label']],
804
+ 'pred_label': r['pred_label'],
805
+ 'pred_class': r['pred_class'],
806
+ 'confidence': r['confidence'],
807
+ 'agreement': r['agreement'],
808
+ })
809
+
810
+ # Per-class summary
811
+ per_class = {}
812
+ for cls_idx in range(NUM_CLASSES):
813
+ cls_results = [r for r in all_results if r['true_label'] == cls_idx]
814
+ if cls_results:
815
+ pearson_vals = [r['agreement']['pearson_r'] for r in cls_results]
816
+ iou_vals = [r['agreement']['iou_top20'] for r in cls_results]
817
+ cosine_vals = [r['agreement']['cosine_sim'] for r in cls_results]
818
+ per_class[CLASS_NAMES[cls_idx]] = {
819
+ 'n_images': len(cls_results),
820
+ 'mean_pearson_r': round(float(np.mean(pearson_vals)), 4),
821
+ 'mean_iou_top20': round(float(np.mean(iou_vals)), 4),
822
+ 'mean_cosine_sim': round(float(np.mean(cosine_vals)), 4),
823
+ }
824
+ scores_output['per_class'] = per_class
825
+
826
+ json_path = os.path.join(XAI_DIR, 'agreement_score.json')
827
+ with open(json_path, 'w') as f:
828
+ json.dump(scores_output, f, indent=2)
829
+ print(f' Saved: {json_path}')
830
+
831
+ # ---- Summary ----
832
+ elapsed = time.time() - t_start
833
+ n_correct = sum(1 for r in all_results
834
+ if r['pred_label'] == r['true_label'])
835
+
836
+ print('\n' + '=' * 65)
837
+ print(' PHASE 1C COMPLETE: Integrated Gradients XAI')
838
+ print('=' * 65)
839
+ print(f' Images processed : {n_success}/{len(samples)}')
840
+ print(f' Correct preds : {n_correct}/{n_success} '
841
+ f'({100*n_correct/max(n_success,1):.1f}%)')
842
+ print(f' Mean Pearson r : {scores_output["summary"]["mean_pearson_r"]:.4f}')
843
+ print(f' Mean IoU (top 20%): {scores_output["summary"]["mean_iou_top20"]:.4f}')
844
+ print(f' Mean Cosine Sim : {scores_output["summary"]["mean_cosine_sim"]:.4f}')
845
+ print(f' Time elapsed : {elapsed:.1f}s')
846
+ print(f' Outputs in : {XAI_DIR}')
847
+ print('=' * 65)
848
+
849
+ # List outputs
850
+ print('\n Output files:')
851
+ for fname in sorted(os.listdir(XAI_DIR)):
852
+ fpath = os.path.join(XAI_DIR, fname)
853
+ size_kb = os.path.getsize(fpath) / 1024
854
+ print(f' {fname:40s} {size_kb:8.1f} KB')
855
+
856
+
857
+ if __name__ == '__main__':
858
+ main()