sachinchandrankallar commited on
Commit
f91c303
·
1 Parent(s): 8eb4114

patient summary working

Browse files
GGUF_TROUBLESHOOTING.md CHANGED
@@ -34,7 +34,7 @@ Your Hugging Face Space is throwing 500 errors when calling the `generatepatient
34
  - Threading-based timeout (more reliable than signals)
35
 
36
  ### 3. **Memory Optimization**
37
- - Reduced context window from 4096 to 2048 tokens
38
  - Reduced batch size from 128 to 64
39
  - CPU-only mode with optimized thread usage
40
 
 
34
  - Threading-based timeout (more reliable than signals)
35
 
36
  ### 3. **Memory Optimization**
37
+ - Reduced context window from 4096 to 4000 tokens
38
  - Reduced batch size from 128 to 64
39
  - CPU-only mode with optimized thread usage
40
 
TODO.md CHANGED
@@ -7,7 +7,7 @@
7
 
8
  ## Details
9
  - Approximate prompt tokens by word count (split on whitespace)
10
- - Calculate allowed max_tokens = 2048 - prompt_tokens
11
  - Reduce max_tokens if necessary, log warning
12
  - Raise error if prompt too long
13
  - Set n_threads to os.cpu_count() for speed
 
7
 
8
  ## Details
9
  - Approximate prompt tokens by word count (split on whitespace)
10
+ - Calculate allowed max_tokens = 4000 - prompt_tokens
11
  - Reduce max_tokens if necessary, log warning
12
  - Raise error if prompt too long
13
  - Set n_threads to os.cpu_count() for speed
ai_med_extract/__pycache__/app.cpython-311.pyc CHANGED
Binary files a/ai_med_extract/__pycache__/app.cpython-311.pyc and b/ai_med_extract/__pycache__/app.cpython-311.pyc differ
 
ai_med_extract/agents/__pycache__/patient_summary_agent.cpython-311.pyc CHANGED
Binary files a/ai_med_extract/agents/__pycache__/patient_summary_agent.cpython-311.pyc and b/ai_med_extract/agents/__pycache__/patient_summary_agent.cpython-311.pyc differ
 
ai_med_extract/agents/__pycache__/summarizer.cpython-311.pyc CHANGED
Binary files a/ai_med_extract/agents/__pycache__/summarizer.cpython-311.pyc and b/ai_med_extract/agents/__pycache__/summarizer.cpython-311.pyc differ
 
ai_med_extract/agents/patient_summary_agent.py CHANGED
@@ -16,7 +16,7 @@ class PatientSummarizerAgent:
16
  model_name: str = "falconsai/medical_summarization",
17
  model_type: str = "summarization",
18
  device: Optional[str] = None,
19
- max_input_tokens: int = 2048,
20
  max_output_tokens: int = 512
21
  ):
22
  self.model_name = model_name
 
16
  model_name: str = "falconsai/medical_summarization",
17
  model_type: str = "summarization",
18
  device: Optional[str] = None,
19
+ max_input_tokens: int = 4000,
20
  max_output_tokens: int = 512
21
  ):
22
  self.model_name = model_name
ai_med_extract/agents/summarizer.py CHANGED
@@ -14,7 +14,7 @@ class SummarizerAgent:
14
 
15
  # Base parameters
16
  min_length = max(30, min(100, int(word_count * 0.1))) # 10% of word count, min 30, max 100
17
- max_length = max(512, min(2048, int(word_count * 0.5))) # 50% of word count, min 512, max 2048
18
 
19
  # Adjust based on previous summary length to prevent degradation
20
  if self.request_count > 0 and self.last_summary_length > 0:
@@ -90,7 +90,7 @@ class SummarizerAgent:
90
  # Use GGUF's built-in method that handles large inputs and 4-section requirement
91
  summary = model.generate_full_summary(
92
  clean_text,
93
- max_tokens=2048, # Increased to handle larger inputs
94
  max_loops=2
95
  )
96
  else:
 
14
 
15
  # Base parameters
16
  min_length = max(30, min(100, int(word_count * 0.1))) # 10% of word count, min 30, max 100
17
+ max_length = max(512, min(4000, int(word_count * 0.5))) # 50% of word count, min 512, max 4000
18
 
19
  # Adjust based on previous summary length to prevent degradation
20
  if self.request_count > 0 and self.last_summary_length > 0:
 
90
  # Use GGUF's built-in method that handles large inputs and 4-section requirement
91
  summary = model.generate_full_summary(
92
  clean_text,
93
+ max_tokens=4000, # Increased to handle larger inputs
94
  max_loops=2
95
  )
96
  else:
ai_med_extract/api/__pycache__/routes.cpython-311.pyc CHANGED
Binary files a/ai_med_extract/api/__pycache__/routes.cpython-311.pyc and b/ai_med_extract/api/__pycache__/routes.cpython-311.pyc differ
 
ai_med_extract/api/routes.py CHANGED
@@ -1,18 +1,15 @@
1
  """
2
  Medical Data Extraction API Routes
3
-
4
  This module provides Flask API endpoints for medical data processing, including:
5
  - Patient summary generation using various model types (GGUF, OpenVINO, HuggingFace)
6
  - File upload and text extraction
7
  - Medical data extraction from text and audio
8
  - Protected Health Information (PHI) scrubbing
9
  - Model management and dynamic loading
10
-
11
  The API supports multiple model formats and includes comprehensive error handling,
12
  memory optimization, and caching mechanisms for efficient operation in both
13
  local and cloud environments (Hugging Face Spaces).
14
  """
15
-
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
  import json
18
  import logging
@@ -28,6 +25,7 @@ from transformers import (
28
  pipeline as transformers_pipeline
29
  )
30
  from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent
 
31
  agent = PatientSummarizerAgent(model_name="falconsai/medical_summarization")
32
  from ai_med_extract.agents.summarizer import SummarizerAgent
33
  from ai_med_extract.utils.file_utils import (
@@ -37,35 +35,28 @@ from ai_med_extract.utils.file_utils import (
37
  get_data_from_storage,
38
  )
39
  from ..utils.validation import clean_result, validate_patient_name
40
- # from ..utils.patient_summary_utils import clean_patient_data, flatten_to_string_list
41
-
42
- from ai_med_extract.utils.patient_summary_utils import clean_patient_data, flatten_to_string_list
43
  import time
44
-
45
  logger = logging.getLogger(__name__)
46
 
47
  # Add GGUF model cache at the top of the file
48
  GGUF_MODEL_CACHE = {}
 
49
 
50
  def get_gguf_pipeline(model_name: str, filename: str = None):
51
  """
52
  Load and cache GGUF model pipelines with comprehensive error handling.
53
-
54
  This function provides a cached interface to GGUF models with fallback mechanisms
55
  for robust operation in production environments.
56
-
57
  Args:
58
  model_name (str): The name of the GGUF model or HuggingFace repository ID.
59
  Can be a local file path or HuggingFace model identifier.
60
  filename (str, optional): Specific filename for HuggingFace repository models.
61
  Required when model_name is a repository ID.
62
-
63
  Returns:
64
  GGUFModelPipeline: A loaded GGUF model pipeline instance or fallback pipeline.
65
-
66
  Raises:
67
  RuntimeError: If both model loading and fallback mechanisms fail.
68
-
69
  Notes:
70
  - Uses a global cache to avoid reloading the same model multiple times
71
  - Implements timeout mechanism for model loading (5 minutes)
@@ -77,11 +68,9 @@ def get_gguf_pipeline(model_name: str, filename: str = None):
77
  try:
78
  from ai_med_extract.utils.model_loader_gguf import GGUFModelPipeline, create_fallback_pipeline
79
  import time
80
-
81
  # Add timeout for model loading
82
  start_time = time.time()
83
  timeout = 300 # 5 minutes timeout
84
-
85
  # Try to load the GGUF model
86
  try:
87
  GGUF_MODEL_CACHE[key] = GGUFModelPipeline(model_name, filename, timeout=timeout)
@@ -90,55 +79,67 @@ def get_gguf_pipeline(model_name: str, filename: str = None):
90
  except Exception as e:
91
  load_time = time.time() - start_time
92
  print(f"[GGUF] Failed to load model {model_name} after {load_time:.2f}s: {e}")
93
-
94
  # If model loading fails, use fallback
95
  print("[GGUF] Using fallback pipeline")
96
  GGUF_MODEL_CACHE[key] = create_fallback_pipeline()
97
-
98
  except Exception as e:
99
  print(f"[GGUF] Critical error in model loading: {e}")
100
  # Create a basic fallback
101
  from ai_med_extract.utils.model_loader_gguf import create_fallback_pipeline
102
  GGUF_MODEL_CACHE[key] = create_fallback_pipeline()
103
-
104
  return GGUF_MODEL_CACHE[key]
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def get_qa_pipeline(qa_model_type, qa_model_name):
108
  if not qa_model_type or not qa_model_name:
109
  raise ValueError("Both qa_model_type and qa_model_name must be provided")
110
-
111
-
112
  if not hasattr(get_qa_pipeline, "cache"):
113
  get_qa_pipeline.cache = {}
114
-
115
  # For Hugging Face Spaces, we need to be memory efficient
116
  import torch
117
  torch.cuda.empty_cache() # Clear GPU memory before loading model
118
-
119
  # Set default tensor type to float32 for better compatibility
120
  torch.set_default_tensor_type(torch.FloatTensor)
121
  if torch.cuda.is_available():
122
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
123
-
124
  key = (qa_model_type, qa_model_name)
125
  if key in get_qa_pipeline.cache:
126
  return get_qa_pipeline.cache[key]
127
-
128
  try:
129
  # For Hugging Face Spaces, use smaller models by default
130
  if "Qwen/Qwen-7B-Chat" in qa_model_name:
131
  qa_model_name = "Qwen/Qwen-1_8B-Chat"
132
  elif "Llama" in qa_model_name:
133
  qa_model_name = "facebook/opt-125m"
134
-
135
  # Load tokenizer with trust_remote_code=True for custom tokenizers
136
  tokenizer = AutoTokenizer.from_pretrained(
137
  qa_model_name,
138
  trust_remote_code=True,
139
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
140
  )
141
-
142
  # Load model with memory optimizations
143
  try:
144
  model = AutoModelForCausalLM.from_pretrained(
@@ -160,7 +161,6 @@ def get_qa_pipeline(qa_model_type, qa_model_name):
160
  low_cpu_mem_usage=True,
161
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
162
  )
163
-
164
  # Create pipeline with memory optimizations
165
  pipeline = transformers_pipeline(
166
  task=qa_model_type,
@@ -169,10 +169,8 @@ def get_qa_pipeline(qa_model_type, qa_model_name):
169
  device_map="auto",
170
  torch_dtype=torch.float32
171
  )
172
-
173
  get_qa_pipeline.cache[key] = pipeline
174
  return pipeline
175
-
176
  except Exception as e:
177
  raise
178
 
@@ -182,14 +180,11 @@ def run_qa_pipeline(qa_pipeline, question, context):
182
  """
183
  if not qa_pipeline or not question or not context:
184
  raise ValueError("Pipeline, question and context are required")
185
-
186
  qa_model_type = getattr(qa_pipeline, '_qa_model_type', None)
187
-
188
  try:
189
  if qa_model_type == 'text-generation':
190
  prompt = f"Question: {question}\nContext: {context}\nAnswer:"
191
  result = qa_pipeline(prompt, max_new_tokens=128, do_sample=False)
192
-
193
  if isinstance(result, list) and result and 'generated_text' in result[0]:
194
  answer = result[0]['generated_text'].split('Answer:')[-1].strip()
195
  return {'answer': answer}
@@ -203,30 +198,23 @@ def run_qa_pipeline(qa_pipeline, question, context):
203
  def get_ner_pipeline(ner_model_type, ner_model_name):
204
  if not ner_model_type or not ner_model_name:
205
  raise ValueError("Both ner_model_type and ner_model_name must be provided")
206
-
207
  if not hasattr(get_ner_pipeline, "cache"):
208
  get_ner_pipeline.cache = {}
209
-
210
  # For Hugging Face Spaces, we need to be memory efficient
211
  import torch
212
  torch.cuda.empty_cache() # Clear GPU memory before loading model
213
-
214
  # Set default tensor type
215
  torch.set_default_tensor_type(torch.FloatTensor)
216
  if torch.cuda.is_available():
217
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
218
-
219
  key = (ner_model_type, ner_model_name)
220
  if key in get_ner_pipeline.cache:
221
  return get_ner_pipeline.cache[key]
222
-
223
  try:
224
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
225
-
226
  # Clear any existing models from memory
227
  if torch.cuda.is_available():
228
  torch.cuda.empty_cache()
229
-
230
  # Load tokenizer
231
  try:
232
  tokenizer = AutoTokenizer.from_pretrained(
@@ -242,7 +230,6 @@ def get_ner_pipeline(ner_model_type, ner_model_name):
242
  trust_remote_code=True,
243
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
244
  )
245
-
246
  # Load model with memory optimizations
247
  try:
248
  # For NER models, we'll use CPU if device_map='auto' is not supported
@@ -276,7 +263,6 @@ def get_ner_pipeline(ner_model_type, ner_model_name):
276
  torch_dtype=torch.float32,
277
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
278
  )
279
-
280
  # Create pipeline with appropriate device configuration
281
  try:
282
  qa_pipeline = pipeline(
@@ -297,15 +283,12 @@ def get_ner_pipeline(ner_model_type, ner_model_name):
297
  )
298
  else:
299
  raise
300
-
301
  # Cache the pipeline
302
  get_ner_pipeline.cache[key] = qa_pipeline
303
  return qa_pipeline
304
-
305
  except Exception as e:
306
  raise
307
 
308
-
309
  def get_summarizer_pipeline(summarizer_model_type, summarizer_model_name):
310
  if not hasattr(get_summarizer_pipeline, "cache"):
311
  get_summarizer_pipeline.cache = {}
@@ -313,7 +296,6 @@ def get_summarizer_pipeline(summarizer_model_type, summarizer_model_name):
313
  if key not in get_summarizer_pipeline.cache:
314
  import torch
315
  from transformers import pipeline
316
-
317
  # Use float16 only if CUDA is available, else use float32
318
  if torch.cuda.is_available():
319
  dtype = torch.float16
@@ -323,7 +305,6 @@ def get_summarizer_pipeline(summarizer_model_type, summarizer_model_name):
323
  dtype = torch.float32
324
  device = -1
325
  device_map = None
326
-
327
  get_summarizer_pipeline.cache[key] = pipeline(
328
  task=summarizer_model_type,
329
  model=summarizer_model_name,
@@ -334,7 +315,6 @@ def get_summarizer_pipeline(summarizer_model_type, summarizer_model_name):
334
  )
335
  return get_summarizer_pipeline.cache[key]
336
 
337
-
338
  def register_routes(app, agents):
339
  from ai_med_extract.utils.openvino_summarizer_utils import (
340
  parse_ehr_chartsummarydtl, visits_sorted, compute_deltas, build_compact_baseline, delta_to_text, build_main_prompt
@@ -353,10 +333,8 @@ def register_routes(app, agents):
353
  chartsummarydtl = ehr_result.get("chartsummarydtl") if isinstance(ehr_result, dict) else None
354
  if not chartsummarydtl:
355
  return jsonify({"error": "Missing chartsummarydtl in input"}), 400
356
-
357
  # Normalize visits
358
  visits = parse_ehr_chartsummarydtl(chartsummarydtl)
359
-
360
  # Extract patient demographics if available
361
  patient_info = ""
362
  if isinstance(ehr_result, dict):
@@ -367,7 +345,6 @@ def register_routes(app, agents):
367
  past_medical_history = ', '.join(ehr_result.get('past_medical_history', []))
368
  social_history = ehr_result.get('social_history', 'Not specified')
369
  patient_info = f"Patient: {patient_name} (ID: {patient_id}, Age: {age}, Gender: {gender})\nPast Medical History: {past_medical_history}\nSocial History: {social_history}\n"
370
-
371
  # Generate summary from current data only (no state tracking)
372
  # Use empty old visits to compute deltas against baseline
373
  delta = compute_deltas([], visits)
@@ -375,7 +352,6 @@ def register_routes(app, agents):
375
  baseline = build_compact_baseline(all_visits)
376
  delta_text = delta_to_text(delta)
377
  prompt = build_main_prompt(baseline, delta_text, patient_info)
378
-
379
  # Model selection logic (model_name, model_type)
380
  model_name = data.get("model_name") or "microsoft/Phi-3-mini-4k-instruct"
381
  model_type = data.get("model_type") or "text-generation"
@@ -387,14 +363,12 @@ def register_routes(app, agents):
387
  pipeline = loader.model_loader.load() if hasattr(loader, "model_loader") else None
388
  if not pipeline:
389
  return jsonify({"error": "Model pipeline not available"}), 500
390
-
391
  # Run inference
392
  import torch
393
  torch.set_num_threads(2)
394
  inputs = pipeline.tokenizer([prompt], return_tensors="pt")
395
  outputs = pipeline.model.generate(**inputs, max_new_tokens=100000, do_sample=False, pad_token_id=pipeline.tokenizer.eos_token_id or 32000)
396
  text = pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True)
397
-
398
  # Extract just the markdown summary (remove prompt text)
399
  # The model should return the complete markdown-formatted summary
400
  summary_start_patterns = [
@@ -403,13 +377,11 @@ def register_routes(app, agents):
403
  "# Clinical Assessment",
404
  "Clinical Assessment"
405
  ]
406
-
407
  new_summary = text
408
  for pattern in summary_start_patterns:
409
  if pattern in text:
410
  new_summary = text.split(pattern)[-1].strip()
411
  break
412
-
413
  return jsonify({
414
  "summary": new_summary,
415
  "baseline": baseline,
@@ -417,15 +389,15 @@ def register_routes(app, agents):
417
  }), 200
418
  except Exception as e:
419
  return jsonify({"error": f"Failed to generate summary: {str(e)}"}), 500
420
- # Configure upload directory based on environment import os
421
 
 
 
422
  if os.environ.get('SPACE_ID'): # We're running on Hugging Face Spaces
423
  app.config['UPLOAD_FOLDER'] = '/data/uploads'
424
  else: # We're running locally
425
  upload_dir = os.path.join(os.getcwd(), 'uploads')
426
  os.makedirs(upload_dir, exist_ok=True)
427
  app.config['UPLOAD_FOLDER'] = upload_dir
428
-
429
  # Ensure the upload directory exists and is writable
430
  if not os.path.exists(app.config['UPLOAD_FOLDER']):
431
  try:
@@ -444,11 +416,9 @@ def register_routes(app, agents):
444
  def upload_file():
445
  import torch
446
  torch.cuda.empty_cache() # Clear GPU memory before processing
447
-
448
  files = request.files.getlist("file")
449
  patient_name = request.form.get("patient_name", "").strip()
450
  password = request.form.get("password")
451
-
452
  # Use more compatible models by default
453
  qa_model_name = request.form.get("qa_model_name", "facebook/bart-base")
454
  qa_model_type = request.form.get("qa_model_type", "text-generation")
@@ -456,10 +426,8 @@ def register_routes(app, agents):
456
  ner_model_type = request.form.get("ner_model_type", "ner")
457
  summarizer_model_name = request.form.get("summarizer_model_name", "facebook/bart-base")
458
  summarizer_model_type = request.form.get("summarizer_model_type", "summarization")
459
-
460
  if not files:
461
  return jsonify({"error": "No file uploaded"}), 400
462
-
463
  # Accept any model type and model name for QA, NER, and summarizer
464
  if not qa_model_name or not qa_model_type:
465
  return jsonify({"error": "QA model name and type are required"}), 400
@@ -467,21 +435,18 @@ def register_routes(app, agents):
467
  qa_pipeline = get_qa_pipeline(qa_model_type, qa_model_name)
468
  except Exception as e:
469
  return jsonify({"error": f"QA model load failed: {str(e)}"}), 500
470
-
471
  if not ner_model_name or not ner_model_type:
472
  return jsonify({"error": "NER model name and type are required"}), 400
473
  try:
474
  ner_pipeline = get_ner_pipeline(ner_model_type, ner_model_name)
475
  except Exception as e:
476
  return jsonify({"error": f"NER model load failed: {str(e)}"}), 500
477
-
478
  if not summarizer_model_name or not summarizer_model_type:
479
  return jsonify({"error": "Summarizer model name and type are required"}), 400
480
  try:
481
  summarizer_pipeline = get_summarizer_pipeline(summarizer_model_type, summarizer_model_name)
482
  except Exception as e:
483
  return jsonify({"error": f"Summarizer model load failed: {str(e)}"}), 500
484
-
485
  extracted_data = []
486
  for file in files:
487
  if file.filename == "":
@@ -514,7 +479,6 @@ def register_routes(app, agents):
514
  except Exception as e:
515
  os.remove(filepath) # Clean up on failure
516
  return jsonify({"error": f"Text extraction failed: {str(e)}"}), 500
517
-
518
  skip_medical_check = (
519
  request.form.get("skip_medical_check", "false").lower() == "true"
520
  )
@@ -650,13 +614,11 @@ def register_routes(app, agents):
650
  file = request.files["file"]
651
  if file.filename == "":
652
  return jsonify({"error": "No selected file"}), 400
653
-
654
  # Use secure filename
655
  from werkzeug.utils import secure_filename
656
  import uuid
657
  temp_filename = f"{uuid.uuid4()}_{secure_filename(file.filename)}"
658
  temp_path = os.path.join(app.config['UPLOAD_FOLDER'], temp_filename)
659
-
660
  file.save(temp_path)
661
  result = whisper_model.transcribe(temp_path)
662
  os.remove(temp_path)
@@ -666,7 +628,6 @@ def register_routes(app, agents):
666
  os.remove(temp_path)
667
  return jsonify({"error": str(e)}), 500
668
 
669
-
670
  def group_by_category(data):
671
  grouped = defaultdict(list)
672
  for item in data:
@@ -678,20 +639,17 @@ def register_routes(app, agents):
678
  "answer": item.get("answer", "Not Available"),
679
  }
680
  )
681
-
682
  return [{"category": k, "detail": v} for k, v in grouped.items()]
683
 
684
  def deduplicate_extractions(data):
685
  seen = set()
686
  reversed_unique = []
687
-
688
  # Loop in reverse to keep the *last* occurrence
689
  for item in reversed(data):
690
  key = (item.get("label"))
691
  if key not in seen:
692
  seen.add(key)
693
  reversed_unique.append(item)
694
-
695
  # Reverse back to preserve original order (latest kept, first dropped)
696
  return list(reversed(reversed_unique))
697
 
@@ -701,24 +659,19 @@ def register_routes(app, agents):
701
  text,
702
  add_special_tokens=False
703
  )
704
-
705
  chunks = []
706
  start = 0
707
-
708
  while start < len(input_ids):
709
  end = min(start + max_tokens, len(input_ids))
710
  chunk_ids = input_ids[start:end]
711
-
712
  chunk_text = tokenizer.decode(
713
  chunk_ids,
714
  skip_special_tokens=True,
715
  clean_up_tokenization_spaces=True
716
  )
717
-
718
  # Ensure partial continuation isn't cut off mid-sentence
719
  if not chunk_text.endswith(('.', '?', '!', ':')):
720
  chunk_text += "..."
721
-
722
  chunks.append(chunk_text)
723
  start += max_tokens - overlap
724
  return chunks
@@ -731,7 +684,6 @@ def register_routes(app, agents):
731
  except ValueError:
732
  # '[' not found in output
733
  return []
734
-
735
  # Try parsing full array first
736
  try:
737
  parsed = json.loads(json_text)
@@ -739,7 +691,6 @@ def register_routes(app, agents):
739
  return parsed
740
  except Exception:
741
  pass # fallback to manual parsing
742
-
743
  # Manual recovery via brace matching
744
  stack = 0
745
  obj_start = None
@@ -758,15 +709,12 @@ def register_routes(app, agents):
758
  except Exception as e:
759
  print(f"❌ Invalid JSON object: {e}")
760
  obj_start = None
761
-
762
  return extracted
763
 
764
-
765
  def process_chunk(generator, chunk, idx):
766
  prompt = f"""
767
  [INST] <<SYS>>
768
  You are a clinical data extraction assistant.
769
-
770
  Your job is to:
771
  1. Read the following medical report.
772
  2. Extract all medically relevant facts as a list of JSON objects.
@@ -775,7 +723,6 @@ def register_routes(app, agents):
775
  - "question": a question related to that field
776
  - "answer": the answer from the text
777
  4. After extracting the list, categorize each object under one of the following fixed categories:
778
-
779
  - Patient Info
780
  - Vitals
781
  - Symptoms
@@ -787,7 +734,6 @@ def register_routes(app, agents):
787
  - Laboratory
788
  - Radiology
789
  - Doctor Note
790
-
791
  Example format for structure only — do not include in output:
792
  [
793
  {{
@@ -797,22 +743,17 @@ def register_routes(app, agents):
797
  "category": "Patient Info"
798
  }},
799
  ]
800
-
801
- ⚠ Use these categories listed above. If an item does not fit any of these categories, create a new category for it.
802
-
803
  Text:
804
  {chunk}
805
-
806
  Return a single valid JSON array of all extracted objects.
807
  Do not include any explanations or commentary.
808
  Only output the JSON array
809
  <</SYS>> [/INST]
810
  """
811
-
812
  try:
813
  # Clear GPU memory before processing
814
  torch.cuda.empty_cache()
815
-
816
  # Process with memory optimizations
817
  output = generator(
818
  prompt,
@@ -820,31 +761,26 @@ def register_routes(app, agents):
820
  do_sample=False, # Disable sampling for deterministic output
821
  temperature=0.3, # Lower temperature for more focused output
822
  )[0]["generated_text"]
823
-
824
  return idx, output
825
  except Exception as e:
826
  return idx, None
827
-
828
  @app.route("/extract_medical_data", methods=["POST"])
829
  def extract_medical_data():
830
  data = request.json
831
  qa_model_name = data.get("qa_model_name")
832
  qa_model_type = data.get("qa_model_type")
833
  extracted_files = data.get("extracted_data")
834
-
835
  if not qa_model_name or not qa_model_type:
836
  return jsonify({"error": "Missing 'qa_model_name' or 'qa_model_type'"}), 400
837
-
838
  if not extracted_files:
839
  return jsonify({"error": "Missing 'extracted_data' in request"}), 400
840
-
841
  try:
842
  tokenizer = AutoTokenizer.from_pretrained(
843
  qa_model_name,
844
  trust_remote_code=True,
845
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
846
  )
847
-
848
  model = AutoModelForCausalLM.from_pretrained(
849
  qa_model_name,
850
  device_map="auto",
@@ -853,32 +789,25 @@ def register_routes(app, agents):
853
  low_cpu_mem_usage=True,
854
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
855
  )
856
-
857
  generator = transformers_pipeline(
858
  task=qa_model_type,
859
  model=model,
860
  tokenizer=tokenizer,
861
  torch_dtype=torch.float32
862
  )
863
-
864
  except Exception as e:
865
  return jsonify({"error": f"Could not load model: {str(e)}"}), 500
866
-
867
  structured_response = {"extracted_data": []}
868
-
869
  for file_data in extracted_files:
870
  filename = file_data.get("file", "unknown_file")
871
  context = file_data.get("extracted_text", "").strip()
872
-
873
  if not context:
874
  structured_response["extracted_data"].append(
875
  {"file": filename, "medical_fields": []}
876
  )
877
  continue
878
-
879
  chunks = chunk_text(context, tokenizer)
880
  all_extracted = []
881
-
882
  with ThreadPoolExecutor(max_workers=4) as executor:
883
  futures = {
884
  executor.submit(process_chunk, generator, chunk, idx): idx
@@ -887,19 +816,16 @@ def register_routes(app, agents):
887
  for future in as_completed(futures):
888
  idx = futures[future]
889
  _, output = future.result()
890
-
891
  if not output:
892
  continue
893
-
894
  try:
895
  objs = extract_json_objects(output)
896
  if objs:
897
  all_extracted.extend(objs)
898
  else:
899
- print(f"⚠ Chunk {idx+1} yielded no valid JSON.")
900
  except Exception as e:
901
  print(f"❌ Error extracting JSON from chunk {idx+1}")
902
-
903
  # Clean and group results for this file
904
  if all_extracted:
905
  deduped = deduplicate_extractions(all_extracted)
@@ -907,20 +833,16 @@ def register_routes(app, agents):
907
  grouped_data = group_by_category(deduped)
908
  else:
909
  grouped_data = {"error": "No valid data extracted"}
910
-
911
  structured_response["extracted_data"].append(
912
  {"file": filename, "medical_fields": grouped_data}
913
  )
914
-
915
  try:
916
  save_data_to_storage(filename, grouped_data)
917
  except Exception as e:
918
- print(f"⚠ Failed to save data for {filename}: {e}")
919
-
920
  print("✅ Extraction complete.")
921
  return jsonify(structured_response)
922
 
923
-
924
  @app.route("/api/generate_summary", methods=["POST"])
925
  def generate_summary():
926
  logger.info("Received request to generate summary.")
@@ -951,34 +873,27 @@ def register_routes(app, agents):
951
  torch.set_default_tensor_type(torch.FloatTensor)
952
  if torch.cuda.is_available():
953
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
954
-
955
  # Handle multipart form data from Flutter
956
  if "audio" not in request.files:
957
  return jsonify({"error": "No audio file provided"}), 400
958
-
959
  audio_file = request.files["audio"]
960
  if audio_file.filename == "":
961
  return jsonify({"error": "No selected audio file"}), 400
962
-
963
  # Validate file extension
964
  if not allowed_file(audio_file.filename):
965
  return jsonify({"error": f"Unsupported audio format. Allowed formats: wav, mp3, m4a, ogg"}), 400
966
-
967
  # Check file size
968
  valid_size, error_message = check_file_size(audio_file)
969
  if not valid_size:
970
  return jsonify({"error": error_message}), 400
971
-
972
  # Use default model if not specified
973
  qa_model_name = request.form.get("qa_model_name", "facebook/bart-base")
974
  qa_model_type = request.form.get("qa_model_type", "text-generation")
975
-
976
  # Load QA model with proper error handling
977
  try:
978
  qa_pipeline = get_qa_pipeline(qa_model_type, qa_model_name)
979
  except Exception as e:
980
  return jsonify({"error": f"QA model load failed: {str(e)}"}), 500
981
-
982
  # Use platform-agnostic temp directory
983
  import uuid
984
  from werkzeug.utils import secure_filename
@@ -987,10 +902,8 @@ def register_routes(app, agents):
987
  os.makedirs(temp_dir, exist_ok=True)
988
  temp_filename = f"{uuid.uuid4()}_{secure_filename(audio_file.filename)}"
989
  temp_path = os.path.join(temp_dir, temp_filename)
990
-
991
  try:
992
  audio_file.save(temp_path)
993
-
994
  # Transcribe audio with retries
995
  max_retries = 3
996
  transcribed_text = None
@@ -1007,16 +920,13 @@ def register_routes(app, agents):
1007
  raise
1008
  torch.cuda.empty_cache() # Clear GPU memory between attempts
1009
  continue
1010
-
1011
  if not transcribed_text:
1012
  raise ValueError("Failed to transcribe audio after multiple attempts")
1013
-
1014
  # Clean and process text
1015
  try:
1016
  clean_text = PHIScrubberAgent.scrub_phi(transcribed_text)
1017
  except Exception as e:
1018
  clean_text = transcribed_text
1019
-
1020
  # Extract medical data with proper device handling
1021
  try:
1022
  with torch.cuda.device(0) if torch.cuda.is_available() else torch.no_grad():
@@ -1025,11 +935,9 @@ def register_routes(app, agents):
1025
  medical_data = medical_data_extractor.extract_medical_data(clean_text)
1026
  except Exception as e:
1027
  medical_data = {"error": f"Medical data extraction failed: {str(e)}"}
1028
-
1029
  # Clean up temporary file
1030
  if os.path.exists(temp_path):
1031
  os.remove(temp_path)
1032
-
1033
  # Return response in the format expected by Flutter
1034
  return jsonify({
1035
  "status": "success",
@@ -1038,7 +946,6 @@ def register_routes(app, agents):
1038
  "medical_chart": medical_data
1039
  }
1040
  }), 200
1041
-
1042
  except Exception as e:
1043
  if temp_path and os.path.exists(temp_path):
1044
  os.remove(temp_path)
@@ -1046,7 +953,6 @@ def register_routes(app, agents):
1046
  "status": "error",
1047
  "error": f"Processing failed: {str(e)}"
1048
  }), 500
1049
-
1050
  except Exception as e:
1051
  if temp_path and os.path.exists(temp_path):
1052
  os.remove(temp_path)
@@ -1055,22 +961,17 @@ def register_routes(app, agents):
1055
  "error": f"Request handling failed: {str(e)}"
1056
  }), 500
1057
 
1058
-
1059
-
1060
- # Initialize GGUF pipeline with proper model name handling
1061
- gguf_model_name = "microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf"
1062
- if gguf_model_name.endswith('.gguf') and '/' in gguf_model_name:
1063
- repo_id, filename = gguf_model_name.rsplit('/', 1)
1064
- PIPELINE = get_gguf_pipeline(repo_id, filename)
1065
- else:
1066
- PIPELINE = get_gguf_pipeline(gguf_model_name)
1067
- _ = PIPELINE.generate("Hello", max_tokens=5)
1068
-
1069
  @app.route('/generate_patient_summary', methods=['POST'])
1070
  def generate_patient_summary():
1071
  """
1072
- Enhanced: Uses OpenVINO-style prompt, delta, and validation logic for patient summary generation.
1073
- Generates fresh summary every time without state tracking.
 
 
 
 
 
1074
  """
1075
  from ai_med_extract.utils.openvino_summarizer_utils import (
1076
  parse_ehr_chartsummarydtl, visits_sorted, compute_deltas, build_compact_baseline, delta_to_text, build_main_prompt
@@ -1078,217 +979,488 @@ def register_routes(app, agents):
1078
  try:
1079
  start_total = time.time()
1080
  data = request.get_json()
1081
- t0 = time.time()
1082
  patientid = data.get("patientid")
1083
  token = data.get("token")
1084
  key = data.get("key")
1085
- model_name = data.get("patient_summarizer_model_name") or "falconsai/medical_summarization"
1086
- model_type = data.get("patient_summarizer_model_type") or data.get("model_type") or "summarization"
 
 
 
 
 
 
 
 
 
 
 
 
 
1087
  if not patientid or not token or not key:
1088
  return jsonify({"error": "Missing required fields: patientid, token, or key"}), 400
1089
 
1090
- api_url = f"{key}/Transactionapi/api/PatientList/patientsummary"
1091
  headers = {
1092
  "Authorization": f"Bearer {token}",
1093
  "Content-Type": "application/json",
1094
  }
1095
- # Only include x-api-key if it's a distinct API key, not a base URL
1096
  if key and not key.startswith("http"):
1097
  headers["x-api-key"] = key
 
 
1098
  t_api_start = time.time()
1099
- response = requests.post(api_url, json={"patientid": patientid}, headers=headers, timeout=30)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1100
  t_api_end = time.time()
 
1101
  if response.status_code != 200:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1102
  return jsonify({
1103
- "error": "API request failed",
1104
- "status": response.status_code,
1105
- "message": response.text
1106
- }), 502
 
 
1107
  try:
1108
  api_data = response.json()
1109
  except ValueError:
1110
- api_data = response.text
1111
- if isinstance(api_data, dict):
1112
- ehr_result = api_data.get("result") or api_data
1113
- else:
1114
- ehr_result = api_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1115
  chartsummarydtl = ehr_result.get("chartsummarydtl") if isinstance(ehr_result, dict) else None
1116
  if not chartsummarydtl:
1117
- # Return diagnostics to aid debugging on Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
1118
  return jsonify({
1119
- "error": "Missing chartsummarydtl in EHR response",
1120
- "diagnostic": {
1121
- "api_url": api_url,
1122
- "status": response.status_code,
1123
- "content_type": response.headers.get("content-type"),
1124
- "body_preview": (response.text[:500] if hasattr(response, "text") else str(api_data))
1125
- }
1126
  }), 500
 
 
1127
  visits = parse_ehr_chartsummarydtl(chartsummarydtl)
1128
- # Generate summary from current data only (no state tracking)
1129
- # Use empty old visits to compute deltas against baseline
1130
  delta = compute_deltas([], visits)
1131
  all_visits = visits_sorted(visits)
1132
  baseline = build_compact_baseline(all_visits)
1133
  delta_text = delta_to_text(delta)
1134
- prompt = build_main_prompt(baseline, delta_text)
1135
- t_model_load_start = time.time()
1136
- # Model selection logic (supporting OpenVINO, HuggingFace, and GGUF)
1137
- pipeline = None
1138
- loader = None
1139
- import torch
1140
- torch.set_num_threads(2)
1141
  if model_type == "gguf":
1142
- logger.info("Using GGUF model for summary generation.")
1143
  try:
1144
- # Support both local path and HuggingFace repo/filename
1145
- if model_name.endswith('.gguf') and '/' in model_name:
 
1146
  repo_id, filename = model_name.rsplit('/', 1)
1147
- pipeline = get_gguf_pipeline(repo_id, filename)
1148
- else:
1149
- pipeline = get_gguf_pipeline(model_name)
1150
-
1151
- logger.info(f"Prompt length for GGUF model: {len(prompt)} characters.")
1152
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1153
  try:
1154
- # The timeout is now handled internally by the pipeline
1155
- summary_raw = pipeline.generate_full_summary(prompt, max_tokens=100000, max_loops=5)
1156
- # Extract markdown summary directly from model output
1157
- summary_start_patterns = [
1158
- "Now generate the complete, updated clinical summary with all four sections in a markdown format:",
 
 
 
 
 
 
 
 
 
 
1159
  "## Clinical Assessment",
1160
- "# Clinical Assessment",
1161
- "Clinical Assessment"
 
1162
  ]
 
 
1163
 
1164
- markdown_summary = summary_raw
1165
- for pattern in summary_start_patterns:
1166
- if pattern in summary_raw:
1167
- markdown_summary = summary_raw.split(pattern)[-1].strip()
1168
- break
1169
-
1170
- # No state tracking - just return the summary
1171
- # Remove undefined timing variables and only log steps that are actually measured
1172
- total_time = time.time() - start_total
1173
- logger.info(f"[TIMING] API call: {t_api_end-t_api_start:.2f}s, TOTAL: {total_time:.2f}s")
1174
- return jsonify({
1175
- "summary": markdown_summary,
1176
- "baseline": baseline,
1177
- "delta": delta_text
1178
- }), 200
1179
- except TimeoutError as e:
1180
- logger.error(f"GGUF model generation timed out: {e}")
1181
- # Try to use a simpler fallback model
1182
- try:
1183
- from ai_med_extract.utils.model_loader_gguf import create_fallback_pipeline
1184
- fallback_pipeline = create_fallback_pipeline()
1185
- fallback_summary = fallback_pipeline.generate_full_summary(prompt)
1186
- # Extract markdown summary directly from fallback output
1187
- summary_start_patterns = [
1188
- "Now generate the complete, updated clinical summary with all four sections in a markdown format:",
1189
- "## Clinical Assessment",
1190
- "# Clinical Assessment",
1191
- "Clinical Assessment"
1192
- ]
1193
-
1194
- markdown_summary = fallback_summary
1195
- for pattern in summary_start_patterns:
1196
- if pattern in fallback_summary:
1197
- markdown_summary = fallback_summary.split(pattern)[-1].strip()
1198
  break
1199
-
1200
- return jsonify({
1201
- "summary": markdown_summary,
1202
- "baseline": baseline,
1203
- "delta": delta_text,
1204
- "warning": "GGUF model timed out, using fallback summary"
1205
- }), 200
1206
- except Exception as fallback_error:
1207
- return jsonify({
1208
- "error": f"GGUF model generation timed out and fallback failed: {str(e)}",
1209
- "original_error": str(e)
1210
- }), 408
1211
- except Exception as e:
1212
- logger.error(f"GGUF model generation failed: {e}")
1213
- return jsonify({"error": f"GGUF model generation failed: {str(e)}"}), 500
1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1215
  except Exception as e:
1216
- logger.error(f"Failed to load GGUF model: {e}")
1217
- # Try to use fallback pipeline
1218
- try:
1219
- from ai_med_extract.utils.model_loader_gguf import create_fallback_pipeline
1220
- fallback_pipeline = create_fallback_pipeline()
1221
- fallback_summary = fallback_pipeline.generate_full_summary(prompt)
1222
- markdown_summary =fallback_summary
1223
- # summary_to_markdown(fallback_summary)
1224
- return jsonify({
1225
- "summary": markdown_summary,
1226
- "baseline": baseline,
1227
- "delta": delta_text,
1228
- "warning": "GGUF model failed to load, using fallback summary"
1229
- }), 200
1230
- except Exception as fallback_error:
1231
- return jsonify({
1232
- "error": f"Failed to load GGUF model and fallback failed: {str(e)}",
1233
- "original_error": str(e)
1234
- }), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
1235
  elif model_type in {"text-generation", "causal-openvino"}:
1236
- # Try to use an existing loader if available
1237
- loader = agents.get("medical_data_extractor")
1238
- if not loader or getattr(loader, 'model_name', None) != model_name:
1239
- # Dynamically create OpenVINO loader if needed
1240
- from ai_med_extract.utils.model_loader_spaces import get_openvino_pipeline
1241
- try:
1242
  pipeline = get_openvino_pipeline(model_name)
1243
- except Exception as e:
1244
- return jsonify({"error": f"Failed to load OpenVINO pipeline: {str(e)}"}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1245
  elif model_type == "summarization":
1246
- loader = agents.get("summarizer")
1247
- # Use loader if available
1248
- if not pipeline and loader and hasattr(loader, "model_loader"):
1249
- pipeline = loader.model_loader.load()
1250
- if not pipeline:
1251
- return jsonify({"error": "Model pipeline not available"}), 500
1252
- inputs = pipeline.tokenizer([prompt], return_tensors="pt")
1253
- outputs = pipeline.model.generate(**inputs, max_new_tokens=100000, do_sample=False, pad_token_id=pipeline.tokenizer.eos_token_id or 32000)
1254
- text = pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True)
1255
- new_summary = text.split("Now generate the complete, updated clinical summary with all four sections in a markdown format:")[-1].strip()
1256
- # For other models, after extracting new_summary:
1257
- from ai_med_extract.api.routes import summary_to_markdown
1258
- logger.info(f"Baseline length: {len(baseline)} characters.")
1259
- logger.info(f"Delta text length: {len(delta_text)} characters.")
1260
- logger.info(f"Raw summary length: {len(new_summary)} characters.")
1261
- markdown_summary = summary_to_markdown(new_summary)
1262
- logger.info(f"Formatted summary length: {len(markdown_summary)} characters.")
1263
-
1264
- # Validate and ensure the summary has all 4 required sections
1265
- markdown_summary = ensure_four_sections(markdown_summary)
1266
- # Remove undefined timing variables and only log steps that are actually measured
1267
- total_time = time.time() - start_total
1268
- print(f"[TIMING] API call: {t_api_end-t_api_start:.2f}s, TOTAL: {total_time:.2f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1269
  return jsonify({
1270
- "summary": markdown_summary,
1271
- "baseline": baseline,
1272
- "delta": delta_text
 
 
1273
  }), 200
1274
- except requests.exceptions.Timeout:
1275
- return jsonify({"error": "Request to EHR API timed out"}), 504
1276
- except requests.exceptions.RequestException as e:
1277
- return jsonify({"error": f"Network error: {str(e)}"}), 503
1278
- except Exception as e:
1279
- logger.error(f"Unexpected error: {str(e)}", exc_info=True)
1280
- return jsonify({"error": f"Internal server error: {str(e)}"}), 500
1281
-
1282
  @app.route("/")
1283
  def home():
1284
  return "Medical Data Extraction API is running!", 200
1285
 
1286
-
1287
  def summary_to_markdown(summary):
1288
  import re
1289
  # Remove '- answer:' and similar artifacts
1290
  summary = re.sub(r'-\s*answer: ?', '', summary, flags=re.IGNORECASE)
1291
-
1292
  # Convert numbered sections to markdown headers
1293
  lines = summary.splitlines()
1294
  out = []
@@ -1298,7 +1470,6 @@ def summary_to_markdown(summary):
1298
  '3.': '##',
1299
  '4.': '##',
1300
  }
1301
-
1302
  for line in lines:
1303
  m = re.match(r'^(\d\.)\s*(.+)', line)
1304
  if m and m.group(1) in section_map:
@@ -1306,18 +1477,14 @@ def summary_to_markdown(summary):
1306
  out.append(f"{header} {m.group(2).strip()}")
1307
  else:
1308
  out.append(line)
1309
-
1310
  # Remove empty lines at the start
1311
  while out and not out[0].strip():
1312
  out = out[1:]
1313
-
1314
  # Check if we have the expected 4-section structure
1315
  def is_header(line: str) -> bool:
1316
  return bool(re.match(r'^(#{1,6})\s+.+', line.strip()))
1317
-
1318
  # Find all headers in the output
1319
  headers = [i for i, line in enumerate(out) if is_header(line)]
1320
-
1321
  # If we have at least 4 headers, check if they match the expected structure
1322
  if len(headers) >= 4:
1323
  header_texts = [out[i].strip() for i in headers[:4]]
@@ -1327,23 +1494,19 @@ def summary_to_markdown(summary):
1327
  r'##.*Plan.*Suggested.*Actions',
1328
  r'##.*Direct.*Guidance.*Physician'
1329
  ]
1330
-
1331
  # Check if headers match expected patterns
1332
  matches_pattern = all(
1333
  re.search(pattern, header, re.IGNORECASE)
1334
  for pattern, header in zip(expected_patterns, header_texts)
1335
  )
1336
-
1337
  if matches_pattern:
1338
  # Keep the entire content - don't truncate
1339
  return '\n'.join(out).strip()
1340
-
1341
  # If we don't have the expected structure, try to find the actual summary content
1342
  # Look for the start of the clinical assessment section
1343
  clinical_assessment_pattern = r'(?:# Clinical Assessment|## Clinical Assessment|Clinical Assessment)'
1344
  for i, line in enumerate(out):
1345
  if re.search(clinical_assessment_pattern, line, re.IGNORECASE):
1346
  return '\n'.join(out[i:]).strip()
1347
-
1348
  # If no clinical assessment found, return the entire summary
1349
- return '\n'.join(out).strip()
 
1
  """
2
  Medical Data Extraction API Routes
 
3
  This module provides Flask API endpoints for medical data processing, including:
4
  - Patient summary generation using various model types (GGUF, OpenVINO, HuggingFace)
5
  - File upload and text extraction
6
  - Medical data extraction from text and audio
7
  - Protected Health Information (PHI) scrubbing
8
  - Model management and dynamic loading
 
9
  The API supports multiple model formats and includes comprehensive error handling,
10
  memory optimization, and caching mechanisms for efficient operation in both
11
  local and cloud environments (Hugging Face Spaces).
12
  """
 
13
  from concurrent.futures import ThreadPoolExecutor, as_completed
14
  import json
15
  import logging
 
25
  pipeline as transformers_pipeline
26
  )
27
  from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent
28
+ from ai_med_extract.utils.openvino_summarizer_utils import generate_section
29
  agent = PatientSummarizerAgent(model_name="falconsai/medical_summarization")
30
  from ai_med_extract.agents.summarizer import SummarizerAgent
31
  from ai_med_extract.utils.file_utils import (
 
35
  get_data_from_storage,
36
  )
37
  from ..utils.validation import clean_result, validate_patient_name
38
+ from ai_med_extract.utils.patient_summary_utils import clean_patient_data, flatten_to_string_list
 
 
39
  import time
 
40
  logger = logging.getLogger(__name__)
41
 
42
  # Add GGUF model cache at the top of the file
43
  GGUF_MODEL_CACHE = {}
44
+ GGUF_PIPELINE_CACHE = {}
45
 
46
  def get_gguf_pipeline(model_name: str, filename: str = None):
47
  """
48
  Load and cache GGUF model pipelines with comprehensive error handling.
 
49
  This function provides a cached interface to GGUF models with fallback mechanisms
50
  for robust operation in production environments.
 
51
  Args:
52
  model_name (str): The name of the GGUF model or HuggingFace repository ID.
53
  Can be a local file path or HuggingFace model identifier.
54
  filename (str, optional): Specific filename for HuggingFace repository models.
55
  Required when model_name is a repository ID.
 
56
  Returns:
57
  GGUFModelPipeline: A loaded GGUF model pipeline instance or fallback pipeline.
 
58
  Raises:
59
  RuntimeError: If both model loading and fallback mechanisms fail.
 
60
  Notes:
61
  - Uses a global cache to avoid reloading the same model multiple times
62
  - Implements timeout mechanism for model loading (5 minutes)
 
68
  try:
69
  from ai_med_extract.utils.model_loader_gguf import GGUFModelPipeline, create_fallback_pipeline
70
  import time
 
71
  # Add timeout for model loading
72
  start_time = time.time()
73
  timeout = 300 # 5 minutes timeout
 
74
  # Try to load the GGUF model
75
  try:
76
  GGUF_MODEL_CACHE[key] = GGUFModelPipeline(model_name, filename, timeout=timeout)
 
79
  except Exception as e:
80
  load_time = time.time() - start_time
81
  print(f"[GGUF] Failed to load model {model_name} after {load_time:.2f}s: {e}")
 
82
  # If model loading fails, use fallback
83
  print("[GGUF] Using fallback pipeline")
84
  GGUF_MODEL_CACHE[key] = create_fallback_pipeline()
 
85
  except Exception as e:
86
  print(f"[GGUF] Critical error in model loading: {e}")
87
  # Create a basic fallback
88
  from ai_med_extract.utils.model_loader_gguf import create_fallback_pipeline
89
  GGUF_MODEL_CACHE[key] = create_fallback_pipeline()
 
90
  return GGUF_MODEL_CACHE[key]
91
 
92
+ def get_cached_gguf_pipeline(model_name: str, filename: str = None):
93
+ key = (model_name, filename)
94
+ if key not in GGUF_PIPELINE_CACHE:
95
+ GGUF_PIPELINE_CACHE[key] = get_gguf_pipeline(model_name, filename)
96
+ return GGUF_PIPELINE_CACHE[key]
97
+
98
+ def ensure_four_sections(summary: str) -> str:
99
+ """
100
+ Ensures the summary contains all four required sections.
101
+ If any are missing, appends a placeholder.
102
+ """
103
+ required_sections = [
104
+ "## Clinical Assessment",
105
+ "## Key Trends & Changes",
106
+ "## Plan & Suggested Actions",
107
+ "## Direct Guidance for Physician"
108
+ ]
109
+ lines = summary.splitlines()
110
+ existing_headers = [line.strip() for line in lines if line.strip().startswith("##")]
111
+ for section in required_sections:
112
+ if section not in existing_headers:
113
+ summary += f"\n{section}\n- *Section was not generated. Consider retrying or checking input data.*"
114
+ return summary
115
 
116
  def get_qa_pipeline(qa_model_type, qa_model_name):
117
  if not qa_model_type or not qa_model_name:
118
  raise ValueError("Both qa_model_type and qa_model_name must be provided")
 
 
119
  if not hasattr(get_qa_pipeline, "cache"):
120
  get_qa_pipeline.cache = {}
 
121
  # For Hugging Face Spaces, we need to be memory efficient
122
  import torch
123
  torch.cuda.empty_cache() # Clear GPU memory before loading model
 
124
  # Set default tensor type to float32 for better compatibility
125
  torch.set_default_tensor_type(torch.FloatTensor)
126
  if torch.cuda.is_available():
127
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
 
128
  key = (qa_model_type, qa_model_name)
129
  if key in get_qa_pipeline.cache:
130
  return get_qa_pipeline.cache[key]
 
131
  try:
132
  # For Hugging Face Spaces, use smaller models by default
133
  if "Qwen/Qwen-7B-Chat" in qa_model_name:
134
  qa_model_name = "Qwen/Qwen-1_8B-Chat"
135
  elif "Llama" in qa_model_name:
136
  qa_model_name = "facebook/opt-125m"
 
137
  # Load tokenizer with trust_remote_code=True for custom tokenizers
138
  tokenizer = AutoTokenizer.from_pretrained(
139
  qa_model_name,
140
  trust_remote_code=True,
141
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
142
  )
 
143
  # Load model with memory optimizations
144
  try:
145
  model = AutoModelForCausalLM.from_pretrained(
 
161
  low_cpu_mem_usage=True,
162
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
163
  )
 
164
  # Create pipeline with memory optimizations
165
  pipeline = transformers_pipeline(
166
  task=qa_model_type,
 
169
  device_map="auto",
170
  torch_dtype=torch.float32
171
  )
 
172
  get_qa_pipeline.cache[key] = pipeline
173
  return pipeline
 
174
  except Exception as e:
175
  raise
176
 
 
180
  """
181
  if not qa_pipeline or not question or not context:
182
  raise ValueError("Pipeline, question and context are required")
 
183
  qa_model_type = getattr(qa_pipeline, '_qa_model_type', None)
 
184
  try:
185
  if qa_model_type == 'text-generation':
186
  prompt = f"Question: {question}\nContext: {context}\nAnswer:"
187
  result = qa_pipeline(prompt, max_new_tokens=128, do_sample=False)
 
188
  if isinstance(result, list) and result and 'generated_text' in result[0]:
189
  answer = result[0]['generated_text'].split('Answer:')[-1].strip()
190
  return {'answer': answer}
 
198
  def get_ner_pipeline(ner_model_type, ner_model_name):
199
  if not ner_model_type or not ner_model_name:
200
  raise ValueError("Both ner_model_type and ner_model_name must be provided")
 
201
  if not hasattr(get_ner_pipeline, "cache"):
202
  get_ner_pipeline.cache = {}
 
203
  # For Hugging Face Spaces, we need to be memory efficient
204
  import torch
205
  torch.cuda.empty_cache() # Clear GPU memory before loading model
 
206
  # Set default tensor type
207
  torch.set_default_tensor_type(torch.FloatTensor)
208
  if torch.cuda.is_available():
209
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
 
210
  key = (ner_model_type, ner_model_name)
211
  if key in get_ner_pipeline.cache:
212
  return get_ner_pipeline.cache[key]
 
213
  try:
214
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
 
215
  # Clear any existing models from memory
216
  if torch.cuda.is_available():
217
  torch.cuda.empty_cache()
 
218
  # Load tokenizer
219
  try:
220
  tokenizer = AutoTokenizer.from_pretrained(
 
230
  trust_remote_code=True,
231
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
232
  )
 
233
  # Load model with memory optimizations
234
  try:
235
  # For NER models, we'll use CPU if device_map='auto' is not supported
 
263
  torch_dtype=torch.float32,
264
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
265
  )
 
266
  # Create pipeline with appropriate device configuration
267
  try:
268
  qa_pipeline = pipeline(
 
283
  )
284
  else:
285
  raise
 
286
  # Cache the pipeline
287
  get_ner_pipeline.cache[key] = qa_pipeline
288
  return qa_pipeline
 
289
  except Exception as e:
290
  raise
291
 
 
292
  def get_summarizer_pipeline(summarizer_model_type, summarizer_model_name):
293
  if not hasattr(get_summarizer_pipeline, "cache"):
294
  get_summarizer_pipeline.cache = {}
 
296
  if key not in get_summarizer_pipeline.cache:
297
  import torch
298
  from transformers import pipeline
 
299
  # Use float16 only if CUDA is available, else use float32
300
  if torch.cuda.is_available():
301
  dtype = torch.float16
 
305
  dtype = torch.float32
306
  device = -1
307
  device_map = None
 
308
  get_summarizer_pipeline.cache[key] = pipeline(
309
  task=summarizer_model_type,
310
  model=summarizer_model_name,
 
315
  )
316
  return get_summarizer_pipeline.cache[key]
317
 
 
318
  def register_routes(app, agents):
319
  from ai_med_extract.utils.openvino_summarizer_utils import (
320
  parse_ehr_chartsummarydtl, visits_sorted, compute_deltas, build_compact_baseline, delta_to_text, build_main_prompt
 
333
  chartsummarydtl = ehr_result.get("chartsummarydtl") if isinstance(ehr_result, dict) else None
334
  if not chartsummarydtl:
335
  return jsonify({"error": "Missing chartsummarydtl in input"}), 400
 
336
  # Normalize visits
337
  visits = parse_ehr_chartsummarydtl(chartsummarydtl)
 
338
  # Extract patient demographics if available
339
  patient_info = ""
340
  if isinstance(ehr_result, dict):
 
345
  past_medical_history = ', '.join(ehr_result.get('past_medical_history', []))
346
  social_history = ehr_result.get('social_history', 'Not specified')
347
  patient_info = f"Patient: {patient_name} (ID: {patient_id}, Age: {age}, Gender: {gender})\nPast Medical History: {past_medical_history}\nSocial History: {social_history}\n"
 
348
  # Generate summary from current data only (no state tracking)
349
  # Use empty old visits to compute deltas against baseline
350
  delta = compute_deltas([], visits)
 
352
  baseline = build_compact_baseline(all_visits)
353
  delta_text = delta_to_text(delta)
354
  prompt = build_main_prompt(baseline, delta_text, patient_info)
 
355
  # Model selection logic (model_name, model_type)
356
  model_name = data.get("model_name") or "microsoft/Phi-3-mini-4k-instruct"
357
  model_type = data.get("model_type") or "text-generation"
 
363
  pipeline = loader.model_loader.load() if hasattr(loader, "model_loader") else None
364
  if not pipeline:
365
  return jsonify({"error": "Model pipeline not available"}), 500
 
366
  # Run inference
367
  import torch
368
  torch.set_num_threads(2)
369
  inputs = pipeline.tokenizer([prompt], return_tensors="pt")
370
  outputs = pipeline.model.generate(**inputs, max_new_tokens=100000, do_sample=False, pad_token_id=pipeline.tokenizer.eos_token_id or 32000)
371
  text = pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
372
  # Extract just the markdown summary (remove prompt text)
373
  # The model should return the complete markdown-formatted summary
374
  summary_start_patterns = [
 
377
  "# Clinical Assessment",
378
  "Clinical Assessment"
379
  ]
 
380
  new_summary = text
381
  for pattern in summary_start_patterns:
382
  if pattern in text:
383
  new_summary = text.split(pattern)[-1].strip()
384
  break
 
385
  return jsonify({
386
  "summary": new_summary,
387
  "baseline": baseline,
 
389
  }), 200
390
  except Exception as e:
391
  return jsonify({"error": f"Failed to generate summary: {str(e)}"}), 500
 
392
 
393
+ # Configure upload directory based on environment
394
+ import os
395
  if os.environ.get('SPACE_ID'): # We're running on Hugging Face Spaces
396
  app.config['UPLOAD_FOLDER'] = '/data/uploads'
397
  else: # We're running locally
398
  upload_dir = os.path.join(os.getcwd(), 'uploads')
399
  os.makedirs(upload_dir, exist_ok=True)
400
  app.config['UPLOAD_FOLDER'] = upload_dir
 
401
  # Ensure the upload directory exists and is writable
402
  if not os.path.exists(app.config['UPLOAD_FOLDER']):
403
  try:
 
416
  def upload_file():
417
  import torch
418
  torch.cuda.empty_cache() # Clear GPU memory before processing
 
419
  files = request.files.getlist("file")
420
  patient_name = request.form.get("patient_name", "").strip()
421
  password = request.form.get("password")
 
422
  # Use more compatible models by default
423
  qa_model_name = request.form.get("qa_model_name", "facebook/bart-base")
424
  qa_model_type = request.form.get("qa_model_type", "text-generation")
 
426
  ner_model_type = request.form.get("ner_model_type", "ner")
427
  summarizer_model_name = request.form.get("summarizer_model_name", "facebook/bart-base")
428
  summarizer_model_type = request.form.get("summarizer_model_type", "summarization")
 
429
  if not files:
430
  return jsonify({"error": "No file uploaded"}), 400
 
431
  # Accept any model type and model name for QA, NER, and summarizer
432
  if not qa_model_name or not qa_model_type:
433
  return jsonify({"error": "QA model name and type are required"}), 400
 
435
  qa_pipeline = get_qa_pipeline(qa_model_type, qa_model_name)
436
  except Exception as e:
437
  return jsonify({"error": f"QA model load failed: {str(e)}"}), 500
 
438
  if not ner_model_name or not ner_model_type:
439
  return jsonify({"error": "NER model name and type are required"}), 400
440
  try:
441
  ner_pipeline = get_ner_pipeline(ner_model_type, ner_model_name)
442
  except Exception as e:
443
  return jsonify({"error": f"NER model load failed: {str(e)}"}), 500
 
444
  if not summarizer_model_name or not summarizer_model_type:
445
  return jsonify({"error": "Summarizer model name and type are required"}), 400
446
  try:
447
  summarizer_pipeline = get_summarizer_pipeline(summarizer_model_type, summarizer_model_name)
448
  except Exception as e:
449
  return jsonify({"error": f"Summarizer model load failed: {str(e)}"}), 500
 
450
  extracted_data = []
451
  for file in files:
452
  if file.filename == "":
 
479
  except Exception as e:
480
  os.remove(filepath) # Clean up on failure
481
  return jsonify({"error": f"Text extraction failed: {str(e)}"}), 500
 
482
  skip_medical_check = (
483
  request.form.get("skip_medical_check", "false").lower() == "true"
484
  )
 
614
  file = request.files["file"]
615
  if file.filename == "":
616
  return jsonify({"error": "No selected file"}), 400
 
617
  # Use secure filename
618
  from werkzeug.utils import secure_filename
619
  import uuid
620
  temp_filename = f"{uuid.uuid4()}_{secure_filename(file.filename)}"
621
  temp_path = os.path.join(app.config['UPLOAD_FOLDER'], temp_filename)
 
622
  file.save(temp_path)
623
  result = whisper_model.transcribe(temp_path)
624
  os.remove(temp_path)
 
628
  os.remove(temp_path)
629
  return jsonify({"error": str(e)}), 500
630
 
 
631
  def group_by_category(data):
632
  grouped = defaultdict(list)
633
  for item in data:
 
639
  "answer": item.get("answer", "Not Available"),
640
  }
641
  )
 
642
  return [{"category": k, "detail": v} for k, v in grouped.items()]
643
 
644
  def deduplicate_extractions(data):
645
  seen = set()
646
  reversed_unique = []
 
647
  # Loop in reverse to keep the *last* occurrence
648
  for item in reversed(data):
649
  key = (item.get("label"))
650
  if key not in seen:
651
  seen.add(key)
652
  reversed_unique.append(item)
 
653
  # Reverse back to preserve original order (latest kept, first dropped)
654
  return list(reversed(reversed_unique))
655
 
 
659
  text,
660
  add_special_tokens=False
661
  )
 
662
  chunks = []
663
  start = 0
 
664
  while start < len(input_ids):
665
  end = min(start + max_tokens, len(input_ids))
666
  chunk_ids = input_ids[start:end]
 
667
  chunk_text = tokenizer.decode(
668
  chunk_ids,
669
  skip_special_tokens=True,
670
  clean_up_tokenization_spaces=True
671
  )
 
672
  # Ensure partial continuation isn't cut off mid-sentence
673
  if not chunk_text.endswith(('.', '?', '!', ':')):
674
  chunk_text += "..."
 
675
  chunks.append(chunk_text)
676
  start += max_tokens - overlap
677
  return chunks
 
684
  except ValueError:
685
  # '[' not found in output
686
  return []
 
687
  # Try parsing full array first
688
  try:
689
  parsed = json.loads(json_text)
 
691
  return parsed
692
  except Exception:
693
  pass # fallback to manual parsing
 
694
  # Manual recovery via brace matching
695
  stack = 0
696
  obj_start = None
 
709
  except Exception as e:
710
  print(f"❌ Invalid JSON object: {e}")
711
  obj_start = None
 
712
  return extracted
713
 
 
714
  def process_chunk(generator, chunk, idx):
715
  prompt = f"""
716
  [INST] <<SYS>>
717
  You are a clinical data extraction assistant.
 
718
  Your job is to:
719
  1. Read the following medical report.
720
  2. Extract all medically relevant facts as a list of JSON objects.
 
723
  - "question": a question related to that field
724
  - "answer": the answer from the text
725
  4. After extracting the list, categorize each object under one of the following fixed categories:
 
726
  - Patient Info
727
  - Vitals
728
  - Symptoms
 
734
  - Laboratory
735
  - Radiology
736
  - Doctor Note
 
737
  Example format for structure only — do not include in output:
738
  [
739
  {{
 
743
  "category": "Patient Info"
744
  }},
745
  ]
746
+ ⚠️ Use these categories listed above. If an item does not fit any of these categories, create a new category for it.
 
 
747
  Text:
748
  {chunk}
 
749
  Return a single valid JSON array of all extracted objects.
750
  Do not include any explanations or commentary.
751
  Only output the JSON array
752
  <</SYS>> [/INST]
753
  """
 
754
  try:
755
  # Clear GPU memory before processing
756
  torch.cuda.empty_cache()
 
757
  # Process with memory optimizations
758
  output = generator(
759
  prompt,
 
761
  do_sample=False, # Disable sampling for deterministic output
762
  temperature=0.3, # Lower temperature for more focused output
763
  )[0]["generated_text"]
 
764
  return idx, output
765
  except Exception as e:
766
  return idx, None
767
+
768
  @app.route("/extract_medical_data", methods=["POST"])
769
  def extract_medical_data():
770
  data = request.json
771
  qa_model_name = data.get("qa_model_name")
772
  qa_model_type = data.get("qa_model_type")
773
  extracted_files = data.get("extracted_data")
 
774
  if not qa_model_name or not qa_model_type:
775
  return jsonify({"error": "Missing 'qa_model_name' or 'qa_model_type'"}), 400
 
776
  if not extracted_files:
777
  return jsonify({"error": "Missing 'extracted_data' in request"}), 400
 
778
  try:
779
  tokenizer = AutoTokenizer.from_pretrained(
780
  qa_model_name,
781
  trust_remote_code=True,
782
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
783
  )
 
784
  model = AutoModelForCausalLM.from_pretrained(
785
  qa_model_name,
786
  device_map="auto",
 
789
  low_cpu_mem_usage=True,
790
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
791
  )
 
792
  generator = transformers_pipeline(
793
  task=qa_model_type,
794
  model=model,
795
  tokenizer=tokenizer,
796
  torch_dtype=torch.float32
797
  )
 
798
  except Exception as e:
799
  return jsonify({"error": f"Could not load model: {str(e)}"}), 500
 
800
  structured_response = {"extracted_data": []}
 
801
  for file_data in extracted_files:
802
  filename = file_data.get("file", "unknown_file")
803
  context = file_data.get("extracted_text", "").strip()
 
804
  if not context:
805
  structured_response["extracted_data"].append(
806
  {"file": filename, "medical_fields": []}
807
  )
808
  continue
 
809
  chunks = chunk_text(context, tokenizer)
810
  all_extracted = []
 
811
  with ThreadPoolExecutor(max_workers=4) as executor:
812
  futures = {
813
  executor.submit(process_chunk, generator, chunk, idx): idx
 
816
  for future in as_completed(futures):
817
  idx = futures[future]
818
  _, output = future.result()
 
819
  if not output:
820
  continue
 
821
  try:
822
  objs = extract_json_objects(output)
823
  if objs:
824
  all_extracted.extend(objs)
825
  else:
826
+ print(f"⚠ Chunk {idx+1} yielded no valid JSON.")
827
  except Exception as e:
828
  print(f"❌ Error extracting JSON from chunk {idx+1}")
 
829
  # Clean and group results for this file
830
  if all_extracted:
831
  deduped = deduplicate_extractions(all_extracted)
 
833
  grouped_data = group_by_category(deduped)
834
  else:
835
  grouped_data = {"error": "No valid data extracted"}
 
836
  structured_response["extracted_data"].append(
837
  {"file": filename, "medical_fields": grouped_data}
838
  )
 
839
  try:
840
  save_data_to_storage(filename, grouped_data)
841
  except Exception as e:
842
+ print(f"⚠ Failed to save data for {filename}: {e}")
 
843
  print("✅ Extraction complete.")
844
  return jsonify(structured_response)
845
 
 
846
  @app.route("/api/generate_summary", methods=["POST"])
847
  def generate_summary():
848
  logger.info("Received request to generate summary.")
 
873
  torch.set_default_tensor_type(torch.FloatTensor)
874
  if torch.cuda.is_available():
875
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
 
876
  # Handle multipart form data from Flutter
877
  if "audio" not in request.files:
878
  return jsonify({"error": "No audio file provided"}), 400
 
879
  audio_file = request.files["audio"]
880
  if audio_file.filename == "":
881
  return jsonify({"error": "No selected audio file"}), 400
 
882
  # Validate file extension
883
  if not allowed_file(audio_file.filename):
884
  return jsonify({"error": f"Unsupported audio format. Allowed formats: wav, mp3, m4a, ogg"}), 400
 
885
  # Check file size
886
  valid_size, error_message = check_file_size(audio_file)
887
  if not valid_size:
888
  return jsonify({"error": error_message}), 400
 
889
  # Use default model if not specified
890
  qa_model_name = request.form.get("qa_model_name", "facebook/bart-base")
891
  qa_model_type = request.form.get("qa_model_type", "text-generation")
 
892
  # Load QA model with proper error handling
893
  try:
894
  qa_pipeline = get_qa_pipeline(qa_model_type, qa_model_name)
895
  except Exception as e:
896
  return jsonify({"error": f"QA model load failed: {str(e)}"}), 500
 
897
  # Use platform-agnostic temp directory
898
  import uuid
899
  from werkzeug.utils import secure_filename
 
902
  os.makedirs(temp_dir, exist_ok=True)
903
  temp_filename = f"{uuid.uuid4()}_{secure_filename(audio_file.filename)}"
904
  temp_path = os.path.join(temp_dir, temp_filename)
 
905
  try:
906
  audio_file.save(temp_path)
 
907
  # Transcribe audio with retries
908
  max_retries = 3
909
  transcribed_text = None
 
920
  raise
921
  torch.cuda.empty_cache() # Clear GPU memory between attempts
922
  continue
 
923
  if not transcribed_text:
924
  raise ValueError("Failed to transcribe audio after multiple attempts")
 
925
  # Clean and process text
926
  try:
927
  clean_text = PHIScrubberAgent.scrub_phi(transcribed_text)
928
  except Exception as e:
929
  clean_text = transcribed_text
 
930
  # Extract medical data with proper device handling
931
  try:
932
  with torch.cuda.device(0) if torch.cuda.is_available() else torch.no_grad():
 
935
  medical_data = medical_data_extractor.extract_medical_data(clean_text)
936
  except Exception as e:
937
  medical_data = {"error": f"Medical data extraction failed: {str(e)}"}
 
938
  # Clean up temporary file
939
  if os.path.exists(temp_path):
940
  os.remove(temp_path)
 
941
  # Return response in the format expected by Flutter
942
  return jsonify({
943
  "status": "success",
 
946
  "medical_chart": medical_data
947
  }
948
  }), 200
 
949
  except Exception as e:
950
  if temp_path and os.path.exists(temp_path):
951
  os.remove(temp_path)
 
953
  "status": "error",
954
  "error": f"Processing failed: {str(e)}"
955
  }), 500
 
956
  except Exception as e:
957
  if temp_path and os.path.exists(temp_path):
958
  os.remove(temp_path)
 
961
  "error": f"Request handling failed: {str(e)}"
962
  }), 500
963
 
964
+ # ==================== ULTRA-OPTIMIZED generate_patient_summary ENDPOINT ====================
 
 
 
 
 
 
 
 
 
 
965
  @app.route('/generate_patient_summary', methods=['POST'])
966
  def generate_patient_summary():
967
  """
968
+ 🚀 ULTRA-OPTIMIZED + TIMEOUT-FLEXIBLE PATIENT SUMMARY HF SPACES READY
969
+ - Ultra-fast by default (8s EHR, 25s gen) → perfect for HF Spaces
970
+ - Supports "timeout_mode": "extended" → 30s EHR, 55s gen for heavy tasks
971
+ - Works with ANY model_name and model_type (GGUF, text-generation, summarization)
972
+ - GGUF uses SINGLE PROMPT → 4x faster
973
+ - NEVER breaks — multi-layer fallbacks
974
+ - Preserves medical accuracy via delta/baseline logic
975
  """
976
  from ai_med_extract.utils.openvino_summarizer_utils import (
977
  parse_ehr_chartsummarydtl, visits_sorted, compute_deltas, build_compact_baseline, delta_to_text, build_main_prompt
 
979
  try:
980
  start_total = time.time()
981
  data = request.get_json()
 
982
  patientid = data.get("patientid")
983
  token = data.get("token")
984
  key = data.get("key")
985
+ # Support any model — default to GGUF Phi-3-mini
986
+ model_name = data.get("patient_summarizer_model_name") or "microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf"
987
+ model_type = data.get("patient_summarizer_model_type") or data.get("model_type") or "gguf"
988
+
989
+ # ⚡ TIMEOUT MODE: "fast" (default) or "extended"
990
+ timeout_mode = data.get("timeout_mode", "fast") # fast (HF Spaces) | extended (heavy tasks)
991
+ if timeout_mode == "extended":
992
+ EHR_TIMEOUT = 30 # Longer for slow EHR systems
993
+ GEN_TIMEOUT = 500 # Almost full 60s HF limit
994
+ logger.info("🕒 Using EXTENDED timeout mode")
995
+ else:
996
+ EHR_TIMEOUT = 8 # Ultra-fast for HF Spaces
997
+ GEN_TIMEOUT = 500 # Leaves room for overhead
998
+ logger.info("⚡ Using FAST timeout mode (default)")
999
+
1000
  if not patientid or not token or not key:
1001
  return jsonify({"error": "Missing required fields: patientid, token, or key"}), 400
1002
 
1003
+ api_url = f"{key.strip()}/Transactionapi/api/PatientList/patientsummary"
1004
  headers = {
1005
  "Authorization": f"Bearer {token}",
1006
  "Content-Type": "application/json",
1007
  }
 
1008
  if key and not key.startswith("http"):
1009
  headers["x-api-key"] = key
1010
+
1011
+ # ⚡ DYNAMIC EHR TIMEOUT
1012
  t_api_start = time.time()
1013
+ try:
1014
+ response = requests.post(api_url, json={"patientid": patientid}, headers=headers, timeout=EHR_TIMEOUT)
1015
+ except requests.exceptions.Timeout:
1016
+ logger.warning(f"EHR API timeout ({EHR_TIMEOUT}s) — returning structured fallback.")
1017
+ minimal_fallback = f"""
1018
+ ## Clinical Assessment
1019
+ - EHR API timeout ({EHR_TIMEOUT}s) — could not fetch patient data.
1020
+
1021
+ ## Key Trends & Changes
1022
+ - No data available due to API timeout.
1023
+
1024
+ ## Plan & Suggested Actions
1025
+ - Retry with "timeout_mode": "extended" or check EHR API performance.
1026
+
1027
+ ## Direct Guidance for Physician
1028
+ - Patient data unavailable — do not proceed without verification.
1029
+ """
1030
+ return jsonify({
1031
+ "summary": ensure_four_sections(minimal_fallback),
1032
+ "warning": f"EHR API timeout ({EHR_TIMEOUT}s) — used minimal fallback.",
1033
+ "timing": {"total": round(time.time() - start_total, 1)},
1034
+ "timeout_mode_used": timeout_mode
1035
+ }), 200
1036
+ except requests.exceptions.RequestException as e:
1037
+ logger.error(f"Network error contacting EHR API: {e}")
1038
+ return jsonify({"error": f"Network error: {str(e)}"}), 503
1039
  t_api_end = time.time()
1040
+
1041
  if response.status_code != 200:
1042
+ logger.warning(f"EHR API non-200 status: {response.status_code}")
1043
+ minimal_fallback = f"""
1044
+ ## Clinical Assessment
1045
+ - EHR API returned error {response.status_code}.
1046
+
1047
+ ## Key Trends & Changes
1048
+ - No patient data available.
1049
+
1050
+ ## Plan & Suggested Actions
1051
+ - Verify API key, token, and patient ID.
1052
+
1053
+ ## Direct Guidance for Physician
1054
+ - System received invalid response from EHR — do not proceed.
1055
+ """
1056
  return jsonify({
1057
+ "summary": ensure_four_sections(minimal_fallback),
1058
+ "warning": f"EHR API error {response.status_code}",
1059
+ "timing": {"total": round(time.time() - start_total, 1)},
1060
+ "timeout_mode_used": timeout_mode
1061
+ }), 200
1062
+
1063
  try:
1064
  api_data = response.json()
1065
  except ValueError:
1066
+ logger.error("Invalid JSON from EHR API")
1067
+ minimal_fallback = """
1068
+ ## Clinical Assessment
1069
+ - EHR API returned invalid JSON.
1070
+
1071
+ ## Key Trends & Changes
1072
+ - Unable to parse patient data.
1073
+
1074
+ ## Plan & Suggested Actions
1075
+ - Contact EHR API administrator.
1076
+
1077
+ ## Direct Guidance for Physician
1078
+ - Patient data corrupted — do not proceed.
1079
+ """
1080
+ return jsonify({
1081
+ "summary": ensure_four_sections(minimal_fallback),
1082
+ "warning": "Invalid JSON from EHR API",
1083
+ "timing": {"total": round(time.time() - start_total, 1)},
1084
+ "timeout_mode_used": timeout_mode
1085
+ }), 500
1086
+
1087
+ ehr_result = api_data.get("result") or api_data
1088
  chartsummarydtl = ehr_result.get("chartsummarydtl") if isinstance(ehr_result, dict) else None
1089
  if not chartsummarydtl:
1090
+ logger.warning("Missing chartsummarydtl in EHR response")
1091
+ minimal_fallback = """
1092
+ ## Clinical Assessment
1093
+ - No chartsummarydtl found in EHR response.
1094
+
1095
+ ## Key Trends & Changes
1096
+ - Patient data structure invalid.
1097
+
1098
+ ## Plan & Suggested Actions
1099
+ - Verify EHR API response format.
1100
+
1101
+ ## Direct Guidance for Physician
1102
+ - Incomplete patient data — manual review required.
1103
+ """
1104
  return jsonify({
1105
+ "summary": ensure_four_sections(minimal_fallback),
1106
+ "warning": "Missing chartsummarydtl",
1107
+ "timing": {"total": round(time.time() - start_total, 1)},
1108
+ "timeout_mode_used": timeout_mode
 
 
 
1109
  }), 500
1110
+
1111
+ # Parse and compute deltas — YOUR LOGIC PRESERVED
1112
  visits = parse_ehr_chartsummarydtl(chartsummarydtl)
 
 
1113
  delta = compute_deltas([], visits)
1114
  all_visits = visits_sorted(visits)
1115
  baseline = build_compact_baseline(all_visits)
1116
  delta_text = delta_to_text(delta)
1117
+
1118
+ # ==================== GGUF MODEL HANDLING ====================
 
 
 
 
 
1119
  if model_type == "gguf":
1120
+ logger.info(f"🧠 GGUF MODE: Single-prompt generation for {model_name}")
1121
  try:
1122
+ # Extract repo_id/filename if needed
1123
+ repo_id, filename = model_name, None
1124
+ if '/' in model_name and model_name.endswith('.gguf'):
1125
  repo_id, filename = model_name.rsplit('/', 1)
1126
+
1127
+ # Load pipeline — uses global cache
1128
+ pipeline = get_cached_gguf_pipeline(repo_id, filename)
1129
+
1130
+ # ⚡⚡⚡ SINGLE PROMPT ALL 4 SECTIONS AT ONCE
1131
+ # ==================== OPTIMIZED PROMPT FOR GGUF MODEL ====================
1132
+ full_prompt = f"""
1133
+ <|system|>
1134
+ You are an expert clinical AI assistant. Your task is to generate a patient summary with EXACTLY FOUR sections in valid markdown format.
1135
+
1136
+ ### STRICT OUTPUT FORMAT RULES ###
1137
+ 1. Your response MUST start immediately with "## Clinical Assessment" (no preamble, no "Sure", no explanations).
1138
+ 2. Use ONLY these four section headers, in this exact order:
1139
+ ## Clinical Assessment
1140
+ ## Key Trends & Changes
1141
+ ## Plan & Suggested Actions
1142
+ ## Direct Guidance for Physician
1143
+ 3. Under each header, provide 2-4 concise bullet points using "- ".
1144
+ 4. Base your summary SOLELY on the data provided below. DO NOT HALLUCINATE or invent information.
1145
+ 5. End your response after the "## Direct Guidance for Physician" section.
1146
+
1147
+ ### DATA TO SUMMARIZE ###
1148
+ - PATIENT VISITS: {visits}
1149
+ - BASELINE: {baseline}
1150
+ - DELTAS: {delta_text}
1151
+
1152
+ ### EXAMPLE OUTPUT FORMAT ###
1153
+ ## Clinical Assessment
1154
+ - Patient presents with chronic ischemic heart disease.
1155
+ - Current medications include telmisartan, atorvastatin, metoprolol, and aspirin.
1156
+ ## Key Trends & Changes
1157
+ - Blood pressure elevated at 160/100 mmHg.
1158
+ - No significant weight change recorded.
1159
+ - No new diagnoses or medications since last visit.
1160
+ ## Plan & Suggested Actions
1161
+ - Consider medication adjustment for hypertension.
1162
+ - Schedule follow-up to monitor BP and lipid panel.
1163
+ ## Direct Guidance for Physician
1164
+ - Prioritize BP control to mitigate cardiac risk.
1165
+ - Review recent lab results when available.
1166
+ </s>
1167
+ <|user|>
1168
+ Generate the 4-section patient summary in the exact format specified above.
1169
+ </s>
1170
+ <|assistant|>
1171
+ ## Clinical Assessment
1172
+ """
1173
+
1174
+ # Generate with DYNAMIC timeout using generate_full_summary for better section completion
1175
  try:
1176
+ raw_summary = pipeline.generate(
1177
+ full_prompt,
1178
+ max_tokens=2000,
1179
+ temperature=0.1,
1180
+ top_p=0.9,
1181
+ # max_loops=3 # Allow up to 3 loops to complete sections
1182
+ )
1183
+ logger.info(f"GGUF raw summary length: {len(raw_summary)} chars")
1184
+ except Exception as gen_error:
1185
+ logger.error(f"GGUF generation failed: {gen_error}")
1186
+ raise # Trigger fallback below
1187
+
1188
+ # Clean output
1189
+ def extract_markdown_sections(text):
1190
+ sections = [
1191
  "## Clinical Assessment",
1192
+ "## Key Trends & Changes",
1193
+ "## Plan & Suggested Actions",
1194
+ "## Direct Guidance for Physician"
1195
  ]
1196
+ output_lines = []
1197
+ current_section = None
1198
 
1199
+ for line in text.splitlines():
1200
+ stripped = line.strip()
1201
+ for section in sections:
1202
+ if stripped.startswith(section):
1203
+ current_section = section
1204
+ output_lines.append(section)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1205
  break
1206
+ else:
1207
+ if current_section and stripped:
1208
+ output_lines.append(line)
 
 
 
 
 
 
 
 
 
 
 
 
1209
 
1210
+ return "\n".join(output_lines)
1211
+
1212
+ markdown_summary = extract_markdown_sections(raw_summary)
1213
+ markdown_summary = ensure_four_sections(markdown_summary)
1214
+
1215
+ total_time = time.time() - start_total
1216
+ logger.info(f"[✅ SUCCESS] GGUF | TIMEOUT_MODE: {timeout_mode} | TOTAL: {total_time:.1f}s")
1217
+
1218
+ return jsonify({
1219
+ "summary": raw_summary,
1220
+ "baseline": baseline,
1221
+ "delta": delta_text,
1222
+ "timing": {
1223
+ "ehr_api": round(t_api_end - t_api_start, 1),
1224
+ "generation": round(total_time - (t_api_end - t_api_start), 1),
1225
+ "total": round(total_time, 1)
1226
+ },
1227
+ "model_used": f"{model_name} ({model_type})",
1228
+ "timeout_mode_used": timeout_mode
1229
+ }), 200
1230
+
1231
  except Exception as e:
1232
+ logger.error(f"GGUF generation failed: {e}")
1233
+ # FALLBACK 1: STRUCTURED MINIMAL SUMMARY
1234
+ structured_fallback = f"""
1235
+ ## Clinical Assessment
1236
+ - System generated fallback due to model error.
1237
+
1238
+ ## Key Trends & Changes
1239
+ - Weight: {delta['weight']['curr'] or 'N/A'} (Δ {delta['weight']['delta'] or 'N/A'})
1240
+ - BP: {delta['bp_sys']['curr'] or 'N/A'}/{delta['bp_dia']['curr'] or 'N/A'}
1241
+ - New Dx: {', '.join(delta['added_dx']) if delta['added_dx'] else 'None'}
1242
+ - Meds Started: {', '.join(delta['started_meds']) if delta['started_meds'] else 'None'}
1243
+
1244
+ ## Plan & Suggested Actions
1245
+ - Review recent vitals and medication changes.
1246
+
1247
+ ## Direct Guidance for Physician
1248
+ - Model generation failed verify all data manually.
1249
+ """
1250
+ total_time = time.time() - start_total
1251
+ logger.info(f"[⚠️ FALLBACK 1] Structured summary | TIMEOUT_MODE: {timeout_mode} | TOTAL: {total_time:.1f}s")
1252
+
1253
+ return jsonify({
1254
+ "summary": ensure_four_sections(structured_fallback),
1255
+ "baseline": baseline,
1256
+ "delta": delta_text,
1257
+ "warning": "Model generation failed — used structured fallback.",
1258
+ "error": str(e),
1259
+ "timing": {"total": round(total_time, 1)},
1260
+ "timeout_mode_used": timeout_mode
1261
+ }), 200
1262
+
1263
+ # ==================== TEXT-GENERATION / OPENVINO ====================
1264
  elif model_type in {"text-generation", "causal-openvino"}:
1265
+ logger.info(f"🔤 TEXT-GENERATION MODE: {model_name}")
1266
+ try:
1267
+ loader = agents.get("medical_data_extractor")
1268
+ if not loader or getattr(loader, 'model_name', None) != model_name:
1269
+ from ai_med_extract.utils.model_loader_spaces import get_openvino_pipeline
 
1270
  pipeline = get_openvino_pipeline(model_name)
1271
+ else:
1272
+ pipeline = loader.model_loader.load() if hasattr(loader, "model_loader") else None
1273
+
1274
+ if not pipeline:
1275
+ raise ValueError("Pipeline not available")
1276
+
1277
+ prompt = build_main_prompt(baseline, delta_text)
1278
+ inputs = pipeline.tokenizer([prompt], return_tensors="pt")
1279
+ outputs = pipeline.model.generate(
1280
+ **inputs,
1281
+ max_new_tokens=800,
1282
+ do_sample=False,
1283
+ pad_token_id=pipeline.tokenizer.pad_token_id or pipeline.tokenizer.eos_token_id or 0
1284
+ )
1285
+ text = pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True)
1286
+
1287
+ summary_start_patterns = [
1288
+ "Now generate the complete clinical summary",
1289
+ "## Clinical Assessment",
1290
+ "# Clinical Assessment",
1291
+ "Clinical Assessment"
1292
+ ]
1293
+ new_summary = text
1294
+ for pattern in summary_start_patterns:
1295
+ if pattern in text:
1296
+ new_summary = text.split(pattern)[-1].strip()
1297
+ break
1298
+
1299
+ markdown_summary = summary_to_markdown(new_summary)
1300
+ markdown_summary = ensure_four_sections(markdown_summary)
1301
+
1302
+ total_time = time.time() - start_total
1303
+ logger.info(f"[✅ SUCCESS] Text-generation | TIMEOUT_MODE: {timeout_mode} | TOTAL: {total_time:.1f}s")
1304
+
1305
+ return jsonify({
1306
+ "summary": markdown_summary,
1307
+ "baseline": baseline,
1308
+ "delta": delta_text,
1309
+ "timing": {"total": round(total_time, 1)},
1310
+ "model_used": f"{model_name} ({model_type})",
1311
+ "timeout_mode_used": timeout_mode
1312
+ }), 200
1313
+
1314
+ except Exception as e:
1315
+ logger.error(f"Text-generation failed: {e}")
1316
+ structured_fallback = f"""
1317
+ ## Clinical Assessment
1318
+ - Text-generation model failed.
1319
+
1320
+ ## Key Trends & Changes
1321
+ - Refer to delta data for details.
1322
+
1323
+ ## Plan & Suggested Actions
1324
+ - Manual clinical review required.
1325
+
1326
+ ## Direct Guidance for Physician
1327
+ - AI model unavailable — use clinical judgment.
1328
+ """
1329
+ return jsonify({
1330
+ "summary": ensure_four_sections(structured_fallback),
1331
+ "baseline": baseline,
1332
+ "delta": delta_text,
1333
+ "warning": "Text-generation model failed — used fallback.",
1334
+ "error": str(e),
1335
+ "timing": {"total": round(time.time() - start_total, 1)},
1336
+ "timeout_mode_used": timeout_mode
1337
+ }), 200
1338
+
1339
+ # ==================== SUMMARIZATION MODEL ====================
1340
  elif model_type == "summarization":
1341
+ logger.info(f"📝 SUMMARIZATION MODE: {model_name}")
1342
+ try:
1343
+ loader = agents.get("summarizer")
1344
+ pipeline = loader.model_loader.load() if hasattr(loader, "model_loader") else get_summarizer_pipeline("summarization", model_name)
1345
+
1346
+ context = f"Patient Data:\nBaseline: {baseline}\nChanges: {delta_text}"
1347
+ result = pipeline(context, max_length=400, min_length=100, do_sample=False)
1348
+
1349
+ if isinstance(result, list) and result and "summary_text" in result[0]:
1350
+ raw_summary = result[0]["summary_text"]
1351
+ else:
1352
+ raw_summary = str(result)
1353
+
1354
+ markdown_summary = f"""
1355
+ ## Clinical Assessment
1356
+ {raw_summary[:250]}...
1357
+
1358
+ ## Key Trends & Changes
1359
+ See delta data for details.
1360
+
1361
+ ## Plan & Suggested Actions
1362
+ Further evaluation recommended.
1363
+
1364
+ ## Direct Guidance for Physician
1365
+ Generic summary — verify details clinically.
1366
+ """
1367
+ markdown_summary = ensure_four_sections(markdown_summary)
1368
+
1369
+ total_time = time.time() - start_total
1370
+ logger.info(f"[✅ SUCCESS] Summarization | TIMEOUT_MODE: {timeout_mode} | TOTAL: {total_time:.1f}s")
1371
+
1372
+ return jsonify({
1373
+ "summary": markdown_summary,
1374
+ "baseline": baseline,
1375
+ "delta": delta_text,
1376
+ "timing": {"total": round(total_time, 1)},
1377
+ "model_used": f"{model_name} ({model_type})",
1378
+ "timeout_mode_used": timeout_mode
1379
+ }), 200
1380
+
1381
+ except Exception as e:
1382
+ logger.error(f"Summarization failed: {e}")
1383
+ structured_fallback = """
1384
+ ## Clinical Assessment
1385
+ - Summarization model failed.
1386
+
1387
+ ## Key Trends & Changes
1388
+ - Unable to generate trends.
1389
+
1390
+ ## Plan & Suggested Actions
1391
+ - Full manual review required.
1392
+
1393
+ ## Direct Guidance for Physician
1394
+ - AI assistance unavailable — proceed with caution.
1395
+ """
1396
+ return jsonify({
1397
+ "summary": ensure_four_sections(structured_fallback),
1398
+ "baseline": baseline,
1399
+ "delta": delta_text,
1400
+ "warning": "Summarization model failed — used fallback.",
1401
+ "error": str(e),
1402
+ "timing": {"total": round(time.time() - start_total, 1)},
1403
+ "timeout_mode_used": timeout_mode
1404
+ }), 200
1405
+
1406
+ # ==================== UNSUPPORTED MODEL TYPE ====================
1407
+ else:
1408
+ logger.warning(f"Unsupported model_type: {model_type}")
1409
+ generic_fallback = f"""
1410
+ ## Clinical Assessment
1411
+ - Unsupported model type: {model_type}
1412
+
1413
+ ## Key Trends & Changes
1414
+ - Please use model_type: gguf, text-generation, or summarization
1415
+
1416
+ ## Plan & Suggested Actions
1417
+ - Update API request with supported model type.
1418
+
1419
+ ## Direct Guidance for Physician
1420
+ - System configuration error — contact administrator.
1421
+ """
1422
+ return jsonify({
1423
+ "summary": ensure_four_sections(generic_fallback),
1424
+ "baseline": baseline,
1425
+ "delta": delta_text,
1426
+ "warning": f"Unsupported model_type: {model_type}",
1427
+ "supported_types": ["gguf", "text-generation", "causal-openvino", "summarization"],
1428
+ "timing": {"total": round(time.time() - start_total, 1)},
1429
+ "timeout_mode_used": timeout_mode
1430
+ }), 400
1431
+
1432
+ except Exception as e:
1433
+ logger.error(f"🚨 CRITICAL ERROR: {str(e)}", exc_info=True)
1434
+ # ⚡ FINAL FALLBACK — NEVER FAIL
1435
+ emergency_fallback = """
1436
+ ## Clinical Assessment
1437
+ - System emergency fallback — critical error occurred.
1438
+
1439
+ ## Key Trends & Changes
1440
+ - No data available due to system error.
1441
+
1442
+ ## Plan & Suggested Actions
1443
+ - Retry request or contact system administrator.
1444
+
1445
+ ## Direct Guidance for Physician
1446
+ - DO NOT rely on this summary — system malfunction.
1447
+ """
1448
  return jsonify({
1449
+ "summary": ensure_four_sections(emergency_fallback),
1450
+ "warning": "Critical system error — used emergency fallback.",
1451
+ "error": str(e),
1452
+ "timing": {"total": round(time.time() - start_total, 1)},
1453
+ "timeout_mode_used": data.get("timeout_mode", "fast") if request.get_json() else "unknown"
1454
  }), 200
1455
+
 
 
 
 
 
 
 
1456
  @app.route("/")
1457
  def home():
1458
  return "Medical Data Extraction API is running!", 200
1459
 
 
1460
  def summary_to_markdown(summary):
1461
  import re
1462
  # Remove '- answer:' and similar artifacts
1463
  summary = re.sub(r'-\s*answer: ?', '', summary, flags=re.IGNORECASE)
 
1464
  # Convert numbered sections to markdown headers
1465
  lines = summary.splitlines()
1466
  out = []
 
1470
  '3.': '##',
1471
  '4.': '##',
1472
  }
 
1473
  for line in lines:
1474
  m = re.match(r'^(\d\.)\s*(.+)', line)
1475
  if m and m.group(1) in section_map:
 
1477
  out.append(f"{header} {m.group(2).strip()}")
1478
  else:
1479
  out.append(line)
 
1480
  # Remove empty lines at the start
1481
  while out and not out[0].strip():
1482
  out = out[1:]
 
1483
  # Check if we have the expected 4-section structure
1484
  def is_header(line: str) -> bool:
1485
  return bool(re.match(r'^(#{1,6})\s+.+', line.strip()))
 
1486
  # Find all headers in the output
1487
  headers = [i for i, line in enumerate(out) if is_header(line)]
 
1488
  # If we have at least 4 headers, check if they match the expected structure
1489
  if len(headers) >= 4:
1490
  header_texts = [out[i].strip() for i in headers[:4]]
 
1494
  r'##.*Plan.*Suggested.*Actions',
1495
  r'##.*Direct.*Guidance.*Physician'
1496
  ]
 
1497
  # Check if headers match expected patterns
1498
  matches_pattern = all(
1499
  re.search(pattern, header, re.IGNORECASE)
1500
  for pattern, header in zip(expected_patterns, header_texts)
1501
  )
 
1502
  if matches_pattern:
1503
  # Keep the entire content - don't truncate
1504
  return '\n'.join(out).strip()
 
1505
  # If we don't have the expected structure, try to find the actual summary content
1506
  # Look for the start of the clinical assessment section
1507
  clinical_assessment_pattern = r'(?:# Clinical Assessment|## Clinical Assessment|Clinical Assessment)'
1508
  for i, line in enumerate(out):
1509
  if re.search(clinical_assessment_pattern, line, re.IGNORECASE):
1510
  return '\n'.join(out[i:]).strip()
 
1511
  # If no clinical assessment found, return the entire summary
1512
+ return '\n'.join(out).strip()
ai_med_extract/app.py CHANGED
@@ -13,6 +13,7 @@ from .agents.medical_data_extractor import MedicalDocDataExtractorAgent
13
  from .agents.patient_summary_agent import PatientSummarizerAgent
14
  from .utils.model_manager import model_manager
15
  import torch
 
16
 
17
  # Load environment variables
18
  load_dotenv()
 
13
  from .agents.patient_summary_agent import PatientSummarizerAgent
14
  from .utils.model_manager import model_manager
15
  import torch
16
+ torch.set_num_threads(1) # Prevent PyTorch thread fighting with llama.cpp
17
 
18
  # Load environment variables
19
  load_dotenv()
ai_med_extract/utils/__pycache__/model_loader_gguf.cpython-311.pyc CHANGED
Binary files a/ai_med_extract/utils/__pycache__/model_loader_gguf.cpython-311.pyc and b/ai_med_extract/utils/__pycache__/model_loader_gguf.cpython-311.pyc differ
 
ai_med_extract/utils/__pycache__/model_manager.cpython-311.pyc CHANGED
Binary files a/ai_med_extract/utils/__pycache__/model_manager.cpython-311.pyc and b/ai_med_extract/utils/__pycache__/model_manager.cpython-311.pyc differ
 
ai_med_extract/utils/__pycache__/openvino_summarizer_utils.cpython-311.pyc CHANGED
Binary files a/ai_med_extract/utils/__pycache__/openvino_summarizer_utils.cpython-311.pyc and b/ai_med_extract/utils/__pycache__/openvino_summarizer_utils.cpython-311.pyc differ
 
ai_med_extract/utils/model_config.py CHANGED
@@ -62,7 +62,7 @@ SPACES_OPTIMIZED_MODELS = {
62
  MODEL_VALIDATION_RULES = {
63
  "text-generation": {
64
  "min_tokens": 100,
65
- "max_tokens": 2048,
66
  "supported_formats": ["huggingface", "local"]
67
  },
68
  "summarization": {
@@ -82,7 +82,7 @@ MODEL_VALIDATION_RULES = {
82
  },
83
  "openvino": {
84
  "min_tokens": 100,
85
- "max_tokens": 2048,
86
  "supported_formats": ["huggingface", "local"]
87
  }
88
  }
 
62
  MODEL_VALIDATION_RULES = {
63
  "text-generation": {
64
  "min_tokens": 100,
65
+ "max_tokens": 4000,
66
  "supported_formats": ["huggingface", "local"]
67
  },
68
  "summarization": {
 
82
  },
83
  "openvino": {
84
  "min_tokens": 100,
85
+ "max_tokens": 4000,
86
  "supported_formats": ["huggingface", "local"]
87
  }
88
  }
ai_med_extract/utils/model_loader_gguf.py CHANGED
@@ -59,7 +59,7 @@ class GGUFModelPipeline:
59
  # Memory-optimized settings for Hugging Face Spaces
60
  self.model = Llama(
61
  model_path=local_path,
62
- n_ctx=2048, # Reduced from 4096 to save memory
63
  n_threads=n_threads,
64
  n_batch=n_batch,
65
  n_gpu_layers=0, # CPU-only on Spaces by default
@@ -90,7 +90,7 @@ class GGUFModelPipeline:
90
  """Generate text with timeout using threading"""
91
  # Approximate token count by splitting on whitespace
92
  prompt_tokens = len(prompt.split())
93
- n_ctx = 2048
94
  allowed_max_tokens = n_ctx - prompt_tokens
95
  if allowed_max_tokens <= 0:
96
  raise ValueError(f"Prompt too long: {prompt_tokens} tokens exceed context window of {n_ctx}")
@@ -105,7 +105,7 @@ class GGUFModelPipeline:
105
  max_tokens=max_tokens,
106
  temperature=temperature,
107
  top_p=top_p,
108
- stop=["</s>", "###"]
109
  )
110
  return output
111
  except Exception as e:
 
59
  # Memory-optimized settings for Hugging Face Spaces
60
  self.model = Llama(
61
  model_path=local_path,
62
+ n_ctx=4000, # Reduced from 4096 to save memory
63
  n_threads=n_threads,
64
  n_batch=n_batch,
65
  n_gpu_layers=0, # CPU-only on Spaces by default
 
90
  """Generate text with timeout using threading"""
91
  # Approximate token count by splitting on whitespace
92
  prompt_tokens = len(prompt.split())
93
+ n_ctx = 4000
94
  allowed_max_tokens = n_ctx - prompt_tokens
95
  if allowed_max_tokens <= 0:
96
  raise ValueError(f"Prompt too long: {prompt_tokens} tokens exceed context window of {n_ctx}")
 
105
  max_tokens=max_tokens,
106
  temperature=temperature,
107
  top_p=top_p,
108
+ stop=["\n\n##", "\n\n#", "###", "</s>", "<|endoftext|>", "User:", "System:"]
109
  )
110
  return output
111
  except Exception as e:
ai_med_extract/utils/model_manager.py CHANGED
@@ -100,7 +100,7 @@ class TransformersModelLoader(BaseModelLoader):
100
  if self.model_type == "text-generation":
101
  result = pipeline(
102
  prompt,
103
- max_new_tokens=kwargs.get('max_new_tokens', 2048),
104
  do_sample=kwargs.get('do_sample', False),
105
  temperature=kwargs.get('temperature', 0.7),
106
  pad_token_id=self._tokenizer.eos_token_id
@@ -179,7 +179,7 @@ class GGUFModelLoader(BaseModelLoader):
179
  pipeline = self.load()
180
 
181
  try:
182
- max_tokens = kwargs.get('max_tokens', 2048)
183
  temperature = kwargs.get('temperature', 0.7)
184
  top_p = kwargs.get('top_p', 0.95)
185
 
 
100
  if self.model_type == "text-generation":
101
  result = pipeline(
102
  prompt,
103
+ max_new_tokens=kwargs.get('max_new_tokens', 4000),
104
  do_sample=kwargs.get('do_sample', False),
105
  temperature=kwargs.get('temperature', 0.7),
106
  pad_token_id=self._tokenizer.eos_token_id
 
179
  pipeline = self.load()
180
 
181
  try:
182
+ max_tokens = kwargs.get('max_tokens', 4000)
183
  temperature = kwargs.get('temperature', 0.7)
184
  top_p = kwargs.get('top_p', 0.95)
185
 
ai_med_extract/utils/openvino_summarizer_utils.py CHANGED
@@ -6,6 +6,8 @@ import difflib
6
  import logging
7
  from copy import deepcopy
8
 
 
 
9
  def parse_ehr_chartsummarydtl(chartsummarydtl):
10
  """
11
  Converts EHR API chartsummarydtl list to the internal visit format expected by the summarizer.
@@ -169,29 +171,72 @@ def delta_to_text(delta):
169
  L.append(f"{lab_name}: {_fmt(lab_data['prev'])} -> {_fmt(lab_data['curr'])} (Δ {_fmt(lab_data['delta'], '+.1f')})")
170
 
171
  return "\n".join(L)
172
- def build_main_prompt(baseline, delta_text, patient_info=""):
173
- # return("You are an expert clinical AI assistant.\n"
174
- # "Produce a concise, physician-ready update. Never omit critical new information from the deltas.\n\n"
175
- # "The summary MUST have four sections:\n"
176
- # "1) Clinical Assessment\n"
177
- # "2) Key Trends & Changes\n"
178
- # "3) Plan & Suggested Actions\n"
179
- # "4) Direct Guidance for Physician\n\n"
180
- # "Now generate the complete, updated clinical summary with all four sections in markdown format:")
181
- return (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  "You are an expert clinical AI assistant. Your task is to generate a patient summary.\n"
183
  "Use the chartsummarydtl for context. The STRUCTURED BASELINE and DELTAS are the absolute ground truth.\n"
184
  "Produce a concise, physician-ready summary. Never omit critical new information from the deltas.\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  "The summary MUST have four sections:\n"
186
  "1) Clinical Assessment\n"
187
  "2) Key Trends & Changes\n"
188
  "3) Plan & Suggested Actions\n"
189
  "4) Direct Guidance for Physician\n\n"
190
- f"PATIENT INFORMATION:\n{patient_info}\n\n"
191
- f"STRUCTURED BASELINE (authoritative):\n{baseline}\n\n"
192
- f"STRUCTURED DELTAS (authoritative):\n{delta_text}\n\n"
193
  "Now generate the complete clinical summary with all four sections in markdown format:"
194
  )
 
195
  def validate_and_compare_summaries(old_summary, new_summary, update_name=""):
196
  report = f"### Validation Report for {update_name}\n"
197
  report += "This report validates that the updated summary incorporates new information correctly.\n"
 
6
  import logging
7
  from copy import deepcopy
8
 
9
+
10
+
11
  def parse_ehr_chartsummarydtl(chartsummarydtl):
12
  """
13
  Converts EHR API chartsummarydtl list to the internal visit format expected by the summarizer.
 
171
  L.append(f"{lab_name}: {_fmt(lab_data['prev'])} -> {_fmt(lab_data['curr'])} (Δ {_fmt(lab_data['delta'], '+.1f')})")
172
 
173
  return "\n".join(L)
174
+
175
+ from concurrent.futures import ThreadPoolExecutor, as_completed
176
+ import threading
177
+
178
+ def generate_section(pipeline, prompt, section_name, timeout=60):
179
+ """Generate one section with timeout protection."""
180
+ try:
181
+ # If your pipeline supports timeout, pass it. Otherwise, wrap in future.
182
+ from concurrent.futures import ThreadPoolExecutor as TPE, TimeoutError as TE
183
+ with TPE(max_workers=1) as executor:
184
+ future = executor.submit(pipeline.generate_full_summary, prompt, max_tokens=2000, max_loops=3)
185
+ raw = future.result(timeout=timeout)
186
+
187
+ # Clean: remove instruction residue, extract content
188
+ patterns_to_split = [
189
+ "Now generate the complete",
190
+ "## Clinical Assessment",
191
+ "# Clinical Assessment",
192
+ "Clinical Assessment",
193
+ "Output ONLY the section content"
194
+ ]
195
+ content = raw
196
+ for pat in patterns_to_split:
197
+ if pat in content:
198
+ content = content.split(pat)[-1].strip()
199
+
200
+ # Ensure it starts with section header if not present
201
+ header = f"## {section_name}"
202
+ if not content.startswith(header):
203
+ content = f"{header}\n{content.strip()}"
204
+
205
+ return content.strip()
206
+ except Exception as e:
207
+ # Return placeholder if generation fails
208
+ logging.Logger.error(f"Section '{section_name}' generation failed: {e}")
209
+ return f"## {section_name}\n- *Generation failed or timed out. Please retry or check logs.*"
210
+ def build_main_prompt(baseline, delta_text, patient_info="", section=None):
211
+ base_prompt = (
212
  "You are an expert clinical AI assistant. Your task is to generate a patient summary.\n"
213
  "Use the chartsummarydtl for context. The STRUCTURED BASELINE and DELTAS are the absolute ground truth.\n"
214
  "Produce a concise, physician-ready summary. Never omit critical new information from the deltas.\n\n"
215
+ f"PATIENT INFORMATION:\n{patient_info}\n\n"
216
+ f"STRUCTURED BASELINE (authoritative):\n{baseline}\n\n"
217
+ f"STRUCTURED DELTAS (authoritative):\n{delta_text}\n\n"
218
+ )
219
+
220
+ if section:
221
+ section_prompts = {
222
+ "Clinical Assessment": "Generate ONLY the 'Clinical Assessment' section. Be concise, accurate, and evidence-based.",
223
+ "Key Trends & Changes": "Generate ONLY the 'Key Trends & Changes' section. Focus on deltas, trends, vitals, labs, and med changes.",
224
+ "Plan & Suggested Actions": "Generate ONLY the 'Plan & Suggested Actions' section. Suggest next steps, monitoring, referrals, or med adjustments.",
225
+ "Direct Guidance for Physician": "Generate ONLY the 'Direct Guidance for Physician' section. Give clear, actionable advice for the clinician."
226
+ }
227
+ instruction = section_prompts.get(section, f"Generate the '{section}' section.")
228
+ return base_prompt + f"{instruction}\n\nOutput ONLY the section content. Do not include headers unless specified.\n\n"
229
+
230
+ # Default: generate full 4-section summary
231
+ return base_prompt + (
232
  "The summary MUST have four sections:\n"
233
  "1) Clinical Assessment\n"
234
  "2) Key Trends & Changes\n"
235
  "3) Plan & Suggested Actions\n"
236
  "4) Direct Guidance for Physician\n\n"
 
 
 
237
  "Now generate the complete clinical summary with all four sections in markdown format:"
238
  )
239
+
240
  def validate_and_compare_summaries(old_summary, new_summary, update_name=""):
241
  report = f"### Validation Report for {update_name}\n"
242
  report += "This report validates that the updated summary incorporates new information correctly.\n"
generate_patient_summary_colab.py CHANGED
@@ -313,7 +313,7 @@ class GGUFModelPipeline:
313
  text = re.sub(p, "", text, flags=re.IGNORECASE)
314
  return text.strip()
315
 
316
- def _generate_with_timeout(self, prompt, max_tokens=2048, temperature=0.5, top_p=0.95, timeout=None):
317
  if timeout is None:
318
  is_hf_space = os.environ.get('SPACE_ID') is not None
319
  timeout = int(os.environ.get('GGUF_GENERATION_TIMEOUT', '600' if is_hf_space else '300'))
@@ -341,7 +341,7 @@ class GGUFModelPipeline:
341
  future.cancel()
342
  raise TimeoutError(f"Generation timed out after {timeout} seconds")
343
 
344
- def generate(self, prompt, max_tokens=2048, temperature=0.5, top_p=0.95):
345
  t0 = time.time()
346
  try:
347
  output = self._generate_with_timeout(prompt, max_tokens, temperature, top_p)
@@ -358,7 +358,7 @@ class GGUFModelPipeline:
358
  logging.error(f"Generation failed: {e}")
359
  raise RuntimeError(f"Text generation failed: {str(e)}")
360
 
361
- def generate_full_summary(self, prompt, max_tokens=2048, max_loops=5):
362
  def is_complete(text):
363
  required_sections = [
364
  'Clinical Assessment',
@@ -537,7 +537,7 @@ class SummarizerAgent:
537
  return "Input text is too short for summarization"
538
  model = self.summarization_model_loader.load()
539
  if hasattr(model, 'generate_full_summary'):
540
- summary = model.generate_full_summary(clean_text, max_tokens=2048, max_loops=2)
541
  else:
542
  # fallback simple summarization
543
  summary = model(clean_text, max_length=512, min_length=50, do_sample=False)
 
313
  text = re.sub(p, "", text, flags=re.IGNORECASE)
314
  return text.strip()
315
 
316
+ def _generate_with_timeout(self, prompt, max_tokens=4000, temperature=0.5, top_p=0.95, timeout=None):
317
  if timeout is None:
318
  is_hf_space = os.environ.get('SPACE_ID') is not None
319
  timeout = int(os.environ.get('GGUF_GENERATION_TIMEOUT', '600' if is_hf_space else '300'))
 
341
  future.cancel()
342
  raise TimeoutError(f"Generation timed out after {timeout} seconds")
343
 
344
+ def generate(self, prompt, max_tokens=4000, temperature=0.5, top_p=0.95):
345
  t0 = time.time()
346
  try:
347
  output = self._generate_with_timeout(prompt, max_tokens, temperature, top_p)
 
358
  logging.error(f"Generation failed: {e}")
359
  raise RuntimeError(f"Text generation failed: {str(e)}")
360
 
361
+ def generate_full_summary(self, prompt, max_tokens=4000, max_loops=5):
362
  def is_complete(text):
363
  required_sections = [
364
  'Clinical Assessment',
 
537
  return "Input text is too short for summarization"
538
  model = self.summarization_model_loader.load()
539
  if hasattr(model, 'generate_full_summary'):
540
+ summary = model.generate_full_summary(clean_text, max_tokens=4000, max_loops=2)
541
  else:
542
  # fallback simple summarization
543
  summary = model(clean_text, max_length=512, min_length=50, do_sample=False)