# ═══════════════════════════════════════════════════════════════════════════════ # File: app/core/prompt_cache.py # Description: SHA256-based Prompt Cache with TTL for Token Storm Prevention # ═══════════════════════════════════════════════════════════════════════════════ """ Prompt Cache - Prevents token storms during retries and fallbacks. KEY DESIGN: - SHA256 fingerprint of (system_prompt + conversation_history + user_message) - TTL: 30-120 seconds configurable per session - Per-session isolation (different sessions don't share cache) USAGE: cache = PromptCache() fingerprint = cache.get_fingerprint(system, history, user_msg) if cached := cache.get(fingerprint, session_id): return cached # Skip LLM call response = await llm.generate(...) cache.set(fingerprint, session_id, response, ttl=60) """ import hashlib import json import time from typing import Dict, Optional, Any from dataclasses import dataclass, field from threading import Lock @dataclass class CacheEntry: """Single cache entry with expiration.""" response: Any # LLMResponse expires_at: float model: str # Track which model generated this fingerprint: str class PromptCache: """ Thread-safe prompt cache with SHA256 fingerprinting and TTL. Prevents token storms by returning cached responses for identical prompts. Scoped per-session to avoid cross-contamination. """ DEFAULT_TTL = 60 # seconds MIN_TTL = 30 MAX_TTL = 120 MAX_ENTRIES_PER_SESSION = 50 # Prevent memory bloat def __init__(self): self._cache: Dict[str, Dict[str, CacheEntry]] = {} # session_id -> {fingerprint -> entry} self._lock = Lock() self._stats = {"hits": 0, "misses": 0, "sets": 0, "evictions": 0} def get_fingerprint( self, system_prompt: str, conversation_history: list, user_message: str, role: str = "" ) -> str: """ Generate SHA256 fingerprint for a prompt combination. Args: system_prompt: The system prompt conversation_history: List of message dicts user_message: Current user message role: Optional role context (FAST_CHAT, SMART_REASONING, etc.) Returns: SHA256 hex digest (first 16 chars for efficiency) """ # Normalize inputs normalized_history = json.dumps(conversation_history or [], sort_keys=True) raw_content = f"{system_prompt}|||{normalized_history}|||{user_message}|||{role}" # Generate hash full_hash = hashlib.sha256(raw_content.encode("utf-8")).hexdigest() return full_hash[:16] # Truncate for efficiency (still 64-bit collision resistance) def get( self, fingerprint: str, session_id: str = "default" ) -> Optional[Any]: """ Retrieve cached response if exists and not expired. Args: fingerprint: SHA256 fingerprint from get_fingerprint() session_id: Session identifier for isolation Returns: Cached LLMResponse or None if miss/expired """ with self._lock: session_cache = self._cache.get(session_id, {}) entry = session_cache.get(fingerprint) if entry is None: self._stats["misses"] += 1 return None # Check expiration if time.time() > entry.expires_at: # Expired - remove and return None del session_cache[fingerprint] self._stats["misses"] += 1 self._stats["evictions"] += 1 return None # Cache hit! self._stats["hits"] += 1 return entry.response def set( self, fingerprint: str, session_id: str, response: Any, model: str = "unknown", ttl: int = None ) -> None: """ Store response in cache with TTL. Args: fingerprint: SHA256 fingerprint session_id: Session identifier response: LLMResponse to cache model: Model that generated this response ttl: Time-to-live in seconds (clamped to MIN_TTL-MAX_TTL) """ # Clamp TTL ttl = ttl or self.DEFAULT_TTL ttl = max(self.MIN_TTL, min(self.MAX_TTL, ttl)) with self._lock: # Initialize session cache if needed if session_id not in self._cache: self._cache[session_id] = {} session_cache = self._cache[session_id] # Evict oldest if at capacity if len(session_cache) >= self.MAX_ENTRIES_PER_SESSION: oldest_fp = min(session_cache.keys(), key=lambda k: session_cache[k].expires_at) del session_cache[oldest_fp] self._stats["evictions"] += 1 # Store entry session_cache[fingerprint] = CacheEntry( response=response, expires_at=time.time() + ttl, model=model, fingerprint=fingerprint ) self._stats["sets"] += 1 def clear_session(self, session_id: str) -> int: """Clear all cached entries for a session. Returns count cleared.""" with self._lock: if session_id in self._cache: count = len(self._cache[session_id]) del self._cache[session_id] return count return 0 def clear_expired(self) -> int: """Remove all expired entries across all sessions. Returns count evicted.""" evicted = 0 current_time = time.time() with self._lock: for session_id in list(self._cache.keys()): session_cache = self._cache[session_id] expired_fps = [ fp for fp, entry in session_cache.items() if current_time > entry.expires_at ] for fp in expired_fps: del session_cache[fp] evicted += 1 # Remove empty sessions if not session_cache: del self._cache[session_id] self._stats["evictions"] += evicted return evicted def get_stats(self) -> Dict[str, int]: """Get cache statistics.""" with self._lock: total_entries = sum(len(sc) for sc in self._cache.values()) return { **self._stats, "total_entries": total_entries, "sessions": len(self._cache), "hit_rate": ( self._stats["hits"] / (self._stats["hits"] + self._stats["misses"]) if (self._stats["hits"] + self._stats["misses"]) > 0 else 0.0 ) } # Global cache instance prompt_cache = PromptCache() __all__ = ["PromptCache", "prompt_cache"]