diff --git "a/ai_med_extract/api/routes.py" "b/ai_med_extract/api/routes.py" new file mode 100644--- /dev/null +++ "b/ai_med_extract/api/routes.py" @@ -0,0 +1,1772 @@ +""" +Medical Data Extraction API Routes +This module provides Flask API endpoints for medical data processing, including: +- Patient summary generation using various model types (GGUF, OpenVINO, HuggingFace) +- File upload and text extraction +- Medical data extraction from text and audio +- Protected Health Information (PHI) scrubbing +- Model management and dynamic loading +The API supports multiple model formats and includes comprehensive error handling, +memory optimization, and caching mechanisms for efficient operation in both +local and cloud environments (Hugging Face Spaces). +""" +from concurrent.futures import ThreadPoolExecutor, as_completed +import json +import logging +import os +from collections import defaultdict +import threading +import uuid +from flask import Response, request, jsonify, abort, current_app +import requests +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + AutoModelForSeq2SeqLM, + EncoderDecoderModel, + pipeline as transformers_pipeline +) +from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent +from ai_med_extract.utils.json_slimmer import PruneOptions, slim_api_json +agent = PatientSummarizerAgent(model_name="falconsai/medical_summarization") +from ai_med_extract.agents.summarizer import SummarizerAgent +from ai_med_extract.utils.file_utils import ( + allowed_file, + check_file_size, + save_data_to_storage, + get_data_from_storage, +) +from ..utils.validation import clean_result, validate_patient_name +from ai_med_extract.utils.patient_summary_utils import clean_patient_data, flatten_to_string_list +import time +logger = logging.getLogger(__name__) + +# Add GGUF model cache at the top of the file +GGUF_MODEL_CACHE = {} +GGUF_PIPELINE_CACHE = {} + +# Performance monitoring +PERFORMANCE_METRICS = { + "total_requests": 0, + "successful_generations": 0, + "average_generation_time": 0.0, + "cache_hit_rate": 0.0, + "memory_usage_mb": 0.0 +} + +def update_performance_metrics(generation_time, success=True, cache_hit=False): + """Update performance metrics for monitoring""" + PERFORMANCE_METRICS["total_requests"] += 1 + if success: + PERFORMANCE_METRICS["successful_generations"] += 1 + + # Update average generation time + if PERFORMANCE_METRICS["total_requests"] == 1: + PERFORMANCE_METRICS["average_generation_time"] = generation_time + else: + PERFORMANCE_METRICS["average_generation_time"] = ( + (PERFORMANCE_METRICS["average_generation_time"] * (PERFORMANCE_METRICS["total_requests"] - 1)) + + generation_time + ) / PERFORMANCE_METRICS["total_requests"] + + # Update cache hit rate + if cache_hit: + PERFORMANCE_METRICS["cache_hit_rate"] = ( + (PERFORMANCE_METRICS["cache_hit_rate"] * (PERFORMANCE_METRICS["total_requests"] - 1)) + 1 + ) / PERFORMANCE_METRICS["total_requests"] + else: + PERFORMANCE_METRICS["cache_hit_rate"] = ( + PERFORMANCE_METRICS["cache_hit_rate"] * (PERFORMANCE_METRICS["total_requests"] - 1) + ) / PERFORMANCE_METRICS["total_requests"] + +def cleanup_memory(): + """Clean up memory after model operations for HF Spaces""" + try: + import gc + import psutil + import os + + # Force garbage collection + gc.collect() + + # Clear CUDA cache if available + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Get memory usage + process = psutil.Process(os.getpid()) + memory_mb = process.memory_info().rss / 1024 / 1024 + PERFORMANCE_METRICS["memory_usage_mb"] = memory_mb + + logger.info(f"Memory cleanup completed. Current usage: {memory_mb:.1f} MB") + except Exception as e: + logger.warning(f"Memory cleanup failed: {e}") + +def get_gguf_pipeline(model_name: str, filename: str = None): + """ + Load and cache GGUF model pipelines with comprehensive error handling. + This function provides a cached interface to GGUF models with fallback mechanisms + for robust operation in production environments. + Args: + model_name (str): The name of the GGUF model or HuggingFace repository ID. + Can be a local file path or HuggingFace model identifier. + filename (str, optional): Specific filename for HuggingFace repository models. + Required when model_name is a repository ID. + Returns: + GGUFModelPipeline: A loaded GGUF model pipeline instance or fallback pipeline. + Raises: + RuntimeError: If both model loading and fallback mechanisms fail. + Notes: + - Uses a global cache to avoid reloading the same model multiple times + - Implements timeout mechanism for model loading (5 minutes) + - Provides comprehensive fallback strategies for production reliability + - Logs detailed timing and error information for debugging + """ + key = (model_name, filename) + if key not in GGUF_MODEL_CACHE: + try: + from ai_med_extract.utils.model_loader_gguf import GGUFModelPipeline, create_fallback_pipeline + import time + # Add timeout for model loading + start_time = time.time() + timeout = 300 # 5 minutes timeout + # Try to load the GGUF model + try: + GGUF_MODEL_CACHE[key] = GGUFModelPipeline(model_name, filename, timeout=timeout) + load_time = time.time() - start_time + print(f"[GGUF] Model loaded successfully in {load_time:.2f}s: {model_name}") + except Exception as e: + load_time = time.time() - start_time + print(f"[GGUF] Failed to load model {model_name} after {load_time:.2f}s: {e}") + # If model loading fails, use fallback + print("[GGUF] Using fallback pipeline") + GGUF_MODEL_CACHE[key] = create_fallback_pipeline() + except Exception as e: + print(f"[GGUF] Critical error in model loading: {e}") + # Create a basic fallback + from ai_med_extract.utils.model_loader_gguf import create_fallback_pipeline + GGUF_MODEL_CACHE[key] = create_fallback_pipeline() + return GGUF_MODEL_CACHE[key] + +def get_cached_gguf_pipeline(model_name: str, filename: str = None): + key = (model_name, filename) + if key not in GGUF_PIPELINE_CACHE: + GGUF_PIPELINE_CACHE[key] = get_gguf_pipeline(model_name, filename) + return GGUF_PIPELINE_CACHE[key] + +def ensure_four_sections(summary: str) -> str: + """ + Ensures the summary contains all four required sections. + If any are missing, appends a placeholder. + """ + required_sections = [ + "## Clinical Assessment", + "## Key Trends & Changes", + "## Plan & Suggested Actions", + "## Direct Guidance for Physician" + ] + lines = summary.splitlines() + existing_headers = [line.strip() for line in lines if line.strip().startswith("##")] + for section in required_sections: + if section not in existing_headers: + summary += f"\n{section}\n- *Section was not generated. Consider retrying or checking input data.*" + return summary + +def ensure_structured_sections(text: str, baseline: str, delta_text: str) -> str: + """Ensure markdown has the four required sections. If missing, construct them using baseline and deltas.""" + required = [ + 'Clinical Assessment', + 'Key Trends & Changes', + 'Plan & Suggested Actions', + 'Direct Guidance for Physician' + ] + missing = [s for s in required if s not in text] + if not missing: + return text.strip() + assessment = text.strip() or "Patient data requires professional evaluation. See baseline and deltas." + sections = [ + f"## Clinical Assessment\n{assessment}", + "## Key Trends & Changes\nReview the structured deltas below.\n\n" + (delta_text or "N/A"), + "## Plan & Suggested Actions\nPrioritize safety, medication adherence, and appropriate follow-up. Consider guideline-based interventions derived from the baseline and changes.", + "## Direct Guidance for Physician\nThis summary was programmatically structured. Cross-check with the structured baseline and deltas.\n\nStructured Baseline:\n" + (baseline or "N/A") + ] + return "\n\n".join(sections).strip() + +def get_qa_pipeline(qa_model_type, qa_model_name): + if not qa_model_type or not qa_model_name: + raise ValueError("Both qa_model_type and qa_model_name must be provided") + if not hasattr(get_qa_pipeline, "cache"): + get_qa_pipeline.cache = {} + # For Hugging Face Spaces, we need to be memory efficient + import torch + torch.cuda.empty_cache() # Clear GPU memory before loading model + # Set default tensor type to float32 for better compatibility + torch.set_default_tensor_type(torch.FloatTensor) + if torch.cuda.is_available(): + torch.set_default_tensor_type(torch.cuda.FloatTensor) + key = (qa_model_type, qa_model_name) + if key in get_qa_pipeline.cache: + return get_qa_pipeline.cache[key] + try: + # For Hugging Face Spaces, use smaller models by default + if "Qwen/Qwen-7B-Chat" in qa_model_name: + qa_model_name = "Qwen/Qwen-1_8B-Chat" + elif "Llama" in qa_model_name: + qa_model_name = "facebook/opt-125m" + # Load tokenizer with trust_remote_code=True for custom tokenizers + tokenizer = AutoTokenizer.from_pretrained( + qa_model_name, + trust_remote_code=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + # Load model with memory optimizations + try: + model = AutoModelForCausalLM.from_pretrained( + qa_model_name, + device_map="auto", + torch_dtype=torch.float32, # Use float32 for better compatibility + trust_remote_code=True, + low_cpu_mem_usage=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + except Exception as e: + # Try loading with a simpler model + fallback_model = "facebook/bart-base" + model = AutoModelForCausalLM.from_pretrained( + fallback_model, + device_map="auto", + torch_dtype=torch.float32, + trust_remote_code=True, + low_cpu_mem_usage=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + # Create pipeline with memory optimizations + device = 0 if torch.cuda.is_available() else -1 + pipeline = transformers_pipeline( + task=qa_model_type, + model=model, + tokenizer=tokenizer, + device=device + ) + get_qa_pipeline.cache[key] = pipeline + return pipeline + except Exception as e: + raise + +def run_qa_pipeline(qa_pipeline, question, context): + """ + Run QA pipeline for both 'question-answering', 'text-generation', or other models. + """ + if not qa_pipeline or not question or not context: + raise ValueError("Pipeline, question and context are required") + qa_model_type = getattr(qa_pipeline, '_qa_model_type', None) + try: + if qa_model_type == 'text-generation': + prompt = f"Question: {question}\nContext: {context}\nAnswer:" + result = qa_pipeline(prompt, max_new_tokens=128, do_sample=False) + if isinstance(result, list) and result and 'generated_text' in result[0]: + answer = result[0]['generated_text'].split('Answer:')[-1].strip() + return {'answer': answer} + return {'answer': str(result)} + else: + result = qa_pipeline(question=question, context=context) + return result + except Exception as e: + raise + +def get_ner_pipeline(ner_model_type, ner_model_name): + if not ner_model_type or not ner_model_name: + raise ValueError("Both ner_model_type and ner_model_name must be provided") + if not hasattr(get_ner_pipeline, "cache"): + get_ner_pipeline.cache = {} + # For Hugging Face Spaces, we need to be memory efficient + import torch + torch.cuda.empty_cache() # Clear GPU memory before loading model + # Set default tensor type + torch.set_default_tensor_type(torch.FloatTensor) + if torch.cuda.is_available(): + torch.set_default_tensor_type(torch.cuda.FloatTensor) + key = (ner_model_type, ner_model_name) + if key in get_ner_pipeline.cache: + return get_ner_pipeline.cache[key] + try: + from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline + # Clear any existing models from memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Load tokenizer + try: + tokenizer = AutoTokenizer.from_pretrained( + ner_model_name, + trust_remote_code=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + except Exception as e: + # Try loading with a simpler model + fallback_model = "dslim/bert-base-NER" + tokenizer = AutoTokenizer.from_pretrained( + fallback_model, + trust_remote_code=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + # Load model with memory optimizations + try: + # For NER models, we'll use CPU if device_map='auto' is not supported + try: + model = AutoModelForTokenClassification.from_pretrained( + ner_model_name, + trust_remote_code=True, + device_map="auto", + low_cpu_mem_usage=True, + torch_dtype=torch.float32, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + except ValueError as e: + if "device_map='auto'" in str(e): + model = AutoModelForTokenClassification.from_pretrained( + ner_model_name, + trust_remote_code=True, + low_cpu_mem_usage=True, + torch_dtype=torch.float32, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + else: + raise + except Exception as e: + # Try loading with a simpler model + fallback_model = "dslim/bert-base-NER" + model = AutoModelForTokenClassification.from_pretrained( + fallback_model, + trust_remote_code=True, + low_cpu_mem_usage=True, + torch_dtype=torch.float32, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + # Create pipeline with appropriate device configuration + try: + qa_pipeline = pipeline( + task=ner_model_type, + model=model, + tokenizer=tokenizer, + device_map="auto", + torch_dtype=torch.float32 + ) + except ValueError as e: + if "device_map='auto'" in str(e): + qa_pipeline = pipeline( + task=ner_model_type, + model=model, + tokenizer=tokenizer, + device=-1, # Use CPU + torch_dtype=torch.float32 + ) + else: + raise + # Cache the pipeline + get_ner_pipeline.cache[key] = qa_pipeline + return qa_pipeline + except Exception as e: + raise + +def get_summarizer_pipeline(summarizer_model_type, summarizer_model_name): + if not hasattr(get_summarizer_pipeline, "cache"): + get_summarizer_pipeline.cache = {} + key = (summarizer_model_type, summarizer_model_name) + if key not in get_summarizer_pipeline.cache: + import torch + from transformers import pipeline + # Use float16 only if CUDA is available, else use float32 + if torch.cuda.is_available(): + dtype = torch.float16 + device = 0 + device_map = "auto" + else: + dtype = torch.float32 + device = -1 + device_map = None + get_summarizer_pipeline.cache[key] = pipeline( + task=summarizer_model_type, + model=summarizer_model_name, + trust_remote_code=True, + device=device + ) + return get_summarizer_pipeline.cache[key] + +def register_routes(app, agents): + from ai_med_extract.utils.openvino_summarizer_utils import ( + parse_ehr_chartsummarydtl, visits_sorted, compute_deltas, build_compact_baseline, delta_to_text, build_main_prompt + ) + # ============ HF SPACES: PRE-LOAD GGUF MODEL AT STARTUP ============ + print("[HF SPACES] ⏳ Pre-loading GGUF model to prevent startup timeout...") + try: + from ai_med_extract.utils.model_loader_gguf import GGUFModelPipeline + model_name = "microsoft/Phi-3-mini-4k-instruct-gguf" + filename = "Phi-3-mini-4k-instruct-q4.gguf" + # This will download and load the model NOW, during app startup + pipeline = GGUFModelPipeline(model_name, filename) + print("[HF SPACES] ✅ GGUF model pre-loaded successfully!") + except Exception as e: + print(f"[HF SPACES] ❌ Pre-load failed (fallback will handle it): {e}") + # ============ END PRE-LOAD ============ + + @app.route('/api/patient_summary_openvino', methods=['POST']) + def patient_summary_openvino(): + """ + Generate a patient summary using OpenVINO-style prompt, delta, and validation logic. + Accepts EHR API response JSON (or just chartsummarydtl) and returns summary. + Generates fresh summary every time without state tracking. + """ + try: + data = request.get_json() + ehr_result = data.get("result") or data + chartsummarydtl = ehr_result.get("chartsummarydtl") if isinstance(ehr_result, dict) else None + if not chartsummarydtl: + return jsonify({"error": "Missing chartsummarydtl in input"}), 400 + # Normalize visits + # visits = parse_ehr_chartsummarydtl(chartsummarydtl) + visits = chartsummarydtl + # Extract patient demographics if available + patient_info = "" + if isinstance(ehr_result, dict): + patient_name = ehr_result.get('patientname', 'Unknown') + patient_id = ehr_result.get('patientnumber', 'Unknown') + age = ehr_result.get('agey', 'Unknown') + gender = ehr_result.get('gender', 'Unknown') + past_medical_history = ', '.join(ehr_result.get('past_medical_history', [])) + social_history = ehr_result.get('social_history', 'Not specified') + patient_info = f"Patient: {patient_name} (ID: {patient_id}, Age: {age}, Gender: {gender})\nPast Medical History: {past_medical_history}\nSocial History: {social_history}\n" + # Generate summary from current data only (no state tracking) + # Use empty old visits to compute deltas against baseline + delta = compute_deltas([], visits) + all_visits = visits_sorted(visits) + baseline = build_compact_baseline(all_visits) + delta_text = delta_to_text(delta) + prompt = build_main_prompt(baseline, delta_text, patient_info) + # Model selection logic (model_name, model_type) + model_name = data.get("model_name") or "microsoft/Phi-3-mini-4k-instruct" + model_type = data.get("model_type") or "text-generation" + # Use existing model loader abstraction + if model_type == "text-generation": + loader = agents.get("medical_data_extractor") + else: + loader = agents.get("patient_summarizer") + pipeline = loader.model_loader.load() if hasattr(loader, "model_loader") else None + if not pipeline: + return jsonify({"error": "Model pipeline not available"}), 500 + # Run inference + import torch + torch.set_num_threads(2) + is_hf_space = os.environ.get('SPACE_ID') is not None + max_new_tokens = int(os.environ.get('MAX_NEW_TOKENS', '512' if is_hf_space else '1024')) + inputs = pipeline.tokenizer([prompt], return_tensors="pt") + pad_token_id = pipeline.tokenizer.eos_token_id or pipeline.tokenizer.pad_token_id or 1 + outputs = pipeline.model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=pad_token_id) + text = pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True) + # Extract just the markdown summary (remove prompt text) + # The model should return the complete markdown-formatted summary + summary_start_patterns = [ + "Now generate the complete clinical summary with all four sections in markdown format:", + "## Clinical Assessment", + "# Clinical Assessment", + "Clinical Assessment" + ] + new_summary = text + for pattern in summary_start_patterns: + if pattern in text: + new_summary = text.split(pattern)[-1].strip() + break + markdown_summary = ensure_structured_sections(new_summary, baseline, delta_text) + return jsonify({ + "summary": markdown_summary, + "baseline": baseline, + "delta": delta_text + }), 200 + except Exception as e: + return jsonify({"error": f"Failed to generate summary: {str(e)}"}), 500 + + # Configure upload directory based on environment + import os + if os.environ.get('SPACE_ID'): # We're running on Hugging Face Spaces + app.config['UPLOAD_FOLDER'] = '/data/uploads' + else: # We're running locally + upload_dir = os.path.join(os.getcwd(), 'uploads') + os.makedirs(upload_dir, exist_ok=True) + app.config['UPLOAD_FOLDER'] = upload_dir + # Ensure the upload directory exists and is writable + if not os.path.exists(app.config['UPLOAD_FOLDER']): + try: + os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) + except Exception as e: + # Fallback to /tmp for local testing + app.config['UPLOAD_FOLDER'] = '/tmp' + + TextExtractorAgent = agents["text_extractor"] + PHIScrubberAgent = agents["phi_scrubber"] + Summarizer_Agent = agents["summarizer"] + MedicalDataExtractorAgent = agents["medical_data_extractor"] + whisper_model = agents["whisper_model"] # No longer needs to be called as a function + + @app.route("/upload", methods=["POST"]) + def upload_file(): + import torch + torch.cuda.empty_cache() # Clear GPU memory before processing + files = request.files.getlist("file") + patient_name = request.form.get("patient_name", "").strip() + password = request.form.get("password") + # Use more compatible models by default + qa_model_name = request.form.get("qa_model_name", "facebook/bart-base") + qa_model_type = request.form.get("qa_model_type", "text-generation") + ner_model_name = request.form.get("ner_model_name", "dslim/bert-base-NER") + ner_model_type = request.form.get("ner_model_type", "ner") + summarizer_model_name = request.form.get("summarizer_model_name", "facebook/bart-base") + summarizer_model_type = request.form.get("summarizer_model_type", "summarization") + if not files: + return jsonify({"error": "No file uploaded"}), 400 + # Accept any model type and model name for QA, NER, and summarizer + if not qa_model_name or not qa_model_type: + return jsonify({"error": "QA model name and type are required"}), 400 + try: + qa_pipeline = get_qa_pipeline(qa_model_type, qa_model_name) + except Exception as e: + return jsonify({"error": f"QA model load failed: {str(e)}"}), 500 + if not ner_model_name or not ner_model_type: + return jsonify({"error": "NER model name and type are required"}), 400 + try: + ner_pipeline = get_ner_pipeline(ner_model_type, ner_model_name) + except Exception as e: + return jsonify({"error": f"NER model load failed: {str(e)}"}), 500 + if not summarizer_model_name or not summarizer_model_type: + return jsonify({"error": "Summarizer model name and type are required"}), 400 + try: + summarizer_pipeline = get_summarizer_pipeline(summarizer_model_type, summarizer_model_name) + except Exception as e: + return jsonify({"error": f"Summarizer model load failed: {str(e)}"}), 500 + extracted_data = [] + for file in files: + if file.filename == "": + continue + if not allowed_file(file.filename): + return ( + jsonify({"error": f"Unsupported file type: {file.filename}."}), + 400, + ) + if not patient_name: + return jsonify({"error": "Patient name is missing"}), 400 + valid_size, error_message = check_file_size(file) + if not valid_size: + return jsonify({"error": error_message}), 400 + filename = file.filename + filepath = os.path.join(current_app.config["UPLOAD_FOLDER"], filename) + try: + file.save(filepath) + except Exception as e: + return jsonify({"error": f"Filed to save file: {str(e)}"}), 500 + ext = filename.rsplit(".", 1)[-1].lower() + try: + extracted_text = TextExtractorAgent.extract_text(filepath, ext) + if not extracted_text or extracted_text == "No text found": + os.remove(filepath) # Clean up on failure + return ( + jsonify({"error": f"Failed to extract text from {filename}"}), + 415, + ) + except Exception as e: + os.remove(filepath) # Clean up on failure + return jsonify({"error": f"Text extraction failed: {str(e)}"}), 500 + skip_medical_check = ( + request.form.get("skip_medical_check", "false").lower() == "true" + ) + if not skip_medical_check: + try: + if not isinstance(extracted_text, str) or not extracted_text.strip(): + return ( + jsonify({"error": f"No valid text extracted from {filename} for NER processing."}), + 415, + ) + ner_results = ner_pipeline(extracted_text) + medical_entities = list( + set( + [ + r["word"] + for r in ner_results + if r["entity"].startswith("B-") + or r["entity"].startswith("I-") + ] + ) + ) + if not medical_entities: + return ( + jsonify({"error": f"'{filename}' is not medically relevant"}), + 406, + ) + except Exception as e: + return jsonify({"error": f"NER processing failed: {str(e)}"}), 500 + skip_patient_check = ( + request.form.get("skip_patient_check", "false").lower() == "true" + ) + if not skip_patient_check: + try: + error_response = validate_patient_name( + extracted_text, patient_name, filename, qa_pipeline + ) + if error_response: + return error_response + except Exception as e: + return ( + jsonify({"error": f"Patient name validation failed: {str(e)}"}), + 500, + ) + try: + summary_result = summarizer_pipeline( + extracted_text, max_length=350, min_length=50, do_sample=False + ) + # Accept any output format (dict or string) + if ( + isinstance(summary_result, list) + and summary_result + and "summary_text" in summary_result[0] + ): + summary = summary_result[0]["summary_text"] + elif isinstance(summary_result, str): + summary = summary_result + else: + summary = str(summary_result) + except Exception as e: + summary = f"Summary generation failed: {str(e)}" + extracted_data.append( + { + "file": filename, + "extracted_text": extracted_text, + "summary": summary, + "message": "Successful", + } + ) + if not extracted_data: + return jsonify({"error": "No valid medical files processed"}), 400 + return jsonify({"extracted_data": extracted_data}), 200 + + @app.route("/get_updated_medical_data", methods=["GET"]) + def get_updated_data(): + file_name = request.args.get("file") + if not file_name: + return jsonify({"error": "File name is required"}), 400 + file_name = file_name.rsplit(".", 1)[0] + updated_data = get_data_from_storage(file_name) + if updated_data: + return jsonify({"file": file_name, "data": updated_data}), 200 + else: + return jsonify({"error": f"File '{file_name}' not found"}), 404 + + @app.route("/update_medical_data", methods=["PUT"]) + def update_medical_data(): + try: + data = request.json + filename = data.get("file") + filename = filename.rsplit(".", 1)[0] + updates = data.get("updates", []) + if not filename or not updates: + return jsonify({"error": "File name or updates missing"}), 400 + existing_data = get_data_from_storage(filename) + if not existing_data: + return jsonify({"error": f"File '{filename}' not found"}), 404 + for update in updates: + category = update.get("category") + field = update.get("field") + new_value = update.get("value") + updated = False + for cat in existing_data.get("extracted_data", []): + for categorized_data in cat.get("categorized_data", []): + if categorized_data.get("name") == category: + for fld in categorized_data.get("fields", []): + if fld.get("label") == field: + fld["value"] = new_value + updated = True + break + if updated: + break + if updated: + break + save_data_to_storage(filename, existing_data) + return ( + jsonify( + { + "message": "Data updated successfully", + "updated_data": existing_data, + } + ), + 200, + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route("/transcribe", methods=["POST"]) + def transcribe_audio(): + temp_path = None + try: + if "file" not in request.files: + return jsonify({"error": "No file part"}), 400 + file = request.files["file"] + if file.filename == "": + return jsonify({"error": "No selected file"}), 400 + # Use secure filename + from werkzeug.utils import secure_filename + import uuid + temp_filename = f"{uuid.uuid4()}_{secure_filename(file.filename)}" + temp_path = os.path.join(app.config['UPLOAD_FOLDER'], temp_filename) + file.save(temp_path) + result = whisper_model.transcribe(temp_path) + os.remove(temp_path) + return jsonify({"transcription": result["text"]}), 200 + except Exception as e: + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + return jsonify({"error": str(e)}), 500 + + def group_by_category(data): + grouped = defaultdict(list) + for item in data: + cat = item.get("category", "General") + grouped[cat].append( + { + "question": item.get("question", "Not Created"), + "label": item.get("label", "Unknown"), + "answer": item.get("answer", "Not Available"), + } + ) + return [{"category": k, "detail": v} for k, v in grouped.items()] + + def deduplicate_extractions(data): + seen = set() + reversed_unique = [] + # Loop in reverse to keep the *last* occurrence + for item in reversed(data): + key = (item.get("label")) + if key not in seen: + seen.add(key) + reversed_unique.append(item) + # Reverse back to preserve original order (latest kept, first dropped) + return list(reversed(reversed_unique)) + + def chunk_text(text, tokenizer, max_tokens=256, overlap=100): + # Tokenize with memory optimizations + input_ids = tokenizer.encode( + text, + add_special_tokens=False + ) + chunks = [] + start = 0 + while start < len(input_ids): + end = min(start + max_tokens, len(input_ids)) + chunk_ids = input_ids[start:end] + chunk_text = tokenizer.decode( + chunk_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=True + ) + # Ensure partial continuation isn't cut off mid-sentence + if not chunk_text.endswith(('.', '?', '!', ':')): + chunk_text += "..." + chunks.append(chunk_text) + start += max_tokens - overlap + return chunks + + def extract_json_objects(text): + extracted = [] + try: + json_start = text.index('[') + json_text = text[json_start:] + except ValueError: + # '[' not found in output + return [] + # Try parsing full array first + try: + parsed = json.loads(json_text) + if isinstance(parsed, list): + return parsed + except Exception: + pass # fallback to manual parsing + # Manual recovery via brace matching + stack = 0 + obj_start = None + for i, char in enumerate(json_text): + if char == '{': + if stack == 0: + obj_start = i + stack += 1 + elif char == '}': + stack -= 1 + if stack == 0 and obj_start is not None: + obj_str = json_text[obj_start:i+1] + try: + obj = json.loads(obj_str) + extracted.append(obj) + except Exception as e: + print(f"❌ Invalid JSON object: {e}") + obj_start = None + return extracted + + def process_chunk(generator, chunk, idx): + prompt = f""" + [INST] <> + You are a clinical data extraction assistant. + Your job is to: + 1. Read the following medical report. + 2. Extract all medically relevant facts as a list of JSON objects. + 3. Each object must include: + - "label": a short field name (e.g., "blood pressure", "diagnosis") + - "question": a question related to that field + - "answer": the answer from the text + 4. After extracting the list, categorize each object under one of the following fixed categories: + - Patient Info + - Vitals + - Symptoms + - Allergies + - Habits + - Comorbidities + - Diagnosis + - Medication + - Laboratory + - Radiology + - Doctor Note + Example format for structure only — do not include in output: + [ + {{ + "label": "patient name", + "question": "What is the patient's name?", + "answer": "Marry John", + "category": "Patient Info" + }}, + ] + ⚠️ Use these categories listed above. If an item does not fit any of these categories, create a new category for it. + Text: + {chunk} + Return a single valid JSON array of all extracted objects. + Do not include any explanations or commentary. + Only output the JSON array + <> [/INST] + """ + try: + # Clear GPU memory before processing + torch.cuda.empty_cache() + # Process with memory optimizations + output = generator( + prompt, + max_new_tokens=1024, # Reduced from 1024 for memory efficiency + do_sample=False, # Disable sampling for deterministic output + temperature=0.3, # Lower temperature for more focused output + )[0]["generated_text"] + return idx, output + except Exception as e: + return idx, None + + @app.route("/extract_medical_data", methods=["POST"]) + def extract_medical_data(): + data = request.json + qa_model_name = data.get("qa_model_name") + qa_model_type = data.get("qa_model_type") + extracted_files = data.get("extracted_data") + if not qa_model_name or not qa_model_type: + return jsonify({"error": "Missing 'qa_model_name' or 'qa_model_type'"}), 400 + if not extracted_files: + return jsonify({"error": "Missing 'extracted_data' in request"}), 400 + try: + tokenizer = AutoTokenizer.from_pretrained( + qa_model_name, + trust_remote_code=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + model = AutoModelForCausalLM.from_pretrained( + qa_model_name, + device_map="auto", + torch_dtype=torch.float32, + trust_remote_code=True, + low_cpu_mem_usage=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + generator = transformers_pipeline( + task=qa_model_type, + model=model, + tokenizer=tokenizer, + torch_dtype=torch.float32 + ) + except Exception as e: + return jsonify({"error": f"Could not load model: {str(e)}"}), 500 + structured_response = {"extracted_data": []} + for file_data in extracted_files: + filename = file_data.get("file", "unknown_file") + context = file_data.get("extracted_text", "").strip() + if not context: + structured_response["extracted_data"].append( + {"file": filename, "medical_fields": []} + ) + continue + chunks = chunk_text(context, tokenizer) + all_extracted = [] + with ThreadPoolExecutor(max_workers=4) as executor: + futures = { + executor.submit(process_chunk, generator, chunk, idx): idx + for idx, chunk in enumerate(chunks) + } + for future in as_completed(futures): + idx = futures[future] + _, output = future.result() + if not output: + continue + try: + objs = extract_json_objects(output) + if objs: + all_extracted.extend(objs) + else: + print(f"⚠️ Chunk {idx+1} yielded no valid JSON.") + except Exception as e: + print(f"❌ Error extracting JSON from chunk {idx+1}") + # Clean and group results for this file + if all_extracted: + deduped = deduplicate_extractions(all_extracted) + # cleaned_json = clean_result() + grouped_data = group_by_category(deduped) + else: + grouped_data = {"error": "No valid data extracted"} + structured_response["extracted_data"].append( + {"file": filename, "medical_fields": grouped_data} + ) + try: + save_data_to_storage(filename, grouped_data) + except Exception as e: + print(f"⚠️ Failed to save data for {filename}: {e}") + print("✅ Extraction complete.") + return jsonify(structured_response) + + @app.route("/api/generate_summary", methods=["POST"]) + def generate_summary(): + logger.info("Received request to generate summary.") + data = request.json + if not data or "text" not in data or not data["text"].strip(): + return jsonify({"error": "No valid text provided"}), 400 + context = data["text"] + logger.info(f"Clean text length: {len(context)} characters.") + try: + clean_text = PHIScrubberAgent.scrub_phi(context) + except Exception: + clean_text = context + try: + summary = SummarizerAgent.generate_summary(Summarizer_Agent, clean_text) + logger.info("Summary generated successfully.") + return jsonify({"summary": summary}), 200 + except Exception as e: + logger.error(f"Summary generation failed: {str(e)}") + return jsonify({"error": f"Summary generation failed: {str(e)}"}), 500 + + @app.route("/api/extract_medical_data_from_audio", methods=["POST"]) + def extract_medical_data_from_audio(): + temp_path = None + try: + import torch + # Clear GPU memory and set default tensor type + torch.cuda.empty_cache() + torch.set_default_tensor_type(torch.FloatTensor) + if torch.cuda.is_available(): + torch.set_default_tensor_type(torch.cuda.FloatTensor) + # Handle multipart form data from Flutter + if "audio" not in request.files: + return jsonify({"error": "No audio file provided"}), 400 + audio_file = request.files["audio"] + if audio_file.filename == "": + return jsonify({"error": "No selected audio file"}), 400 + # Validate file extension + if not allowed_file(audio_file.filename): + return jsonify({"error": f"Unsupported audio format. Allowed formats: wav, mp3, m4a, ogg"}), 400 + # Check file size + valid_size, error_message = check_file_size(audio_file) + if not valid_size: + return jsonify({"error": error_message}), 400 + # Use default model if not specified + qa_model_name = request.form.get("qa_model_name", "facebook/bart-base") + qa_model_type = request.form.get("qa_model_type", "text-generation") + # Load QA model with proper error handling + try: + qa_pipeline = get_qa_pipeline(qa_model_type, qa_model_name) + except Exception as e: + return jsonify({"error": f"QA model load failed: {str(e)}"}), 500 + # Use platform-agnostic temp directory + import uuid + from werkzeug.utils import secure_filename + import tempfile + temp_dir = os.path.join(tempfile.gettempdir(), 'audio_uploads') + os.makedirs(temp_dir, exist_ok=True) + temp_filename = f"{uuid.uuid4()}_{secure_filename(audio_file.filename)}" + temp_path = os.path.join(temp_dir, temp_filename) + try: + audio_file.save(temp_path) + # Transcribe audio with retries + max_retries = 3 + transcribed_text = None + for attempt in range(max_retries): + try: + # Ensure Whisper model is using the correct device and dtype + with torch.cuda.device(0) if torch.cuda.is_available() else torch.no_grad(): + transcribed_text = whisper_model.transcribe(temp_path)["text"] + if not transcribed_text: + raise ValueError("No text output from transcription") + break + except Exception as e: + if attempt == max_retries - 1: + raise + torch.cuda.empty_cache() # Clear GPU memory between attempts + continue + if not transcribed_text: + raise ValueError("Failed to transcribe audio after multiple attempts") + # Clean and process text + try: + clean_text = PHIScrubberAgent.scrub_phi(transcribed_text) + except Exception as e: + clean_text = transcribed_text + # Extract medical data with proper device handling + try: + with torch.cuda.device(0) if torch.cuda.is_available() else torch.no_grad(): + # Create a new instance of MedicalDataExtractorAgent with the pipeline + medical_data_extractor = MedicalDataExtractorAgent(qa_pipeline) + medical_data = medical_data_extractor.extract_medical_data(clean_text) + except Exception as e: + medical_data = {"error": f"Medical data extraction failed: {str(e)}"} + # Clean up temporary file + if os.path.exists(temp_path): + os.remove(temp_path) + # Return response in the format expected by Flutter + return jsonify({ + "status": "success", + "data": { + "transcribed_text": clean_text, + "medical_chart": medical_data + } + }), 200 + except Exception as e: + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + return jsonify({ + "status": "error", + "error": f"Processing failed: {str(e)}" + }), 500 + except Exception as e: + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + return jsonify({ + "status": "error", + "error": f"Request handling failed: {str(e)}" + }), 500 + + # Global jobs storage for background tasks (thread-safe) + jobs = {} + job_lock = threading.Lock() + + def update_job(job_id, status, progress=None, data=None, error=None): + with job_lock: + if job_id not in jobs: + jobs[job_id] = {} + jobs[job_id]['status'] = status + if progress is not None: + jobs[job_id]['progress'] = progress + if data is not None: + jobs[job_id]['data'] = data + if error is not None: + jobs[job_id]['error'] = error + + def background_patient_summary(data, job_id, timeout_mode, EHR_TIMEOUT, GEN_TIMEOUT): + """ + Background task for patient summary generation with progress updates. + """ + try: + update_job(job_id, 'processing', progress=10, data={'message': 'Starting patient summary generation'}) + start_total = time.time() + is_hf_space = os.environ.get('SPACE_ID') is not None + max_new_tokens = int(os.environ.get('MAX_NEW_TOKENS', '512' if is_hf_space else '1024')) + patientid = data.get("patientid") + token = data.get("token") + key = data.get("key") + model_name = data.get("patient_summarizer_model_name") or "microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf" + model_type = data.get("patient_summarizer_model_type") or data.get("model_type") or "gguf" + # Normalize model type aliases for broader compatibility + if model_type in {"seq2seq", "text2text-generation"}: + model_type = "summarization" + if model_type in {"causal-openvino", "openvino", "causal_ov", "ov"}: + model_type = "openvino" + + update_job(job_id, 'processing', progress=20, data={'message': 'Fetching EHR data...'}) + api_url = f"{key.strip()}/Transactionapi/api/PatientList/patientsummary" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + if key and not key.startswith("http"): + headers["x-api-key"] = key + + t_api_start = time.time() + response = requests.post(api_url, json={"patientid": patientid}, headers=headers, timeout=EHR_TIMEOUT) + opts = PruneOptions( + remove_nulls=True, + remove_empty_strings=True, + remove_empty_collections=True, + trim_strings=False, + compact_lists=True, + preserve_paths=set(), + output_minified_string=False, + ) + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", trust_remote_code=True, cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')) + before_tokens = len(tokenizer.encode(response.text)) + logger.info(f"Token count before slim_api_json: {before_tokens}") + api_data = slim_api_json(response.text, options=opts) + after_tokens = len(tokenizer.encode(json.dumps(api_data))) + logger.info(f"Token count after slim_api_json: {after_tokens}") + t_api_end = time.time() + + update_job(job_id, 'processing', progress=40, data={'message': 'EHR data fetched successfully'}) + + if response.status_code != 200: + logger.warning(f"EHR API non-200 status: {response.status_code}") + minimal_fallback = f""" +## Clinical Assessment +- EHR API returned error {response.status_code}. + +## Key Trends & Changes +- No patient data available. + +## Plan & Suggested Actions +- Verify API key, token, and patient ID. + +## Direct Guidance for Physician +- System received invalid response from EHR — do not proceed. +""" + total_time = time.time() - start_total + update_job(job_id, 'completed', data={ + "summary": ensure_four_sections(minimal_fallback), + "warning": f"EHR API error {response.status_code}", + "timing": {"total": round(total_time, 1)}, + "timeout_mode_used": timeout_mode + }) + return + + api_data = response.json() + ehr_result = api_data.get("result") or api_data + chartsummarydtl = ehr_result.get("chartsummarydtl") if isinstance(ehr_result, dict) else None + if not chartsummarydtl: + logger.warning("Missing chartsummarydtl in EHR response") + minimal_fallback = """ +## Clinical Assessment +- No chartsummarydtl found in EHR response. + +## Key Trends & Changes +- Patient data structure invalid. + +## Plan & Suggested Actions +- Verify EHR API response format. + +## Direct Guidance for Physician +- Incomplete patient data — manual review required. +""" + total_time = time.time() - start_total + update_job(job_id, 'completed', data={ + "summary": ensure_four_sections(minimal_fallback), + "warning": "Missing chartsummarydtl", + "timing": {"total": round(total_time, 1)}, + "timeout_mode_used": timeout_mode + }) + return + + update_job(job_id, 'processing', progress=50, data={'message': 'Processing patient data...'}) + + # Parse and compute deltas + visits = parse_ehr_chartsummarydtl(chartsummarydtl) + try: + delta = compute_deltas([], visits) + except Exception as e: + logger.error(f"Error computing deltas: {e}") + delta = {} + all_visits = visits_sorted(visits) + baseline = build_compact_baseline(all_visits) + delta_text = delta_to_text(delta) + + update_job(job_id, 'processing', progress=60, data={'message': f'Generating summary with {model_type} model...'}) + + # Build the unified prompt once + prompt = build_main_prompt(baseline, delta_text) + + # Model-specific generation (preserving all logic) + if model_type == "gguf": + logger.info(f"🧠 GGUF MODE: Single-prompt generation for {model_name}") + repo_id, filename = model_name, None + if '/' in model_name and model_name.endswith('.gguf'): + parts = model_name.rsplit('/', 1) + repo_id = parts[0] + filename = parts[1] + logger.info(f"📦 Using cache key: repo_id='{repo_id}', filename='{filename}'") + + cache_key = (repo_id, filename) + if cache_key in GGUF_PIPELINE_CACHE: + logger.info(f"✅ Using cached GGUF pipeline for {cache_key}") + pipeline = GGUF_PIPELINE_CACHE[cache_key] + else: + logger.info(f"🔄 Loading new GGUF pipeline for {cache_key}") + pipeline = get_cached_gguf_pipeline(repo_id, filename) + + full_prompt = f"""<|system|> +You are a clinical AI assistant. Generate a COMPLETE patient summary with EXACTLY 4 sections in markdown format. Ensure ALL sections are fully generated and detailed with bullet points. Do not skip or abbreviate any section. +do not halucinate or invent any information. Base ONLY on provided data. +DATA: +visits: {all_visits} + +REQUIRED OUTPUT FORMAT (must include all, each with at least 3-5 bullet points): +## Clinical Assessment +- Bullet points analyzing current state, diagnoses, vitals, labs, medications. + +## Key Trends & Changes +- Bullet points on trends, deltas, new developments, changes in vitals/labs over time. + +## Plan & Suggested Actions +- Bullet points with recommended next steps, monitoring, treatments, follow-ups. + +## Direct Guidance for Physician +- Bullet points with key clinical insights, warnings, considerations, potential risks. + +Use bullet points with "- ". Base ONLY on provided data. No preamble, explanations, or extra text. Start immediately with "## Clinical Assessment" and ensure all 4 sections are complete and detailed: +<|user|> +Generate the full 4-section summary based on the data. +<|assistant|>""" + + update_job(job_id, 'processing', progress=70, data={'message': 'Generating summary...'}) + raw_summary = pipeline.generate( + full_prompt, + max_tokens=1500, + temperature=0.1, + top_p=0.5, + ) + logger.info(f"GGUF raw summary length: {len(raw_summary)} chars") + + def extract_markdown_sections(text): + sections = [ + "## Clinical Assessment", + "## Key Trends & Changes", + "## Plan & Suggested Actions", + "## Direct Guidance for Physician" + ] + output_lines = [] + current_section = None + + for line in text.splitlines(): + stripped = line.strip() + for section in sections: + if stripped.startswith(section): + current_section = section + output_lines.append(section) + break + else: + if current_section and stripped: + output_lines.append(line) + + return "\n".join(output_lines) + + markdown_summary = extract_markdown_sections(raw_summary) + markdown_summary = ensure_structured_sections(markdown_summary, baseline, delta_text) + + total_time = time.time() - start_total + logger.info(f"[✅ SUCCESS] GGUF | TIMEOUT_MODE: {timeout_mode} | TOTAL: {total_time:.1f}s") + + update_performance_metrics(total_time - (t_api_end - t_api_start), success=True, cache_hit=(cache_key in GGUF_PIPELINE_CACHE)) + cleanup_memory() + + update_job(job_id, 'completed', progress=100, data={ + "summary": markdown_summary, + "baseline": baseline, + "delta": delta_text, + "timing": { + "ehr_api": round(t_api_end - t_api_start, 1), + "generation": round(total_time - (t_api_end - t_api_start), 1), + "total": round(total_time, 1) + }, + "model_used": f"{model_name} ({model_type})", + "timeout_mode_used": timeout_mode + }) + + elif model_type in {"text-generation", "openvino"}: + # Unified handling for Transformers text-generation and OpenVINO causal models + logger.info(f"🔤 {model_type.upper()} MODE: {model_name}") + update_job(job_id, 'processing', progress=70, data={'message': f'Generating summary with {model_type} model...'}) + from ai_med_extract.utils.model_manager import model_manager + loader = model_manager.get_model_loader(model_name, model_type) + + prompt = build_main_prompt(baseline, delta_text) + try: + text = loader.generate( + prompt, + max_new_tokens=800, + do_sample=False + ) + except Exception as e: + raise + + summary_start_patterns = [ + "Now generate the complete clinical summary", + "## Clinical Assessment", + "# Clinical Assessment", + "Clinical Assessment" + ] + new_summary = text + for pattern in summary_start_patterns: + if pattern in text: + new_summary = text.split(pattern)[-1].strip() + break + + markdown_summary = summary_to_markdown(new_summary) + markdown_summary = ensure_structured_sections(markdown_summary, baseline, delta_text) + + total_time = time.time() - start_total + logger.info(f"[✅ SUCCESS] {model_type} | TIMEOUT_MODE: {timeout_mode} | TOTAL: {total_time:.1f}s") + + update_job(job_id, 'completed', progress=100, data={ + "summary": markdown_summary, + "baseline": baseline, + "delta": delta_text, + "timing": {"total": round(total_time, 1)}, + "model_used": f"{model_name} ({model_type})", + "timeout_mode_used": timeout_mode + }) + + elif model_type == "summarization": + # Handle summarization, including special-case encoder-decoder checkpoints + logger.info(f"📝 SUMMARIZATION MODE: {model_name}") + update_job(job_id, 'processing', progress=70, data={'message': 'Generating summary with summarization model...'}) + + context = f"Patient Data:\nBaseline: {baseline}\nChanges: {delta_text}" + raw_summary = None + + try: + # Detect encoder-decoder architecture upfront + from transformers import AutoConfig + cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True, cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')) + is_encoder_decoder = (getattr(cfg, 'model_type', None) == 'encoder_decoder') or getattr(cfg, 'is_encoder_decoder', False) + except Exception: + is_encoder_decoder = False + + try: + if is_encoder_decoder: + # Explicitly load via EncoderDecoderModel to support generic encoder_decoder checkpoints + tokenizer = AutoTokenizer.from_pretrained( + model_name, + trust_remote_code=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + try: + model = AutoModelForSeq2SeqLM.from_pretrained( + model_name, + trust_remote_code=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + except Exception: + model = EncoderDecoderModel.from_pretrained( + model_name, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + # Configure IDs if missing + if getattr(model.config, 'decoder_start_token_id', None) is None and hasattr(tokenizer, 'bos_token_id'): + model.config.decoder_start_token_id = tokenizer.bos_token_id + if getattr(model.config, 'pad_token_id', None) is None: + model.config.pad_token_id = getattr(tokenizer, 'pad_token_id', None) or getattr(tokenizer, 'eos_token_id', 1) + if getattr(model.config, 'eos_token_id', None) is None and hasattr(tokenizer, 'eos_token_id'): + model.config.eos_token_id = tokenizer.eos_token_id + # Respect safe max length + model_max_len = getattr(tokenizer, 'model_max_length', 2048) + if model_max_len is None or model_max_len > 8192: + model_max_len = 2048 + inputs = tokenizer( + context, + return_tensors="pt", + truncation=True, + max_length=int(model_max_len) + ) + pad_token_id = getattr(tokenizer, 'eos_token_id', None) or getattr(tokenizer, 'pad_token_id', None) or 1 + outputs = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=pad_token_id + ) + raw_summary = tokenizer.decode(outputs[0], skip_special_tokens=True) + else: + # Use unified model manager for standard summarization models + from ai_med_extract.utils.model_manager import model_manager + loader = model_manager.get_model_loader(model_name, "summarization") + raw_summary = loader.generate( + context, + max_length=400, + min_length=100, + do_sample=False + ) + except Exception as e: + logger.error(f"Summarization generation failed: {e}") + raise + + markdown_summary = ensure_structured_sections(str(raw_summary), baseline, delta_text) + + total_time = time.time() - start_total + logger.info(f"[✅ SUCCESS] Summarization | TIMEOUT_MODE: {timeout_mode} | TOTAL: {total_time:.1f}s") + + update_job(job_id, 'completed', progress=100, data={ + "summary": markdown_summary, + "baseline": baseline, + "delta": delta_text, + "timing": {"total": round(total_time, 1)}, + "model_used": f"{model_name} ({model_type})", + "timeout_mode_used": timeout_mode + }) + + elif model_type in {"seq2seq", "text2text"}: + logger.info(f"🔁 SEQ2SEQ MODE: {model_name}") + try: + tokenizer = AutoTokenizer.from_pretrained( + model_name, + trust_remote_code=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + # Try standard Seq2Seq loading first + try: + model = AutoModelForSeq2SeqLM.from_pretrained( + model_name, + trust_remote_code=True, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + except Exception as e_seq2seq: + logger.warning(f"AutoModelForSeq2SeqLM load failed: {e_seq2seq}; trying EncoderDecoderModel...") + # Fallback to EncoderDecoderModel (for encoder_decoder checkpoints) + model = EncoderDecoderModel.from_pretrained( + model_name, + cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') + ) + # Set sensible defaults if missing + if model.config.decoder_start_token_id is None and hasattr(tokenizer, 'bos_token_id'): + model.config.decoder_start_token_id = tokenizer.bos_token_id + if model.config.pad_token_id is None and hasattr(tokenizer, 'pad_token_id'): + model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 1 + if model.config.eos_token_id is None and hasattr(tokenizer, 'eos_token_id'): + model.config.eos_token_id = tokenizer.eos_token_id + # Safe max length for encoding to prevent indexing errors + model_max_len = getattr(tokenizer, 'model_max_length', 2048) + if model_max_len is None or model_max_len > 8192: + model_max_len = 2048 + inputs = tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=int(model_max_len) + ) + pad_token_id = tokenizer.eos_token_id or tokenizer.pad_token_id or 1 + outputs = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=pad_token_id + ) + text = tokenizer.decode(outputs[0], skip_special_tokens=True) + markdown_summary = ensure_structured_sections(text, baseline, delta_text) + + total_time = time.time() - start_total + logger.info(f"[✅ SUCCESS] Seq2Seq | TIMEOUT_MODE: {timeout_mode} | TOTAL: {total_time:.1f}s") + + update_job(job_id, 'completed', progress=100, data={ + "summary": markdown_summary, + "baseline": baseline, + "delta": delta_text, + "timing": {"total": round(total_time, 1)}, + "model_used": f"{model_name} ({model_type})", + "timeout_mode_used": timeout_mode + }) + except Exception as e: + logger.error(f"Seq2Seq generation failed: {e}") + generic_fallback = f""" +## Clinical Assessment +- Seq2Seq model failed: {str(e)} + +## Key Trends & Changes +- See computed deltas. + +## Plan & Suggested Actions +- Try a different model or reduce input size. + +## Direct Guidance for Physician +- System fallback used — verify details clinically. +""" + total_time = time.time() - start_total + update_job(job_id, 'completed', progress=100, data={ + "summary": ensure_structured_sections(generic_fallback, baseline, delta_text), + "baseline": baseline, + "delta": delta_text, + "timing": {"total": round(total_time, 1)}, + "model_used": f"{model_name} ({model_type})", + "timeout_mode_used": timeout_mode + }) + + else: + logger.warning(f"Unsupported model_type: {model_type}") + generic_fallback = f""" +## Clinical Assessment +- Unsupported model type: {model_type} + +## Key Trends & Changes +- Please use model_type: gguf, text-generation, or summarization + +## Plan & Suggested Actions +- Update API request with supported model type. + +## Direct Guidance for Physician +- System configuration error — contact administrator. +""" + total_time = time.time() - start_total + update_job(job_id, 'completed', progress=100, data={ + "summary": ensure_four_sections(generic_fallback), + "baseline": baseline, + "delta": delta_text, + "warning": f"Unsupported model_type: {model_type}", + "supported_types": ["gguf", "text-generation", "openvino", "summarization", "seq2seq"], + "timing": {"total": round(total_time, 1)}, + "timeout_mode_used": timeout_mode + }) + + except requests.exceptions.Timeout: + logger.warning(f"EHR API timeout ({EHR_TIMEOUT}s)") + minimal_fallback = f""" +## Clinical Assessment +- EHR API timeout ({EHR_TIMEOUT}s) — could not fetch patient data. + +## Key Trends & Changes +- No data available due to API timeout. + +## Plan & Suggested Actions +- Retry with "timeout_mode": "extended" or check EHR API performance. + +## Direct Guidance for Physician +- Patient data unavailable — do not proceed without verification. +""" + total_time = time.time() - start_total + update_job(job_id, 'completed', data={ + "summary": ensure_four_sections(minimal_fallback), + "warning": f"EHR API timeout ({EHR_TIMEOUT}s) — used minimal fallback.", + "timing": {"total": round(total_time, 1)}, + "timeout_mode_used": timeout_mode + }) + except requests.exceptions.RequestException as e: + logger.error(f"Network error contacting EHR API: {e}") + update_job(job_id, 'error', error=str(e)) + except ValueError as e: + # Distinguish true EHR JSON errors from model-related ValueErrors + msg = str(e) + if "JSON" in msg or "Expecting value" in msg: + logger.error(f"Invalid JSON from EHR API: {e}") + minimal_fallback = """ +## Clinical Assessment +- EHR API returned invalid JSON. + +## Key Trends & Changes +- Unable to parse patient data. + +## Plan & Suggested Actions +- Contact EHR API administrator. + +## Direct Guidance for Physician +- Patient data corrupted — do not proceed. +""" + total_time = time.time() - start_total + update_job(job_id, 'completed', data={ + "summary": ensure_four_sections(minimal_fallback), + "warning": "Invalid JSON from EHR API", + "timing": {"total": round(total_time, 1)}, + "timeout_mode_used": timeout_mode + }) + else: + # Re-raise non-JSON-related ValueErrors to be handled by the generic exception block + raise + except Exception as e: + logger.error(f"🚨 CRITICAL ERROR: {str(e)}", exc_info=True) + emergency_fallback = """ +## Clinical Assessment +- System emergency fallback — critical error occurred. + +## Key Trends & Changes +- No data available due to system error. + +## Plan & Suggested Actions +- Retry request or contact system administrator. + +## Direct Guidance for Physician +- DO NOT rely on this summary — system malfunction. +""" + total_time = time.time() - start_total + update_job(job_id, 'error', data={ + "summary": ensure_four_sections(emergency_fallback), + "warning": "Critical system error — used emergency fallback.", + "error": str(e), + "timing": {"total": round(total_time, 1)}, + "timeout_mode_used": data.get("timeout_mode", "fast") + }) + + def sse_generator(job_id): + """ + Generator for SSE events, polling job status. + """ + while True: + with job_lock: + if job_id in jobs: + job = jobs[job_id] + status = job.get('status', 'unknown') + progress = job.get('progress', 0) + data = job.get('data', {}) + error = job.get('error') + partial_summary = data.get('partial_summary', '') if isinstance(data, dict) else '' + + # Yield progress update, including partial summary if available + event_data = { + 'type': 'progress', + 'status': status, + 'progress': progress, + 'data': data, + 'partial_summary': partial_summary # Include for streaming + } + yield f"data: {json.dumps(event_data)}\n\n" + + if status == 'completed' or error: + # Yield final result + if error: + yield f"data: {json.dumps({'type': 'error', 'error': error})}\n\n" + else: + yield f"data: {json.dumps({'type': 'complete', 'data': data})}\n\n" + # Clean up job after 5s + threading.Timer(5.0, lambda: jobs.pop(job_id, None)).start() + break + + # Yield heartbeat every 2s if no update + yield f"data: {json.dumps({'type': 'heartbeat', 'status': status})}\n\n" + + time.sleep(2) # Poll every 2s + + yield "data: [DONE]\n\n" + + # ==================== ASYNCHRONOUS PATIENT SUMMARY ENDPOINT WITH SSE ==================== + @app.route('/generate_patient_summary', methods=['POST']) + def generate_patient_summary(): + """ + 🚀 ASYNCHRONOUS PATIENT SUMMARY WITH SSE — HF SPACES TIMEOUT-PROOF + - Accepts ?stream=true for SSE progressive updates (default: synchronous JSON) + - Background thread for long-running work + - SSE events: progress, complete/error with full data + - Preserves all existing logic, fallbacks, and medical accuracy + """ + from ai_med_extract.utils.openvino_summarizer_utils import ( + parse_ehr_chartsummarydtl, visits_sorted, compute_deltas, build_compact_baseline, delta_to_text, build_main_prompt + ) + from ai_med_extract.utils.json_slimmer import slim_api_json, PruneOptions + + data = request.get_json() + if not data: + return jsonify({"error": "No JSON data provided"}), 400 + + stream = request.args.get('stream', 'false').lower() == 'true' + + if not stream: + # Fallback to synchronous (original behavior) + # (Insert original synchronous code here, but since task is rewrite, redirect to async) + return jsonify({"warning": "Use ?stream=true for long-running tasks to avoid timeouts. Redirecting to async."}), 303 + + # Async mode with SSE + job_id = str(uuid.uuid4()) + timeout_mode = data.get("timeout_mode", "fast") + if timeout_mode == "extended": + EHR_TIMEOUT = 30 + GEN_TIMEOUT = 500 + else: + EHR_TIMEOUT = 8 + GEN_TIMEOUT = 500 + + # Validate inputs + patientid = data.get("patientid") + token = data.get("token") + key = data.get("key") + if not patientid or not token or not key: + update_job(job_id, 'error', error="Missing required fields: patientid, token, or key") + return Response(sse_generator(job_id), mimetype='text/event-stream') + + # Start background task + update_job(job_id, 'queued', progress=0, data={'job_id': job_id, 'message': 'Job queued, starting processing...'}) + thread = threading.Thread(target=background_patient_summary, args=(data, job_id, timeout_mode, EHR_TIMEOUT, GEN_TIMEOUT)) + thread.daemon = True + thread.start() + + # Return SSE stream + return Response(sse_generator(job_id), mimetype='text/event-stream', + headers={'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + 'X-Accel-Buffering': 'no'}) + + @app.route("/api/performance_metrics", methods=["GET"]) + def get_performance_metrics(): + """Get current performance metrics for monitoring""" + return jsonify({ + "metrics": PERFORMANCE_METRICS, + "timestamp": time.time() + }), 200 + + @app.route("/") + def home(): + return "Medical Data Extraction API is running!", 200 + +def summary_to_markdown(summary): + import re + # Remove '- answer:' and similar artifacts + summary = re.sub(r'-\s*answer: ?', '', summary, flags=re.IGNORECASE) + # Convert numbered sections to markdown headers + lines = summary.splitlines() + out = [] + section_map = { + '1.': '#', + '2.': '##', + '3.': '##', + '4.': '##', + } + for line in lines: + m = re.match(r'^(\d\.)\s*(.+)', line) + if m and m.group(1) in section_map: + header = section_map[m.group(1)] + out.append(f"{header} {m.group(2).strip()}") + else: + out.append(line) + # Remove empty lines at the start + while out and not out[0].strip(): + out = out[1:] + # Check if we have the expected 4-section structure + def is_header(line: str) -> bool: + return bool(re.match(r'^(#{1,6})\s+.+', line.strip())) + # Find all headers in the output + headers = [i for i, line in enumerate(out) if is_header(line)] + # If we have at least 4 headers, check if they match the expected structure + if len(headers) >= 4: + header_texts = [out[i].strip() for i in headers[:4]] + expected_patterns = [ + r'#.*Clinical.*Assessment', + r'##.*Key.*Trends.*Changes', + r'##.*Plan.*Suggested.*Actions', + r'##.*Direct.*Guidance.*Physician' + ] + # Check if headers match expected patterns + matches_pattern = all( + re.search(pattern, header, re.IGNORECASE) + for pattern, header in zip(expected_patterns, header_texts) + ) + if matches_pattern: + # Keep the entire content - don't truncate + return '\n'.join(out).strip() + # If we don't have the expected structure, try to find the actual summary content + # Look for the start of the clinical assessment section + clinical_assessment_pattern = r'(?:# Clinical Assessment|## Clinical Assessment|Clinical Assessment)' + for i, line in enumerate(out): + if re.search(clinical_assessment_pattern, line, re.IGNORECASE): + return '\n'.join(out[i:]).strip() + # If no clinical assessment found, return the entire summary + return '\n'.join(out).strip() \ No newline at end of file