Spaces:
Paused
Paused
| # Colab-ready script to replicate generate_patient_summary API logic in a single file | |
| # Install necessary packages for Colab environment (run this in a separate cell): | |
| # !pip install transformers optimum llama-cpp-python huggingface_hub requests | |
| import os | |
| import re | |
| import time | |
| import json | |
| import logging | |
| import requests | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline as transformers_pipeline | |
| from optimum.intel.openvino import OVModelForCausalLM | |
| from llama_cpp import Llama | |
| from huggingface_hub import hf_hub_download | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- Utils from openvino_summarizer_utils.py --- | |
| def parse_ehr_chartsummarydtl(chartsummarydtl): | |
| visits = [] | |
| for entry in chartsummarydtl: | |
| visit = {} | |
| visit["chartdate"] = entry.get("chartdate", "")[:10] | |
| vitals_dict = {} | |
| weight = None | |
| if "vitals" in entry: | |
| for v in entry["vitals"]: | |
| if ":" in v: | |
| k, val = v.split(":", 1) | |
| k = k.strip() | |
| val = val.strip() | |
| if k.lower().startswith("weight"): | |
| weight = val | |
| else: | |
| vitals_dict[k] = val | |
| visit["vitals"] = vitals_dict | |
| if weight: | |
| visit["weight"] = weight | |
| if "allergies" in entry: | |
| visit["allergies"] = entry["allergies"] | |
| if "diagnosis" in entry: | |
| visit["diagnosis"] = entry["diagnosis"] | |
| if "medications" in entry: | |
| visit["medications"] = entry["medications"] | |
| labtests = [] | |
| if "labtests" in entry: | |
| for l in entry["labtests"]: | |
| name = l.get("name", "") | |
| value = l.get("value", "") | |
| if name or value: | |
| labtests.append({"name": name, "value": value}) | |
| visit["labtests"] = labtests | |
| if "radiologyorders" in entry: | |
| visit["radiologyorders"] = [r.get("name", "") for r in entry["radiologyorders"] if r.get("name")] | |
| visits.append(visit) | |
| return visits | |
| ALIASES = { | |
| ("vitals","Bp(sys)(mmHg)"): [("vitals","Bp(sys)(mmHg)"), ("vitals","Bp_sys"), ("vitals","SBP")], | |
| ("vitals","Bp(dia)(mmHg)"): [("vitals","Bp(dia)(mmHg)"), ("vitals","Bp_dia"), ("vitals","DBP")], | |
| ("labtests","HbA1c (%)"): [("labtests","HbA1c (%)"), ("labtests","HbA1c")], | |
| ("labtests","Creatinine Ratio"): [("labtests","Creatinine Ratio"), ("labtests","Creatinine")], | |
| } | |
| def visits_sorted(v): | |
| return sorted(v, key=lambda v: v.get("chartdate", "")) | |
| def to_float(val): | |
| try: | |
| s = str(val) | |
| m = re.findall(r"-?\d+\.?\d*", s) | |
| return float(m[0]) if m else None | |
| except: | |
| return None | |
| def _latest_value_exact(visits, key_path): | |
| v_sorted = visits_sorted(visits) | |
| if not v_sorted: | |
| return None | |
| if key_path[0] == "labtests": | |
| for v in reversed(v_sorted): | |
| for lab in v.get("labtests", []): | |
| if lab.get("name") == key_path[1]: | |
| return lab.get("value") | |
| return None | |
| for v in reversed(v_sorted): | |
| cur = v | |
| ok = True | |
| for k in key_path: | |
| if isinstance(cur, dict) and k in cur: | |
| cur = cur[k] | |
| else: | |
| ok = False | |
| break | |
| if ok: | |
| return cur | |
| return None | |
| def latest_value(visits, key_path): | |
| for kp in ALIASES.get(key_path, [key_path]): | |
| val = _latest_value_exact(visits, kp) | |
| if val is not None: | |
| return val | |
| return None | |
| def active_set(visits, field): | |
| s = set() | |
| for v in visits: | |
| s.update(v.get(field, [])) | |
| return s | |
| def _fmt(x, spec=None): | |
| if x is None: | |
| return "N/A" | |
| try: | |
| return format(x, spec) if spec else str(x) | |
| except Exception: | |
| return str(x) | |
| def compute_deltas(old_visits, new_visits): | |
| prev_all = old_visits | |
| curr_all = old_visits + new_visits | |
| def get_val(visits, path): | |
| return to_float(latest_value(visits, path)) | |
| all_lab_names = set() | |
| for visits_list in [prev_all, curr_all]: | |
| for visit in visits_list: | |
| for lab in visit.get("labtests", []): | |
| if lab.get("name"): | |
| all_lab_names.add(lab["name"]) | |
| w_p, w_c = get_val(prev_all, ("weight",)), get_val(curr_all, ("weight",)) | |
| s_p, s_c = get_val(prev_all, ("vitals","Bp(sys)(mmHg)")), get_val(curr_all, ("vitals","Bp(sys)(mmHg)")) | |
| d_p, d_c = get_val(prev_all, ("vitals","Bp(dia)(mmHg)")), get_val(curr_all, ("vitals","Bp(dia)(mmHg)")) | |
| lab_deltas = {} | |
| for lab_name in all_lab_names: | |
| prev_val = get_val(prev_all, ("labtests", lab_name)) | |
| curr_val = get_val(curr_all, ("labtests", lab_name)) | |
| if prev_val is not None or curr_val is not None: | |
| delta = (curr_val - prev_val) if prev_val is not None and curr_val is not None else None | |
| lab_deltas[lab_name] = {"prev": prev_val, "curr": curr_val, "delta": delta} | |
| return { | |
| "added_dx": sorted(list(active_set(curr_all,"diagnosis") - active_set(prev_all,"diagnosis"))), | |
| "started_meds": sorted(list(active_set(curr_all,"medications") - active_set(prev_all,"medications"))), | |
| "stopped_meds": sorted(list(active_set(prev_all,"medications") - active_set(curr_all,"medications"))), | |
| "weight": {"prev": w_p, "curr": w_c, "delta": (w_c - w_p) if w_p and w_c else None}, | |
| "bp_sys": {"prev": s_p, "curr": s_c, "delta": (s_c - s_p) if s_p and s_c else None}, | |
| "bp_dia": {"prev": d_p, "curr": d_c, "delta": (d_c - d_p) if d_p and d_c else None}, | |
| "labs": lab_deltas, | |
| } | |
| def build_compact_baseline(all_visits): | |
| all_lab_names = set() | |
| for visit in all_visits: | |
| for lab in visit.get("labtests", []): | |
| if lab.get("name"): | |
| all_lab_names.add(lab["name"]) | |
| lab_strings = [] | |
| for lab_name in sorted(all_lab_names): | |
| lab_value = latest_value(all_visits, ("labtests", lab_name)) | |
| if lab_value is not None: | |
| lab_strings.append(f"{lab_name}: {lab_value}") | |
| labs_text = ", ".join(lab_strings) if lab_strings else "N/A" | |
| return f"Latest date: {latest_value(all_visits,('chartdate',)) or 'N/A'}\n" \ | |
| f"Active Diagnoses: {', '.join(sorted(active_set(all_visits,'diagnosis'))) or 'N/A'}\n" \ | |
| f"Active Medications: {', '.join(sorted(active_set(all_visits,'medications'))) or 'N/A'}\n" \ | |
| f"Latest Vitals: Bp: {latest_value(all_visits,('vitals','Bp(sys)(mmHg)'))}/{latest_value(all_visits,('vitals','Bp(dia)(mmHg)'))} mmHg, Weight: {latest_value(all_visits,('weight',))}\n" \ | |
| f"Latest Labs: {labs_text}" | |
| def delta_to_text(delta): | |
| L = [] | |
| if delta["added_dx"]: | |
| L.append("New Diagnoses: " + ", ".join(delta["added_dx"])) | |
| if delta["started_meds"]: | |
| L.append("Medications Started: " + ", ".join(delta["started_meds"])) | |
| if delta["stopped_meds"]: | |
| L.append("Medications Stopped: " + ", ".join(delta["stopped_meds"])) | |
| w = delta["weight"] | |
| L.append(f"Weight: {_fmt(w['prev'])} -> {_fmt(w['curr'])} (Δ {_fmt(w['delta'], '+.1f')})") | |
| s, d = delta["bp_sys"], delta["bp_dia"] | |
| L.append(f"BP: {_fmt(s['curr'])}/{_fmt(d['curr'])} (Δs {_fmt(s['delta'], '+.0f')}, Δd {_fmt(d['delta'], '+.0f')})") | |
| for lab_name, lab_data in delta["labs"].items(): | |
| if lab_data["prev"] is not None or lab_data["curr"] is not None: | |
| L.append(f"{lab_name}: {_fmt(lab_data['prev'])} -> {_fmt(lab_data['curr'])} (Δ {_fmt(lab_data['delta'], '+.1f')})") | |
| return "\n".join(L) | |
| def build_main_prompt(baseline, delta_text, patient_info=""): | |
| return ( | |
| "You are an expert clinical AI assistant. Your task is to generate a patient summary.\n" | |
| "Use the chartsummarydtl for context. The STRUCTURED BASELINE and DELTAS are the absolute ground truth.\n" | |
| "Produce a concise, physician-ready summary. Never omit critical new information from the deltas.\n\n" | |
| "The summary MUST have four sections:\n" | |
| "1) Clinical Assessment\n" | |
| "2) Key Trends & Changes\n" | |
| "3) Plan & Suggested Actions\n" | |
| "4) Direct Guidance for Physician\n\n" | |
| f"PATIENT INFORMATION:\n{patient_info}\n\n" | |
| f"STRUCTURED BASELINE (authoritative):\n{baseline}\n\n" | |
| f"STRUCTURED DELTAS (authoritative):\n{delta_text}\n\n" | |
| "Now generate the complete clinical summary with all four sections in markdown format:" | |
| ) | |
| # --- GGUF model loader and pipeline (simplified) --- | |
| class GGUFModelPipeline: | |
| def __init__(self, model_path_or_repo, filename=None, cache_dir=None, timeout=300): | |
| from llama_cpp import Llama | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import time | |
| import logging | |
| import threading | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError | |
| cache_dir = cache_dir or os.environ.get("HF_HOME", "/tmp/huggingface") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| self.timeout = timeout | |
| if filename is not None: | |
| try: | |
| logging.info(f"Downloading model from {model_path_or_repo}/{filename}") | |
| local_path = hf_hub_download( | |
| repo_id=model_path_or_repo, | |
| filename=filename, | |
| cache_dir=cache_dir, | |
| resume_download=True, | |
| local_files_only=False, | |
| ) | |
| logging.info(f"Model downloaded successfully to {local_path}") | |
| except Exception as e: | |
| logging.error(f"Failed to download model: {e}") | |
| raise RuntimeError(f"Model download failed: {str(e)}") | |
| else: | |
| local_path = model_path_or_repo | |
| if not os.path.exists(local_path): | |
| raise FileNotFoundError(f"Model path does not exist: {local_path}") | |
| file_size = os.path.getsize(local_path) / (1024 * 1024) | |
| logging.info(f"Model file size: {file_size:.2f} MB") | |
| if file_size > 5000: | |
| logging.warning(f"Model file is very large ({file_size:.2f} MB), may cause memory issues") | |
| load_start = time.time() | |
| try: | |
| cpu_count = os.cpu_count() or 2 | |
| is_hf_space = os.environ.get('SPACE_ID') is not None | |
| if is_hf_space: | |
| default_threads = 1 | |
| n_batch = 16 | |
| n_ctx = 4096 | |
| logging.info("[GGUF] Detected Hugging Face Space - using ultra-conservative memory settings") | |
| else: | |
| default_threads = max(1, min(2, cpu_count)) | |
| n_batch = 32 | |
| n_ctx = 1024 | |
| n_threads = int(os.environ.get("GGUF_N_THREADS", str(default_threads))) | |
| n_batch = int(os.environ.get("GGUF_N_BATCH", str(n_batch))) | |
| self.model = Llama( | |
| model_path=local_path, | |
| n_ctx=n_ctx, | |
| n_threads=n_threads, | |
| n_batch=n_batch, | |
| n_gpu_layers=0, | |
| logits_all=False, | |
| embedding=False, | |
| use_mmap=True, | |
| use_mlock=False, | |
| seed=0, | |
| verbose=False, | |
| rope_freq_base=10000, | |
| rope_freq_scale=1.0, | |
| mul_mat_q=True, | |
| f16_kv=True, | |
| vocab_only=False, | |
| n_threads_batch=n_threads, | |
| mmap=True, | |
| cache_type_k=0, | |
| cache_type_v=0, | |
| ) | |
| except Exception as e: | |
| logging.error(f"Failed to initialize GGUF model: {e}") | |
| raise RuntimeError(f"Failed to initialize GGUF model via llama.cpp: {e}") | |
| load_time = time.time() - load_start | |
| logging.info(f"[GGUF] Model initialized in {load_time:.2f}s from {local_path} (threads={n_threads}, batch={n_batch})") | |
| def _strip_special_tokens(self, text: str) -> str: | |
| patterns = [ | |
| r"<\|assistant\|>", r"<\|user\|>", r"<\|system\|>", r"<\|end\|>", r"<\|endoftext\|>", r"</s>", r"<s>" | |
| ] | |
| for p in patterns: | |
| text = re.sub(p, "", text, flags=re.IGNORECASE) | |
| return text.strip() | |
| def _generate_with_timeout(self, prompt, max_tokens=4000, temperature=0.5, top_p=0.95, timeout=None): | |
| if timeout is None: | |
| is_hf_space = os.environ.get('SPACE_ID') is not None | |
| timeout = int(os.environ.get('GGUF_GENERATION_TIMEOUT', '600' if is_hf_space else '300')) | |
| def _generate(): | |
| try: | |
| output = self.model( | |
| prompt, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| stop=["</s>", "###"] | |
| ) | |
| return output | |
| except Exception as e: | |
| raise e | |
| from concurrent.futures import ThreadPoolExecutor | |
| with ThreadPoolExecutor(max_workers=1) as executor: | |
| future = executor.submit(_generate) | |
| try: | |
| output = future.result(timeout=timeout) | |
| return output | |
| except TimeoutError: | |
| future.cancel() | |
| raise TimeoutError(f"Generation timed out after {timeout} seconds") | |
| def generate(self, prompt, max_tokens=4000, temperature=0.5, top_p=0.95): | |
| t0 = time.time() | |
| try: | |
| output = self._generate_with_timeout(prompt, max_tokens, temperature, top_p) | |
| dt = time.time() - t0 | |
| text = output["choices"][0]["text"].strip() | |
| text = self._strip_special_tokens(text) | |
| approx_words = len(text.split()) | |
| logging.info(f"[GGUF] generate: {dt:.2f}s, ~{approx_words} words, max_tokens={max_tokens}") | |
| return text | |
| except TimeoutError as e: | |
| logging.error(f"Generation timed out: {e}") | |
| raise e | |
| except Exception as e: | |
| logging.error(f"Generation failed: {e}") | |
| raise RuntimeError(f"Text generation failed: {str(e)}") | |
| def generate_full_summary(self, prompt, max_tokens=4000, max_loops=5): | |
| def is_complete(text): | |
| required_sections = [ | |
| 'Clinical Assessment', | |
| 'Key Trends & Changes', | |
| 'Plan & Suggested Actions', | |
| 'Direct Guidance for Physician' | |
| ] | |
| missing_sections = [s for s in required_sections if s not in text] | |
| if missing_sections: | |
| logging.info(f"[GGUF] Missing sections: {missing_sections}") | |
| return False, missing_sections | |
| ends_with_punct = bool(re.search(r'[.!?][\s\n]*$', text)) | |
| if not ends_with_punct: | |
| logging.info("[GGUF] Summary does not end with a full sentence") | |
| return ends_with_punct, [] | |
| def generate_missing_section(section_name, base_prompt, existing_output): | |
| """Generate a specific missing section using targeted prompt""" | |
| section_prompts = { | |
| 'Clinical Assessment': f"Based on the patient data provided, generate only the Clinical Assessment section in markdown format. Focus on the current clinical status, key findings, and overall patient condition.\n\nPatient Data:\n{base_prompt}\n\n## Clinical Assessment\n", | |
| 'Key Trends & Changes': f"Based on the patient data provided, generate only the Key Trends & Changes section in markdown format. Analyze trends in vitals, labs, diagnoses, and medications over time.\n\nPatient Data:\n{base_prompt}\n\n## Key Trends & Changes\n", | |
| 'Plan & Suggested Actions': f"Based on the patient data provided, generate only the Plan & Suggested Actions section in markdown format. Recommend next steps, follow-up actions, and treatment considerations.\n\nPatient Data:\n{base_prompt}\n\n## Plan & Suggested Actions\n", | |
| 'Direct Guidance for Physician': f"Based on the patient data provided, generate only the Direct Guidance for Physician section in markdown format. Provide specific recommendations for the treating physician.\n\nPatient Data:\n{base_prompt}\n\n## Direct Guidance for Physician\n" | |
| } | |
| targeted_prompt = section_prompts.get(section_name, f"Generate the {section_name} section based on the patient data.\n\n{base_prompt}\n\n## {section_name}\n") | |
| try: | |
| section_output = self.generate(targeted_prompt, max_tokens=max_tokens//2) # Use half tokens for section generation | |
| # Clean up the output to extract just the section content | |
| if f"## {section_name}" in section_output: | |
| section_content = section_output.split(f"## {section_name}", 1)[1].strip() | |
| # Remove any subsequent section headers | |
| section_content = re.split(r'##\s+', section_content, 1)[0].strip() | |
| return f"## {section_name}\n{section_content}" | |
| else: | |
| # If the model didn't follow the format, use the raw output | |
| return f"## {section_name}\n{section_output.strip()}" | |
| except Exception as e: | |
| logging.error(f"Failed to generate {section_name} section: {e}") | |
| # Return a minimal section if generation fails | |
| return f"## {section_name}\nUnable to generate this section due to processing error. Please review patient data manually." | |
| full_output = "" | |
| current_prompt = prompt | |
| total_start = time.time() | |
| try: | |
| logging.info(f"[GGUF] Starting enhanced full summary generation with max_loops={max_loops}") | |
| logging.info(f"[GGUF] Prompt length: {len(prompt)} characters") | |
| # Main generation loops | |
| for loop_idx in range(max_loops): | |
| loop_start = time.time() | |
| logging.info(f"[GGUF] Starting loop {loop_idx+1}/{max_loops}") | |
| logging.info(f"[GGUF] Current prompt length: {len(current_prompt)} characters") | |
| output = self.generate(current_prompt, max_tokens=max_tokens) | |
| if output.startswith(prompt): | |
| output = output[len(prompt):].strip() | |
| full_output += output | |
| loop_time = time.time() - loop_start | |
| logging.info(f"[GGUF] loop {loop_idx+1}/{max_loops}: {loop_time:.2f}s, cumulative {time.time()-total_start:.2f}s, length={len(full_output)} chars") | |
| logging.info(f"[GGUF] Generated {len(output)} characters in this loop") | |
| complete, missing_sections = is_complete(full_output) | |
| if complete: | |
| logging.info(f"[GGUF] All required sections found after loop {loop_idx+1}") | |
| break | |
| # If not complete and this is not the last loop, prepare next prompt | |
| if loop_idx < max_loops - 1: | |
| if missing_sections: | |
| missing_list = ", ".join(missing_sections) | |
| current_prompt = f"{prompt}\n\n{full_output}\n\nThe summary is missing these sections: {missing_list}. Please continue and complete all sections in markdown format:" | |
| else: | |
| current_prompt = f"{prompt}\n\n{full_output}\n\nContinue the summary and ensure it ends with a complete sentence:" | |
| logging.info(f"[GGUF] Preparing next prompt for loop {loop_idx+2}") | |
| # Post-processing: Generate any remaining missing sections | |
| complete, missing_sections = is_complete(full_output) | |
| if missing_sections: | |
| logging.info(f"[GGUF] Generating {len(missing_sections)} missing sections post-processing") | |
| generated_sections = [] | |
| for section in missing_sections: | |
| logging.info(f"[GGUF] Generating missing section: {section}") | |
| section_content = generate_missing_section(section, prompt, full_output) | |
| generated_sections.append(section_content) | |
| # Append generated sections to the main output | |
| if generated_sections: | |
| full_output += "\n\n" + "\n\n".join(generated_sections) | |
| total_time = time.time() - total_start | |
| logging.info(f"[GGUF] generate_full_summary completed in {total_time:.2f}s") | |
| logging.info(f"[GGUF] Final summary length: {len(full_output)} characters") | |
| # Final validation | |
| final_complete, final_missing = is_complete(full_output) | |
| if not final_complete: | |
| logging.warning(f"[GGUF] Final summary still incomplete. Missing: {final_missing}") | |
| # As a last resort, ensure at least basic structure | |
| if final_missing: | |
| fallback_sections = [] | |
| for section in final_missing: | |
| fallback_sections.append(f"## {section}\nPlease review the patient data for this section.") | |
| full_output += "\n\n" + "\n\n".join(fallback_sections) | |
| return full_output.strip() | |
| except Exception as e: | |
| logging.error(f"Full summary generation failed: {e}") | |
| # Instead of raising error, return a minimal complete summary | |
| minimal_sections = [ | |
| "## Clinical Assessment\nPatient data processing encountered an error. Please review the raw patient information manually.", | |
| "## Key Trends & Changes\nUnable to analyze trends due to processing error. Manual review recommended.", | |
| "## Plan & Suggested Actions\nError in generating action plan. Consult with healthcare provider for appropriate next steps.", | |
| "## Direct Guidance for Physician\nProcessing error occurred. Please conduct a thorough manual review of all patient data." | |
| ] | |
| return "\n\n".join(minimal_sections) | |
| def create_fallback_pipeline(): | |
| class FallbackPipeline: | |
| def __init__(self): | |
| self.name = "fallback_text" | |
| def generate(self, prompt, **kwargs): | |
| sections = [ | |
| "## Clinical Assessment\nThe patient data indicates a medical condition that requires professional evaluation. Key indicators include vital signs, diagnoses, and medication history that suggest ongoing clinical management is necessary.", | |
| "## Key Trends & Changes\nPatient records show historical data points including vital measurements, laboratory results, and medication adjustments. Trends in weight, blood pressure, and lab values should be monitored for significant changes over time.", | |
| "## Plan & Suggested Actions\nImmediate actions include reviewing all available patient data, consulting with specialists if needed, and ensuring continuity of care. Follow-up appointments and medication adherence should be prioritized.", | |
| "## Direct Guidance for Physician\nThis summary was generated using a fallback method due to model unavailability. Please conduct a thorough review of the patient's complete medical record, including EHR data, imaging, and clinical notes. Consider interdisciplinary consultation for complex cases." | |
| ] | |
| return "\n\n".join(sections) | |
| def generate_full_summary(self, prompt, **kwargs): | |
| return self.generate(prompt, **kwargs) | |
| return FallbackPipeline() | |
| # --- OpenVINO pipeline loader --- | |
| class OpenVinoPipeline: | |
| def __init__(self, model, tokenizer): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| def get_openvino_pipeline(model_name: str): | |
| import os | |
| if os.path.isdir(model_name): | |
| model = OVModelForCausalLM.from_pretrained(model_name, compile=True, device="CPU", cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')) | |
| else: | |
| model = OVModelForCausalLM.from_pretrained(model_name, export=False, compile=False, device="CPU", cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')) | |
| return OpenVinoPipeline(model, tokenizer) | |
| # --- Summarizer Agent --- | |
| class SummarizerAgent: | |
| def __init__(self, summarization_model_loader): | |
| self.summarization_model_loader = summarization_model_loader | |
| self.last_summary_length = 0 | |
| self.request_count = 0 | |
| def generate_summary(self, text): | |
| try: | |
| clean_text = text.strip() | |
| if not clean_text or len(clean_text.split()) < 5: | |
| return "Input text is too short for summarization" | |
| model = self.summarization_model_loader.load() | |
| if hasattr(model, 'generate_full_summary'): | |
| summary = model.generate_full_summary(clean_text, max_tokens=4000, max_loops=2) | |
| else: | |
| # fallback simple summarization | |
| summary = model(clean_text, max_length=512, min_length=50, do_sample=False) | |
| if isinstance(summary, list) and summary: | |
| summary = summary[0].get('summary_text', '') | |
| elif isinstance(summary, str): | |
| summary = summary | |
| else: | |
| summary = str(summary) | |
| return summary | |
| except Exception as e: | |
| return f"Summary generation failed: {str(e)}" | |
| # --- Main function to replicate generate_patient_summary API --- | |
| def generate_patient_summary(patientid=None, token=None, key=None, ehr_data=None, patient_summarizer_model_name="falconsai/medical_summarization", patient_summarizer_model_type="summarization"): | |
| start_total = time.time() | |
| if ehr_data is not None: | |
| # Use provided EHR data directly | |
| if isinstance(ehr_data, dict): | |
| ehr_result = ehr_data.get("result") or ehr_data | |
| else: | |
| ehr_result = ehr_data | |
| else: | |
| # Make API call | |
| if not patientid or not token or not key: | |
| raise ValueError("Missing required fields: patientid, token, or key") | |
| api_url = f"{key}/Transactionapi/api/PatientList/patientsummary" | |
| headers = { | |
| "Authorization": f"Bearer {token}", | |
| "Content-Type": "application/json", | |
| } | |
| if key and not key.startswith("http"): | |
| headers["x-api-key"] = key | |
| response = requests.post(api_url, json={"patientid": patientid}, headers=headers, timeout=30) | |
| if response.status_code != 200: | |
| raise RuntimeError(f"API request failed with status {response.status_code}: {response.text}") | |
| try: | |
| api_data = response.json() | |
| except ValueError: | |
| api_data = response.text | |
| if isinstance(api_data, dict): | |
| ehr_result = api_data.get("result") or api_data | |
| else: | |
| ehr_result = api_data | |
| chartsummarydtl = ehr_result.get("chartsummarydtl") if isinstance(ehr_result, dict) else None | |
| if not chartsummarydtl: | |
| raise RuntimeError("Missing chartsummarydtl in EHR response") | |
| visits = parse_ehr_chartsummarydtl(chartsummarydtl) | |
| delta = compute_deltas([], visits) | |
| all_visits = visits_sorted(visits) | |
| baseline = build_compact_baseline(all_visits) | |
| delta_text = delta_to_text(delta) | |
| prompt = build_main_prompt(baseline, delta_text) | |
| pipeline = None | |
| loader = None | |
| torch.set_num_threads(2) | |
| if patient_summarizer_model_type == "gguf": | |
| try: | |
| if patient_summarizer_model_name.endswith('.gguf') and '/' in patient_summarizer_model_name: | |
| repo_id, filename = patient_summarizer_model_name.rsplit('/', 1) | |
| pipeline = GGUFModelPipeline(repo_id, filename) | |
| else: | |
| pipeline = GGUFModelPipeline(patient_summarizer_model_name) | |
| summary_raw = pipeline.generate_full_summary(prompt, max_tokens=100000, max_loops=5) | |
| summary_start_patterns = [ | |
| "Now generate the complete, updated clinical summary with all four sections in a markdown format:", | |
| "## Clinical Assessment", | |
| "# Clinical Assessment", | |
| "Clinical Assessment" | |
| ] | |
| markdown_summary = summary_raw | |
| for pattern in summary_start_patterns: | |
| if pattern in summary_raw: | |
| markdown_summary = summary_raw.split(pattern)[-1].strip() | |
| break | |
| total_time = time.time() - start_total | |
| logger.info(f"[TIMING] TOTAL: {total_time:.2f}s") | |
| return markdown_summary, baseline, delta_text | |
| except Exception as e: | |
| fallback_pipeline = create_fallback_pipeline() | |
| fallback_summary = fallback_pipeline.generate_full_summary(prompt) | |
| return fallback_summary, baseline, delta_text | |
| elif patient_summarizer_model_type in {"text-generation", "causal-openvino"}: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from optimum.intel.openvino import OVModelForCausalLM | |
| if patient_summarizer_model_type == "causal-openvino": | |
| pipeline = get_openvino_pipeline(patient_summarizer_model_name) | |
| else: | |
| # Use HuggingFace pipeline | |
| tokenizer = AutoTokenizer.from_pretrained(patient_summarizer_model_name, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained(patient_summarizer_model_name, trust_remote_code=True) | |
| pipeline = transformers_pipeline("text-generation", model=model, tokenizer=tokenizer) | |
| inputs = pipeline.tokenizer([prompt], return_tensors="pt") | |
| outputs = pipeline.model.generate(**inputs, max_new_tokens=100000, do_sample=False, pad_token_id=pipeline.tokenizer.eos_token_id or 32000) | |
| text = pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| new_summary = text.split("Now generate the complete, updated clinical summary with all four sections in a markdown format:")[-1].strip() | |
| return new_summary, baseline, delta_text | |
| elif patient_summarizer_model_type == "summarization": | |
| summarizer_agent = SummarizerAgent(loader=None) | |
| summary = summarizer_agent.generate_summary(prompt) | |
| return summary, baseline, delta_text | |
| else: | |
| raise ValueError(f"Unsupported model type: {patient_summarizer_model_type}") | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Option 1: Using API call (replace with actual values) | |
| # patientid = "your_patient_id" | |
| # token = "your_bearer_token" | |
| # key = "https://your-ehr-api.com" | |
| # summary, baseline, delta_text = generate_patient_summary(patientid, token, key) | |
| # Option 2: Using direct EHR data (recommended for testing) | |
| sample_ehr_data = { | |
| "chartsummarydtl": [ | |
| { | |
| "chartdate": "2023-12-01", | |
| "vitals": ["Bp(sys)(mmHg): 140", "Bp(dia)(mmHg): 90", "Weight: 180"], | |
| "diagnosis": ["Hypertension", "Diabetes"], | |
| "medications": ["Lisinopril 10mg", "Metformin 500mg"], | |
| "labtests": [ | |
| {"name": "HbA1c (%)", "value": "7.2"}, | |
| {"name": "Creatinine", "value": "1.1"} | |
| ] | |
| }, | |
| { | |
| "chartdate": "2023-11-15", | |
| "vitals": ["Bp(sys)(mmHg): 135", "Bp(dia)(mmHg): 85", "Weight: 182"], | |
| "diagnosis": ["Hypertension"], | |
| "medications": ["Lisinopril 10mg"], | |
| "labtests": [ | |
| {"name": "HbA1c (%)", "value": "7.5"}, | |
| {"name": "Creatinine", "value": "1.0"} | |
| ] | |
| } | |
| ] | |
| } | |
| summary, baseline, delta_text = generate_patient_summary(ehr_data=sample_ehr_data) | |
| print("Patient Summary:") | |
| print(summary) | |
| print("\nBaseline:") | |
| print(baseline) | |
| print("\nDelta Text:") | |
| print(delta_text) | |