""" Colab-Ready Patient Summary Generation ===================================== This file replicates the functionality of generate_patient_summary from the HNTAI project. It provides a complete, self-contained implementation that can be run in Google Colab. Features: - Robust JSON parsing with flexible key matching - Medical summarization model (Falconsai/medical_summarization) - Original chartsummarydtl data format support - Clinical summary generation with 4 sections - Error handling and fallback mechanisms - Exact prompt format matching original system Usage in Google Colab: # Install dependencies !pip install torch transformers accelerate huggingface-hub # Run the code from patient_summary_colab import PatientSummaryGenerator, create_sample_patient_data # Create generator with medical model generator = PatientSummaryGenerator(model_name="Falconsai/medical_summarization") # Use sample data or your own data sample_data = create_sample_patient_data() summary = generator.generate_patient_summary(sample_data) print(summary) Data Format: The system expects data in the original chartsummarydtl format: { "result": { "patientname": "John Doe", "agey": 65, "gender": "Male", "chartsummarydtl": [...], "past_medical_history": [...], "allergies": [...] } } """ import os import json import re import time import logging import warnings from typing import Dict, Any, List, Optional, Union from datetime import datetime import concurrent.futures from textwrap import fill # Suppress warnings for cleaner output warnings.filterwarnings("ignore", category=UserWarning) # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class RobustJSONParser: """Enhanced JSON parsing utilities for medical data with flexible key matching.""" @staticmethod def safe_get(data_dict: Dict[str, Any], key_aliases: List[str]) -> Optional[Any]: """Safely retrieves a value from a dictionary using a list of possible alias keys.""" if not isinstance(data_dict, dict): return None for alias in key_aliases: for key, value in data_dict.items(): if key.lower() == alias.lower(): return value return None @staticmethod def normalize_visit_data(visit: Dict[str, Any]) -> Dict[str, Any]: """Normalizes a single visit record with flexible key matching.""" if not isinstance(visit, dict): return {} normalized = {} # Date handling with multiple possible keys date_value = RobustJSONParser.safe_get(visit, ['chartdate', 'date', 'visitDate', 'encounterDate', 'appointmentDate']) if date_value: normalized['chartdate'] = str(date_value)[:10] # Ensure YYYY-MM-DD format # Vitals handling vitals = RobustJSONParser.safe_get(visit, ['vitals', 'vitalSigns', 'vital_signs', 'vitalsigns']) if vitals: if isinstance(vitals, dict): normalized['vitals'] = vitals elif isinstance(vitals, list): vitals_dict = {} for item in vitals: if isinstance(item, str) and ':' in item: key, value = item.split(':', 1) vitals_dict[key.strip()] = value.strip() normalized['vitals'] = vitals_dict # Diagnoses with flexible naming diagnoses = RobustJSONParser.safe_get(visit, ['diagnoses', 'diagnosis', 'conditions', 'diagnosisList']) if diagnoses: if isinstance(diagnoses, list): normalized['diagnosis'] = [str(d).strip() for d in diagnoses if d] elif isinstance(diagnoses, str): normalized['diagnosis'] = [diagnoses.strip()] # Medications with flexible naming medications = RobustJSONParser.safe_get(visit, ['medications', 'meds', 'prescriptions', 'medicationList', 'drugs']) if medications: if isinstance(medications, list): normalized['medications'] = [str(m).strip() for m in medications if m] elif isinstance(medications, str): normalized['medications'] = [medications.strip()] # Allergies allergies = RobustJSONParser.safe_get(visit, ['allergies', 'allergyList', 'allergyInfo']) if allergies: if isinstance(allergies, list): normalized['allergies'] = [str(a).strip() for a in allergies if a] elif isinstance(allergies, str): normalized['allergies'] = [allergies.strip()] # Lab tests labtests = RobustJSONParser.safe_get(visit, ['labtests', 'labTests', 'lab_tests', 'laboratory', 'labs']) if labtests: if isinstance(labtests, list): normalized['labtests'] = [] for lab in labtests: if isinstance(lab, dict): normalized['labtests'].append(lab) elif isinstance(lab, str): if ':' in lab: name, value = lab.split(':', 1) normalized['labtests'].append({'name': name.strip(), 'value': value.strip()}) # Chief complaint complaint = RobustJSONParser.safe_get(visit, ['chiefComplaint', 'reasonForVisit', 'chief_complaint', 'complaint']) if complaint: normalized['chiefComplaint'] = str(complaint).strip() # Symptoms symptoms = RobustJSONParser.safe_get(visit, ['symptoms', 'reportedSymptoms', 'symptomList']) if symptoms: if isinstance(symptoms, list): normalized['symptoms'] = [str(s).strip() for s in symptoms if s] elif isinstance(symptoms, str): normalized['symptoms'] = [symptoms.strip()] return normalized @staticmethod def process_patient_record_robust(patient_data: Dict[str, Any]) -> Dict[str, Any]: """Robustly processes a comprehensive patient JSON record.""" if not isinstance(patient_data, dict): return {"error": "Invalid patient data format"} processed = {} # Demographics with flexible key matching demographics = RobustJSONParser.safe_get(patient_data, ['demographics', 'patientInfo', 'patient_info', 'demographic']) if demographics and isinstance(demographics, dict): processed['demographics'] = { 'age': RobustJSONParser.safe_get(demographics, ['age', 'yearsOld', 'age_years', 'ageYears']), 'gender': RobustJSONParser.safe_get(demographics, ['gender', 'sex', 'genderIdentity']), 'dob': RobustJSONParser.safe_get(demographics, ['dob', 'dateOfBirth', 'birthdate', 'birthDate']) } # Patient name and ID processed['patientName'] = RobustJSONParser.safe_get(patient_data, ['patientName', 'patient_name', 'name', 'patient']) processed['patientNumber'] = RobustJSONParser.safe_get(patient_data, ['patientNumber', 'patient_number', 'id', 'patientId', 'patientID']) # Past medical history pmh = RobustJSONParser.safe_get(patient_data, ['pastMedicalHistory', 'pmh', 'medical_history', 'medicalHistory']) if pmh: processed['pastMedicalHistory'] = pmh if isinstance(pmh, list) else [pmh] # Allergies allergies = RobustJSONParser.safe_get(patient_data, ['allergies', 'allergyInfo', 'allergy_list']) if allergies: processed['allergies'] = allergies if isinstance(allergies, list) else [allergies] # Habits/Lifestyle habits = RobustJSONParser.safe_get(patient_data, ['habits', 'lifestyle', 'lifestyleFactors']) if habits and isinstance(habits, dict): processed['habits'] = habits # Comorbidities comorbidities = RobustJSONParser.safe_get(patient_data, ['comorbidities', 'chronicConditions', 'chronic_conditions']) if comorbidities: processed['comorbidities'] = comorbidities if isinstance(comorbidities, list) else [comorbidities] # Visits/Encounters with robust processing visits = RobustJSONParser.safe_get(patient_data, ['visits', 'encounters', 'appointments', 'encounterList', 'visitList']) if visits and isinstance(visits, list): processed_visits = [] for visit in visits: if isinstance(visit, dict): normalized_visit = RobustJSONParser.normalize_visit_data(visit) if normalized_visit: processed_visits.append(normalized_visit) processed['visits'] = processed_visits return processed class PatientSummaryGenerator: """Main class for generating patient summaries using AI models.""" def __init__(self, model_name: str = "Falconsai/medical_summarization", device: str = "auto"): """ Initialize the patient summary generator. Args: model_name: Name of the model to use for generation (default: Falconsai/medical_summarization) device: Device to run on ('auto', 'cpu', 'cuda') """ self.model_name = model_name self.device = self._get_device(device) self.model = None self.tokenizer = None self.parser = RobustJSONParser() logger.info(f"PatientSummaryGenerator initialized with model: {model_name} on {self.device}") def _get_device(self, device: str) -> str: """Determine the appropriate device to use.""" if device == "auto": try: import torch return "cuda" if torch.cuda.is_available() else "cpu" except ImportError: return "cpu" return device def _load_model(self): """Load the model and tokenizer.""" if self.model is not None: return try: from transformers import AutoTokenizer, AutoModelForCausalLM logger.info(f"Loading model: {self.model_name}") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype="auto" if self.device == "cuda" else "auto", device_map="auto" if self.device == "cuda" else None ) # Set pad token if not present if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") # Create a fallback model self._create_fallback_model() def _create_fallback_model(self): """Create a fallback model when the main model fails to load.""" class FallbackModel: def __init__(self): self.name = "fallback_template" def generate(self, prompt: str, **kwargs) -> str: # Simple template-based response sections = [ "## Clinical Assessment\nBased on the provided information, this appears to be a medical case requiring clinical review.", "## Key Trends & Changes\nPlease review the patient data for any significant changes or trends.", "## Plan & Suggested Actions\nConsider consulting with a healthcare provider for proper medical assessment.", "## Direct Guidance for Physician\nThis summary was generated using a fallback method. Please review all patient data thoroughly." ] return "\n\n".join(sections) self.model = FallbackModel() logger.warning("Using fallback model for generation") def _build_chronological_narrative(self, patient_data: dict) -> str: """Builds a chronological narrative from multi-encounter patient history.""" result = patient_data.get("result", {}) narrative = [] # Past Medical History pmh = self.parser.safe_get(result, ['past_medical_history', 'pastMedicalHistory', 'pmh', 'medical_history', 'medicalHistory']) if pmh: if isinstance(pmh, list): narrative.append(f"Past Medical History: {', '.join(pmh)}.") else: narrative.append(f"Past Medical History: {pmh}.") else: narrative.append("Past Medical History: Not specified.") # Social History social = self.parser.safe_get(result, ['social_history', 'socialHistory', 'social', 'lifestyle']) if social: narrative.append(f"Social History: {social}.") else: narrative.append("Social History: Not specified.") # Allergies allergies = self.parser.safe_get(result, ['allergies', 'allergyInfo', 'allergy_list']) if allergies: if isinstance(allergies, list): narrative.append(f"Allergies: {', '.join(allergies)}.") else: narrative.append(f"Allergies: {allergies}.") else: narrative.append("Allergies: None reported.") # Loop through encounters chronologically for enc in result.get("encounters", []): encounter_str = ( f"Encounter on {enc['visit_date']}: " f"Chief Complaint: '{enc['chief_complaint']}'. " f"Symptoms: {enc.get('symptoms', 'None reported')}. " f"Diagnosis: {', '.join(enc['diagnosis'])}. " f"Doctor's Notes: {enc['dr_notes']}. " ) if enc.get('vitals'): encounter_str += f"Vitals: {', '.join([f'{k}: {v}' for k, v in enc['vitals'].items()])}. " if enc.get('lab_results'): encounter_str += f"Labs: {', '.join([f'{k}: {v}' for k, v in enc['lab_results'].items()])}. " if enc.get('medications'): encounter_str += f"Medications: {', '.join(enc['medications'])}. " if enc.get('treatment'): encounter_str += f"Treatment: {enc['treatment']}." narrative.append(encounter_str) return "\n".join(narrative) def _create_ai_prompt(self, processed_data: Dict[str, Any]) -> str: """Creates a comprehensive AI prompt from processed patient data using the original format.""" # Convert patient data to plain text format (matching original) patient_text = self._convert_patient_data_to_plain_text(processed_data) # Create the exact prompt format from the original system prompt = f"""<|system|> You are an expert clinical AI assistant. Your task is to generate a comprehensive patient summary based on the provided medical data. Analyze the patient's medical history, current condition, and provide a structured clinical assessment. **PATIENT MEDICAL DATA:** {patient_text} **REQUIRED OUTPUT FORMAT:** Generate a complete patient summary with exactly 4 sections in markdown format: ## Clinical Assessment - Provide a comprehensive assessment of the patient's current condition - Include key diagnoses, vital signs, and clinical findings - Analyze the patient's overall health status ## Key Trends & Changes - Identify important trends in the patient's medical history - Note any significant changes in vital signs, lab values, or conditions - Highlight progression or improvement of conditions ## Plan & Suggested Actions - Recommend specific next steps for patient care - Suggest monitoring parameters and follow-up requirements - Include medication management recommendations ## Direct Guidance for Physician - Provide clear, actionable clinical guidance - Highlight critical concerns or risks - Offer specific recommendations for the treating physician <|user|> Generate the complete, four-part integrated patient summary based on the medical data provided above. Keep the summary under 4000 tokens with each section containing approximately 500 tokens including bullet points. <|assistant|>""" return prompt def _convert_patient_data_to_plain_text(self, patient_data: Dict[str, Any]) -> str: """Convert patient data from JSON format to plain text format for LLM processing.""" text_parts = [] # Patient demographics demographics = patient_data.get('demographics', {}) if demographics: age = demographics.get('age', 'Unknown') gender = demographics.get('gender', 'Unknown') patient_name = demographics.get('patientName', 'Unknown') text_parts.append(f"Patient: {patient_name}, {age} year old {gender}") # Visits/Encounters visits = patient_data.get('visits', []) if visits: text_parts.append("\nMedical History:") # Sort visits by date sorted_visits = sorted(visits, key=lambda v: v.get('chartdate', '')) for i, visit in enumerate(sorted_visits): visit_date = visit.get('chartdate', 'Unknown date') text_parts.append(f"\nVisit {i+1} - {visit_date}:") # Chief complaint complaint = visit.get('chiefComplaint', '') if complaint: text_parts.append(f"Chief Complaint: {complaint}") # Symptoms symptoms = visit.get('symptoms', []) if symptoms: text_parts.append(f"Symptoms: {', '.join(symptoms)}") # Vitals vitals = visit.get('vitals', {}) if vitals: vitals_text = [] for key, value in vitals.items(): vitals_text.append(f"{key}: {value}") if vitals_text: text_parts.append(f"Vital Signs: {', '.join(vitals_text)}") # Diagnoses diagnoses = visit.get('diagnosis', []) if diagnoses: text_parts.append(f"Diagnoses: {', '.join(diagnoses)}") # Medications medications = visit.get('medications', []) if medications: text_parts.append(f"Medications: {', '.join(medications)}") # Allergies allergies = visit.get('allergies', []) if allergies: text_parts.append(f"Allergies: {', '.join(allergies)}") # Lab tests labtests = visit.get('labtests', []) if labtests: lab_text = [] for lab in labtests: if isinstance(lab, dict): name = lab.get('name', '') value = lab.get('value', '') if name and value: lab_text.append(f"{name}: {value}") elif isinstance(lab, str): lab_text.append(lab) if lab_text: text_parts.append(f"Lab Results: {', '.join(lab_text)}") # Radiology orders radiology = visit.get('radiologyorders', []) if radiology: text_parts.append(f"Imaging Orders: {', '.join(radiology)}") # Past medical history pmh = patient_data.get('pastMedicalHistory', []) if pmh: text_parts.append(f"\nPast Medical History: {', '.join(pmh)}") # Allergies (general) allergies = patient_data.get('allergies', []) if allergies: text_parts.append(f"Known Allergies: {', '.join(allergies)}") # Habits/Lifestyle habits = patient_data.get('habits', {}) if habits: habit_text = [f"{key}: {value}" for key, value in habits.items()] text_parts.append(f"Lifestyle Factors: {', '.join(habit_text)}") # Comorbidities comorbidities = patient_data.get('comorbidities', []) if comorbidities: text_parts.append(f"Chronic Conditions: {', '.join(comorbidities)}") return "\n".join(text_parts) def _extract_structured_summary(self, processed_data: Dict[str, Any]) -> str: """Extracts a structured summary from processed patient data.""" summary_parts = [] # Patient baseline summary_parts.append("Patient Baseline Profile:") # Demographics demographics = processed_data.get('demographics', {}) age = demographics.get('age', 'N/A') gender = demographics.get('gender', 'N/A') summary_parts.append(f"- Demographics: {age} y/o {gender}") # Past medical history pmh = processed_data.get('pastMedicalHistory', []) if pmh: summary_parts.append(f"- Past Medical History: {', '.join(pmh)}") # Allergies allergies = processed_data.get('allergies', []) if allergies: summary_parts.append(f"- Allergies: {', '.join(allergies)}") # Habits habits = processed_data.get('habits', {}) if habits: habit_list = [f"{key}: {value}" for key, value in habits.items()] summary_parts.append(f"- Habits: {', '.join(habit_list)}") # Comorbidities comorbidities = processed_data.get('comorbidities', []) if comorbidities: summary_parts.append(f"- Key Comorbidities: {', '.join(comorbidities)}") # Visit history visits = processed_data.get('visits', []) if visits: try: # Sort visits by date sorted_visits = sorted(visits, key=lambda v: v.get('chartdate', '')) # Historical visits historical_visits = sorted_visits[:-1] if len(sorted_visits) > 1 else [] if historical_visits: summary_parts.append("\nLongitudinal Visit History:") for visit in historical_visits: visit_date = visit.get('chartdate', 'N/A') summary_parts.append(f"\n- Date: {visit_date}") # Vitals vitals = visit.get('vitals', {}) if vitals: bp_sys = vitals.get('Bp(sys)(mmHg)', vitals.get('systolic', 'N/A')) bp_dia = vitals.get('Bp(dia)(mmHg)', vitals.get('diastolic', 'N/A')) pulse = vitals.get('Pulse(bpm)', vitals.get('heartrate', 'N/A')) summary_parts.append(f" - Vitals: BP {bp_sys}/{bp_dia} mmHg, Pulse {pulse} bpm") # Diagnoses diagnoses = visit.get('diagnosis', []) if diagnoses: summary_parts.append(f" - Diagnoses: {', '.join(diagnoses)}") # Medications medications = visit.get('medications', []) if medications: summary_parts.append(f" - Medications: {', '.join(medications)}") # Current visit if sorted_visits: current_visit = sorted_visits[-1] summary_parts.append("\nCurrent Visit Details:") current_date = current_visit.get('chartdate', 'N/A') summary_parts.append(f"- Date: {current_date}") complaint = current_visit.get('chiefComplaint', 'Not specified') summary_parts.append(f"- Chief Complaint: {complaint}") symptoms = current_visit.get('symptoms', []) if symptoms: summary_parts.append(f"- Reported Symptoms: {', '.join(symptoms)}") vitals = current_visit.get('vitals', {}) if vitals: vitals_str = ", ".join([f"{key}: {value}" for key, value in vitals.items()]) summary_parts.append(f"- Vitals: {vitals_str}") diagnoses = current_visit.get('diagnosis', []) if diagnoses: summary_parts.append(f"- Diagnoses This Visit: {', '.join(diagnoses)}") except Exception as e: logger.error(f"Error processing visits: {str(e)}") summary_parts.append("\nError: Could not process visit data") return "\n".join(summary_parts) def _format_clinical_output(self, raw_summary: str, patient_data: dict) -> str: """Formats the raw AI-generated summary into a structured, doctor-friendly report.""" result = patient_data.get("result", {}) last_encounter = result.get("encounters", [{}])[-1] if result.get("encounters") else result # Consolidate active problems all_diagnoses_raw = set(result.get('past_medical_history', [])) for enc in result.get('encounters', []): all_diagnoses_raw.update(enc.get('diagnosis', [])) cleaned_diagnoses = sorted({ re.sub(r'\s*\([^)]*\)', '', dx).strip() for dx in all_diagnoses_raw }) # Consolidate current medications all_medications = set() for enc in result.get('encounters', []): all_medications.update(enc.get('medications', [])) current_meds = sorted(all_medications) # Report Header report = "\n==============================================\n" report += " CLINICAL SUMMARY REPORT\n" report += "==============================================\n" report += f"Generated On: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n" # Patient Overview report += "\n--- PATIENT OVERVIEW ---\n" report += f"Name: {result.get('patientname', 'Unknown')}\n" report += f"Patient ID: {result.get('patientnumber', 'Unknown')}\n" gender = result.get('gender', 'Unknown') report += f"Age/Sex: {result.get('agey', 'Unknown')} {gender[0] if gender != 'Unknown' else 'U'}\n" report += f"Allergies: {', '.join(result.get('allergies', ['None']))}\n" # Social History report += "\n--- SOCIAL HISTORY ---\n" report += fill(result.get('social_history', 'Not specified.'), width=80) + "\n" # Immediate Attention report += "\n--- IMMEDIATE ATTENTION (Most Recent Encounter) ---\n" report += f"Date of Event: {last_encounter.get('visit_date', 'Unknown')}\n" report += f"Chief Complaint: {last_encounter.get('chief_complaint', 'Not specified')}\n" if last_encounter.get('vitals'): vitals_str = ', '.join([f'{k}: {v}' for k, v in last_encounter['vitals'].items()]) report += f"Vitals: {vitals_str}\n" critical_diagnoses = [ dx for dx in last_encounter.get('diagnosis', []) if any(kw in dx.lower() for kw in ['acute', 'new onset', 'fall', 'afib', 'kidney injury']) ] if critical_diagnoses: report += f"Critical New Diagnoses: {', '.join(critical_diagnoses)}\n" report += f"Doctor's Notes: {last_encounter.get('dr_notes', 'N/A')}\n" # Active Problem List report += "\n--- ACTIVE PROBLEM LIST (Consolidated) ---\n" report += "\n".join(f"- {dx}" for dx in cleaned_diagnoses) + "\n" # Current Medications report += "\n--- CURRENT MEDICATION LIST (Consolidated) ---\n" report += "\n".join(f"- {med}" for med in current_meds) + "\n" # AI-Generated Narrative report += "\n--- AI-GENERATED CLINICAL NARRATIVE ---\n" report += fill(raw_summary, width=80) + "\n" return report def _evaluate_summary_against_guidelines(self, summary_text: str, patient_data: dict) -> str: """Simulated evaluation of summary against clinical guidelines.""" result = patient_data.get("result", {}) last_enc = result.get("encounters", [{}])[-1] if result.get("encounters") else {} summary_lower = summary_text.lower() evaluation = ( "\n==============================================\n" " AI SUMMARY EVALUATION & GUIDELINE CHECK\n" "==============================================\n" ) # Keyword-based accuracy critical_keywords = [ "fall", "dizziness", "atrial fibrillation", "afib", "rvr", "kidney", "ckd", "diabetes", "anticoagulation", "warfarin", "aspirin", "statin", "metformin", "gout", "angina", "pci", "bph", "hypertension", "metoprolol", "clopidogrel" ] found = [kw for kw in critical_keywords if kw in summary_lower] score = (len(found) / len(critical_keywords)) * 10 evaluation += f"\n1. KEYWORD ACCURACY SCORE: {score:.1f}/10\n" evaluation += f" - Found {len(found)} out of {len(critical_keywords)} critical concepts.\n" # Guideline checks evaluation += "\n2. CLINICAL GUIDELINE COMMENTARY (SIMULATED):\n" has_afib = any("atrial fibrillation" in dx.lower() for dx in last_enc.get('diagnosis', [])) on_anticoag = any("warfarin" in med.lower() or "apixaban" in med.lower() for med in last_enc.get('medications', [])) if has_afib: evaluation += " - ✅ Patient with Atrial Fibrillation is on anticoagulation.\n" if on_anticoag \ else " - ❌ Atrial Fibrillation present but no anticoagulant prescribed.\n" has_mi = any("myocardial infarction" in hx.lower() for hx in result.get('past_medical_history', [])) on_statin = any("atorvastatin" in med.lower() or "statin" in med.lower() for med in last_enc.get('medications', [])) if has_mi: evaluation += " - ✅ Patient with MI history is on statin therapy.\n" if on_statin \ else " - ❌ Patient with MI history is not on statin therapy.\n" has_aki = any("acute kidney injury" in dx.lower() for dx in last_enc.get('diagnosis', [])) acei_held = "hold" in last_enc.get('dr_notes', '').lower() and "lisinopril" in last_enc.get('dr_notes', '') if has_aki: evaluation += " - ✅ AKI noted and ACE inhibitor was appropriately held.\n" if acei_held \ else " - ⚠️ AKI present but ACE inhibitor not documented as held.\n" evaluation += ( "\nDisclaimer: This is a simulated evaluation and not a substitute for clinical judgment.\n" ) return evaluation def generate_patient_summary(self, patient_data: Union[List[str], Dict], callback=None) -> str: """ Generate the complete patient summary using the original format and processing. Args: patient_data: Patient data in dictionary format (with 'result' key containing chartsummarydtl) callback: Optional callback function for progress updates Returns: Complete patient summary as formatted string """ logger.info(f"Generating patient summary using model: {self.model_name}") try: # Load model if not already loaded self._load_model() # Extract chartsummarydtl from the original format if isinstance(patient_data, dict) and 'result' in patient_data: result = patient_data['result'] chartsummarydtl = result.get('chartsummarydtl', []) # Convert to the format expected by the prompt processed_data = { 'demographics': { 'age': result.get('agey', 'Unknown'), 'gender': result.get('gender', 'Unknown'), 'patientName': result.get('patientname', 'Unknown') }, 'visits': chartsummarydtl, 'pastMedicalHistory': result.get('past_medical_history', []), 'allergies': result.get('allergies', []), 'habits': {'social_history': result.get('social_history', '')} } else: # Fallback processing processed_data = self.parser.process_patient_record_robust(patient_data) # Create AI prompt using the original format prompt = self._create_ai_prompt(processed_data) # Generate summary using the model if hasattr(self.model, 'generate'): # Use transformers model try: from transformers import pipeline generator = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, device=0 if self.device == "cuda" else -1 ) raw_summary = generator(prompt, max_length=2048, do_sample=True, temperature=0.7)[0]['generated_text'] # Extract only the new generated text raw_summary = raw_summary[len(prompt):].strip() except Exception as e: logger.warning(f"Transformers pipeline failed: {e}") raw_summary = self.model.generate(prompt) else: # Use fallback model raw_summary = self.model.generate(prompt) # Clean up the summary to match original format summary_start_patterns = [ "Now generate the complete clinical summary", "## Clinical Assessment", "# Clinical Assessment", "Clinical Assessment" ] new_summary = raw_summary for pattern in summary_start_patterns: if pattern in raw_summary: new_summary = raw_summary.split(pattern)[-1].strip() break # Ensure we have the four required sections new_summary = self._ensure_four_sections(new_summary) # Format the output to match original formatted_report = self._format_clinical_output(new_summary, patient_data) evaluation_report = self._evaluate_summary_against_guidelines(new_summary, patient_data) # Combine final output final_output = ( f"\n{'='*80}\n" f" FINAL CLINICAL SUMMARY REPORT\n" f"{'='*80}\n" f"{formatted_report}\n\n" f"{'='*80}\n" f" SIMULATED EVALUATION REPORT\n" f"{'='*80}\n" f"{evaluation_report}" ) return final_output except Exception as e: logger.exception(f"Error during summary generation: {e}") return f"Error generating patient summary: {str(e)}" def _ensure_four_sections(self, summary: str) -> str: """Ensure the summary has exactly four sections as required.""" required_sections = [ "## Clinical Assessment", "## Key Trends & Changes", "## Plan & Suggested Actions", "## Direct Guidance for Physician" ] # Check if all sections are present missing_sections = [] for section in required_sections: if section not in summary: missing_sections.append(section) # Add missing sections if any if missing_sections: for section in missing_sections: if section == "## Clinical Assessment": summary += f"\n\n{section}\n- Assessment pending clinical review." elif section == "## Key Trends & Changes": summary += f"\n\n{section}\n- No significant trends identified." elif section == "## Plan & Suggested Actions": summary += f"\n\n{section}\n- Follow standard clinical protocols." elif section == "## Direct Guidance for Physician": summary += f"\n\n{section}\n- Review all clinical data thoroughly." return summary def create_sample_patient_data(): """Create sample patient data in the original chartsummarydtl format.""" return { "result": { "patientname": "John Doe", "patientnumber": "12345", "agey": 65, "gender": "Male", "dob": "1958-01-15", "lastvisitdt": "2024-01-20", "past_medical_history": [ "Hypertension", "Type 2 Diabetes", "Atrial Fibrillation" ], "allergies": ["Penicillin", "Shellfish"], "social_history": "Former smoker, quit 5 years ago. Occasional alcohol use.", "chartsummarydtl": [ { "chartdate": "2024-01-15", "chiefComplaint": "Chest pain and shortness of breath", "symptoms": ["chest pain", "dyspnea", "fatigue"], "vitals": [ "Bp(sys)(mmHg): 150", "Bp(dia)(mmHg): 95", "Pulse(bpm): 110", "Weight: 180" ], "diagnosis": ["Acute coronary syndrome", "Atrial fibrillation with RVR"], "medications": ["Metoprolol 50mg", "Aspirin 81mg", "Warfarin 5mg"], "labtests": [ {"name": "Troponin I", "value": "2.5"}, {"name": "CK-MB", "value": "15.2"}, {"name": "BNP", "value": "450"} ], "radiologyorders": ["Chest X-ray", "Echocardiogram"], "dr_notes": "Patient presents with acute chest pain. EKG shows atrial fibrillation with rapid ventricular response. Troponin elevated consistent with NSTEMI. Started on dual antiplatelet therapy and anticoagulation." }, { "chartdate": "2024-01-20", "chiefComplaint": "Follow-up after cardiac event", "symptoms": ["mild chest discomfort", "fatigue"], "vitals": [ "Bp(sys)(mmHg): 140", "Bp(dia)(mmHg): 88", "Pulse(bpm): 85", "Weight: 178" ], "diagnosis": ["Post-MI follow-up", "Atrial fibrillation", "Hypertension"], "medications": ["Metoprolol 50mg", "Aspirin 81mg", "Warfarin 5mg", "Atorvastatin 40mg"], "labtests": [ {"name": "Troponin I", "value": "0.8"}, {"name": "INR", "value": "2.1"} ], "radiologyorders": ["Echocardiogram follow-up"], "dr_notes": "Patient doing well post-MI. Troponin trending down. INR therapeutic. Continue current medications. Follow up in 2 weeks." } ] } } # Example usage and testing if __name__ == "__main__": print("Patient Summary Generator - Colab Ready") print("=" * 50) # Create generator instance with medical summarization model generator = PatientSummaryGenerator(model_name="Falconsai/medical_summarization") # Create sample data in original format sample_data = create_sample_patient_data() print("Sample patient data created.") print(f"Patient: {sample_data['result']['patientname']}") print(f"Visits: {len(sample_data['result']['chartsummarydtl'])}") print() # Generate summary print("Generating patient summary...") try: summary = generator.generate_patient_summary(sample_data) print("Summary generated successfully!") print("\n" + "="*80) print(summary) except Exception as e: print(f"Error generating summary: {e}") print("This is expected if transformers is not installed.") print("Please install with: pip install transformers torch") print("\nTrying with fallback model...") try: # Try with fallback generator_fallback = PatientSummaryGenerator(model_name="facebook/bart-base") summary = generator_fallback.generate_patient_summary(sample_data) print("Summary generated with fallback model!") print("\n" + "="*80) print(summary) except Exception as e2: print(f"Fallback also failed: {e2}") print("Please install transformers: pip install transformers torch")