Joyna-Joy commited on
Commit
a9612bb
·
1 Parent(s): 206c5f7
ai_med_extract/api/routes.py.REMOVED.git-id CHANGED
@@ -1 +1 @@
1
- ff540d5471cce91e425947ea7e6397c986f9a7fb
 
1
+ 053c0d73058268dec33b161e1067d37c3fbe1855
ai_med_extract/utils/validation.py CHANGED
@@ -1,8 +1,43 @@
 
 
 
1
  import re
 
2
  from flask import jsonify
3
  import logging
4
  import os
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def clean_result(value):
7
  value = re.sub(r"\s+", " ", value)
8
  value = re.sub(r"[-_:]+", " ", value)
@@ -138,3 +173,207 @@ def validate_patient_name(extracted_text, patient_name, filename, qa_pipeline):
138
 
139
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import functools
3
+ import json
4
  import re
5
+ import time
6
  from flask import jsonify
7
  import logging
8
  import os
9
 
10
+
11
+ # -------------------- Logging Config -------------------- #
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format="%(asctime)s - %(levelname)s - %(message)s",
15
+ handlers=[
16
+ logging.FileHandler("app.log"),
17
+ logging.StreamHandler()
18
+ ]
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # -------------------- Execution Time Decorator -------------------- #
23
+ def log_execution_time(level=logging.INFO):
24
+ def decorator(func):
25
+ @functools.wraps(func)
26
+ def wrapper(*args, **kwargs):
27
+ start_time = time.time()
28
+ try:
29
+ result = func(*args, **kwargs)
30
+ duration = time.time() - start_time
31
+ logger.log(level, f"⏱️ {func.__name__} executed in {duration:.6f} seconds")
32
+ return result
33
+ except Exception as e:
34
+ duration = time.time() - start_time
35
+ logger.exception(f"❌ Exception in {func.__name__} after {duration:.6f} seconds: {e}")
36
+ raise
37
+ return wrapper
38
+ return decorator
39
+
40
+
41
  def clean_result(value):
42
  value = re.sub(r"\s+", " ", value)
43
  value = re.sub(r"[-_:]+", " ", value)
 
173
 
174
 
175
 
176
+ # ------------------ CLEAN FUNCTION ------------------ #
177
+ @log_execution_time()
178
+ def clean_result(value):
179
+ logger.debug("Cleaning value: %s", value)
180
+ if isinstance(value, str):
181
+ value = re.sub(r"\s+", " ", value)
182
+ value = re.sub(r"[-_:]+", " ", value)
183
+ value = re.sub(r"[^\x00-\x7F]+", " ", value)
184
+ value = re.sub(
185
+ r"(?<=\d),(?=\d)", "", value
186
+ ) # Remove commas in numbers like 250,000
187
+ return value.strip() if value.strip() else "Not Available"
188
+ elif isinstance(value, list):
189
+ cleaned = [clean_result(v) for v in value if v is not None]
190
+ return cleaned if cleaned else ["Not Available"]
191
+ elif isinstance(value, dict):
192
+ return {k: clean_result(v) for k, v in value.items()}
193
+ return value
194
+
195
+ # ------------------Group by Category ------------------ #
196
+ @log_execution_time()
197
+ def group_by_category(data):
198
+ logger.info("Grouping extracted items by category")
199
+ grouped = defaultdict(list)
200
+ category_times = {}
201
+
202
+ for item in data:
203
+ cat = item.get("category", "General")
204
+ start_time = time.time()
205
+ grouped[cat].append(
206
+ {
207
+ "question": item.get("question", "Not Created"),
208
+ "label": item.get("label", "Unknown"),
209
+ "answer": item.get("answer", "Not Available"),
210
+ }
211
+ )
212
+ elapsed = time.time() - start_time
213
+ category_times[cat] = category_times.get(cat, 0) + elapsed
214
+
215
+ for cat, details in grouped.items():
216
+ logger.info(f"📂 Category '{cat}': {len(details)} items, time taken: {category_times[cat]:.4f}s")
217
+
218
+ return [{"category": k, "detail": v} for k, v in grouped.items()]
219
+
220
+ # ------------------detect duplicate to keep latest ------------------ #
221
+ @log_execution_time()
222
+ def deduplicate_extractions(data):
223
+ logger.info("Deduplicating extracted data (keep last duplicates)")
224
+
225
+ seen = set()
226
+ reversed_unique = []
227
+
228
+ # Loop in reverse to keep the *last* occurrence
229
+ for item in reversed(data):
230
+ key = (item.get("label"))
231
+ if key not in seen:
232
+ seen.add(key)
233
+ reversed_unique.append(item)
234
+
235
+ # Reverse back to preserve original order (latest kept, first dropped)
236
+ return list(reversed(reversed_unique))
237
+
238
+ # -----------------------------Split text into overlapping chunks---------------#
239
+ @log_execution_time()
240
+ def chunk_text(text, tokenizer, max_tokens=512, overlap=50):
241
+ """
242
+ Splits text into overlapping token-based chunks without using NLTK.
243
+
244
+ Args:
245
+ text (str): Raw input text.
246
+ tokenizer (transformers tokenizer): Hugging Face tokenizer instance.
247
+ max_tokens (int): Max tokens per chunk.
248
+ overlap (int): Number of overlapping tokens between chunks.
249
+
250
+ Returns:
251
+ List[str]: List of decoded text chunks.
252
+ """
253
+ # Tokenize the full text
254
+ logger.info("Splitting text into chunks")
255
+ input_ids = tokenizer.encode(text, add_special_tokens=False)
256
+ chunks = []
257
+ start = 0
258
+ while start < len(input_ids):
259
+ end = start + max_tokens
260
+ chunk_ids = input_ids[start:end]
261
+ chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True)
262
+ # Ensure partial continuation isn't cut off mid-sentence
263
+ if not chunk_text.endswith(('.', '?', '!', ':')):
264
+ chunk_text += "..."
265
+
266
+ chunks.append(chunk_text)
267
+ start += max_tokens - overlap
268
+ logger.info("Created %d chunks", len(chunks))
269
+ return chunks
270
+
271
+ # ------------------ PARSE JSON OBJECTS FROM OUTPUT ------------------ #
272
+ @log_execution_time()
273
+ def extract_json_objects(text):
274
+ logger.info("Extracting JSON objects from text")
275
+ extracted = []
276
+ try:
277
+ json_start = text.index('[')
278
+ json_text = text[json_start:]
279
+ except ValueError:
280
+ logger.warning("⚠ '[' not found in output")
281
+ return []
282
+
283
+ # Try parsing full array first
284
+ try:
285
+ parsed = json.loads(json_text)
286
+ if isinstance(parsed, list):
287
+ return parsed
288
+ except Exception:
289
+ pass # fallback to manual parsing
290
+
291
+ # Manual recovery via brace matching
292
+ stack = 0
293
+ obj_start = None
294
+ for i, char in enumerate(json_text):
295
+ if char == '{':
296
+ if stack == 0:
297
+ obj_start = i
298
+ stack += 1
299
+ elif char == '}':
300
+ stack -= 1
301
+ if stack == 0 and obj_start is not None:
302
+ obj_str = json_text[obj_start:i+1]
303
+ try:
304
+ obj = json.loads(obj_str)
305
+ extracted.append(obj)
306
+ except Exception as e:
307
+ logger.error(f"❌ Invalid JSON object: {e}")
308
+ obj_start = None
309
+
310
+ return extracted
311
+
312
+
313
+ # ------------------ PROCESS A SINGLE CHUNK ------------------ #
314
+ @log_execution_time()
315
+ def process_chunk(generator, chunk, idx):
316
+ logger.info("Processing chunk %d", idx + 1)
317
+ prompt = f"""
318
+ [INST] <<SYS>>
319
+ You are a clinical data extraction assistant.
320
+
321
+ Your job is to:
322
+ 1. Read the following medical report.
323
+ 2. Extract all medically relevant facts as a list of JSON objects.
324
+ 3. Each object must include:
325
+ - "label": a short field name (e.g., "blood pressure", "diagnosis")
326
+ - "question": a question related to that field
327
+ - "answer": the answer from the text
328
+ 4. After extracting the list, categorize each object under one of the following fixed categories:
329
+
330
+ - Patient Info
331
+ - Vitals
332
+ - Symptoms
333
+ - Allergies
334
+ - Habits
335
+ - Comorbidities
336
+ - Diagnosis
337
+ - Medication
338
+ - Laboratory
339
+ - Radiology
340
+ - Doctor Note
341
+
342
+ Example format for structure only — do not include in output:
343
+ [
344
+ {{
345
+ "label": "patient name",
346
+ "question": "What is the patient's name?",
347
+ "answer": "Marry John",
348
+ "category": "Patient Info"
349
+ }},
350
+ ]
351
+
352
+ ⚠ Use these categories listed above.If an item does not fit any of these categories, create a new category for it.
353
+
354
+ Text:
355
+ {chunk}
356
+
357
+ Return a single valid JSON array of all extracted objects.
358
+ Do not include any explanations or commentary.
359
+ Only output the JSON array
360
+ <</SYS>> [/INST]
361
+ """
362
+
363
+ try:
364
+ output = generator(
365
+ prompt,
366
+ max_new_tokens=1024,
367
+ do_sample=True,
368
+ temperature=0.3
369
+ )[0]["generated_text"]
370
+ print("----------------------------------")
371
+ logger.info(f"📤 Output from chunk {idx}: {output}...")
372
+ return idx, output
373
+ except Exception as e:
374
+ logger.error("Error processing chunk %d: %s", idx, e)
375
+ return idx, None
376
+
377
+
378
+
379
+