#!/usr/bin/env python3 """ RetinaSense v4 -- Gradio Demo App =================================== Interactive web UI for retinal disease classification with: - 5-class disease prediction with confidence scores - GradCAM attention heatmap overlay - Similar case retrieval via FAISS - MC Dropout uncertainty estimation Usage: python app.py python app.py --share # public link python app.py --model-path outputs_v4/lesion_attention/best_model.pth """ import argparse import json import os import sys import warnings import cv2 import numpy as np warnings.filterwarnings("ignore") import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) if SCRIPT_DIR not in sys.path: sys.path.insert(0, SCRIPT_DIR) from models.hybrid_retina_model import HybridRetinaModel # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- NUM_CLASSES = 5 IMG_SIZE = 224 CLASS_NAMES = ["Normal", "Diabetes/DR", "Glaucoma", "Cataract", "AMD"] NORM_MEAN = [0.4298, 0.2784, 0.1559] NORM_STD = [0.2857, 0.2065, 0.1465] # --------------------------------------------------------------------------- # Global state (loaded once at startup) # --------------------------------------------------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = None faiss_index = None faiss_metadata = None temperature = 1.0 thresholds = None def load_model(model_path: str): """Load the hybrid model checkpoint.""" global model, temperature, thresholds model = HybridRetinaModel(pretrained=False) if os.path.exists(model_path): ckpt = torch.load(model_path, map_location="cpu", weights_only=False) state = ckpt.get("model_state_dict", ckpt) model.load_state_dict(state, strict=False) print(f"Loaded model from {model_path}") else: print(f"WARNING: {model_path} not found, using random weights") model = model.to(device).eval() # Load temperature scaling temp_path = os.path.join(SCRIPT_DIR, "outputs_v4", "temperature.json") if os.path.exists(temp_path): with open(temp_path) as f: temperature = json.load(f).get("temperature", 1.0) print(f"Temperature scaling: {temperature:.4f}") # Load optimized thresholds thresh_path = os.path.join(SCRIPT_DIR, "outputs_v4", "thresholds.json") if os.path.exists(thresh_path): with open(thresh_path) as f: thresholds = json.load(f) print(f"Loaded per-class thresholds") def load_faiss_index(): """Load FAISS index and metadata for retrieval.""" global faiss_index, faiss_metadata try: import faiss index_path = os.path.join(SCRIPT_DIR, "outputs_v4", "retrieval", "index_flat_l2.faiss") meta_path = os.path.join(SCRIPT_DIR, "outputs_v4", "retrieval", "metadata.json") if os.path.exists(index_path) and os.path.exists(meta_path): faiss_index = faiss.read_index(index_path) with open(meta_path) as f: faiss_metadata = json.load(f) print(f"FAISS index loaded: {faiss_index.ntotal} vectors") else: print("FAISS index not found, retrieval disabled") except ImportError: print("faiss not installed, retrieval disabled") # --------------------------------------------------------------------------- # Preprocessing # --------------------------------------------------------------------------- def preprocess_image(image_rgb: np.ndarray): """Preprocess a raw RGB image for model input. Applies: resize 224 -> CLAHE on L-channel -> circular mask. Returns (tensor [1,3,224,224], display_image [224,224,3]). """ from torchvision import transforms img = image_rgb.copy() # Resize to 224x224 img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA) # CLAHE on L-channel lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) lab[:, :, 0] = clahe.apply(lab[:, :, 0]) img = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) # Circular mask h, w = img.shape[:2] mask = np.zeros((h, w), dtype=np.uint8) cv2.circle(mask, (w // 2, h // 2), min(h, w) // 2, 255, -1) img[mask == 0] = 0 display_img = img.copy() # To tensor + normalize transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(NORM_MEAN, NORM_STD), ]) tensor = transform(img).unsqueeze(0) # (1, 3, 224, 224) return tensor, display_img # --------------------------------------------------------------------------- # GradCAM # --------------------------------------------------------------------------- def compute_gradcam(input_tensor: torch.Tensor, target_class: int) -> np.ndarray: """Compute GradCAM heatmap for the given class. Returns (224, 224) in [0, 1].""" activations = {} gradients = {} target_layer = model.vit.blocks[-1].norm1 def fwd_hook(module, inp, out): activations["val"] = out def bwd_hook(module, grad_in, grad_out): gradients["val"] = grad_out[0] h_fwd = target_layer.register_forward_hook(fwd_hook) h_bwd = target_layer.register_full_backward_hook(bwd_hook) model.zero_grad() input_tensor = input_tensor.to(device).requires_grad_(True) logits = model(input_tensor) score = logits[0, target_class] score.backward() h_fwd.remove() h_bwd.remove() act = activations["val"][:, 1:, :] # (1, 196, 768) patch tokens grad = gradients["val"][:, 1:, :] weights = grad.mean(dim=1, keepdim=True) cam = (act * weights).sum(dim=-1) cam = F.relu(cam) cam = cam.view(1, 1, 14, 14) cam = F.interpolate(cam, size=(IMG_SIZE, IMG_SIZE), mode="bilinear", align_corners=False) cam = cam.squeeze().detach().cpu().numpy() cam_min, cam_max = cam.min(), cam.max() if cam_max - cam_min > 1e-8: cam = (cam - cam_min) / (cam_max - cam_min) else: cam = np.zeros_like(cam) return cam def overlay_heatmap(image: np.ndarray, heatmap: np.ndarray, alpha: float = 0.4) -> np.ndarray: """Overlay a heatmap on an image.""" heatmap_colored = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET) heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) overlay = (image.astype(float) * (1 - alpha) + heatmap_colored.astype(float) * alpha).astype(np.uint8) return overlay # --------------------------------------------------------------------------- # MC Dropout Uncertainty # --------------------------------------------------------------------------- @torch.no_grad() def mc_dropout_predict(input_tensor: torch.Tensor, n_passes: int = 20) -> dict: """Run MC Dropout for uncertainty estimation.""" input_tensor = input_tensor.to(device) def enable_dropout(m): if isinstance(m, nn.Dropout): m.train() model.eval() model.apply(enable_dropout) all_probs = [] for _ in range(n_passes): logits = model(input_tensor) if temperature != 1.0: logits = logits / temperature probs = F.softmax(logits, dim=1) all_probs.append(probs.cpu().numpy()) model.eval() all_probs = np.array(all_probs) # (n_passes, 1, 5) mean_probs = all_probs.mean(axis=0)[0] std_probs = all_probs.std(axis=0)[0] pred_class = int(mean_probs.argmax()) entropy = -np.sum(mean_probs * np.log(mean_probs + 1e-10)) max_entropy = np.log(NUM_CLASSES) uncertainty = entropy / max_entropy return { "mean_probs": mean_probs, "std_probs": std_probs, "predicted_class": pred_class, "predicted_name": CLASS_NAMES[pred_class], "confidence": float(mean_probs[pred_class]), "uncertainty": float(uncertainty), } # --------------------------------------------------------------------------- # FAISS Retrieval # --------------------------------------------------------------------------- @torch.no_grad() def retrieve_similar(input_tensor: torch.Tensor, k: int = 5) -> list: """Retrieve top-k similar cases from the FAISS index.""" if faiss_index is None or faiss_metadata is None: return [] import faiss as faiss_lib input_tensor = input_tensor.to(device) embedding = model.vit(input_tensor) if isinstance(embedding, (tuple, list)): embedding = embedding[0] embedding = embedding.cpu().numpy().astype(np.float32) faiss_lib.normalize_L2(embedding) distances, indices = faiss_index.search(embedding, k) distances = distances[0] indices = indices[0] results = [] for rank, (dist, idx) in enumerate(zip(distances, indices), 1): if idx < 0 or idx >= len(faiss_metadata): continue meta = faiss_metadata[idx] similarity = max(0.0, 1.0 - dist / 4.0) img = None cache_path = meta.get("cache_path", "") if cache_path and os.path.exists(cache_path): try: img = np.load(cache_path) except Exception: pass results.append({ "rank": rank, "class_name": meta["class_name"], "label": meta["label"], "similarity": round(similarity * 100, 1), "distance": round(float(dist), 4), "image": img, }) return results # --------------------------------------------------------------------------- # Main Prediction Function # --------------------------------------------------------------------------- def predict(image: np.ndarray): """Main prediction pipeline called by Gradio.""" if image is None: return {}, None, "No image uploaded", [] # Preprocess input_tensor, display_img = preprocess_image(image) # MC Dropout prediction mc_result = mc_dropout_predict(input_tensor, n_passes=20) pred_class = mc_result["predicted_class"] # Confidence scores for gr.Label confidences = { CLASS_NAMES[i]: float(mc_result["mean_probs"][i]) for i in range(NUM_CLASSES) } # GradCAM heatmap = compute_gradcam(input_tensor, pred_class) gradcam_overlay = overlay_heatmap(display_img, heatmap, alpha=0.4) # Uncertainty text conf = mc_result["confidence"] * 100 unc = mc_result["uncertainty"] * 100 std = mc_result["std_probs"][pred_class] * 100 if unc < 30: unc_level = "LOW (reliable prediction)" elif unc < 60: unc_level = "MODERATE (clinical review recommended)" else: unc_level = "HIGH (unreliable -- refer to specialist)" uncertainty_text = ( f"Prediction: {mc_result['predicted_name']}\n" f"Confidence: {conf:.1f}%\n" f"Std Dev: +/- {std:.1f}%\n" f"Uncertainty: {unc:.1f}% -- {unc_level}\n" f"(Based on 20 MC Dropout forward passes)" ) # Retrieval retrieval_results = retrieve_similar(input_tensor, k=5) gallery_images = [] for r in retrieval_results: if r["image"] is not None: caption = f"#{r['rank']} {r['class_name']} ({r['similarity']}% similar)" gallery_images.append((r["image"], caption)) return confidences, gradcam_overlay, uncertainty_text, gallery_images # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- def create_demo(): """Build the Gradio interface.""" import gradio as gr with gr.Blocks( title="RetinaSense v4 -- Retinal Disease Classifier", theme=gr.themes.Soft(), ) as demo: gr.Markdown( """ # RetinaSense v4 -- Retinal Disease Classification **Hybrid ViT + EfficientNet-B3 model** for 5-class retinal disease detection. Upload a retinal fundus image to get: - Disease classification with confidence scores - GradCAM attention heatmap showing where the model focuses - Uncertainty estimation via MC Dropout - Similar case retrieval from the training database **Classes**: Normal | Diabetes/DR | Glaucoma | Cataract | AMD | **Performance**: 91.1% accuracy (5-fold CV), 0.986 macro AUC """ ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="Upload Retinal Fundus Image", type="numpy", height=300, ) predict_btn = gr.Button("Analyze", variant="primary", size="lg") gr.Markdown( "*Upload any retinal fundus photograph (JPEG/PNG). " "The model expects standard color fundus images.*" ) with gr.Column(scale=1): label_output = gr.Label( label="Classification Confidence", num_top_classes=5, ) gradcam_output = gr.Image( label="GradCAM Attention Map", height=300, ) with gr.Row(): with gr.Column(scale=1): uncertainty_output = gr.Textbox( label="Uncertainty Analysis (MC Dropout)", lines=5, interactive=False, ) with gr.Column(scale=1): gallery_output = gr.Gallery( label="Similar Cases from Training Database", columns=5, height=250, ) predict_btn.click( fn=predict, inputs=[input_image], outputs=[label_output, gradcam_output, uncertainty_output, gallery_output], ) gr.Markdown( """ --- **Model**: HybridRetinaModel (ViT-Base/16 + EfficientNet-B3, 97.8M params) | **Dataset**: APTOS-2019 + ODIR-5K (10,000 balanced images) | **Training**: Focal Loss + MixUp + LLRD + SWA + GradCAM Attention *This tool is for research and educational purposes only. Not intended for clinical diagnosis.* """ ) return demo # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="RetinaSense v4 Gradio Demo") parser.add_argument( "--model-path", type=str, default="outputs_v4/best_model.pth", help="Path to model checkpoint", ) parser.add_argument("--share", action="store_true", help="Create public share link") parser.add_argument("--port", type=int, default=7860, help="Server port") args = parser.parse_args() # Resolve path model_path = args.model_path if not os.path.isabs(model_path): model_path = os.path.join(SCRIPT_DIR, model_path) # Check for lesion attention model (preferred if available) lesion_path = os.path.join(SCRIPT_DIR, "outputs_v4", "lesion_attention", "best_model.pth") if os.path.exists(lesion_path) and args.model_path == "outputs_v4/best_model.pth": print(f"Found lesion attention model, using: {lesion_path}") model_path = lesion_path print("=" * 60) print(" RetinaSense v4 -- Gradio Demo") print("=" * 60) print(f" Device: {device}") if torch.cuda.is_available(): print(f" GPU: {torch.cuda.get_device_name(0)}") print(f" Model: {model_path}") print("=" * 60) load_model(model_path) load_faiss_index() demo = create_demo() demo.launch( server_name="0.0.0.0", server_port=args.port, share=args.share, ) if __name__ == "__main__": main()