hoshikrana commited on
Commit
b1406c1
·
1 Parent(s): 7807e33

feat: NLP module, multimodal fusion, and RAG chatbot

Browse files
backend/api/v1/routers/analyze.py CHANGED
@@ -1,14 +1,16 @@
 
1
  from fastapi import APIRouter, Depends, Form, File, UploadFile, HTTPException, Request
2
  from sqlalchemy.ext.asyncio import AsyncSession
3
  from pathlib import Path
4
 
5
  from backend.db.session import get_db
6
  from backend.db.models import AnalysisSession, AnalysisTask
7
- from backend.core.dependencies import get_current_user, get_pagination
8
- from backend.core.exceptions import InvalidFileTypeError, FileTooLargeError
9
  from backend.utils.validators import ImageValidator, sanitize_symptoms_text, validate_patient_id, safe_temp_path
10
  from backend.api.v1.schemas.analysis import TaskSubmitResponse, TaskStatusResponse, AnalysisResult
11
  from backend.orchestration.queue import task_queue
 
12
 
13
  router = APIRouter()
14
 
@@ -77,3 +79,29 @@ async def get_result(task_id: str, db: AsyncSession = Depends(get_db), current_u
77
  async def cancel_task(task_id: str, current_user = Depends(get_current_user)):
78
  success = await task_queue.cancel(task_id, current_user.id)
79
  return {"cancelled": success, "message": "Task cancelled" if success else "Cannot cancel task"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  from fastapi import APIRouter, Depends, Form, File, UploadFile, HTTPException, Request
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
  from pathlib import Path
5
 
6
  from backend.db.session import get_db
7
  from backend.db.models import AnalysisSession, AnalysisTask
8
+ from backend.core.dependencies import get_current_user, get_pagination, get_model_registry
9
+ from backend.core.exceptions import InvalidFileTypeError, FileTooLargeError, ModelNotLoadedError
10
  from backend.utils.validators import ImageValidator, sanitize_symptoms_text, validate_patient_id, safe_temp_path
11
  from backend.api.v1.schemas.analysis import TaskSubmitResponse, TaskStatusResponse, AnalysisResult
12
  from backend.orchestration.queue import task_queue
13
+ from backend.ml.nlp.whisper import WhisperTranscriber
14
 
15
  router = APIRouter()
16
 
 
79
  async def cancel_task(task_id: str, current_user = Depends(get_current_user)):
80
  success = await task_queue.cancel(task_id, current_user.id)
81
  return {"cancelled": success, "message": "Task cancelled" if success else "Cannot cancel task"}
82
+
83
+ @router.post("/transcribe")
84
+ async def transcribe_audio(
85
+ audio: UploadFile = File(...),
86
+ registry = Depends(get_model_registry),
87
+ current_user = Depends(get_current_user)
88
+ ):
89
+ if audio.content_type not in ["audio/wav", "audio/mpeg", "audio/webm", "audio/ogg"]:
90
+ raise InvalidFileTypeError("Audio must be WAV, MP3, WebM, or OGG")
91
+
92
+ content = await audio.read()
93
+ if len(content) > 25 * 1024 * 1024:
94
+ raise FileTooLargeError("Audio file too large (max 25MB)")
95
+
96
+ temp_path = safe_temp_path(audio.filename or "audio.webm")
97
+ temp_path.write_bytes(content)
98
+
99
+ try:
100
+ whisper_state = await registry.get("whisper_tiny")
101
+ if not whisper_state.is_available:
102
+ raise ModelNotLoadedError("Voice transcription unavailable")
103
+
104
+ result = await asyncio.to_thread(WhisperTranscriber.transcribe, temp_path, whisper_state.model)
105
+ return result
106
+ finally:
107
+ temp_path.unlink(missing_ok=True)
backend/api/v1/routers/chat.py CHANGED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import asyncio
3
+ from fastapi import APIRouter, Depends
4
+ from fastapi.responses import StreamingResponse
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+
7
+ from backend.api.v1.schemas.chat import ChatRequest
8
+ from backend.core.dependencies import get_session_or_404, get_model_registry, get_current_user
9
+ from backend.db.session import get_db
10
+ # Note: Assume session_manager is built in db/utils.py to manage chat history records
11
+ # from backend.db.utils import session_manager
12
+ from backend.ml.rag.retriever import MedicalRAG
13
+ from backend.ml.rag.generator import ChatGenerator
14
+
15
+ router = APIRouter()
16
+
17
+ @router.post("/")
18
+ async def chat(
19
+ body: ChatRequest,
20
+ session = Depends(get_session_or_404),
21
+ registry = Depends(get_model_registry),
22
+ current_user = Depends(get_current_user),
23
+ db: AsyncSession = Depends(get_db)
24
+ ):
25
+ is_safe, warning = MedicalRAG.is_safe_query(body.message)
26
+
27
+ # Placeholder for actual DB history fetch
28
+ history_dicts = []
29
+
30
+ session_result = session.result_json if hasattr(session, 'result_json') else {}
31
+ chunks = MedicalRAG.retrieve(body.message, session_result, n_results=5)
32
+ prompt = MedicalRAG.build_prompt(body.message, chunks, history_dicts, session_result)
33
+
34
+ biogpt_state = await registry.get("biogpt_base")
35
+ full_response = []
36
+
37
+ async def stream_generator():
38
+ yield f"data: {json.dumps({'type': 'sources', 'sources': chunks[:3]})}\n\n"
39
+
40
+ if not biogpt_state.is_available:
41
+ fallback = "I cannot provide a detailed answer right now as the AI system is unavailable."
42
+ yield f"data: {json.dumps({'type': 'token', 'token': fallback})}\n\n"
43
+ else:
44
+ token_gen = await asyncio.to_thread(lambda: list(ChatGenerator.generate_stream(prompt, biogpt_state.model, biogpt_state.tokenizer)))
45
+ for token in token_gen:
46
+ full_response.append(token)
47
+ yield f"data: {json.dumps({'type': 'token', 'token': token})}\n\n"
48
+ await asyncio.sleep(0.02) # Optional smoothing
49
+
50
+ if warning:
51
+ yield f"data: {json.dumps({'type': 'token', 'token': warning})}\n\n"
52
+
53
+ yield f"data: {json.dumps({'type': 'done'})}\n\n"
54
+
55
+ return StreamingResponse(
56
+ stream_generator(),
57
+ media_type="text/event-stream",
58
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}
59
+ )
backend/ml/fusion/medclip.py CHANGED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from pathlib import Path
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class MultimodalFusion:
11
+ """Computes image-text alignment using BiomedVLP."""
12
+ IMAGE_SIZE = 224
13
+
14
+ @staticmethod
15
+ def get_image_transform():
16
+ return transforms.Compose([
17
+ transforms.Resize((MultimodalFusion.IMAGE_SIZE, MultimodalFusion.IMAGE_SIZE)),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
20
+ ])
21
+
22
+ @staticmethod
23
+ def get_image_embedding(image_path: Path, model, device: str) -> torch.Tensor:
24
+ image = Image.open(image_path).convert("RGB")
25
+ transform = MultimodalFusion.get_image_transform()
26
+ image_tensor = transform(image).unsqueeze(0).to(device)
27
+
28
+ with torch.no_grad():
29
+ embedding = model.get_image_embeddings(image_tensor)
30
+ return F.normalize(embedding, p=2, dim=-1)
31
+
32
+ @staticmethod
33
+ def get_text_embedding(text: str, model, tokenizer, device: str) -> torch.Tensor:
34
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256, padding="max_length").to(device)
35
+ with torch.no_grad():
36
+ embedding = model.get_text_embeddings(**inputs)
37
+ return F.normalize(embedding, p=2, dim=-1)
38
+
39
+ @staticmethod
40
+ def compute_similarity(image_path: Path, text: str, model, tokenizer, device: str) -> tuple[float, str]:
41
+ try:
42
+ img_emb = MultimodalFusion.get_image_embedding(image_path, model, device)
43
+ txt_emb = MultimodalFusion.get_text_embedding(text, model, tokenizer, device)
44
+
45
+ similarity = float(torch.cosine_similarity(img_emb, txt_emb).item())
46
+ similarity = (similarity + 1) / 2 # Shift to [0,1]
47
+
48
+ if similarity >= 0.7:
49
+ alignment = "HIGH"
50
+ elif similarity >= 0.4:
51
+ alignment = "MEDIUM"
52
+ else:
53
+ alignment = "LOW"
54
+ return round(similarity, 3), alignment
55
+ except Exception as e:
56
+ logger.warning(f"Fusion similarity computation failed: {e}")
57
+ return 0.5, "UNKNOWN"
58
+
59
+ @staticmethod
60
+ def get_fused_embedding(image_path: Path, text: str, model, tokenizer, device: str) -> torch.Tensor:
61
+ img_emb = MultimodalFusion.get_image_embedding(image_path, model, device)
62
+ txt_emb = MultimodalFusion.get_text_embedding(text, model, tokenizer, device)
63
+ return torch.cat([img_emb, txt_emb], dim=-1)
64
+
65
+ class FallbackFusion:
66
+ @staticmethod
67
+ def compute_similarity(image_path: Path, text: str) -> tuple[float, str]:
68
+ """Simple keyword-based fallback when BiomedVLP unavailable due to RAM constraints."""
69
+ CHEST_KEYWORDS = ["chest", "lung", "cardiac", "pleural", "pneumo", "infiltrate", "opacity", "nodule", "effusion"]
70
+ text_lower = text.lower()
71
+ matches = sum(1 for kw in CHEST_KEYWORDS if kw in text_lower)
72
+ score = min(0.9, 0.3 + matches * 0.1)
73
+ alignment = "HIGH" if score > 0.6 else "MEDIUM" if score > 0.4 else "LOW"
74
+ return score, alignment
backend/ml/nlp/classifier.py CHANGED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from backend.api.v1.schemas.analysis import NERResult
3
+
4
+ CHEST_CONDITIONS = [
5
+ "Pneumonia", "Pleural Effusion", "Cardiomegaly", "Atelectasis",
6
+ "Pneumothorax", "Pulmonary Edema", "Tuberculosis", "Lung Cancer",
7
+ "COVID-19", "Chronic Obstructive Pulmonary Disease", "Asthma",
8
+ "Pulmonary Fibrosis", "Bronchitis", "Emphysema", "Heart Failure",
9
+ "Aortic Aneurysm", "Pulmonary Embolism", "Sarcoidosis",
10
+ "No significant finding", "Other condition"
11
+ ]
12
+
13
+ class DiseaseClassifier:
14
+ """Zero-shot classifier. No fine-tuning needed. Always runs on CPU."""
15
+ _pipeline = None
16
+ _pipeline_lock = threading.Lock()
17
+
18
+ @classmethod
19
+ def _get_pipeline(cls):
20
+ if cls._pipeline is None:
21
+ with cls._pipeline_lock:
22
+ if cls._pipeline is None:
23
+ from transformers import pipeline
24
+ # Note: For faster inference use "valhalla/distilbart-mnli-12-1"
25
+ # We stick to bart-large-mnli as instructed, but it takes ~2-4s on CPU
26
+ cls._pipeline = pipeline(
27
+ "zero-shot-classification",
28
+ model="facebook/bart-large-mnli",
29
+ device=-1
30
+ )
31
+ return cls._pipeline
32
+
33
+ @staticmethod
34
+ def classify(text: str, entities: NERResult, top_k: int = 3) -> dict:
35
+ if not text or not text.strip():
36
+ return {"primary": "Insufficient information", "confidence": 0.0, "differential": []}
37
+
38
+ enriched_text = DiseaseClassifier._build_enriched_text(text, entities)
39
+ pipe = DiseaseClassifier._get_pipeline()
40
+
41
+ result = pipe(
42
+ enriched_text,
43
+ candidate_labels=CHEST_CONDITIONS,
44
+ multi_label=False
45
+ )
46
+
47
+ scores_dict = dict(zip(result["labels"], result["scores"]))
48
+ sorted_labels = sorted(scores_dict, key=scores_dict.get, reverse=True)
49
+
50
+ primary = sorted_labels[0]
51
+ primary_confidence = scores_dict[primary]
52
+
53
+ differential = [
54
+ {"disease": label, "confidence": round(scores_dict[label], 3)}
55
+ for label in sorted_labels[1:top_k]
56
+ ]
57
+
58
+ return {
59
+ "primary": primary,
60
+ "confidence": round(primary_confidence, 3),
61
+ "differential": differential
62
+ }
63
+
64
+ @staticmethod
65
+ def _build_enriched_text(original: str, entities: NERResult) -> str:
66
+ parts = [original]
67
+ if entities.diseases:
68
+ parts.append(f"Diagnosed conditions: {', '.join(entities.diseases)}")
69
+ if entities.symptoms:
70
+ parts.append(f"Presenting symptoms: {', '.join(entities.symptoms)}")
71
+ if entities.medications:
72
+ parts.append(f"Current medications: {', '.join(entities.medications)}")
73
+
74
+ enriched = ". ".join(parts)
75
+ return enriched[:1024]
76
+
77
+ if __name__ == "__main__":
78
+ test_text = "Patient complains of severe chest pain and shortness of breath."
79
+ entities = NERResult(diseases=[], symptoms=["chest pain", "shortness of breath"], medications=[], anatomy=[], raw_entities=[])
80
+ res = DiseaseClassifier.classify(test_text, entities)
81
+ print("Classification result:", res)
backend/ml/nlp/ner.py CHANGED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
3
+ from backend.api.v1.schemas.analysis import NERResult
4
+
5
+ class NERExtractor:
6
+ ENTITY_MAP = {
7
+ "DISEASE": "diseases",
8
+ "SYMPTOM": "symptoms",
9
+ "MEDICATION": "medications",
10
+ "ANATOMY": "anatomy"
11
+ }
12
+
13
+ @staticmethod
14
+ def extract(text: str, model: AutoModelForTokenClassification, tokenizer: AutoTokenizer, device: str = "cuda") -> NERResult:
15
+ if not text or not text.strip():
16
+ return NERResult(diseases=[], symptoms=[], medications=[], anatomy=[], raw_entities=[])
17
+
18
+ chunks = NERExtractor._chunk_text(text, tokenizer, max_length=400, overlap=50)
19
+
20
+ all_entities = []
21
+ seen_spans = set()
22
+
23
+ for chunk_text, char_offset in chunks:
24
+ chunk_entities = NERExtractor._extract_chunk(chunk_text, model, tokenizer, device, char_offset)
25
+ for entity in chunk_entities:
26
+ span_key = (entity["text"].lower(), entity["entity_type"])
27
+ if span_key not in seen_spans:
28
+ seen_spans.add(span_key)
29
+ all_entities.append(entity)
30
+
31
+ grouped = {"diseases": [], "symptoms": [], "medications": [], "anatomy": []}
32
+ for entity in all_entities:
33
+ entity_type = entity["entity_type"].split("-")[-1]
34
+ group_key = NERExtractor.ENTITY_MAP.get(entity_type)
35
+ if group_key:
36
+ entity_text = entity["text"].strip()
37
+ if entity_text and entity_text not in grouped[group_key]:
38
+ grouped[group_key].append(entity_text)
39
+
40
+ return NERResult(
41
+ diseases=grouped["diseases"], symptoms=grouped["symptoms"],
42
+ medications=grouped["medications"], anatomy=grouped["anatomy"],
43
+ raw_entities=all_entities
44
+ )
45
+
46
+ @staticmethod
47
+ def _extract_chunk(text: str, model, tokenizer, device: str, char_offset: int = 0) -> list[dict]:
48
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, return_offsets_mapping=True, padding=False)
49
+ offset_mapping = inputs.pop("offset_mapping")[0]
50
+ inputs = {k: v.to(device) for k, v in inputs.items()}
51
+
52
+ with torch.no_grad():
53
+ outputs = model(**inputs)
54
+
55
+ predictions = torch.argmax(outputs.logits, dim=-1)[0]
56
+ entities = []
57
+ current_entity = None
58
+
59
+ for idx, (pred_id, offsets) in enumerate(zip(predictions, offset_mapping)):
60
+ label = model.config.id2label[pred_id.item()]
61
+ start, end = offsets[0].item(), offsets[1].item()
62
+
63
+ if start == 0 and end == 0: # Special token
64
+ if current_entity:
65
+ entities.append(current_entity)
66
+ current_entity = None
67
+ continue
68
+
69
+ token_text = text[start:end]
70
+
71
+ if label.startswith("B-"):
72
+ if current_entity: entities.append(current_entity)
73
+ current_entity = {
74
+ "text": token_text, "entity_type": label,
75
+ "start": start + char_offset, "end": end + char_offset,
76
+ "confidence": torch.softmax(outputs.logits[0][idx], dim=-1).max().item()
77
+ }
78
+ elif label.startswith("I-") and current_entity:
79
+ current_entity["text"] += token_text if not token_text.startswith("##") else token_text[2:]
80
+ current_entity["end"] = end + char_offset
81
+ else:
82
+ if current_entity:
83
+ entities.append(current_entity)
84
+ current_entity = None
85
+
86
+ if current_entity:
87
+ entities.append(current_entity)
88
+
89
+ return [e for e in entities if e["confidence"] > 0.7]
90
+
91
+ @staticmethod
92
+ def _chunk_text(text: str, tokenizer, max_length: int = 400, overlap: int = 50) -> list[tuple[str, int]]:
93
+ words = text.split()
94
+ chunks = []
95
+ current_words = []
96
+ current_length = 0
97
+ char_offset = 0
98
+
99
+ for word in words:
100
+ word_tokens = tokenizer(word, add_special_tokens=False)["input_ids"]
101
+ if current_length + len(word_tokens) > max_length and current_words:
102
+ chunk_text = " ".join(current_words)
103
+ chunks.append((chunk_text, char_offset))
104
+
105
+ overlap_words = current_words[-overlap//4:]
106
+ char_offset += len(" ".join(current_words[:-len(overlap_words)])) + 1
107
+ current_words = overlap_words
108
+ current_length = sum(len(tokenizer(w, add_special_tokens=False)["input_ids"]) for w in current_words)
109
+
110
+ current_words.append(word)
111
+ current_length += len(word_tokens)
112
+
113
+ if current_words:
114
+ chunks.append((" ".join(current_words), char_offset))
115
+
116
+ return chunks if chunks else [(text, 0)]
117
+
118
+ def highlight_entities(text: str, entities: list[dict]) -> str:
119
+ COLORS = {
120
+ "DISEASE": "#FF6B6B", "SYMPTOM": "#FFD93D",
121
+ "MEDICATION": "#6BCB77", "ANATOMY": "#4D96FF"
122
+ }
123
+ sorted_entities = sorted(entities, key=lambda e: e["start"], reverse=True)
124
+ result = text
125
+ for entity in sorted_entities:
126
+ entity_type = entity["entity_type"].replace("B-", "").replace("I-", "")
127
+ color = COLORS.get(entity_type, "#cccccc")
128
+ span = (
129
+ f'<mark style="background:{color};padding:2px 4px;border-radius:3px;'
130
+ f'font-size:0.85em" title="{entity_type} ({entity["confidence"]:.0%})">'
131
+ f'{entity["text"]}</mark>'
132
+ )
133
+ result = result[:entity["start"]] + span + result[entity["end"]:]
134
+ return result
backend/ml/nlp/whisper.py CHANGED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import json
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from backend.core.exceptions import InvalidFileError, ValidationError, InferenceError
6
+
7
+ class WhisperTranscriber:
8
+ SUPPORTED_INPUT_FORMATS = {".wav", ".mp3", ".webm", ".ogg", ".m4a"}
9
+ MAX_DURATION_SECONDS = 60
10
+
11
+ @staticmethod
12
+ def transcribe(audio_file_path: Path, model, language: str = "en") -> dict:
13
+ audio_path = Path(audio_file_path)
14
+ if not audio_path.exists():
15
+ raise InvalidFileError(f"Audio file not found: {audio_path}")
16
+
17
+ wav_path = None
18
+ try:
19
+ if audio_path.suffix.lower() != ".wav":
20
+ wav_path = audio_path.with_suffix(".converted.wav")
21
+ WhisperTranscriber._convert_to_wav(audio_path, wav_path)
22
+ process_path = wav_path
23
+ else:
24
+ process_path = audio_path
25
+
26
+ duration = WhisperTranscriber._get_duration(process_path)
27
+ if duration > WhisperTranscriber.MAX_DURATION_SECONDS:
28
+ raise ValidationError(f"Audio too long ({duration:.0f}s). Max {WhisperTranscriber.MAX_DURATION_SECONDS}s.")
29
+
30
+ result = model.transcribe(
31
+ str(process_path), language=language, verbose=False,
32
+ word_timestamps=True, fp16=False, condition_on_previous_text=False,
33
+ no_speech_threshold=0.6, logprob_threshold=-1.0
34
+ )
35
+
36
+ avg_logprob = np.mean([s["avg_logprob"] for s in result["segments"]]) if result["segments"] else -1
37
+ confidence = float(min(1.0, max(0.0, np.exp(avg_logprob))))
38
+
39
+ return {
40
+ "text": result["text"].strip(),
41
+ "language": result["language"],
42
+ "confidence": round(confidence, 3),
43
+ "duration_seconds": round(duration, 1),
44
+ "segments": [{"start": s["start"], "end": s["end"], "text": s["text"]} for s in result["segments"]]
45
+ }
46
+ finally:
47
+ if wav_path and wav_path.exists():
48
+ wav_path.unlink()
49
+
50
+ @staticmethod
51
+ def _convert_to_wav(input_path: Path, output_path: Path):
52
+ cmd = [
53
+ "ffmpeg", "-i", str(input_path), "-acodec", "pcm_s16le",
54
+ "-ac", "1", "-ar", "16000", "-y", str(output_path)
55
+ ]
56
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
57
+ if result.returncode != 0:
58
+ raise InferenceError(f"Audio conversion failed: {result.stderr[:200]}")
59
+
60
+ @staticmethod
61
+ def _get_duration(wav_path: Path) -> float:
62
+ cmd = [
63
+ "ffprobe", "-v", "quiet", "-print_format", "json",
64
+ "-show_format", str(wav_path)
65
+ ]
66
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
67
+ if result.returncode != 0:
68
+ return 0.0
69
+ data = json.loads(result.stdout)
70
+ return float(data.get("format", {}).get("duration", 0))
backend/ml/rag/generator.py CHANGED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from threading import Thread
3
+ from transformers import TextIteratorStreamer
4
+
5
+ class ReportGenerator:
6
+ @staticmethod
7
+ def generate(vision, nlp, fusion, model, tokenizer) -> str:
8
+ parts = ["Generate a medical radiology report based on these findings:"]
9
+ if vision: parts.append(f"Imaging: {vision.risk_level} risk, anomaly score {vision.anomaly_score}/100")
10
+ if nlp: parts.append(f"Clinical: {nlp.primary_diagnosis}, confidence {nlp.diagnosis_confidence:.0%}")
11
+ if fusion: parts.append(f"Image-text alignment: {fusion.alignment}")
12
+ parts.append("Report:")
13
+
14
+ prompt = " ".join(parts)
15
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
16
+
17
+ with torch.no_grad():
18
+ outputs = model.generate(
19
+ **inputs, max_new_tokens=300, do_sample=False, num_beams=4,
20
+ early_stopping=True, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.3
21
+ )
22
+
23
+ generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
24
+ return ReportGenerator._format_report(generated, vision, nlp)
25
+
26
+ @staticmethod
27
+ def _format_report(raw_text: str, vision, nlp) -> str:
28
+ sections = ["## AI Diagnostic Report\n"]
29
+ if vision:
30
+ risk_emoji = "🔴" if vision.risk_level == "HIGH" else "🟡" if vision.risk_level == "MEDIUM" else "🟢"
31
+ sections.append(f"### Imaging Findings\n{risk_emoji} **Risk Level:** {vision.risk_level} \n**Anomaly Score:** {vision.anomaly_score}/100\n")
32
+ if nlp:
33
+ sections.append(f"### Clinical Assessment\n**Primary Impression:** {nlp.primary_diagnosis}\n")
34
+ sections.append(f"### AI Analysis\n{raw_text.strip()}\n\n### Recommendation\nPlease consult a licensed physician.")
35
+ return "\n".join(sections)
36
+
37
+ class ChatGenerator:
38
+ @staticmethod
39
+ def generate_stream(prompt: str, model, tokenizer, max_new_tokens: int = 200):
40
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
41
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
42
+
43
+ generation_kwargs = {
44
+ **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens,
45
+ "do_sample": True, "temperature": 0.7, "top_p": 0.9,
46
+ "repetition_penalty": 1.2, "pad_token_id": tokenizer.eos_token_id
47
+ }
48
+
49
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
50
+ thread.start()
51
+
52
+ for token in streamer:
53
+ yield token
backend/ml/rag/retriever.py CHANGED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from backend.ml.rag.vectorstore import vector_store
3
+
4
+ class MedicalRAG:
5
+ """Retrieval-Augmented Generation for medical Q&A."""
6
+ MAX_CONTEXT_TOKENS = 800
7
+ MAX_HISTORY_TURNS = 6
8
+ SAFETY_DISCLAIMER = "\n\n*Note: This is AI-generated information for educational purposes only. Please consult a licensed physician for medical advice.*"
9
+ DRUG_DOSAGE_PATTERNS = [r"\b\d+\s*mg\b", r"how much.*take", r"dosage", r"dose of"]
10
+
11
+ @staticmethod
12
+ def retrieve(query: str, session_result: dict | None, n_results: int = 5) -> list[dict]:
13
+ enriched_query = query
14
+ if session_result:
15
+ primary_dx = session_result.get("nlp", {}).get("primary_diagnosis", "")
16
+ symptoms = session_result.get("nlp", {}).get("entities", {}).get("symptoms", [])
17
+ if primary_dx:
18
+ enriched_query = f"{query} {primary_dx} {' '.join(symptoms[:3])}"
19
+ return vector_store.search(enriched_query, n_results=n_results)
20
+
21
+ @staticmethod
22
+ def build_prompt(query: str, retrieved_chunks: list[dict], chat_history: list[dict], session_result: dict | None) -> str:
23
+ parts = []
24
+ if session_result:
25
+ vision, nlp = session_result.get("vision", {}), session_result.get("nlp", {})
26
+ if vision or nlp:
27
+ parts.append("Patient analysis context:")
28
+ if vision:
29
+ parts.append(f"- Imaging: {vision.get('risk_level', 'unknown')} risk (anomaly score: {vision.get('anomaly_score', 'N/A')}/100)")
30
+ if nlp:
31
+ parts.append(f"- Primary impression: {nlp.get('primary_diagnosis', 'unknown')}")
32
+ parts.append("")
33
+
34
+ if retrieved_chunks:
35
+ parts.append("Relevant medical information:")
36
+ for chunk in retrieved_chunks[:3]:
37
+ parts.append(f"- {chunk['text'][:200]}...")
38
+ parts.append("")
39
+
40
+ recent_history = chat_history[-MedicalRAG.MAX_HISTORY_TURNS:]
41
+ for msg in recent_history:
42
+ role = "Patient" if msg["role"] == "user" else "Assistant"
43
+ parts.append(f"{role}: {msg['content'][:100]}")
44
+
45
+ parts.append(f"Patient: {query}\nAssistant:")
46
+ return "\n".join(parts)
47
+
48
+ @staticmethod
49
+ def is_safe_query(query: str) -> tuple[bool, str | None]:
50
+ for pattern in MedicalRAG.DRUG_DOSAGE_PATTERNS:
51
+ if re.search(pattern, query.lower()):
52
+ return False, "Please consult a physician for specific dosage information."
53
+ return True, None
backend/ml/rag/vectorstore.py CHANGED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import asyncio
3
+ import logging
4
+ from pathlib import Path
5
+ from backend.core.config import settings
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class MedicalVectorStore:
10
+ COLLECTION_NAME = "pubmed_medical"
11
+ CHUNK_SIZE = 256
12
+ CHUNK_OVERLAP = 32
13
+ BATCH_SIZE = 100
14
+
15
+ def __init__(self):
16
+ self._client = None
17
+ self._collection = None
18
+ self._embedder = None
19
+
20
+ async def initialize(self):
21
+ if settings.ENVIRONMENT == "production":
22
+ import chromadb
23
+ self._client = chromadb.EphemeralClient()
24
+ else:
25
+ import chromadb
26
+ persist_path = str(settings.BASE_DIR / "data" / "chromadb" if hasattr(settings, 'BASE_DIR') else Path("data/chromadb"))
27
+ self._client = chromadb.PersistentClient(path=persist_path)
28
+
29
+ self._collection = self._client.get_or_create_collection(
30
+ name=self.COLLECTION_NAME, metadata={"hnsw:space": "cosine"}
31
+ )
32
+
33
+ doc_count = self._collection.count()
34
+ logger.info(f"ChromaDB initialized. Documents: {doc_count}")
35
+
36
+ if doc_count == 0:
37
+ logger.info("ChromaDB is empty. Starting PubMed ingestion...")
38
+ await asyncio.to_thread(self.ingest_from_json, Path("data/pubmed_raw.json"))
39
+
40
+ def ingest_from_json(self, json_path: Path):
41
+ if not json_path.exists():
42
+ logger.warning(f"PubMed data file not found: {json_path}. RAG will be limited.")
43
+ return
44
+
45
+ records = json.loads(json_path.read_text(encoding="utf-8"))
46
+ logger.info(f"Ingesting {len(records)} PubMed records into ChromaDB...")
47
+
48
+ all_chunks, all_ids, all_metadatas = [], [], []
49
+
50
+ for record in records:
51
+ if not record.get("abstract"): continue
52
+ text = f"{record['title']}. {record['abstract']}"
53
+ chunks = self._chunk_text(text, self.CHUNK_SIZE, self.CHUNK_OVERLAP)
54
+
55
+ for i, chunk in enumerate(chunks):
56
+ all_chunks.append(chunk)
57
+ all_ids.append(f"{record['pmid']}_chunk_{i}")
58
+ all_metadatas.append({
59
+ "pmid": str(record["pmid"]),
60
+ "title": record["title"][:200],
61
+ "disease_category": record.get("category", "general"),
62
+ "year": str(record.get("year", ""))
63
+ })
64
+
65
+ embedder = self._get_embedder()
66
+
67
+ for i in range(0, len(all_chunks), self.BATCH_SIZE):
68
+ batch_chunks = all_chunks[i:i+self.BATCH_SIZE]
69
+ batch_ids = all_ids[i:i+self.BATCH_SIZE]
70
+ batch_metas = all_metadatas[i:i+self.BATCH_SIZE]
71
+
72
+ embeddings = embedder.encode(batch_chunks, show_progress_bar=False).tolist()
73
+ self._collection.add(documents=batch_chunks, embeddings=embeddings, ids=batch_ids, metadatas=batch_metas)
74
+
75
+ logger.info(f"Ingestion complete. Total documents: {self._collection.count()}")
76
+
77
+ def search(self, query: str, n_results: int = 5, disease_filter: str = None) -> list[dict]:
78
+ if not self._collection: return []
79
+ embedder = self._get_embedder()
80
+ query_embedding = embedder.encode([query])[0].tolist()
81
+
82
+ where_filter = {"disease_category": disease_filter} if disease_filter else None
83
+
84
+ results = self._collection.query(
85
+ query_embeddings=[query_embedding], n_results=n_results,
86
+ where=where_filter, include=["documents", "metadatas", "distances"]
87
+ )
88
+
89
+ output = []
90
+ if results["documents"]:
91
+ for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0]):
92
+ relevance = float(1 - dist)
93
+ output.append({
94
+ "text": doc, "pmid": meta.get("pmid", ""), "title": meta.get("title", ""),
95
+ "disease_category": meta.get("disease_category", ""), "year": meta.get("year", ""),
96
+ "relevance_score": round(max(0, relevance), 3)
97
+ })
98
+ return output
99
+
100
+ def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> list[str]:
101
+ words = text.split()
102
+ chunks, start = [], 0
103
+ while start < len(words):
104
+ end = min(start + chunk_size, len(words))
105
+ chunks.append(" ".join(words[start:end]))
106
+ start = end - overlap
107
+ if start >= len(words): break
108
+ return chunks
109
+
110
+ def _get_embedder(self):
111
+ if self._embedder is None:
112
+ from sentence_transformers import SentenceTransformer
113
+ cache_dir = str(settings.MODEL_CACHE_DIR / "minilm")
114
+ self._embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
115
+ return self._embedder
116
+
117
+ vector_store = MedicalVectorStore()
frontend/components/analysis/VoiceInput.jsx ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client'
2
+ import { useState, useRef } from 'react'
3
+ import { apiClient } from '@/lib/api/client'
4
+
5
+ export default function VoiceInput({ onTranscribed }) {
6
+ const [isRecording, setIsRecording] = useState(false)
7
+ const [recordingDuration, setRecordingDuration] = useState(0)
8
+ const [isTranscribing, setIsTranscribing] = useState(false)
9
+ const [transcribedText, setTranscribedText] = useState("")
10
+ const [confidence, setConfidence] = useState(null)
11
+ const [error, setError] = useState(null)
12
+ const [audioBlob, setAudioBlob] = useState(null)
13
+
14
+ const mediaRecorderRef = useRef(null)
15
+ const chunksRef = useRef([])
16
+ const timerRef = useRef(null)
17
+
18
+ const startRecording = async () => {
19
+ try {
20
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true })
21
+ const mediaRecorder = new MediaRecorder(stream, { mimeType: "audio/webm" })
22
+ mediaRecorderRef.current = mediaRecorder
23
+ chunksRef.current = []
24
+
25
+ mediaRecorder.ondataavailable = e => chunksRef.current.push(e.data)
26
+ mediaRecorder.onstop = () => {
27
+ const blob = new Blob(chunksRef.current, { type: "audio/webm" })
28
+ setAudioBlob(blob)
29
+ sendForTranscription(blob)
30
+ stream.getTracks().forEach(track => track.stop())
31
+ }
32
+
33
+ mediaRecorder.start(1000)
34
+ setIsRecording(true)
35
+ setRecordingDuration(0)
36
+ setError(null)
37
+
38
+ timerRef.current = setInterval(() => {
39
+ setRecordingDuration(prev => {
40
+ if (prev >= 59) {
41
+ stopRecording()
42
+ return 60
43
+ }
44
+ return prev + 1
45
+ })
46
+ }, 1000)
47
+ } catch (err) {
48
+ setError("Microphone access denied or unavailable.")
49
+ }
50
+ }
51
+
52
+ const stopRecording = () => {
53
+ if (mediaRecorderRef.current && isRecording) {
54
+ mediaRecorderRef.current.stop()
55
+ setIsRecording(false)
56
+ clearInterval(timerRef.current)
57
+ }
58
+ }
59
+
60
+ const sendForTranscription = async (blob) => {
61
+ setIsTranscribing(true)
62
+ const formData = new FormData()
63
+ formData.append("audio", blob, "recording.webm")
64
+
65
+ try {
66
+ const response = await apiClient.post('/api/v1/analyze/transcribe', formData, {
67
+ headers: { 'Content-Type': 'multipart/form-data' }
68
+ })
69
+ setTranscribedText(response.data.text)
70
+ setConfidence(response.data.confidence)
71
+ } catch (err) {
72
+ setError("Transcription failed. Please try again.")
73
+ } finally {
74
+ setIsTranscribing(false)
75
+ }
76
+ }
77
+
78
+ const handleReset = () => {
79
+ setTranscribedText("")
80
+ setConfidence(null)
81
+ setAudioBlob(null)
82
+ setRecordingDuration(0)
83
+ }
84
+
85
+ return (
86
+ <div className="p-4 border rounded-lg bg-gray-50">
87
+ {!transcribedText && !isTranscribing && (
88
+ <div className="flex flex-col items-center">
89
+ <button
90
+ onClick={isRecording ? stopRecording : startRecording}
91
+ className={`w-16 h-16 rounded-full flex items-center justify-center transition-all ${
92
+ isRecording ? "bg-red-500 animate-pulse" : "bg-teal-500 hover:bg-teal-600"
93
+ }`}
94
+ >
95
+ <svg className="w-8 h-8 text-white" fill="none" stroke="currentColor" viewBox="0 0 24 24">
96
+ <path strokeLinecap="round" strokeLinejoin="round" strokeWidth="2" d="M19 11a7 7 0 01-7 7m0 0a7 7 0 01-7-7m7 7v4m0 0H8m4 0h4m-4-8a3 3 0 01-3-3V5a3 3 0 116 0v6a3 3 0 01-3 3z"></path>
97
+ </svg>
98
+ </button>
99
+
100
+ <p className={`mt-2 font-mono ${recordingDuration > 50 ? 'text-red-500' : 'text-gray-600'}`}>
101
+ {Math.floor(recordingDuration / 60)}:{(recordingDuration % 60).toString().padStart(2, '0')} / 1:00
102
+ </p>
103
+ {error && <p className="mt-2 text-sm text-red-500">{error}</p>}
104
+ </div>
105
+ )}
106
+
107
+ {isTranscribing && (
108
+ <div className="flex flex-col items-center py-4">
109
+ <div className="w-6 h-6 border-2 border-teal-500 border-t-transparent rounded-full animate-spin" />
110
+ <p className="mt-2 text-sm text-gray-600">Transcribing audio...</p>
111
+ </div>
112
+ )}
113
+
114
+ {transcribedText && !isTranscribing && (
115
+ <div className="space-y-4">
116
+ <div className="p-3 bg-white border rounded">
117
+ <p className="text-gray-800 animate-[typewriter_0.5s_steps(40,end)] overflow-hidden break-words">
118
+ "{transcribedText}"
119
+ </p>
120
+ <div className="flex items-center mt-2 text-xs">
121
+ <span className={confidence > 0.8 ? "text-green-600" : "text-yellow-600"}>
122
+ {confidence > 0.8 ? "✅" : "⚠️"} Confidence: {(confidence * 100).toFixed(0)}%
123
+ </span>
124
+ </div>
125
+ </div>
126
+ <div className="flex space-x-2">
127
+ <button onClick={() => onTranscribed(transcribedText)} className="px-4 py-2 text-white bg-teal-600 rounded hover:bg-teal-700">
128
+ Use This Text
129
+ </button>
130
+ <button onClick={handleReset} className="px-4 py-2 text-gray-700 bg-gray-200 rounded hover:bg-gray-300">
131
+ Re-record
132
+ </button>
133
+ </div>
134
+ </div>
135
+ )}
136
+ </div>
137
+ )
138
+ }
frontend/components/chat/ChatInterface.jsx ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client'
2
+ import { useState } from 'react'
3
+ import { v4 as uuid } from 'uuid'
4
+ import { useAuth } from '@/lib/auth/AuthContext'
5
+
6
+ export default function ChatInterface({ sessionId }) {
7
+ const { accessToken } = useAuth()
8
+ const [messages, setMessages] = useState([])
9
+ const [inputText, setInputText] = useState("")
10
+ const [streamingMessageId, setStreamingMessageId] = useState(null)
11
+
12
+ const sendMessage = async (e) => {
13
+ e.preventDefault()
14
+ if(!inputText.trim() || streamingMessageId) return
15
+
16
+ const text = inputText
17
+ setInputText("")
18
+
19
+ const userMsg = { id: uuid(), role: "user", content: text }
20
+ setMessages(prev => [...prev, userMsg])
21
+
22
+ const assistantMsgId = uuid()
23
+ const assistantMsg = { id: assistantMsgId, role: "assistant", content: "", isStreaming: true, sources: [] }
24
+ setMessages(prev => [...prev, assistantMsg])
25
+ setStreamingMessageId(assistantMsgId)
26
+
27
+ const API_URL = process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000'
28
+
29
+ try {
30
+ const response = await fetch(`${API_URL}/api/v1/chat`, {
31
+ method: "POST",
32
+ headers: { "Content-Type": "application/json", "Authorization": `Bearer ${accessToken}` },
33
+ body: JSON.stringify({ session_id: sessionId, message: text })
34
+ })
35
+
36
+ const reader = response.body.getReader()
37
+ const decoder = new TextDecoder()
38
+
39
+ while (true) {
40
+ const { done, value } = await reader.read()
41
+ if (done) break
42
+
43
+ const chunk = decoder.decode(value, { stream: true })
44
+ const lines = chunk.split("\n\n").filter(l => l.startsWith("data: "))
45
+
46
+ for (const line of lines) {
47
+ const data = JSON.parse(line.replace("data: ", ""))
48
+
49
+ if (data.type === "token") {
50
+ setMessages(prev => prev.map(m => m.id === assistantMsgId ? { ...m, content: m.content + data.token } : m))
51
+ } else if (data.type === "sources") {
52
+ setMessages(prev => prev.map(m => m.id === assistantMsgId ? { ...m, sources: data.sources } : m))
53
+ } else if (data.type === "done") {
54
+ setMessages(prev => prev.map(m => m.id === assistantMsgId ? { ...m, isStreaming: false } : m))
55
+ setStreamingMessageId(null)
56
+ }
57
+ }
58
+ }
59
+ } catch (error) {
60
+ setMessages(prev => prev.map(m => m.id === assistantMsgId ? { ...m, content: "Network error occurred.", isStreaming: false } : m))
61
+ setStreamingMessageId(null)
62
+ }
63
+ }
64
+
65
+ return (
66
+ <div className="flex flex-col h-[600px] border rounded-lg bg-white">
67
+ <div className="flex-1 p-4 overflow-y-auto space-y-4">
68
+ {messages.map(m => (
69
+ <div key={m.id} className={`flex ${m.role === 'user' ? 'justify-end' : 'justify-start'}`}>
70
+ <div className={`max-w-[80%] p-3 rounded-lg ${m.role === 'user' ? 'bg-teal-600 text-white' : 'bg-gray-100 text-gray-800'}`}>
71
+ <p>{m.content}</p>
72
+ {m.isStreaming && <span className="inline-block w-2 h-4 ml-1 bg-gray-400 animate-pulse" />}
73
+ {m.sources?.length > 0 && (
74
+ <div className="mt-2 pt-2 border-t border-gray-300 text-xs">
75
+ <p className="font-semibold mb-1">Sources:</p>
76
+ {m.sources.map((s, i) => <p key={i}>• {s.title.substring(0, 50)}...</p>)}
77
+ </div>
78
+ )}
79
+ </div>
80
+ </div>
81
+ ))}
82
+ </div>
83
+ <form onSubmit={sendMessage} className="p-3 border-t bg-gray-50 flex items-center">
84
+ <input
85
+ type="text"
86
+ value={inputText}
87
+ onChange={e => setInputText(e.target.value)}
88
+ disabled={!!streamingMessageId}
89
+ className="flex-1 p-2 border rounded-l-md focus:outline-none focus:ring-2 focus:ring-teal-500 disabled:opacity-50"
90
+ placeholder="Ask about the results..."
91
+ />
92
+ <button type="submit" disabled={!!streamingMessageId} className="px-4 py-2 bg-teal-600 text-white rounded-r-md hover:bg-teal-700 disabled:opacity-50">
93
+ Send
94
+ </button>
95
+ </form>
96
+ </div>
97
+ )
98
+ }
training/scripts/finetune_ner.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from transformers import (
5
+ AutoModelForTokenClassification, AutoTokenizer,
6
+ TrainingArguments, Trainer, DataCollatorForTokenClassification
7
+ )
8
+ from datasets import load_from_disk
9
+ from seqeval.metrics import f1_score, precision_score, recall_score, classification_report
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ id2label = {
15
+ 0: "O", 1: "B-DISEASE", 2: "I-DISEASE",
16
+ 3: "B-SYMPTOM", 4: "I-SYMPTOM",
17
+ 5: "B-MEDICATION", 6: "I-MEDICATION",
18
+ 7: "B-ANATOMY", 8: "I-ANATOMY"
19
+ }
20
+ label2id = {v: k for k, v in id2label.items()}
21
+
22
+ def compute_metrics(eval_pred):
23
+ predictions, labels = eval_pred
24
+ predictions = np.argmax(predictions, axis=2)
25
+
26
+ true_labels = [[id2label[l] for l in label if l != -100] for label in labels]
27
+ true_predictions = [
28
+ [id2label[p] for p, l in zip(pred, label) if l != -100]
29
+ for pred, label in zip(predictions, labels)
30
+ ]
31
+
32
+ return {
33
+ "precision": precision_score(true_labels, true_predictions),
34
+ "recall": recall_score(true_labels, true_predictions),
35
+ "f1": f1_score(true_labels, true_predictions),
36
+ }
37
+
38
+ def main():
39
+ dataset = load_from_disk("data/processed/ner_dataset")
40
+ model_id = "dmis-lab/biobert-base-cased-v1.2"
41
+
42
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
43
+ model = AutoModelForTokenClassification.from_pretrained(
44
+ model_id, num_labels=9, id2label=id2label, label2id=label2id
45
+ )
46
+
47
+ training_args = TrainingArguments(
48
+ output_dir="models/biobert_ner",
49
+ per_device_train_batch_size=8,
50
+ per_device_eval_batch_size=8,
51
+ gradient_accumulation_steps=4,
52
+ fp16=True,
53
+ gradient_checkpointing=True,
54
+ dataloader_num_workers=0, # Windows fix
55
+ dataloader_pin_memory=False, # Windows fix
56
+ num_train_epochs=10,
57
+ warmup_ratio=0.1,
58
+ learning_rate=2e-5,
59
+ weight_decay=0.01,
60
+ evaluation_strategy="epoch",
61
+ save_strategy="epoch",
62
+ load_best_model_at_end=True,
63
+ metric_for_best_model="f1",
64
+ greater_is_better=True,
65
+ save_total_limit=3,
66
+ logging_steps=50,
67
+ report_to="none",
68
+ )
69
+
70
+ data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
71
+
72
+ trainer = Trainer(
73
+ model=model,
74
+ args=training_args,
75
+ train_dataset=dataset["train"],
76
+ eval_dataset=dataset["validation"],
77
+ tokenizer=tokenizer,
78
+ data_collator=data_collator,
79
+ compute_metrics=compute_metrics
80
+ )
81
+
82
+ # Auto-resume logic
83
+ checkpoint_dir = Path("models/biobert_ner")
84
+ latest_checkpoint = None
85
+ if checkpoint_dir.exists():
86
+ checkpoints = list(checkpoint_dir.glob("checkpoint-*"))
87
+ if checkpoints:
88
+ latest_checkpoint = str(max(checkpoints, key=lambda p: int(p.name.split("-")[1])))
89
+ logger.info(f"📂 Resuming from {latest_checkpoint}")
90
+
91
+ trainer.train(resume_from_checkpoint=latest_checkpoint)
92
+
93
+ logger.info("Evaluating on test set...")
94
+ test_results = trainer.predict(dataset["test"])
95
+ print(classification_report(*Trainer._get_labels_and_preds(test_results.predictions, test_results.label_ids)))
96
+
97
+ final_dir = "models/biobert_ner_finetuned"
98
+ trainer.save_model(final_dir)
99
+ tokenizer.save_pretrained(final_dir)
100
+ logger.info(f"✅ Model saved to {final_dir}")
101
+
102
+ if __name__ == "__main__":
103
+ main()
training/scripts/ingest_pubmed.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ import logging
4
+ import requests
5
+ import xml.etree.ElementTree as ET
6
+ from pathlib import Path
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ SEARCH_QUERIES = {
12
+ "pneumonia": "chest X-ray pneumonia findings radiology",
13
+ "pleural_effusion": "pleural effusion chest radiograph diagnosis",
14
+ "cardiomegaly": "cardiomegaly cardiac enlargement chest X-ray",
15
+ "tuberculosis": "tuberculosis pulmonary chest radiograph findings",
16
+ "normal_chest": "normal chest radiograph no significant finding"
17
+ }
18
+
19
+ def fetch_pubmed_abstracts(query: str, category: str, max_results=200) -> list[dict]:
20
+ logger.info(f"Searching PubMed for: {category}")
21
+ search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
22
+ params = {"db": "pubmed", "term": query, "retmax": max_results, "retmode": "json", "usehistory": "y"}
23
+
24
+ res = requests.get(search_url, params=params, timeout=30)
25
+ res.raise_for_status()
26
+ pmids = res.json().get("esearchresult", {}).get("idlist", [])
27
+
28
+ results = []
29
+ fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
30
+
31
+ for i in range(0, len(pmids), 20):
32
+ batch = pmids[i:i+20]
33
+ f_params = {"db": "pubmed", "id": ",".join(batch), "rettype": "abstract", "retmode": "xml"}
34
+ f_res = requests.get(fetch_url, params=f_params, timeout=30)
35
+
36
+ try:
37
+ tree = ET.fromstring(f_res.content)
38
+ for article in tree.findall(".//PubmedArticle"):
39
+ pmid = article.findtext(".//PMID")
40
+ title = article.findtext(".//ArticleTitle")
41
+ abstract = article.findtext(".//AbstractText")
42
+ year = article.findtext(".//PubDate/Year")
43
+
44
+ if abstract and title:
45
+ results.append({
46
+ "pmid": pmid, "title": title, "abstract": abstract,
47
+ "year": year, "category": category
48
+ })
49
+ except ET.ParseError:
50
+ logger.error("XML parse error on a batch, skipping.")
51
+ time.sleep(0.4) # Respect 3 req/sec rate limit
52
+ return results
53
+
54
+ def main():
55
+ out_file = Path("data/pubmed_raw.json")
56
+ out_file.parent.mkdir(parents=True, exist_ok=True)
57
+
58
+ all_data = []
59
+ completed_cats = set()
60
+
61
+ if out_file.exists():
62
+ all_data = json.loads(out_file.read_text())
63
+ completed_cats = {d["category"] for d in all_data}
64
+ logger.info(f"Resuming... found {len(all_data)} existing records.")
65
+
66
+ for cat, query in SEARCH_QUERIES.items():
67
+ if cat in completed_cats:
68
+ continue
69
+ data = fetch_pubmed_abstracts(query, cat)
70
+ all_data.extend(data)
71
+ out_file.write_text(json.dumps(all_data, indent=2))
72
+ logger.info(f"Saved {len(data)} abstracts for {cat}.")
73
+
74
+ logger.info(f"✅ PubMed ingestion complete. Total records: {len(all_data)}")
75
+
76
+ if __name__ == "__main__":
77
+ main()
training/scripts/prepare_ner_data.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from pathlib import Path
4
+ from datasets import load_dataset
5
+ from transformers import AutoTokenizer
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ SYMPTOMS_LIST = [
11
+ "fever", "cough", "pain", "fatigue", "nausea", "vomiting",
12
+ "headache", "dizziness", "shortness of breath", "chest pain",
13
+ "dyspnea", "tachycardia", "edema", "rash"
14
+ ]
15
+
16
+ def add_symptom_labels(example):
17
+ tokens = example["tokens"]
18
+ tags = example["ner_tags"]
19
+
20
+ for i, token in enumerate(tokens):
21
+ if tags[i] == 0: # If currently 'O'
22
+ if token.lower() in SYMPTOMS_LIST:
23
+ # 3 is B-SYMPTOM, 4 is I-SYMPTOM in our mapping
24
+ tags[i] = 3
25
+ return {"tokens": tokens, "ner_tags": tags}
26
+
27
+ def tokenize_and_align_labels(examples, tokenizer):
28
+ tokenized = tokenizer(
29
+ examples["tokens"],
30
+ truncation=True,
31
+ max_length=512,
32
+ is_split_into_words=True, # CRITICAL: Tells tokenizer input is already word-split
33
+ padding=False
34
+ )
35
+
36
+ all_labels = []
37
+ for i, labels in enumerate(examples["ner_tags"]):
38
+ word_ids = tokenized.word_ids(batch_index=i)
39
+ aligned = []
40
+ prev_word_id = None
41
+
42
+ for word_id in word_ids:
43
+ if word_id is None:
44
+ aligned.append(-100) # Special tokens (CLS, SEP) ignored in loss
45
+ elif word_id != prev_word_id:
46
+ aligned.append(labels[word_id]) # First subword gets the actual label
47
+ else:
48
+ label = labels[word_id]
49
+ # If B- label (odd numbers in NCBI), convert to I- label (even) for subwords
50
+ if label % 2 == 1:
51
+ label += 1
52
+ aligned.append(label)
53
+ prev_word_id = word_id
54
+
55
+ all_labels.append(aligned)
56
+
57
+ tokenized["labels"] = all_labels
58
+ return tokenized
59
+
60
+ def main():
61
+ out_dir = Path("data/processed/ner_dataset")
62
+ out_dir.parent.mkdir(parents=True, exist_ok=True)
63
+
64
+ logger.info("Loading NCBI Disease dataset...")
65
+ dataset = load_dataset("ncbi_disease")
66
+
67
+ logger.info("Augmenting symptom labels...")
68
+ dataset = dataset.map(add_symptom_labels)
69
+
70
+ logger.info("Tokenizing and aligning BIO tags...")
71
+ tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.2")
72
+
73
+ tokenized_dataset = dataset.map(
74
+ lambda x: tokenize_and_align_labels(x, tokenizer),
75
+ batched=True,
76
+ remove_columns=dataset["train"].column_names
77
+ )
78
+
79
+ tokenized_dataset.save_to_disk(str(out_dir))
80
+ logger.info(f"✅ NER dataset prepared and saved to {out_dir}")
81
+
82
+ if __name__ == "__main__":
83
+ main()