Spaces:
Running
Running
Commit ·
b1406c1
1
Parent(s): 7807e33
feat: NLP module, multimodal fusion, and RAG chatbot
Browse files- backend/api/v1/routers/analyze.py +30 -2
- backend/api/v1/routers/chat.py +59 -0
- backend/ml/fusion/medclip.py +74 -0
- backend/ml/nlp/classifier.py +81 -0
- backend/ml/nlp/ner.py +134 -0
- backend/ml/nlp/whisper.py +70 -0
- backend/ml/rag/generator.py +53 -0
- backend/ml/rag/retriever.py +53 -0
- backend/ml/rag/vectorstore.py +117 -0
- frontend/components/analysis/VoiceInput.jsx +138 -0
- frontend/components/chat/ChatInterface.jsx +98 -0
- training/scripts/finetune_ner.py +103 -0
- training/scripts/ingest_pubmed.py +77 -0
- training/scripts/prepare_ner_data.py +83 -0
backend/api/v1/routers/analyze.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
|
|
| 1 |
from fastapi import APIRouter, Depends, Form, File, UploadFile, HTTPException, Request
|
| 2 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
from backend.db.session import get_db
|
| 6 |
from backend.db.models import AnalysisSession, AnalysisTask
|
| 7 |
-
from backend.core.dependencies import get_current_user, get_pagination
|
| 8 |
-
from backend.core.exceptions import InvalidFileTypeError, FileTooLargeError
|
| 9 |
from backend.utils.validators import ImageValidator, sanitize_symptoms_text, validate_patient_id, safe_temp_path
|
| 10 |
from backend.api.v1.schemas.analysis import TaskSubmitResponse, TaskStatusResponse, AnalysisResult
|
| 11 |
from backend.orchestration.queue import task_queue
|
|
|
|
| 12 |
|
| 13 |
router = APIRouter()
|
| 14 |
|
|
@@ -77,3 +79,29 @@ async def get_result(task_id: str, db: AsyncSession = Depends(get_db), current_u
|
|
| 77 |
async def cancel_task(task_id: str, current_user = Depends(get_current_user)):
|
| 78 |
success = await task_queue.cancel(task_id, current_user.id)
|
| 79 |
return {"cancelled": success, "message": "Task cancelled" if success else "Cannot cancel task"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
from fastapi import APIRouter, Depends, Form, File, UploadFile, HTTPException, Request
|
| 3 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
from backend.db.session import get_db
|
| 7 |
from backend.db.models import AnalysisSession, AnalysisTask
|
| 8 |
+
from backend.core.dependencies import get_current_user, get_pagination, get_model_registry
|
| 9 |
+
from backend.core.exceptions import InvalidFileTypeError, FileTooLargeError, ModelNotLoadedError
|
| 10 |
from backend.utils.validators import ImageValidator, sanitize_symptoms_text, validate_patient_id, safe_temp_path
|
| 11 |
from backend.api.v1.schemas.analysis import TaskSubmitResponse, TaskStatusResponse, AnalysisResult
|
| 12 |
from backend.orchestration.queue import task_queue
|
| 13 |
+
from backend.ml.nlp.whisper import WhisperTranscriber
|
| 14 |
|
| 15 |
router = APIRouter()
|
| 16 |
|
|
|
|
| 79 |
async def cancel_task(task_id: str, current_user = Depends(get_current_user)):
|
| 80 |
success = await task_queue.cancel(task_id, current_user.id)
|
| 81 |
return {"cancelled": success, "message": "Task cancelled" if success else "Cannot cancel task"}
|
| 82 |
+
|
| 83 |
+
@router.post("/transcribe")
|
| 84 |
+
async def transcribe_audio(
|
| 85 |
+
audio: UploadFile = File(...),
|
| 86 |
+
registry = Depends(get_model_registry),
|
| 87 |
+
current_user = Depends(get_current_user)
|
| 88 |
+
):
|
| 89 |
+
if audio.content_type not in ["audio/wav", "audio/mpeg", "audio/webm", "audio/ogg"]:
|
| 90 |
+
raise InvalidFileTypeError("Audio must be WAV, MP3, WebM, or OGG")
|
| 91 |
+
|
| 92 |
+
content = await audio.read()
|
| 93 |
+
if len(content) > 25 * 1024 * 1024:
|
| 94 |
+
raise FileTooLargeError("Audio file too large (max 25MB)")
|
| 95 |
+
|
| 96 |
+
temp_path = safe_temp_path(audio.filename or "audio.webm")
|
| 97 |
+
temp_path.write_bytes(content)
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
whisper_state = await registry.get("whisper_tiny")
|
| 101 |
+
if not whisper_state.is_available:
|
| 102 |
+
raise ModelNotLoadedError("Voice transcription unavailable")
|
| 103 |
+
|
| 104 |
+
result = await asyncio.to_thread(WhisperTranscriber.transcribe, temp_path, whisper_state.model)
|
| 105 |
+
return result
|
| 106 |
+
finally:
|
| 107 |
+
temp_path.unlink(missing_ok=True)
|
backend/api/v1/routers/chat.py
CHANGED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import asyncio
|
| 3 |
+
from fastapi import APIRouter, Depends
|
| 4 |
+
from fastapi.responses import StreamingResponse
|
| 5 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 6 |
+
|
| 7 |
+
from backend.api.v1.schemas.chat import ChatRequest
|
| 8 |
+
from backend.core.dependencies import get_session_or_404, get_model_registry, get_current_user
|
| 9 |
+
from backend.db.session import get_db
|
| 10 |
+
# Note: Assume session_manager is built in db/utils.py to manage chat history records
|
| 11 |
+
# from backend.db.utils import session_manager
|
| 12 |
+
from backend.ml.rag.retriever import MedicalRAG
|
| 13 |
+
from backend.ml.rag.generator import ChatGenerator
|
| 14 |
+
|
| 15 |
+
router = APIRouter()
|
| 16 |
+
|
| 17 |
+
@router.post("/")
|
| 18 |
+
async def chat(
|
| 19 |
+
body: ChatRequest,
|
| 20 |
+
session = Depends(get_session_or_404),
|
| 21 |
+
registry = Depends(get_model_registry),
|
| 22 |
+
current_user = Depends(get_current_user),
|
| 23 |
+
db: AsyncSession = Depends(get_db)
|
| 24 |
+
):
|
| 25 |
+
is_safe, warning = MedicalRAG.is_safe_query(body.message)
|
| 26 |
+
|
| 27 |
+
# Placeholder for actual DB history fetch
|
| 28 |
+
history_dicts = []
|
| 29 |
+
|
| 30 |
+
session_result = session.result_json if hasattr(session, 'result_json') else {}
|
| 31 |
+
chunks = MedicalRAG.retrieve(body.message, session_result, n_results=5)
|
| 32 |
+
prompt = MedicalRAG.build_prompt(body.message, chunks, history_dicts, session_result)
|
| 33 |
+
|
| 34 |
+
biogpt_state = await registry.get("biogpt_base")
|
| 35 |
+
full_response = []
|
| 36 |
+
|
| 37 |
+
async def stream_generator():
|
| 38 |
+
yield f"data: {json.dumps({'type': 'sources', 'sources': chunks[:3]})}\n\n"
|
| 39 |
+
|
| 40 |
+
if not biogpt_state.is_available:
|
| 41 |
+
fallback = "I cannot provide a detailed answer right now as the AI system is unavailable."
|
| 42 |
+
yield f"data: {json.dumps({'type': 'token', 'token': fallback})}\n\n"
|
| 43 |
+
else:
|
| 44 |
+
token_gen = await asyncio.to_thread(lambda: list(ChatGenerator.generate_stream(prompt, biogpt_state.model, biogpt_state.tokenizer)))
|
| 45 |
+
for token in token_gen:
|
| 46 |
+
full_response.append(token)
|
| 47 |
+
yield f"data: {json.dumps({'type': 'token', 'token': token})}\n\n"
|
| 48 |
+
await asyncio.sleep(0.02) # Optional smoothing
|
| 49 |
+
|
| 50 |
+
if warning:
|
| 51 |
+
yield f"data: {json.dumps({'type': 'token', 'token': warning})}\n\n"
|
| 52 |
+
|
| 53 |
+
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
| 54 |
+
|
| 55 |
+
return StreamingResponse(
|
| 56 |
+
stream_generator(),
|
| 57 |
+
media_type="text/event-stream",
|
| 58 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}
|
| 59 |
+
)
|
backend/ml/fusion/medclip.py
CHANGED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class MultimodalFusion:
|
| 11 |
+
"""Computes image-text alignment using BiomedVLP."""
|
| 12 |
+
IMAGE_SIZE = 224
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def get_image_transform():
|
| 16 |
+
return transforms.Compose([
|
| 17 |
+
transforms.Resize((MultimodalFusion.IMAGE_SIZE, MultimodalFusion.IMAGE_SIZE)),
|
| 18 |
+
transforms.ToTensor(),
|
| 19 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 20 |
+
])
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_image_embedding(image_path: Path, model, device: str) -> torch.Tensor:
|
| 24 |
+
image = Image.open(image_path).convert("RGB")
|
| 25 |
+
transform = MultimodalFusion.get_image_transform()
|
| 26 |
+
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 27 |
+
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
embedding = model.get_image_embeddings(image_tensor)
|
| 30 |
+
return F.normalize(embedding, p=2, dim=-1)
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def get_text_embedding(text: str, model, tokenizer, device: str) -> torch.Tensor:
|
| 34 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256, padding="max_length").to(device)
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
embedding = model.get_text_embeddings(**inputs)
|
| 37 |
+
return F.normalize(embedding, p=2, dim=-1)
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def compute_similarity(image_path: Path, text: str, model, tokenizer, device: str) -> tuple[float, str]:
|
| 41 |
+
try:
|
| 42 |
+
img_emb = MultimodalFusion.get_image_embedding(image_path, model, device)
|
| 43 |
+
txt_emb = MultimodalFusion.get_text_embedding(text, model, tokenizer, device)
|
| 44 |
+
|
| 45 |
+
similarity = float(torch.cosine_similarity(img_emb, txt_emb).item())
|
| 46 |
+
similarity = (similarity + 1) / 2 # Shift to [0,1]
|
| 47 |
+
|
| 48 |
+
if similarity >= 0.7:
|
| 49 |
+
alignment = "HIGH"
|
| 50 |
+
elif similarity >= 0.4:
|
| 51 |
+
alignment = "MEDIUM"
|
| 52 |
+
else:
|
| 53 |
+
alignment = "LOW"
|
| 54 |
+
return round(similarity, 3), alignment
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.warning(f"Fusion similarity computation failed: {e}")
|
| 57 |
+
return 0.5, "UNKNOWN"
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def get_fused_embedding(image_path: Path, text: str, model, tokenizer, device: str) -> torch.Tensor:
|
| 61 |
+
img_emb = MultimodalFusion.get_image_embedding(image_path, model, device)
|
| 62 |
+
txt_emb = MultimodalFusion.get_text_embedding(text, model, tokenizer, device)
|
| 63 |
+
return torch.cat([img_emb, txt_emb], dim=-1)
|
| 64 |
+
|
| 65 |
+
class FallbackFusion:
|
| 66 |
+
@staticmethod
|
| 67 |
+
def compute_similarity(image_path: Path, text: str) -> tuple[float, str]:
|
| 68 |
+
"""Simple keyword-based fallback when BiomedVLP unavailable due to RAM constraints."""
|
| 69 |
+
CHEST_KEYWORDS = ["chest", "lung", "cardiac", "pleural", "pneumo", "infiltrate", "opacity", "nodule", "effusion"]
|
| 70 |
+
text_lower = text.lower()
|
| 71 |
+
matches = sum(1 for kw in CHEST_KEYWORDS if kw in text_lower)
|
| 72 |
+
score = min(0.9, 0.3 + matches * 0.1)
|
| 73 |
+
alignment = "HIGH" if score > 0.6 else "MEDIUM" if score > 0.4 else "LOW"
|
| 74 |
+
return score, alignment
|
backend/ml/nlp/classifier.py
CHANGED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from backend.api.v1.schemas.analysis import NERResult
|
| 3 |
+
|
| 4 |
+
CHEST_CONDITIONS = [
|
| 5 |
+
"Pneumonia", "Pleural Effusion", "Cardiomegaly", "Atelectasis",
|
| 6 |
+
"Pneumothorax", "Pulmonary Edema", "Tuberculosis", "Lung Cancer",
|
| 7 |
+
"COVID-19", "Chronic Obstructive Pulmonary Disease", "Asthma",
|
| 8 |
+
"Pulmonary Fibrosis", "Bronchitis", "Emphysema", "Heart Failure",
|
| 9 |
+
"Aortic Aneurysm", "Pulmonary Embolism", "Sarcoidosis",
|
| 10 |
+
"No significant finding", "Other condition"
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
class DiseaseClassifier:
|
| 14 |
+
"""Zero-shot classifier. No fine-tuning needed. Always runs on CPU."""
|
| 15 |
+
_pipeline = None
|
| 16 |
+
_pipeline_lock = threading.Lock()
|
| 17 |
+
|
| 18 |
+
@classmethod
|
| 19 |
+
def _get_pipeline(cls):
|
| 20 |
+
if cls._pipeline is None:
|
| 21 |
+
with cls._pipeline_lock:
|
| 22 |
+
if cls._pipeline is None:
|
| 23 |
+
from transformers import pipeline
|
| 24 |
+
# Note: For faster inference use "valhalla/distilbart-mnli-12-1"
|
| 25 |
+
# We stick to bart-large-mnli as instructed, but it takes ~2-4s on CPU
|
| 26 |
+
cls._pipeline = pipeline(
|
| 27 |
+
"zero-shot-classification",
|
| 28 |
+
model="facebook/bart-large-mnli",
|
| 29 |
+
device=-1
|
| 30 |
+
)
|
| 31 |
+
return cls._pipeline
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def classify(text: str, entities: NERResult, top_k: int = 3) -> dict:
|
| 35 |
+
if not text or not text.strip():
|
| 36 |
+
return {"primary": "Insufficient information", "confidence": 0.0, "differential": []}
|
| 37 |
+
|
| 38 |
+
enriched_text = DiseaseClassifier._build_enriched_text(text, entities)
|
| 39 |
+
pipe = DiseaseClassifier._get_pipeline()
|
| 40 |
+
|
| 41 |
+
result = pipe(
|
| 42 |
+
enriched_text,
|
| 43 |
+
candidate_labels=CHEST_CONDITIONS,
|
| 44 |
+
multi_label=False
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
scores_dict = dict(zip(result["labels"], result["scores"]))
|
| 48 |
+
sorted_labels = sorted(scores_dict, key=scores_dict.get, reverse=True)
|
| 49 |
+
|
| 50 |
+
primary = sorted_labels[0]
|
| 51 |
+
primary_confidence = scores_dict[primary]
|
| 52 |
+
|
| 53 |
+
differential = [
|
| 54 |
+
{"disease": label, "confidence": round(scores_dict[label], 3)}
|
| 55 |
+
for label in sorted_labels[1:top_k]
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
"primary": primary,
|
| 60 |
+
"confidence": round(primary_confidence, 3),
|
| 61 |
+
"differential": differential
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def _build_enriched_text(original: str, entities: NERResult) -> str:
|
| 66 |
+
parts = [original]
|
| 67 |
+
if entities.diseases:
|
| 68 |
+
parts.append(f"Diagnosed conditions: {', '.join(entities.diseases)}")
|
| 69 |
+
if entities.symptoms:
|
| 70 |
+
parts.append(f"Presenting symptoms: {', '.join(entities.symptoms)}")
|
| 71 |
+
if entities.medications:
|
| 72 |
+
parts.append(f"Current medications: {', '.join(entities.medications)}")
|
| 73 |
+
|
| 74 |
+
enriched = ". ".join(parts)
|
| 75 |
+
return enriched[:1024]
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
test_text = "Patient complains of severe chest pain and shortness of breath."
|
| 79 |
+
entities = NERResult(diseases=[], symptoms=["chest pain", "shortness of breath"], medications=[], anatomy=[], raw_entities=[])
|
| 80 |
+
res = DiseaseClassifier.classify(test_text, entities)
|
| 81 |
+
print("Classification result:", res)
|
backend/ml/nlp/ner.py
CHANGED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
| 3 |
+
from backend.api.v1.schemas.analysis import NERResult
|
| 4 |
+
|
| 5 |
+
class NERExtractor:
|
| 6 |
+
ENTITY_MAP = {
|
| 7 |
+
"DISEASE": "diseases",
|
| 8 |
+
"SYMPTOM": "symptoms",
|
| 9 |
+
"MEDICATION": "medications",
|
| 10 |
+
"ANATOMY": "anatomy"
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def extract(text: str, model: AutoModelForTokenClassification, tokenizer: AutoTokenizer, device: str = "cuda") -> NERResult:
|
| 15 |
+
if not text or not text.strip():
|
| 16 |
+
return NERResult(diseases=[], symptoms=[], medications=[], anatomy=[], raw_entities=[])
|
| 17 |
+
|
| 18 |
+
chunks = NERExtractor._chunk_text(text, tokenizer, max_length=400, overlap=50)
|
| 19 |
+
|
| 20 |
+
all_entities = []
|
| 21 |
+
seen_spans = set()
|
| 22 |
+
|
| 23 |
+
for chunk_text, char_offset in chunks:
|
| 24 |
+
chunk_entities = NERExtractor._extract_chunk(chunk_text, model, tokenizer, device, char_offset)
|
| 25 |
+
for entity in chunk_entities:
|
| 26 |
+
span_key = (entity["text"].lower(), entity["entity_type"])
|
| 27 |
+
if span_key not in seen_spans:
|
| 28 |
+
seen_spans.add(span_key)
|
| 29 |
+
all_entities.append(entity)
|
| 30 |
+
|
| 31 |
+
grouped = {"diseases": [], "symptoms": [], "medications": [], "anatomy": []}
|
| 32 |
+
for entity in all_entities:
|
| 33 |
+
entity_type = entity["entity_type"].split("-")[-1]
|
| 34 |
+
group_key = NERExtractor.ENTITY_MAP.get(entity_type)
|
| 35 |
+
if group_key:
|
| 36 |
+
entity_text = entity["text"].strip()
|
| 37 |
+
if entity_text and entity_text not in grouped[group_key]:
|
| 38 |
+
grouped[group_key].append(entity_text)
|
| 39 |
+
|
| 40 |
+
return NERResult(
|
| 41 |
+
diseases=grouped["diseases"], symptoms=grouped["symptoms"],
|
| 42 |
+
medications=grouped["medications"], anatomy=grouped["anatomy"],
|
| 43 |
+
raw_entities=all_entities
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def _extract_chunk(text: str, model, tokenizer, device: str, char_offset: int = 0) -> list[dict]:
|
| 48 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, return_offsets_mapping=True, padding=False)
|
| 49 |
+
offset_mapping = inputs.pop("offset_mapping")[0]
|
| 50 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 51 |
+
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
outputs = model(**inputs)
|
| 54 |
+
|
| 55 |
+
predictions = torch.argmax(outputs.logits, dim=-1)[0]
|
| 56 |
+
entities = []
|
| 57 |
+
current_entity = None
|
| 58 |
+
|
| 59 |
+
for idx, (pred_id, offsets) in enumerate(zip(predictions, offset_mapping)):
|
| 60 |
+
label = model.config.id2label[pred_id.item()]
|
| 61 |
+
start, end = offsets[0].item(), offsets[1].item()
|
| 62 |
+
|
| 63 |
+
if start == 0 and end == 0: # Special token
|
| 64 |
+
if current_entity:
|
| 65 |
+
entities.append(current_entity)
|
| 66 |
+
current_entity = None
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
token_text = text[start:end]
|
| 70 |
+
|
| 71 |
+
if label.startswith("B-"):
|
| 72 |
+
if current_entity: entities.append(current_entity)
|
| 73 |
+
current_entity = {
|
| 74 |
+
"text": token_text, "entity_type": label,
|
| 75 |
+
"start": start + char_offset, "end": end + char_offset,
|
| 76 |
+
"confidence": torch.softmax(outputs.logits[0][idx], dim=-1).max().item()
|
| 77 |
+
}
|
| 78 |
+
elif label.startswith("I-") and current_entity:
|
| 79 |
+
current_entity["text"] += token_text if not token_text.startswith("##") else token_text[2:]
|
| 80 |
+
current_entity["end"] = end + char_offset
|
| 81 |
+
else:
|
| 82 |
+
if current_entity:
|
| 83 |
+
entities.append(current_entity)
|
| 84 |
+
current_entity = None
|
| 85 |
+
|
| 86 |
+
if current_entity:
|
| 87 |
+
entities.append(current_entity)
|
| 88 |
+
|
| 89 |
+
return [e for e in entities if e["confidence"] > 0.7]
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def _chunk_text(text: str, tokenizer, max_length: int = 400, overlap: int = 50) -> list[tuple[str, int]]:
|
| 93 |
+
words = text.split()
|
| 94 |
+
chunks = []
|
| 95 |
+
current_words = []
|
| 96 |
+
current_length = 0
|
| 97 |
+
char_offset = 0
|
| 98 |
+
|
| 99 |
+
for word in words:
|
| 100 |
+
word_tokens = tokenizer(word, add_special_tokens=False)["input_ids"]
|
| 101 |
+
if current_length + len(word_tokens) > max_length and current_words:
|
| 102 |
+
chunk_text = " ".join(current_words)
|
| 103 |
+
chunks.append((chunk_text, char_offset))
|
| 104 |
+
|
| 105 |
+
overlap_words = current_words[-overlap//4:]
|
| 106 |
+
char_offset += len(" ".join(current_words[:-len(overlap_words)])) + 1
|
| 107 |
+
current_words = overlap_words
|
| 108 |
+
current_length = sum(len(tokenizer(w, add_special_tokens=False)["input_ids"]) for w in current_words)
|
| 109 |
+
|
| 110 |
+
current_words.append(word)
|
| 111 |
+
current_length += len(word_tokens)
|
| 112 |
+
|
| 113 |
+
if current_words:
|
| 114 |
+
chunks.append((" ".join(current_words), char_offset))
|
| 115 |
+
|
| 116 |
+
return chunks if chunks else [(text, 0)]
|
| 117 |
+
|
| 118 |
+
def highlight_entities(text: str, entities: list[dict]) -> str:
|
| 119 |
+
COLORS = {
|
| 120 |
+
"DISEASE": "#FF6B6B", "SYMPTOM": "#FFD93D",
|
| 121 |
+
"MEDICATION": "#6BCB77", "ANATOMY": "#4D96FF"
|
| 122 |
+
}
|
| 123 |
+
sorted_entities = sorted(entities, key=lambda e: e["start"], reverse=True)
|
| 124 |
+
result = text
|
| 125 |
+
for entity in sorted_entities:
|
| 126 |
+
entity_type = entity["entity_type"].replace("B-", "").replace("I-", "")
|
| 127 |
+
color = COLORS.get(entity_type, "#cccccc")
|
| 128 |
+
span = (
|
| 129 |
+
f'<mark style="background:{color};padding:2px 4px;border-radius:3px;'
|
| 130 |
+
f'font-size:0.85em" title="{entity_type} ({entity["confidence"]:.0%})">'
|
| 131 |
+
f'{entity["text"]}</mark>'
|
| 132 |
+
)
|
| 133 |
+
result = result[:entity["start"]] + span + result[entity["end"]:]
|
| 134 |
+
return result
|
backend/ml/nlp/whisper.py
CHANGED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from backend.core.exceptions import InvalidFileError, ValidationError, InferenceError
|
| 6 |
+
|
| 7 |
+
class WhisperTranscriber:
|
| 8 |
+
SUPPORTED_INPUT_FORMATS = {".wav", ".mp3", ".webm", ".ogg", ".m4a"}
|
| 9 |
+
MAX_DURATION_SECONDS = 60
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
def transcribe(audio_file_path: Path, model, language: str = "en") -> dict:
|
| 13 |
+
audio_path = Path(audio_file_path)
|
| 14 |
+
if not audio_path.exists():
|
| 15 |
+
raise InvalidFileError(f"Audio file not found: {audio_path}")
|
| 16 |
+
|
| 17 |
+
wav_path = None
|
| 18 |
+
try:
|
| 19 |
+
if audio_path.suffix.lower() != ".wav":
|
| 20 |
+
wav_path = audio_path.with_suffix(".converted.wav")
|
| 21 |
+
WhisperTranscriber._convert_to_wav(audio_path, wav_path)
|
| 22 |
+
process_path = wav_path
|
| 23 |
+
else:
|
| 24 |
+
process_path = audio_path
|
| 25 |
+
|
| 26 |
+
duration = WhisperTranscriber._get_duration(process_path)
|
| 27 |
+
if duration > WhisperTranscriber.MAX_DURATION_SECONDS:
|
| 28 |
+
raise ValidationError(f"Audio too long ({duration:.0f}s). Max {WhisperTranscriber.MAX_DURATION_SECONDS}s.")
|
| 29 |
+
|
| 30 |
+
result = model.transcribe(
|
| 31 |
+
str(process_path), language=language, verbose=False,
|
| 32 |
+
word_timestamps=True, fp16=False, condition_on_previous_text=False,
|
| 33 |
+
no_speech_threshold=0.6, logprob_threshold=-1.0
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
avg_logprob = np.mean([s["avg_logprob"] for s in result["segments"]]) if result["segments"] else -1
|
| 37 |
+
confidence = float(min(1.0, max(0.0, np.exp(avg_logprob))))
|
| 38 |
+
|
| 39 |
+
return {
|
| 40 |
+
"text": result["text"].strip(),
|
| 41 |
+
"language": result["language"],
|
| 42 |
+
"confidence": round(confidence, 3),
|
| 43 |
+
"duration_seconds": round(duration, 1),
|
| 44 |
+
"segments": [{"start": s["start"], "end": s["end"], "text": s["text"]} for s in result["segments"]]
|
| 45 |
+
}
|
| 46 |
+
finally:
|
| 47 |
+
if wav_path and wav_path.exists():
|
| 48 |
+
wav_path.unlink()
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def _convert_to_wav(input_path: Path, output_path: Path):
|
| 52 |
+
cmd = [
|
| 53 |
+
"ffmpeg", "-i", str(input_path), "-acodec", "pcm_s16le",
|
| 54 |
+
"-ac", "1", "-ar", "16000", "-y", str(output_path)
|
| 55 |
+
]
|
| 56 |
+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
| 57 |
+
if result.returncode != 0:
|
| 58 |
+
raise InferenceError(f"Audio conversion failed: {result.stderr[:200]}")
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def _get_duration(wav_path: Path) -> float:
|
| 62 |
+
cmd = [
|
| 63 |
+
"ffprobe", "-v", "quiet", "-print_format", "json",
|
| 64 |
+
"-show_format", str(wav_path)
|
| 65 |
+
]
|
| 66 |
+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
| 67 |
+
if result.returncode != 0:
|
| 68 |
+
return 0.0
|
| 69 |
+
data = json.loads(result.stdout)
|
| 70 |
+
return float(data.get("format", {}).get("duration", 0))
|
backend/ml/rag/generator.py
CHANGED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from threading import Thread
|
| 3 |
+
from transformers import TextIteratorStreamer
|
| 4 |
+
|
| 5 |
+
class ReportGenerator:
|
| 6 |
+
@staticmethod
|
| 7 |
+
def generate(vision, nlp, fusion, model, tokenizer) -> str:
|
| 8 |
+
parts = ["Generate a medical radiology report based on these findings:"]
|
| 9 |
+
if vision: parts.append(f"Imaging: {vision.risk_level} risk, anomaly score {vision.anomaly_score}/100")
|
| 10 |
+
if nlp: parts.append(f"Clinical: {nlp.primary_diagnosis}, confidence {nlp.diagnosis_confidence:.0%}")
|
| 11 |
+
if fusion: parts.append(f"Image-text alignment: {fusion.alignment}")
|
| 12 |
+
parts.append("Report:")
|
| 13 |
+
|
| 14 |
+
prompt = " ".join(parts)
|
| 15 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
| 16 |
+
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
outputs = model.generate(
|
| 19 |
+
**inputs, max_new_tokens=300, do_sample=False, num_beams=4,
|
| 20 |
+
early_stopping=True, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.3
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 24 |
+
return ReportGenerator._format_report(generated, vision, nlp)
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def _format_report(raw_text: str, vision, nlp) -> str:
|
| 28 |
+
sections = ["## AI Diagnostic Report\n"]
|
| 29 |
+
if vision:
|
| 30 |
+
risk_emoji = "🔴" if vision.risk_level == "HIGH" else "🟡" if vision.risk_level == "MEDIUM" else "🟢"
|
| 31 |
+
sections.append(f"### Imaging Findings\n{risk_emoji} **Risk Level:** {vision.risk_level} \n**Anomaly Score:** {vision.anomaly_score}/100\n")
|
| 32 |
+
if nlp:
|
| 33 |
+
sections.append(f"### Clinical Assessment\n**Primary Impression:** {nlp.primary_diagnosis}\n")
|
| 34 |
+
sections.append(f"### AI Analysis\n{raw_text.strip()}\n\n### Recommendation\nPlease consult a licensed physician.")
|
| 35 |
+
return "\n".join(sections)
|
| 36 |
+
|
| 37 |
+
class ChatGenerator:
|
| 38 |
+
@staticmethod
|
| 39 |
+
def generate_stream(prompt: str, model, tokenizer, max_new_tokens: int = 200):
|
| 40 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
| 41 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 42 |
+
|
| 43 |
+
generation_kwargs = {
|
| 44 |
+
**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens,
|
| 45 |
+
"do_sample": True, "temperature": 0.7, "top_p": 0.9,
|
| 46 |
+
"repetition_penalty": 1.2, "pad_token_id": tokenizer.eos_token_id
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 50 |
+
thread.start()
|
| 51 |
+
|
| 52 |
+
for token in streamer:
|
| 53 |
+
yield token
|
backend/ml/rag/retriever.py
CHANGED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from backend.ml.rag.vectorstore import vector_store
|
| 3 |
+
|
| 4 |
+
class MedicalRAG:
|
| 5 |
+
"""Retrieval-Augmented Generation for medical Q&A."""
|
| 6 |
+
MAX_CONTEXT_TOKENS = 800
|
| 7 |
+
MAX_HISTORY_TURNS = 6
|
| 8 |
+
SAFETY_DISCLAIMER = "\n\n*Note: This is AI-generated information for educational purposes only. Please consult a licensed physician for medical advice.*"
|
| 9 |
+
DRUG_DOSAGE_PATTERNS = [r"\b\d+\s*mg\b", r"how much.*take", r"dosage", r"dose of"]
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
def retrieve(query: str, session_result: dict | None, n_results: int = 5) -> list[dict]:
|
| 13 |
+
enriched_query = query
|
| 14 |
+
if session_result:
|
| 15 |
+
primary_dx = session_result.get("nlp", {}).get("primary_diagnosis", "")
|
| 16 |
+
symptoms = session_result.get("nlp", {}).get("entities", {}).get("symptoms", [])
|
| 17 |
+
if primary_dx:
|
| 18 |
+
enriched_query = f"{query} {primary_dx} {' '.join(symptoms[:3])}"
|
| 19 |
+
return vector_store.search(enriched_query, n_results=n_results)
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def build_prompt(query: str, retrieved_chunks: list[dict], chat_history: list[dict], session_result: dict | None) -> str:
|
| 23 |
+
parts = []
|
| 24 |
+
if session_result:
|
| 25 |
+
vision, nlp = session_result.get("vision", {}), session_result.get("nlp", {})
|
| 26 |
+
if vision or nlp:
|
| 27 |
+
parts.append("Patient analysis context:")
|
| 28 |
+
if vision:
|
| 29 |
+
parts.append(f"- Imaging: {vision.get('risk_level', 'unknown')} risk (anomaly score: {vision.get('anomaly_score', 'N/A')}/100)")
|
| 30 |
+
if nlp:
|
| 31 |
+
parts.append(f"- Primary impression: {nlp.get('primary_diagnosis', 'unknown')}")
|
| 32 |
+
parts.append("")
|
| 33 |
+
|
| 34 |
+
if retrieved_chunks:
|
| 35 |
+
parts.append("Relevant medical information:")
|
| 36 |
+
for chunk in retrieved_chunks[:3]:
|
| 37 |
+
parts.append(f"- {chunk['text'][:200]}...")
|
| 38 |
+
parts.append("")
|
| 39 |
+
|
| 40 |
+
recent_history = chat_history[-MedicalRAG.MAX_HISTORY_TURNS:]
|
| 41 |
+
for msg in recent_history:
|
| 42 |
+
role = "Patient" if msg["role"] == "user" else "Assistant"
|
| 43 |
+
parts.append(f"{role}: {msg['content'][:100]}")
|
| 44 |
+
|
| 45 |
+
parts.append(f"Patient: {query}\nAssistant:")
|
| 46 |
+
return "\n".join(parts)
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def is_safe_query(query: str) -> tuple[bool, str | None]:
|
| 50 |
+
for pattern in MedicalRAG.DRUG_DOSAGE_PATTERNS:
|
| 51 |
+
if re.search(pattern, query.lower()):
|
| 52 |
+
return False, "Please consult a physician for specific dosage information."
|
| 53 |
+
return True, None
|
backend/ml/rag/vectorstore.py
CHANGED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import asyncio
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from backend.core.config import settings
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
class MedicalVectorStore:
|
| 10 |
+
COLLECTION_NAME = "pubmed_medical"
|
| 11 |
+
CHUNK_SIZE = 256
|
| 12 |
+
CHUNK_OVERLAP = 32
|
| 13 |
+
BATCH_SIZE = 100
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self._client = None
|
| 17 |
+
self._collection = None
|
| 18 |
+
self._embedder = None
|
| 19 |
+
|
| 20 |
+
async def initialize(self):
|
| 21 |
+
if settings.ENVIRONMENT == "production":
|
| 22 |
+
import chromadb
|
| 23 |
+
self._client = chromadb.EphemeralClient()
|
| 24 |
+
else:
|
| 25 |
+
import chromadb
|
| 26 |
+
persist_path = str(settings.BASE_DIR / "data" / "chromadb" if hasattr(settings, 'BASE_DIR') else Path("data/chromadb"))
|
| 27 |
+
self._client = chromadb.PersistentClient(path=persist_path)
|
| 28 |
+
|
| 29 |
+
self._collection = self._client.get_or_create_collection(
|
| 30 |
+
name=self.COLLECTION_NAME, metadata={"hnsw:space": "cosine"}
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
doc_count = self._collection.count()
|
| 34 |
+
logger.info(f"ChromaDB initialized. Documents: {doc_count}")
|
| 35 |
+
|
| 36 |
+
if doc_count == 0:
|
| 37 |
+
logger.info("ChromaDB is empty. Starting PubMed ingestion...")
|
| 38 |
+
await asyncio.to_thread(self.ingest_from_json, Path("data/pubmed_raw.json"))
|
| 39 |
+
|
| 40 |
+
def ingest_from_json(self, json_path: Path):
|
| 41 |
+
if not json_path.exists():
|
| 42 |
+
logger.warning(f"PubMed data file not found: {json_path}. RAG will be limited.")
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
records = json.loads(json_path.read_text(encoding="utf-8"))
|
| 46 |
+
logger.info(f"Ingesting {len(records)} PubMed records into ChromaDB...")
|
| 47 |
+
|
| 48 |
+
all_chunks, all_ids, all_metadatas = [], [], []
|
| 49 |
+
|
| 50 |
+
for record in records:
|
| 51 |
+
if not record.get("abstract"): continue
|
| 52 |
+
text = f"{record['title']}. {record['abstract']}"
|
| 53 |
+
chunks = self._chunk_text(text, self.CHUNK_SIZE, self.CHUNK_OVERLAP)
|
| 54 |
+
|
| 55 |
+
for i, chunk in enumerate(chunks):
|
| 56 |
+
all_chunks.append(chunk)
|
| 57 |
+
all_ids.append(f"{record['pmid']}_chunk_{i}")
|
| 58 |
+
all_metadatas.append({
|
| 59 |
+
"pmid": str(record["pmid"]),
|
| 60 |
+
"title": record["title"][:200],
|
| 61 |
+
"disease_category": record.get("category", "general"),
|
| 62 |
+
"year": str(record.get("year", ""))
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
embedder = self._get_embedder()
|
| 66 |
+
|
| 67 |
+
for i in range(0, len(all_chunks), self.BATCH_SIZE):
|
| 68 |
+
batch_chunks = all_chunks[i:i+self.BATCH_SIZE]
|
| 69 |
+
batch_ids = all_ids[i:i+self.BATCH_SIZE]
|
| 70 |
+
batch_metas = all_metadatas[i:i+self.BATCH_SIZE]
|
| 71 |
+
|
| 72 |
+
embeddings = embedder.encode(batch_chunks, show_progress_bar=False).tolist()
|
| 73 |
+
self._collection.add(documents=batch_chunks, embeddings=embeddings, ids=batch_ids, metadatas=batch_metas)
|
| 74 |
+
|
| 75 |
+
logger.info(f"Ingestion complete. Total documents: {self._collection.count()}")
|
| 76 |
+
|
| 77 |
+
def search(self, query: str, n_results: int = 5, disease_filter: str = None) -> list[dict]:
|
| 78 |
+
if not self._collection: return []
|
| 79 |
+
embedder = self._get_embedder()
|
| 80 |
+
query_embedding = embedder.encode([query])[0].tolist()
|
| 81 |
+
|
| 82 |
+
where_filter = {"disease_category": disease_filter} if disease_filter else None
|
| 83 |
+
|
| 84 |
+
results = self._collection.query(
|
| 85 |
+
query_embeddings=[query_embedding], n_results=n_results,
|
| 86 |
+
where=where_filter, include=["documents", "metadatas", "distances"]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
output = []
|
| 90 |
+
if results["documents"]:
|
| 91 |
+
for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0]):
|
| 92 |
+
relevance = float(1 - dist)
|
| 93 |
+
output.append({
|
| 94 |
+
"text": doc, "pmid": meta.get("pmid", ""), "title": meta.get("title", ""),
|
| 95 |
+
"disease_category": meta.get("disease_category", ""), "year": meta.get("year", ""),
|
| 96 |
+
"relevance_score": round(max(0, relevance), 3)
|
| 97 |
+
})
|
| 98 |
+
return output
|
| 99 |
+
|
| 100 |
+
def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> list[str]:
|
| 101 |
+
words = text.split()
|
| 102 |
+
chunks, start = [], 0
|
| 103 |
+
while start < len(words):
|
| 104 |
+
end = min(start + chunk_size, len(words))
|
| 105 |
+
chunks.append(" ".join(words[start:end]))
|
| 106 |
+
start = end - overlap
|
| 107 |
+
if start >= len(words): break
|
| 108 |
+
return chunks
|
| 109 |
+
|
| 110 |
+
def _get_embedder(self):
|
| 111 |
+
if self._embedder is None:
|
| 112 |
+
from sentence_transformers import SentenceTransformer
|
| 113 |
+
cache_dir = str(settings.MODEL_CACHE_DIR / "minilm")
|
| 114 |
+
self._embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 115 |
+
return self._embedder
|
| 116 |
+
|
| 117 |
+
vector_store = MedicalVectorStore()
|
frontend/components/analysis/VoiceInput.jsx
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client'
|
| 2 |
+
import { useState, useRef } from 'react'
|
| 3 |
+
import { apiClient } from '@/lib/api/client'
|
| 4 |
+
|
| 5 |
+
export default function VoiceInput({ onTranscribed }) {
|
| 6 |
+
const [isRecording, setIsRecording] = useState(false)
|
| 7 |
+
const [recordingDuration, setRecordingDuration] = useState(0)
|
| 8 |
+
const [isTranscribing, setIsTranscribing] = useState(false)
|
| 9 |
+
const [transcribedText, setTranscribedText] = useState("")
|
| 10 |
+
const [confidence, setConfidence] = useState(null)
|
| 11 |
+
const [error, setError] = useState(null)
|
| 12 |
+
const [audioBlob, setAudioBlob] = useState(null)
|
| 13 |
+
|
| 14 |
+
const mediaRecorderRef = useRef(null)
|
| 15 |
+
const chunksRef = useRef([])
|
| 16 |
+
const timerRef = useRef(null)
|
| 17 |
+
|
| 18 |
+
const startRecording = async () => {
|
| 19 |
+
try {
|
| 20 |
+
const stream = await navigator.mediaDevices.getUserMedia({ audio: true })
|
| 21 |
+
const mediaRecorder = new MediaRecorder(stream, { mimeType: "audio/webm" })
|
| 22 |
+
mediaRecorderRef.current = mediaRecorder
|
| 23 |
+
chunksRef.current = []
|
| 24 |
+
|
| 25 |
+
mediaRecorder.ondataavailable = e => chunksRef.current.push(e.data)
|
| 26 |
+
mediaRecorder.onstop = () => {
|
| 27 |
+
const blob = new Blob(chunksRef.current, { type: "audio/webm" })
|
| 28 |
+
setAudioBlob(blob)
|
| 29 |
+
sendForTranscription(blob)
|
| 30 |
+
stream.getTracks().forEach(track => track.stop())
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
mediaRecorder.start(1000)
|
| 34 |
+
setIsRecording(true)
|
| 35 |
+
setRecordingDuration(0)
|
| 36 |
+
setError(null)
|
| 37 |
+
|
| 38 |
+
timerRef.current = setInterval(() => {
|
| 39 |
+
setRecordingDuration(prev => {
|
| 40 |
+
if (prev >= 59) {
|
| 41 |
+
stopRecording()
|
| 42 |
+
return 60
|
| 43 |
+
}
|
| 44 |
+
return prev + 1
|
| 45 |
+
})
|
| 46 |
+
}, 1000)
|
| 47 |
+
} catch (err) {
|
| 48 |
+
setError("Microphone access denied or unavailable.")
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
const stopRecording = () => {
|
| 53 |
+
if (mediaRecorderRef.current && isRecording) {
|
| 54 |
+
mediaRecorderRef.current.stop()
|
| 55 |
+
setIsRecording(false)
|
| 56 |
+
clearInterval(timerRef.current)
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
const sendForTranscription = async (blob) => {
|
| 61 |
+
setIsTranscribing(true)
|
| 62 |
+
const formData = new FormData()
|
| 63 |
+
formData.append("audio", blob, "recording.webm")
|
| 64 |
+
|
| 65 |
+
try {
|
| 66 |
+
const response = await apiClient.post('/api/v1/analyze/transcribe', formData, {
|
| 67 |
+
headers: { 'Content-Type': 'multipart/form-data' }
|
| 68 |
+
})
|
| 69 |
+
setTranscribedText(response.data.text)
|
| 70 |
+
setConfidence(response.data.confidence)
|
| 71 |
+
} catch (err) {
|
| 72 |
+
setError("Transcription failed. Please try again.")
|
| 73 |
+
} finally {
|
| 74 |
+
setIsTranscribing(false)
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
const handleReset = () => {
|
| 79 |
+
setTranscribedText("")
|
| 80 |
+
setConfidence(null)
|
| 81 |
+
setAudioBlob(null)
|
| 82 |
+
setRecordingDuration(0)
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
return (
|
| 86 |
+
<div className="p-4 border rounded-lg bg-gray-50">
|
| 87 |
+
{!transcribedText && !isTranscribing && (
|
| 88 |
+
<div className="flex flex-col items-center">
|
| 89 |
+
<button
|
| 90 |
+
onClick={isRecording ? stopRecording : startRecording}
|
| 91 |
+
className={`w-16 h-16 rounded-full flex items-center justify-center transition-all ${
|
| 92 |
+
isRecording ? "bg-red-500 animate-pulse" : "bg-teal-500 hover:bg-teal-600"
|
| 93 |
+
}`}
|
| 94 |
+
>
|
| 95 |
+
<svg className="w-8 h-8 text-white" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
| 96 |
+
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth="2" d="M19 11a7 7 0 01-7 7m0 0a7 7 0 01-7-7m7 7v4m0 0H8m4 0h4m-4-8a3 3 0 01-3-3V5a3 3 0 116 0v6a3 3 0 01-3 3z"></path>
|
| 97 |
+
</svg>
|
| 98 |
+
</button>
|
| 99 |
+
|
| 100 |
+
<p className={`mt-2 font-mono ${recordingDuration > 50 ? 'text-red-500' : 'text-gray-600'}`}>
|
| 101 |
+
{Math.floor(recordingDuration / 60)}:{(recordingDuration % 60).toString().padStart(2, '0')} / 1:00
|
| 102 |
+
</p>
|
| 103 |
+
{error && <p className="mt-2 text-sm text-red-500">{error}</p>}
|
| 104 |
+
</div>
|
| 105 |
+
)}
|
| 106 |
+
|
| 107 |
+
{isTranscribing && (
|
| 108 |
+
<div className="flex flex-col items-center py-4">
|
| 109 |
+
<div className="w-6 h-6 border-2 border-teal-500 border-t-transparent rounded-full animate-spin" />
|
| 110 |
+
<p className="mt-2 text-sm text-gray-600">Transcribing audio...</p>
|
| 111 |
+
</div>
|
| 112 |
+
)}
|
| 113 |
+
|
| 114 |
+
{transcribedText && !isTranscribing && (
|
| 115 |
+
<div className="space-y-4">
|
| 116 |
+
<div className="p-3 bg-white border rounded">
|
| 117 |
+
<p className="text-gray-800 animate-[typewriter_0.5s_steps(40,end)] overflow-hidden break-words">
|
| 118 |
+
"{transcribedText}"
|
| 119 |
+
</p>
|
| 120 |
+
<div className="flex items-center mt-2 text-xs">
|
| 121 |
+
<span className={confidence > 0.8 ? "text-green-600" : "text-yellow-600"}>
|
| 122 |
+
{confidence > 0.8 ? "✅" : "⚠️"} Confidence: {(confidence * 100).toFixed(0)}%
|
| 123 |
+
</span>
|
| 124 |
+
</div>
|
| 125 |
+
</div>
|
| 126 |
+
<div className="flex space-x-2">
|
| 127 |
+
<button onClick={() => onTranscribed(transcribedText)} className="px-4 py-2 text-white bg-teal-600 rounded hover:bg-teal-700">
|
| 128 |
+
Use This Text
|
| 129 |
+
</button>
|
| 130 |
+
<button onClick={handleReset} className="px-4 py-2 text-gray-700 bg-gray-200 rounded hover:bg-gray-300">
|
| 131 |
+
Re-record
|
| 132 |
+
</button>
|
| 133 |
+
</div>
|
| 134 |
+
</div>
|
| 135 |
+
)}
|
| 136 |
+
</div>
|
| 137 |
+
)
|
| 138 |
+
}
|
frontend/components/chat/ChatInterface.jsx
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client'
|
| 2 |
+
import { useState } from 'react'
|
| 3 |
+
import { v4 as uuid } from 'uuid'
|
| 4 |
+
import { useAuth } from '@/lib/auth/AuthContext'
|
| 5 |
+
|
| 6 |
+
export default function ChatInterface({ sessionId }) {
|
| 7 |
+
const { accessToken } = useAuth()
|
| 8 |
+
const [messages, setMessages] = useState([])
|
| 9 |
+
const [inputText, setInputText] = useState("")
|
| 10 |
+
const [streamingMessageId, setStreamingMessageId] = useState(null)
|
| 11 |
+
|
| 12 |
+
const sendMessage = async (e) => {
|
| 13 |
+
e.preventDefault()
|
| 14 |
+
if(!inputText.trim() || streamingMessageId) return
|
| 15 |
+
|
| 16 |
+
const text = inputText
|
| 17 |
+
setInputText("")
|
| 18 |
+
|
| 19 |
+
const userMsg = { id: uuid(), role: "user", content: text }
|
| 20 |
+
setMessages(prev => [...prev, userMsg])
|
| 21 |
+
|
| 22 |
+
const assistantMsgId = uuid()
|
| 23 |
+
const assistantMsg = { id: assistantMsgId, role: "assistant", content: "", isStreaming: true, sources: [] }
|
| 24 |
+
setMessages(prev => [...prev, assistantMsg])
|
| 25 |
+
setStreamingMessageId(assistantMsgId)
|
| 26 |
+
|
| 27 |
+
const API_URL = process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000'
|
| 28 |
+
|
| 29 |
+
try {
|
| 30 |
+
const response = await fetch(`${API_URL}/api/v1/chat`, {
|
| 31 |
+
method: "POST",
|
| 32 |
+
headers: { "Content-Type": "application/json", "Authorization": `Bearer ${accessToken}` },
|
| 33 |
+
body: JSON.stringify({ session_id: sessionId, message: text })
|
| 34 |
+
})
|
| 35 |
+
|
| 36 |
+
const reader = response.body.getReader()
|
| 37 |
+
const decoder = new TextDecoder()
|
| 38 |
+
|
| 39 |
+
while (true) {
|
| 40 |
+
const { done, value } = await reader.read()
|
| 41 |
+
if (done) break
|
| 42 |
+
|
| 43 |
+
const chunk = decoder.decode(value, { stream: true })
|
| 44 |
+
const lines = chunk.split("\n\n").filter(l => l.startsWith("data: "))
|
| 45 |
+
|
| 46 |
+
for (const line of lines) {
|
| 47 |
+
const data = JSON.parse(line.replace("data: ", ""))
|
| 48 |
+
|
| 49 |
+
if (data.type === "token") {
|
| 50 |
+
setMessages(prev => prev.map(m => m.id === assistantMsgId ? { ...m, content: m.content + data.token } : m))
|
| 51 |
+
} else if (data.type === "sources") {
|
| 52 |
+
setMessages(prev => prev.map(m => m.id === assistantMsgId ? { ...m, sources: data.sources } : m))
|
| 53 |
+
} else if (data.type === "done") {
|
| 54 |
+
setMessages(prev => prev.map(m => m.id === assistantMsgId ? { ...m, isStreaming: false } : m))
|
| 55 |
+
setStreamingMessageId(null)
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
} catch (error) {
|
| 60 |
+
setMessages(prev => prev.map(m => m.id === assistantMsgId ? { ...m, content: "Network error occurred.", isStreaming: false } : m))
|
| 61 |
+
setStreamingMessageId(null)
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
return (
|
| 66 |
+
<div className="flex flex-col h-[600px] border rounded-lg bg-white">
|
| 67 |
+
<div className="flex-1 p-4 overflow-y-auto space-y-4">
|
| 68 |
+
{messages.map(m => (
|
| 69 |
+
<div key={m.id} className={`flex ${m.role === 'user' ? 'justify-end' : 'justify-start'}`}>
|
| 70 |
+
<div className={`max-w-[80%] p-3 rounded-lg ${m.role === 'user' ? 'bg-teal-600 text-white' : 'bg-gray-100 text-gray-800'}`}>
|
| 71 |
+
<p>{m.content}</p>
|
| 72 |
+
{m.isStreaming && <span className="inline-block w-2 h-4 ml-1 bg-gray-400 animate-pulse" />}
|
| 73 |
+
{m.sources?.length > 0 && (
|
| 74 |
+
<div className="mt-2 pt-2 border-t border-gray-300 text-xs">
|
| 75 |
+
<p className="font-semibold mb-1">Sources:</p>
|
| 76 |
+
{m.sources.map((s, i) => <p key={i}>• {s.title.substring(0, 50)}...</p>)}
|
| 77 |
+
</div>
|
| 78 |
+
)}
|
| 79 |
+
</div>
|
| 80 |
+
</div>
|
| 81 |
+
))}
|
| 82 |
+
</div>
|
| 83 |
+
<form onSubmit={sendMessage} className="p-3 border-t bg-gray-50 flex items-center">
|
| 84 |
+
<input
|
| 85 |
+
type="text"
|
| 86 |
+
value={inputText}
|
| 87 |
+
onChange={e => setInputText(e.target.value)}
|
| 88 |
+
disabled={!!streamingMessageId}
|
| 89 |
+
className="flex-1 p-2 border rounded-l-md focus:outline-none focus:ring-2 focus:ring-teal-500 disabled:opacity-50"
|
| 90 |
+
placeholder="Ask about the results..."
|
| 91 |
+
/>
|
| 92 |
+
<button type="submit" disabled={!!streamingMessageId} className="px-4 py-2 bg-teal-600 text-white rounded-r-md hover:bg-teal-700 disabled:opacity-50">
|
| 93 |
+
Send
|
| 94 |
+
</button>
|
| 95 |
+
</form>
|
| 96 |
+
</div>
|
| 97 |
+
)
|
| 98 |
+
}
|
training/scripts/finetune_ner.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from transformers import (
|
| 5 |
+
AutoModelForTokenClassification, AutoTokenizer,
|
| 6 |
+
TrainingArguments, Trainer, DataCollatorForTokenClassification
|
| 7 |
+
)
|
| 8 |
+
from datasets import load_from_disk
|
| 9 |
+
from seqeval.metrics import f1_score, precision_score, recall_score, classification_report
|
| 10 |
+
|
| 11 |
+
logging.basicConfig(level=logging.INFO)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
id2label = {
|
| 15 |
+
0: "O", 1: "B-DISEASE", 2: "I-DISEASE",
|
| 16 |
+
3: "B-SYMPTOM", 4: "I-SYMPTOM",
|
| 17 |
+
5: "B-MEDICATION", 6: "I-MEDICATION",
|
| 18 |
+
7: "B-ANATOMY", 8: "I-ANATOMY"
|
| 19 |
+
}
|
| 20 |
+
label2id = {v: k for k, v in id2label.items()}
|
| 21 |
+
|
| 22 |
+
def compute_metrics(eval_pred):
|
| 23 |
+
predictions, labels = eval_pred
|
| 24 |
+
predictions = np.argmax(predictions, axis=2)
|
| 25 |
+
|
| 26 |
+
true_labels = [[id2label[l] for l in label if l != -100] for label in labels]
|
| 27 |
+
true_predictions = [
|
| 28 |
+
[id2label[p] for p, l in zip(pred, label) if l != -100]
|
| 29 |
+
for pred, label in zip(predictions, labels)
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
return {
|
| 33 |
+
"precision": precision_score(true_labels, true_predictions),
|
| 34 |
+
"recall": recall_score(true_labels, true_predictions),
|
| 35 |
+
"f1": f1_score(true_labels, true_predictions),
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
def main():
|
| 39 |
+
dataset = load_from_disk("data/processed/ner_dataset")
|
| 40 |
+
model_id = "dmis-lab/biobert-base-cased-v1.2"
|
| 41 |
+
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 43 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 44 |
+
model_id, num_labels=9, id2label=id2label, label2id=label2id
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
training_args = TrainingArguments(
|
| 48 |
+
output_dir="models/biobert_ner",
|
| 49 |
+
per_device_train_batch_size=8,
|
| 50 |
+
per_device_eval_batch_size=8,
|
| 51 |
+
gradient_accumulation_steps=4,
|
| 52 |
+
fp16=True,
|
| 53 |
+
gradient_checkpointing=True,
|
| 54 |
+
dataloader_num_workers=0, # Windows fix
|
| 55 |
+
dataloader_pin_memory=False, # Windows fix
|
| 56 |
+
num_train_epochs=10,
|
| 57 |
+
warmup_ratio=0.1,
|
| 58 |
+
learning_rate=2e-5,
|
| 59 |
+
weight_decay=0.01,
|
| 60 |
+
evaluation_strategy="epoch",
|
| 61 |
+
save_strategy="epoch",
|
| 62 |
+
load_best_model_at_end=True,
|
| 63 |
+
metric_for_best_model="f1",
|
| 64 |
+
greater_is_better=True,
|
| 65 |
+
save_total_limit=3,
|
| 66 |
+
logging_steps=50,
|
| 67 |
+
report_to="none",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
|
| 71 |
+
|
| 72 |
+
trainer = Trainer(
|
| 73 |
+
model=model,
|
| 74 |
+
args=training_args,
|
| 75 |
+
train_dataset=dataset["train"],
|
| 76 |
+
eval_dataset=dataset["validation"],
|
| 77 |
+
tokenizer=tokenizer,
|
| 78 |
+
data_collator=data_collator,
|
| 79 |
+
compute_metrics=compute_metrics
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Auto-resume logic
|
| 83 |
+
checkpoint_dir = Path("models/biobert_ner")
|
| 84 |
+
latest_checkpoint = None
|
| 85 |
+
if checkpoint_dir.exists():
|
| 86 |
+
checkpoints = list(checkpoint_dir.glob("checkpoint-*"))
|
| 87 |
+
if checkpoints:
|
| 88 |
+
latest_checkpoint = str(max(checkpoints, key=lambda p: int(p.name.split("-")[1])))
|
| 89 |
+
logger.info(f"📂 Resuming from {latest_checkpoint}")
|
| 90 |
+
|
| 91 |
+
trainer.train(resume_from_checkpoint=latest_checkpoint)
|
| 92 |
+
|
| 93 |
+
logger.info("Evaluating on test set...")
|
| 94 |
+
test_results = trainer.predict(dataset["test"])
|
| 95 |
+
print(classification_report(*Trainer._get_labels_and_preds(test_results.predictions, test_results.label_ids)))
|
| 96 |
+
|
| 97 |
+
final_dir = "models/biobert_ner_finetuned"
|
| 98 |
+
trainer.save_model(final_dir)
|
| 99 |
+
tokenizer.save_pretrained(final_dir)
|
| 100 |
+
logger.info(f"✅ Model saved to {final_dir}")
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
main()
|
training/scripts/ingest_pubmed.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import requests
|
| 5 |
+
import xml.etree.ElementTree as ET
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(level=logging.INFO)
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
SEARCH_QUERIES = {
|
| 12 |
+
"pneumonia": "chest X-ray pneumonia findings radiology",
|
| 13 |
+
"pleural_effusion": "pleural effusion chest radiograph diagnosis",
|
| 14 |
+
"cardiomegaly": "cardiomegaly cardiac enlargement chest X-ray",
|
| 15 |
+
"tuberculosis": "tuberculosis pulmonary chest radiograph findings",
|
| 16 |
+
"normal_chest": "normal chest radiograph no significant finding"
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
def fetch_pubmed_abstracts(query: str, category: str, max_results=200) -> list[dict]:
|
| 20 |
+
logger.info(f"Searching PubMed for: {category}")
|
| 21 |
+
search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
|
| 22 |
+
params = {"db": "pubmed", "term": query, "retmax": max_results, "retmode": "json", "usehistory": "y"}
|
| 23 |
+
|
| 24 |
+
res = requests.get(search_url, params=params, timeout=30)
|
| 25 |
+
res.raise_for_status()
|
| 26 |
+
pmids = res.json().get("esearchresult", {}).get("idlist", [])
|
| 27 |
+
|
| 28 |
+
results = []
|
| 29 |
+
fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
|
| 30 |
+
|
| 31 |
+
for i in range(0, len(pmids), 20):
|
| 32 |
+
batch = pmids[i:i+20]
|
| 33 |
+
f_params = {"db": "pubmed", "id": ",".join(batch), "rettype": "abstract", "retmode": "xml"}
|
| 34 |
+
f_res = requests.get(fetch_url, params=f_params, timeout=30)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
tree = ET.fromstring(f_res.content)
|
| 38 |
+
for article in tree.findall(".//PubmedArticle"):
|
| 39 |
+
pmid = article.findtext(".//PMID")
|
| 40 |
+
title = article.findtext(".//ArticleTitle")
|
| 41 |
+
abstract = article.findtext(".//AbstractText")
|
| 42 |
+
year = article.findtext(".//PubDate/Year")
|
| 43 |
+
|
| 44 |
+
if abstract and title:
|
| 45 |
+
results.append({
|
| 46 |
+
"pmid": pmid, "title": title, "abstract": abstract,
|
| 47 |
+
"year": year, "category": category
|
| 48 |
+
})
|
| 49 |
+
except ET.ParseError:
|
| 50 |
+
logger.error("XML parse error on a batch, skipping.")
|
| 51 |
+
time.sleep(0.4) # Respect 3 req/sec rate limit
|
| 52 |
+
return results
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
out_file = Path("data/pubmed_raw.json")
|
| 56 |
+
out_file.parent.mkdir(parents=True, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
all_data = []
|
| 59 |
+
completed_cats = set()
|
| 60 |
+
|
| 61 |
+
if out_file.exists():
|
| 62 |
+
all_data = json.loads(out_file.read_text())
|
| 63 |
+
completed_cats = {d["category"] for d in all_data}
|
| 64 |
+
logger.info(f"Resuming... found {len(all_data)} existing records.")
|
| 65 |
+
|
| 66 |
+
for cat, query in SEARCH_QUERIES.items():
|
| 67 |
+
if cat in completed_cats:
|
| 68 |
+
continue
|
| 69 |
+
data = fetch_pubmed_abstracts(query, cat)
|
| 70 |
+
all_data.extend(data)
|
| 71 |
+
out_file.write_text(json.dumps(all_data, indent=2))
|
| 72 |
+
logger.info(f"Saved {len(data)} abstracts for {cat}.")
|
| 73 |
+
|
| 74 |
+
logger.info(f"✅ PubMed ingestion complete. Total records: {len(all_data)}")
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
main()
|
training/scripts/prepare_ner_data.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
|
| 7 |
+
logging.basicConfig(level=logging.INFO)
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
SYMPTOMS_LIST = [
|
| 11 |
+
"fever", "cough", "pain", "fatigue", "nausea", "vomiting",
|
| 12 |
+
"headache", "dizziness", "shortness of breath", "chest pain",
|
| 13 |
+
"dyspnea", "tachycardia", "edema", "rash"
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
def add_symptom_labels(example):
|
| 17 |
+
tokens = example["tokens"]
|
| 18 |
+
tags = example["ner_tags"]
|
| 19 |
+
|
| 20 |
+
for i, token in enumerate(tokens):
|
| 21 |
+
if tags[i] == 0: # If currently 'O'
|
| 22 |
+
if token.lower() in SYMPTOMS_LIST:
|
| 23 |
+
# 3 is B-SYMPTOM, 4 is I-SYMPTOM in our mapping
|
| 24 |
+
tags[i] = 3
|
| 25 |
+
return {"tokens": tokens, "ner_tags": tags}
|
| 26 |
+
|
| 27 |
+
def tokenize_and_align_labels(examples, tokenizer):
|
| 28 |
+
tokenized = tokenizer(
|
| 29 |
+
examples["tokens"],
|
| 30 |
+
truncation=True,
|
| 31 |
+
max_length=512,
|
| 32 |
+
is_split_into_words=True, # CRITICAL: Tells tokenizer input is already word-split
|
| 33 |
+
padding=False
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
all_labels = []
|
| 37 |
+
for i, labels in enumerate(examples["ner_tags"]):
|
| 38 |
+
word_ids = tokenized.word_ids(batch_index=i)
|
| 39 |
+
aligned = []
|
| 40 |
+
prev_word_id = None
|
| 41 |
+
|
| 42 |
+
for word_id in word_ids:
|
| 43 |
+
if word_id is None:
|
| 44 |
+
aligned.append(-100) # Special tokens (CLS, SEP) ignored in loss
|
| 45 |
+
elif word_id != prev_word_id:
|
| 46 |
+
aligned.append(labels[word_id]) # First subword gets the actual label
|
| 47 |
+
else:
|
| 48 |
+
label = labels[word_id]
|
| 49 |
+
# If B- label (odd numbers in NCBI), convert to I- label (even) for subwords
|
| 50 |
+
if label % 2 == 1:
|
| 51 |
+
label += 1
|
| 52 |
+
aligned.append(label)
|
| 53 |
+
prev_word_id = word_id
|
| 54 |
+
|
| 55 |
+
all_labels.append(aligned)
|
| 56 |
+
|
| 57 |
+
tokenized["labels"] = all_labels
|
| 58 |
+
return tokenized
|
| 59 |
+
|
| 60 |
+
def main():
|
| 61 |
+
out_dir = Path("data/processed/ner_dataset")
|
| 62 |
+
out_dir.parent.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
|
| 64 |
+
logger.info("Loading NCBI Disease dataset...")
|
| 65 |
+
dataset = load_dataset("ncbi_disease")
|
| 66 |
+
|
| 67 |
+
logger.info("Augmenting symptom labels...")
|
| 68 |
+
dataset = dataset.map(add_symptom_labels)
|
| 69 |
+
|
| 70 |
+
logger.info("Tokenizing and aligning BIO tags...")
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.2")
|
| 72 |
+
|
| 73 |
+
tokenized_dataset = dataset.map(
|
| 74 |
+
lambda x: tokenize_and_align_labels(x, tokenizer),
|
| 75 |
+
batched=True,
|
| 76 |
+
remove_columns=dataset["train"].column_names
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
tokenized_dataset.save_to_disk(str(out_dir))
|
| 80 |
+
logger.info(f"✅ NER dataset prepared and saved to {out_dir}")
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
main()
|