Zorrojurro commited on
Commit
2749a58
Β·
verified Β·
1 Parent(s): 303959d

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +490 -0
model.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid CNN-ViT Model for Brain Tumor Classification.
3
+
4
+ Self-contained module bundling CNN backbone, Vision Transformer,
5
+ Radiomics, and Feature Fusion into a single file for Hugging Face deployment.
6
+
7
+ Architecture:
8
+ 1. ResNet50 CNN backbone β†’ local texture/shape features
9
+ 2. Vision Transformer encoder β†’ global context via self-attention
10
+ 3. Learnable Radiomics branch β†’ texture + shape features
11
+ 4. Feature Fusion β†’ concatenation + MLP projection
12
+ 5. Classification Head β†’ 4-class tumor prediction
13
+
14
+ Author: Vishnu K (ZorroJurro)
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torchvision.models as models
21
+ from typing import Dict, List, Optional, Tuple
22
+ from einops import rearrange, repeat
23
+
24
+
25
+ # =============================================================================
26
+ # 1. CNN Backbone
27
+ # =============================================================================
28
+
29
+ class CNNBackbone(nn.Module):
30
+ """ResNet50 backbone for local feature extraction from brain MRI."""
31
+
32
+ def __init__(
33
+ self,
34
+ backbone_name: str = "resnet50",
35
+ pretrained: bool = True,
36
+ output_features: bool = True,
37
+ ):
38
+ super().__init__()
39
+ self.backbone_name = backbone_name.lower()
40
+ self.output_features = output_features
41
+
42
+ resnet_configs = {
43
+ "resnet18": (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1, 512),
44
+ "resnet34": (models.resnet34, models.ResNet34_Weights.IMAGENET1K_V1, 512),
45
+ "resnet50": (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2, 2048),
46
+ "resnet101": (models.resnet101, models.ResNet101_Weights.IMAGENET1K_V2, 2048),
47
+ }
48
+
49
+ if self.backbone_name not in resnet_configs:
50
+ raise ValueError(f"Unsupported backbone: {self.backbone_name}")
51
+
52
+ model_fn, weights, self.num_features = resnet_configs[self.backbone_name]
53
+ model = model_fn(weights=weights if pretrained else None)
54
+
55
+ # Remove final avg pool and fc to get feature maps
56
+ self.backbone = nn.Sequential(*list(model.children())[:-2])
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ return self.backbone(x)
60
+
61
+
62
+ # =============================================================================
63
+ # 2. Vision Transformer Components
64
+ # =============================================================================
65
+
66
+ class PatchEmbedding(nn.Module):
67
+ """Convert CNN feature maps to patch embeddings for ViT."""
68
+
69
+ def __init__(
70
+ self,
71
+ feature_size: int = 7,
72
+ feature_dim: int = 2048,
73
+ embed_dim: int = 512,
74
+ patch_size: int = 1,
75
+ ):
76
+ super().__init__()
77
+ self.feature_size = feature_size
78
+ self.patch_size = patch_size
79
+ self.num_patches = (feature_size // patch_size) ** 2
80
+
81
+ if patch_size == 1:
82
+ self.projection = nn.Linear(feature_dim, embed_dim)
83
+ else:
84
+ self.projection = nn.Conv2d(
85
+ feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size
86
+ )
87
+
88
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
89
+ self.pos_embedding = nn.Parameter(
90
+ torch.randn(1, self.num_patches + 1, embed_dim) * 0.02
91
+ )
92
+ self.embed_dim = embed_dim
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ B = x.shape[0]
96
+
97
+ if self.patch_size == 1:
98
+ x = rearrange(x, "b c h w -> b (h w) c")
99
+ x = self.projection(x)
100
+ else:
101
+ x = self.projection(x)
102
+ x = rearrange(x, "b c h w -> b (h w) c")
103
+
104
+ cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=B)
105
+ x = torch.cat([cls_tokens, x], dim=1)
106
+ x = x + self.pos_embedding[:, : x.size(1)]
107
+ return x
108
+
109
+
110
+ class MultiHeadSelfAttention(nn.Module):
111
+ """Multi-Head Self-Attention for Vision Transformer."""
112
+
113
+ def __init__(
114
+ self,
115
+ embed_dim: int = 512,
116
+ num_heads: int = 8,
117
+ dropout: float = 0.1,
118
+ attention_dropout: float = 0.1,
119
+ ):
120
+ super().__init__()
121
+ assert embed_dim % num_heads == 0
122
+ self.num_heads = num_heads
123
+ self.head_dim = embed_dim // num_heads
124
+ self.scale = self.head_dim ** -0.5
125
+
126
+ self.qkv = nn.Linear(embed_dim, embed_dim * 3)
127
+ self.attn_dropout = nn.Dropout(attention_dropout)
128
+ self.proj = nn.Linear(embed_dim, embed_dim)
129
+ self.proj_dropout = nn.Dropout(dropout)
130
+ self.attention_weights = None
131
+
132
+ def forward(
133
+ self, x: torch.Tensor, return_attention: bool = False
134
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
135
+ B, N, D = x.shape
136
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
137
+ qkv = qkv.permute(2, 0, 3, 1, 4)
138
+ q, k, v = qkv[0], qkv[1], qkv[2]
139
+
140
+ attn = (q @ k.transpose(-2, -1)) * self.scale
141
+ attn = attn.softmax(dim=-1)
142
+ attn = self.attn_dropout(attn)
143
+ self.attention_weights = attn.detach()
144
+
145
+ x = (attn @ v).transpose(1, 2).reshape(B, N, D)
146
+ x = self.proj(x)
147
+ x = self.proj_dropout(x)
148
+
149
+ if return_attention:
150
+ return x, attn
151
+ return x, None
152
+
153
+
154
+ class TransformerBlock(nn.Module):
155
+ """Transformer encoder block: MHSA + FFN with residual connections."""
156
+
157
+ def __init__(
158
+ self,
159
+ embed_dim: int = 512,
160
+ num_heads: int = 8,
161
+ mlp_ratio: float = 4.0,
162
+ dropout: float = 0.1,
163
+ attention_dropout: float = 0.1,
164
+ ):
165
+ super().__init__()
166
+ self.norm1 = nn.LayerNorm(embed_dim)
167
+ self.attn = MultiHeadSelfAttention(
168
+ embed_dim, num_heads, dropout, attention_dropout
169
+ )
170
+ self.norm2 = nn.LayerNorm(embed_dim)
171
+ mlp_hidden = int(embed_dim * mlp_ratio)
172
+ self.mlp = nn.Sequential(
173
+ nn.Linear(embed_dim, mlp_hidden),
174
+ nn.GELU(),
175
+ nn.Dropout(dropout),
176
+ nn.Linear(mlp_hidden, embed_dim),
177
+ nn.Dropout(dropout),
178
+ )
179
+ self.attention_weights = None
180
+
181
+ def forward(
182
+ self, x: torch.Tensor, return_attention: bool = False
183
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
184
+ attn_out, attn = self.attn(self.norm1(x), return_attention)
185
+ x = x + attn_out
186
+ x = x + self.mlp(self.norm2(x))
187
+ self.attention_weights = attn
188
+ if return_attention:
189
+ return x, attn
190
+ return x, None
191
+
192
+
193
+ class ViTEncoder(nn.Module):
194
+ """Vision Transformer encoder: stack of TransformerBlocks."""
195
+
196
+ def __init__(
197
+ self,
198
+ embed_dim: int = 512,
199
+ depth: int = 6,
200
+ num_heads: int = 8,
201
+ mlp_ratio: float = 4.0,
202
+ dropout: float = 0.1,
203
+ attention_dropout: float = 0.1,
204
+ ):
205
+ super().__init__()
206
+ self.embed_dim = embed_dim
207
+ self.depth = depth
208
+ self.blocks = nn.ModuleList(
209
+ [
210
+ TransformerBlock(
211
+ embed_dim, num_heads, mlp_ratio, dropout, attention_dropout
212
+ )
213
+ for _ in range(depth)
214
+ ]
215
+ )
216
+ self.norm = nn.LayerNorm(embed_dim)
217
+ self.attention_weights_all = []
218
+
219
+ def forward(
220
+ self, x: torch.Tensor, return_attention: bool = False
221
+ ) -> Tuple[torch.Tensor, Optional[list]]:
222
+ self.attention_weights_all = []
223
+ for block in self.blocks:
224
+ x, attn = block(x, return_attention)
225
+ if return_attention and attn is not None:
226
+ self.attention_weights_all.append(attn)
227
+ x = self.norm(x)
228
+ if return_attention:
229
+ return x, self.attention_weights_all
230
+ return x, None
231
+
232
+
233
+ # =============================================================================
234
+ # 3. Learnable Radiomics
235
+ # =============================================================================
236
+
237
+ class LearnableRadiomics(nn.Module):
238
+ """CNN-based radiomics: texture + shape branches fused together."""
239
+
240
+ def __init__(self, in_channels: int = 3, feature_dim: int = 128):
241
+ super().__init__()
242
+ self.texture_branch = nn.Sequential(
243
+ nn.Conv2d(in_channels, 32, 3, padding=1),
244
+ nn.BatchNorm2d(32),
245
+ nn.ReLU(),
246
+ nn.Conv2d(32, 64, 3, padding=1),
247
+ nn.BatchNorm2d(64),
248
+ nn.ReLU(),
249
+ nn.AdaptiveAvgPool2d(1),
250
+ nn.Flatten(),
251
+ nn.Linear(64, feature_dim // 2),
252
+ )
253
+ self.shape_branch = nn.Sequential(
254
+ nn.Conv2d(in_channels, 32, 5, padding=2),
255
+ nn.BatchNorm2d(32),
256
+ nn.ReLU(),
257
+ nn.MaxPool2d(2),
258
+ nn.Conv2d(32, 64, 5, padding=2),
259
+ nn.BatchNorm2d(64),
260
+ nn.ReLU(),
261
+ nn.AdaptiveAvgPool2d(1),
262
+ nn.Flatten(),
263
+ nn.Linear(64, feature_dim // 2),
264
+ )
265
+ self.fusion = nn.Sequential(
266
+ nn.Linear(feature_dim, feature_dim),
267
+ nn.LayerNorm(feature_dim),
268
+ nn.ReLU(),
269
+ )
270
+ self.feature_dim = feature_dim
271
+
272
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
273
+ texture = self.texture_branch(x)
274
+ shape = self.shape_branch(x)
275
+ combined = torch.cat([texture, shape], dim=-1)
276
+ return self.fusion(combined)
277
+
278
+
279
+ # =============================================================================
280
+ # 4. Feature Fusion
281
+ # =============================================================================
282
+
283
+ class FeatureFusion(nn.Module):
284
+ """Fuse CNN, ViT, and radiomics features via concatenation + MLP."""
285
+
286
+ def __init__(
287
+ self,
288
+ cnn_dim: int = 2048,
289
+ vit_dim: int = 512,
290
+ radiomics_dim: int = 128,
291
+ output_dim: int = 512,
292
+ fusion_type: str = "concat",
293
+ use_radiomics: bool = True,
294
+ ):
295
+ super().__init__()
296
+ self.use_radiomics = use_radiomics
297
+ self.fusion_type = fusion_type
298
+
299
+ total_dim = cnn_dim + vit_dim + (radiomics_dim if use_radiomics else 0)
300
+
301
+ if fusion_type == "concat":
302
+ self.fusion = nn.Sequential(
303
+ nn.Linear(total_dim, output_dim * 2),
304
+ nn.LayerNorm(output_dim * 2),
305
+ nn.GELU(),
306
+ nn.Dropout(0.1),
307
+ nn.Linear(output_dim * 2, output_dim),
308
+ nn.LayerNorm(output_dim),
309
+ nn.GELU(),
310
+ )
311
+ self.output_dim = output_dim
312
+
313
+ def forward(
314
+ self,
315
+ cnn_features: torch.Tensor,
316
+ vit_features: torch.Tensor,
317
+ radiomics_features: Optional[torch.Tensor] = None,
318
+ ) -> torch.Tensor:
319
+ if self.use_radiomics and radiomics_features is not None:
320
+ x = torch.cat([cnn_features, vit_features, radiomics_features], dim=-1)
321
+ else:
322
+ x = torch.cat([cnn_features, vit_features], dim=-1)
323
+ return self.fusion(x)
324
+
325
+
326
+ # =============================================================================
327
+ # 5. Complete Hybrid CNN-ViT Model
328
+ # =============================================================================
329
+
330
+ class HybridCNNViT(nn.Module):
331
+ """
332
+ Hybrid CNN-ViT for Brain Tumor Classification.
333
+
334
+ Pipeline:
335
+ Image β†’ CNN Backbone β†’ Feature Maps
336
+ ↓
337
+ Patch Embedding β†’ ViT Encoder β†’ CLS Token
338
+ ↓
339
+ Image β†’ Radiomics Branch β†’ Radiomics Features
340
+ ↓
341
+ [CNN Pooled | ViT CLS | Radiomics] β†’ Fusion β†’ Classifier
342
+ """
343
+
344
+ def __init__(
345
+ self,
346
+ num_classes: int = 4,
347
+ cnn_backbone: str = "resnet50",
348
+ cnn_pretrained: bool = True,
349
+ vit_embed_dim: int = 512,
350
+ vit_depth: int = 6,
351
+ vit_num_heads: int = 8,
352
+ vit_mlp_ratio: float = 4.0,
353
+ use_radiomics: bool = True,
354
+ radiomics_dim: int = 128,
355
+ fusion_type: str = "concat",
356
+ dropout: float = 0.3,
357
+ ):
358
+ super().__init__()
359
+ self.use_radiomics = use_radiomics
360
+
361
+ # CNN Backbone
362
+ self.cnn = CNNBackbone(
363
+ backbone_name=cnn_backbone,
364
+ pretrained=cnn_pretrained,
365
+ output_features=True,
366
+ )
367
+ cnn_feature_dim = self.cnn.num_features
368
+ self.feature_size = 7
369
+
370
+ # Patch Embedding
371
+ self.patch_embed = PatchEmbedding(
372
+ feature_size=self.feature_size,
373
+ feature_dim=cnn_feature_dim,
374
+ embed_dim=vit_embed_dim,
375
+ patch_size=1,
376
+ )
377
+
378
+ # ViT Encoder
379
+ self.vit_encoder = ViTEncoder(
380
+ embed_dim=vit_embed_dim,
381
+ depth=vit_depth,
382
+ num_heads=vit_num_heads,
383
+ mlp_ratio=vit_mlp_ratio,
384
+ dropout=dropout * 0.5,
385
+ )
386
+
387
+ # CNN global pooling
388
+ self.cnn_pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten())
389
+
390
+ # Radiomics branch
391
+ if use_radiomics:
392
+ self.radiomics = LearnableRadiomics(in_channels=3, feature_dim=radiomics_dim)
393
+ else:
394
+ self.radiomics = None
395
+ radiomics_dim = 0
396
+
397
+ # Fusion
398
+ self.fusion = FeatureFusion(
399
+ cnn_dim=cnn_feature_dim,
400
+ vit_dim=vit_embed_dim,
401
+ radiomics_dim=radiomics_dim,
402
+ output_dim=512,
403
+ fusion_type=fusion_type,
404
+ use_radiomics=use_radiomics,
405
+ )
406
+
407
+ # Classifier
408
+ self.classifier = nn.Sequential(
409
+ nn.Dropout(dropout),
410
+ nn.Linear(512, 256),
411
+ nn.LayerNorm(256),
412
+ nn.GELU(),
413
+ nn.Dropout(dropout * 0.5),
414
+ nn.Linear(256, num_classes),
415
+ )
416
+
417
+ self.attention_weights = None
418
+
419
+ def forward(
420
+ self,
421
+ x: torch.Tensor,
422
+ return_features: bool = False,
423
+ return_attention: bool = False,
424
+ ) -> Dict[str, torch.Tensor]:
425
+ # CNN backbone
426
+ cnn_features = self.cnn(x)
427
+ cnn_pooled = self.cnn_pool(cnn_features)
428
+
429
+ # ViT encoder
430
+ patch_embeddings = self.patch_embed(cnn_features)
431
+ vit_output, attention = self.vit_encoder(patch_embeddings, return_attention)
432
+ vit_cls = vit_output[:, 0]
433
+
434
+ if return_attention:
435
+ self.attention_weights = attention
436
+
437
+ # Radiomics
438
+ if self.use_radiomics:
439
+ radiomics_features = self.radiomics(x)
440
+ else:
441
+ radiomics_features = None
442
+
443
+ # Fusion + Classification
444
+ fused = self.fusion(cnn_pooled, vit_cls, radiomics_features)
445
+ logits = self.classifier(fused)
446
+
447
+ output = {"logits": logits}
448
+ if return_features:
449
+ output["cnn_features"] = cnn_pooled
450
+ output["vit_features"] = vit_cls
451
+ output["fused_features"] = fused
452
+ if radiomics_features is not None:
453
+ output["radiomics_features"] = radiomics_features
454
+ if return_attention:
455
+ output["attention"] = attention
456
+
457
+ return output
458
+
459
+
460
+ class BrainTumorClassifier(nn.Module):
461
+ """Top-level wrapper that creates HybridCNNViT from config dict."""
462
+
463
+ def __init__(self, config: Dict):
464
+ super().__init__()
465
+ model_config = config.get("model", {})
466
+ self.model = HybridCNNViT(
467
+ num_classes=config.get("data", {}).get("num_classes", 4),
468
+ cnn_backbone=model_config.get("cnn_backbone", "resnet50"),
469
+ cnn_pretrained=model_config.get("cnn_pretrained", True),
470
+ vit_embed_dim=model_config.get("vit_embed_dim", 512),
471
+ vit_depth=model_config.get("vit_depth", 6),
472
+ vit_num_heads=model_config.get("vit_num_heads", 8),
473
+ vit_mlp_ratio=model_config.get("vit_mlp_ratio", 4.0),
474
+ use_radiomics=model_config.get("use_radiomics", True),
475
+ radiomics_dim=model_config.get("radiomics_features", 128),
476
+ fusion_type="concat",
477
+ dropout=model_config.get("dropout", 0.3),
478
+ )
479
+ self.num_classes = config.get("data", {}).get("num_classes", 4)
480
+
481
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
482
+ output = self.model(x)
483
+ return output["logits"]
484
+
485
+ def predict(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
486
+ with torch.no_grad():
487
+ logits = self.forward(x)
488
+ probs = F.softmax(logits, dim=-1)
489
+ preds = torch.argmax(probs, dim=-1)
490
+ return preds, probs