HNTAI / generate_patient_summary_colab.py
sachinchandrankallar's picture
patient summary working
f91c303
Raw
History Blame
32.3 kB
# Colab-ready script to replicate generate_patient_summary API logic in a single file
# Install necessary packages for Colab environment (run this in a separate cell):
# !pip install transformers optimum llama-cpp-python huggingface_hub requests
import os
import re
import time
import json
import logging
import requests
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline as transformers_pipeline
from optimum.intel.openvino import OVModelForCausalLM
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Utils from openvino_summarizer_utils.py ---
def parse_ehr_chartsummarydtl(chartsummarydtl):
visits = []
for entry in chartsummarydtl:
visit = {}
visit["chartdate"] = entry.get("chartdate", "")[:10]
vitals_dict = {}
weight = None
if "vitals" in entry:
for v in entry["vitals"]:
if ":" in v:
k, val = v.split(":", 1)
k = k.strip()
val = val.strip()
if k.lower().startswith("weight"):
weight = val
else:
vitals_dict[k] = val
visit["vitals"] = vitals_dict
if weight:
visit["weight"] = weight
if "allergies" in entry:
visit["allergies"] = entry["allergies"]
if "diagnosis" in entry:
visit["diagnosis"] = entry["diagnosis"]
if "medications" in entry:
visit["medications"] = entry["medications"]
labtests = []
if "labtests" in entry:
for l in entry["labtests"]:
name = l.get("name", "")
value = l.get("value", "")
if name or value:
labtests.append({"name": name, "value": value})
visit["labtests"] = labtests
if "radiologyorders" in entry:
visit["radiologyorders"] = [r.get("name", "") for r in entry["radiologyorders"] if r.get("name")]
visits.append(visit)
return visits
ALIASES = {
("vitals","Bp(sys)(mmHg)"): [("vitals","Bp(sys)(mmHg)"), ("vitals","Bp_sys"), ("vitals","SBP")],
("vitals","Bp(dia)(mmHg)"): [("vitals","Bp(dia)(mmHg)"), ("vitals","Bp_dia"), ("vitals","DBP")],
("labtests","HbA1c (%)"): [("labtests","HbA1c (%)"), ("labtests","HbA1c")],
("labtests","Creatinine Ratio"): [("labtests","Creatinine Ratio"), ("labtests","Creatinine")],
}
def visits_sorted(v):
return sorted(v, key=lambda v: v.get("chartdate", ""))
def to_float(val):
try:
s = str(val)
m = re.findall(r"-?\d+\.?\d*", s)
return float(m[0]) if m else None
except:
return None
def _latest_value_exact(visits, key_path):
v_sorted = visits_sorted(visits)
if not v_sorted:
return None
if key_path[0] == "labtests":
for v in reversed(v_sorted):
for lab in v.get("labtests", []):
if lab.get("name") == key_path[1]:
return lab.get("value")
return None
for v in reversed(v_sorted):
cur = v
ok = True
for k in key_path:
if isinstance(cur, dict) and k in cur:
cur = cur[k]
else:
ok = False
break
if ok:
return cur
return None
def latest_value(visits, key_path):
for kp in ALIASES.get(key_path, [key_path]):
val = _latest_value_exact(visits, kp)
if val is not None:
return val
return None
def active_set(visits, field):
s = set()
for v in visits:
s.update(v.get(field, []))
return s
def _fmt(x, spec=None):
if x is None:
return "N/A"
try:
return format(x, spec) if spec else str(x)
except Exception:
return str(x)
def compute_deltas(old_visits, new_visits):
prev_all = old_visits
curr_all = old_visits + new_visits
def get_val(visits, path):
return to_float(latest_value(visits, path))
all_lab_names = set()
for visits_list in [prev_all, curr_all]:
for visit in visits_list:
for lab in visit.get("labtests", []):
if lab.get("name"):
all_lab_names.add(lab["name"])
w_p, w_c = get_val(prev_all, ("weight",)), get_val(curr_all, ("weight",))
s_p, s_c = get_val(prev_all, ("vitals","Bp(sys)(mmHg)")), get_val(curr_all, ("vitals","Bp(sys)(mmHg)"))
d_p, d_c = get_val(prev_all, ("vitals","Bp(dia)(mmHg)")), get_val(curr_all, ("vitals","Bp(dia)(mmHg)"))
lab_deltas = {}
for lab_name in all_lab_names:
prev_val = get_val(prev_all, ("labtests", lab_name))
curr_val = get_val(curr_all, ("labtests", lab_name))
if prev_val is not None or curr_val is not None:
delta = (curr_val - prev_val) if prev_val is not None and curr_val is not None else None
lab_deltas[lab_name] = {"prev": prev_val, "curr": curr_val, "delta": delta}
return {
"added_dx": sorted(list(active_set(curr_all,"diagnosis") - active_set(prev_all,"diagnosis"))),
"started_meds": sorted(list(active_set(curr_all,"medications") - active_set(prev_all,"medications"))),
"stopped_meds": sorted(list(active_set(prev_all,"medications") - active_set(curr_all,"medications"))),
"weight": {"prev": w_p, "curr": w_c, "delta": (w_c - w_p) if w_p and w_c else None},
"bp_sys": {"prev": s_p, "curr": s_c, "delta": (s_c - s_p) if s_p and s_c else None},
"bp_dia": {"prev": d_p, "curr": d_c, "delta": (d_c - d_p) if d_p and d_c else None},
"labs": lab_deltas,
}
def build_compact_baseline(all_visits):
all_lab_names = set()
for visit in all_visits:
for lab in visit.get("labtests", []):
if lab.get("name"):
all_lab_names.add(lab["name"])
lab_strings = []
for lab_name in sorted(all_lab_names):
lab_value = latest_value(all_visits, ("labtests", lab_name))
if lab_value is not None:
lab_strings.append(f"{lab_name}: {lab_value}")
labs_text = ", ".join(lab_strings) if lab_strings else "N/A"
return f"Latest date: {latest_value(all_visits,('chartdate',)) or 'N/A'}\n" \
f"Active Diagnoses: {', '.join(sorted(active_set(all_visits,'diagnosis'))) or 'N/A'}\n" \
f"Active Medications: {', '.join(sorted(active_set(all_visits,'medications'))) or 'N/A'}\n" \
f"Latest Vitals: Bp: {latest_value(all_visits,('vitals','Bp(sys)(mmHg)'))}/{latest_value(all_visits,('vitals','Bp(dia)(mmHg)'))} mmHg, Weight: {latest_value(all_visits,('weight',))}\n" \
f"Latest Labs: {labs_text}"
def delta_to_text(delta):
L = []
if delta["added_dx"]:
L.append("New Diagnoses: " + ", ".join(delta["added_dx"]))
if delta["started_meds"]:
L.append("Medications Started: " + ", ".join(delta["started_meds"]))
if delta["stopped_meds"]:
L.append("Medications Stopped: " + ", ".join(delta["stopped_meds"]))
w = delta["weight"]
L.append(f"Weight: {_fmt(w['prev'])} -> {_fmt(w['curr'])}{_fmt(w['delta'], '+.1f')})")
s, d = delta["bp_sys"], delta["bp_dia"]
L.append(f"BP: {_fmt(s['curr'])}/{_fmt(d['curr'])} (Δs {_fmt(s['delta'], '+.0f')}, Δd {_fmt(d['delta'], '+.0f')})")
for lab_name, lab_data in delta["labs"].items():
if lab_data["prev"] is not None or lab_data["curr"] is not None:
L.append(f"{lab_name}: {_fmt(lab_data['prev'])} -> {_fmt(lab_data['curr'])}{_fmt(lab_data['delta'], '+.1f')})")
return "\n".join(L)
def build_main_prompt(baseline, delta_text, patient_info=""):
return (
"You are an expert clinical AI assistant. Your task is to generate a patient summary.\n"
"Use the chartsummarydtl for context. The STRUCTURED BASELINE and DELTAS are the absolute ground truth.\n"
"Produce a concise, physician-ready summary. Never omit critical new information from the deltas.\n\n"
"The summary MUST have four sections:\n"
"1) Clinical Assessment\n"
"2) Key Trends & Changes\n"
"3) Plan & Suggested Actions\n"
"4) Direct Guidance for Physician\n\n"
f"PATIENT INFORMATION:\n{patient_info}\n\n"
f"STRUCTURED BASELINE (authoritative):\n{baseline}\n\n"
f"STRUCTURED DELTAS (authoritative):\n{delta_text}\n\n"
"Now generate the complete clinical summary with all four sections in markdown format:"
)
# --- GGUF model loader and pipeline (simplified) ---
class GGUFModelPipeline:
def __init__(self, model_path_or_repo, filename=None, cache_dir=None, timeout=300):
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import os
import time
import logging
import threading
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
cache_dir = cache_dir or os.environ.get("HF_HOME", "/tmp/huggingface")
os.makedirs(cache_dir, exist_ok=True)
self.timeout = timeout
if filename is not None:
try:
logging.info(f"Downloading model from {model_path_or_repo}/{filename}")
local_path = hf_hub_download(
repo_id=model_path_or_repo,
filename=filename,
cache_dir=cache_dir,
resume_download=True,
local_files_only=False,
)
logging.info(f"Model downloaded successfully to {local_path}")
except Exception as e:
logging.error(f"Failed to download model: {e}")
raise RuntimeError(f"Model download failed: {str(e)}")
else:
local_path = model_path_or_repo
if not os.path.exists(local_path):
raise FileNotFoundError(f"Model path does not exist: {local_path}")
file_size = os.path.getsize(local_path) / (1024 * 1024)
logging.info(f"Model file size: {file_size:.2f} MB")
if file_size > 5000:
logging.warning(f"Model file is very large ({file_size:.2f} MB), may cause memory issues")
load_start = time.time()
try:
cpu_count = os.cpu_count() or 2
is_hf_space = os.environ.get('SPACE_ID') is not None
if is_hf_space:
default_threads = 1
n_batch = 16
n_ctx = 4096
logging.info("[GGUF] Detected Hugging Face Space - using ultra-conservative memory settings")
else:
default_threads = max(1, min(2, cpu_count))
n_batch = 32
n_ctx = 1024
n_threads = int(os.environ.get("GGUF_N_THREADS", str(default_threads)))
n_batch = int(os.environ.get("GGUF_N_BATCH", str(n_batch)))
self.model = Llama(
model_path=local_path,
n_ctx=n_ctx,
n_threads=n_threads,
n_batch=n_batch,
n_gpu_layers=0,
logits_all=False,
embedding=False,
use_mmap=True,
use_mlock=False,
seed=0,
verbose=False,
rope_freq_base=10000,
rope_freq_scale=1.0,
mul_mat_q=True,
f16_kv=True,
vocab_only=False,
n_threads_batch=n_threads,
mmap=True,
cache_type_k=0,
cache_type_v=0,
)
except Exception as e:
logging.error(f"Failed to initialize GGUF model: {e}")
raise RuntimeError(f"Failed to initialize GGUF model via llama.cpp: {e}")
load_time = time.time() - load_start
logging.info(f"[GGUF] Model initialized in {load_time:.2f}s from {local_path} (threads={n_threads}, batch={n_batch})")
def _strip_special_tokens(self, text: str) -> str:
patterns = [
r"<\|assistant\|>", r"<\|user\|>", r"<\|system\|>", r"<\|end\|>", r"<\|endoftext\|>", r"</s>", r"<s>"
]
for p in patterns:
text = re.sub(p, "", text, flags=re.IGNORECASE)
return text.strip()
def _generate_with_timeout(self, prompt, max_tokens=4000, temperature=0.5, top_p=0.95, timeout=None):
if timeout is None:
is_hf_space = os.environ.get('SPACE_ID') is not None
timeout = int(os.environ.get('GGUF_GENERATION_TIMEOUT', '600' if is_hf_space else '300'))
def _generate():
try:
output = self.model(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=["</s>", "###"]
)
return output
except Exception as e:
raise e
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(_generate)
try:
output = future.result(timeout=timeout)
return output
except TimeoutError:
future.cancel()
raise TimeoutError(f"Generation timed out after {timeout} seconds")
def generate(self, prompt, max_tokens=4000, temperature=0.5, top_p=0.95):
t0 = time.time()
try:
output = self._generate_with_timeout(prompt, max_tokens, temperature, top_p)
dt = time.time() - t0
text = output["choices"][0]["text"].strip()
text = self._strip_special_tokens(text)
approx_words = len(text.split())
logging.info(f"[GGUF] generate: {dt:.2f}s, ~{approx_words} words, max_tokens={max_tokens}")
return text
except TimeoutError as e:
logging.error(f"Generation timed out: {e}")
raise e
except Exception as e:
logging.error(f"Generation failed: {e}")
raise RuntimeError(f"Text generation failed: {str(e)}")
def generate_full_summary(self, prompt, max_tokens=4000, max_loops=5):
def is_complete(text):
required_sections = [
'Clinical Assessment',
'Key Trends & Changes',
'Plan & Suggested Actions',
'Direct Guidance for Physician'
]
missing_sections = [s for s in required_sections if s not in text]
if missing_sections:
logging.info(f"[GGUF] Missing sections: {missing_sections}")
return False, missing_sections
ends_with_punct = bool(re.search(r'[.!?][\s\n]*$', text))
if not ends_with_punct:
logging.info("[GGUF] Summary does not end with a full sentence")
return ends_with_punct, []
def generate_missing_section(section_name, base_prompt, existing_output):
"""Generate a specific missing section using targeted prompt"""
section_prompts = {
'Clinical Assessment': f"Based on the patient data provided, generate only the Clinical Assessment section in markdown format. Focus on the current clinical status, key findings, and overall patient condition.\n\nPatient Data:\n{base_prompt}\n\n## Clinical Assessment\n",
'Key Trends & Changes': f"Based on the patient data provided, generate only the Key Trends & Changes section in markdown format. Analyze trends in vitals, labs, diagnoses, and medications over time.\n\nPatient Data:\n{base_prompt}\n\n## Key Trends & Changes\n",
'Plan & Suggested Actions': f"Based on the patient data provided, generate only the Plan & Suggested Actions section in markdown format. Recommend next steps, follow-up actions, and treatment considerations.\n\nPatient Data:\n{base_prompt}\n\n## Plan & Suggested Actions\n",
'Direct Guidance for Physician': f"Based on the patient data provided, generate only the Direct Guidance for Physician section in markdown format. Provide specific recommendations for the treating physician.\n\nPatient Data:\n{base_prompt}\n\n## Direct Guidance for Physician\n"
}
targeted_prompt = section_prompts.get(section_name, f"Generate the {section_name} section based on the patient data.\n\n{base_prompt}\n\n## {section_name}\n")
try:
section_output = self.generate(targeted_prompt, max_tokens=max_tokens//2) # Use half tokens for section generation
# Clean up the output to extract just the section content
if f"## {section_name}" in section_output:
section_content = section_output.split(f"## {section_name}", 1)[1].strip()
# Remove any subsequent section headers
section_content = re.split(r'##\s+', section_content, 1)[0].strip()
return f"## {section_name}\n{section_content}"
else:
# If the model didn't follow the format, use the raw output
return f"## {section_name}\n{section_output.strip()}"
except Exception as e:
logging.error(f"Failed to generate {section_name} section: {e}")
# Return a minimal section if generation fails
return f"## {section_name}\nUnable to generate this section due to processing error. Please review patient data manually."
full_output = ""
current_prompt = prompt
total_start = time.time()
try:
logging.info(f"[GGUF] Starting enhanced full summary generation with max_loops={max_loops}")
logging.info(f"[GGUF] Prompt length: {len(prompt)} characters")
# Main generation loops
for loop_idx in range(max_loops):
loop_start = time.time()
logging.info(f"[GGUF] Starting loop {loop_idx+1}/{max_loops}")
logging.info(f"[GGUF] Current prompt length: {len(current_prompt)} characters")
output = self.generate(current_prompt, max_tokens=max_tokens)
if output.startswith(prompt):
output = output[len(prompt):].strip()
full_output += output
loop_time = time.time() - loop_start
logging.info(f"[GGUF] loop {loop_idx+1}/{max_loops}: {loop_time:.2f}s, cumulative {time.time()-total_start:.2f}s, length={len(full_output)} chars")
logging.info(f"[GGUF] Generated {len(output)} characters in this loop")
complete, missing_sections = is_complete(full_output)
if complete:
logging.info(f"[GGUF] All required sections found after loop {loop_idx+1}")
break
# If not complete and this is not the last loop, prepare next prompt
if loop_idx < max_loops - 1:
if missing_sections:
missing_list = ", ".join(missing_sections)
current_prompt = f"{prompt}\n\n{full_output}\n\nThe summary is missing these sections: {missing_list}. Please continue and complete all sections in markdown format:"
else:
current_prompt = f"{prompt}\n\n{full_output}\n\nContinue the summary and ensure it ends with a complete sentence:"
logging.info(f"[GGUF] Preparing next prompt for loop {loop_idx+2}")
# Post-processing: Generate any remaining missing sections
complete, missing_sections = is_complete(full_output)
if missing_sections:
logging.info(f"[GGUF] Generating {len(missing_sections)} missing sections post-processing")
generated_sections = []
for section in missing_sections:
logging.info(f"[GGUF] Generating missing section: {section}")
section_content = generate_missing_section(section, prompt, full_output)
generated_sections.append(section_content)
# Append generated sections to the main output
if generated_sections:
full_output += "\n\n" + "\n\n".join(generated_sections)
total_time = time.time() - total_start
logging.info(f"[GGUF] generate_full_summary completed in {total_time:.2f}s")
logging.info(f"[GGUF] Final summary length: {len(full_output)} characters")
# Final validation
final_complete, final_missing = is_complete(full_output)
if not final_complete:
logging.warning(f"[GGUF] Final summary still incomplete. Missing: {final_missing}")
# As a last resort, ensure at least basic structure
if final_missing:
fallback_sections = []
for section in final_missing:
fallback_sections.append(f"## {section}\nPlease review the patient data for this section.")
full_output += "\n\n" + "\n\n".join(fallback_sections)
return full_output.strip()
except Exception as e:
logging.error(f"Full summary generation failed: {e}")
# Instead of raising error, return a minimal complete summary
minimal_sections = [
"## Clinical Assessment\nPatient data processing encountered an error. Please review the raw patient information manually.",
"## Key Trends & Changes\nUnable to analyze trends due to processing error. Manual review recommended.",
"## Plan & Suggested Actions\nError in generating action plan. Consult with healthcare provider for appropriate next steps.",
"## Direct Guidance for Physician\nProcessing error occurred. Please conduct a thorough manual review of all patient data."
]
return "\n\n".join(minimal_sections)
def create_fallback_pipeline():
class FallbackPipeline:
def __init__(self):
self.name = "fallback_text"
def generate(self, prompt, **kwargs):
sections = [
"## Clinical Assessment\nThe patient data indicates a medical condition that requires professional evaluation. Key indicators include vital signs, diagnoses, and medication history that suggest ongoing clinical management is necessary.",
"## Key Trends & Changes\nPatient records show historical data points including vital measurements, laboratory results, and medication adjustments. Trends in weight, blood pressure, and lab values should be monitored for significant changes over time.",
"## Plan & Suggested Actions\nImmediate actions include reviewing all available patient data, consulting with specialists if needed, and ensuring continuity of care. Follow-up appointments and medication adherence should be prioritized.",
"## Direct Guidance for Physician\nThis summary was generated using a fallback method due to model unavailability. Please conduct a thorough review of the patient's complete medical record, including EHR data, imaging, and clinical notes. Consider interdisciplinary consultation for complex cases."
]
return "\n\n".join(sections)
def generate_full_summary(self, prompt, **kwargs):
return self.generate(prompt, **kwargs)
return FallbackPipeline()
# --- OpenVINO pipeline loader ---
class OpenVinoPipeline:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def get_openvino_pipeline(model_name: str):
import os
if os.path.isdir(model_name):
model = OVModelForCausalLM.from_pretrained(model_name, compile=True, device="CPU", cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface'))
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface'))
else:
model = OVModelForCausalLM.from_pretrained(model_name, export=False, compile=False, device="CPU", cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface'))
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface'))
return OpenVinoPipeline(model, tokenizer)
# --- Summarizer Agent ---
class SummarizerAgent:
def __init__(self, summarization_model_loader):
self.summarization_model_loader = summarization_model_loader
self.last_summary_length = 0
self.request_count = 0
def generate_summary(self, text):
try:
clean_text = text.strip()
if not clean_text or len(clean_text.split()) < 5:
return "Input text is too short for summarization"
model = self.summarization_model_loader.load()
if hasattr(model, 'generate_full_summary'):
summary = model.generate_full_summary(clean_text, max_tokens=4000, max_loops=2)
else:
# fallback simple summarization
summary = model(clean_text, max_length=512, min_length=50, do_sample=False)
if isinstance(summary, list) and summary:
summary = summary[0].get('summary_text', '')
elif isinstance(summary, str):
summary = summary
else:
summary = str(summary)
return summary
except Exception as e:
return f"Summary generation failed: {str(e)}"
# --- Main function to replicate generate_patient_summary API ---
def generate_patient_summary(patientid=None, token=None, key=None, ehr_data=None, patient_summarizer_model_name="falconsai/medical_summarization", patient_summarizer_model_type="summarization"):
start_total = time.time()
if ehr_data is not None:
# Use provided EHR data directly
if isinstance(ehr_data, dict):
ehr_result = ehr_data.get("result") or ehr_data
else:
ehr_result = ehr_data
else:
# Make API call
if not patientid or not token or not key:
raise ValueError("Missing required fields: patientid, token, or key")
api_url = f"{key}/Transactionapi/api/PatientList/patientsummary"
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
if key and not key.startswith("http"):
headers["x-api-key"] = key
response = requests.post(api_url, json={"patientid": patientid}, headers=headers, timeout=30)
if response.status_code != 200:
raise RuntimeError(f"API request failed with status {response.status_code}: {response.text}")
try:
api_data = response.json()
except ValueError:
api_data = response.text
if isinstance(api_data, dict):
ehr_result = api_data.get("result") or api_data
else:
ehr_result = api_data
chartsummarydtl = ehr_result.get("chartsummarydtl") if isinstance(ehr_result, dict) else None
if not chartsummarydtl:
raise RuntimeError("Missing chartsummarydtl in EHR response")
visits = parse_ehr_chartsummarydtl(chartsummarydtl)
delta = compute_deltas([], visits)
all_visits = visits_sorted(visits)
baseline = build_compact_baseline(all_visits)
delta_text = delta_to_text(delta)
prompt = build_main_prompt(baseline, delta_text)
pipeline = None
loader = None
torch.set_num_threads(2)
if patient_summarizer_model_type == "gguf":
try:
if patient_summarizer_model_name.endswith('.gguf') and '/' in patient_summarizer_model_name:
repo_id, filename = patient_summarizer_model_name.rsplit('/', 1)
pipeline = GGUFModelPipeline(repo_id, filename)
else:
pipeline = GGUFModelPipeline(patient_summarizer_model_name)
summary_raw = pipeline.generate_full_summary(prompt, max_tokens=100000, max_loops=5)
summary_start_patterns = [
"Now generate the complete, updated clinical summary with all four sections in a markdown format:",
"## Clinical Assessment",
"# Clinical Assessment",
"Clinical Assessment"
]
markdown_summary = summary_raw
for pattern in summary_start_patterns:
if pattern in summary_raw:
markdown_summary = summary_raw.split(pattern)[-1].strip()
break
total_time = time.time() - start_total
logger.info(f"[TIMING] TOTAL: {total_time:.2f}s")
return markdown_summary, baseline, delta_text
except Exception as e:
fallback_pipeline = create_fallback_pipeline()
fallback_summary = fallback_pipeline.generate_full_summary(prompt)
return fallback_summary, baseline, delta_text
elif patient_summarizer_model_type in {"text-generation", "causal-openvino"}:
from transformers import AutoTokenizer, AutoModelForCausalLM
from optimum.intel.openvino import OVModelForCausalLM
if patient_summarizer_model_type == "causal-openvino":
pipeline = get_openvino_pipeline(patient_summarizer_model_name)
else:
# Use HuggingFace pipeline
tokenizer = AutoTokenizer.from_pretrained(patient_summarizer_model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(patient_summarizer_model_name, trust_remote_code=True)
pipeline = transformers_pipeline("text-generation", model=model, tokenizer=tokenizer)
inputs = pipeline.tokenizer([prompt], return_tensors="pt")
outputs = pipeline.model.generate(**inputs, max_new_tokens=100000, do_sample=False, pad_token_id=pipeline.tokenizer.eos_token_id or 32000)
text = pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True)
new_summary = text.split("Now generate the complete, updated clinical summary with all four sections in a markdown format:")[-1].strip()
return new_summary, baseline, delta_text
elif patient_summarizer_model_type == "summarization":
summarizer_agent = SummarizerAgent(loader=None)
summary = summarizer_agent.generate_summary(prompt)
return summary, baseline, delta_text
else:
raise ValueError(f"Unsupported model type: {patient_summarizer_model_type}")
# Example usage
if __name__ == "__main__":
# Option 1: Using API call (replace with actual values)
# patientid = "your_patient_id"
# token = "your_bearer_token"
# key = "https://your-ehr-api.com"
# summary, baseline, delta_text = generate_patient_summary(patientid, token, key)
# Option 2: Using direct EHR data (recommended for testing)
sample_ehr_data = {
"chartsummarydtl": [
{
"chartdate": "2023-12-01",
"vitals": ["Bp(sys)(mmHg): 140", "Bp(dia)(mmHg): 90", "Weight: 180"],
"diagnosis": ["Hypertension", "Diabetes"],
"medications": ["Lisinopril 10mg", "Metformin 500mg"],
"labtests": [
{"name": "HbA1c (%)", "value": "7.2"},
{"name": "Creatinine", "value": "1.1"}
]
},
{
"chartdate": "2023-11-15",
"vitals": ["Bp(sys)(mmHg): 135", "Bp(dia)(mmHg): 85", "Weight: 182"],
"diagnosis": ["Hypertension"],
"medications": ["Lisinopril 10mg"],
"labtests": [
{"name": "HbA1c (%)", "value": "7.5"},
{"name": "Creatinine", "value": "1.0"}
]
}
]
}
summary, baseline, delta_text = generate_patient_summary(ehr_data=sample_ehr_data)
print("Patient Summary:")
print(summary)
print("\nBaseline:")
print(baseline)
print("\nDelta Text:")
print(delta_text)