minhvtt commited on
Commit
c6c0abc
·
verified ·
1 Parent(s): ce85c02

Upload routes_team_chat.py

Browse files
Files changed (1) hide show
  1. routes_team_chat.py +59 -23
routes_team_chat.py CHANGED
@@ -225,34 +225,44 @@ def _has_requirement_node_payload(payload: Dict[str, Any]) -> bool:
225
  )
226
 
227
 
 
 
 
 
228
  def _truncate_prompt_context(
229
  ctx: Dict[str, Any],
230
  *,
231
  max_section_content: int = 1500,
232
- max_sections: int = 10,
233
- max_messages: int = 15,
234
- max_msg_content: int = 500,
 
235
  max_grounded_answer: int = 2000,
236
  max_qa_memory: int = 1500,
237
- max_index_nodes: int = 30,
 
238
  ) -> Dict[str, Any]:
239
  """Return a size-bounded copy of prompt_context before sending to NVIDIA."""
240
  ctx = dict(ctx)
241
 
242
- # Truncate document sections (most expensive — can be huge)
243
- sections = list(ctx.get("documents_sections") or [])
244
- if len(sections) > max_sections:
245
- sections = sections[:max_sections]
246
  safe_sections = []
247
  for sec in sections:
248
  sec = dict(sec)
249
- content = str(sec.get("content") or "")
250
- if len(content) > max_section_content:
251
- sec["content"] = content[:max_section_content] + "…[truncated]"
 
 
 
 
 
 
252
  safe_sections.append(sec)
253
  ctx["documents_sections"] = safe_sections
254
 
255
- # Limit + truncate chat messages
256
  for key in ("selected_messages", "fallback_messages"):
257
  msgs = list(ctx.get(key) or [])
258
  if len(msgs) > max_messages:
@@ -262,30 +272,50 @@ def _truncate_prompt_context(
262
  m = dict(m)
263
  content = str(m.get("content") or "")
264
  if len(content) > max_msg_content:
265
- m["content"] = content[:max_msg_content] + "…"
266
  trimmed.append(m)
267
  ctx[key] = trimmed
268
 
269
- # Truncate LLM-generated answer that may be large
270
  answer = str(ctx.get("document_grounded_answer") or "")
271
  if len(answer) > max_grounded_answer:
272
- ctx["document_grounded_answer"] = answer[:max_grounded_answer] + "…"
273
 
274
- # Truncate accumulated QA memory
275
  memory = str(ctx.get("doc_qa_memory") or "")
276
  if len(memory) > max_qa_memory:
277
- ctx["doc_qa_memory"] = memory[:max_qa_memory] + "…"
278
 
279
- # Limit index nodes per document (already minimal fields, just guard count)
280
  doc_indexes = list(ctx.get("documents_index") or [])
281
  for di in doc_indexes:
282
  nodes = di.get("nodes") or []
283
  if len(nodes) > max_index_nodes:
284
  di["nodes"] = nodes[:max_index_nodes]
 
 
 
 
 
285
  ctx["documents_index"] = doc_indexes
286
 
287
- # Drop large debug metadata that the agent does not use
288
  ctx.pop("documents_retrieval_meta", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  return ctx
291
 
@@ -555,10 +585,16 @@ async def team_chat(req: TeamChatRequest, x_session_token: Optional[str] = Heade
555
  )
556
  qa_memory = get_team_doc_qa_memory(req.team_id, req.project_id)
557
 
 
 
 
558
  doc_indexes = []
559
  for doc in team_docs:
560
  tree = doc.get("tree") or {}
561
  nodes = tree.get("nodes") or []
 
 
 
562
  doc_indexes.append(
563
  {
564
  "document_id": doc.get("id"),
@@ -568,12 +604,12 @@ async def team_chat(req: TeamChatRequest, x_session_token: Optional[str] = Heade
568
  {
569
  "id": node.get("id"),
570
  "parent_id": node.get("parent_id"),
571
- "title": node.get("title"),
572
- "summary": node.get("summary"),
573
- "scope": node.get("scope"),
574
  "level": node.get("level"),
575
  }
576
- for node in nodes
577
  ],
578
  }
579
  )
 
225
  )
226
 
227
 
228
+ def _clip(text: str, limit: int) -> str:
229
+ return text if len(text) <= limit else text[:limit] + "…[truncated]"
230
+
231
+
232
  def _truncate_prompt_context(
233
  ctx: Dict[str, Any],
234
  *,
235
  max_section_content: int = 1500,
236
+ max_section_summary: int = 300,
237
+ max_sections: int = 8,
238
+ max_messages: int = 12,
239
+ max_msg_content: int = 400,
240
  max_grounded_answer: int = 2000,
241
  max_qa_memory: int = 1500,
242
+ max_index_nodes: int = 25,
243
+ max_index_summary: int = 200,
244
  ) -> Dict[str, Any]:
245
  """Return a size-bounded copy of prompt_context before sending to NVIDIA."""
246
  ctx = dict(ctx)
247
 
248
+ # ── Truncate document sections (biggest offender) ──
249
+ sections = list(ctx.get("documents_sections") or [])[:max_sections]
 
 
250
  safe_sections = []
251
  for sec in sections:
252
  sec = dict(sec)
253
+ # Field names from retrieve_document_context_with_tree are section_content / section_summary / section_context
254
+ for field in ("section_content", "content"):
255
+ val = sec.get(field)
256
+ if val and len(str(val)) > max_section_content:
257
+ sec[field] = _clip(str(val), max_section_content)
258
+ for field in ("section_summary", "summary", "section_context"):
259
+ val = sec.get(field)
260
+ if val and len(str(val)) > max_section_summary:
261
+ sec[field] = _clip(str(val), max_section_summary)
262
  safe_sections.append(sec)
263
  ctx["documents_sections"] = safe_sections
264
 
265
+ # ── Limit + truncate chat messages ──
266
  for key in ("selected_messages", "fallback_messages"):
267
  msgs = list(ctx.get(key) or [])
268
  if len(msgs) > max_messages:
 
272
  m = dict(m)
273
  content = str(m.get("content") or "")
274
  if len(content) > max_msg_content:
275
+ m["content"] = _clip(content, max_msg_content)
276
  trimmed.append(m)
277
  ctx[key] = trimmed
278
 
279
+ # ── Truncate LLM-generated answer ──
280
  answer = str(ctx.get("document_grounded_answer") or "")
281
  if len(answer) > max_grounded_answer:
282
+ ctx["document_grounded_answer"] = _clip(answer, max_grounded_answer)
283
 
284
+ # ── Truncate accumulated QA memory ──
285
  memory = str(ctx.get("doc_qa_memory") or "")
286
  if len(memory) > max_qa_memory:
287
+ ctx["doc_qa_memory"] = _clip(memory, max_qa_memory)
288
 
289
+ # ── Limit document index nodes (82K nodes × 2 docs = main overflow source) ──
290
  doc_indexes = list(ctx.get("documents_index") or [])
291
  for di in doc_indexes:
292
  nodes = di.get("nodes") or []
293
  if len(nodes) > max_index_nodes:
294
  di["nodes"] = nodes[:max_index_nodes]
295
+ for node in di.get("nodes", []):
296
+ for field in ("summary", "scope"):
297
+ val = node.get(field)
298
+ if val and len(str(val)) > max_index_summary:
299
+ node[field] = _clip(str(val), max_index_summary)
300
  ctx["documents_index"] = doc_indexes
301
 
302
+ # ── Drop large debug metadata the agent does not use ──
303
  ctx.pop("documents_retrieval_meta", None)
304
+ # Also strip citations detail (agent doesn't act on them)
305
+ ctx.pop("documents_citations", None)
306
+ ctx.pop("document_grounded_citations", None)
307
+
308
+ # ── Hard safety cap: if serialized prompt still too large, aggressively trim ──
309
+ _MAX_PROMPT_CHARS = 180_000 # ~60K tokens, well under 262K token limit
310
+ serialized = json.dumps(ctx, ensure_ascii=False, default=str)
311
+ if len(serialized) > _MAX_PROMPT_CHARS:
312
+ # Emergency: drop the heaviest fields until under budget
313
+ for drop_key in ("documents_sections", "documents_index", "agent_context", "doc_qa_memory"):
314
+ if len(serialized) <= _MAX_PROMPT_CHARS:
315
+ break
316
+ if drop_key in ctx:
317
+ ctx[drop_key] = [] if isinstance(ctx.get(drop_key), list) else ""
318
+ serialized = json.dumps(ctx, ensure_ascii=False, default=str)
319
 
320
  return ctx
321
 
 
585
  )
586
  qa_memory = get_team_doc_qa_memory(req.team_id, req.project_id)
587
 
588
+ # Build lightweight index — only top-level nodes (level <= 2) to avoid
589
+ # sending 82K+ nodes per document to the NVIDIA agent.
590
+ _MAX_INDEX_NODES_PER_DOC = 25
591
  doc_indexes = []
592
  for doc in team_docs:
593
  tree = doc.get("tree") or {}
594
  nodes = tree.get("nodes") or []
595
+ top_nodes = [n for n in nodes if (n.get("level") or 0) <= 2][:_MAX_INDEX_NODES_PER_DOC]
596
+ if not top_nodes:
597
+ top_nodes = nodes[:_MAX_INDEX_NODES_PER_DOC]
598
  doc_indexes.append(
599
  {
600
  "document_id": doc.get("id"),
 
604
  {
605
  "id": node.get("id"),
606
  "parent_id": node.get("parent_id"),
607
+ "title": _clip(str(node.get("title") or ""), 120),
608
+ "summary": _clip(str(node.get("summary") or ""), 200),
609
+ "scope": _clip(str(node.get("scope") or ""), 100),
610
  "level": node.get("level"),
611
  }
612
+ for node in top_nodes
613
  ],
614
  }
615
  )