Zorrojurro commited on
Commit
0b3f1ab
·
verified ·
1 Parent(s): 1a72329

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +504 -0
app.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Brain Tumor Detection — Gradio Space
3
+ Hybrid CNN-ViT model with Grad-CAM explainability.
4
+
5
+ Author: Vishnu K (ZorroJurro)
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torchvision.models as models
15
+ from torchvision import transforms
16
+ from PIL import Image
17
+ import cv2
18
+ import gradio as gr
19
+ from huggingface_hub import hf_hub_download
20
+ from einops import rearrange, repeat
21
+ from typing import Dict, Optional, Tuple
22
+ import matplotlib
23
+ matplotlib.use("Agg")
24
+ import matplotlib.pyplot as plt
25
+
26
+
27
+ # =============================================================================
28
+ # Model Architecture (self-contained)
29
+ # =============================================================================
30
+
31
+ class CNNBackbone(nn.Module):
32
+ def __init__(self, backbone_name="resnet50", pretrained=False, output_features=True):
33
+ super().__init__()
34
+ self.backbone_name = backbone_name.lower()
35
+ self.output_features = output_features
36
+ configs = {
37
+ "resnet50": (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2, 2048),
38
+ }
39
+ model_fn, weights, self.num_features = configs[self.backbone_name]
40
+ model = model_fn(weights=weights if pretrained else None)
41
+ self.backbone = nn.Sequential(*list(model.children())[:-2])
42
+
43
+ def forward(self, x):
44
+ return self.backbone(x)
45
+
46
+
47
+ class PatchEmbedding(nn.Module):
48
+ def __init__(self, feature_size=7, feature_dim=2048, embed_dim=512, patch_size=1):
49
+ super().__init__()
50
+ self.patch_size = patch_size
51
+ self.num_patches = (feature_size // patch_size) ** 2
52
+ self.projection = nn.Linear(feature_dim, embed_dim) if patch_size == 1 else nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
53
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
54
+ self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim) * 0.02)
55
+
56
+ def forward(self, x):
57
+ B = x.shape[0]
58
+ if self.patch_size == 1:
59
+ x = rearrange(x, "b c h w -> b (h w) c")
60
+ x = self.projection(x)
61
+ else:
62
+ x = self.projection(x)
63
+ x = rearrange(x, "b c h w -> b (h w) c")
64
+ cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=B)
65
+ x = torch.cat([cls_tokens, x], dim=1)
66
+ x = x + self.pos_embedding[:, :x.size(1)]
67
+ return x
68
+
69
+
70
+ class MultiHeadSelfAttention(nn.Module):
71
+ def __init__(self, embed_dim=512, num_heads=8, dropout=0.1, attention_dropout=0.1):
72
+ super().__init__()
73
+ self.num_heads = num_heads
74
+ self.head_dim = embed_dim // num_heads
75
+ self.scale = self.head_dim ** -0.5
76
+ self.qkv = nn.Linear(embed_dim, embed_dim * 3)
77
+ self.attn_dropout = nn.Dropout(attention_dropout)
78
+ self.proj = nn.Linear(embed_dim, embed_dim)
79
+ self.proj_dropout = nn.Dropout(dropout)
80
+ self.attention_weights = None
81
+
82
+ def forward(self, x, return_attention=False):
83
+ B, N, D = x.shape
84
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
85
+ q, k, v = qkv[0], qkv[1], qkv[2]
86
+ attn = (q @ k.transpose(-2, -1)) * self.scale
87
+ attn = attn.softmax(dim=-1)
88
+ attn = self.attn_dropout(attn)
89
+ self.attention_weights = attn.detach()
90
+ x = (attn @ v).transpose(1, 2).reshape(B, N, D)
91
+ x = self.proj_dropout(self.proj(x))
92
+ return (x, attn) if return_attention else (x, None)
93
+
94
+
95
+ class TransformerBlock(nn.Module):
96
+ def __init__(self, embed_dim=512, num_heads=8, mlp_ratio=4.0, dropout=0.1, attention_dropout=0.1):
97
+ super().__init__()
98
+ self.norm1 = nn.LayerNorm(embed_dim)
99
+ self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout, attention_dropout)
100
+ self.norm2 = nn.LayerNorm(embed_dim)
101
+ mlp_hidden = int(embed_dim * mlp_ratio)
102
+ self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden, embed_dim), nn.Dropout(dropout))
103
+ self.attention_weights = None
104
+
105
+ def forward(self, x, return_attention=False):
106
+ attn_out, attn = self.attn(self.norm1(x), return_attention)
107
+ x = x + attn_out
108
+ x = x + self.mlp(self.norm2(x))
109
+ self.attention_weights = attn
110
+ return (x, attn) if return_attention else (x, None)
111
+
112
+
113
+ class ViTEncoder(nn.Module):
114
+ def __init__(self, embed_dim=512, depth=6, num_heads=8, mlp_ratio=4.0, dropout=0.1, attention_dropout=0.1):
115
+ super().__init__()
116
+ self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout, attention_dropout) for _ in range(depth)])
117
+ self.norm = nn.LayerNorm(embed_dim)
118
+ self.attention_weights_all = []
119
+
120
+ def forward(self, x, return_attention=False):
121
+ self.attention_weights_all = []
122
+ for block in self.blocks:
123
+ x, attn = block(x, return_attention)
124
+ if return_attention and attn is not None:
125
+ self.attention_weights_all.append(attn)
126
+ x = self.norm(x)
127
+ return (x, self.attention_weights_all) if return_attention else (x, None)
128
+
129
+
130
+ class LearnableRadiomics(nn.Module):
131
+ def __init__(self, in_channels=3, feature_dim=128):
132
+ super().__init__()
133
+ self.texture_branch = nn.Sequential(nn.Conv2d(in_channels, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, feature_dim // 2))
134
+ self.shape_branch = nn.Sequential(nn.Conv2d(in_channels, 32, 5, padding=2), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 5, padding=2), nn.BatchNorm2d(64), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, feature_dim // 2))
135
+ self.fusion = nn.Sequential(nn.Linear(feature_dim, feature_dim), nn.LayerNorm(feature_dim), nn.ReLU())
136
+
137
+ def forward(self, x):
138
+ return self.fusion(torch.cat([self.texture_branch(x), self.shape_branch(x)], dim=-1))
139
+
140
+
141
+ class FeatureFusion(nn.Module):
142
+ def __init__(self, cnn_dim=2048, vit_dim=512, radiomics_dim=128, output_dim=512, use_radiomics=True):
143
+ super().__init__()
144
+ self.use_radiomics = use_radiomics
145
+ total_dim = cnn_dim + vit_dim + (radiomics_dim if use_radiomics else 0)
146
+ self.fusion = nn.Sequential(nn.Linear(total_dim, output_dim * 2), nn.LayerNorm(output_dim * 2), nn.GELU(), nn.Dropout(0.1), nn.Linear(output_dim * 2, output_dim), nn.LayerNorm(output_dim), nn.GELU())
147
+
148
+ def forward(self, cnn_features, vit_features, radiomics_features=None):
149
+ parts = [cnn_features, vit_features]
150
+ if self.use_radiomics and radiomics_features is not None:
151
+ parts.append(radiomics_features)
152
+ return self.fusion(torch.cat(parts, dim=-1))
153
+
154
+
155
+ class HybridCNNViT(nn.Module):
156
+ def __init__(self, num_classes=4, cnn_backbone="resnet50", cnn_pretrained=False,
157
+ vit_embed_dim=512, vit_depth=6, vit_num_heads=8, vit_mlp_ratio=4.0,
158
+ use_radiomics=True, radiomics_dim=128, fusion_type="concat", dropout=0.3):
159
+ super().__init__()
160
+ self.use_radiomics = use_radiomics
161
+ self.cnn = CNNBackbone(backbone_name=cnn_backbone, pretrained=cnn_pretrained, output_features=True)
162
+ cnn_feature_dim = self.cnn.num_features
163
+ self.patch_embed = PatchEmbedding(feature_size=7, feature_dim=cnn_feature_dim, embed_dim=vit_embed_dim, patch_size=1)
164
+ self.vit_encoder = ViTEncoder(embed_dim=vit_embed_dim, depth=vit_depth, num_heads=vit_num_heads, mlp_ratio=vit_mlp_ratio, dropout=dropout * 0.5)
165
+ self.cnn_pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten())
166
+ self.radiomics = LearnableRadiomics(in_channels=3, feature_dim=radiomics_dim) if use_radiomics else None
167
+ if not use_radiomics:
168
+ radiomics_dim = 0
169
+ self.fusion = FeatureFusion(cnn_dim=cnn_feature_dim, vit_dim=vit_embed_dim, radiomics_dim=radiomics_dim, output_dim=512, use_radiomics=use_radiomics)
170
+ self.classifier = nn.Sequential(nn.Dropout(dropout), nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout * 0.5), nn.Linear(256, num_classes))
171
+ self.attention_weights = None
172
+
173
+ def forward(self, x, return_features=False, return_attention=False):
174
+ cnn_features = self.cnn(x)
175
+ cnn_pooled = self.cnn_pool(cnn_features)
176
+ patch_embeddings = self.patch_embed(cnn_features)
177
+ vit_output, attention = self.vit_encoder(patch_embeddings, return_attention)
178
+ vit_cls = vit_output[:, 0]
179
+ if return_attention:
180
+ self.attention_weights = attention
181
+ radiomics_features = self.radiomics(x) if self.use_radiomics else None
182
+ fused = self.fusion(cnn_pooled, vit_cls, radiomics_features)
183
+ logits = self.classifier(fused)
184
+ output = {"logits": logits}
185
+ if return_features:
186
+ output["cnn_features"] = cnn_pooled
187
+ output["vit_features"] = vit_cls
188
+ output["fused_features"] = fused
189
+ if return_attention:
190
+ output["attention"] = attention
191
+ return output
192
+
193
+
194
+ class BrainTumorClassifier(nn.Module):
195
+ def __init__(self, config):
196
+ super().__init__()
197
+ mc = config.get("model", {})
198
+ self.model = HybridCNNViT(
199
+ num_classes=config.get("data", {}).get("num_classes", 4),
200
+ cnn_backbone=mc.get("cnn_backbone", "resnet50"),
201
+ cnn_pretrained=mc.get("cnn_pretrained", False),
202
+ vit_embed_dim=mc.get("vit_embed_dim", 512),
203
+ vit_depth=mc.get("vit_depth", 6),
204
+ vit_num_heads=mc.get("vit_num_heads", 8),
205
+ vit_mlp_ratio=mc.get("vit_mlp_ratio", 4.0),
206
+ use_radiomics=mc.get("use_radiomics", True),
207
+ radiomics_dim=mc.get("radiomics_features", 128),
208
+ dropout=mc.get("dropout", 0.3),
209
+ )
210
+ self.num_classes = config.get("data", {}).get("num_classes", 4)
211
+
212
+ def forward(self, x):
213
+ return self.model(x)["logits"]
214
+
215
+
216
+ # =============================================================================
217
+ # Grad-CAM Implementation
218
+ # =============================================================================
219
+
220
+ class GradCAM:
221
+ """Simplified Grad-CAM for the CNN backbone."""
222
+
223
+ def __init__(self, model: HybridCNNViT):
224
+ self.model = model
225
+ self.gradients = None
226
+ self.activations = None
227
+ self._register_hooks()
228
+
229
+ def _register_hooks(self):
230
+ # Hook into the last conv layer of the CNN backbone
231
+ target_layer = self.model.cnn.backbone[-1]
232
+
233
+ def forward_hook(module, input, output):
234
+ self.activations = output.detach()
235
+
236
+ def backward_hook(module, grad_input, grad_output):
237
+ self.gradients = grad_output[0].detach()
238
+
239
+ target_layer.register_forward_hook(forward_hook)
240
+ target_layer.register_full_backward_hook(backward_hook)
241
+
242
+ def generate(self, input_tensor: torch.Tensor, target_class: int = None) -> np.ndarray:
243
+ self.model.eval()
244
+ input_tensor.requires_grad_(True)
245
+
246
+ output = self.model(input_tensor)
247
+ logits = output["logits"]
248
+
249
+ if target_class is None:
250
+ target_class = logits.argmax(dim=-1).item()
251
+
252
+ self.model.zero_grad()
253
+ logits[0, target_class].backward()
254
+
255
+ gradients = self.gradients
256
+ activations = self.activations
257
+
258
+ # Global average pooling of gradients
259
+ weights = gradients.mean(dim=(2, 3), keepdim=True)
260
+ cam = (weights * activations).sum(dim=1, keepdim=True)
261
+ cam = F.relu(cam)
262
+
263
+ # Normalize
264
+ cam = cam.squeeze().cpu().numpy()
265
+ cam = cam - cam.min()
266
+ cam = cam / (cam.max() + 1e-8)
267
+
268
+ return cam
269
+
270
+
271
+ def create_gradcam_overlay(
272
+ original_image: np.ndarray,
273
+ cam: np.ndarray,
274
+ alpha: float = 0.5,
275
+ ) -> np.ndarray:
276
+ """Create a Grad-CAM heatmap overlay on the original image."""
277
+ h, w = original_image.shape[:2]
278
+ cam_resized = cv2.resize(cam, (w, h))
279
+
280
+ # Apply colormap
281
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
282
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
283
+
284
+ # Overlay
285
+ overlay = np.float32(heatmap) * alpha + np.float32(original_image) * (1 - alpha)
286
+ overlay = np.clip(overlay, 0, 255).astype(np.uint8)
287
+
288
+ return overlay
289
+
290
+
291
+ # =============================================================================
292
+ # Model Loading
293
+ # =============================================================================
294
+
295
+ REPO_ID = "Zorrojurro/brain-tumor-cnn-vit"
296
+ CLASS_NAMES = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
297
+ CLASS_EMOJIS = {"Glioma": "🔴", "Meningioma": "🟠", "No Tumor": "🟢", "Pituitary": "🟡"}
298
+
299
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
300
+
301
+ # Image preprocessing
302
+ TRANSFORM = transforms.Compose([
303
+ transforms.Resize((224, 224)),
304
+ transforms.ToTensor(),
305
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
306
+ ])
307
+
308
+
309
+ def load_model():
310
+ """Download and load the model from Hugging Face Hub."""
311
+ print("📥 Downloading model from Hugging Face Hub...")
312
+
313
+ # Download checkpoint
314
+ checkpoint_path = hf_hub_download(
315
+ repo_id=REPO_ID,
316
+ filename="best_model.pth",
317
+ cache_dir="./model_cache",
318
+ )
319
+
320
+ # Create model
321
+ config = {
322
+ "data": {"num_classes": 4},
323
+ "model": {
324
+ "cnn_backbone": "resnet50",
325
+ "cnn_pretrained": False,
326
+ "vit_embed_dim": 512,
327
+ "vit_depth": 6,
328
+ "vit_num_heads": 8,
329
+ "vit_mlp_ratio": 4.0,
330
+ "use_radiomics": True,
331
+ "radiomics_features": 128,
332
+ "dropout": 0.3,
333
+ },
334
+ }
335
+
336
+ classifier = BrainTumorClassifier(config)
337
+ model = classifier.model
338
+
339
+ # Load checkpoint
340
+ checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
341
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
342
+
343
+ # Handle key prefix mismatches
344
+ new_state_dict = {}
345
+ for k, v in state_dict.items():
346
+ # Remove 'model.' prefix if present
347
+ new_key = k.replace("model.", "") if k.startswith("model.") else k
348
+ new_state_dict[new_key] = v
349
+
350
+ model.load_state_dict(new_state_dict, strict=False)
351
+ model.eval().to(DEVICE)
352
+
353
+ print(f"✅ Model loaded on {DEVICE}")
354
+ return model
355
+
356
+
357
+ # Load model at startup
358
+ MODEL = load_model()
359
+ GRADCAM = GradCAM(MODEL)
360
+
361
+
362
+ # =============================================================================
363
+ # Prediction Function
364
+ # =============================================================================
365
+
366
+ def predict(image: Image.Image):
367
+ """Run prediction and generate Grad-CAM visualization."""
368
+ if image is None:
369
+ return None, None, "Please upload an image."
370
+
371
+ # Convert to RGB
372
+ image = image.convert("RGB")
373
+ original_np = np.array(image)
374
+
375
+ # Preprocess
376
+ input_tensor = TRANSFORM(image).unsqueeze(0).to(DEVICE)
377
+
378
+ # Forward pass with gradients for Grad-CAM
379
+ with torch.enable_grad():
380
+ cam = GRADCAM.generate(input_tensor)
381
+
382
+ # Get predictions
383
+ with torch.no_grad():
384
+ output = MODEL(input_tensor)
385
+ logits = output["logits"]
386
+ probs = F.softmax(logits, dim=-1)[0]
387
+
388
+ # Build confidence dict
389
+ confidences = {}
390
+ for i, name in enumerate(CLASS_NAMES):
391
+ emoji = CLASS_EMOJIS[name]
392
+ confidences[f"{emoji} {name}"] = float(probs[i])
393
+
394
+ # Grad-CAM overlay
395
+ gradcam_overlay = create_gradcam_overlay(original_np, cam, alpha=0.45)
396
+
397
+ # Predicted class info
398
+ pred_idx = probs.argmax().item()
399
+ pred_name = CLASS_NAMES[pred_idx]
400
+ pred_conf = probs[pred_idx].item()
401
+ emoji = CLASS_EMOJIS[pred_name]
402
+
403
+ summary = f"## {emoji} {pred_name}\n**Confidence:** {pred_conf:.1%}\n\n"
404
+ if pred_name == "No Tumor":
405
+ summary += "✅ No tumor detected in the MRI scan."
406
+ else:
407
+ summary += f"⚠️ Potential **{pred_name.lower()}** detected. Please consult a medical professional."
408
+
409
+ return confidences, gradcam_overlay, summary
410
+
411
+
412
+ # =============================================================================
413
+ # Gradio UI
414
+ # =============================================================================
415
+
416
+ CUSTOM_CSS = """
417
+ .gradio-container {
418
+ max-width: 1100px !important;
419
+ margin: auto !important;
420
+ }
421
+ .gr-button-primary {
422
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
423
+ border: none !important;
424
+ }
425
+ .gr-button-primary:hover {
426
+ background: linear-gradient(135deg, #764ba2 0%, #667eea 100%) !important;
427
+ transform: translateY(-1px);
428
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4);
429
+ }
430
+ footer {visibility: hidden}
431
+ """
432
+
433
+ DESCRIPTION = """
434
+ # 🧠 Brain Tumor Detection — Hybrid CNN-ViT
435
+
436
+ Upload a brain MRI scan for instant AI-powered classification with **Grad-CAM explainability**.
437
+
438
+ **Model Architecture**: ResNet50 (CNN) + 6-Layer Vision Transformer + Learnable Radiomics
439
+ **Classes**: Glioma · Meningioma · No Tumor · Pituitary
440
+ **Performance**: 98% Accuracy · 0.97 F1-Score · 0.99 AUC
441
+
442
+ > ⚠️ *For research and educational purposes only. Not a substitute for professional medical diagnosis.*
443
+ """
444
+
445
+ with gr.Blocks(
446
+ css=CUSTOM_CSS,
447
+ theme=gr.themes.Soft(
448
+ primary_hue="indigo",
449
+ secondary_hue="purple",
450
+ neutral_hue="slate",
451
+ ),
452
+ title="Brain Tumor Detection — CNN-ViT",
453
+ ) as demo:
454
+
455
+ gr.Markdown(DESCRIPTION)
456
+
457
+ with gr.Row(equal_height=True):
458
+ with gr.Column(scale=1):
459
+ input_image = gr.Image(
460
+ type="pil",
461
+ label="Upload Brain MRI",
462
+ height=350,
463
+ )
464
+ predict_btn = gr.Button(
465
+ "🔬 Analyze MRI",
466
+ variant="primary",
467
+ size="lg",
468
+ )
469
+
470
+ with gr.Column(scale=1):
471
+ gradcam_output = gr.Image(
472
+ label="Grad-CAM Visualization",
473
+ height=350,
474
+ )
475
+
476
+ with gr.Row():
477
+ with gr.Column(scale=1):
478
+ label_output = gr.Label(
479
+ label="Classification Confidence",
480
+ num_top_classes=4,
481
+ )
482
+ with gr.Column(scale=1):
483
+ summary_output = gr.Markdown(
484
+ label="Diagnosis Summary",
485
+ )
486
+
487
+ predict_btn.click(
488
+ fn=predict,
489
+ inputs=[input_image],
490
+ outputs=[label_output, gradcam_output, summary_output],
491
+ )
492
+
493
+ gr.Markdown(
494
+ """
495
+ ---
496
+ **Built by [Vishnu K](https://huggingface.co/ZorroJurro)** ·
497
+ [Model Card](https://huggingface.co/ZorroJurro/brain-tumor-cnn-vit) ·
498
+ [GitHub](https://github.com/ZorroJurro)
499
+ """
500
+ )
501
+
502
+
503
+ if __name__ == "__main__":
504
+ demo.launch()