import json import os import re import logging import shutil from flask import Flask, request, jsonify, abort from werkzeug.utils import secure_filename from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import torch import whisper from dotenv import load_dotenv import pytesseract import cv2 import pdfplumber import pandas as pd from PIL import Image from docx import Document from flask_cors import CORS # Load environment variables load_dotenv() # Initialize Flask app app = Flask(__name__) CORS(app) # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") # Configure upload directory and max file size UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads')) os.makedirs(UPLOAD_DIR, exist_ok=True) app.config['UPLOAD_FOLDER'] = UPLOAD_DIR app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size # Allowed file extensions ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'flac'} ALLOWED_DOCUMENT_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'docx', 'xlsx', 'xls'} # Ensure ffmpeg is in PATH ffmpeg_path = shutil.which("ffmpeg") or "C:\\ffmpeg\\bin\\ffmpeg.exe" if not os.path.exists(ffmpeg_path): raise RuntimeError("FFmpeg not found! Please install FFmpeg and set the correct path.") os.environ["PATH"] += os.pathsep + os.path.dirname(ffmpeg_path)\ def allowed_file(filename, allowed_extensions): """Check if the file extension is allowed.""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions class LazyModelLoader: def __init__(self, model_name, task, tokenizer=None, apply_quantization=False): self.model_name = model_name self.task = task self.tokenizer = tokenizer self.apply_quantization = apply_quantization self._pipeline = None def load(self): if self._pipeline is None: logging.info(f"Loading pipeline for task: {self.task} | model: {self.model_name}") if self.task == "question-answering": model = AutoModelForCausalLM.from_pretrained(self.model_name) tokenizer = AutoTokenizer.from_pretrained(self.model_name) if self.apply_quantization: logging.info("Applying quantization...") model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) self._pipeline = pipeline(self.task, model=model, tokenizer=tokenizer) else: self._pipeline = pipeline(self.task, model=self.model_name, tokenizer=self.tokenizer) return self._pipeline # PHI scrubbing agent class PHIScrubberAgent: @staticmethod def scrub_phi(text): try: text = re.sub(r'\b(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text) text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w{2,4}\b', '[EMAIL]', text) text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text) text = re.sub(r'\b\d{1,5}\s+\w+\s+(Street|St|Avenue|Ave|Boulevard|Blvd|Road|Rd|Lane|Ln)\b', '[ADDRESS]', text, flags=re.IGNORECASE) text = re.sub(r'\bDr\.?\s+[A-Z][a-z]+\s+[A-Z][a-z]+\b', 'Dr. [NAME]', text) text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text) except Exception as e: logging.error(f"PHI scrubbing failed: {e}") return text # Summarization Agent class SummarizerAgent: def __init__(self, summarization_model_loader): self.summarization_model_loader = summarization_model_loader def generate_summary(self, text): model = self.summarization_model_loader.load() try: summary_result = model(text, max_length=150, min_length=30, do_sample=False) return summary_result[0]['summary_text'].strip() except Exception as e: logging.error(f"Summary generation failed: {e}") return "Summary generation failed." # Medical Data Extraction Agent class MedicalDataExtractorAgent: def __init__(self, gen_model_loader): self.gen_model_loader = gen_model_loader def extract_medical_data(self, text): try: generator = self.gen_model_loader.load() prompt = ( "Extract structured medical information from the following clinical note.\n\n" "Return the result in JSON format with the following fields:\n" "patient_condition, symptoms, current_problems, allergies, dr_notes, " "prescription, investigations, follow_up_instructions.\n\n" f"Clinical Note:\n{text}\n\n" "Structured JSON Output:\n" ) response = generator(prompt, max_new_tokens=256)[0]["generated_text"] logging.debug(f"Raw model output: {response}") json_start = response.find("{") json_end = response.rfind("}") + 1 if json_start == -1 or json_end == -1: raise ValueError("No JSON found in the model response.") json_str = response[json_start:json_end] return json.loads(json_str) except Exception as e: logging.error(f"Error extracting medical data: {e}") return {"error": f"Failed to extract medical data: {str(e)}"} # Initialize lazy loaders gen_model_loader = LazyModelLoader( "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "text-generation", ) summarization_model_loader = LazyModelLoader("google-t5/t5-large", "summarization", apply_quantization=True) whisper_model = whisper.load_model("base") # Initialize agents phi_scrubber_agent = PHIScrubberAgent() medical_data_extractor_agent = MedicalDataExtractorAgent(gen_model_loader) summarizer_agent = SummarizerAgent(summarization_model_loader) # API Endpoints @app.route('/api/extract_medical_data', methods=['POST']) def extract_medical_data(): try: data = request.json if "text" not in data or not data["text"].strip(): return jsonify({"error": "No valid text provided"}), 400 raw_text = data["text"] clean_text = phi_scrubber_agent.scrub_phi(raw_text) structured_data = medical_data_extractor_agent.extract_medical_data(clean_text) return jsonify(structured_data), 200 except Exception as e: logging.error(f"Failed to extract medical data: {e}") return jsonify({"error": f"Extraction Error: {str(e)}"}), 500 @app.route('/api/transcribe', methods=['POST']) def transcribe_audio(): if 'audio' not in request.files: abort(400, description="No audio file provided") audio_file = request.files['audio'] if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS): abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.") filename = secure_filename(audio_file.filename) audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) audio_file.save(audio_path) try: result = whisper_model.transcribe(audio_path) transcribed_text = result["text"] os.remove(audio_path) return jsonify({"transcribed_text": transcribed_text}), 200 except Exception as e: logging.error(f"Transcription failed: {str(e)}") return jsonify({"error": f"Transcription failed: {str(e)}"}), 500 @app.route('/api/generate_summary', methods=['POST']) def generate_summary(): data = request.json if "text" not in data or not data["text"].strip(): return jsonify({"error": "No valid text provided"}), 400 context = data["text"] clean_text = phi_scrubber_agent.scrub_phi(context) summary = summarizer_agent.generate_summary(clean_text) return jsonify({"summary": summary}), 200 @app.route('/api/extract_medical_data_from_audio', methods=['POST']) def extract_medical_data_from_audio(): if 'audio' not in request.files: abort(400, description="No audio file provided") audio_file = request.files['audio'] if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS): abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.") filename = secure_filename(audio_file.filename) audio_path = os.path.join(UPLOAD_DIR, filename) audio_file.save(audio_path) try: result = whisper_model.transcribe(audio_path) transcribed_text = result["text"] clean_text = phi_scrubber_agent.scrub_phi(transcribed_text) summary = summarizer_agent.generate_summary(clean_text) structured_data = medical_data_extractor_agent.extract_medical_data(clean_text) response = { "transcribed_text": clean_text, "summary": summary, "medical_chart": structured_data } os.remove(audio_path) return jsonify(response), 200 except Exception as e: logging.error(f"Processing failed: {str(e)}") return jsonify({"error": f"Processing failed: {str(e)}"}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, debug=False) # import json # import os # import re # import logging # import shutil # from dotenv import load_dotenv # from flask import Flask, request, jsonify, abort # from werkzeug.utils import secure_filename # from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # import pytesseract # import cv2 # import pdfplumber # import pandas as pd # from PIL import Image # from docx import Document # from flask_cors import CORS # from flask_executor import Executor # from sentence_transformers import SentenceTransformer # import faiss # import whisper # from PyPDF2 import PdfReader # from pdf2image import convert_from_path # from concurrent.futures import ThreadPoolExecutor # import tempfile # # Load environment variables # load_dotenv() # # Initialize Flask app # app = Flask(__name__) # CORS(app) # # Configure logging # logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") # # Configure upload directory and max file size # UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads')) # os.makedirs(UPLOAD_DIR, exist_ok=True) # app.config['UPLOAD_FOLDER'] = UPLOAD_DIR # app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size # # Initialize Flask-Executor for asynchronous tasks # executor = Executor(app) # whisper_model = whisper.load_model("tiny") # # Allowed file extensions # ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'flac'} # ALLOWED_DOCUMENT_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'docx', 'xlsx', 'xls'} # # Ensure ffmpeg is in PATH # ffmpeg_path = shutil.which("ffmpeg") or "C:\\ffmpeg\\bin\\ffmpeg.exe" # if not os.path.exists(ffmpeg_path): # raise RuntimeError("FFmpeg not found! Please install FFmpeg and set the correct path.") # os.environ["PATH"] += os.pathsep + os.path.dirname(ffmpeg_path) # # Lazy model loading to save resources # class LazyModelLoader: # def __init__(self, model_name, task, tokenizer=None): # self.model_name = model_name # self.task = task # self.tokenizer = tokenizer # self._model = None # def load(self): # """Load the model if not already loaded.""" # if self._model is None: # logging.info(f"Loading model: {self.model_name}") # if self.task == "text-generation": # self._model = AutoModelForCausalLM.from_pretrained( # self.model_name, device_map="auto", torch_dtype="auto" # ) # self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, legacy=False) # # Set pad_token_id if it's not already set # if self._model.generation_config.pad_token_id is None or self._model.generation_config.pad_token_id < 0: # if self._tokenizer.eos_token_id is not None: # self._model.generation_config.pad_token_id = self._tokenizer.eos_token_id # logging.info(f"Set pad_token_id to {self._tokenizer.eos_token_id}") # else: # logging.warning("No valid eos_token_id found. Setting pad_token_id to 0 as a fallback.") # self._model.generation_config.pad_token_id = 0 # else: # self._model = pipeline(self.task, model=self.model_name, tokenizer=self.tokenizer) # return self._model # # Text extraction agents # class TextExtractorAgent: # @staticmethod # def extract_text(filepath, ext): # """Extract text based on file type.""" # try: # if ext == "pdf": # return TextExtractorAgent.extract_text_from_pdf(filepath) # elif ext in {"jpg", "jpeg", "png"}: # return TextExtractorAgent.extract_text_from_image(filepath) # elif ext == "docx": # return TextExtractorAgent.extract_text_from_docx(filepath) # elif ext in {"xlsx", "xls"}: # return TextExtractorAgent.extract_text_from_excel(filepath) # return None # except Exception as e: # logging.error(f"Text extraction failed: {e}") # return None # @staticmethod # def extract_text_from_pdf(filepath): # """Extract text from a PDF file.""" # text = "" # with pdfplumber.open(filepath) as pdf: # for page in pdf.pages: # page_text = page.extract_text() # if page_text: # text += page_text + "\n" # return text.strip() or None # @staticmethod # def extract_text_from_image(filepath): # """Extract text from an image using OCR.""" # image = cv2.imread(filepath) # gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # _, processed = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: # processed_path = temp_file.name # cv2.imwrite(processed_path, processed) # text = pytesseract.image_to_string(Image.open(processed_path), lang='eng') # os.remove(processed_path) # return text.strip() or None # @staticmethod # def extract_text_from_docx(filepath): # """Extract text from a DOCX file.""" # doc = Document(filepath) # text = "\n".join([para.text for para in doc.paragraphs]) # return text.strip() or None # @staticmethod # def extract_text_from_excel(filepath): # """Extract text from an Excel file.""" # dfs = pd.read_excel(filepath, sheet_name=None) # text = "\n".join([ # "\n".join([ # " ".join(map(str, df[col].dropna())) # for col in df.columns # ]) # for df in dfs.values() # ]) # return text.strip() or None # # PHI scrubbing agent # class PHIScrubberAgent: # @staticmethod # def scrub_phi(text): # """Remove sensitive personal health information (PHI).""" # try: # text = re.sub(r'\b(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text) # text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w{2,4}\b', '[EMAIL]', text) # text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text) # text = re.sub(r'\b\d{1,5}\s+\w+\s+(Street|St|Avenue|Ave|Boulevard|Blvd|Road|Rd|Lane|Ln)\b', '[ADDRESS]', text, flags=re.IGNORECASE) # text = re.sub(r'\bDr\.?\s+[A-Z][a-z]+\s+[A-Z][a-z]+\b', 'Dr. [NAME]', text) # text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text) # except Exception as e: # logging.error(f"PHI scrubbing failed: {e}") # return text # # Summarization agent # class SummarizerAgent: # def __init__(self, summarization_model_loader): # self.summarization_model_loader = summarization_model_loader # def generate_summary(self, text): # """Generate a summary of the provided text.""" # model = self.summarization_model_loader.load() # try: # summary_result = model(text, do_sample=False) # return summary_result[0]['summary_text'].strip() # except Exception as e: # logging.error(f"Summary generation failed: {e}") # return "Summary generation failed." # def allowed_file(filename, allowed_extensions): # """Check if the file extension is allowed.""" # return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions # # Knowledge Base # class KnowledgeBase: # def __init__(self, documents): # self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2") # self.documents = documents # self.embeddings = self.embedding_model.encode(documents) # self.dimension = self.embedding_model.get_sentence_embedding_dimension() # self.index = faiss.IndexFlatL2(self.dimension) # self.index.add(self.embeddings) # def retrieve_relevant_info(self, query, top_k=3): # """Retrieve relevant medical information from the knowledge base.""" # query_embedding = self.embedding_model.encode([query]) # distances, indices = self.index.search(query_embedding, top_k) # relevant_texts = [self.documents[i] for i in indices[0]] # return relevant_texts # # Medical data extraction agent # class MedicalDataExtractorAgent: # def __init__(self, model_loader, knowledge_base): # self.model_loader = model_loader # self.knowledge_base = knowledge_base # def retrieve_relevant_info(self, query, top_k=3): # """Retrieve relevant medical information from the knowledge base.""" # query_embedding = self.knowledge_base.embedding_model.encode([query]) # distances, indices = self.knowledge_base.index.search(query_embedding, top_k) # relevant_texts = [self.knowledge_base.documents[i] for i in indices[0]] # return relevant_texts # def extract_medical_data(self, text): # """Extract structured medical data from text using Agentic RAG.""" # try: # # Define the default JSON schema # default_schema = { # "patient_name": "[NAME]", # "age": None, # "gender": None, # "diagnosis": [], # "symptoms": [], # "medications": [], # "allergies": [], # "vitals": { # "blood_pressure": None, # "heart_rate": None, # "temperature": None # }, # "notes": "" # } # # Construct the prompt with the input text # prompt = f""" # ### Instruction: # Extract structured medical data from the following text as a JSON whose parameters are enclosed in "" and without any \. # The JSON should include patientname, age, gender, medications, allergies, diagnosis, symptoms, vitals, and notes. # ### Text: # {text} # ### Response: # """ # # Tokenize and generate the response # model = self.model_loader.load() # tokenizer = self.model_loader._tokenizer # inputs = tokenizer(prompt, return_tensors="pt", truncation=True) # outputs = model.generate( # inputs.input_ids, # num_return_sequences=1, # temperature=0.7, # top_p=0.9, # do_sample=True # ) # response = tokenizer.decode(outputs[0], skip_special_tokens=True) # logging.info(f"Model response: {response}") # # Parse and normalize the JSON output # json_start = response.find("{") # json_end = response.rfind("}") + 1 # if json_start == -1 or json_end == -1: # raise ValueError("No JSON found in the model response.") # # Extract the JSON substring # structured_data = json.loads(response[json_start:json_end]) # # Normalize the JSON output # normalized_data = self.normalize_json_output(structured_data, default_schema) # # Ensure blood pressure is a string # if normalized_data["vitals"]["blood_pressure"] and isinstance(normalized_data["vitals"]["blood_pressure"], str): # normalized_data["vitals"]["blood_pressure"] = normalized_data["vitals"]["blood_pressure"].strip('"') # return json.dumps(normalized_data) # except json.JSONDecodeError as e: # logging.error(f"JSON parsing failed: {e}") # return json.dumps({"error": f"Failed to parse JSON: {str(e)}"}) # except Exception as e: # logging.error(f"Error extracting medical data: {e}") # return json.dumps({"error": f"Failed to extract medical data: {str(e)}"}) # @staticmethod # def normalize_json_output(model_output, default_schema): # """ # Normalize the model's JSON output to match the default schema. # """ # try: # normalized_output = default_schema.copy() # for key in normalized_output: # if key in model_output: # normalized_output[key] = model_output[key] # return normalized_output # except Exception as e: # logging.error(f"Failed to normalize JSON: {e}") # return default_schema # Return the default schema in case of errors # # Initialize lazy loaders # medalpaca_model_loader = LazyModelLoader("lmsys/vicuna-7b-v1.5", "text-generation") # summarization_model_loader = LazyModelLoader("google-t5/t5-small", "summarization") # whisper_model = whisper.load_model("tiny") # # Initialize knowledge base # medical_documents = [ # "Hypertension is a chronic condition characterized by elevated blood pressure.", # "Diabetes is a metabolic disorder that affects blood sugar levels.", # "Common symptoms of chest pain include pressure, tightness, or discomfort in the chest." # ] # knowledge_base = KnowledgeBase(medical_documents) # # Initialize agents # text_extractor_agent = TextExtractorAgent() # phi_scrubber_agent = PHIScrubberAgent() # medical_data_extractor_agent = MedicalDataExtractorAgent(medalpaca_model_loader, knowledge_base) # summarizer_agent = SummarizerAgent(summarization_model_loader) # # API Endpoints # @app.route('/api/extract_medical_data', methods=['POST']) # def extract_medical_data(): # """Extract structured medical data from raw text.""" # try: # data = request.json # if "text" not in data or not data["text"].strip(): # return jsonify({"error": "No valid text provided"}), 400 # raw_text = data["text"] # clean_text = phi_scrubber_agent.scrub_phi(raw_text) # structured_data = medical_data_extractor_agent.extract_medical_data(clean_text) # return jsonify(json.loads(structured_data)), 200 # except Exception as e: # logging.error(f"Failed to extract medical data: {e}") # return jsonify({"error": f"Extraction Error: {str(e)}"}), 500 # @app.route('/api/transcribe', methods=['POST']) # def transcribe_audio(): # """Transcribe audio files into text.""" # if 'audio' not in request.files: # abort(400, description="No audio file provided") # audio_file = request.files['audio'] # if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS): # abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.") # filename = secure_filename(audio_file.filename) # audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) # audio_file.save(audio_path) # try: # result = whisper_model.transcribe(audio_path) # transcribed_text = result["text"] # os.remove(audio_path) # return jsonify({"transcribed_text": transcribed_text}), 200 # except Exception as e: # logging.error(f"Transcription failed: {str(e)}") # return jsonify({"error": f"Transcription failed: {str(e)}"}), 500 # @app.route('/api/generate_summary', methods=['POST']) # def generate_summary(): # """Generate a summary from the provided text.""" # data = request.json # if "text" not in data or not data["text"].strip(): # return jsonify({"error": "No valid text provided"}), 400 # context = data["text"] # clean_text = phi_scrubber_agent.scrub_phi(context) # summary = summarizer_agent.generate_summary(clean_text) # return jsonify({"summary": summary}), 200 # @app.route('/api/extract_medical_data_from_audio', methods=['POST']) # def extract_medical_data_from_audio(): # """Extract medical data from transcribed audio.""" # if 'audio' not in request.files: # abort(400, description="No audio file provided") # audio_file = request.files['audio'] # if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS): # abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.") # filename = secure_filename(audio_file.filename) # audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) # audio_file.save(audio_path) # try: # result = whisper_model.transcribe(audio_path) # transcribed_text = result["text"] # clean_text = phi_scrubber_agent.scrub_phi(transcribed_text) # summary = summarizer_agent.generate_summary(transcribed_text) # structured_data = medical_data_extractor_agent.extract_medical_data(transcribed_text) # response = { # "transcribed_text": transcribed_text, # "summary": summary, # "medical_chart": json.loads(structured_data) # } # os.remove(audio_path) # return jsonify(response), 200 # except Exception as e: # logging.error(f"Processing failed: {str(e)}") # return jsonify({"error": f"Processing failed: {str(e)}"}), 500 # @app.route('/upload_document', methods=['POST']) # def upload_document(): # """Upload and extract text from documents.""" # if 'file' not in request.files: # return jsonify({"error": "No file uploaded"}), 400 # file = request.files['file'] # if file.filename == '': # return jsonify({"error": "No file selected"}), 400 # if file and allowed_file(file.filename, ALLOWED_DOCUMENT_EXTENSIONS): # filename = secure_filename(file.filename) # filepath = os.path.join(UPLOAD_DIR, filename) # file.save(filepath) # ext = filename.rsplit('.', 1)[1].lower() # extracted_text = text_extractor_agent.extract_text(filepath, ext) # if not extracted_text: # return jsonify({"error": "No text found in file."}), 400 # response_data = { # "file": filename, # "extracted_text": extracted_text[:500], # "message": "Click to extract medical terms" # } # os.remove(filepath) # return jsonify(response_data), 200 # return jsonify({"error": "Invalid file type"}), 400 # @app.route('/extract_medical_data_from_document', methods=['POST']) # def extract_medical_data_from_document(): # """Extract medical data from document text.""" # data = request.json # if "text" not in data or not data["text"].strip(): # return jsonify({"error": "No valid text provided"}), 400 # context = data["text"] # clean_text = phi_scrubber_agent.scrub_phi(context) # structured_data = medical_data_extractor_agent.extract_medical_data(clean_text) # return jsonify(json.loads(structured_data)), 200 # if __name__ == '__main__': # app.run(host='0.0.0.0', port=5000, debug=True)