import asyncio import json import os import re import tempfile from typing import Any, Dict import whisper from fastapi import APIRouter, File, HTTPException, UploadFile from fastapi.responses import Response, StreamingResponse from google.genai import types from core import ( GOOGLE_API_KEY, MODEL_NAME, TTS_MAX_CHARS, TTS_MAX_TOTAL_CHARS, TTS_MODEL_NAME, VOICE_MAX_SECONDS, executor, model_client, proposals_collection, tasks_collection, ) from prompts import COMPACT_PROMPT, GATHER_PROMPT, PROPOSAL_PROMPT, VOICE_COMPACT_PROMPT from schemas import ManualTaskRequest, ScheduleRequest, TTSRequest from services import ( compact_chat_with_prompt, get_daily_chat, get_memory, get_vn_now, get_vn_time_str, get_whisper_model, maybe_auto_compact_voice_chat, normalize_tts_text, rewrite_text_for_speech, save_chat_message, synthesize_vi_speech, synthesize_vi_speech_long, ) router = APIRouter() def list_tasks(): return list(tasks_collection.find({}, {"_id": 0})) def check_conflicts(start_time: str, end_time: str): return list( tasks_collection.find( {"$or": [{"start_time": {"$lt": end_time}, "end_time": {"$gt": start_time}}]}, {"_id": 0}, ) ) def gather_context(user_text: str, curr_time: str) -> str: try: response = model_client.models.generate_content( model=MODEL_NAME, contents=user_text, config=types.GenerateContentConfig( system_instruction=GATHER_PROMPT + f"\nThời gian VN hiện tại: {curr_time}", tools=[list_tasks, check_conflicts], automatic_function_calling=types.AutomaticFunctionCallingConfig(disable=False, maximum_remote_calls=5), ), ) return response.text or "Lịch hiện tại trống." except Exception as e: return f"Không thể kiểm tra lịch: {str(e)}" @router.post("/schedule/stream") async def handle_stream(req: ScheduleRequest): if not GOOGLE_API_KEY: raise HTTPException(status_code=500, detail="GOOGLE_API_KEY not set.") curr_time = req.current_time or get_vn_time_str() save_chat_message("user", req.text) memory = get_memory() context_data = await asyncio.to_thread(gather_context, req.text, curr_time) history = get_daily_chat()[-9:-1] history_text = "\n".join([f"{'USER' if m['role']=='user' else 'NOMUS'}: {m['content']}" for m in history]) history_section = "=== Lịch sử hội thoại ===\n" + history_text + "\n\n" if history_text else "" memory_section = "=== Memory về bạn ===\n" + memory + "\n\n" if memory else "" full_prompt = ( f"{history_section}" f"=== Lịch trình hiện tại ===\n{context_data}\n\n" f"{memory_section}" f"=== Yêu cầu mới ===\n{req.text}" ) loop = asyncio.get_event_loop() queue: asyncio.Queue = asyncio.Queue() def run_stream(): full_text = "" try: for chunk in model_client.models.generate_content_stream( model=MODEL_NAME, contents=full_prompt, config=types.GenerateContentConfig(system_instruction=PROPOSAL_PROMPT + f"\nThời gian VN hiện tại: {curr_time}"), ): if chunk.text: full_text += chunk.text asyncio.run_coroutine_threadsafe(queue.put({"type": "chunk", "data": chunk.text}), loop) save_chat_message("assistant", full_text) asyncio.run_coroutine_threadsafe(maybe_auto_compact_voice_chat(VOICE_COMPACT_PROMPT), loop) asyncio.run_coroutine_threadsafe(queue.put({"type": "done"}), loop) except Exception as e: asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "data": str(e)}), loop) loop.run_in_executor(executor, run_stream) async def event_stream(): while True: item = await queue.get() yield f"data: {json.dumps(item, ensure_ascii=False)}\n\n" if item["type"] in ("done", "error"): break return StreamingResponse( event_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/chat/compact") async def compact_chat(): try: result = await compact_chat_with_prompt(system_prompt=COMPACT_PROMPT, min_messages=6) if not result.get("ok"): if result.get("reason") == "not_enough_messages": return {"message": "Cần ít nhất 6 tin nhắn để compact", "deleted": 0} return {"message": "Không thể phân tích kết quả compact", "deleted": 0} return { "message": f"Đã compact {result['deleted']} tin nhắn, cập nhật memory", "deleted": result["deleted"], "hard_delete": result["hard_delete"], "memory": result["memory"], } except Exception as e: return {"message": f"Lỗi compact: {str(e)}", "deleted": 0} @router.get("/chat") async def get_chat(): return {"history": get_daily_chat(include_compacted=False)} @router.get("/tasks") async def get_tasks(): return {"tasks": list(tasks_collection.find({}, {"_id": 0}))} @router.get("/health") async def health_check(): return { "status": "ok", "message": f"Nomus AI V3 (Streaming + Proposals) – {MODEL_NAME}", "vn_time": get_vn_time_str(), "tts_engine": "mms", "tts_model_fallback": TTS_MODEL_NAME, } @router.post("/tasks") async def create_manual_task(task: ManualTaskRequest): task_data = task.model_dump() task_data["id"] = str(get_vn_now().timestamp()).replace(".", "") if not task_data.get("reminder"): task_data["reminder"] = task_data["start_time"] tasks_collection.insert_one(task_data) return {"message": "Task created", "id": task_data["id"]} @router.patch("/tasks/{task_id}") async def update_task(task_id: str, update: Dict[str, Any]): if not update: return {"message": "No data"} result = tasks_collection.update_one({"id": task_id}, {"$set": update}) if result.modified_count == 0: raise HTTPException(status_code=404, detail="Task not found") return {"message": "Task updated"} @router.delete("/tasks/{task_id}") async def delete_task(task_id: str): tasks_collection.delete_one({"id": task_id}) return {"message": "Task deleted"} @router.get("/proposals/statuses") async def get_proposal_statuses(): docs = list(proposals_collection.find({}, {"_id": 0, "key": 1})) return {"rejected_keys": [d["key"] for d in docs]} @router.post("/proposals/reject") async def reject_proposal(req: Dict[str, str]): key = req.get("key") if not key: raise HTTPException(status_code=400, detail="Missing key") proposals_collection.update_one( {"key": key}, {"$set": {"key": key, "rejected_at": get_vn_now().isoformat()}}, upsert=True, ) return {"message": "Proposal rejected", "key": key} @router.delete("/proposals/reject/{key:path}") async def undo_reject_proposal(key: str): proposals_collection.delete_one({"key": key}) return {"message": "Rejection removed", "key": key} @router.post("/transcribe") async def transcribe_audio(file: UploadFile = File(...)): if not file: raise HTTPException(status_code=400, detail="Audio file is required.") file_ext = os.path.splitext(file.filename or "voice.webm")[1] or ".webm" tmp_path = None try: with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp: tmp.write(await file.read()) tmp_path = tmp.name audio = whisper.load_audio(tmp_path) duration_sec = len(audio) / 16000.0 if duration_sec > VOICE_MAX_SECONDS: raise HTTPException(status_code=413, detail=f"Audio too long. Max supported length is {VOICE_MAX_SECONDS}s.") model = get_whisper_model() result = model.transcribe(tmp_path, language=os.environ.get("WHISPER_LANGUAGE", "vi"), task="transcribe", fp16=False) text = (result.get("text") or "").strip() return { "text": text, "language": result.get("language") or os.environ.get("WHISPER_LANGUAGE", "vi"), "duration_ms": int(duration_sec * 1000), "model": os.environ.get("WHISPER_MODEL", "base"), } finally: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) @router.post("/tts") async def text_to_speech(req: TTSRequest): source_text = (req.speech_text or req.text or "").strip() if len(source_text) > TTS_MAX_TOTAL_CHARS: raise HTTPException(status_code=413, detail=f"Text too long. Max supported total length is {TTS_MAX_TOTAL_CHARS} characters.") speech_text = source_text if req.rewrite_for_speech or (req.speech_prompt and req.speech_prompt.strip()): speech_text = await asyncio.to_thread(rewrite_text_for_speech, source_text, req.speech_prompt) text = normalize_tts_text(speech_text) if not text: raise HTTPException(status_code=400, detail="Text is required.") wav_bytes = await asyncio.to_thread(synthesize_vi_speech, text) if len(text) <= TTS_MAX_CHARS else await asyncio.to_thread(synthesize_vi_speech_long, text) return Response(content=wav_bytes, media_type="audio/wav")