import json import os import re import logging from dotenv import load_dotenv from flask import Flask, request, jsonify, abort from werkzeug.utils import secure_filename from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import pytesseract import cv2 import pdfplumber import pandas as pd from PIL import Image from docx import Document from flask_cors import CORS from flask_executor import Executor from sentence_transformers import SentenceTransformer import faiss import whisper from PyPDF2 import PdfReader from pdf2image import convert_from_path from concurrent.futures import ThreadPoolExecutor import tempfile import tensorflow.keras.layers as KL # Instead of keras.layers as KL import numpy as np # Load environment variables load_dotenv() # Set Tesseract OCR Path pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' # Initialize Flask app app = Flask(__name__) CORS(app) # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") # Configure upload directory and max file size UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads')) os.makedirs(UPLOAD_DIR, exist_ok=True) app.config['UPLOAD_FOLDER'] = UPLOAD_DIR app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size # Initialize Flask-Executor for asynchronous tasks executor = Executor(app) whisper_model = whisper.load_model("tiny") # Allowed file extensions ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'flac'} ALLOWED_DOCUMENT_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'docx', 'xlsx', 'xls'} UPLOAD_FOLDER = 'Uploads' ALLOWED_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'svg', 'docx', 'doc'} app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER # Set file size limits MAX_SIZE_PDF_DOCS = 1 * 1024 * 1024 * 1024 # 1GB MAX_SIZE_IMAGES = 500 * 1024 * 1024 # 500MB # Lazy model loading to save resources class LazyModelLoader: def __init__(self, model_name, task, tokenizer=None): self.model_name = model_name self.task = task self.tokenizer = tokenizer self._model = None def load(self): """Load the model if not already loaded.""" if self._model is None: logging.info(f"Loading model: {self.model_name}") if self.task == "text-generation": self._model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map="auto", torch_dtype="auto" ) self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, legacy=False) if self._model.generation_config.pad_token_id is None or self._model.generation_config.pad_token_id < 0: if self._tokenizer.eos_token_id is not None: self._model.generation_config.pad_token_id = self._tokenizer.eos_token_id logging.info(f"Set pad_token_id to {self._tokenizer.eos_token_id}") else: logging.warning("No valid eos_token_id found. Setting pad_token_id to 0 as a fallback.") self._model.generation_config.pad_token_id = 0 else: self._model = pipeline(self.task, model=self.model_name, tokenizer=self.tokenizer) return self._model # Text extraction agents class TextExtractorAgent: @staticmethod def extract_text(filepath, ext): """Extract text based on file type.""" try: if ext == "pdf": return TextExtractorAgent.extract_text_from_pdf(filepath) elif ext in {"jpg", "jpeg", "png"}: return TextExtractorAgent.extract_text_from_image(filepath) elif ext == "docx": return TextExtractorAgent.extract_text_from_docx(filepath) elif ext in {"xlsx", "xls"}: return TextExtractorAgent.extract_text_from_excel(filepath) return None except Exception as e: logging.error(f"Text extraction failed: {e}") return None @staticmethod def extract_text_from_pdf(filepath): """Extract text from a PDF file.""" text = "" with pdfplumber.open(filepath) as pdf: for page in pdf.pages: page_text = page.extract_text() if page_text: text += page_text + "\n" return text.strip() or None @staticmethod def extract_text_from_image(filepath): """Extract text from an image using OCR.""" image = cv2.imread(filepath) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) _, processed = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: processed_path = temp_file.name cv2.imwrite(processed_path, processed) text = pytesseract.image_to_string(Image.open(processed_path), lang='eng') os.remove(processed_path) return text.strip() or None @staticmethod def extract_text_from_docx(filepath): """Extract text from a DOCX file.""" doc = Document(filepath) text = "\n".join([para.text for para in doc.paragraphs]) return text.strip() or None @staticmethod def extract_text_from_excel(filepath): """Extract text from an Excel file.""" dfs = pd.read_excel(filepath, sheet_name=None) text = "\n".join([ "\n".join([ " ".join(map(str, df[col].dropna())) for col in df.columns ]) for df in dfs.values() ]) return text.strip() or None # PHI scrubbing agent class PHIScrubberAgent: @staticmethod def scrub_phi(text): """Remove sensitive personal health information (PHI).""" try: text = re.sub(r'\b(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text) text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w{2,4}\b', '[EMAIL]', text) text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text) text = re.sub(r'\b\d{1,5}\s+\w+\s+(Street|St|Avenue|Ave|Boulevard|Blvd|Road|Rd|Lane|Ln)\b', '[ADDRESS]', text, flags=re.IGNORECASE) text = re.sub(r'\bDr\.?\s+[A-Z][a-z]+\s+[A-Z][a-z]+\b', 'Dr. [NAME]', text) text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text) except Exception as e: logging.error(f"PHI scrubbing failed: {e}") return text # Summarization agent class SummarizerAgent: def __init__(self, summarization_model_loader): self.summarization_model_loader = summarization_model_loader def generate_summary(self, text): """Generate a summary of the provided text.""" model = self.summarization_model_loader.load() try: summary_result = model(text, do_sample=False) return summary_result[0]['summary_text'].strip() except Exception as e: logging.error(f"Summary generation failed: {e}") return "Summary generation failed." def allowed_file(filename, allowed_extensions=ALLOWED_EXTENSIONS): """Check if the file extension is allowed.""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions # Knowledge Base class KnowledgeBase: def __init__(self, documents): self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2") self.documents = documents self.embeddings = self.embedding_model.encode(documents) self.dimension = self.embedding_model.get_sentence_embedding_dimension() self.index = faiss.IndexFlatL2(self.dimension) self.index.add(self.embeddings) def retrieve_relevant_info(self, query, top_k=3): """Retrieve relevant medical information from the knowledge base.""" query_embedding = self.embedding_model.encode([query]) distances, indices = self.index.search(query_embedding, top_k) relevant_texts = [self.documents[i] for i in indices[0]] return relevant_texts # Medical data extraction agent class MedicalDataExtractorAgent: def __init__(self, model_loader, knowledge_base): self.model_loader = model_loader self.knowledge_base = knowledge_base def retrieve_relevant_info(self, query, top_k=3): """Retrieve relevant medical information from the knowledge base.""" query_embedding = self.knowledge_base.embedding_model.encode([query]) distances, indices = self.knowledge_base.index.search(query_embedding, top_k) relevant_texts = [self.knowledge_base.documents[i] for i in indices[0]] return relevant_texts def extract_medical_data(self, text): """Extract structured medical data from text using Agentic RAG.""" try: default_schema = { "patient_name": "[NAME]", "age": None, "gender": None, "diagnosis": [], "symptoms": [], "medications": [], "allergies": [], "vitals": { "blood_pressure": None, "heart_rate": None, "temperature": None }, "notes": "" } prompt = f""" ### Instruction: Extract structured medical data from the following text as a JSON whose parameters are enclosed in "" and without any \. The JSON should include patientname, age, gender, medications, allergies, diagnosis, symptoms, vitals, and notes. ### Text: {text} ### Response: """ model = self.model_loader.load() tokenizer = self.model_loader._tokenizer inputs = tokenizer(prompt, return_tensors="pt", truncation=True) outputs = model.generate( inputs.input_ids, num_return_sequences=1, temperature=0.7, top_p=0.9, do_sample=True ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) logging.info(f"Model response: {response}") json_start = response.find("{") json_end = response.rfind("}") + 1 if json_start == -1 or json_end == -1: raise ValueError("No JSON found in the model response.") structured_data = json.loads(response[json_start:json_end]) normalized_data = self.normalize_json_output(structured_data, default_schema) if normalized_data["vitals"]["blood_pressure"] and isinstance(normalized_data["vitals"]["blood_pressure"], str): normalized_data["vitals"]["blood_pressure"] = normalized_data["vitals"]["blood_pressure"].strip('"') return json.dumps(normalized_data) except json.JSONDecodeError as e: logging.error(f"JSON parsing failed: {e}") return json.dumps({"error": f"Failed to parse JSON: {str(e)}"}) except Exception as e: logging.error(f"Error extracting medical data: {e}") return json.dumps({"error": f"Failed to extract medical data: {str(e)}"}) @staticmethod def normalize_json_output(model_output, default_schema): """Normalize the model's JSON output to match the default schema.""" try: normalized_output = default_schema.copy() for key in normalized_output: if key in model_output: normalized_output[key] = model_output[key] return normalized_output except Exception as e: logging.error(f"Failed to normalize JSON: {e}") return default_schema # Initialize lazy loaders medalpaca_model_loader = LazyModelLoader( model_name="stanford-crfm/BioMedLM", task="text-generation" ) summarization_model_loader = LazyModelLoader("google-t5/t5-small", "summarization") # Initialize knowledge base medical_documents = [ "Hypertension is a chronic condition characterized by elevated blood pressure.", "Diabetes is a metabolic disorder that affects blood sugar levels.", "Common symptoms of chest pain include pressure, tightness, or discomfort in the chest." ] knowledge_base = KnowledgeBase(medical_documents) # Initialize agents text_extractor_agent = TextExtractorAgent() phi_scrubber_agent = PHIScrubberAgent() medical_data_extractor_agent = MedicalDataExtractorAgent(medalpaca_model_loader, knowledge_base) summarizer_agent = SummarizerAgent(summarization_model_loader) # NER to Detect medical info CONFIDENCE_THRESHOLD = 0.80 def extract_medical_entities(text, ner_pipeline): if not text or not text.strip(): return ["No medical entities found"] if ner_pipeline is None: print("⚠️ NER model is not loaded, skipping entity extraction.") return ["No medical entities found"] ner_results = ner_pipeline(text) relevant_entities = { "Disease", "MedicalCondition", "Symptom", "Sign_or_Symptom", "B-DISEASE", "I-DISEASE", "Test", "Measurement", "B-TEST", "I-TEST", "Lab_value", "B-Lab_value", "I-Lab_value", "Medication", "B-MEDICATION", "I-MEDICATION", "Treatment", "Procedure", "B-Diagnostic_procedure", "I-Diagnostic_procedure", "Anatomical_site", "Body_Part", "Organ_or_Tissue", "Diagnostic_procedure", "Surgical_Procedure", "Therapeutic_Procedure", "Health_condition", "B-Health_condition", "I-Health_condition", "Pathological_Condition", "Clinical_Event", "Chemical_Substance", "B-Chemical_Substance", "I-Chemical_Substance", "Biological_Entity", "B-Biological_Entity", "I-Biological_Entity" } medical_entities = set() for ent in ner_results: entity_label = ent.get("entity_group") or ent.get("entity") if entity_label in relevant_entities and ent["score"] >= CONFIDENCE_THRESHOLD: word = ent["word"].lower().strip().replace("-", "") if len(word) > 2: medical_entities.add(word) if len(medical_entities) >= 5: return list(medical_entities) return ["No medical entities found"] # Validation: Check File Size def check_file_size(file): file.seek(0, os.SEEK_END) size = file.tell() file.seek(0) extension = file.filename.rsplit('.', 1)[-1].lower() if extension in {'pdf', 'docx'} and size > MAX_SIZE_PDF_DOCS: return False, f"File {file.filename} exceeds 1GB size limit" elif extension in {'jpg', 'jpeg', 'png'} and size > MAX_SIZE_IMAGES: return False, f"Image {file.filename} exceeds 500MB size limit" return True, None def extract_patient_name(text, qa_pipeline): """Extracts patient name using the given QA pipeline.""" if not text or not qa_pipeline: return None try: result = qa_pipeline( question="What is the patient's name?", context=text ) return result.get("answer", "").strip() except Exception as e: print(f"⚠️ Error extracting patient name: {e}") return None def normalize_name(name): """Cleans and normalizes names for comparison, removing salutations dynamically.""" if not name: return "" name = name.lower().strip() name = re.sub(r"[^\w\s]", "", name) name = re.sub(r"^\b\w{1,5}\b\s+", "", name) return name def validate_patient_name(extracted_text, patient_name, filename, qa_pipeline): """Validates if the extracted name matches the registered patient name.""" detected_name = extract_patient_name(extracted_text, qa_pipeline) if not detected_name: return jsonify({"error": f"Could not determine patient name from {filename}"}), 400 normalized_detected_name = normalize_name(detected_name) normalized_patient_name = normalize_name(patient_name) if normalized_detected_name not in normalized_patient_name: return jsonify({ "error": f"Document '{filename}' does not belong to {patient_name}. Found: {detected_name}" }), 400 return None def is_blurred(image_path, variance_threshold=150): try: image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) if image is None: print(f"❌ Error: Unable to read image {image_path}") return True laplacian_var = cv2.Laplacian(image, cv2.CV_64F).var() print(f"🔍 Blur Check: Variance={laplacian_var} (Threshold={variance_threshold})") edges = cv2.Canny(image, 50, 150) edge_density = np.mean(edges) print(f"📏 Edge Density: {edge_density}") return laplacian_var < variance_threshold and edge_density < 10 except Exception as e: print(f"❌ Error detecting blur: {e}") return True def extract_text_from_image(filepath): try: if is_blurred(filepath): return "Image is too blurry, OCR failed." image = cv2.imread(filepath) if image is None: return "Image could not be read." gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) gray = cv2.GaussianBlur(gray, (5, 5), 0) gray = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2) kernel = np.ones((2,2), np.uint8) gray = cv2.dilate(gray, kernel, iterations=1) processed_path = f"{filepath}_processed.png" cv2.imwrite(processed_path, gray) text = pytesseract.image_to_string(Image.open(processed_path), lang='eng').strip() words = text.split() if len(words) < 5: return "OCR failed to extract meaningful text." return text except Exception as e: print(f"❌ Error processing {filepath}: {e}") return "Failed to extract text" def extract_text_from_pdf(filepath, password=None): """Extract text from PDFs using pdfplumber (faster) or OCR (if needed).""" text = "" try: reader = PdfReader(filepath) if reader.is_encrypted: if not password: print("🔒 PDF is encrypted but no password was provided.") return {"error": "File is password-protected. Please provide a password."}, 401 decryption_result = reader.decrypt(password) if decryption_result == 0: print("❌ Incorrect password provided!") return {"error": "Invalid password provided."}, 403 else: print("✅ PDF successfully decrypted!") text = "\n".join([page.extract_text() or "" for page in reader.pages]) if text.strip(): return text.strip(), 200 with pdfplumber.open(filepath) as pdf: for page in pdf.pages: page_text = page.extract_text() if page_text: text += page_text + "\n" if text.strip(): return text.strip(), 200 images = convert_from_path(filepath) with ThreadPoolExecutor(max_workers=5) as pool: ocr_text = list(pool.map(lambda img: pytesseract.image_to_string(img, lang='eng'), images)) return ("\n".join(ocr_text).strip(), 200) if ocr_text else ("No text found", 415) except Exception as e: print(f"❌ Error processing PDF {filepath}: {e}") return "Failed to extract text" def extract_text_from_docx(filepath): doc = Document(filepath) text = "\n".join([para.text for para in doc.paragraphs]) return text.strip() or None def clean_result(value): value = re.sub(r"\s+", " ", value) value = re.sub(r"[-_:]+", " ", value) value = re.sub(r"[^\x00-\x7F]+", " ", value) return value if value else "Not Available" def mask_sensitive_info(text): text = re.sub(r'(?<=\b\w{2})\w+(?=\s\w{2,})', '***', text) text = re.sub(r'\b(\d{2})\d{2}-(\d{2})\d{2}-(\d{2})\d{2}\b', r'**\2-**\3-**', text) text = re.sub(r'\b(\d{8})(\d{2})\b', r'********\2', text) return text # API Endpoints @app.route('/extract_medical_data', methods=['POST']) def extract_medical_data(): """Extract structured medical data from raw text.""" try: data = request.json if "text" not in data or not data["text"].strip(): return jsonify({"error": "No valid text provided"}), 400 raw_text = data["text"] clean_text = phi_scrubber_agent.scrub_phi(raw_text) structured_data = medical_data_extractor_agent.extract_medical_data(clean_text) return jsonify(json.loads(structured_data)), 200 except Exception as e: logging.error(f"Failed to extract medical data: {e}") return jsonify({"error": f"Extraction Error: {str(e)}"}), 500 @app.route('/api/transcribe', methods=['POST']) def transcribe_audio(): """Transcribe audio files into text.""" if 'audio' not in request.files: abort(400, description="No audio file provided") audio_file = request.files['audio'] if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS): abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.") filename = secure_filename(audio_file.filename) audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) audio_file.save(audio_path) try: result = whisper_model.transcribe(audio_path) transcribed_text = result["text"] os.remove(audio_path) return jsonify({"transcribed_text": transcribed_text}), 200 except Exception as e: logging.error(f"Transcription failed: {str(e)}") return jsonify({"error": f"Transcription failed: {str(e)}"}), 500 @app.route('/api/generate_summary', methods=['POST']) def generate_summary(): """Generate a summary from the provided text.""" data = request.json if "text" not in data or not data["text"].strip(): return jsonify({"error": "No valid text provided"}), 400 context = data["text"] clean_text = phi_scrubber_agent.scrub_phi(context) summary = summarizer_agent.generate_summary(clean_text) return jsonify({"summary": summary}), 200 @app.route('/api/extract_medical_data_from_audio', methods=['POST']) def extract_medical_data_from_audio(): """Extract medical data from transcribed audio.""" if 'audio' not in request.files: abort(400, description="No audio file provided") audio_file = request.files['audio'] if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS): abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.") logging.info(audio_file.filename) logging.info(app.config['UPLOAD_FOLDER']) filename = secure_filename(audio_file.filename) logging.info(filename) audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) logging.info(audio_path) audio_file.save(audio_path) try: result = whisper_model.transcribe(audio_path) transcribed_text = result["text"] clean_text = phi_scrubber_agent.scrub_phi(transcribed_text) summary = summarizer_agent.generate_summary(transcribed_text) structured_data = medical_data_extractor_agent.extract_medical_data(transcribed_text) response = { "transcribed_text": transcribed_text, "summary": summary, "medical_chart": json.loads(structured_data) } os.remove(audio_path) return jsonify(response), 200 except Exception as e: logging.error(f"Processing failed: {str(e)}") return jsonify({"error": f"Processing failed: {str(e)}"}), 500 @app.route('/upload', methods=['POST']) def upload_file(): files = request.files.getlist("file") patient_name = request.form.get("patient_name", "").strip() password = request.form.get("password") qa_model_name = request.form.get("qa_model_name") qa_model_type = request.form.get("qa_model_type") ner_model_name = request.form.get("ner_model_name") ner_model_type = request.form.get("ner_model_type") summarizer_model_name = request.form.get("summarizer_model_name") summarizer_model_type = request.form.get("summarizer_model_type") if not files: return jsonify({"error": "No file uploaded"}), 400 try: qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name) print(f"✅ QA Model Loaded: {qa_model_name}") except Exception as e: return jsonify({"error": f"QA model load failed: {str(e)}"}), 500 try: ner_pipeline = pipeline(task=ner_model_type, model=ner_model_name) print(f"✅ NER Model Loaded: {ner_model_name}") except Exception as e: return jsonify({"error": f"NER model load failed: {str(e)}"}), 500 try: summarizer_pipeline = pipeline(task=summarizer_model_type, model=summarizer_model_name) print(f"✅ Summarizer Model Loaded: {summarizer_model_name}") except Exception as e: return jsonify({"error": f"Summarizer model load failed: {str(e)}"}), 500 extracted_data = [] print(patient_name) for file in files: if file.filename == '': continue if not allowed_file(file.filename): return jsonify({"error": f"Unsupported file type: {file.filename}. Supported file types are: {', '.join(ALLOWED_EXTENSIONS)}"}), 400 if not patient_name: return jsonify({"error": "Patient name is missing"}), 400 valid_size, error_message = check_file_size(file) if not valid_size: return jsonify({"error": error_message}), 400 filename = secure_filename(file.filename) filepath = os.path.join(UPLOAD_FOLDER, filename) file.save(filepath) extracted_text = None if filename.endswith(".pdf"): result = extract_text_from_pdf(filepath, password) if isinstance(result, tuple): extracted_text, status_code = result else: extracted_text = result status_code = 200 if isinstance(extracted_text, dict) and "error" in extracted_text: return jsonify(extracted_text), status_code elif filename.endswith(".docx"): extracted_text = extract_text_from_docx(filepath) elif filename.endswith((".jpg", ".jpeg", ".png", ".svg")): extracted_text = extract_text_from_image(filepath) if not extracted_text or extracted_text == "No text found": return jsonify({"error": f"Failed to extract text from {filename}"}), 415 if extracted_text in ["Image is too blurry, OCR failed.", "OCR failed to extract meaningful text."]: return jsonify({"error": f"'{filename}' is too blurry or text is unreadable."}), 422 skip_medical_check = request.form.get("skip_medical_check", "false").lower() == "true" if not skip_medical_check: ner_results = ner_pipeline(extracted_text) medical_entities = list(set([r["word"] for r in ner_results if r["entity"].startswith("B-") or r["entity"].startswith("I-")])) print(f"Medical entities found: {medical_entities}") if not medical_entities: return jsonify({"error": f"'{filename}' is not medically relevant"}), 406 else: print(f"Skipping Medical Validation for {filename}") skip_patient_check = request.form.get("skip_patient_check", "false").lower() == "true" if not skip_patient_check: try: error_response = validate_patient_name(extracted_text, patient_name, filename, qa_pipeline) if error_response: return error_response except Exception as e: return jsonify({"error": f"Patient name validation failed: {str(e)}"}), 500 else: print(f"Skipping Patient Name Validation for {filename}") try: summary = summarizer_pipeline(extracted_text, max_length=350, min_length=50, do_sample=False)[0]["summary_text"] except Exception as e: summary = "Summary failed" print(f"⚠️ Error summarizing: {e}") extracted_data.append({ "file": filename, "extracted_text": extracted_text, "summary": summary, "message": "Successful" }) extracted_text = None summary = None if not extracted_data: return jsonify({"error": "No valid medical files processed"}), 400 return jsonify({"extracted_data": extracted_data}), 200 @app.route('/extract_medical_data_questions', methods=['POST']) def extract_medical_data_questions(): """Extract medical data based on predefined questions.""" data = request.json qa_model_name = data.get("qa_model_name") qa_model_type = data.get("qa_model_type") if "extracted_data" not in data: return jsonify({"error": "Missing 'extracted_data' in request"}), 400 if not qa_model_name or not qa_model_type: return jsonify({"error": "Missing 'model_name' or 'model_type'"}), 400 try: print(f"🌀 Loading model: {qa_model_name} ({qa_model_type})") qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name) loaded_model_name = qa_pipeline.model.config._name_or_path loaded_model_type = qa_pipeline.task print(f"✅ Model loaded: {loaded_model_name}") except Exception as e: print("❌ Error loading model:", str(e)) return jsonify({"error": f"Could not load model: {str(e)}"}), 500 questions = { "Patient Name": "What is the patient's name?", "Age": "What is the patient's age?", "Gender": "What is the patient's gender?", "Date of Birth": "What is the patient's date of birth?", "Patient ID": "What is the patient ID?", "Reason for Visit": "What is the reason for the patient's visit?", "Physician": "Who is the physician in charge of the patient?", "Test Date": "What is the test date?", "Hemoglobin": "What is the patient's hemoglobin level?", "Blood Glucose (Fasting)": "What is the patient's fasting blood glucose level?", "Total Cholesterol": "What is the total cholesterol level?", "LDL Cholesterol": "What is the LDL cholesterol level?", "HDL Cholesterol": "What is the HDL cholesterol level?", "Serum Creatinine": "What is the serum creatinine level?", "Vitamin D (25-OH)": "What is the patient's Vitamin D level?", "Height": "What is the patient's height?", "Weight": "What is the patient's weight?", "Blood Pressure (Systolic)": "What is the patient's systolic blood pressure?", "Blood Pressure (Diastolic)": "What is the patient's diastolic blood pressure?", "Recommendations": "What are the recommendations based on the test results?" } structured_response = {"extracted_data": []} for file_data in data["extracted_data"]: filename = file_data["file"] context = file_data["extracted_text"] if not context: structured_response["extracted_data"].append({ "file": filename, "medical_terms": "No data extracted", }) continue extracted_info = {} for key, question in questions.items(): try: result = qa_pipeline(question=question, context=context) extracted_info[key] = clean_result(result.get("answer", "Not Available")) except: extracted_info[key] = "Error extracting" categorized_data = [ { "name": "Patient Information", "fields": [ {"label": "Patient Name", "value": extracted_info.get("Patient Name", "")}, {"label": "Date of Birth", "value": extracted_info.get("Date of Birth", "")}, {"label": "Gender", "value": extracted_info.get("Gender", "")}, {"label": "Patient ID", "value": extracted_info.get("Patient ID", "")} ] }, { "name": "Vitals", "fields": [ {"label": "Height", "value": extracted_info.get("Height", "")}, {"label": "Weight", "value": extracted_info.get("Weight", "")}, {"label": "Blood Pressure", "value": f"{extracted_info.get('Blood Pressure (Systolic)', '')}/{extracted_info.get('Blood Pressure (Diastolic)', '')} mmHg"}, {"label": "Hemoglobin", "value": extracted_info.get("Hemoglobin", "")}, {"label": "Serum Creatinine", "value": extracted_info.get("Serum Creatinine", "")} ] }, { "name": "Lab Results", "fields": [ {"label": "Blood Glucose (Fasting)", "value": extracted_info.get("Blood Glucose (Fasting)", "")}, {"label": "Total Cholesterol", "value": extracted_info.get("Total Cholesterol", "")}, {"label": "LDL Cholesterol", "value": extracted_info.get("LDL Cholesterol", "")}, {"label": "HDL Cholesterol", "value": extracted_info.get("HDL Cholesterol", "")}, {"label": "Vitamin D (25-OH)", "value": extracted_info.get("Vitamin D (25-OH)", "")} ] }, { "name": "Medical Notes", "fields": [ {"label": "Reason for Visit", "value": extracted_info.get("Reason for Visit", "")}, {"label": "Physician", "value": extracted_info.get("Physician", "")}, {"label": "Test Date", "value": extracted_info.get("Test Date", "")}, {"label": "Recommendations", "value": extracted_info.get("Recommendations", "")} ] } ] structured_response["extracted_data"].append({ "file": filename, "medical_terms": extracted_info, "categorized_data": categorized_data, "model_used": loaded_model_name, "model_type": loaded_model_type }) save_data_to_storage(filename, structured_response) print(f"✅ Extracted data saved to: {os.path.join(UPLOAD_FOLDER, f'{filename}.json')}") return jsonify(structured_response), 200 def get_data_from_storage(filename): try: filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json") print(f"🔍 Looking for file at: {filepath}") if not os.path.exists(filepath): print(f"🚫 File not found at: {filepath}") return None with open(filepath, "r") as file: data = json.load(file) print(f"✅ File found and loaded: {filepath}") return data except Exception as e: print(f"🚨 Error loading data: {e}") return None def save_data_to_storage(filename, data): try: filename = filename.rsplit(".", 1)[0] filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json") print(f"Saving to: {filepath}") print(f"Directory exists: {os.path.exists(UPLOAD_FOLDER)}") if not os.path.exists(UPLOAD_FOLDER): print(f"Directory not found. Creating: {UPLOAD_FOLDER}") os.makedirs(UPLOAD_FOLDER, exist_ok=True) with open(filepath, "w") as file: json.dump(data, file) print(f"✅ Data saved successfully to {filepath}") except Exception as e: print(f"🚨 Exception during save: {e}") @app.route('/get_updated_medical_data', methods=['GET']) def get_updated_data(): file_name = request.args.get('file') if not file_name: return jsonify({"error": "File name is required"}), 400 file_name = file_name.rsplit(".", 1)[0] updated_data = get_data_from_storage(file_name) if updated_data: return jsonify({"file": file_name, "data": updated_data}), 200 else: return jsonify({"error": f"File '{file_name}' not found"}), 404 @app.route('/update_medical_data', methods=['PUT']) def update_medical_data(): try: data = request.json print("Received data:", data) filename = data.get("file") filename = filename.rsplit(".", 1)[0] updates = data.get("updates", []) if not filename or not updates: return jsonify({"error": "File name or updates missing"}), 400 existing_data = get_data_from_storage(filename) if not existing_data: return jsonify({"error": f"File '{filename}' not found"}), 404 for update in updates: category = update.get("category") field = update.get("field") new_value = update.get("value") updated = False for cat in existing_data.get("extracted_data", []): for categorized_data in cat.get("categorized_data", []): if categorized_data.get("name") == category: for fld in categorized_data.get("fields", []): if fld.get("label") == field: print(f"🔄 Updating {category} -> {field} from '{fld['value']}' to '{new_value}'") fld["value"] = new_value updated = True break if updated: break if updated: break save_data_to_storage(filename, existing_data) print("✅ Updated data:", existing_data) return jsonify({"message": "Data updated successfully", "updated_data": existing_data}), 200 except Exception as e: print("❌ Error:", str(e)) return jsonify({"error": str(e)}), 500 @app.route('/') def home(): return "Medical Data Extraction API is running!" if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, debug=True)