import os, re, json import time, logging, functools import pytesseract import cv2 import pdfplumber import numpy as np from PIL import Image from PyPDF2 import PdfReader from pdf2image import convert_from_path from flask import Flask, request, jsonify from flask_cors import CORS import torch from werkzeug.utils import secure_filename from docx import Document from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from concurrent.futures import ThreadPoolExecutor, as_completed from collections import defaultdict from huggingface_hub import login # -------------------- Logging Config -------------------- # logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[ logging.FileHandler("app.log"), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # -------------------- Execution Time Decorator -------------------- # def log_execution_time(level=logging.INFO): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): start_time = time.time() try: result = func(*args, **kwargs) duration = time.time() - start_time logger.log(level, f"โฑ๏ธ {func.__name__} executed in {duration:.6f} seconds") return result except Exception as e: duration = time.time() - start_time logger.exception(f"โŒ Exception in {func.__name__} after {duration:.6f} seconds: {e}") raise return wrapper return decorator login( "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" ) # ๐Ÿง  This will store it and every model load will use it executor = ThreadPoolExecutor(max_workers=5) logger.info("Executor initialized with 5 workers") # Set Tesseract OCR Path # in Windows # pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe" # in Linux pytesseract.pytesseract.tesseract_cmd = "/usr/local/bin/tesseract" # Set up Flask app app = Flask(__name__) CORS(app) UPLOAD_FOLDER = "uploads" ALLOWED_EXTENSIONS = {"pdf", "jpg", "jpeg", "png", "svg", "docx", "doc"} app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER # Set file size limits MAX_SIZE_PDF_DOCS = 1 * 1024 * 1024 * 1024 # *1GB* MAX_SIZE_IMAGES = 500 * 1024 * 1024 # *500MB* # # Load ClinicalBERT Model for Classification # try: # zero_shot_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") # print("โœ… zero_shot_classifier Model Loaded Successfully") # except Exception as e: # zero_shot_classifier = None # print("โŒ Error loading ClinicalBERT Model:", str(e)) if not os.path.exists(UPLOAD_FOLDER): os.makedirs(UPLOAD_FOLDER, exist_ok=True) # NER to Detect medical info CONFIDENCE_THRESHOLD = 0.80 @log_execution_time() def extract_medical_entities(text): if not text or not text.strip(): return ["No medical entities found"] if ner_pipeline is None: # type: ignore logger.warning("NER model is not loaded, skipping entity extraction.") return ["No medical entities found"] ner_results = ner_pipeline(text) # type: ignore relevant_entities = { # Diseases & Symptoms "Disease", "MedicalCondition", "Symptom", "Sign_or_Symptom", "B-DISEASE", "I-DISEASE", # Tests, Measurements, and Lab Values "Test", "Measurement", "B-TEST", "I-TEST", "Lab_value", "B-Lab_value", "I-Lab_value", # Medications, Treatments, and Procedures "Medication", "B-MEDICATION", "I-MEDICATION", "Treatment", "Procedure", "B-Diagnostic_procedure", "I-Diagnostic_procedure", # Body Parts & Medical Anatomy "Anatomical_site", "Body_Part", "Organ_or_Tissue", # Medical Procedures "Diagnostic_procedure", "Surgical_Procedure", "Therapeutic_Procedure", # Clinical Terms "Health_condition", "B-Health_condition", "I-Health_condition", "Pathological_Condition", "Clinical_Event", # Biological & Chemical Substances (Relevant to Lab Reports) "Chemical_Substance", "B-Chemical_Substance", "I-Chemical_Substance", "Biological_Entity", "B-Biological_Entity", "I-Biological_Entity", } medical_entities = set() for ent in ner_results: entity_label = ent.get("entity_group") or ent.get("entity") if entity_label in relevant_entities and ent["score"] >= CONFIDENCE_THRESHOLD: word = ent["word"].lower().strip().replace("-", "") # Normalize text if len(word) > 2: # Ignore short/junk words medical_entities.add(word) if len(medical_entities) >= 5: logger.info(f"Extracted {len(medical_entities)} medical entities") return list(medical_entities) logger.info("Not enough medical entities found") return ["No medical entities found"] # Validation: Check Allowed File Types def allowed_file(filename): return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS # Validation: Check File Size def check_file_size(file): file.seek(0, os.SEEK_END) size = file.tell() file.seek(0) extension = file.filename.rsplit(".", 1)[-1].lower() logger.info(f"Checking file size for '{file.filename}' - Size: {size} bytes") if extension in {"pdf", "docx"} and size > MAX_SIZE_PDF_DOCS: logger.warning(f"{file.filename} exceeds 1GB limit") return False, f"File {file.filename} exceeds 1MB size limit" elif extension in {"jpg", "jpeg", "png"} and size > MAX_SIZE_IMAGES: logger.warning(f"{file.filename} exceeds 500MB image limit") return False, f"Image {file.filename} exceeds 500KB size limit" return True, None @log_execution_time() def extract_patient_name(text, qa_pipeline): if not text or not qa_pipeline: return None try: result = qa_pipeline(question="What is the patient's name?", context=text) answer = result.get("answer", "").strip() logger.info(f"Extracted patient name: {answer}") return answer except Exception as e: logger.error(f"Error extracting patient name: {e}") return None def normalize_name(name): """Cleans and normalizes names for comparison, removing salutations dynamically""" if not name: return "" name = name.lower().strip() name = re.sub(r"[^\w\s]", "", name) name = re.sub(r"^\b\w{1,5}\b\s+", "", name) # Matches short words at the start return name @log_execution_time() def validate_patient_name(extracted_text, patient_name, filename, qa_pipeline): """Validates if the extracted name matches the registered patient name""" detected_name = extract_patient_name(extracted_text, qa_pipeline) if not detected_name: logger.warning(f"Could not determine patient name from {filename}") return ( jsonify({"error": f"Could not determine patient name from {filename}"}), 400, ) normalized_detected_name = normalize_name(detected_name) normalized_patient_name = normalize_name(patient_name) if normalized_detected_name not in normalized_patient_name: logger.warning( f"Patient mismatch in file '{filename}': Found '{detected_name}'" ) return ( jsonify( { "error": f"Document '{filename}' does not belong to {patient_name}. Found: {detected_name}" } ), 400, ) logger.info(f"Patient name validation passed for '{filename}'") return None # No error, validation passed # Check if the image is blurred using the Laplacian method def is_blurred(image_path, variance_threshold=150): try: image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) if image is None: logger.error(f"Unable to read image: {image_path}") return True # Assume it's blurry if not readable # Compute Laplacian variance laplacian_var = cv2.Laplacian(image, cv2.CV_64F).var() logger.info( f"Blur Check on '{image_path}': Laplacian Variance = {laplacian_var:.2f} (Threshold = {variance_threshold})" ) # Compute Edge Density (Additional Check) edges = cv2.Canny(image, 50, 150) edge_density = np.mean(edges) logger.info(f"Edge Density for '{image_path}': {edge_density:.2f}") is_blurry = laplacian_var < variance_threshold and edge_density < 10 if is_blurry: logger.warning(f"Image '{image_path}' flagged as blurry.") return is_blurry except Exception as e: logger.exception(f"Exception during blur detection for '{image_path}': {e}") return True # Assume it's blurry on failure # Helper Function: Extract Text from Images (OCR) with Blur Detection @log_execution_time() def extract_text_from_image(filepath): try: # Check if the image is blurry if is_blurred(filepath): logger.warning(f"OCR skipped: '{filepath}' is too blurry.") return "Image is too blurry, OCR failed." image = cv2.imread(filepath) if image is None: logger.error(f"OCR failed: Unable to read image '{filepath}'.") return "Image could not be read." # Convert to Grayscale and Apply Thresholding gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) gray = cv2.GaussianBlur(gray, (5, 5), 0) gray = cv2.adaptiveThreshold( gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2 ) # Apply dilation (bolds the text) for better OCR accuracy kernel = np.ones((2, 2), np.uint8) gray = cv2.dilate(gray, kernel, iterations=1) processed_path = f"{filepath}_processed.png" cv2.imwrite(processed_path, gray) logger.info(f"Image preprocessed and saved: {processed_path}") text = pytesseract.image_to_string( Image.open(processed_path), lang="eng" ).strip() # Validate OCR output (Reject if too little text is extracted) word_count = len(text.split()) logger.info( f"OCR completed for '{filepath}' with {word_count} words extracted." ) if word_count < 5: logger.warning(f"OCR output too small for '{filepath}'. Might be junk.") return "OCR failed to extract meaningful text." return text except Exception as e: logger.exception(f"Error extracting text from image '{filepath}': {e}") return "Failed to extract text" # Helper Function: Extract Text from PDF @log_execution_time() def extract_text_from_pdf(filepath, password=None): """Extract text from PDFs using pdfplumber (faster) or OCR (if needed).""" text = "" try: logger.info(f"Starting PDF extraction: {filepath}") reader = PdfReader(filepath) if reader.is_encrypted: if not password: logger.warning("Encrypted PDF without password.") return { "error": "File is password-protected. Please provide a password." }, 401 # โœ… Attempt to decrypt decryption_result = reader.decrypt(password) if decryption_result == 0: # Decryption failed logger.error("Incorrect password provided.") return {"error": "Invalid password provided."}, 403 else: logger.info("PDF decryption successful.") text = "\n".join([page.extract_text() or "" for page in reader.pages]) if text.strip(): logger.info("Text extracted from decrypted PDF.") return text.strip(), 200 # โœ… Now, use pdfplumber for text extraction with pdfplumber.open(filepath) as pdf: for page in pdf.pages: page_text = page.extract_text() if page_text: text += page_text + "\n" if text.strip(): logger.info( f"PDF text extracted using pdfplumber: {len(text.split())} words." ) return text.strip(), 200 # โœ… Always return a tuple (text, status) logger.info("No text found via pdfplumber. Falling back to OCR.") # โœ… Use OCR if the PDF has no selectable text images = convert_from_path(filepath) with ThreadPoolExecutor(max_workers=5) as pool: ocr_text = list( pool.map( lambda img: pytesseract.image_to_string(img, lang="eng"), images ) ) full_ocr_text = "\n".join(ocr_text).strip() logger.info( f"OCR fallback complete for PDF: {len(full_ocr_text.split())} words extracted." ) return (full_ocr_text, 200) if full_ocr_text else ("No text found", 415) except Exception as e: logger.exception(f"Error during PDF processing: {filepath}") return "Failed to extract text" # Helper Function: Extract Text from DOCX @log_execution_time() def extract_text_from_docx(filepath): try: doc = Document(filepath) text = "\n".join([para.text for para in doc.paragraphs]) word_count = len(text.split()) logger.info(f"DOCX extracted from '{filepath}': {word_count} words.") return text.strip() or None except Exception as e: logger.exception(f"Failed to extract text from DOCX: {filepath}") return None # Masking function to hide sensitive data def mask_sensitive_info(text): text = re.sub(r"(?<=\b\w{2})\w+(?=\s\w{2,})", "*", text) # Mask names text = re.sub( r"\b(\d{2})\d{2}-(\d{2})\d{2}-(\d{2})\d{2}\b", r"\2-\3-", text ) # Mask DOB text = re.sub(r"\b(\d{8})(\d{2})\b", r"\2", text) # Mask phone numbers return text # ------------------Upload Documents ------------------ # # API Route: Upload File & Extract Text @app.route("/upload", methods=["POST"]) @log_execution_time() def upload_file(): logger.info("๐Ÿ“ฅ Upload request received") files = request.files.getlist("file") patient_name = request.form.get("patient_name", "").strip() password = request.form.get("password") # Get password if provided # Dynamic model info from form qa_model_name = request.form.get("qa_model_name") qa_model_type = request.form.get("qa_model_type") ner_model_name = request.form.get("ner_model_name") ner_model_type = request.form.get("ner_model_type") summarizer_model_name = request.form.get("summarizer_model_name") summarizer_model_type = request.form.get("summarizer_model_type") if not files: logger.warning("No file uploaded") return jsonify({"error": "No file uploaded"}), 400 # ๐Ÿ”Œ Load models dynamically try: qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name) logger.info(f"โœ… QA model loaded: {qa_model_name}") except Exception as e: logger.error(f"โŒ QA model load failed: {e}") return jsonify({"error": f"QA model load failed: {str(e)}"}), 500 try: ner_pipeline = pipeline(task=ner_model_type, model=ner_model_name) logger.info(f"โœ… NER model loaded: {ner_model_name}") except Exception as e: logger.error(f"โŒ NER model load failed: {e}") return jsonify({"error": f"NER model load failed: {str(e)}"}), 500 try: summarizer_pipeline = pipeline( task=summarizer_model_type, model=summarizer_model_name ) logger.info(f"โœ… Summarizer model loaded: {summarizer_model_name}") except Exception as e: logger.error(f"โŒ Summarizer model load failed: {e}") return jsonify({"error": f"Summarizer model load failed: {str(e)}"}), 500 extracted_data = [] print(patient_name) for file in files: logger.info(f"๐Ÿ“‚ Processing file: {file.filename}") if file.filename == "": logger.warning("Skipping unnamed file") continue # Skip empty file names if not allowed_file(file.filename): logger.warning(f"Unsupported file type: {file.filename}") return ( jsonify( { "error": f"Unsupported file type: {file.filename}. Supported file types are: {', '.join(ALLOWED_EXTENSIONS)}" } ), 400, ) if not patient_name: logger.warning("Patient name missing") return jsonify({"error": "Patient name is missing"}), 400 # *Check file size* valid_size, error_message = check_file_size(file) if not valid_size: logger.warning(f"โŒ File size validation failed: {error_message}") return jsonify({"error": error_message}), 400 filename = secure_filename(file.filename) filepath = os.path.join(UPLOAD_FOLDER, filename) file.save(filepath) logger.info(f"โœ… File saved: {filepath}") extracted_text = None # โœ… *Extract text based on file type* if filename.endswith(".pdf"): logger.info("๐Ÿงพ Extracting text from PDF") result = extract_text_from_pdf(filepath, password) # โœ… If PDF requires a password, return 401 if isinstance(result, tuple): extracted_text, status_code = result else: extracted_text = result status_code = 200 if isinstance(extracted_text, dict) and "error" in extracted_text: logger.warning(f"โš ๏ธ PDF extraction error: {extracted_text}") return jsonify(extracted_text), status_code elif filename.endswith(".docx"): extracted_text = extract_text_from_docx(filepath) elif filename.endswith((".jpg", ".jpeg", ".png", ".svg")): logger.info("๐Ÿ–ผ๏ธ Extracting text from image") extracted_text = extract_text_from_image(filepath) if not extracted_text or extracted_text == "No text found": logger.warning(f"โš ๏ธ No text extracted from {filename}") return ( jsonify({"error": f"Failed to extract text from {filename}"}), 415, ) # Unsupported Media Type # reject blurred images if extracted_text in [ "Image is too blurry, OCR failed.", "OCR failed to extract meaningful text.", ]: logger.warning(f"๐Ÿ” OCR failed or image too blurry: {filename}") return ( jsonify( {"error": f"'{filename}' is too blurry or text is unreadable."} ), 422, ) # Unprocessable Entity # โœ… Medical Validation using NER skip_medical_check = ( request.form.get("skip_medical_check", "false").lower() == "true" ) if not skip_medical_check: logger.info("๐Ÿง  Running NER medical validation") start_time = time.time() ner_results = ner_pipeline(extracted_text) medical_entities = list( set( [ r["word"] for r in ner_results if r["entity"].startswith("B-") or r["entity"].startswith("I-") ] ) ) elapsed_time = time.time() - start_time logger.info(f"โฑ๏ธ Medical entity validation took {elapsed_time:.2f}s") logger.info(f"๐Ÿฉบ Medical entities found: {medical_entities}") if not medical_entities: logger.warning(f"โŒ No medical relevance in {filename}") return ( jsonify({"error": f"'{filename}' is not medically relevant"}), 406, ) else: logger.info(f"โญ๏ธ Skipping medical validation for {filename}") # # โœ… Patient Name Validation using QA # skip_patient_check = request.form.get("skip_patient_check", "false").lower() == "true" # if not skip_patient_check: # try: # logger.info("๐Ÿง Validating patient name") # start_time = time.time() # error_response = validate_patient_name(extracted_text, patient_name, filename,qa_pipeline) # elapsed_time = time.time() - start_time # logger.info(f"โฑ๏ธ Patient name validation took {elapsed_time:.2f}s") # if error_response: # return error_response # except Exception as e: # logger.error(f"โŒ Patient name validation failed: {e}") # return jsonify({"error": f"Patient name validation failed: {str(e)}"}), 500 # else: # logger.info(f"โญ๏ธ Skipping patient name validation for {filename}") # โœจ Generate Summary using Summarizer try: logger.info("๐Ÿ“ Generating summary: %s", extracted_text) start_time = time.time() summary = summarizer_pipeline( extracted_text, max_length=350, min_length=50, do_sample=False )[0]["summary_text"] elapsed_time = time.time() - start_time logger.info(f"โœ… Summary generated: {summary}") logger.info(f"โฑ๏ธ Summary generation took {elapsed_time:.2f} seconds") except Exception as e: summary = "Summary failed" logger.warning(f"โš  Summary generation failed: {e}") # # Classify report type # report_type = classify_medical_document(extracted_text) # print(report_type) # โœ… Summarize extracted text extracted_data.append( { "file": filename, # "document_type": report_type, "extracted_text": extracted_text, "summary": summary, "message": "Successful", } ) logger.info(f"โœ… Finished processing file: {filename}") if not extracted_data: logger.warning("โŒ No valid medical files processed") return jsonify({"error": "No valid medical files processed"}), 400 logger.info("๐Ÿ“ฆ Upload processing completed successfully") return jsonify({"extracted_data": extracted_data}), 200 # # API Route: Extract Medical Data Based on Predefined Questions # @app.route('/extract_medical_data', methods=['POST']) # def extract_medical_data(): # data = request.json # print(f"๐Ÿ“ฅ Incoming request data: {data}") # qa_model_name = data.get("qa_model_name") # qa_model_type = data.get("qa_model_type") # if "extracted_data" not in data: # return jsonify({"error": "Missing 'extracted_data' in request"}), 400 # if not qa_model_name or not qa_model_type: # return jsonify({"error": "Missing 'model_name' or 'model_type'"}), 400 # try: # print(f"๐ŸŒ€ Loading model: {qa_model_name} ({qa_model_type})") # qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name) # print(f"โœ… Model loaded: {qa_pipeline.model.config._name_or_path}") # except Exception as e: # print("โŒ Error loading model:", str(e)) # return jsonify({"error": f"Could not load model: {str(e)}"}), 500 # questions = { # "Patient Name": "What is the patient's name?", # "Age": "What is the patient's age?", # "Gender": "What is the patient's gender?", # "Date of Birth": "What is the patient's date of birth?", # "Patient ID": "What is the patient ID?", # "Reason for Visit": "What is the reason for the patient's visit?", # "Physician": "Who is the physician in charge of the patient?", # "Test Date": "What is the test date?", # "Hemoglobin": "What is the patient's hemoglobin level?", # "Blood Glucose (Fasting)": "What is the patient's fasting blood glucose level?", # "Total Cholesterol": "What is the total cholesterol level?", # "LDL Cholesterol": "What is the LDL cholesterol level?", # "HDL Cholesterol": "What is the HDL cholesterol level?", # "Serum Creatinine": "What is the serum creatinine level?", # "Vitamin D (25-OH)": "What is the patient's Vitamin D level?", # "Height": "What is the patient's height?", # "Weight": "What is the patient's weight?", # "Blood Pressure (Systolic)": "What is the patient's systolic blood pressure?", # "Blood Pressure (Diastolic)": "What is the patient's diastolic blood pressure?", # "Recommendations": "What are the recommendations based on the test results?" # } # structured_response = {"extracted_data": []} # for file_data in data["extracted_data"]: # filename = file_data["file"] # context = file_data["extracted_text"] # if not context: # structured_response["extracted_data"].append({ # "file": filename, # "medical_terms": "No data extracted" # }) # continue # # Prepare batch QA input # qa_inputs = [ # {"question": q, "context": context} # for q in questions.values() # ] # try: # qa_outputs = qa_pipeline(qa_inputs) # print("๐Ÿ“ค Batch QA outputs:", qa_outputs) # except Exception as e: # print("โš ๏ธ Batch failed, falling back to loop:", str(e)) # qa_outputs = [qa_pipeline(q) for q in qa_inputs] # # Map answers back to questions # extracted_info = {} # for i, key in enumerate(questions.keys()): # answer = qa_outputs[i].get("answer", "").strip() # score = qa_outputs[i].get("score", 0.0) # # If the model returns an empty string or very low confidence, mark as "Not Mentioned" # if not answer or score < 0.1: # extracted_info[key] = "Not Mentioned" # else: # extracted_info[key] = answer # # Optional: Clean results # # extracted_info = {k: clean_result(v) for k, v in extracted_info.items()} # categorized_data = [ # { # "name": "Patient Information", # "fields": [ # {"label": "Patient Name", "value": extracted_info.get("Patient Name", "")}, # {"label": "Date of Birth", "value": extracted_info.get("Date of Birth", "")}, # {"label": "Gender", "value": extracted_info.get("Gender", "")}, # {"label": "Patient ID", "value": extracted_info.get("Patient ID", "")} # ] # }, # { # "name": "Vitals", # "fields": [ # {"label": "Height", "value": extracted_info.get("Height", "")}, # {"label": "Weight", "value": extracted_info.get("Weight", "")}, # {"label": "Blood Pressure", "value": f"{extracted_info.get('Blood Pressure (Systolic)', '')}/{extracted_info.get('Blood Pressure (Diastolic)', '')} mmHg"}, # {"label": "Hemoglobin", "value": extracted_info.get("Hemoglobin", "")}, # {"label": "Serum Creatinine", "value": extracted_info.get("Serum Creatinine", "")} # ] # }, # { # "name": "Lab Results", # "fields": [ # {"label": "Blood Glucose (Fasting)", "value": extracted_info.get("Blood Glucose (Fasting)", "")}, # {"label": "Total Cholesterol", "value": extracted_info.get("Total Cholesterol", "")}, # {"label": "LDL Cholesterol", "value": extracted_info.get("LDL Cholesterol", "")}, # {"label": "HDL Cholesterol", "value": extracted_info.get("HDL Cholesterol", "")}, # {"label": "Vitamin D (25-OH)", "value": extracted_info.get("Vitamin D (25-OH)", "")} # ] # }, # { # "name": "Medical Notes", # "fields": [ # {"label": "Reason for Visit", "value": extracted_info.get("Reason for Visit", "")}, # {"label": "Physician", "value": extracted_info.get("Physician", "")}, # {"label": "Test Date", "value": extracted_info.get("Test Date", "")}, # {"label": "Recommendations", "value": extracted_info.get("Recommendations", "")} # ] # } # ] # structured_response["extracted_data"].append({ # "file": filename, # "medical_terms": extracted_info, # "categorized_data": categorized_data # }) # save_data_to_storage(filename, structured_response) # print(f"โœ… Extracted data saved to: {os.path.join(UPLOAD_FOLDER, f'{filename}.json')}") # return jsonify(structured_response) # ------------------ CLEAN FUNCTION ------------------ # @log_execution_time() def clean_result(value): logger.debug("Cleaning value: %s", value) if isinstance(value, str): value = re.sub(r"\s+", " ", value) value = re.sub(r"[-_:]+", " ", value) value = re.sub(r"[^\x00-\x7F]+", " ", value) value = re.sub( r"(?<=\d),(?=\d)", "", value ) # Remove commas in numbers like 250,000 return value.strip() if value.strip() else "Not Available" elif isinstance(value, list): cleaned = [clean_result(v) for v in value if v is not None] return cleaned if cleaned else ["Not Available"] elif isinstance(value, dict): return {k: clean_result(v) for k, v in value.items()} return value # ------------------Group by Category ------------------ # @log_execution_time() def group_by_category(data): logger.info("Grouping extracted items by category") grouped = defaultdict(list) category_times = {} for item in data: cat = item.get("category", "General") start_time = time.time() grouped[cat].append( { "question": item.get("question", "Not Created"), "label": item.get("label", "Unknown"), "answer": item.get("answer", "Not Available"), } ) elapsed = time.time() - start_time category_times[cat] = category_times.get(cat, 0) + elapsed for cat, details in grouped.items(): logger.info(f"๐Ÿ“‚ Category '{cat}': {len(details)} items, time taken: {category_times[cat]:.4f}s") return [{"category": k, "detail": v} for k, v in grouped.items()] # ------------------detect duplicate to remove it ------------------ # @log_execution_time() def deduplicate_extractions(data): logger.info("Deduplicating extracted data") seen = set() unique = [] for item in data: # Use a tuple of key fields to detect duplicates key = (item.get("label")) if key not in seen: seen.add(key) unique.append(item) return unique # Load tokenizer outside the route for performance tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") # -----------------------------Split text into overlapping chunks---------------# @log_execution_time() def chunk_text(text, tokenizer, max_tokens=512, overlap=50): """ Splits text into overlapping token-based chunks without using NLTK. Args: text (str): Raw input text. tokenizer (transformers tokenizer): Hugging Face tokenizer instance. max_tokens (int): Max tokens per chunk. overlap (int): Number of overlapping tokens between chunks. Returns: List[str]: List of decoded text chunks. """ # Tokenize the full text logger.info("Splitting text into chunks") input_ids = tokenizer.encode(text, add_special_tokens=False) chunks = [] start = 0 while start < len(input_ids): end = start + max_tokens chunk_ids = input_ids[start:end] chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True) # Ensure partial continuation isn't cut off mid-sentence if not chunk_text.endswith(('.', '?', '!', ':')): chunk_text += "..." chunks.append(chunk_text) start += max_tokens - overlap logger.info("Created %d chunks", len(chunks)) return chunks # ------------------ PARSE JSON OBJECTS FROM OUTPUT ------------------ # @log_execution_time() def extract_json_objects(text): logger.info("Extracting JSON objects from text") extracted = [] try: json_start = text.index('[') json_text = text[json_start:] except ValueError: logger.warning("โš  '[' not found in output") return [] # Try parsing full array first try: parsed = json.loads(json_text) if isinstance(parsed, list): return parsed except Exception: pass # fallback to manual parsing # Manual recovery via brace matching stack = 0 obj_start = None for i, char in enumerate(json_text): if char == '{': if stack == 0: obj_start = i stack += 1 elif char == '}': stack -= 1 if stack == 0 and obj_start is not None: obj_str = json_text[obj_start:i+1] try: obj = json.loads(obj_str) extracted.append(obj) except Exception as e: logger.error(f"โŒ Invalid JSON object: {e}") obj_start = None return extracted # ------------------ PROCESS A SINGLE CHUNK ------------------ # @log_execution_time() def process_chunk(generator, chunk, idx): logger.info("Processing chunk %d", idx + 1) prompt = f""" [INST] <> You are a clinical data extraction assistant. Your job is to: 1. Read the following medical report. 2. Extract all medically relevant facts as a list of JSON objects. 3. Each object must include: - "label": a short field name (e.g., "blood pressure", "diagnosis") - "question": a question related to that field - "answer": the answer from the text 4. After extracting the list, categorize each object under one of the following fixed categories: - Patient Info - Vitals - Symptoms - Allergies - Habits - Comorbidities - Diagnosis - Medication - Laboratory - Radiology - Doctor Note Example format for structure only โ€” do not include in output: [ {{ "label": "patient name", "question": "What is the patient's name?", "answer": "John Doe", "category": "Patient Info" }}, {{ "label": "heart rate", "question": "What is the heart rate?", "answer": "78 bpm", "category": "Vitals" }} ] โš  Use these categories listed above.If an item does not fit any of these categories, create a new category for it. Text: {chunk} Return a single valid JSON array of all extracted objects. Do not include any explanations or commentary. Only output the JSON array <> [/INST] """ try: output = generator( prompt, max_new_tokens=1024, do_sample=True, temperature=0.3 )[0]["generated_text"] print("----------------------------------") logger.info(f"๐Ÿ“ค Output from chunk {idx}: {output}...") return idx, output except Exception as e: logger.error("Error processing chunk %d: %s", idx, e) return idx, None # ------------------Extract Medical Data ------------------ # @app.route("/extract_medical_data", methods=["POST"]) @log_execution_time() def extract_medical_data(): data = request.json logger.info("Received request: %s", json.dumps(data, indent=2)) qa_model_name = data.get("qa_model_name") qa_model_type = data.get("qa_model_type") extracted_files = data.get("extracted_data") if not qa_model_name or not qa_model_type: return jsonify({"error": "Missing 'qa_model_name' or 'qa_model_type'"}), 400 if not extracted_files: return jsonify({"error": "Missing 'extracted_data' in request"}), 400 try: logger.info(f"๐ŸŒ€ Loading model: {qa_model_name} ({qa_model_type})") model = AutoModelForCausalLM.from_pretrained(qa_model_name, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True) generator = pipeline(task=qa_model_type, model=model, tokenizer=tokenizer) logger.info(f"โœ… Model loaded successfully: {generator.model.config._name_or_path}") except Exception as e: logger.error("โŒ Model load failure") return jsonify({"error": f"Could not load model: {str(e)}"}), 500 structured_response = {"extracted_data": []} for file_data in extracted_files: filename = file_data.get("file", "unknown_file") context = file_data.get("extracted_text", "").strip() logger.info("Processing file: %s", filename) if not context: logger.warning("No text found in file: %s", filename) structured_response["extracted_data"].append( {"file": filename, "medical_fields": "No data extracted"} ) continue chunks = chunk_text(context, tokenizer) logger.info(f"๐Ÿ“š Chunked into {len(chunks)} parts for {filename}") all_extracted = [] # for idx,chunk in enumerate(chunks): # print(f"Processing chunk {idx+1}/{len(chunks)}") with ThreadPoolExecutor(max_workers=4) as executor: futures = { executor.submit(process_chunk, generator, chunk, idx): idx for idx, chunk in enumerate(chunks) } for future in as_completed(futures): idx = futures[future] _, output = future.result() if not output: continue try: objs = extract_json_objects(output) if objs: all_extracted.extend(objs) else: logger.error(f"โš  Chunk {idx+1} yielded no valid JSON.") except Exception as e: logger.error(f"โŒ Error extracting JSON from chunk {idx+1}") # Clean and group results for this file if all_extracted: deduped = deduplicate_extractions(all_extracted) # cleaned_json = clean_result() grouped_data = group_by_category(deduped) else: grouped_data = {"error": "No valid data extracted"} structured_response["extracted_data"].append( {"file": filename, "medical_fields": grouped_data} ) try: save_data_to_storage(filename, grouped_data) except Exception as e: logger.error(f"โš  Failed to save data for {filename}: {e}") logger.info("โœ… Extraction complete.") return jsonify(structured_response) # -------------------------- save data to a JSON file----------------------# @log_execution_time() def save_data_to_storage(filename, data): try: filename = filename.rsplit(".", 1)[0] # Remove extension filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json") logger.info(f"๐Ÿ’พ Saving to: {filepath}") with open(filepath, "w") as file: json.dump(data, file) logger.info(f"โœ… Data saved successfully to {filepath}") except Exception as e: logger.error(f"๐Ÿšจ Exception during save: {e}") # Function to get data from a JSON file # ๐Ÿ” Get data from storage @log_execution_time() def get_data_from_storage(filename): try: filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json") logger.info(f"๐Ÿ” Looking for file at: {filepath}") if not os.path.exists(filepath): logger.warning(f"๐Ÿšซ File not found at: {filepath}") return None with open(filepath, "r") as file: data = json.load(file) logger.info(f"โœ… File found and loaded: {filepath}") return data except Exception as e: logger.error(f"๐Ÿšจ Error loading data: {e}") return None # ๐Ÿ”น Fetch updated medical data @app.route("/get_updated_medical_data", methods=["GET"]) @log_execution_time() def get_updated_data(): file_name = request.args.get("file") if not file_name: return jsonify({"error": "File name is required"}), 400 # ๐Ÿ”ฅ Strip extension if present file_name = file_name.rsplit(".", 1)[0] # โœ… Load updated JSON data from storage updated_data = get_data_from_storage(file_name) if updated_data: return jsonify({"file": file_name, "data": updated_data}), 200 else: return jsonify({"error": f"File '{file_name}' not found"}), 404 @app.route("/update_medical_data", methods=["PUT"]) @log_execution_time() def update_medical_data(): try: data = request.json logger.info("Received update: %s", json.dumps(data, indent=2)) filename = data.get("file", "").rsplit(".", 1)[0] # Strip extension like .pdf updates = data.get("updates", []) if not filename or not updates: return jsonify({"error": "File name or updates missing"}), 400 # Load current stored data existing_data = get_data_from_storage(filename) if not existing_data: return jsonify({"error": f"File '{filename}' not found"}), 404 # Loop through updates and modify categorized_data for update in updates: category = update.get("category") field = update.get("field") new_value = update.get("value") updated = False for extracted in existing_data.get("extracted_data", []): for cat in extracted.get("categorized_data", []): if cat.get("name") == category: for fld in cat.get("fields", []): if fld.get("label") == field: logger.info("Updating [%s] %s โ†’ %s", category, field, new_value) fld["value"] = new_value updated = True break if updated: break if updated: break # ๐Ÿง  Sync medical_terms with categorized_data for extracted in existing_data.get("extracted_data", []): if "categorized_data" in extracted: new_terms = {} for category in extracted["categorized_data"]: for field in category.get("fields", []): label = field.get("label") value = field.get("value", "") new_terms[label] = value extracted["medical_terms"] = new_terms logger.info("Synced 'medical_terms' with 'categorized_data'") # Save updated data to file save_data_to_storage(filename, existing_data) logger.info("โœ… Updated data saved successfully") return ( jsonify( {"message": "Data updated successfully", "updated_data": existing_data} ), 200, ) except Exception as e: logger.error("Update error: %s", e) return jsonify({"error": str(e)}), 500 # Test Route @app.route("/") def home(): return "Medical Data Extraction API is running!" if __name__ == "__main__": app.run(host="0.0.0.0", port=5000, debug=True) # if __name__ == '__main__': # from gevent.pywsgi import WSGIServer # type: ignore # http_server = WSGIServer(('0.0.0.0', 5000), app) # http_server.serve_forever()