import gradio as gr
import traceback
import time
import json
import re
import os
from transformers import AutoTokenizer
# --- Internal Core Layer Dependencies ---
from modules.database import restore_vector_database, production_smart_reranked_retrieval
from modules.inference import backend_engine, QueryExpansionSchema, torch
# --- CRITICAL IMPORT FIXED: Pulling Qdrant Query Models to prevent attribute errors ---
from qdrant_client.models import Filter, FieldCondition, MatchValue
from fastembed.sparse import SparseTextEmbedding
from fastembed import TextEmbedding
# Initialize Core System Engines Globals
client = restore_vector_database("qdrant_in_memory_db.pkl")
rerank_model = backend_engine.rerank_model
rerank_tokenizer = backend_engine.rerank_tokenizer
embedding_model = TextEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
tinyllama_tokenizer = AutoTokenizer.from_pretrained(
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
)
print("๐ฅ Initializing local BM25 Text Encoder...")
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25")
def get_dense_embedding(text):
return list(embedding_model.embed([text]))[0].tolist()
def lookup_chunks_by_page(page_number):
"""
Restored back-end utility task to pull precise page records directly
out of the vector collection.
"""
if not page_number or not str(page_number).strip().isdigit():
return "โ ๏ธ Please provide a valid numerical page identifier configuration."
try:
page_val = int(page_number)
core_engine = client._client
# --- FIXED QDRANT FILTER OBJECT SYNTAX ---
page_inspect = core_engine.scroll(
collection_name="medical_manual",
scroll_filter=Filter(
must=[
FieldCondition(
key="metadata.page_numbers",
match=MatchValue(value=page_val)
)
]
),
limit=100,
with_payload=True,
with_vectors=False
)
scroll_results = page_inspect[0]
if not scroll_results:
return f"โ No chunks found matching exactly with stored metadata identifier: metadata.page_numbers={page_val}"
output_report = f"๐ Found {len(scroll_results)} distinct chunks indexed for Page {page_val}:\n\n"
for idx, point in enumerate(scroll_results):
payload = point.payload
text_snippet = payload.get("text", "[[No text key discovered in payload]]")
snippet = text_snippet[:800] + "..." if len(text_snippet) > 800 else text_snippet
output_report += f"--- [CHUNK {idx+1}] ---\n"
output_report += f"Snippet text content (First 800 Chars):\n{snippet}\n\n"
return output_report
except Exception as e:
return f"โ Error querying local in-memory Qdrant: {str(e)}"
def process_reactive_clinical_pipeline(user_query):
if not user_query.strip():
yield (
gr.update(visible=False),
"
โ No query provided.
", "", "", "",
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
"", "", "", "", ""
)
return
collection_name = "medical_manual"
# --- DARK MODE CONTRAST REPAIR ---
# Enforcing explicit dark backgrounds with bright text and high-contrast borders
style_active = "margin-bottom:8px; padding:10px 14px; border-left:4px solid #f59e0b; background-color:#292524; color:#fef08a; font-family:sans-serif; border-radius:0 4px 4px 0;"
style_done = "margin-bottom:8px; padding:10px 14px; border-left:4px solid #10b981; background-color:#1c1917; color:#a7f3d0; font-family:sans-serif; border-radius:0 4px 4px 0;"
style_hold = "margin-bottom:8px; padding:10px 14px; border-left:4px solid #475569; background-color:#1c1917; color:#94a3b8; font-family:sans-serif; border-radius:0 4px 4px 0;"
# -------------------------------------------------------------
# INITIALIZATION STATE
# -------------------------------------------------------------
s1 = f"โณ Step 1: Rephrasing Query Matrix
Compiling alternative nomenclature options using local engine...
"
s2 = f"๐ Step 2: Scouring Vector DB Storage Pool
"
s3 = f"๐ Step 3: Filtering via Cross-Encoder Metrics
"
s4 = f"๐ Step 4: Synthesizing Validated Clinical Sheet
"
yield (
gr.update(visible=True),
s1, s2, s3, s4,
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
"", "", "", "", ""
)
time.sleep(0.1)
# -------------------------------------------------------------
# STAGE 1: DUAL CHANNEL QUERY TRANSLATION
# -------------------------------------------------------------
# -------------------------------------------------------------
# QUERY EXPANSION PROMPT
# -------------------------------------------------------------
prompt_expansion = f"""
<|im_start|>system
You generate short medical semantic search queries for retrieval systems.
Rules:
- Generate exactly 2 queries
- Keep the meaning of the original question
- Use different wording
- Include symptoms, diagnosis, or treatment terms if relevant
- Keep each query under 12 words
- Output ONLY valid JSON
Return ONLY this JSON format:
{{
"queries": [
"generated query 1",
"generated query 2"
]
}}
Any other format will NOT be parsed correctly.
<|im_end|>
<|im_start|>user
Question:
{user_query}
<|im_end|>
<|im_start|>assistant
"""
expansion_output = backend_engine.run_inference_raw(
prompt_expansion,
max_tokens=40,
response_mode="queries"
)
print("Expansion Output Before Formatting Check: ", expansion_output)
expanded_queries = [user_query]
try:
# ---------------------------------------------------------
# Extract first JSON-looking block
# ---------------------------------------------------------
json_match = re.search(
r"\{[\s\S]*\}",
expansion_output
)
if json_match:
raw_json = json_match.group(0)
# ---------------------------------------------------------
# Parse raw JSON safely
# ---------------------------------------------------------
parsed_json = json.loads(raw_json)
# ---------------------------------------------------------
# FORMAT 1:
# {
# "queries": ["a", "b"]
# }
# ---------------------------------------------------------
if (
isinstance(parsed_json, dict)
and "queries" in parsed_json
and isinstance(parsed_json["queries"], list)
):
expanded_queries.extend([
q.strip()
for q in parsed_json["queries"]
if isinstance(q, str) and q.strip()
])
# ---------------------------------------------------------
# FORMAT 2:
# {
# "queries": [
# {"query": "..."},
# {"query": "..."}
# ]
# }
# ---------------------------------------------------------
elif (
isinstance(parsed_json, dict)
and "queries" in parsed_json
and isinstance(parsed_json["queries"], list)
):
for item in parsed_json["queries"]:
if (
isinstance(item, dict)
and "query" in item
and isinstance(item["query"], str)
):
expanded_queries.append(
item["query"].strip()
)
# ---------------------------------------------------------
# FORMAT 3:
# [
# {"query": "..."},
# {"query": "..."}
# ]
# ---------------------------------------------------------
elif isinstance(parsed_json, list):
for item in parsed_json:
if (
isinstance(item, dict)
and "query" in item
and isinstance(item["query"], str)
):
expanded_queries.append(
item["query"].strip()
)
except Exception as e:
print(
f"โ ๏ธ Query expansion extraction bypassed. "
f"Fallback engaged. Reason: {e}"
)
# ---------------------------------------------------------
# Final cleanup
# ---------------------------------------------------------
expanded_queries = list(dict.fromkeys([
q.strip()
for q in expanded_queries
if q.strip()
]))
print("๐ฏ๏ธ Rephrased Queries: ", expanded_queries)
expanded_queries_display = "\n".join([f"โก๏ธ Target Vector Search Matrix {idx}: {q}" for idx, q in enumerate(expanded_queries, 1)])
s1 = f"โ
Step 1: Rephrasing Complete
Query matrix successfully compiled down to {len(expanded_queries)} extraction criteria targets.
"
s2 = f"โณ Step 2: Scouring Vector DB Storage Pool
Polling in-memory Qdrant chunks via dense dot-products and sparse BM25 tokens...
"
yield (
gr.update(visible=True),
s1, s2, s3, s4,
gr.update(visible=True, value=expanded_queries_display), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
"", "", "", "", ""
)
time.sleep(0.1)
# -------------------------------------------------------------
# STAGE 2: DATA EXTRACTION LOOP
# -------------------------------------------------------------
merged_contexts = []
seen_chunks = set()
db_retrieval_logs = "๐ Isolated Candidate Tracking Windows:\n\n"
# NEW:
# We now populate reranker_logs directly from the FIRST reranking pass
# inside production_smart_reranked_retrieval().
reranker_logs = (
"๐ฏ Cross-Encoder Evaluation Metrics Matrix Weights:\n\n"
)
for q in expanded_queries:
context_payload = production_smart_reranked_retrieval(
q,
client,
collection_name=collection_name,
rerank_model=rerank_model,
rerank_tokenizer=rerank_tokenizer,
sparse_encoder=sparse_encoder,
get_dense_embedding=get_dense_embedding
)
# ---------------------------------------------------------
# NEW STRUCTURED PAYLOAD PATH
# ---------------------------------------------------------
if isinstance(context_payload, dict):
retrieved_context = context_payload.get("context", "")
debug_chunks_info = context_payload.get("debug", [])
# ---------------------------------------------
# Preserve reranker scores from FIRST pass
# ---------------------------------------------
for dbg in debug_chunks_info:
reranker_logs += (
f" ๐ [Score: {dbg['score']:6.2f}] "
f"-> PAGE {dbg['page']} "
f"({dbg['header']})\n"
)
# ---------------------------------------------
# Existing chunk merging logic preserved
# ---------------------------------------------
segments = retrieved_context.split("--- SOURCE: ")
for seg in segments:
seg_cleaned = seg.strip()
if seg_cleaned and seg_cleaned not in seen_chunks:
seen_chunks.add(seg_cleaned)
merged_contexts.append(
f"--- SOURCE: {seg_cleaned}"
)
short_preview = (
seg_cleaned.split('\n')[0][:80] + "..."
)
db_retrieval_logs += (
f" ๐ฅ Cached Node: {short_preview}\n"
)
# ---------------------------------------------------------
# LEGACY FALLBACK SUPPORT
# Keeps compatibility if old string return appears
# ---------------------------------------------------------
elif (
isinstance(context_payload, str)
and not context_payload.startswith("โ ๏ธ")
):
segments = context_payload.split("--- SOURCE: ")
for seg in segments:
seg_cleaned = seg.strip()
if seg_cleaned and seg_cleaned not in seen_chunks:
seen_chunks.add(seg_cleaned)
merged_contexts.append(
f"--- SOURCE: {seg_cleaned}"
)
short_preview = (
seg_cleaned.split('\n')[0][:80] + "..."
)
db_retrieval_logs += (
f" ๐ฅ Cached Node: {short_preview}\n"
)
# -------------------------------------------------------------
# FINAL MERGED CONTEXT
# -------------------------------------------------------------
final_merged_context_string = (
"\n\n".join(merged_contexts)
if merged_contexts
else "No context found."
)
final_merged_context_string = final_merged_context_string.replace('adarshofficial11', '').replace('@gmail.com', '')
# -------------------------------------------------------------
# TOKEN-BASED CONTEXT LIMITING
# Restrict final retrieval context to ~3000 tokens
# to avoid overflowing TinyLlama context window
# -------------------------------------------------------------
MAX_CONTEXT_TOKENS = 1400
try:
context_tokens = tinyllama_tokenizer.encode(
final_merged_context_string,
add_special_tokens=False
)
original_token_count = len(context_tokens)
if original_token_count > MAX_CONTEXT_TOKENS:
print(
f"โ ๏ธ Context too large "
f"({original_token_count} tokens). "
f"Truncating to {MAX_CONTEXT_TOKENS} tokens."
)
context_tokens = context_tokens[:MAX_CONTEXT_TOKENS]
final_merged_context_string = (
tinyllama_tokenizer.decode(
context_tokens,
skip_special_tokens=True
)
)
print(
f"๐ง Final Context Token Count: "
f"{len(context_tokens)}"
)
except Exception as e:
print(
f"โ ๏ธ Token truncation failed: {e}"
)
# Emergency fallback
final_merged_context_string = (
final_merged_context_string[:12000]
)
db_retrieval_logs += (
f"\n๐ Gathered total of "
f"{len(seen_chunks)} raw context segments "
f"from sliding layers."
)
s2 = f"โ
Step 2: DB Extraction Complete
Located {len(seen_chunks)} page candidates inside database memory blocks.
"
s3 = f"โณ Step 3: Filtering via Cross-Encoder Metrics
Evaluating textual relevance using cross-encoder architecture weights...
"
yield (
gr.update(visible=True),
s1, s2, s3, s4,
gr.update(visible=True), gr.update(visible=True, value=db_retrieval_logs), gr.update(visible=False), gr.update(visible=False),
"", "", "", "", ""
)
time.sleep(0.1)
# -------------------------------------------------------------
# STAGE 3: CROSS-ENCODER RANK REVALUATION
# -------------------------------------------------------------
s3 = f"โ
Step 3: Neural Filtering Complete
Cross-Encoder validation matrix updated. Irrelevant records dropped.
"
s4 = f"โณ Step 4: Synthesizing Validated Clinical Sheet
Formulating natural markdown response layers (Template execution)...
"
yield (
gr.update(visible=True),
s1, s2, s3, s4,
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True, value=reranker_logs), gr.update(visible=False),
"", "", "", "", ""
)
time.sleep(0.1)
# -------------------------------------------------------------
# STAGE 4: HIGH-SPEED NATIVE TEMPLATE INTERPRETATION
# -------------------------------------------------------------
# -------------------------------------------------------------
# FINAL GENERATION PROMPT
# -------------------------------------------------------------
generation_prompt = f"""
<|im_start|>system
You are a medical retrieval assistant.
Use ONLY the provided context.
Rules:
- Do not invent information
- Do not guess missing details
- Do not place treatments under symptoms
- Do not place symptoms under treatments
- Omit sections if information is unavailable
- Keep the response concise and factual
- If details don't fit a section then put them at the end as summary.
If insufficient information exists, say:
"I could not find enough information in the provided medical sources."
Output Format:
Condition:
Symptoms:
- symptom
Treatments:
- treatment
Clinical Summary:
<2 sentence summary>
<|im_end|>
<|im_start|>user
Context:
{final_merged_context_string}
Question:
{user_query}
<|im_end|>
<|im_start|>assistant
"""
print(
f"๐ Final Prompt Length: "
f"{len(generation_prompt)} chars"
)
prompt_tokens = tinyllama_tokenizer.encode(
generation_prompt,
add_special_tokens=False
)
print(
f"๐ง Final Prompt Tokens: "
f"{len(prompt_tokens)}"
)
raw_template_output = backend_engine.run_inference_raw(
generation_prompt,
max_tokens=300,
response_mode="template"
)
print("๐ LLM Output: ", raw_template_output)
def extract_section(text, current_tag, next_tag=None):
try:
if next_tag:
pattern = rf"{re.escape(current_tag)}\s*(.*?)\s*{re.escape(next_tag)}"
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
return match.group(1).strip() if match else ""
else:
pattern = rf"{re.escape(current_tag)}\s*(.*)"
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
return match.group(1).strip() if match else ""
except Exception:
return ""
pathology = extract_section(raw_template_output, "[PATHOLOGY]", "[SYMPTOMS]").upper()
symptoms = extract_section(raw_template_output, "[SYMPTOMS]", "[INTERVENTIONS]")
meds = extract_section(raw_template_output, "[INTERVENTIONS]", "[SUMMARY]")
summary = extract_section(raw_template_output, "[SUMMARY]", None)
if not pathology: pathology = ""
if not symptoms: symptoms = ""
if not meds: meds = ""
if not summary: summary = raw_template_output
s4 = f"โ
Step 4: Profile Synthesis Complete
Template extracted successfully. Markdown fields generated with clean sentence flow.
"
yield (
gr.update(visible=True),
s1, s2, s3, s4,
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True, value=raw_template_output),
pathology, summary, symptoms, meds, ""
)
# -------------------------------------------------------------
# GRAPHICAL GRADIO VIEW PORT INTERFACE LAYOUT
# -------------------------------------------------------------
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=4):
gr.Markdown("# ๐ฅ Clinical Diagnostic Multi-Query RAG\n### Local Vision-Language Synthesizer Node")
with gr.Column(scale=1, min_width=150):
gr.Markdown("๐ข **SYSTEM STATUS**\nEngine Tier: `CPU-Optimized`")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("""When clinical retrieval pipeline is running,
please do not refresh the page.
CPU inference may take up to 60 seconds.
You can check the Tracker for updates.
""")
gr.HTML("
")
with gr.Tabs():
# --- TAB 1: PRIMARY PIPELINE ---
with gr.TabItem("๐ Production Diagnostic Dashboard"):
with gr.Row():
with gr.Column(scale=2, variant="panel"):
gr.Markdown("### ๐๏ธ Clinical Query Console")
query_box = gr.Textbox(
label="Input Clinical Presentation or Search Query",
placeholder="e.g., Schizophrenia causes, common symptoms and clinical interventions?",
lines=5
)
with gr.Row():
clear_btn = gr.Button("๐งน Clear Input", variant="secondary")
submit_btn = gr.Button("โก Execute Retrieval & Grounding", variant="primary")
with gr.Group(visible=False) as pipeline_status_box:
gr.Markdown("### โ๏ธ Pipeline Tracking Monitor")
s1_html = gr.HTML()
s2_html = gr.HTML()
s3_html = gr.HTML()
s4_html = gr.HTML()
gr.HTML("
")
with gr.Accordion("๐ Target Expansion Matrix Logs (Dual Channels)", open=False) as a1:
s1_details = gr.Code(language="markdown")
with gr.Accordion("๐ Localized DB Window Chunk Extractions (800 Chars Preview)", open=False) as a2:
s2_details = gr.Code(language="markdown")
with gr.Accordion("๐ BGE Cross-Encoder Evaluation Metrics Weights", open=False) as a3:
s3_details = gr.Code(language="markdown")
with gr.Accordion("๐ Native Unrestricted Text Manifest Payload", open=False) as a4:
s4_details = gr.Code(language="markdown")
with gr.Column(scale=3):
gr.Markdown("### ๐ Intelligent Diagnostic Readout")
with gr.Group():
pathology_lbl = gr.Textbox(label="Identified Medical Pathology Target", interactive=False)
summary_md = gr.Markdown(value="*Analysis results will generate here.*")
with gr.Row():
symptoms_txt = gr.Markdown(label="Clinical Presentation / Symptoms")
meds_txt = gr.Markdown(label="Medical Interventions")
surg_txt = gr.Markdown(label="Surgical Operations")
# --- TAB 2: RESTORED PAGE CHUNK SEARCH FINDER ---
with gr.TabItem("๐ Vector Storage Inspect Matrix"):
gr.Markdown("### ๐๏ธ Real-time Chunk Investigator")
gr.Markdown(
"#### ๐ ๏ธ Document Processing & Local Ingestion Flow\n"
"```\n"
"[ Raw Medical PDF Handbook Document ]\n"
" โ\n"
" โผ\n"
"[ Character Token Window Chunk Splitting Matrix ]\n"
" โ\n"
" โผ\n"
"[ Structural Index Metadata Tagging (Page Number, Chapter ID) ]\n"
" โ\n"
" โผ\n"
"[ Qdrant Memory-Isolated High-Performance Storage Pool ]\n"
"```"
)
gr.HTML("
")
with gr.Row():
with gr.Column(scale=1, variant="panel"):
page_input = gr.Number(label="Target Query Metadata Page Number", value=1742, precision=0)
fetch_page_btn = gr.Button("๐ฐ๏ธ Pull Chunks directly from Database", variant="primary")
with gr.Column(scale=3):
page_output_display = gr.Textbox(label="Raw Target Payload Output Matrix (Max 800 Chars Per Block)", lines=18, interactive=False)
# --- Event Handling Grid Bindings ---
submit_btn.click(
fn=process_reactive_clinical_pipeline,
inputs=[query_box],
outputs=[
pipeline_status_box, s1_html, s2_html, s3_html, s4_html,
s1_details, a2, a3, s4_details,
pathology_lbl, summary_md, symptoms_txt, meds_txt, surg_txt
],
show_progress="hidden"
)
clear_btn.click(
lambda: (gr.update(visible=False), "", "", "", "", "", "", "", "", "", "", "*Analysis results will generate here.*", "", "", ""),
inputs=[],
outputs=[
pipeline_status_box, s1_html, s2_html, s3_html, s4_html,
s1_details, s2_details, s3_details, s4_details,
query_box, pathology_lbl, summary_md, symptoms_txt, meds_txt, surg_txt
]
)
fetch_page_btn.click(
fn=lookup_chunks_by_page,
inputs=[page_input],
outputs=[page_output_display]
)
if __name__ == "__main__":
# Disable Gradio's internal loopback check which fails on cloud instances like Spaces
import os
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
allowed_paths=["/app"],
theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate")
)