import asyncio from collections import Counter import hashlib import hmac import io import json import math import os import re import secrets import time import uuid import urllib.error import urllib.request from datetime import datetime from typing import Any, Dict, List, Optional from zoneinfo import ZoneInfo import numpy as np import soundfile as sf import torch import whisper from fastapi import HTTPException from google.genai import types from PIL import Image, ImageOps from pymongo.errors import OperationFailure from transformers import AutoTokenizer, VitsModel from core import ( AUTO_COMPACT_COOLDOWN_SEC, AUTO_COMPACT_ENABLED, AUTO_COMPACT_MIN_MESSAGES, AUTO_COMPACT_MIN_TOTAL_CHARS, COMPACT_HARD_DELETE, PASSWORD_HASH_ITERATIONS, SESSION_TTL_DAYS, TTS_CHUNK_CHARS, TTS_DEFAULT_SAMPLE_RATE, TTS_MODEL_NAME, _audio_model_lock, _auto_compact_lock, chat_collection, image_assets_collection, memory_collection, model_client, NVIDIA_API_BASE, NVIDIA_KEY, proposals_collection, sessions_collection, tasks_collection, TEAM_AGENT_MODEL, team_chat_collection, team_doc_chunks_collection, team_docs_collection, teams_collection, users_collection, projects_collection, issues_collection, UPLOAD_DIR, WHISPER_MODEL_NAME, ) from prompts import DOC_QA_COMPACT_PROMPT, TTS_REWRITE_PROMPT VN_TZ = ZoneInfo("Asia/Ho_Chi_Minh") _whisper_model = None _tts_tokenizer = None _tts_model = None _last_auto_compact_ts = 0.0 TEAM_DOC_NODE_CONTENT_LIMIT = int(os.environ.get("TEAM_DOC_NODE_CONTENT_LIMIT", "2400")) TEAM_DOC_NODE_CHUNK_SIZE = int(os.environ.get("TEAM_DOC_NODE_CHUNK_SIZE", "80")) def get_vn_now() -> datetime: return datetime.now(tz=VN_TZ) def get_vn_time_str() -> str: return get_vn_now().strftime("%Y-%m-%dT%H:%M:%S+07:00") def get_whisper_model(): global _whisper_model if _whisper_model is None: with _audio_model_lock: if _whisper_model is None: _whisper_model = whisper.load_model(WHISPER_MODEL_NAME) return _whisper_model def get_tts_model(): global _tts_tokenizer, _tts_model if _tts_tokenizer is None or _tts_model is None: with _audio_model_lock: if _tts_tokenizer is None or _tts_model is None: _tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME) _tts_model = VitsModel.from_pretrained(TTS_MODEL_NAME) _tts_model.eval() return _tts_tokenizer, _tts_model def _wav_bytes_from_audio(audio: Any, sample_rate: int) -> bytes: arr = np.asarray(audio).squeeze() if arr.size == 0: raise RuntimeError("Empty audio from TTS engine") if arr.dtype.kind in ("i", "u"): max_abs = float(np.max(np.abs(arr))) or 1.0 arr = arr.astype(np.float32) / max_abs else: arr = arr.astype(np.float32) arr = np.clip(arr, -1.0, 1.0) wav_buffer = io.BytesIO() sf.write(wav_buffer, arr, sample_rate, format="WAV") wav_buffer.seek(0) return wav_buffer.read() def synthesize_vi_speech(text: str) -> bytes: tokenizer, tts_model = get_tts_model() inputs = tokenizer(text=text, return_tensors="pt") with torch.no_grad(), np.errstate(all="ignore"): output = tts_model(**inputs) audio = output.waveform.squeeze().detach().cpu().numpy() return _wav_bytes_from_audio(audio, int(tts_model.config.sampling_rate)) def split_tts_chunks(text: str, max_chars: int) -> List[str]: if len(text) <= max_chars: return [text] parts = re.split(r"(?<=[.!?;:])\s+", text) chunks: List[str] = [] current = "" for part in parts: part = part.strip() if not part: continue candidate = f"{current} {part}".strip() if current else part if len(candidate) <= max_chars: current = candidate else: if current: chunks.append(current) current = part if current: chunks.append(current) return chunks def synthesize_vi_speech_long(text: str) -> bytes: chunks = split_tts_chunks(text, TTS_CHUNK_CHARS) all_audio: List[np.ndarray] = [] target_sr: Optional[int] = None for idx, chunk in enumerate(chunks): wav_chunk = synthesize_vi_speech(chunk) audio_arr, sr = sf.read(io.BytesIO(wav_chunk), dtype="float32") if isinstance(audio_arr, np.ndarray) and audio_arr.ndim > 1: audio_arr = audio_arr.mean(axis=1) if target_sr is None: target_sr = int(sr) all_audio.append(audio_arr.astype(np.float32)) if idx < len(chunks) - 1 and target_sr: all_audio.append(np.zeros(int(target_sr * 0.12), dtype=np.float32)) merged = np.concatenate(all_audio) if all_audio else np.array([], dtype=np.float32) return _wav_bytes_from_audio(merged, target_sr or TTS_DEFAULT_SAMPLE_RATE) def normalize_tts_text(raw_text: str) -> str: text = (raw_text or "").replace("\r\n", "\n").replace("\r", "\n") text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r"\1", text) text = re.sub(r"[`*_#>-]", " ", text) text = re.sub(r"\s+", " ", text).strip() return text def rewrite_text_for_speech(base_text: str, speech_prompt: Optional[str] = None) -> str: prompt = TTS_REWRITE_PROMPT if speech_prompt and speech_prompt.strip(): prompt = f"{prompt}\n\nYêu cầu phong cách đọc thêm từ client:\n{speech_prompt.strip()}" try: response = model_client.models.generate_content( model=os.environ.get("MODEL_NAME", "gemini-flash-lite-latest"), contents=base_text, config=types.GenerateContentConfig(system_instruction=prompt), ) rewritten = (response.text or "").strip() return rewritten or base_text except Exception: return base_text def save_chat_message(role: str, content: str, msg_id: Optional[str] = None): chat_collection.insert_one( { "id": msg_id or str(uuid.uuid4()), "role": role, "content": content, "timestamp": get_vn_now().isoformat(), "date": get_vn_now().strftime("%Y-%m-%d"), } ) def get_daily_chat(include_compacted: bool = False) -> List[Dict[str, Any]]: today = get_vn_now().strftime("%Y-%m-%d") query: Dict[str, Any] = {"date": today} if not include_compacted: query["$or"] = [{"compacted": {"$exists": False}}, {"compacted": False}] return list(chat_collection.find(query, {"_id": 0}).sort("timestamp", 1)) def get_memory() -> str: mem = memory_collection.find_one({"type": "user_memory"}, {"_id": 0}) return mem["content"] if mem else "" def _team_doc_qa_memory_scope_key(team_id: str, project_id: Optional[str]) -> str: return f"{team_id}::{project_id or 'global'}" def get_team_doc_qa_memory(team_id: str, project_id: Optional[str]) -> str: scope_key = _team_doc_qa_memory_scope_key(team_id, project_id) mem = memory_collection.find_one({"type": "team_doc_qa_memory", "scope_key": scope_key}, {"_id": 0}) return str(mem.get("content") or "") if mem else "" def _save_team_doc_qa_memory(team_id: str, project_id: Optional[str], content: str) -> None: scope_key = _team_doc_qa_memory_scope_key(team_id, project_id) memory_collection.update_one( {"type": "team_doc_qa_memory", "scope_key": scope_key}, { "$set": { "type": "team_doc_qa_memory", "scope_key": scope_key, "team_id": team_id, "project_id": project_id, "content": content, "updated_at": get_vn_now().isoformat(), } }, upsert=True, ) async def compact_team_doc_qa_memory( team_id: str, project_id: Optional[str], query: str, answer: str, doc_context: Dict[str, Any], selected_messages: List[Dict[str, Any]], citations: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: current_memory = get_team_doc_qa_memory(team_id, project_id) sections = doc_context.get("sections") if isinstance(doc_context, dict) else [] payload = { "team_id": team_id, "project_id": project_id, "current_memory": current_memory, "query": query, "answer": answer, "selected_messages": selected_messages[-4:] if isinstance(selected_messages, list) else [], "citations": citations[:6] if isinstance(citations, list) else [], "evidence_sections": [ { "document_id": section.get("document_id"), "document_name": section.get("document_name"), "section_id": section.get("section_id"), "section_title": section.get("section_title"), "section_path": section.get("section_path"), "section_content": _clip_text(str(section.get("section_content") or ""), 420), "section_summary": _clip_text(str(section.get("section_summary") or ""), 240), } for section in sections[:6] ], } def run_compact() -> str: return _nvidia_chat_completion( system_prompt=DOC_QA_COMPACT_PROMPT, user_prompt=json.dumps(payload, ensure_ascii=False), model=TEAM_AGENT_MODEL, temperature=0.0, max_tokens=800, ) result_text = await asyncio.to_thread(run_compact) result_json = _extract_json_object(result_text) memory_summary = str(result_json.get("memory_summary") or "").strip() if not memory_summary: memory_summary = current_memory if memory_summary: _save_team_doc_qa_memory(team_id, project_id, memory_summary) return { "memory_summary": memory_summary, "raw": result_json, } async def compact_chat_with_prompt(system_prompt: str, min_messages: int = 6) -> Dict[str, Any]: messages = get_daily_chat() if len(messages) < min_messages: return {"ok": False, "reason": "not_enough_messages", "deleted": 0, "hard_delete": COMPACT_HARD_DELETE} msgs_json = json.dumps( [{"id": m["id"], "role": m["role"], "content": m["content"]} for m in messages], ensure_ascii=False, ) def run_compact_once() -> str: response = model_client.models.generate_content( model=os.environ.get("MODEL_NAME", "gemini-flash-lite-latest"), contents=f"Lịch sử chat:\n{msgs_json}", config=types.GenerateContentConfig(system_instruction=system_prompt), ) return response.text or "" result_text = await asyncio.to_thread(run_compact_once) json_match = re.search(r"\{[\s\S]*\}", result_text) if not json_match: return {"ok": False, "reason": "invalid_compact_result", "deleted": 0, "hard_delete": COMPACT_HARD_DELETE} result = json.loads(json_match.group()) delete_ids = result.get("delete_ids", []) memory_summary = result.get("memory_summary", "") compacted_count = 0 if delete_ids: if COMPACT_HARD_DELETE: res = chat_collection.delete_many({"id": {"$in": delete_ids}}) compacted_count = res.deleted_count else: res = chat_collection.update_many( {"id": {"$in": delete_ids}}, {"$set": {"compacted": True, "compacted_at": get_vn_now().isoformat()}}, ) compacted_count = res.modified_count if memory_summary: memory_collection.update_one( {"type": "user_memory"}, {"$set": {"content": memory_summary, "updated_at": get_vn_now().isoformat()}}, upsert=True, ) return { "ok": True, "deleted": compacted_count, "hard_delete": COMPACT_HARD_DELETE, "memory": memory_summary, } async def maybe_auto_compact_voice_chat(system_prompt: str) -> None: global _last_auto_compact_ts if not AUTO_COMPACT_ENABLED: return messages = get_daily_chat() if len(messages) < AUTO_COMPACT_MIN_MESSAGES: return total_chars = sum(len((m.get("content") or "")) for m in messages) if total_chars < AUTO_COMPACT_MIN_TOTAL_CHARS: return now_ts = time.time() with _auto_compact_lock: if now_ts - _last_auto_compact_ts < AUTO_COMPACT_COOLDOWN_SEC: return _last_auto_compact_ts = now_ts try: await compact_chat_with_prompt(system_prompt=system_prompt, min_messages=AUTO_COMPACT_MIN_MESSAGES) except Exception: return def _hash_password(password: str, salt: Optional[bytes] = None) -> tuple[str, str]: salt_bytes = salt or secrets.token_bytes(16) password_hash = hashlib.pbkdf2_hmac( "sha256", password.encode("utf-8"), salt_bytes, PASSWORD_HASH_ITERATIONS, ) return salt_bytes.hex(), password_hash.hex() def _verify_password(password: str, salt_hex: str, password_hash_hex: str) -> bool: _, computed_hash = _hash_password(password, bytes.fromhex(salt_hex)) return hmac.compare_digest(computed_hash, password_hash_hex) def _create_session(user_id: str) -> Dict[str, str]: token = secrets.token_urlsafe(32) expires_at = get_vn_now().timestamp() + (SESSION_TTL_DAYS * 24 * 60 * 60) expires_iso = datetime.fromtimestamp(expires_at, tz=VN_TZ).isoformat() sessions_collection.insert_one( { "token": token, "user_id": user_id, "created_at": get_vn_now().isoformat(), "expires_at": expires_iso, } ) return {"token": token, "expires_at": expires_iso} def get_session_user(session_token: Optional[str]) -> Optional[Dict[str, Any]]: if not session_token: return None session = sessions_collection.find_one( { "token": session_token, "revoked_at": {"$exists": False}, "expires_at": {"$gt": get_vn_now().isoformat()}, }, {"_id": 0}, ) if not session: return None user = users_collection.find_one( {"id": session["user_id"]}, {"_id": 0, "password_hash": 0, "password_salt": 0}, ) return user def require_session_user(session_token: Optional[str]) -> Dict[str, Any]: user = get_session_user(session_token) if not user: raise HTTPException(status_code=401, detail="Unauthorized") return user def unique_ids(*lists: List[str]) -> List[str]: seen: set[str] = set() result: List[str] = [] for item_list in lists: for item in item_list: if item and item not in seen: seen.add(item) result.append(item) return result def get_user_team_ids(user_id: str) -> List[str]: return [team["id"] for team in teams_collection.find({"member_ids": user_id}, {"_id": 0, "id": 1})] def project_visibility_query(user_id: str) -> Dict[str, Any]: team_ids = get_user_team_ids(user_id) clauses: List[Dict[str, Any]] = [{"owner_id": user_id}, {"member_ids": user_id}] if team_ids: clauses.append({"team_ids": {"$in": team_ids}}) return {"$or": clauses} def get_users_by_ids(user_ids: List[str]) -> List[Dict[str, Any]]: normalized_ids = [user_id for user_id in unique_ids(user_ids) if user_id] if not normalized_ids: return [] return list( users_collection.find( {"id": {"$in": normalized_ids}}, {"_id": 0, "password_hash": 0, "password_salt": 0}, ) ) def resolve_user_reference(payload: Dict[str, Any]) -> Optional[Dict[str, Any]]: lookup_order = [ (payload.get("user_id") or payload.get("assignee_id") or payload.get("member_id"), "id"), (payload.get("user_email") or payload.get("email") or payload.get("assignee_email"), "email"), (payload.get("user_name") or payload.get("name") or payload.get("assignee_name"), "name"), ] for value, field in lookup_order: if not value: continue if field == "id": user = users_collection.find_one({"id": value}, {"_id": 0, "password_hash": 0, "password_salt": 0}) else: regex = re.compile(re.escape(str(value).strip()), re.IGNORECASE) user = users_collection.find_one( {field: regex}, {"_id": 0, "password_hash": 0, "password_salt": 0}, ) if user: return user return None def get_team_agent_context(team_id: str, project_id: Optional[str], issue_anchor_id: Optional[str], window: int = 7) -> Dict[str, Any]: team = teams_collection.find_one({"id": team_id}, {"_id": 0}) project = None if project_id: project = projects_collection.find_one({"id": project_id}, {"_id": 0}) team_member_ids = unique_ids([team.get("owner_id", "")] if team else [], team.get("member_ids", []) if team else []) project_member_ids = unique_ids([project.get("owner_id", "")] if project else [], project.get("member_ids", []) if project else []) visible_member_ids = unique_ids(team_member_ids, project_member_ids) recent_chat = get_team_chat_context(team_id, None, window=window) issue_anchor = None if issue_anchor_id: issue_anchor = issues_collection.find_one({"id": issue_anchor_id}, {"_id": 0}) project_issues: List[Dict[str, Any]] = [] if project_id: project_issues = list( issues_collection.find( {"project_id": project_id}, { "_id": 0, "id": 1, "title": 1, "description": 1, "severity": 1, "status": 1, "assignee_id": 1, "tags": 1, "requirement_text": 1, "attachment_urls": 1, "updated_at": 1, "created_at": 1, }, ).sort("updated_at", -1).limit(20) ) issue_related: List[Dict[str, Any]] = [] if issue_anchor and project_issues: anchor_tags = set(issue_anchor.get("tags", []) or []) anchor_title = str(issue_anchor.get("title", "")).lower() for issue in project_issues: if issue["id"] == issue_anchor["id"]: continue tags = set(issue.get("tags", []) or []) if tags.intersection(anchor_tags): issue_related.append(issue) continue title = str(issue.get("title", "")).lower() if anchor_title and (anchor_title in title or title in anchor_title): issue_related.append(issue) issue_related = issue_related[:6] open_issues = [issue for issue in project_issues if issue.get("status") not in {"done", "closed", "resolved"}] critical_issues = [issue for issue in project_issues if issue.get("severity") == "critical"] high_issues = [issue for issue in project_issues if issue.get("severity") == "high"] members = get_users_by_ids(visible_member_ids) return { "team": team, "project": project, "members": members, "member_ids": visible_member_ids, "issue_anchor": issue_anchor, "recent_chat": recent_chat, "project_issues": project_issues, "open_issue_count": len(open_issues), "critical_issue_count": len(critical_issues), "high_issue_count": len(high_issues), "related_issues": issue_related, } def compress_image_bytes(raw_bytes: bytes, max_size: int = 1600, quality: int = 78) -> tuple[bytes, str, int, int]: image = Image.open(io.BytesIO(raw_bytes)) image = ImageOps.exif_transpose(image) image = image.convert("RGB") image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) output = io.BytesIO() image.save(output, format="WEBP", quality=quality, method=6, optimize=True) return output.getvalue(), "image/webp", image.width, image.height def store_uploaded_image(raw_bytes: bytes, original_name: str, scope: str, scope_id: str) -> Dict[str, Any]: file_id = str(uuid.uuid4()) ext = ".webp" folder = os.path.join(UPLOAD_DIR, scope, scope_id) os.makedirs(folder, exist_ok=True) file_path = os.path.join(folder, f"{file_id}{ext}") compressed_bytes, mime_type, width, height = compress_image_bytes(raw_bytes) with open(file_path, "wb") as file_handle: file_handle.write(compressed_bytes) url_path = f"/uploads/{scope}/{scope_id}/{file_id}{ext}" asset = { "id": file_id, "scope": scope, "scope_id": scope_id, "original_name": original_name, "stored_name": f"{file_id}{ext}", "url": url_path, "mime_type": mime_type, "width": width, "height": height, "size_bytes": len(compressed_bytes), "created_at": get_vn_now().isoformat(), } image_assets_collection.insert_one(asset) return asset def get_team_chat_context(team_id: str, issue_anchor_id: Optional[str], window: int = 7) -> List[Dict[str, Any]]: history = list(team_chat_collection.find({"team_id": team_id}, {"_id": 0}).sort("timestamp", 1)) if not history: return [] if not issue_anchor_id: return history[-(window * 2 + 1) :] idx = next((i for i, msg in enumerate(history) if msg.get("id") == issue_anchor_id), None) if idx is None: return history[-(window * 2 + 1) :] start = max(0, idx - window) end = min(len(history), idx + window + 1) return history[start:end] def get_selected_team_messages(team_id: str, selected_ids: List[str], project_id: Optional[str] = None) -> List[Dict[str, Any]]: if not selected_ids: return [] query: Dict[str, Any] = {"team_id": team_id, "id": {"$in": selected_ids}} if project_id: query["project_id"] = project_id messages = list(team_chat_collection.find(query, {"_id": 0}).sort("timestamp", 1)) by_id = {message.get("id"): message for message in messages} ordered = [by_id[msg_id] for msg_id in selected_ids if msg_id in by_id] return ordered def _safe_decode_text(raw_bytes: bytes, fallback_name: str = "") -> str: encodings = ["utf-8", "utf-8-sig", "cp1258", "latin-1"] for enc in encodings: try: return raw_bytes.decode(enc) except Exception: continue return raw_bytes.decode("utf-8", errors="ignore") def _split_large_node_content(tree: Dict[str, Any], max_chars: int = TEAM_DOC_NODE_CONTENT_LIMIT) -> Dict[str, Any]: nodes = tree.get("nodes") or [] if not nodes: return tree by_id: Dict[str, Dict[str, Any]] = {str(node.get("id") or ""): node for node in nodes if node.get("id")} root_id = str(tree.get("root_id") or "root") next_split_idx = 0 normalized_nodes: List[Dict[str, Any]] = [] for node in nodes: node_id = str(node.get("id") or "") clone = dict(node) clone["children"] = list(node.get("children") or []) content = str(clone.get("content") or "") if node_id == root_id or len(content) <= max_chars: normalized_nodes.append(clone) continue parts = [content[i : i + max_chars].strip() for i in range(0, len(content), max_chars)] parts = [part for part in parts if part] if not parts: clone["content"] = "" normalized_nodes.append(clone) continue clone["content"] = parts[0] clone["summary"] = _clip_text(parts[0], 280) clone["contextual_summary"] = _clip_text(parts[0], 260) parent_level = int(clone.get("level") or 1) continuation_children = list(clone.get("children") or []) title = str(clone.get("title") or "Section") normalized_nodes.append(clone) for idx, part in enumerate(parts[1:], start=2): next_split_idx += 1 split_id = f"{node_id}_part_{next_split_idx}" while split_id in by_id: next_split_idx += 1 split_id = f"{node_id}_part_{next_split_idx}" child_node = { "id": split_id, "parent_id": node_id, "level": parent_level + 1, "title": f"{title} (phần {idx})", "summary": _clip_text(part, 280), "contextual_summary": _clip_text(part, 260), "scope": f"Nội dung tiếp theo của: {title}", "content": part, "children": [], } continuation_children.append(split_id) by_id[split_id] = child_node normalized_nodes.append(child_node) clone["children"] = continuation_children return { "root_id": root_id, "nodes": normalized_nodes, "total_nodes": len(normalized_nodes), } def _chunk_document_nodes(nodes: List[Dict[str, Any]], chunk_size: int = TEAM_DOC_NODE_CHUNK_SIZE) -> List[List[Dict[str, Any]]]: if not nodes: return [] size = max(1, int(chunk_size)) return [nodes[i : i + size] for i in range(0, len(nodes), size)] def _store_team_document_chunks(doc_id: str, nodes: List[Dict[str, Any]], created_at: str) -> None: chunks = _chunk_document_nodes(nodes) if not chunks: return chunk_docs = [ { "id": str(uuid.uuid4()), "doc_id": doc_id, "chunk_index": idx, "nodes": chunk_nodes, "created_at": created_at, "updated_at": created_at, } for idx, chunk_nodes in enumerate(chunks) ] team_doc_chunks_collection.insert_many(chunk_docs) def _load_team_document_nodes(doc_id: str) -> List[Dict[str, Any]]: normalized_doc_id = str(doc_id or "").strip() if not normalized_doc_id: return [] projection = {"_id": 0, "chunk_index": 1, "nodes": 1} try: # Prefer index-backed sort. If index is not ready yet, fallback below. cursor = team_doc_chunks_collection.find( {"doc_id": normalized_doc_id}, projection, ).hint("idx_team_doc_chunks_doc_id_chunk").sort("chunk_index", 1) if hasattr(cursor, "allow_disk_use"): cursor = cursor.allow_disk_use(True) rows = list(cursor) except (OperationFailure, ValueError): rows = list( team_doc_chunks_collection.find( {"doc_id": normalized_doc_id}, projection, ) ) rows.sort(key=lambda row: int(row.get("chunk_index", 0))) nodes: List[Dict[str, Any]] = [] for row in rows: chunk_nodes = row.get("nodes") if isinstance(chunk_nodes, list): nodes.extend(chunk_nodes) return nodes def _resolve_document_tree(doc: Dict[str, Any]) -> Dict[str, Any]: inline_tree = doc.get("tree") if isinstance(doc.get("tree"), dict) else None if inline_tree and isinstance(inline_tree.get("nodes"), list): total_nodes = inline_tree.get("total_nodes") if not isinstance(total_nodes, int): inline_tree["total_nodes"] = len(inline_tree.get("nodes") or []) return inline_tree tree_meta = doc.get("tree_meta") if isinstance(doc.get("tree_meta"), dict) else {} nodes = _load_team_document_nodes(str(doc.get("id") or "")) root_id = str(tree_meta.get("root_id") or "root") total_nodes = int(tree_meta.get("total_nodes") or len(nodes)) return { "root_id": root_id, "nodes": nodes, "total_nodes": total_nodes, } def build_document_tree(text: str) -> Dict[str, Any]: lines = (text or "").splitlines() nodes: List[Dict[str, Any]] = [] root_id = "root" nodes.append( { "id": root_id, "parent_id": None, "level": 0, "title": "Document Root", "summary": "Điểm bắt đầu của tài liệu", "scope": "Câu hỏi tổng quan về toàn bộ tài liệu", "content": "", "children": [], } ) stack: List[Dict[str, Any]] = [nodes[0]] node_counter = 0 def _create_node(title: str, level: int) -> Dict[str, Any]: nonlocal node_counter node_counter += 1 while stack and stack[-1]["level"] >= level: stack.pop() parent = stack[-1] if stack else nodes[0] node = { "id": f"sec_{node_counter}", "parent_id": parent["id"], "level": level, "title": title.strip() or f"Section {node_counter}", "summary": "", "contextual_summary": "", "scope": "", "content": "", "children": [], } parent["children"].append(node["id"]) nodes.append(node) stack.append(node) return node current = _create_node("Nội dung chính", 1) for raw_line in lines: line = raw_line.rstrip() if not line.strip(): current["content"] += "\n" continue md_heading = re.match(r"^(#{1,6})\s+(.+)$", line) if md_heading: current = _create_node(md_heading.group(2), len(md_heading.group(1))) continue num_heading = re.match(r"^(\d+(?:\.\d+)*)\s+(.+)$", line) if num_heading and len(num_heading.group(1).split(".")) <= 4: current = _create_node(num_heading.group(2), len(num_heading.group(1).split("."))) continue current["content"] += line + "\n" for node in nodes: if node["id"] == root_id: continue paragraphs = [part.strip() for part in node["content"].split("\n") if part.strip()] summary = paragraphs[0] if paragraphs else f"Mục {node['title']}" node["summary"] = summary[:280] node["contextual_summary"] = node["summary"] node["scope"] = f"Dùng để trả lời câu hỏi liên quan tới: {node['title']}" node["content"] = node["content"].strip() return { "root_id": root_id, "nodes": nodes, "total_nodes": len(nodes), } def _contextualize_document_tree(file_name: str, text: str, tree: Dict[str, Any]) -> Dict[str, Any]: nodes = tree.get("nodes") or [] target_nodes = [node for node in nodes if node.get("id") and node.get("id") != tree.get("root_id")] if not target_nodes: return { "global_summary": f"Tài liệu {file_name}", "context_coverage": 0, "context_source": "fallback", } compact_nodes = [ { "id": str(node.get("id") or ""), "title": str(node.get("title") or ""), "summary": _clip_text(str(node.get("summary") or ""), 180), "content_snippet": _clip_text(str(node.get("content") or ""), 180), } for node in target_nodes[:80] ] global_context = "" contextual_map: Dict[str, str] = {} try: payload = { "file_name": file_name, "document_snippet": _clip_text(text, 1600), "nodes": compact_nodes, "instruction": ( "Sinh context retrieval cho từng node để tăng độ chính xác tìm kiếm. " "Mỗi context 1 câu ngắn, có thực thể/chủ đề cụ thể, không bịa thêm dữ kiện." ), } response = _nvidia_chat_completion( system_prompt=( "Trả về JSON thuần: " "{\"global_summary\":\"...\",\"nodes\":[{\"id\":\"sec_x\",\"context\":\"...\"}]}." ), user_prompt=json.dumps(payload, ensure_ascii=False), model=TEAM_AGENT_MODEL, temperature=0.0, max_tokens=1000, ) parsed = _extract_json_object(response) global_context = str(parsed.get("global_summary") or "").strip() node_items = parsed.get("nodes") if isinstance(parsed.get("nodes"), list) else [] for item in node_items: if not isinstance(item, dict): continue node_id = str(item.get("id") or "").strip() context = str(item.get("context") or "").strip() if node_id and context: contextual_map[node_id] = _clip_text(context, 260) except Exception: global_context = "" if not global_context: first_lines = [line.strip() for line in (text or "").splitlines() if line.strip()] global_context = _clip_text(" ".join(first_lines[:3]) or f"Tài liệu {file_name}", 260) applied = 0 for node in target_nodes: node_id = str(node.get("id") or "") node_context = contextual_map.get(node_id) if not node_context: node_context = _clip_text( f"{global_context}. Mục {node.get('title')}: {node.get('summary') or 'Nội dung liên quan'}", 260, ) else: applied += 1 node["contextual_summary"] = node_context root_id = tree.get("root_id") for node in nodes: if node.get("id") == root_id: node["summary"] = _clip_text(global_context, 280) node["contextual_summary"] = node["summary"] node["scope"] = "Tóm tắt toàn bộ tài liệu cho truy vấn tổng quan" break return { "global_summary": global_context, "context_coverage": round(applied / max(1, len(target_nodes)), 4), "context_source": "llm_contextualizer" if contextual_map else "fallback", } def save_team_document( team_id: str, project_id: Optional[str], uploader_id: str, file_name: str, mime_type: str, raw_bytes: bytes, ) -> Dict[str, Any]: text = _safe_decode_text(raw_bytes, file_name) tree = _split_large_node_content(build_document_tree(text)) contextual_meta = _contextualize_document_tree(file_name=file_name, text=text, tree=tree) doc_id = str(uuid.uuid4()) now_iso = get_vn_now().isoformat() doc = { "id": doc_id, "team_id": team_id, "project_id": project_id, "name": file_name, "mime_type": mime_type, "uploader_id": uploader_id, "tree_meta": { "root_id": tree.get("root_id", "root"), "total_nodes": int(tree.get("total_nodes") or len(tree.get("nodes") or [])), "storage": "chunked", "chunk_size": TEAM_DOC_NODE_CHUNK_SIZE, }, "text_meta": { "raw_bytes": len(raw_bytes), "decoded_chars": len(text), "storage": "transient", }, "contextual_global_summary": contextual_meta.get("global_summary", ""), "contextual_meta": contextual_meta, "created_at": now_iso, "updated_at": now_iso, } try: team_docs_collection.insert_one(doc) _store_team_document_chunks(doc_id=doc_id, nodes=tree.get("nodes") or [], created_at=now_iso) except Exception: team_docs_collection.delete_many({"id": doc_id}) team_doc_chunks_collection.delete_many({"doc_id": doc_id}) raise out_doc = dict(doc) out_doc["tree"] = { "root_id": tree.get("root_id", "root"), "total_nodes": tree.get("total_nodes", len(tree.get("nodes") or [])), } return out_doc def list_team_documents(team_id: str, project_id: Optional[str] = None) -> List[Dict[str, Any]]: query: Dict[str, Any] = {"team_id": team_id} if project_id: query["project_id"] = project_id docs = list( team_docs_collection.find( query, { "_id": 0, "id": 1, "team_id": 1, "project_id": 1, "name": 1, "mime_type": 1, "uploader_id": 1, "created_at": 1, "updated_at": 1, "tree": 1, "tree_meta": 1, }, ).sort("updated_at", -1) ) for doc in docs: tree = _resolve_document_tree(doc) nodes = tree.get("nodes") or [] doc["tree"] = { "root_id": tree.get("root_id", "root"), "total_nodes": int(tree.get("total_nodes") or len(nodes)), } doc["node_catalog"] = [ { "id": node.get("id"), "parent_id": node.get("parent_id"), "level": node.get("level"), "title": node.get("title"), "summary": node.get("summary"), "contextual_summary": node.get("contextual_summary"), "path": _build_node_path(tree, str(node.get("id") or "")).get("node_path", ""), "path_titles": _build_node_path(tree, str(node.get("id") or "")).get("node_path_titles", []), "path_ids": _build_node_path(tree, str(node.get("id") or "")).get("node_path_ids", []), } for node in nodes if node.get("id") ] return docs def get_team_documents_by_ids(team_id: str, doc_ids: List[str], project_id: Optional[str] = None) -> List[Dict[str, Any]]: if not doc_ids: return [] query: Dict[str, Any] = {"team_id": team_id, "id": {"$in": doc_ids}} if project_id: query["project_id"] = project_id docs = list( team_docs_collection.find( query, { "_id": 0, "id": 1, "team_id": 1, "project_id": 1, "name": 1, "mime_type": 1, "uploader_id": 1, "contextual_global_summary": 1, "contextual_meta": 1, "created_at": 1, "updated_at": 1, "tree": 1, "tree_meta": 1, }, ) ) for doc in docs: doc["tree"] = _resolve_document_tree(doc) by_id = {doc.get("id"): doc for doc in docs} ordered = [by_id[doc_id] for doc_id in doc_ids if doc_id in by_id] return ordered def build_requirement_node_options_from_documents(documents: List[Dict[str, Any]], limit: int = 8) -> List[Dict[str, Any]]: options: List[Dict[str, Any]] = [] seen_ids: set[str] = set() for doc in documents: tree = doc.get("tree") or {} for node in tree.get("nodes") or []: node_id = str(node.get("id") or "").strip() if not node_id or node_id in seen_ids: continue path = _build_node_path(tree, node_id) node_title = str(node.get("title") or "").strip() node_path = str(path.get("node_path") or "").strip() if not node_title and not node_path: continue options.append( { "node_id": node_id, "node_title": node_title or node_id, "node_path": node_path or node_title or node_id, "node_path_titles": path.get("node_path_titles", []), "node_path_ids": path.get("node_path_ids", []), "node_depth": path.get("node_depth", 0), "document_id": doc.get("id"), "document_name": doc.get("name"), } ) seen_ids.add(node_id) if len(options) >= limit: return options return options def _nvidia_chat_completion(system_prompt: str, user_prompt: str, model: Optional[str] = None, temperature: float = 0.1, max_tokens: int = 1200) -> str: if not NVIDIA_KEY: raise HTTPException(status_code=500, detail="Missing NVIDIA_KEY for team agent") payload = { "model": model or TEAM_AGENT_MODEL, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], "temperature": temperature, "max_tokens": max_tokens, } req = urllib.request.Request( f"{NVIDIA_API_BASE.rstrip('/')}/chat/completions", data=json.dumps(payload).encode("utf-8"), headers={ "Authorization": f"Bearer {NVIDIA_KEY}", "Content-Type": "application/json", "Accept": "application/json", }, method="POST", ) try: with urllib.request.urlopen(req, timeout=90) as response: body = json.loads(response.read().decode("utf-8")) return body.get("choices", [{}])[0].get("message", {}).get("content", "") except urllib.error.HTTPError as exc: detail = exc.read().decode("utf-8", errors="ignore") raise HTTPException(status_code=502, detail=f"NVIDIA API error: {detail[:400]}") except Exception as exc: raise HTTPException(status_code=502, detail=f"NVIDIA API request failed: {str(exc)}") def _extract_json_object(text: str) -> Dict[str, Any]: if not text: return {} candidates: List[str] = [text.strip()] fenced = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE) if fenced: candidates.insert(0, fenced.group(1).strip()) obj = re.search(r"\{[\s\S]*\}", text) if obj: candidates.insert(0, obj.group()) for cand in candidates: try: parsed = json.loads(cand) if isinstance(parsed, dict): return parsed except Exception: continue return {} def _normalize_search_text(text: str) -> str: lowered = (text or "").lower() return re.sub(r"[^\w\s]", " ", lowered, flags=re.UNICODE) def _tokenize_search_text(text: str) -> List[str]: normalized = _normalize_search_text(text) return [token for token in normalized.split() if token] def _clip_text(text: str, max_len: int = 420) -> str: content = (text or "").strip() if len(content) <= max_len: return content return f"{content[: max_len - 3].rstrip()}..." def _build_node_path(tree: Dict[str, Any], node_id: str) -> Dict[str, Any]: nodes = tree.get("nodes") or [] by_id = {str(node.get("id") or ""): node for node in nodes if node.get("id")} root_id = str(tree.get("root_id") or "root") current_id = str(node_id or "").strip() path_nodes: List[Dict[str, Any]] = [] while current_id and current_id in by_id: node = by_id[current_id] path_nodes.append(node) parent_id = str(node.get("parent_id") or "").strip() if not parent_id or parent_id == current_id: break current_id = parent_id path_nodes.reverse() filtered_nodes = [ node for node in path_nodes if str(node.get("id") or "").strip() != root_id and str(node.get("title") or "").strip() != "Document Root" ] titles = [str(node.get("title") or "").strip() for node in filtered_nodes if str(node.get("title") or "").strip()] ids = [str(node.get("id") or "").strip() for node in filtered_nodes if str(node.get("id") or "").strip()] return { "node_path_titles": titles, "node_path_ids": ids, "node_path": " > ".join(titles), "node_depth": max(0, len(titles) - 1), "node_title": titles[-1] if titles else "", "parent_node_title": titles[-2] if len(titles) > 1 else "", } def _format_requirement_node(node_ref: Dict[str, Any]) -> str: node_path = str(node_ref.get("node_path") or "").strip() node_title = str(node_ref.get("node_title") or "").strip() document_name = str(node_ref.get("document_name") or "").strip() if node_path and document_name: return f"{document_name} > {node_path}" if node_path: return node_path return node_title or document_name or "" def _build_requirement_node_reference(sections: List[Dict[str, Any]]) -> Dict[str, Any]: if not sections: return {} top = sections[0] reference = { "document_id": top.get("document_id"), "document_name": top.get("document_name"), "section_id": top.get("section_id"), "section_title": top.get("section_title"), "node_id": top.get("section_id"), "node_title": top.get("section_title"), "node_path": top.get("section_path"), "node_path_titles": top.get("section_path_titles", []), "node_path_ids": top.get("section_path_ids", []), "node_depth": top.get("section_depth", 0), "retrieval_source": top.get("retrieval_source"), "retrieval_score": top.get("retrieval_score"), } reference["node_display"] = _format_requirement_node(reference) return reference def resolve_requirement_node_reference_from_documents(documents: List[Dict[str, Any]], preferred_node_id: Optional[str]) -> Dict[str, Any]: target_id = str(preferred_node_id or "").strip() if not target_id: return {} for doc in documents: tree = doc.get("tree") or {} nodes = tree.get("nodes") or [] by_id = {str(node.get("id") or ""): node for node in nodes if node.get("id")} if target_id not in by_id: continue node = by_id[target_id] path = _build_node_path(tree, target_id) return { "document_id": doc.get("id"), "document_name": doc.get("name"), "node_id": target_id, "node_title": node.get("title"), "node_path": path.get("node_path", ""), "node_path_titles": path.get("node_path_titles", []), "node_path_ids": path.get("node_path_ids", []), "node_depth": path.get("node_depth", 0), "node_display": _format_requirement_node({ "document_name": doc.get("name"), "node_path": path.get("node_path", ""), "node_title": node.get("title"), }), "source": "preferred_node", } return {} def _node_to_search_blob(node: Dict[str, Any]) -> str: fields = [ str(node.get("title") or ""), str(node.get("summary") or ""), str(node.get("scope") or ""), str(node.get("contextual_summary") or ""), str(node.get("content") or ""), ] return "\n".join(field for field in fields if field) def _prepare_bm25f_corpus(documents: List[Dict[str, Any]]) -> Dict[str, Any]: field_names = ["title", "summary", "contextual_summary", "content"] rows: List[Dict[str, Any]] = [] for doc in documents: tree = doc.get("tree") or {} for node in tree.get("nodes") or []: node_id = str(node.get("id") or "").strip() if not node_id: continue field_tokens: Dict[str, List[str]] = {} for field in field_names: field_tokens[field] = _tokenize_search_text(str(node.get(field) or "")) rows.append( { "document_id": doc.get("id"), "document_name": doc.get("name"), "node": node, "field_tokens": field_tokens, } ) total_docs = max(1, len(rows)) avg_field_len: Dict[str, float] = {} doc_freq: Dict[str, Dict[str, int]] = {field: {} for field in field_names} for field in field_names: lengths = [len(row["field_tokens"][field]) for row in rows] avg_field_len[field] = (sum(lengths) / len(lengths)) if lengths else 1.0 for row in rows: unique_terms = set(row["field_tokens"][field]) for term in unique_terms: doc_freq[field][term] = doc_freq[field].get(term, 0) + 1 return { "rows": rows, "field_names": field_names, "avg_field_len": avg_field_len, "doc_freq": doc_freq, "total_docs": total_docs, } def _bm25f_score_row(query_tokens: List[str], row: Dict[str, Any], corpus: Dict[str, Any]) -> float: if not query_tokens: return 0.0 field_weights = { "title": 2.2, "summary": 1.4, "contextual_summary": 1.8, "content": 1.0, } k1 = 1.5 b = 0.75 total_docs = int(corpus.get("total_docs", 1)) avg_field_len = corpus.get("avg_field_len", {}) doc_freq = corpus.get("doc_freq", {}) score = 0.0 for term in query_tokens: term_score = 0.0 max_df = 0 for field in ["title", "summary", "contextual_summary", "content"]: tokens = row["field_tokens"][field] tf = tokens.count(term) if tf <= 0: continue field_len = len(tokens) avg_len = max(1e-6, float(avg_field_len.get(field, 1.0))) norm = (1 - b) + b * (field_len / avg_len) tf_norm = (tf * (k1 + 1)) / (tf + (k1 * norm)) term_score += field_weights[field] * tf_norm df_field = int(doc_freq.get(field, {}).get(term, 0)) max_df = max(max_df, df_field) if term_score <= 0: continue idf = math.log(1 + (total_docs - max_df + 0.5) / (max_df + 0.5)) if max_df > 0 else 0.0 score += term_score * idf return score def _collect_bm25f_candidates(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: query_tokens = _tokenize_search_text(query) if not query_tokens: return [] corpus = _prepare_bm25f_corpus(documents) rows = corpus.get("rows", []) candidates: List[Dict[str, Any]] = [] for row in rows: score = _bm25f_score_row(query_tokens, row, corpus) if score <= 0: continue candidates.append( { "document_id": row.get("document_id"), "document_name": row.get("document_name"), "node": row.get("node") or {}, "score": score, } ) candidates.sort(key=lambda item: float(item.get("score", 0.0)), reverse=True) return candidates[:60] def _generate_hyde_variant(query: str, selected_messages: Optional[List[Dict[str, Any]]] = None) -> str: payload = { "query": query, "selected_messages": (selected_messages or [])[-6:], "instruction": ( "Sinh một đoạn giả định ngắn (2-4 câu) mô tả câu trả lời lý tưởng để phục vụ retrieval. " "Giữ keyword/thuật ngữ kỹ thuật quan trọng, không thêm lan man." ), } text = _nvidia_chat_completion( system_prompt=( "Bạn là bộ tạo HyDE query cho retrieval. " "Trả về văn bản thuần duy nhất, không markdown, không JSON." ), user_prompt=json.dumps(payload, ensure_ascii=False), model=TEAM_AGENT_MODEL, temperature=0.0, max_tokens=220, ) return _clip_text(text.strip(), 500) def _expand_query_variants( query: str, max_variants: int = 5, selected_messages: Optional[List[Dict[str, Any]]] = None, use_hyde: bool = True, ) -> Dict[str, Any]: base = (query or "").strip() if not base: return {"variants": [], "hyde_variant": ""} variants: List[str] = [base] hyde_variant = "" try: payload = { "query": base, "instruction": ( "Sinh tối đa 3 truy vấn thay thế để tăng recall tài liệu kỹ thuật. " "Giữ nguyên ý nghĩa, thêm biến thể keyword/chuẩn thuật ngữ." ), } response = _nvidia_chat_completion( system_prompt=( "Trả về JSON thuần: {\"variants\":[\"...\"]}. " "Không thêm giải thích." ), user_prompt=json.dumps(payload, ensure_ascii=False), model=TEAM_AGENT_MODEL, temperature=0.0, max_tokens=220, ) parsed = _extract_json_object(response) llm_variants = parsed.get("variants") if isinstance(parsed.get("variants"), list) else [] for item in llm_variants: text = str(item or "").strip() if text and text.lower() not in {v.lower() for v in variants}: variants.append(text) if len(variants) >= max_variants: break except Exception: pass # Deterministic backup variant using token dedupe. tokens = _tokenize_search_text(base) if tokens: keyword_variant = " ".join(sorted(set(tokens), key=tokens.index)) if keyword_variant and keyword_variant.lower() not in {v.lower() for v in variants}: variants.append(keyword_variant) if use_hyde and len(variants) < max_variants: try: hyde_variant = _generate_hyde_variant(base, selected_messages=selected_messages) if hyde_variant and hyde_variant.lower() not in {v.lower() for v in variants}: variants.append(hyde_variant) except Exception: hyde_variant = "" return {"variants": variants[:max_variants], "hyde_variant": hyde_variant} def _collect_multi_query_candidates( query: str, documents: List[Dict[str, Any]], selected_messages: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: variant_payload = _expand_query_variants( query, max_variants=5, selected_messages=selected_messages, use_hyde=True, ) variants = variant_payload.get("variants", []) if isinstance(variant_payload, dict) else [] if not variants: return {"variants": [query], "candidates": [], "hyde_variant": ""} k = 60.0 merged: Dict[str, Dict[str, Any]] = {} for variant in variants: candidates = _collect_bm25f_candidates(variant, documents) for rank, item in enumerate(candidates): node = item.get("node") or {} doc_id = str(item.get("document_id") or "") node_id = str(node.get("id") or "") if not doc_id or not node_id: continue key = f"{doc_id}::{node_id}" rrf = 1.0 / (k + rank + 1) score = float(item.get("score", 0.0)) if key not in merged: merged[key] = { "document_id": item.get("document_id"), "document_name": item.get("document_name"), "node": node, "score": 0.0, "max_lexical": 0.0, "query_hits": [], } merged[key]["score"] += rrf merged[key]["max_lexical"] = max(float(merged[key]["max_lexical"]), score) if variant not in merged[key]["query_hits"]: merged[key]["query_hits"].append(variant) fused = list(merged.values()) for item in fused: item["score"] = float(item.get("score", 0.0)) + float(item.get("max_lexical", 0.0)) * 0.45 fused.sort(key=lambda item: float(item.get("score", 0.0)), reverse=True) return { "variants": variants, "candidates": fused[:50], "hyde_variant": variant_payload.get("hyde_variant", "") if isinstance(variant_payload, dict) else "", } def _llm_rerank_document_candidates(query: str, candidates: List[Dict[str, Any]], top_k: int = 8) -> List[str]: if not candidates: return [] payload = { "query": query, "candidates": [ { "id": str(item["node"].get("id") or ""), "document_id": item.get("document_id"), "document_name": item.get("document_name"), "title": item["node"].get("title"), "summary": item["node"].get("summary"), "snippet": _clip_text(item["node"].get("content", ""), 220), "lexical_score": round(float(item.get("score", 0.0)), 4), } for item in candidates[:18] ], "top_k": max(1, min(top_k, 10)), } rerank_text = _nvidia_chat_completion( system_prompt=( "Bạn là bộ xếp hạng bằng chứng tài liệu. " "Trả về JSON thuần: {\"selected_ids\": [\"node_id\"], \"reason\": \"...\"}. " "Chọn các node liên quan nhất để trả lời câu hỏi." ), user_prompt=json.dumps(payload, ensure_ascii=False), model=TEAM_AGENT_MODEL, temperature=0.0, max_tokens=260, ) rerank_json = _extract_json_object(rerank_text) selected_ids_raw = rerank_json.get("selected_ids") if isinstance(selected_ids_raw, list): selected_ids = [str(item).strip() for item in selected_ids_raw if str(item).strip()] if selected_ids: return selected_ids[:top_k] return [ str(item["node"].get("id")) for item in candidates[:top_k] if item["node"].get("id") ] def retrieve_document_context_with_tree( query: str, documents: List[Dict[str, Any]], selected_messages: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: if not documents: return {"sections": [], "citations": []} picked_sections: List[Dict[str, Any]] = [] citations: List[Dict[str, Any]] = [] seen_node_ids: set[str] = set() # Layer 1: tree navigation keeps hierarchical intent and ensures at least one anchor per doc. tree_picks: List[Dict[str, Any]] = [] for doc in documents: tree = doc.get("tree") or {} nodes = tree.get("nodes") or [] by_id = {node.get("id"): node for node in nodes if node.get("id")} children: Dict[str, List[str]] = {} for node in nodes: parent_id = node.get("parent_id") children.setdefault(parent_id or "", []).append(node.get("id")) current_parent = tree.get("root_id", "root") selected_node: Optional[Dict[str, Any]] = None for _ in range(4): candidate_ids = children.get(current_parent, []) candidate_nodes = [by_id[cid] for cid in candidate_ids if cid in by_id] if not candidate_nodes: break nav_prompt = { "query": query, "document_name": doc.get("name"), "candidates": [ { "id": node.get("id"), "title": node.get("title"), "summary": node.get("summary"), "scope": node.get("scope"), "level": node.get("level"), } for node in candidate_nodes ], "instruction": "Chọn node phù hợp nhất để đi tiếp. Nếu đã đủ thì chọn stop.", } nav_text = _nvidia_chat_completion( system_prompt="Trả về JSON: {\"selected_id\": \"...\" hoặc \"stop\", \"reason\": \"...\"}", user_prompt=json.dumps(nav_prompt, ensure_ascii=False), model=TEAM_AGENT_MODEL, temperature=0.0, max_tokens=220, ) nav_json = _extract_json_object(nav_text) selected_id = str(nav_json.get("selected_id") or "").strip() if not selected_id or selected_id.lower() == "stop": break if selected_id not in by_id: break selected_node = by_id[selected_id] current_parent = selected_id if selected_node is None: top_nodes = [by_id[cid] for cid in children.get(current_parent, []) if cid in by_id] selected_node = top_nodes[0] if top_nodes else None if not selected_node: continue selected_node_id = str(selected_node.get("id") or "").strip() node_path = _build_node_path(tree, selected_node_id) path_titles = node_path.get("node_path_titles", []) path_ids = node_path.get("node_path_ids", []) tree_picks.append( { "document_id": doc.get("id"), "document_name": doc.get("name"), "node": selected_node, "score": 1.0, "source": "tree_nav", } ) # Layer 2: multi-query lexical retrieval broadens recall. fused_results = _collect_multi_query_candidates(query, documents, selected_messages=selected_messages) lexical_candidates = fused_results.get("candidates", []) if isinstance(fused_results, dict) else [] query_variants = fused_results.get("variants", [query]) if isinstance(fused_results, dict) else [query] hyde_variant = fused_results.get("hyde_variant", "") if isinstance(fused_results, dict) else "" # Layer 3: LLM reranking improves precision on top lexical candidates. reranked_ids = set(_llm_rerank_document_candidates(query, lexical_candidates, top_k=8)) merged_candidates: List[Dict[str, Any]] = [] merged_candidates.extend(tree_picks) for item in lexical_candidates: node_id = str(item["node"].get("id") or "") if not node_id: continue score = float(item.get("score", 0.0)) if node_id in reranked_ids: score += 1.0 merged_candidates.append( { "document_id": item.get("document_id"), "document_name": item.get("document_name"), "node": item.get("node") or {}, "score": score, "source": "hybrid_rerank" if node_id in reranked_ids else "lexical", } ) merged_candidates.sort(key=lambda item: float(item.get("score", 0.0)), reverse=True) for item in merged_candidates: node = item.get("node") or {} section_id = str(node.get("id") or "").strip() if not section_id or section_id in seen_node_ids: continue seen_node_ids.add(section_id) section_text = str(node.get("content") or "") picked_sections.append( { "document_id": item.get("document_id"), "document_name": item.get("document_name"), "section_id": section_id, "section_title": node.get("title"), "section_content": section_text, "section_summary": node.get("summary", ""), "section_context": node.get("contextual_summary", ""), "section_path": node_path.get("node_path", ""), "section_path_titles": path_titles, "section_path_ids": path_ids, "section_depth": node_path.get("node_depth", 0), "retrieval_score": round(float(item.get("score", 0.0)), 4), "retrieval_source": item.get("source"), "query_hit_count": len(item.get("query_hits", [])) if isinstance(item.get("query_hits"), list) else 0, } ) citations.append( { "document_id": item.get("document_id"), "document_name": item.get("document_name"), "section_id": section_id, "section_title": node.get("title"), "section_path": node_path.get("node_path", ""), "section_path_titles": path_titles, "section_path_ids": path_ids, "source": item.get("source"), } ) if len(picked_sections) >= 10: break return { "sections": picked_sections, "citations": citations, "retrieval_meta": { "tree_pick_count": len(tree_picks), "lexical_candidate_count": len(lexical_candidates), "rerank_pick_count": len(reranked_ids), "query_variants": query_variants, "hyde_used": bool(hyde_variant), }, "requirement_node_reference": _build_requirement_node_reference(picked_sections), } def _evaluate_grounding_confidence( answer: str, citations: List[Dict[str, Any]], sections: List[Dict[str, Any]], retrieval_meta: Dict[str, Any], llm_confidence: str, ) -> Dict[str, Any]: section_ids = { str(section.get("section_id") or "").strip() for section in sections if str(section.get("section_id") or "").strip() } citation_match = 0 for item in citations: if not isinstance(item, dict): continue section_id = str(item.get("section_id") or "").strip() if section_id and section_id in section_ids: citation_match += 1 llm_map = {"low": 0.35, "medium": 0.65, "high": 0.9} llm_score = llm_map.get(llm_confidence, 0.6) cited_ratio = citation_match / max(1, len(citations) if citations else 1) retrieval_strength = min(float(retrieval_meta.get("rerank_pick_count", 0)) / 6.0, 1.0) section_strength = min(len(sections) / 8.0, 1.0) answer_len_strength = min(len((answer or "").split()) / 80.0, 1.0) score = ( llm_score * 0.40 + cited_ratio * 0.25 + retrieval_strength * 0.20 + section_strength * 0.10 + answer_len_strength * 0.05 ) score = max(0.0, min(score, 1.0)) if score >= 0.78: label = "high" elif score >= 0.56: label = "medium" else: label = "low" return { "confidence": label, "confidence_score": round(score, 4), "needs_clarification": label == "low", } def build_document_grounded_answer( query: str, selected_messages: List[Dict[str, Any]], doc_context: Dict[str, Any], qa_memory: Optional[str] = None, ) -> Dict[str, Any]: sections = doc_context.get("sections") if isinstance(doc_context, dict) else [] citations = doc_context.get("citations") if isinstance(doc_context, dict) else [] retrieval_meta = doc_context.get("retrieval_meta") if isinstance(doc_context, dict) else {} if not isinstance(sections, list) or not sections: return { "answer": "", "citations": [], "confidence": "low", "confidence_score": 0.0, "needs_clarification": True, "clarifying_question": "Bạn có thể chọn thêm tài liệu hoặc section liên quan để mình trả lời chính xác hơn không?", } payload = { "query": query, "qa_memory": (qa_memory or "").strip(), "selected_messages": selected_messages[-8:] if isinstance(selected_messages, list) else [], "evidence_sections": [ { "document_id": section.get("document_id"), "document_name": section.get("document_name"), "section_id": section.get("section_id"), "section_title": section.get("section_title"), "section_path": section.get("section_path"), "summary": _clip_text(str(section.get("section_summary") or ""), 280), "context": _clip_text(str(section.get("section_context") or ""), 240), "content": _clip_text(str(section.get("section_content") or ""), 520), } for section in sections[:8] ], "citations": citations[:8] if isinstance(citations, list) else [], } answer_text = _nvidia_chat_completion( system_prompt=( "Bạn là trợ lý phân tích tài liệu dạng NotebookLM-style cho team chat. " "Nhiệm vụ: trả lời trực tiếp câu hỏi user dựa trên evidence đã cho, không bịa, không suy diễn vượt dữ liệu. " "Nếu có qa_memory thì dùng như ngữ cảnh ổn định cho các lượt QA tiếp theo, nhưng không được vượt quá evidence hiện có. " "Trả về JSON thuần: {\"answer\":\"...\",\"citations\":[{\"document_id\":\"...\",\"section_id\":\"...\"}],\"confidence\":\"high|medium|low\"}." ), user_prompt=json.dumps(payload, ensure_ascii=False), model=TEAM_AGENT_MODEL, temperature=0.1, max_tokens=700, ) answer_json = _extract_json_object(answer_text) answer = str(answer_json.get("answer") or "").strip() out_citations = answer_json.get("citations") if isinstance(answer_json.get("citations"), list) else [] confidence = str(answer_json.get("confidence") or "medium").strip().lower() if confidence not in {"high", "medium", "low"}: confidence = "medium" if not answer: top = sections[0] fallback_title = str(top.get("section_title") or "nội dung liên quan").strip() fallback_doc = str(top.get("document_name") or "tài liệu").strip() answer = f"Theo {fallback_doc}, phần '{fallback_title}' là dữ liệu liên quan nhất với câu hỏi hiện tại." out_citations = [ { "document_id": top.get("document_id"), "section_id": top.get("section_id"), } ] confidence = "low" eval_result = _evaluate_grounding_confidence( answer=answer, citations=out_citations, sections=sections, retrieval_meta=retrieval_meta if isinstance(retrieval_meta, dict) else {}, llm_confidence=confidence, ) clarifying_question = "" if eval_result.get("needs_clarification"): clarifying_question = ( "Mình chưa đủ chắc chắn vì bằng chứng tài liệu còn yếu. " "Bạn muốn mình bám vào tài liệu nào hoặc section nào cụ thể hơn?" ) return { "answer": answer, "citations": out_citations, "confidence": eval_result.get("confidence", confidence), "confidence_score": eval_result.get("confidence_score", 0.0), "needs_clarification": bool(eval_result.get("needs_clarification")), "clarifying_question": clarifying_question, "requirement_node_reference": _build_requirement_node_reference(sections), } def run_team_agent_with_nvidia(system_prompt: str, payload: Dict[str, Any]) -> str: return _nvidia_chat_completion( system_prompt=system_prompt, user_prompt=json.dumps(payload, ensure_ascii=False), model=TEAM_AGENT_MODEL, temperature=0.15, max_tokens=1400, ) def save_team_chat_message( team_id: str, role: str, content: str, project_id: Optional[str] = None, attachment_urls: Optional[List[str]] = None, ) -> Dict[str, Any]: doc = { "id": str(uuid.uuid4()), "team_id": team_id, "project_id": project_id, "role": role, "content": content, "attachment_urls": attachment_urls or [], "timestamp": get_vn_now().isoformat(), } team_chat_collection.insert_one(doc) return doc def create_issue_for_project(project_id: str, reporter_id: str, payload: Dict[str, Any]) -> Dict[str, Any]: assignee = resolve_user_reference(payload) tags = payload.get("tags", []) if isinstance(tags, str): tags = [tag.strip() for tag in tags.split(",") if tag.strip()] issue = { "id": str(uuid.uuid4()), "project_id": project_id, "title": payload.get("title", "Issue mới"), "description": payload.get("description", ""), "severity": payload.get("severity", "medium"), "status": payload.get("status", "open"), "assignee_id": assignee["id"] if assignee else payload.get("assignee_id"), "tags": tags, "requirement_text": payload.get("requirement_text"), "requirement_node_id": payload.get("requirement_node_id"), "requirement_node_title": payload.get("requirement_node_title"), "requirement_node_path": payload.get("requirement_node_path"), "requirement_node_path_titles": payload.get("requirement_node_path_titles", []), "requirement_node_path_ids": payload.get("requirement_node_path_ids", []), "requirement_node_depth": payload.get("requirement_node_depth"), "requirement_document_id": payload.get("requirement_document_id"), "requirement_document_name": payload.get("requirement_document_name"), "attachment_urls": payload.get("attachment_urls", []), "reporter_id": reporter_id, "created_at": get_vn_now().isoformat(), "updated_at": get_vn_now().isoformat(), } issues_collection.insert_one(issue) return issue def create_task_from_agent(payload: Dict[str, Any]) -> Dict[str, Any]: tags = payload.get("tags", ["TeamChat"]) if isinstance(tags, str): tags = [tag.strip() for tag in tags.split(",") if tag.strip()] task = { "id": str(uuid.uuid4()), "title": payload.get("title", "Task từ team chat"), "description": payload.get("description", ""), "start_time": payload.get("start_time"), "end_time": payload.get("end_time"), "priority": payload.get("priority", "medium"), "tags": tags, "reminder": payload.get("reminder") or payload.get("start_time"), "requirement_node_id": payload.get("requirement_node_id"), "requirement_node_title": payload.get("requirement_node_title"), "requirement_node_path": payload.get("requirement_node_path"), "requirement_node_path_titles": payload.get("requirement_node_path_titles", []), "requirement_node_path_ids": payload.get("requirement_node_path_ids", []), "requirement_node_depth": payload.get("requirement_node_depth"), "requirement_document_id": payload.get("requirement_document_id"), "requirement_document_name": payload.get("requirement_document_name"), } tasks_collection.insert_one(task) return task def update_issue_from_agent(issue_id: str, payload: Dict[str, Any]) -> Dict[str, Any]: issue = issues_collection.find_one({"id": issue_id}, {"_id": 0}) if not issue: raise HTTPException(status_code=404, detail="Issue not found") update_data: Dict[str, Any] = {"updated_at": get_vn_now().isoformat()} field_names = ["title", "description", "severity", "status", "tags", "attachment_urls", "requirement_text", "requirement_node_id", "requirement_node_title", "requirement_node_path", "requirement_document_id", "requirement_document_name"] for field_name in field_names: value = payload.get(field_name) if value is None: continue if isinstance(value, str): update_data[field_name] = value.strip() else: update_data[field_name] = value if "requirement_node_path_titles" in payload and payload["requirement_node_path_titles"] is not None: update_data["requirement_node_path_titles"] = payload["requirement_node_path_titles"] if "requirement_node_path_ids" in payload and payload["requirement_node_path_ids"] is not None: update_data["requirement_node_path_ids"] = payload["requirement_node_path_ids"] if "requirement_node_depth" in payload and payload["requirement_node_depth"] is not None: update_data["requirement_node_depth"] = payload["requirement_node_depth"] assignee = resolve_user_reference(payload) if assignee: update_data["assignee_id"] = assignee["id"] elif "assignee_id" in payload and payload["assignee_id"] is not None: update_data["assignee_id"] = payload["assignee_id"] issues_collection.update_one({"id": issue_id}, {"$set": update_data}) return issues_collection.find_one({"id": issue_id}, {"_id": 0}) def update_project_from_agent(project_id: str, payload: Dict[str, Any]) -> Dict[str, Any]: project = projects_collection.find_one({"id": project_id}, {"_id": 0}) if not project: raise HTTPException(status_code=404, detail="Project not found") update_data: Dict[str, Any] = {"updated_at": get_vn_now().isoformat()} for field_name in ["name", "description", "team_ids", "member_ids"]: value = payload.get(field_name) if value is None: continue if isinstance(value, str): update_data[field_name] = value.strip() else: update_data[field_name] = value projects_collection.update_one({"id": project_id}, {"$set": update_data}) return projects_collection.find_one({"id": project_id}, {"_id": 0}) def add_team_member_from_agent(team_id: str, payload: Dict[str, Any]) -> Dict[str, Any]: team = teams_collection.find_one({"id": team_id}, {"_id": 0}) if not team: raise HTTPException(status_code=404, detail="Team not found") member = resolve_user_reference(payload) if not member: raise HTTPException(status_code=404, detail="User not found") teams_collection.update_one( {"id": team_id}, {"$addToSet": {"member_ids": member["id"]}, "$set": {"updated_at": get_vn_now().isoformat()}}, ) return {"team_id": team_id, "member": member} def add_project_member_from_agent(project_id: str, payload: Dict[str, Any]) -> Dict[str, Any]: project = projects_collection.find_one({"id": project_id}, {"_id": 0}) if not project: raise HTTPException(status_code=404, detail="Project not found") member = resolve_user_reference(payload) if not member: raise HTTPException(status_code=404, detail="User not found") projects_collection.update_one( {"id": project_id}, {"$addToSet": {"member_ids": member["id"]}, "$set": {"updated_at": get_vn_now().isoformat()}}, ) return {"project_id": project_id, "member": member}