tanishq74 commited on
Commit
6193b31
·
verified ·
1 Parent(s): 0d9c418

Add retfound_backbone.py

Browse files
Files changed (1) hide show
  1. retfound_backbone.py +459 -0
retfound_backbone.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RETFound Backbone for RetinaSense
4
+ ==================================
5
+ Provides a RETFound-based backbone (ViT-Base/16 pretrained via MAE on
6
+ 1.6 million retinal images) as a drop-in replacement for the ImageNet-
7
+ pretrained ViT used in retinasense_v3.py.
8
+
9
+ RETFound (Retinal Foundation Model) was pretrained with masked
10
+ autoencoders on a large corpus of colour fundus photographs and OCT
11
+ scans. Using it as the backbone gives the model domain-specific
12
+ features (vessel topology, optic-disc morphology, drusen texture) that
13
+ ImageNet weights cannot provide.
14
+
15
+ Weight download
16
+ ---------------
17
+ Colour-fundus-photo weights are hosted on Hugging Face:
18
+ Repository : rmaphoh/RETFound_MAE
19
+ File : RETFound_cfp_weights.pth
20
+
21
+ You can download them in one of two ways:
22
+
23
+ 1. Programmatic (recommended):
24
+ from retfound_backbone import setup_retfound
25
+ path = setup_retfound() # downloads ~350 MB on first call
26
+
27
+ 2. Manual:
28
+ pip install huggingface_hub
29
+ huggingface-cli download rmaphoh/RETFound_MAE RETFound_cfp_weights.pth \\
30
+ --local-dir ./weights
31
+
32
+ Reference
33
+ ---------
34
+ Zhou et al., "A foundation model for generalizable disease detection
35
+ from retinal images", Nature 2023.
36
+ https://github.com/rmaphoh/RETFound_MAE
37
+
38
+ Usage with the training pipeline
39
+ ---------------------------------
40
+ from retfound_backbone import MultiTaskRetFound, setup_retfound
41
+
42
+ weights_path = setup_retfound() # or pass your own path
43
+ model = MultiTaskRetFound(pretrained_path=weights_path).to(device)
44
+
45
+ The model exposes the same (disease_logits, severity_logits) forward
46
+ interface as MultiTaskViT in retinasense_v3.py, so the training loop,
47
+ LLRD optimiser, and evaluation code work without modification.
48
+ """
49
+
50
+ import os
51
+ import re
52
+ import logging
53
+ from collections import OrderedDict
54
+
55
+ import torch
56
+ import torch.nn as nn
57
+ import timm
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # Constants
63
+ # ---------------------------------------------------------------------------
64
+ _HF_REPO = "rmaphoh/RETFound_MAE"
65
+ _HF_FILE = "RETFound_cfp_weights.pth"
66
+ _DEFAULT_WEIGHTS_DIR = os.path.join(os.path.dirname(__file__), "weights")
67
+ _VIT_EMBED_DIM = 768 # ViT-Base CLS token dimension
68
+
69
+
70
+ # ===================================================================
71
+ # Weight download helper
72
+ # ===================================================================
73
+ def setup_retfound(
74
+ save_dir: str = _DEFAULT_WEIGHTS_DIR,
75
+ filename: str = _HF_FILE,
76
+ ) -> str:
77
+ """Download RETFound colour-fundus-photo weights from Hugging Face.
78
+
79
+ Uses ``huggingface_hub.hf_hub_download`` so that repeated calls are
80
+ no-ops (the hub client caches the file).
81
+
82
+ Parameters
83
+ ----------
84
+ save_dir : str
85
+ Local directory to store the weight file. Created if absent.
86
+ filename : str
87
+ Name of the weight file on Hugging Face.
88
+
89
+ Returns
90
+ -------
91
+ str
92
+ Absolute path to the downloaded ``.pth`` file.
93
+
94
+ Raises
95
+ ------
96
+ ImportError
97
+ If ``huggingface_hub`` is not installed.
98
+ """
99
+ try:
100
+ from huggingface_hub import hf_hub_download
101
+ except ImportError:
102
+ raise ImportError(
103
+ "huggingface_hub is required to download RETFound weights.\n"
104
+ "Install it with: pip install huggingface_hub"
105
+ )
106
+
107
+ os.makedirs(save_dir, exist_ok=True)
108
+ local_path = os.path.join(save_dir, filename)
109
+
110
+ if os.path.isfile(local_path):
111
+ logger.info("RETFound weights already present at %s", local_path)
112
+ return local_path
113
+
114
+ logger.info(
115
+ "Downloading RETFound weights from %s/%s ...", _HF_REPO, filename
116
+ )
117
+ downloaded = hf_hub_download(
118
+ repo_id=_HF_REPO,
119
+ filename=filename,
120
+ local_dir=save_dir,
121
+ local_dir_use_symlinks=False,
122
+ )
123
+ logger.info("RETFound weights saved to %s", downloaded)
124
+ return downloaded
125
+
126
+
127
+ # ===================================================================
128
+ # Key-mapping helpers (RETFound MAE checkpoint -> timm ViT)
129
+ # ===================================================================
130
+ def _map_retfound_keys(mae_state_dict: dict) -> OrderedDict:
131
+ """Translate a RETFound MAE encoder checkpoint into timm ViT keys.
132
+
133
+ RETFound saves its encoder weights under a ``'model'`` (or
134
+ ``'model_state_dict'``) top-level key. Inside, the naming
135
+ convention differs from timm's ``VisionTransformer``:
136
+
137
+ RETFound key pattern timm key pattern
138
+ -------------------------------- --------------------------------
139
+ encoder.patch_embed.* patch_embed.*
140
+ encoder.cls_token cls_token
141
+ encoder.pos_embed pos_embed
142
+ encoder.blocks.{i}.* blocks.{i}.*
143
+ encoder.norm.* norm.*
144
+ fc_norm.* (skipped -- MAE head norm)
145
+ decoder_* (skipped -- MAE decoder)
146
+ mask_token (skipped)
147
+
148
+ Some RETFound releases omit the ``encoder.`` prefix; both forms
149
+ are handled.
150
+
151
+ Parameters
152
+ ----------
153
+ mae_state_dict : dict
154
+ The raw state dict loaded from the ``.pth`` file, *after*
155
+ extracting the ``'model'`` sub-key if present.
156
+
157
+ Returns
158
+ -------
159
+ OrderedDict
160
+ State dict with keys compatible with
161
+ ``timm.create_model('vit_base_patch16_224', num_classes=0)``.
162
+ """
163
+ mapped = OrderedDict()
164
+
165
+ # Patterns to skip (decoder weights, mask token, MAE head norms)
166
+ _skip_prefixes = ("decoder", "mask_token", "fc_norm", "head")
167
+
168
+ for key, value in mae_state_dict.items():
169
+ # Skip decoder / MAE-head parameters
170
+ if any(key.startswith(p) for p in _skip_prefixes):
171
+ continue
172
+
173
+ new_key = key
174
+
175
+ # Strip 'encoder.' prefix if present
176
+ if new_key.startswith("encoder."):
177
+ new_key = new_key[len("encoder."):]
178
+
179
+ # Some checkpoints store blocks as 'encoder.blocks.N.*' which
180
+ # after stripping becomes 'blocks.N.*' -- already correct for timm.
181
+
182
+ # RETFound sometimes names the final LayerNorm 'norm.' which
183
+ # matches timm, but occasionally uses 'ln_pre' or 'encoder_norm'.
184
+ new_key = re.sub(r"^encoder_norm\.", "norm.", new_key)
185
+ new_key = re.sub(r"^ln_pre\.", "norm.", new_key)
186
+
187
+ mapped[new_key] = value
188
+
189
+ return mapped
190
+
191
+
192
+ # ===================================================================
193
+ # Backbone factory
194
+ # ===================================================================
195
+ def create_retfound_model(pretrained_path: str = None) -> nn.Module:
196
+ """Create a ViT-Base/16 backbone, optionally with RETFound weights.
197
+
198
+ Parameters
199
+ ----------
200
+ pretrained_path : str or None
201
+ Path to ``RETFound_cfp_weights.pth``. When *None*, the model
202
+ is initialised with ImageNet-pretrained timm weights (identical
203
+ to the v3 baseline).
204
+
205
+ Returns
206
+ -------
207
+ nn.Module
208
+ A ``timm`` ``VisionTransformer`` with ``num_classes=0``
209
+ (feature-extractor mode, returns CLS-token embeddings of
210
+ dimension 768).
211
+ """
212
+ # Start from the same timm architecture used in retinasense_v3.py
213
+ # so that LLRD, head structure, and image-size assumptions stay valid.
214
+ backbone = timm.create_model(
215
+ "vit_base_patch16_224",
216
+ pretrained=(pretrained_path is None), # ImageNet fallback
217
+ num_classes=0,
218
+ )
219
+
220
+ if pretrained_path is not None:
221
+ if not os.path.isfile(pretrained_path):
222
+ raise FileNotFoundError(
223
+ f"RETFound weights not found at {pretrained_path}. "
224
+ f"Run setup_retfound() or download manually from "
225
+ f"https://huggingface.co/{_HF_REPO}"
226
+ )
227
+
228
+ logger.info("Loading RETFound weights from %s", pretrained_path)
229
+ raw_ckpt = torch.load(pretrained_path, map_location="cpu", weights_only=False)
230
+
231
+ # RETFound checkpoints wrap encoder weights under 'model' or
232
+ # 'model_state_dict'.
233
+ if "model" in raw_ckpt:
234
+ raw_sd = raw_ckpt["model"]
235
+ elif "model_state_dict" in raw_ckpt:
236
+ raw_sd = raw_ckpt["model_state_dict"]
237
+ elif "state_dict" in raw_ckpt:
238
+ raw_sd = raw_ckpt["state_dict"]
239
+ else:
240
+ # Assume the file *is* the state dict directly
241
+ raw_sd = raw_ckpt
242
+
243
+ mapped_sd = _map_retfound_keys(raw_sd)
244
+
245
+ # Load with strict=False: RETFound may lack timm's head.*
246
+ # keys (we already set num_classes=0) and we deliberately
247
+ # dropped the decoder.
248
+ missing, unexpected = backbone.load_state_dict(mapped_sd, strict=False)
249
+
250
+ # Filter out expected mismatches for clean logging
251
+ expected_missing = {"head.weight", "head.bias"}
252
+ real_missing = [k for k in missing if k not in expected_missing]
253
+
254
+ if real_missing:
255
+ logger.warning(
256
+ "Keys in timm model but NOT in RETFound checkpoint (%d): %s",
257
+ len(real_missing),
258
+ real_missing[:10],
259
+ )
260
+ if unexpected:
261
+ logger.warning(
262
+ "Unexpected keys from RETFound checkpoint (%d): %s",
263
+ len(unexpected),
264
+ unexpected[:10],
265
+ )
266
+
267
+ n_loaded = len(mapped_sd) - len(unexpected)
268
+ logger.info(
269
+ "RETFound backbone loaded: %d parameters mapped, "
270
+ "%d missing (expected), %d unexpected (skipped)",
271
+ n_loaded,
272
+ len(real_missing),
273
+ len(unexpected),
274
+ )
275
+
276
+ return backbone
277
+
278
+
279
+ # ===================================================================
280
+ # Multi-task model with RETFound backbone
281
+ # ===================================================================
282
+ class MultiTaskRetFound(nn.Module):
283
+ """ViT-Base/16 (RETFound) with disease + severity classification heads.
284
+
285
+ Architecture mirrors ``MultiTaskViT`` from ``retinasense_v3.py`` so
286
+ that the LLRD optimiser, Focal Loss, MixUp, and evaluation code
287
+ work without changes.
288
+
289
+ Parameters
290
+ ----------
291
+ n_disease : int
292
+ Number of disease classes (default 5: Normal, DR, Glaucoma,
293
+ Cataract, AMD).
294
+ n_severity : int
295
+ Number of DR severity grades (default 5: 0-4 APTOS scale).
296
+ drop : float
297
+ Dropout probability applied to the CLS embedding before the
298
+ classification heads.
299
+ pretrained_path : str or None
300
+ Path to ``RETFound_cfp_weights.pth``. Pass *None* to fall
301
+ back to ImageNet-pretrained timm weights.
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ n_disease: int = 5,
307
+ n_severity: int = 5,
308
+ drop: float = 0.3,
309
+ pretrained_path: str = None,
310
+ ):
311
+ super().__init__()
312
+
313
+ # --- Backbone ---
314
+ self.backbone = create_retfound_model(pretrained_path=pretrained_path)
315
+ feat = _VIT_EMBED_DIM # 768
316
+
317
+ # --- Shared dropout on CLS embedding ---
318
+ self.drop = nn.Dropout(drop)
319
+
320
+ # --- Disease classification head (5-class) ---
321
+ # Same architecture as MultiTaskViT: 768 -> 512 -> 256 -> n_disease
322
+ self.disease_head = nn.Sequential(
323
+ nn.Linear(feat, 512),
324
+ nn.BatchNorm1d(512),
325
+ nn.ReLU(),
326
+ nn.Dropout(0.3),
327
+ nn.Linear(512, 256),
328
+ nn.BatchNorm1d(256),
329
+ nn.ReLU(),
330
+ nn.Dropout(0.2),
331
+ nn.Linear(256, n_disease),
332
+ )
333
+
334
+ # --- Severity grading head (5-class, APTOS DR grades 0-4) ---
335
+ self.severity_head = nn.Sequential(
336
+ nn.Linear(feat, 256),
337
+ nn.BatchNorm1d(256),
338
+ nn.ReLU(),
339
+ nn.Dropout(0.3),
340
+ nn.Linear(256, n_severity),
341
+ )
342
+
343
+ def forward(self, x: torch.Tensor):
344
+ """Forward pass.
345
+
346
+ Parameters
347
+ ----------
348
+ x : torch.Tensor
349
+ Batch of images, shape ``(B, 3, 224, 224)``.
350
+
351
+ Returns
352
+ -------
353
+ tuple[torch.Tensor, torch.Tensor]
354
+ ``(disease_logits, severity_logits)`` each of shape
355
+ ``(B, n_classes)``.
356
+ """
357
+ f = self.backbone(x) # (B, 768)
358
+ f = self.drop(f)
359
+ return self.disease_head(f), self.severity_head(f)
360
+
361
+
362
+ # ===================================================================
363
+ # LLRD helper (convenience re-export so callers do not have to touch
364
+ # retinasense_v3.py internals)
365
+ # ===================================================================
366
+ def get_retfound_optimizer_with_llrd(
367
+ model: MultiTaskRetFound,
368
+ base_lr: float = 3e-4,
369
+ decay_factor: float = 0.75,
370
+ weight_decay: float = 1e-4,
371
+ ) -> torch.optim.AdamW:
372
+ """Build an AdamW optimiser with Layer-wise Learning Rate Decay.
373
+
374
+ Identical strategy to ``get_optimizer_with_llrd`` in retinasense_v3.py:
375
+ - Head layers: base_lr
376
+ - Transformer blocks 11-0: base_lr * decay^(12-block_idx)
377
+ - Patch/pos/cls embeddings: base_lr * decay^13
378
+
379
+ Parameters
380
+ ----------
381
+ model : MultiTaskRetFound
382
+ The model whose parameters will be optimised.
383
+ base_lr : float
384
+ Maximum learning rate (applied to classification heads).
385
+ decay_factor : float
386
+ Multiplicative factor per transformer block.
387
+ weight_decay : float
388
+ L2 regularisation coefficient.
389
+
390
+ Returns
391
+ -------
392
+ torch.optim.AdamW
393
+ """
394
+ param_groups = []
395
+
396
+ # 1. Classification heads (full LR)
397
+ head_params = (
398
+ list(model.disease_head.parameters())
399
+ + list(model.severity_head.parameters())
400
+ + list(model.drop.parameters())
401
+ )
402
+ param_groups.append({"params": head_params, "lr": base_lr})
403
+
404
+ # 2. Transformer blocks (12 blocks, indexed 11 -> 0)
405
+ blocks = model.backbone.blocks
406
+ num_blocks = len(blocks)
407
+ for block_idx in range(num_blocks - 1, -1, -1):
408
+ distance = num_blocks - block_idx # 1 for block[11], 12 for block[0]
409
+ lr_i = base_lr * (decay_factor ** distance)
410
+ param_groups.append({
411
+ "params": list(blocks[block_idx].parameters()),
412
+ "lr": lr_i,
413
+ })
414
+
415
+ # 3. Patch embed + positional embed + CLS token + final norm
416
+ embed_lr = base_lr * (decay_factor ** (num_blocks + 1))
417
+ embed_params = (
418
+ list(model.backbone.patch_embed.parameters())
419
+ + [model.backbone.cls_token, model.backbone.pos_embed]
420
+ + list(model.backbone.norm.parameters())
421
+ )
422
+ param_groups.append({"params": embed_params, "lr": embed_lr})
423
+
424
+ optimizer = torch.optim.AdamW(param_groups, weight_decay=weight_decay)
425
+
426
+ # Log LR spread
427
+ lrs = [g["lr"] for g in param_groups]
428
+ logger.info(
429
+ "LLRD optimizer: %d groups | Head %.2e | Block[11] %.2e | "
430
+ "Block[0] %.2e | Embed %.2e",
431
+ len(param_groups), lrs[0], lrs[1], lrs[-2], lrs[-1],
432
+ )
433
+
434
+ return optimizer
435
+
436
+
437
+ # ===================================================================
438
+ # Quick sanity check
439
+ # ===================================================================
440
+ if __name__ == "__main__":
441
+ logging.basicConfig(level=logging.INFO)
442
+
443
+ print("Creating MultiTaskRetFound with ImageNet fallback weights ...")
444
+ model = MultiTaskRetFound(pretrained_path=None)
445
+ dummy = torch.randn(2, 3, 224, 224)
446
+ d_out, s_out = model(dummy)
447
+ print(f" disease_logits : {d_out.shape}") # (2, 5)
448
+ print(f" severity_logits : {s_out.shape}") # (2, 5)
449
+
450
+ total = sum(p.numel() for p in model.parameters())
451
+ print(f" Total params : {total:,}")
452
+
453
+ opt = get_retfound_optimizer_with_llrd(model)
454
+ print(f" Optimizer groups: {len(opt.param_groups)}")
455
+
456
+ print("\nTo load RETFound weights instead:")
457
+ print(" path = setup_retfound()")
458
+ print(" model = MultiTaskRetFound(pretrained_path=path)")
459
+ print("\nDone.")