Spaces:
Sleeping
Sleeping
| """ | |
| Brain Tumor Detection — Gradio Space | |
| Hybrid CNN-ViT model with Grad-CAM explainability. | |
| Author: Vishnu K (ZorroJurro) | |
| """ | |
| import os | |
| import json | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| from torchvision import transforms | |
| from PIL import Image | |
| import cv2 | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from einops import rearrange, repeat | |
| from typing import Dict, Optional, Tuple | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| # ============================================================================= | |
| # Model Architecture (self-contained) | |
| # ============================================================================= | |
| class CNNBackbone(nn.Module): | |
| def __init__(self, backbone_name="resnet50", pretrained=False, output_features=True): | |
| super().__init__() | |
| self.backbone_name = backbone_name.lower() | |
| self.output_features = output_features | |
| configs = { | |
| "resnet50": (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2, 2048), | |
| } | |
| model_fn, weights, self.num_features = configs[self.backbone_name] | |
| model = model_fn(weights=weights if pretrained else None) | |
| self.backbone = nn.Sequential(*list(model.children())[:-2]) | |
| def forward(self, x): | |
| return self.backbone(x) | |
| class PatchEmbedding(nn.Module): | |
| def __init__(self, feature_size=7, feature_dim=2048, embed_dim=512, patch_size=1): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.num_patches = (feature_size // patch_size) ** 2 | |
| self.projection = nn.Linear(feature_dim, embed_dim) if patch_size == 1 else nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02) | |
| self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim) * 0.02) | |
| def forward(self, x): | |
| B = x.shape[0] | |
| if self.patch_size == 1: | |
| x = rearrange(x, "b c h w -> b (h w) c") | |
| x = self.projection(x) | |
| else: | |
| x = self.projection(x) | |
| x = rearrange(x, "b c h w -> b (h w) c") | |
| cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=B) | |
| x = torch.cat([cls_tokens, x], dim=1) | |
| x = x + self.pos_embedding[:, :x.size(1)] | |
| return x | |
| class MultiHeadSelfAttention(nn.Module): | |
| def __init__(self, embed_dim=512, num_heads=8, dropout=0.1, attention_dropout=0.1): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.scale = self.head_dim ** -0.5 | |
| self.qkv = nn.Linear(embed_dim, embed_dim * 3) | |
| self.attn_dropout = nn.Dropout(attention_dropout) | |
| self.proj = nn.Linear(embed_dim, embed_dim) | |
| self.proj_dropout = nn.Dropout(dropout) | |
| self.attention_weights = None | |
| def forward(self, x, return_attention=False): | |
| B, N, D = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_dropout(attn) | |
| self.attention_weights = attn.detach() | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, D) | |
| x = self.proj_dropout(self.proj(x)) | |
| return (x, attn) if return_attention else (x, None) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, embed_dim=512, num_heads=8, mlp_ratio=4.0, dropout=0.1, attention_dropout=0.1): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout, attention_dropout) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| mlp_hidden = int(embed_dim * mlp_ratio) | |
| self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden, embed_dim), nn.Dropout(dropout)) | |
| self.attention_weights = None | |
| def forward(self, x, return_attention=False): | |
| attn_out, attn = self.attn(self.norm1(x), return_attention) | |
| x = x + attn_out | |
| x = x + self.mlp(self.norm2(x)) | |
| self.attention_weights = attn | |
| return (x, attn) if return_attention else (x, None) | |
| class ViTEncoder(nn.Module): | |
| def __init__(self, embed_dim=512, depth=6, num_heads=8, mlp_ratio=4.0, dropout=0.1, attention_dropout=0.1): | |
| super().__init__() | |
| self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout, attention_dropout) for _ in range(depth)]) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| self.attention_weights_all = [] | |
| def forward(self, x, return_attention=False): | |
| self.attention_weights_all = [] | |
| for block in self.blocks: | |
| x, attn = block(x, return_attention) | |
| if return_attention and attn is not None: | |
| self.attention_weights_all.append(attn) | |
| x = self.norm(x) | |
| return (x, self.attention_weights_all) if return_attention else (x, None) | |
| class LearnableRadiomics(nn.Module): | |
| def __init__(self, in_channels=3, feature_dim=128): | |
| super().__init__() | |
| self.texture_branch = nn.Sequential(nn.Conv2d(in_channels, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, feature_dim // 2)) | |
| self.shape_branch = nn.Sequential(nn.Conv2d(in_channels, 32, 5, padding=2), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 5, padding=2), nn.BatchNorm2d(64), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, feature_dim // 2)) | |
| self.fusion = nn.Sequential(nn.Linear(feature_dim, feature_dim), nn.LayerNorm(feature_dim), nn.ReLU()) | |
| def forward(self, x): | |
| return self.fusion(torch.cat([self.texture_branch(x), self.shape_branch(x)], dim=-1)) | |
| class FeatureFusion(nn.Module): | |
| def __init__(self, cnn_dim=2048, vit_dim=512, radiomics_dim=128, output_dim=512, use_radiomics=True): | |
| super().__init__() | |
| self.use_radiomics = use_radiomics | |
| total_dim = cnn_dim + vit_dim + (radiomics_dim if use_radiomics else 0) | |
| self.fusion = nn.Sequential(nn.Linear(total_dim, output_dim * 2), nn.LayerNorm(output_dim * 2), nn.GELU(), nn.Dropout(0.1), nn.Linear(output_dim * 2, output_dim), nn.LayerNorm(output_dim), nn.GELU()) | |
| def forward(self, cnn_features, vit_features, radiomics_features=None): | |
| parts = [cnn_features, vit_features] | |
| if self.use_radiomics and radiomics_features is not None: | |
| parts.append(radiomics_features) | |
| return self.fusion(torch.cat(parts, dim=-1)) | |
| class HybridCNNViT(nn.Module): | |
| def __init__(self, num_classes=4, cnn_backbone="resnet50", cnn_pretrained=False, | |
| vit_embed_dim=512, vit_depth=6, vit_num_heads=8, vit_mlp_ratio=4.0, | |
| use_radiomics=True, radiomics_dim=128, fusion_type="concat", dropout=0.3): | |
| super().__init__() | |
| self.use_radiomics = use_radiomics | |
| self.cnn = CNNBackbone(backbone_name=cnn_backbone, pretrained=cnn_pretrained, output_features=True) | |
| cnn_feature_dim = self.cnn.num_features | |
| self.patch_embed = PatchEmbedding(feature_size=7, feature_dim=cnn_feature_dim, embed_dim=vit_embed_dim, patch_size=1) | |
| self.vit_encoder = ViTEncoder(embed_dim=vit_embed_dim, depth=vit_depth, num_heads=vit_num_heads, mlp_ratio=vit_mlp_ratio, dropout=dropout * 0.5) | |
| self.cnn_pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten()) | |
| self.radiomics = LearnableRadiomics(in_channels=3, feature_dim=radiomics_dim) if use_radiomics else None | |
| if not use_radiomics: | |
| radiomics_dim = 0 | |
| self.fusion = FeatureFusion(cnn_dim=cnn_feature_dim, vit_dim=vit_embed_dim, radiomics_dim=radiomics_dim, output_dim=512, use_radiomics=use_radiomics) | |
| self.classifier = nn.Sequential(nn.Dropout(dropout), nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout * 0.5), nn.Linear(256, num_classes)) | |
| self.attention_weights = None | |
| def forward(self, x, return_features=False, return_attention=False): | |
| cnn_features = self.cnn(x) | |
| cnn_pooled = self.cnn_pool(cnn_features) | |
| patch_embeddings = self.patch_embed(cnn_features) | |
| vit_output, attention = self.vit_encoder(patch_embeddings, return_attention) | |
| vit_cls = vit_output[:, 0] | |
| if return_attention: | |
| self.attention_weights = attention | |
| radiomics_features = self.radiomics(x) if self.use_radiomics else None | |
| fused = self.fusion(cnn_pooled, vit_cls, radiomics_features) | |
| logits = self.classifier(fused) | |
| output = {"logits": logits} | |
| if return_features: | |
| output["cnn_features"] = cnn_pooled | |
| output["vit_features"] = vit_cls | |
| output["fused_features"] = fused | |
| if return_attention: | |
| output["attention"] = attention | |
| return output | |
| class BrainTumorClassifier(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| mc = config.get("model", {}) | |
| self.model = HybridCNNViT( | |
| num_classes=config.get("data", {}).get("num_classes", 4), | |
| cnn_backbone=mc.get("cnn_backbone", "resnet50"), | |
| cnn_pretrained=mc.get("cnn_pretrained", False), | |
| vit_embed_dim=mc.get("vit_embed_dim", 512), | |
| vit_depth=mc.get("vit_depth", 6), | |
| vit_num_heads=mc.get("vit_num_heads", 8), | |
| vit_mlp_ratio=mc.get("vit_mlp_ratio", 4.0), | |
| use_radiomics=mc.get("use_radiomics", True), | |
| radiomics_dim=mc.get("radiomics_features", 128), | |
| dropout=mc.get("dropout", 0.3), | |
| ) | |
| self.num_classes = config.get("data", {}).get("num_classes", 4) | |
| def forward(self, x): | |
| return self.model(x)["logits"] | |
| # ============================================================================= | |
| # Grad-CAM Implementation | |
| # ============================================================================= | |
| class GradCAM: | |
| """Simplified Grad-CAM for the CNN backbone.""" | |
| def __init__(self, model: HybridCNNViT): | |
| self.model = model | |
| self.gradients = None | |
| self.activations = None | |
| self._register_hooks() | |
| def _register_hooks(self): | |
| # Hook into the last conv layer of the CNN backbone | |
| target_layer = self.model.cnn.backbone[-1] | |
| def forward_hook(module, input, output): | |
| self.activations = output.detach() | |
| def backward_hook(module, grad_input, grad_output): | |
| self.gradients = grad_output[0].detach() | |
| target_layer.register_forward_hook(forward_hook) | |
| target_layer.register_full_backward_hook(backward_hook) | |
| def generate(self, input_tensor: torch.Tensor, target_class: int = None) -> np.ndarray: | |
| self.model.eval() | |
| input_tensor.requires_grad_(True) | |
| output = self.model(input_tensor) | |
| logits = output["logits"] | |
| if target_class is None: | |
| target_class = logits.argmax(dim=-1).item() | |
| self.model.zero_grad() | |
| logits[0, target_class].backward() | |
| gradients = self.gradients | |
| activations = self.activations | |
| # Global average pooling of gradients | |
| weights = gradients.mean(dim=(2, 3), keepdim=True) | |
| cam = (weights * activations).sum(dim=1, keepdim=True) | |
| cam = F.relu(cam) | |
| # Normalize | |
| cam = cam.squeeze().cpu().numpy() | |
| cam = cam - cam.min() | |
| cam = cam / (cam.max() + 1e-8) | |
| return cam | |
| def create_gradcam_overlay( | |
| original_image: np.ndarray, | |
| cam: np.ndarray, | |
| alpha: float = 0.5, | |
| ) -> np.ndarray: | |
| """Create a Grad-CAM heatmap overlay on the original image.""" | |
| h, w = original_image.shape[:2] | |
| cam_resized = cv2.resize(cam, (w, h)) | |
| # Apply colormap | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| # Overlay | |
| overlay = np.float32(heatmap) * alpha + np.float32(original_image) * (1 - alpha) | |
| overlay = np.clip(overlay, 0, 255).astype(np.uint8) | |
| return overlay | |
| # ============================================================================= | |
| # Model Loading | |
| # ============================================================================= | |
| REPO_ID = "Zorrojurro/brain-tumor-cnn-vit" | |
| CLASS_NAMES = ["Glioma", "Meningioma", "No Tumor", "Pituitary"] | |
| CLASS_EMOJIS = {"Glioma": "🔴", "Meningioma": "🟠", "No Tumor": "🟢", "Pituitary": "🟡"} | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Image preprocessing | |
| TRANSFORM = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| def load_model(): | |
| """Download and load the model from Hugging Face Hub.""" | |
| print("📥 Downloading model from Hugging Face Hub...") | |
| # Download checkpoint | |
| checkpoint_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="best_model.pth", | |
| cache_dir="./model_cache", | |
| ) | |
| # Create model | |
| config = { | |
| "data": {"num_classes": 4}, | |
| "model": { | |
| "cnn_backbone": "resnet50", | |
| "cnn_pretrained": False, | |
| "vit_embed_dim": 512, | |
| "vit_depth": 6, | |
| "vit_num_heads": 8, | |
| "vit_mlp_ratio": 4.0, | |
| "use_radiomics": True, | |
| "radiomics_features": 128, | |
| "dropout": 0.3, | |
| }, | |
| } | |
| classifier = BrainTumorClassifier(config) | |
| model = classifier.model | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False) | |
| state_dict = checkpoint.get("model_state_dict", checkpoint) | |
| # Handle key prefix mismatches | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| # Remove 'model.' prefix if present | |
| new_key = k.replace("model.", "") if k.startswith("model.") else k | |
| new_state_dict[new_key] = v | |
| model.load_state_dict(new_state_dict, strict=False) | |
| model.eval().to(DEVICE) | |
| print(f"✅ Model loaded on {DEVICE}") | |
| return model | |
| # Load model at startup | |
| MODEL = load_model() | |
| GRADCAM = GradCAM(MODEL) | |
| # ============================================================================= | |
| # Prediction Function | |
| # ============================================================================= | |
| def predict(image: Image.Image): | |
| """Run prediction and generate Grad-CAM visualization.""" | |
| if image is None: | |
| return None, None, "Please upload an image." | |
| # Convert to RGB | |
| image = image.convert("RGB") | |
| original_np = np.array(image) | |
| # Preprocess | |
| input_tensor = TRANSFORM(image).unsqueeze(0).to(DEVICE) | |
| # Forward pass with gradients for Grad-CAM | |
| with torch.enable_grad(): | |
| cam = GRADCAM.generate(input_tensor) | |
| # Get predictions | |
| with torch.no_grad(): | |
| output = MODEL(input_tensor) | |
| logits = output["logits"] | |
| probs = F.softmax(logits, dim=-1)[0] | |
| # Build confidence dict | |
| confidences = {} | |
| for i, name in enumerate(CLASS_NAMES): | |
| emoji = CLASS_EMOJIS[name] | |
| confidences[f"{emoji} {name}"] = float(probs[i]) | |
| # Grad-CAM overlay | |
| gradcam_overlay = create_gradcam_overlay(original_np, cam, alpha=0.45) | |
| # Predicted class info | |
| pred_idx = probs.argmax().item() | |
| pred_name = CLASS_NAMES[pred_idx] | |
| pred_conf = probs[pred_idx].item() | |
| emoji = CLASS_EMOJIS[pred_name] | |
| summary = f"## {emoji} {pred_name}\n**Confidence:** {pred_conf:.1%}\n\n" | |
| if pred_name == "No Tumor": | |
| summary += "✅ No tumor detected in the MRI scan." | |
| else: | |
| summary += f"⚠️ Potential **{pred_name.lower()}** detected. Please consult a medical professional." | |
| return confidences, gradcam_overlay, summary | |
| # ============================================================================= | |
| # Gradio UI | |
| # ============================================================================= | |
| CUSTOM_CSS = """ | |
| .gradio-container { | |
| max-width: 1100px !important; | |
| margin: auto !important; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| border: none !important; | |
| } | |
| .gr-button-primary:hover { | |
| background: linear-gradient(135deg, #764ba2 0%, #667eea 100%) !important; | |
| transform: translateY(-1px); | |
| box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4); | |
| } | |
| footer {visibility: hidden} | |
| """ | |
| DESCRIPTION = """ | |
| # 🧠 Brain Tumor Detection — Hybrid CNN-ViT | |
| Upload a brain MRI scan for instant AI-powered classification with **Grad-CAM explainability**. | |
| **Model Architecture**: ResNet50 (CNN) + 6-Layer Vision Transformer + Learnable Radiomics | |
| **Classes**: Glioma · Meningioma · No Tumor · Pituitary | |
| **Performance**: 98% Accuracy · 0.97 F1-Score · 0.99 AUC | |
| > ⚠️ *For research and educational purposes only. Not a substitute for professional medical diagnosis.* | |
| """ | |
| with gr.Blocks( | |
| css=CUSTOM_CSS, | |
| theme=gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="purple", | |
| neutral_hue="slate", | |
| ), | |
| title="Brain Tumor Detection — CNN-ViT", | |
| ) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload Brain MRI", | |
| height=350, | |
| ) | |
| predict_btn = gr.Button( | |
| "🔬 Analyze MRI", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| with gr.Column(scale=1): | |
| gradcam_output = gr.Image( | |
| label="Grad-CAM Visualization", | |
| height=350, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| label_output = gr.Label( | |
| label="Classification Confidence", | |
| num_top_classes=4, | |
| ) | |
| with gr.Column(scale=1): | |
| summary_output = gr.Markdown( | |
| label="Diagnosis Summary", | |
| ) | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[input_image], | |
| outputs=[label_output, gradcam_output, summary_output], | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Built by [Vishnu K](https://huggingface.co/ZorroJurro)** · | |
| [Model Card](https://huggingface.co/ZorroJurro/brain-tumor-cnn-vit) · | |
| [GitHub](https://github.com/ZorroJurro) | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |