import os import time import cv2 import numpy as np from gguf_engine import ( generate_with_adapter, generate_with_adapter_vision, exclude_thinking_component, ) import re import json import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from PIL import Image, ImageDraw from tqdm import tqdm from typing import TypedDict, List, Optional from langgraph.graph import StateGraph, END from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_core.documents import Document DEVICE = 'cpu' print("Initializing Embedding Model...") embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': DEVICE} ) _DEFAULT_GUIDELINE_DOCS = [ Document( page_content="NCCN Multiple Myeloma Guidelines: For transplant-eligible patients with high-risk cytogenetics and normal renal function, standard VRd is recommended.", metadata={"source": "Built-in NCCN excerpt", "page": 1} ), Document( page_content="NCCN Multiple Myeloma Guidelines: For transplant-eligible patients presenting with severe renal impairment (Creatinine > 2.0), VRd is still preferred, but Bortezomib requires careful dose adjustment (typically 1.3 mg/m2 on day 1, 4, 8, 11) to prevent toxicity.", metadata={"source": "Built-in NCCN excerpt", "page": 1} ), Document( page_content="ESMO Multiple Myeloma Guidelines: Daratumumab plus VTd is a category 1 recommendation for transplant-eligible patients regardless of renal status.", metadata={"source": "Built-in ESMO excerpt", "page": 1} ), Document( page_content="Acute Lymphoblastic Leukemia Guidelines: First-line induction for Philadelphia chromosome-positive ALL includes targeted TKIs like Imatinib combined with Hyper-CVAD.", metadata={"source": "Built-in ALL excerpt", "page": 1} ), ] print("Building default fallback FAISS Vector Database...") _default_vector_db = FAISS.from_documents(_DEFAULT_GUIDELINE_DOCS, embeddings) retriever = _default_vector_db.as_retriever(search_kwargs={"k": 3}) print("RAG setup complete. GGUF models will be lazy-loaded on first inference call.") VECTOR_STORE_PATH = "./local_vector_store" def load_persisted_retriever(): if os.path.exists(VECTOR_STORE_PATH): try: db = FAISS.load_local( VECTOR_STORE_PATH, embeddings, allow_dangerous_deserialization=True ) print(f" [RAG] Loaded persisted knowledge base from {VECTOR_STORE_PATH}") return db.as_retriever(search_kwargs={"k": 4}), "Hospital Knowledge Base (admin-configured)" except Exception as e: print(f" [RAG] Failed to load persisted KB: {e} — using defaults") return None, None def save_retriever_from_pdf(pdf_path: str): try: import pdfplumber source_name = os.path.basename(pdf_path) chunks = [] with pdfplumber.open(pdf_path) as pdf: total_pages = len(pdf.pages) for page_num, page in enumerate(pdf.pages, start=1): page_text = page.extract_text() or "" if not page_text.strip(): continue chunk_size, overlap = 1000, 200 start = 0 while start < len(page_text): end = min(start + chunk_size, len(page_text)) chunk = page_text[start:end].strip() if len(chunk) > 50: chunks.append(Document( page_content=chunk, metadata={"source": source_name, "page": page_num} )) start += chunk_size - overlap if not chunks: return False, "No extractable text found in PDF. The document may be scanned/image-only." pdf_db = FAISS.from_documents(chunks, embeddings) os.makedirs(VECTOR_STORE_PATH, exist_ok=True) pdf_db.save_local(VECTOR_STORE_PATH) msg = ( f"✅ Knowledge base saved successfully.\n" f" Document : {source_name}\n" f" Pages : {total_pages}\n" f" Chunks : {len(chunks)}\n" f" Saved to : {VECTOR_STORE_PATH}" ) print(f" [Admin] {msg}") return True, msg except Exception as e: return False, f"❌ Failed to build knowledge base: {e}" # ========================================== # WSI PATCH ENGINE # ========================================== _WSI_COLOR_MAP = { "Malignant": "red", "Normal": "green", "Background": "gray", "Unknown": "yellow", } _WSI_PATCH_PROMPT = ( "Analyze this Bone Marrow Biopsy patch. " "Does it contain any plasma cells indicative of Multiple Myeloma? " "Answer with one word: Malignant, Normal, or Background." ) def _classify_patch(text: str) -> str: t = text.lower() if "malignant" in t: return "Malignant" if "normal" in t: return "Normal" if "background" in t or "no plasma" in t: return "Background" return "Unknown" def _is_background_patch(patch_np: np.ndarray, *, tissue_sat_thresh: int = 40) -> tuple: hsv = cv2.cvtColor(patch_np, cv2.COLOR_RGB2HSV) sat_mean = hsv[:, :, 1].mean() is_background = (sat_mean < tissue_sat_thresh) return is_background, sat_mean def process_and_annotate_wsi( wsi_path: str, patch_size: int = 448, border_width: int = 8, output_path: str = "annotated_wsi_output.png", ) -> tuple: print(f"\n [WSI Engine] Opening: {wsi_path}") wsi_image = Image.open(wsi_path).convert("RGB") W, H = wsi_image.size cols = W // patch_size rows = H // patch_size total = cols * rows print(f" [WSI Engine] {W}x{H}px | Grid {cols}x{rows} = {total} patches") canvas = Image.new("RGB", (cols * patch_size, rows * patch_size)) counts = {"Malignant": 0, "Normal": 0, "Background": 0, "Unknown": 0} _wsi_start = time.time() wsi_np = np.array(wsi_image) model_calls = 0 skipped_bg = 0 patches_done = 0 with tqdm(total=total, desc=" WSI Patches") as pbar: for r in range(rows): for c in range(cols): left, upper = c * patch_size, r * patch_size right, lower = left + patch_size, upper + patch_size patch_np = wsi_np[upper:lower, left:right] is_bg, signals = _is_background_patch(patch_np) if is_bg: pred = "Background" skipped_bg += 1 patch = Image.fromarray(patch_np) print("SKIPPED BY CV.............") else: patch = Image.fromarray(patch_np) gen_text = generate_with_adapter_vision( image = patch, prompt = _WSI_PATCH_PROMPT, adapter_name = "module3", max_tokens = 20, ) print(f" [WSI] patch({r},{c}) model→ {gen_text!r} signals={signals}") pred = _classify_patch(gen_text) model_calls += 1 patches_done += 1 counts[pred] += 1 draw = ImageDraw.Draw(patch) draw.rectangle( [(0, 0), (patch_size - 1, patch_size - 1)], outline = _WSI_COLOR_MAP[pred], width = border_width, ) canvas.paste(patch, (left, upper)) pbar.update(1) pbar.set_postfix({"pred": pred, "done": patches_done, "model": model_calls, "bg_skip": skipped_bg}) _wsi_elapsed = time.time() - _wsi_start pct_skipped = 100 * skipped_bg / total if total else 0 print( f" [WSI Engine] Done in {_wsi_elapsed:.2f}s | " f"Model calls: {model_calls}/{total} | Skipped BG: {skipped_bg}" ) n_mal = counts["Malignant"] pct = round((n_mal / total) * 100, 1) if total else 0.0 canvas.save(output_path) fig_path = output_path.replace(".png", "_figure.png") plt.figure(figsize=(15, 15)) plt.imshow(canvas) plt.axis("off") plt.title( f"AI-Annotated Bone Marrow WSI\n" f"(Red=Malignant | Green=Normal | Gray=Background | Yellow=Unknown)\n" f"Malignant: {n_mal}/{total} ({pct}%) | MedGemma called for {model_calls}/{total} patches", fontsize=16, ) plt.tight_layout() plt.savefig(fig_path, dpi=150, bbox_inches="tight") plt.close() summary_text = ( f"WSI patch-level analysis ({cols}x{rows} grid, {patch_size}px patches): " f"{n_mal}/{total} patches ({pct}%) classified as Malignant. " f"Breakdown — Normal: {counts['Normal']}, " f"Background: {counts['Background']} (OpenCV pre-filtered), " f"Unknown: {counts['Unknown']}. " f"MedGemma invoked for {model_calls}/{total} patches. " f"Annotated image saved to: {output_path}" ) return canvas, n_mal, total, summary_text, output_path # ========================================== # LANGGRAPH STATE # ========================================== class PatientState(TypedDict): patient_id: str user_query: str raw_clinical_text: str modules_queue: List[str] wsi_image_path: str wsi_output_path: str module2_risk_score: str module3_wsi_analysis: str module4_progression: str module5_guidelines: str module5_raw_chunks: str module5_source: str final_recommendation: str # ========================================== # PLANNER NODE # ========================================== MODULE_DESCRIPTIONS = { "module2": "ONLY Assess risk profile / risk score (e.g. High/Low/Standard risk)", "module3": "ONLY Analyse bone marrow biopsy / WSI / plasma cell infiltration", "module4": "ONLY Track disease progression / M-Spike / progression metrics", "module5_rag": "ONLY Retrieve treatment guidelines from knowledge base (NCCN/ESMO)", } def planner_node(state: PatientState): print("\n--- NODE: Planner (Dynamic Module Selection) ---") prompt = f"""You are a medical AI orchestrator. Given a clinician's question, output ONLY a JSON array of module names needed to answer it. Nothing else — no explanation, no markdown, no extra text. Available modules: - module2: Assess risk profile / risk score (High / Low / Standard risk) - module3: Analyse bone marrow biopsy / WSI / plasma cell infiltration - module4: Track disease progression / M-Spike / progression metrics - module5_rag: Retrieve treatment guidelines from knowledge base (NCCN/ESMO) --- EXAMPLES Question: "What is the risk score for this patient?" Answer: ["module2"] Question: "What does the bone marrow biopsy show?" Answer: ["module3"] Question: "Is this patient's disease progressing rapidly based on M-Spike?" Answer: ["module4"] Question: "What is the risk and progression status of this patient?" Answer: ["module2", "module4"] Question: "Summarise the biopsy and assess the risk level." Answer: ["module2", "module3"] Question: "What treatment should be recommended for this myeloma patient?" Answer: ["module2", "module3", "module4", "module5_rag"] Question: "What NCCN guidelines apply to this transplant-eligible patient with renal impairment?" Answer: ["module5_rag"] Question: "Give me a full clinical workup and treatment plan for this patient." Answer: ["module2", "module3", "module4", "module5_rag"] Question: "What is the WSI finding and is there rapid progression?" Answer: ["module3", "module4"] --- Now answer for the following. Output ONLY the JSON array, nothing else. Question: "{state['user_query']}" Answer: """ raw = generate_with_adapter(prompt, adapter_name="default", max_tokens=80) match = re.search(r'\[.*?\]', raw, re.DOTALL) if match: try: selected = json.loads(match.group()) canonical_order = list(MODULE_DESCRIPTIONS.keys()) modules_queue = [m for m in canonical_order if m in selected] except json.JSONDecodeError: modules_queue = list(MODULE_DESCRIPTIONS.keys()) else: print(" [Planner] Could not parse module list — running all modules as fallback.") modules_queue = list(MODULE_DESCRIPTIONS.keys()) print(f" [Planner] Selected modules: {modules_queue}") return {"modules_queue": modules_queue} # ========================================== # ROUTING LOGIC # ========================================== def route_next(state: PatientState) -> str: queue = state.get("modules_queue", []) if not queue: return "orchestrator" return queue[0] # ========================================== # MODULE NODES # ========================================== def _pop_queue(state: PatientState) -> List[str]: return state.get("modules_queue", [])[1:] def run_module2_node(state: PatientState): print("\n--- NODE: Module 2 (Risk Assessment) ---") prompt = ( f"Assess the Multiple Myeloma risk profile. " f"OUTPUT ONE OF THESE: Standard, Low risk, High risk\n" f"Based on this clinical data:\n{state['raw_clinical_text']}\nConcise Risk Profile:" ) result = generate_with_adapter(prompt, adapter_name="module2", max_tokens=200) return {"module2_risk_score": result, "modules_queue": _pop_queue(state)} def run_module3_node(state: PatientState): print("\n--- NODE: Module 3 (Bone Marrow WSI) ---") wsi_path = state.get("wsi_image_path", "").strip() if wsi_path: print(f" [Module 3] WSI path detected: {wsi_path}") output_png = state.get("wsi_output_path", "").strip() or "annotated_wsi_output.png" _, n_mal, total, summary_text, saved_path = process_and_annotate_wsi( wsi_path=wsi_path, patch_size=448, border_width=8, output_path=output_png, ) pct = round((n_mal / total) * 100, 1) if total else 0 print(f" [Module 3] Complete — {n_mal}/{total} malignant ({pct}%) -> {saved_path}") return { "module3_wsi_analysis": summary_text, "wsi_output_path": saved_path, "modules_queue": _pop_queue(state), } print(" [Module 3] No WSI path — generating text summary from clinical notes.") prompt = ( f"Provide a 1-sentence summary of the abnormal plasma cell infiltration (WSI) " f"from this data:\n{state['raw_clinical_text']}\nWSI Summary:" ) result = generate_with_adapter(prompt, adapter_name="module3", max_tokens=300) return {"module3_wsi_analysis": result, "modules_queue": _pop_queue(state)} def run_module4_node(state: PatientState): print("\n--- NODE: Module 4 (Progression Tracking) ---") prompt = ( f"Identify the M-Spike progression metrics from this data and state if it indicates rapid progression:\n" f"{state['raw_clinical_text']}\nProgression Summary:" ) result = generate_with_adapter(prompt, adapter_name="module4", max_tokens=300) return {"module4_progression": result, "modules_queue": _pop_queue(state)} def run_module5_rag_node(state: PatientState): print("\n--- NODE: Module 5 (RAG Retrieval) ---") active_retriever = None source_label = "" r, label = load_persisted_retriever() if r: active_retriever = r source_label = label if active_retriever is None: print(" [Module 5] No admin KB found — using default built-in guidelines") active_retriever = retriever source_label = "Default built-in NCCN/ESMO guideline excerpts" print(f" [Module 5] Active source: {source_label}") prompt = ( f"Formulate a 4-to-8 word search query to look up treatment guidelines for this patient.\n" f"Clinical Data: {state['raw_clinical_text']}\n\n" f"Output strictly the search query and nothing else.\nSearch Query String:" ) search_query = generate_with_adapter(prompt, adapter_name="default", max_tokens=60) clean_query = exclude_thinking_component(search_query).replace('"', '').replace('\n', ' ') print(f" [Module 5] Searching for: '{clean_query}'") retrieved_docs = active_retriever.invoke(clean_query) formatted_chunks = [] for doc in retrieved_docs: page = doc.metadata.get("page", "?") source = doc.metadata.get("source", "Guidelines") formatted_chunks.append( f"[SOURCE: {source} | PAGE: {page}]\n{doc.page_content}" ) formatted_context = "\n\n".join(formatted_chunks) return { "module5_guidelines": formatted_context, "module5_raw_chunks": formatted_context, "module5_source": source_label, "modules_queue": _pop_queue(state), } def orchestrator_synthesis_node(state: PatientState): """ Dual-mode orchestrator: FAST PATH — module5_rag did NOT run (risk / WSI / progression queries only). Short focused prompt, max_tokens=250. No citation overhead. FULL PATH — module5_rag DID run (therapy recommendation requested). Full citation prompt with guideline context, max_tokens=800. """ print("\n--- NODE: Orchestrator Synthesis ---") profile_parts = [] if state.get("module2_risk_score"): profile_parts.append(f"- Risk Score: {state['module2_risk_score']}") if state.get("module3_wsi_analysis"): profile_parts.append(f"- Bone Marrow WSI: {state['module3_wsi_analysis']}") if state.get("module4_progression"): profile_parts.append(f"- Progression: {state['module4_progression']}") profile_block = "\n".join(profile_parts) if profile_parts else "No sub-module analyses performed." rag_ran = bool(state.get("module5_guidelines", "").strip()) if not rag_ran: # ── FAST PATH ───────────────────────────────────────────────────────── print(" [Orchestrator] Fast path — no RAG, short summary prompt (max_tokens=250)") prompt = ( f"You are a hematology AI assistant.\n" f"Answer the clinician's question concisely using only the findings below.\n" f"Do not speculate. Do not mention treatment guidelines.\n\n" f"[QUESTION]\n{state['user_query']}\n\n" f"[FINDINGS]\n{profile_block}\n\n" f"Answer in 2-4 sentences:" ) max_tokens = 100 else: # ── FULL PATH ───────────────────────────────────────────────────────── print(" [Orchestrator] Full path — RAG active, citation prompt (max_tokens=800)") guidelines_block = state.get("module5_guidelines", "") prompt = ( f"You are the Master Hematology Orchestrator.\n" f"Answer the clinician's question using ONLY the data below.\n" f"Do not speculate. Do not show your reasoning.\n\n" f"CITATION RULE: After every treatment or guideline recommendation,\n" f"append the exact [SOURCE: ... | PAGE: ...] tag from the guidelines.\n" f"If no supporting guideline exists for a statement, write (no guideline available).\n\n" f"[QUESTION]\n{state['user_query']}\n\n" f"[PATIENT DATA]\n{state['raw_clinical_text']}\n\n" f"[MODULE FINDINGS]\n{profile_block}\n\n" f"[RETRIEVED GUIDELINES — cite exactly]\n{guidelines_block}\n\n" f"Final Answer (with citations):" ) max_tokens = 500 result = generate_with_adapter(prompt, adapter_name="default", max_tokens=max_tokens) return {"final_recommendation": result} # ========================================== # BUILD AND COMPILE THE GRAPH # ========================================== workflow = StateGraph(PatientState) workflow.add_node("planner", planner_node) workflow.add_node("module2", run_module2_node) workflow.add_node("module3", run_module3_node) workflow.add_node("module4", run_module4_node) workflow.add_node("module5_rag", run_module5_rag_node) workflow.add_node("orchestrator", orchestrator_synthesis_node) workflow.set_entry_point("planner") all_modules = ["module2", "module3", "module4", "module5_rag", "orchestrator"] workflow.add_conditional_edges("planner", route_next, {m: m for m in all_modules}) for mod in ["module2", "module3", "module4", "module5_rag"]: workflow.add_conditional_edges(mod, route_next, {m: m for m in all_modules}) workflow.add_edge("orchestrator", END) full_agent = workflow.compile() # ========================================== # RUN THE AGENT (CLI testing) # ========================================== if __name__ == "__main__": clinical_text = """ 65-year-old Male patient. Transplant eligible. - CRAB Panel: Creatinine: 2.5 mg/dL, Calcium: 10.5 mg/dL, Hemoglobin: 9.0 g/dL. - Tumor/Staging Panel: Beta-2 Microglobulin: 5.8 mg/L, LDH: 280 U/L. - Behaviour: Feeling Dizziness with Chest Pain and Mental Confusion. """ _base = { "patient_id": "MMRF_2240", "raw_clinical_text": clinical_text, "modules_queue": [], "wsi_image_path": "", "wsi_output_path": "", "module2_risk_score": "", "module3_wsi_analysis": "", "module4_progression": "", "module5_guidelines": "", "module5_raw_chunks": "", "module5_source": "", "final_recommendation": "", } print("\n\n=== QUERY 2: WSI Patch Analysis ===") out2 = full_agent.invoke({ **_base, "user_query": "Analyse the bone marrow biopsy and report malignant patch percentage.", "wsi_image_path": "/media/shrish/Data/medgemma_finetune/segpc/TCIA_SegPC_dataset/" "TCIA_SegPC_dataset/TCIA_SegPC_dataset/validation/validation/x/104.bmp", "wsi_output_path": "annotated_wsi_output2.png", }) print("\n=== FINAL ANSWER ===") print(out2["final_recommendation"])