""" Hugging Face Spaces - Chest X-Ray Medical Diagnosis API Serves both API endpoints and static frontend Uses pre-trained CheXpert DenseNet121 model from HuggingFace """ import os import io import torch import torch.nn as nn from PIL import Image from datetime import datetime, timezone from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles from torchvision import models, transforms from safetensors.torch import load_file from huggingface_hub import hf_hub_download from pymongo import MongoClient from pymongo.errors import ConnectionFailure import certifi # Initialize FastAPI app = FastAPI( title="Chest X-Ray Disease Diagnosis API", description="AI-powered chest X-ray analysis for detecting 14 diseases using CheXpert model", version="2.0.0" ) # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # CheXpert Disease labels (order matters - matches model output) LABELS = [ "No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices" ] DISEASE_INFO = { "No Finding": "No abnormality detected in the chest X-ray", "Enlarged Cardiomediastinum": "Widening of the central chest area", "Cardiomegaly": "Enlarged heart", "Lung Opacity": "Cloudy or hazy areas in the lungs", "Lung Lesion": "Abnormal tissue in the lung", "Edema": "Fluid accumulation in the lungs", "Consolidation": "Lung tissue filled with liquid instead of air", "Pneumonia": "Lung infection causing inflammation", "Atelectasis": "Partial or complete lung collapse", "Pneumothorax": "Collapsed lung due to air leak", "Pleural Effusion": "Fluid around the lungs", "Pleural Other": "Other pleural abnormalities", "Fracture": "Bone fracture visible in chest X-ray", "Support Devices": "Medical devices detected (tubes, wires, etc.)" } # Model info MODEL_REPO_ID = "itsomk/chexpert-densenet121" MODEL_FILENAME = "pytorch_model.safetensors" # MongoDB Atlas Configuration MONGODB_URI = os.environ.get( "MONGODB_URI", "mongodb+srv://vadhavanarajanmi:RajanMongo1976@rajanvadhavana.8l1my.mongodb.net/?retryWrites=true&w=majority" ) DATABASE_NAME = "chest_xray_diagnosis" COLLECTION_NAME = "predictions" # MongoDB client (initialized on startup) mongo_client = None db = None predictions_collection = None # PyTorch model model = None device = None preprocess = None class DenseNet121_CheXpert(nn.Module): """DenseNet121 model for CheXpert multi-label classification""" def __init__(self, num_labels=14, pretrained=False): super().__init__() self.densenet = models.densenet121(pretrained=pretrained) num_features = self.densenet.classifier.in_features self.densenet.classifier = nn.Linear(num_features, num_labels) def forward(self, x): return self.densenet(x) def load_model(): """Download and load the pre-trained CheXpert model from HuggingFace""" global model, device, preprocess print("Downloading CheXpert model from HuggingFace...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Download model weights local_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME) print(f"Model downloaded to: {local_path}") # Load model model = DenseNet121_CheXpert(num_labels=14, pretrained=False) state_dict = load_file(local_path) model.load_state_dict(state_dict, strict=False) model = model.to(device) model.eval() # Define preprocessing (ImageNet normalization) preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) print("✅ CheXpert DenseNet121 model loaded successfully!") return model def preprocess_image(image_bytes): """Preprocess image for PyTorch model""" img = Image.open(io.BytesIO(image_bytes)).convert("RGB") img_tensor = preprocess(img).unsqueeze(0) # Add batch dimension return img_tensor.to(device) @app.on_event("startup") async def startup_event(): """Load model and connect to MongoDB on startup""" global model, mongo_client, db, predictions_collection # Initialize MongoDB connection try: mongo_client = MongoClient(MONGODB_URI, tlsCAFile=certifi.where(), serverSelectionTimeoutMS=10000) mongo_client.admin.command('ping') db = mongo_client[DATABASE_NAME] predictions_collection = db[COLLECTION_NAME] print("✅ Connected to MongoDB Atlas successfully!") except ConnectionFailure as e: print(f"⚠️ MongoDB connection failed: {e}") print("Predictions will not be stored in the cloud.") except Exception as e: print(f"⚠️ MongoDB error: {e}") # Load ML model from HuggingFace model = load_model() @app.get("/") async def root(): """Serve frontend""" return FileResponse("static/index.html") @app.get("/api") async def api_info(): """API health check""" return { "status": "healthy", "message": "Chest X-Ray Diagnosis API", "version": "2.0.0", "model": "CheXpert DenseNet121 (HuggingFace)", "model_repo": MODEL_REPO_ID, "diseases_detected": len(LABELS), "device": str(device) if device else "not loaded" } @app.get("/diseases") async def get_diseases(): """Get list of detectable diseases""" return { "count": len(LABELS), "diseases": [{"name": l, "description": DISEASE_INFO.get(l, "")} for l in LABELS] } @app.post("/predict") async def predict(file: UploadFile = File(...)): """Predict diseases from chest X-ray""" global model if not file.content_type or not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") try: image_bytes = await file.read() img_tensor = preprocess_image(image_bytes) # Run inference with torch.no_grad(): logits = model(img_tensor) probs = torch.sigmoid(logits).squeeze().cpu().numpy() results = [] detected_diseases = [] for i, label in enumerate(LABELS): prob = float(probs[i]) # Use 0.5 threshold for detection (standard for multi-label) is_detected = prob > 0.5 result = { "disease": label, "description": DISEASE_INFO.get(label, ""), "probability": round(prob * 100, 2), "detected": is_detected, "severity": "High" if prob > 0.7 else "Medium" if prob > 0.5 else "Low" } results.append(result) if is_detected and label != "No Finding": detected_diseases.append(label) # Sort by probability (highest first) results.sort(key=lambda x: x["probability"], reverse=True) response_data = { "success": True, "filename": file.filename, "model": "CheXpert DenseNet121", "total_diseases_checked": len(LABELS), "diseases_detected": len(detected_diseases), "detected_list": detected_diseases, "predictions": results, "disclaimer": "AI-assisted tool. Consult a medical professional for diagnosis." } # Store prediction in MongoDB Atlas if predictions_collection is not None: try: mongo_doc = { "filename": file.filename, "timestamp": datetime.now(timezone.utc), "model": "CheXpert DenseNet121", "diseases_detected_count": len(detected_diseases), "detected_diseases": detected_diseases, "predictions": results, "total_diseases_checked": len(LABELS) } insert_result = predictions_collection.insert_one(mongo_doc) response_data["mongodb_id"] = str(insert_result.inserted_id) response_data["stored_in_cloud"] = True print(f"✅ Prediction stored in MongoDB: {insert_result.inserted_id}") except Exception as db_error: print(f"⚠️ Failed to store in MongoDB: {db_error}") response_data["stored_in_cloud"] = False else: response_data["stored_in_cloud"] = False return JSONResponse(content=response_data) except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {str(e)}") @app.get("/predictions/history") async def get_prediction_history(limit: int = 20): """Get prediction history from MongoDB Atlas""" if predictions_collection is None: raise HTTPException(status_code=503, detail="MongoDB not connected") try: cursor = predictions_collection.find().sort("timestamp", -1).limit(limit) history = [] for doc in cursor: doc["_id"] = str(doc["_id"]) doc["timestamp"] = doc["timestamp"].isoformat() if doc.get("timestamp") else None history.append(doc) return { "success": True, "count": len(history), "predictions": history } except Exception as e: raise HTTPException(status_code=500, detail=f"Error fetching history: {str(e)}") @app.get("/predictions/stats") async def get_prediction_stats(): """Get statistics from stored predictions""" if predictions_collection is None: raise HTTPException(status_code=503, detail="MongoDB not connected") try: total_predictions = predictions_collection.count_documents({}) pipeline = [ {"$unwind": "$detected_diseases"}, {"$group": {"_id": "$detected_diseases", "count": {"$sum": 1}}}, {"$sort": {"count": -1}} ] disease_stats = list(predictions_collection.aggregate(pipeline)) return { "success": True, "total_predictions": total_predictions, "disease_frequency": disease_stats, "database": DATABASE_NAME, "collection": COLLECTION_NAME } except Exception as e: raise HTTPException(status_code=500, detail=f"Error fetching stats: {str(e)}") # Mount static files app.mount("/static", StaticFiles(directory="static"), name="static") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)