import os import time import cv2 import numpy as np # ========================================== # GGUF ENGINE — replaces HuggingFace + PEFT # ========================================== from gguf_engine import ( generate_with_adapter, generate_with_adapter_vision, exclude_thinking_component, ) import re import json import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from PIL import Image, ImageDraw from tqdm import tqdm from typing import TypedDict, List, Optional from langgraph.graph import StateGraph, END from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_core.documents import Document DEVICE = 'cpu' # ========================================== # 0. EMBEDDING MODEL (shared, loaded once) # ========================================== print("Initializing Embedding Model...") embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': DEVICE} ) # ========================================== # DEFAULT FALLBACK GUIDELINES # Used when no admin KB has been configured. # ========================================== _DEFAULT_GUIDELINE_DOCS = [ Document( page_content="NCCN Multiple Myeloma Guidelines: For transplant-eligible patients with high-risk cytogenetics and normal renal function, standard VRd is recommended.", metadata={"source": "Built-in NCCN excerpt", "page": 1} ), Document( page_content="NCCN Multiple Myeloma Guidelines: For transplant-eligible patients presenting with severe renal impairment (Creatinine > 2.0), VRd is still preferred, but Bortezomib requires careful dose adjustment (typically 1.3 mg/m2 on day 1, 4, 8, 11) to prevent toxicity.", metadata={"source": "Built-in NCCN excerpt", "page": 1} ), Document( page_content="ESMO Multiple Myeloma Guidelines: Daratumumab plus VTd is a category 1 recommendation for transplant-eligible patients regardless of renal status.", metadata={"source": "Built-in ESMO excerpt", "page": 1} ), Document( page_content="Acute Lymphoblastic Leukemia Guidelines: First-line induction for Philadelphia chromosome-positive ALL includes targeted TKIs like Imatinib combined with Hyper-CVAD.", metadata={"source": "Built-in ALL excerpt", "page": 1} ), ] print("Building default fallback FAISS Vector Database...") _default_vector_db = FAISS.from_documents(_DEFAULT_GUIDELINE_DOCS, embeddings) retriever = _default_vector_db.as_retriever(search_kwargs={"k": 3}) print("RAG setup complete. GGUF models will be lazy-loaded on first inference call.") # ========================================== # VECTOR STORE PATH (for persisted admin KB) # ========================================== VECTOR_STORE_PATH = "./local_vector_store" def load_persisted_retriever(): """ Loads the admin-configured FAISS index from disk if it exists. Returns (retriever, source_label) or (None, None) if not found. """ if os.path.exists(VECTOR_STORE_PATH): try: db = FAISS.load_local( VECTOR_STORE_PATH, embeddings, allow_dangerous_deserialization=True ) print(f" [RAG] Loaded persisted knowledge base from {VECTOR_STORE_PATH}") return db.as_retriever(search_kwargs={"k": 4}), "Hospital Knowledge Base (admin-configured)" except Exception as e: print(f" [RAG] Failed to load persisted KB: {e} — using defaults") return None, None # ========================================== # ADMIN: PDF → FAISS (saves to disk) # Called from the admin tab in cpu_app.py. # ========================================== def save_retriever_from_pdf(pdf_path: str): """ Parses PDF page-by-page with pdfplumber, chunks with overlap, injects [source, page] metadata, builds FAISS index and saves to disk. Returns (success: bool, message: str). """ try: import pdfplumber source_name = os.path.basename(pdf_path) chunks = [] with pdfplumber.open(pdf_path) as pdf: total_pages = len(pdf.pages) for page_num, page in enumerate(pdf.pages, start=1): page_text = page.extract_text() or "" if not page_text.strip(): continue chunk_size, overlap = 1000, 200 start = 0 while start < len(page_text): end = min(start + chunk_size, len(page_text)) chunk = page_text[start:end].strip() if len(chunk) > 50: chunks.append(Document( page_content=chunk, metadata={"source": source_name, "page": page_num} )) start += chunk_size - overlap if not chunks: return False, "No extractable text found in PDF. The document may be scanned/image-only." pdf_db = FAISS.from_documents(chunks, embeddings) os.makedirs(VECTOR_STORE_PATH, exist_ok=True) pdf_db.save_local(VECTOR_STORE_PATH) msg = ( f"✅ Knowledge base saved successfully.\n" f" Document : {source_name}\n" f" Pages : {total_pages}\n" f" Chunks : {len(chunks)}\n" f" Saved to : {VECTOR_STORE_PATH}" ) print(f" [Admin] {msg}") return True, msg except Exception as e: return False, f"❌ Failed to build knowledge base: {e}" # ========================================== # WSI PATCH ENGINE (Module 3 visual path) # ========================================== _WSI_COLOR_MAP = { "Malignant": "red", "Normal": "green", "Background": "gray", "Unknown": "yellow", } _WSI_PATCH_PROMPT = ( "Analyze this Bone Marrow Biopsy patch. " "Does it contain any plasma cells indicative of Multiple Myeloma? " "Answer with one word: Malignant, Normal, or Background." ) def _classify_patch(text: str) -> str: t = text.lower() if "malignant" in t: return "Malignant" if "normal" in t: return "Normal" if "background" in t or "no plasma" in t: return "Background" return "Unknown" def _is_background_patch( patch_np: np.ndarray, *, tissue_sat_thresh: int = 40, # HSV S channel: pixels above this are "tissue" ) -> tuple[bool, dict]: """ Determine whether *patch_np* is a background (non-tissue) patch. """ # ── Guard 1: HSV saturation tissue-pixel ratio ───────────────────────── hsv = cv2.cvtColor(patch_np, cv2.COLOR_RGB2HSV) sat_mean = hsv[:, :, 1].mean() # ── Decision: background only if BOTH guards agree ───────────────────── is_background = (sat_mean < tissue_sat_thresh) return is_background, sat_mean def process_and_annotate_wsi( wsi_path: str, patch_size: int = 448, border_width: int = 8, output_path: str = "annotated_wsi_output.png", ) -> tuple: print(f"\n [WSI Engine] Opening: {wsi_path}") wsi_image = Image.open(wsi_path).convert("RGB") W, H = wsi_image.size cols = W // patch_size rows = H // patch_size total = cols * rows print(f" [WSI Engine] {W}x{H}px | Grid {cols}x{rows} = {total} patches") canvas = Image.new("RGB", (cols * patch_size, rows * patch_size)) counts = {"Malignant": 0, "Normal": 0, "Background": 0, "Unknown": 0} _wsi_start = time.time() # Pre-convert full image to numpy once — avoids repeated PIL↔numpy round-trips wsi_np = np.array(wsi_image) model_calls = 0 # track how many patches actually hit MedGemma skipped_bg = 0 patches_done=0 with tqdm(total=total, desc=" WSI Patches") as pbar: for r in range(rows): for c in range(cols): left, upper = c * patch_size, r * patch_size right, lower = left + patch_size, upper + patch_size # Slice numpy array (zero-copy view) for OpenCV analysis patch_np = wsi_np[upper:lower, left:right] is_bg, signals = _is_background_patch(patch_np) if is_bg: # ── Fast path: skip MedGemma entirely ───────────────── pred = "Background" skipped_bg += 1 patch = Image.fromarray(patch_np) print("SKIPPED BY CV.............") else: # ── Slow path: send to MedGemma ─────────────────────── patch = Image.fromarray(patch_np) gen_text = generate_with_adapter_vision( image = patch, prompt = _WSI_PATCH_PROMPT, adapter_name = "module3", max_tokens = 20, ) print(f" [WSI] patch({r},{c}) model→ {gen_text!r} signals={signals}") pred = _classify_patch(gen_text) model_calls += 1 patches_done+=1 counts[pred] += 1 # Draw border on patch draw = ImageDraw.Draw(patch) draw.rectangle( [(0, 0), (patch_size - 1, patch_size - 1)], outline = _WSI_COLOR_MAP[pred], width = border_width, ) canvas.paste(patch, (left, upper)) pbar.update(1) pbar.set_postfix({"pred": pred, "patches done": patches_done, "model calls": model_calls, "skipped BG": skipped_bg}) _wsi_elapsed = time.time() - _wsi_start pct_skipped = 100 * skipped_bg / total if total else 0 print( f" [WSI Engine] Done in {_wsi_elapsed:.2f}s | " f"Model calls: {model_calls}/{total} " f"({100 - pct_skipped:.0f}% of patches) | " f"Skipped (BG): {skipped_bg} ({pct_skipped:.0f}%)" ) n_mal = counts["Malignant"] pct = round((n_mal / total) * 100, 1) if total else 0.0 canvas.save(output_path) fig_path = output_path.replace(".png", "_figure.png") plt.figure(figsize=(15, 15)) plt.imshow(canvas) plt.axis("off") plt.title( f"AI-Annotated Bone Marrow WSI\n" f"(Red=Malignant | Green=Normal | Gray=Background | Yellow=Unknown)\n" f"Malignant: {n_mal}/{total} ({pct}%) | " f"MedGemma called for {model_calls}/{total} patches ({100-pct_skipped:.0f}%)", fontsize=16, ) plt.tight_layout() plt.savefig(fig_path, dpi=150, bbox_inches="tight") plt.close() summary_text = ( f"WSI patch-level analysis ({cols}x{rows} grid, {patch_size}px patches): " f"{n_mal}/{total} patches ({pct}%) classified as Malignant. " f"Breakdown — Normal: {counts['Normal']}, " f"Background: {counts['Background']} (OpenCV pre-filtered), " f"Unknown: {counts['Unknown']}. " f"MedGemma invoked for {model_calls}/{total} patches. " f"Annotated image saved to: {output_path}" ) return canvas, n_mal, total, summary_text, output_path # ========================================== # LANGGRAPH STATE # ========================================== class PatientState(TypedDict): patient_id: str user_query: str raw_clinical_text: str modules_queue: List[str] wsi_image_path: str wsi_output_path: str module2_risk_score: str module3_wsi_analysis: str module4_progression: str module5_guidelines: str # formatted context WITH [SOURCE | PAGE] tags module5_raw_chunks: str # passed to UI accordion for citation verification module5_source: str # human-readable label of which source was used final_recommendation: str # ========================================== # PLANNER NODE # ========================================== MODULE_DESCRIPTIONS = { "module2": "ONLY Assess risk profile / risk score (e.g. High/Low/Standard risk)", "module3": "ONLY Analyse bone marrow biopsy / WSI / plasma cell infiltration", "module4": "ONLY Track disease progression / M-Spike / progression metrics", "module5_rag": "ONLY Retrieve treatment guidelines from knowledge base (NCCN/ESMO)", } def planner_node(state: PatientState): print("\n--- NODE: Planner (Dynamic Module Selection) ---") prompt = f"""You are a medical AI orchestrator. Given a clinician's question, output ONLY a JSON array of module names needed to answer it. Nothing else — no explanation, no markdown, no extra text. Available modules: - module2: Assess risk profile / risk score (High / Low / Standard risk) - module3: Analyse bone marrow biopsy / WSI / plasma cell infiltration - module4: Track disease progression / M-Spike / progression metrics - module5_rag: Retrieve treatment guidelines from knowledge base (NCCN/ESMO) --- EXAMPLES Question: "What is the risk score for this patient?" Answer: ["module2"] Question: "What does the bone marrow biopsy show?" Answer: ["module3"] Question: "Is this patient's disease progressing rapidly based on M-Spike?" Answer: ["module4"] Question: "What is the risk and progression status of this patient?" Answer: ["module2", "module4"] Question: "Summarise the biopsy and assess the risk level." Answer: ["module2", "module3"] Question: "What treatment should be recommended for this myeloma patient?" Answer: ["module2", "module3", "module4", "module5_rag"] Question: "What NCCN guidelines apply to this transplant-eligible patient with renal impairment?" Answer: ["module5_rag"] Question: "Give me a full clinical workup and treatment plan for this patient." Answer: ["module2", "module3", "module4", "module5_rag"] Question: "What is the WSI finding and is there rapid progression?" Answer: ["module3", "module4"] --- Now answer for the following. Output ONLY the JSON array, nothing else. Question: "{state['user_query']}" Answer: """ raw = generate_with_adapter(prompt, adapter_name="default", max_tokens=80) match = re.search(r'\[.*?\]', raw, re.DOTALL) if match: try: selected = json.loads(match.group()) canonical_order = list(MODULE_DESCRIPTIONS.keys()) modules_queue = [m for m in canonical_order if m in selected] except json.JSONDecodeError: modules_queue = list(MODULE_DESCRIPTIONS.keys()) else: print(" [Planner] Could not parse module list — running all modules as fallback.") modules_queue = list(MODULE_DESCRIPTIONS.keys()) print(f" [Planner] Selected modules: {modules_queue}") return {"modules_queue": modules_queue} # ========================================== # ROUTING LOGIC # ========================================== def route_next(state: PatientState) -> str: queue = state.get("modules_queue", []) if not queue: return "orchestrator" return queue[0] # ========================================== # MODULE NODES # ========================================== def _pop_queue(state: PatientState) -> List[str]: return state.get("modules_queue", [])[1:] def run_module2_node(state: PatientState): print("\n--- NODE: Module 2 (Risk Assessment) ---") prompt = ( f"Assess the Multiple Myeloma risk profile. " f"OUTPUT ONE OF THESE: Standard, Low risk, High risk\n" f"Based on this clinical data:\n{state['raw_clinical_text']}\nConcise Risk Profile:" ) result = generate_with_adapter(prompt, adapter_name="module2", max_tokens=200) return {"module2_risk_score": result, "modules_queue": _pop_queue(state)} def run_module3_node(state: PatientState): print("\n--- NODE: Module 3 (Bone Marrow WSI) ---") wsi_path = state.get("wsi_image_path", "").strip() if wsi_path: print(f" [Module 3] WSI path detected: {wsi_path}") output_png = state.get("wsi_output_path", "").strip() or "annotated_wsi_output.png" _, n_mal, total, summary_text, saved_path = process_and_annotate_wsi( wsi_path=wsi_path, patch_size=448, border_width=8, output_path=output_png, ) pct = round((n_mal / total) * 100, 1) if total else 0 print(f" [Module 3] Complete — {n_mal}/{total} malignant ({pct}%) -> {saved_path}") return { "module3_wsi_analysis": summary_text, "wsi_output_path": saved_path, "modules_queue": _pop_queue(state), } print(" [Module 3] No WSI path — generating text summary from clinical notes.") prompt = ( f"Provide a 1-sentence summary of the abnormal plasma cell infiltration (WSI) " f"from this data:\n{state['raw_clinical_text']}\nWSI Summary:" ) result = generate_with_adapter(prompt, adapter_name="module3", max_tokens=300) return {"module3_wsi_analysis": result, "modules_queue": _pop_queue(state)} def run_module4_node(state: PatientState): print("\n--- NODE: Module 4 (Progression Tracking) ---") prompt = ( f"Identify the M-Spike progression metrics from this data and state if it indicates rapid progression:\n" f"{state['raw_clinical_text']}\nProgression Summary:" ) result = generate_with_adapter(prompt, adapter_name="module4", max_tokens=300) return {"module4_progression": result, "modules_queue": _pop_queue(state)} def run_module5_rag_node(state: PatientState): """ Retriever priority: 1. Admin-persisted FAISS on disk (VECTOR_STORE_PATH) 2. Built-in default guideline excerpts (fallback) Retrieved chunks are formatted with [SOURCE | PAGE] tags for citation. """ print("\n--- NODE: Module 5 (RAG Retrieval) ---") active_retriever = None source_label = "" # ── Priority 1: admin-persisted KB ────────────────────────────────────── r, label = load_persisted_retriever() if r: active_retriever = r source_label = label # ── Priority 2: built-in defaults ─────────────────────────────────────── if active_retriever is None: print(" [Module 5] No admin KB found — using default built-in guidelines") active_retriever = retriever source_label = "Default built-in NCCN/ESMO guideline excerpts" print(f" [Module 5] Active source: {source_label}") # ── Generate search query ──────────────────────────────────────────────── prompt = ( f"Formulate a 4-to-8 word search query to look up treatment guidelines for this patient.\n" f"Clinical Data: {state['raw_clinical_text']}\n\n" f"Output strictly the search query and nothing else.\nSearch Query String:" ) search_query = generate_with_adapter(prompt, adapter_name="default", max_tokens=60) clean_query = exclude_thinking_component(search_query).replace('"', '').replace('\n', ' ') print(f" [Module 5] Searching for: '{clean_query}'") # ── Retrieve and format WITH citation tags ─────────────────────────────── retrieved_docs = active_retriever.invoke(clean_query) formatted_chunks = [] for doc in retrieved_docs: page = doc.metadata.get("page", "?") source = doc.metadata.get("source", "Guidelines") formatted_chunks.append( f"[SOURCE: {source} | PAGE: {page}]\n{doc.page_content}" ) formatted_context = "\n\n".join(formatted_chunks) return { "module5_guidelines": formatted_context, "module5_raw_chunks": formatted_context, "module5_source": source_label, "modules_queue": _pop_queue(state), } def orchestrator_synthesis_node(state: PatientState): print("\n--- NODE: Orchestrator Synthesis ---") profile_parts = [] if state.get("module2_risk_score"): profile_parts.append(f"- Risk Score: {state['module2_risk_score']}") if state.get("module3_wsi_analysis"): profile_parts.append(f"- Bone Marrow WSI: {state['module3_wsi_analysis']}") if state.get("module4_progression"): profile_parts.append(f"- Progression: {state['module4_progression']}") profile_block = "\n".join(profile_parts) if profile_parts else "No sub-module analyses performed." guidelines_block = ( state.get("module5_guidelines") or "No treatment guidelines retrieved (RAG not selected for this query)." ) prompt = f"""You are the Master Hematology Orchestrator. Answer the clinician's specific question using ONLY the outputs provided below. Do not speculate beyond the available data. Do not output your internal thought process. CRITICAL CITATION RULE: Every treatment or guideline recommendation you state MUST be followed by the exact [SOURCE: ... | PAGE: ...] tag from the retrieved guidelines below. Do not invent citations. If no supporting guideline is found for a statement, write "(no guideline available)" after it. [CLINICIAN'S QUESTION] {state['user_query']} [PATIENT CLINICAL DATA] {state['raw_clinical_text']} [ANALYSIS FROM SELECTED MODULES] {profile_block} [RETRIEVED GUIDELINES — cite these exactly] {guidelines_block} Final Answer (with citations):""" result = generate_with_adapter(prompt, adapter_name="default", max_tokens=800) return {"final_recommendation": result} # ========================================== # BUILD AND COMPILE THE GRAPH # ========================================== workflow = StateGraph(PatientState) workflow.add_node("planner", planner_node) workflow.add_node("module2", run_module2_node) workflow.add_node("module3", run_module3_node) workflow.add_node("module4", run_module4_node) workflow.add_node("module5_rag", run_module5_rag_node) workflow.add_node("orchestrator", orchestrator_synthesis_node) workflow.set_entry_point("planner") all_modules = ["module2", "module3", "module4", "module5_rag", "orchestrator"] workflow.add_conditional_edges("planner", route_next, {m: m for m in all_modules}) for mod in ["module2", "module3", "module4", "module5_rag"]: workflow.add_conditional_edges(mod, route_next, {m: m for m in all_modules}) workflow.add_edge("orchestrator", END) full_agent = workflow.compile() # ========================================== # RUN THE AGENT (CLI testing) # ========================================== if __name__ == "__main__": clinical_text = """ 65-year-old Male patient. Transplant eligible. - CRAB Panel: Creatinine: 2.5 mg/dL, Calcium: 10.5 mg/dL, Hemoglobin: 9.0 g/dL. - Tumor/Staging Panel: Beta-2 Microglobulin: 5.8 mg/L, LDH: 280 U/L. - Behaviour: Feeling Dizziness with Chest Pain and Mental Confusion. """ _base = { "patient_id": "MMRF_2240", "raw_clinical_text": clinical_text, "modules_queue": [], "wsi_image_path": "", "wsi_output_path": "", "module2_risk_score": "", "module3_wsi_analysis": "", "module4_progression": "", "module5_guidelines": "", "module5_raw_chunks": "", "module5_source": "", "final_recommendation": "", } # print("\n\n=== QUERY 1: Risk Assessment Only ===") # out1 = full_agent.invoke({ # **_base, # "user_query": "Assess the risk of multiple myeloma of this patient based on their lab records and behaviours. Do only risk assessment", # }) # print("\n=== FINAL ANSWER ===") # print(out1["final_recommendation"]) # ---- Example 2: WSI bone marrow analysis with real .bmp image ---- print("\n\n=== QUERY 2: WSI Patch Analysis ===") out2 = full_agent.invoke({ **_base, "user_query": "Analyse the bone marrow biopsy and report malignant patch percentage.", "wsi_image_path": "/media/shrish/Data/medgemma_finetune/segpc/TCIA_SegPC_dataset/" "TCIA_SegPC_dataset/TCIA_SegPC_dataset/validation/validation/x/104.bmp", "wsi_output_path": "annotated_wsi_output2.png", }) print("\n=== FINAL ANSWER ===") print(out2["final_recommendation"]) # ---- Example 3: Full pipeline — all modules + real WSI ---- # print("\n\n=== QUERY 3: Full Treatment Plan + WSI ===") # out3 = full_agent.invoke({ # **_base, # "user_query": "What is the recommended treatment for this transplant-eligible myeloma patient with renal impairment?", # "wsi_image_path": "/media/shrish/Data/medgemma_finetune/segpc/TCIA_SegPC_dataset/" # "TCIA_SegPC_dataset/TCIA_SegPC_dataset/validation/validation/x/104.bmp", # "wsi_output_path": "annotated_wsi_output.png", # }) # print("\n=== FINAL ANSWER ===") # print(out3["final_recommendation"])