Spaces:
Running
Running
Elpida Deploy Bot
deploy: 0215407b [FABLE] bump HF Claude provider model: sonnet-4-20250514 -> sonnet-4-6 (#182)
e4d382d | #!/usr/bin/env python3 | |
| """ | |
| Elpida Unified LLM Client | |
| ========================== | |
| Single source of truth for all LLM provider calls. | |
| Every module that needs to talk to an LLM imports from here. | |
| Providers supported: | |
| - Claude (Anthropic) | |
| - OpenAI (GPT) | |
| - Gemini (Google) | |
| - Grok (xAI) | |
| - Mistral | |
| - Cohere | |
| - Perplexity | |
| - OpenRouter (failsafe) | |
| - Groq | |
| - HuggingFace | |
| """ | |
| import os | |
| import time | |
| import json | |
| import logging | |
| import requests | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Dict, Any | |
| from enum import Enum | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| logger = logging.getLogger("elpida.llm_client") | |
| # --------------------------------------------------------------------------- | |
| # Provider registry | |
| # --------------------------------------------------------------------------- | |
| class Provider(str, Enum): | |
| CLAUDE = "claude" | |
| OPENAI = "openai" | |
| GEMINI = "gemini" | |
| GROK = "grok" | |
| MISTRAL = "mistral" | |
| COHERE = "cohere" | |
| PERPLEXITY = "perplexity" | |
| OPENROUTER = "openrouter" | |
| GROQ = "groq" | |
| HUGGINGFACE = "huggingface" | |
| DEEPSEEK = "deepseek" | |
| CEREBRAS = "cerebras" | |
| DOUBLEWORD = "doubleword" | |
| # Default models per provider β can be overridden per-call | |
| DEFAULT_MODELS: Dict[str, str] = { | |
| Provider.CLAUDE: "claude-sonnet-4-6", | |
| Provider.OPENAI: "gpt-4o-mini", | |
| Provider.GEMINI: "gemini-2.5-flash", | |
| Provider.GROK: "grok-3", | |
| Provider.MISTRAL: "mistral-small-latest", | |
| Provider.COHERE: "command-a-03-2025", | |
| Provider.PERPLEXITY: "sonar", | |
| Provider.OPENROUTER: "meta-llama/llama-4-scout-17b-16e-instruct", | |
| Provider.GROQ: "meta-llama/llama-4-scout-17b-16e-instruct", | |
| Provider.HUGGINGFACE: "Qwen/Qwen2.5-72B-Instruct", | |
| Provider.DEEPSEEK: "deepseek-chat", | |
| Provider.CEREBRAS: "qwen-3-235b-a22b-instruct-2507", | |
| Provider.DOUBLEWORD: "Qwen/Qwen3.6-35B-A3B-FP8", | |
| } | |
| # Env var name for each provider's API key | |
| API_KEY_ENV: Dict[str, str] = { | |
| Provider.CLAUDE: "ANTHROPIC_API_KEY", | |
| Provider.OPENAI: "OPENAI_API_KEY", | |
| Provider.GEMINI: "GEMINI_API_KEY", | |
| Provider.GROK: "XAI_API_KEY", | |
| Provider.MISTRAL: "MISTRAL_API_KEY", | |
| Provider.COHERE: "COHERE_API_KEY", | |
| Provider.PERPLEXITY: "PERPLEXITY_API_KEY", | |
| Provider.OPENROUTER: "OPENROUTER_API_KEY", | |
| Provider.GROQ: "GROQ_API_KEY", | |
| Provider.HUGGINGFACE: "HUGGINGFACE_API_KEY", | |
| Provider.DEEPSEEK: "DEEPSEEK_API_KEY", | |
| Provider.CEREBRAS: "CEREBRAS_API_KEY", | |
| Provider.DOUBLEWORD: "DOUBLEWORD_API_KEY", | |
| } | |
| # Cost per output token (approximate, for budget tracking) | |
| COST_PER_TOKEN: Dict[str, float] = { | |
| Provider.CLAUDE: 0.000003, | |
| Provider.OPENAI: 0.0, # gpt-4o-mini β negligible | |
| Provider.GEMINI: 0.0, | |
| Provider.GROK: 0.0000003, | |
| Provider.MISTRAL: 0.000001, | |
| Provider.COHERE: 0.0000005, | |
| Provider.PERPLEXITY: 0.0, | |
| Provider.OPENROUTER: 0.0, | |
| Provider.GROQ: 0.0, | |
| Provider.HUGGINGFACE:0.0, | |
| Provider.DEEPSEEK: 0.00000042, | |
| Provider.CEREBRAS: 0.0, | |
| Provider.DOUBLEWORD: 0.00000017, # ~18Γ cheaper than Claude (per HERMES daily synthesis measured cost) | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Stats tracking | |
| # --------------------------------------------------------------------------- | |
| class ProviderStats: | |
| """Usage stats for a single provider.""" | |
| requests: int = 0 | |
| tokens: int = 0 | |
| cost: float = 0.0 | |
| failures: int = 0 | |
| def to_dict(self) -> dict: | |
| return {"requests": self.requests, "tokens": self.tokens, | |
| "cost": round(self.cost, 6), "failures": self.failures} | |
| # --------------------------------------------------------------------------- | |
| # The unified client | |
| # --------------------------------------------------------------------------- | |
| class LLMClient: | |
| """ | |
| Unified LLM client for Elpida. | |
| Usage: | |
| client = LLMClient() | |
| text = client.call("claude", prompt) | |
| text = client.call("openai", prompt, max_tokens=800, model="gpt-4o") | |
| """ | |
| def __init__( | |
| self, | |
| rate_limit_seconds: float = 1.5, | |
| default_max_tokens: int = 600, | |
| default_timeout: int = 60, | |
| openrouter_failsafe: bool = True, | |
| ): | |
| self.rate_limit_seconds = rate_limit_seconds | |
| self.default_max_tokens = default_max_tokens | |
| self.default_timeout = default_timeout | |
| self.openrouter_failsafe = openrouter_failsafe | |
| # Load API keys from environment | |
| self.api_keys: Dict[str, Optional[str]] = { | |
| provider: os.getenv(env_var) | |
| for provider, env_var in API_KEY_ENV.items() | |
| } | |
| # Optional secondary Claude key (different Anthropic account) for 529 load-sharing | |
| self._claude_key_2: Optional[str] = os.getenv("ANTHROPIC_API_KEY_2") | |
| # Per-provider stats | |
| self.stats: Dict[str, ProviderStats] = { | |
| p.value: ProviderStats() for p in Provider | |
| } | |
| # Rate-limit timestamps | |
| self._last_call: Dict[str, float] = {} | |
| # Circuit breaker: trip after N consecutive failures, cooldown for COOLDOWN_S seconds. | |
| # While tripped the provider is bypassed and OpenRouter is used immediately. | |
| self._CB_THRESHOLD: int = 3 | |
| self._CB_COOLDOWN_S: int = 300 # 5 minutes | |
| self._cb_consec: Dict[str, int] = {} # consecutive failure count | |
| self._cb_until: Dict[str, float] = {} # epoch time when cooldown expires | |
| # Dispatch table | |
| self._dispatch = { | |
| Provider.CLAUDE: self._call_claude, | |
| Provider.OPENAI: self._call_openai_compat, | |
| Provider.GEMINI: self._call_gemini, | |
| Provider.GROK: self._call_openai_compat, | |
| Provider.MISTRAL: self._call_openai_compat, | |
| Provider.COHERE: self._call_cohere, | |
| Provider.PERPLEXITY: self._call_openai_compat, | |
| Provider.OPENROUTER: self._call_openai_compat, | |
| Provider.GROQ: self._call_openai_compat, | |
| Provider.HUGGINGFACE: self._call_openai_compat, | |
| Provider.DEEPSEEK: self._call_openai_compat, | |
| Provider.CEREBRAS: self._call_openai_compat, | |
| Provider.DOUBLEWORD: self._call_openai_compat, | |
| } | |
| # ----- public API ------------------------------------------------------- | |
| # ----- circuit breaker ------------------------------------------------- | |
| def _is_tripped(self, provider: str) -> bool: | |
| """Return True if this provider is in its cooldown window.""" | |
| until = self._cb_until.get(provider, 0.0) | |
| if until and time.time() < until: | |
| remaining = int(until - time.time()) | |
| logger.info("[CB] %s is tripped β %ds cooldown remaining, routing to OpenRouter", provider, remaining) | |
| return True | |
| return False | |
| def _cb_record_success(self, provider: str): | |
| """A successful call resets the consecutive-failure counter.""" | |
| if self._cb_consec.get(provider, 0) > 0: | |
| logger.info("[CB] %s circuit reset after success", provider) | |
| self._cb_consec[provider] = 0 | |
| self._cb_until[provider] = 0.0 | |
| def _cb_record_failure(self, provider: str): | |
| """Record a failure; trip the breaker if threshold is reached.""" | |
| count = self._cb_consec.get(provider, 0) + 1 | |
| self._cb_consec[provider] = count | |
| if count >= self._CB_THRESHOLD: | |
| self._cb_until[provider] = time.time() + self._CB_COOLDOWN_S | |
| logger.warning( | |
| "[CB] %s tripped after %d consecutive failures β " | |
| "bypassing for %ds, all calls go to OpenRouter", | |
| provider, count, self._CB_COOLDOWN_S, | |
| ) | |
| # Discord: notify #parliament-alerts | |
| try: | |
| from elpidaapp.telegram_bridge import post_circuit_breaker | |
| post_circuit_breaker(provider, "trip", count, self._CB_COOLDOWN_S) | |
| except Exception: | |
| pass | |
| # ----- public API ------------------------------------------------------- | |
| def call( | |
| self, | |
| provider: str, | |
| prompt: str, | |
| *, | |
| model: Optional[str] = None, | |
| max_tokens: Optional[int] = None, | |
| timeout: Optional[int] = None, | |
| system_prompt: Optional[str] = None, | |
| ) -> Optional[str]: | |
| """ | |
| Send a prompt to *provider* and return the text response, or None. | |
| If the primary provider fails and openrouter_failsafe is enabled, | |
| automatically retries via OpenRouter. | |
| """ | |
| provider = provider.lower().strip() | |
| if provider not in {p.value for p in Provider}: | |
| logger.warning("Unknown provider '%s', routing to OpenRouter", provider) | |
| provider = Provider.OPENROUTER.value | |
| _model = model or DEFAULT_MODELS.get(provider, "") | |
| _max = max_tokens or self.default_max_tokens | |
| _timeout = timeout or self.default_timeout | |
| # Circuit breaker check β skip directly to OpenRouter if provider is cooling down | |
| circuit_tripped = ( | |
| provider != Provider.OPENROUTER.value | |
| and self._is_tripped(provider) | |
| ) | |
| result = None | |
| if not circuit_tripped: | |
| self._rate_limit(provider) | |
| try: | |
| handler = self._dispatch.get(Provider(provider)) | |
| if handler: | |
| result = handler( | |
| provider=provider, | |
| prompt=prompt, | |
| model=_model, | |
| max_tokens=_max, | |
| timeout=_timeout, | |
| system_prompt=system_prompt, | |
| ) | |
| except Exception as e: | |
| logger.error("%s exception: %s", provider, e) | |
| self.stats[provider].failures += 1 | |
| if result is not None: | |
| self._cb_record_success(provider) | |
| else: | |
| self._cb_record_failure(provider) | |
| # HuggingFace silent fallback for Perplexity | |
| if result is None and provider == Provider.PERPLEXITY.value: | |
| logger.info("Perplexity failed β silent fallback to HuggingFace") | |
| try: | |
| result = self._dispatch[Provider.HUGGINGFACE]( | |
| provider=Provider.HUGGINGFACE.value, | |
| prompt=prompt, | |
| model=DEFAULT_MODELS[Provider.HUGGINGFACE], | |
| max_tokens=_max, | |
| timeout=_timeout, | |
| system_prompt=system_prompt, | |
| ) | |
| except Exception as e: | |
| logger.error("HuggingFace fallback exception: %s", e) | |
| # OpenRouter failsafe (last resort) | |
| if result is None and self.openrouter_failsafe and provider != Provider.OPENROUTER.value: | |
| logger.info("%s failed β trying OpenRouter failsafe", provider) | |
| result = self._openrouter_failsafe(prompt, _max, _timeout) | |
| return result | |
| def get_stats(self) -> Dict[str, dict]: | |
| """Return all provider stats as a serialisable dict.""" | |
| return {k: v.to_dict() for k, v in self.stats.items() if v.requests or v.failures} | |
| def call_with_citations( | |
| self, | |
| provider: str, | |
| prompt: str, | |
| *, | |
| model: Optional[str] = None, | |
| max_tokens: Optional[int] = None, | |
| timeout: Optional[int] = None, | |
| system_prompt: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Like call(), but returns {"text": str, "citations": list[str]}. | |
| Citations are URLs extracted from the Perplexity API response. | |
| If Perplexity fails or returns no citations, URLs are extracted | |
| from the response text as a fallback. | |
| """ | |
| provider = provider.lower().strip() | |
| if provider not in {p.value for p in Provider}: | |
| provider = Provider.OPENROUTER.value | |
| self._rate_limit(provider) | |
| _model = model or DEFAULT_MODELS.get(provider, "") | |
| _max = max_tokens or self.default_max_tokens | |
| _timeout = timeout or self.default_timeout | |
| result = None | |
| citations: List[str] = [] | |
| try: | |
| key = self.api_keys.get(provider) | |
| endpoint = self._OPENAI_COMPAT_ENDPOINTS.get(Provider(provider)) | |
| if key and endpoint: | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| response = requests.post( | |
| endpoint, | |
| headers={ | |
| "Authorization": f"Bearer {key}", | |
| "Content-Type": "application/json", | |
| }, | |
| json={ | |
| "model": _model, | |
| "messages": messages, | |
| "max_tokens": _max, | |
| }, | |
| timeout=_timeout, | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| result = data["choices"][0]["message"]["content"] | |
| # Perplexity returns citations at top level | |
| citations = data.get("citations", []) | |
| tokens = data.get("usage", {}).get("total_tokens", len(result) // 4) | |
| self.stats[provider].requests += 1 | |
| self.stats[provider].tokens += tokens | |
| self.stats[provider].cost += tokens * COST_PER_TOKEN.get(provider, 0) | |
| else: | |
| logger.warning("%s: HTTP %d", provider, response.status_code) | |
| self.stats[provider].failures += 1 | |
| except Exception as e: | |
| logger.error("%s citation call exception: %s", provider, e) | |
| self.stats[provider].failures += 1 | |
| # Fallback to regular call if citation-aware call failed | |
| if result is None: | |
| result = self.call( | |
| provider, prompt, | |
| model=model, max_tokens=max_tokens, | |
| timeout=timeout, system_prompt=system_prompt, | |
| ) | |
| # If no API-level citations, extract URLs from the text | |
| if not citations and result: | |
| import re | |
| urls = re.findall( | |
| r'https?://[^\s\)\]\}\,\"\'<>]+', | |
| result, | |
| ) | |
| # Deduplicate while preserving order | |
| seen = set() | |
| for u in urls: | |
| # Strip trailing punctuation | |
| u = u.rstrip('.,;:!?)') | |
| if u not in seen: | |
| seen.add(u) | |
| citations.append(u) | |
| return {"text": result, "citations": citations or []} | |
| def available_providers(self) -> list[str]: | |
| """Return list of providers that have API keys configured.""" | |
| return [p for p, key in self.api_keys.items() if key] | |
| # ----- rate limiter ----------------------------------------------------- | |
| def _rate_limit(self, provider: str): | |
| now = time.time() | |
| last = self._last_call.get(provider, 0) | |
| wait = self.rate_limit_seconds - (now - last) | |
| if wait > 0: | |
| time.sleep(wait) | |
| self._last_call[provider] = time.time() | |
| # ----- provider implementations ----------------------------------------- | |
| # Endpoints for OpenAI-compatible providers | |
| _OPENAI_COMPAT_ENDPOINTS = { | |
| Provider.OPENAI: "https://api.openai.com/v1/chat/completions", | |
| Provider.GROK: "https://api.x.ai/v1/chat/completions", | |
| Provider.MISTRAL: "https://api.mistral.ai/v1/chat/completions", | |
| Provider.PERPLEXITY: "https://api.perplexity.ai/chat/completions", | |
| Provider.OPENROUTER: "https://openrouter.ai/api/v1/chat/completions", | |
| Provider.GROQ: "https://api.groq.com/openai/v1/chat/completions", | |
| Provider.HUGGINGFACE: "https://router.huggingface.co/v1/chat/completions", | |
| Provider.DEEPSEEK: "https://api.deepseek.com/chat/completions", | |
| Provider.CEREBRAS: "https://api.cerebras.ai/v1/chat/completions", | |
| Provider.DOUBLEWORD: "https://api.doubleword.ai/v1/chat/completions", | |
| } | |
| def _call_openai_compat( | |
| self, *, provider: str, prompt: str, model: str, | |
| max_tokens: int, timeout: int, system_prompt: Optional[str], | |
| ) -> Optional[str]: | |
| """ | |
| Generic handler for all OpenAI-compatible chat/completions APIs. | |
| Covers: OpenAI, Grok, Mistral, Perplexity, OpenRouter, Groq, HuggingFace. | |
| """ | |
| key = self.api_keys.get(provider) | |
| if not key: | |
| logger.warning("%s: no API key (%s)", provider, API_KEY_ENV.get(provider, "?")) | |
| return None | |
| endpoint = self._OPENAI_COMPAT_ENDPOINTS.get(Provider(provider)) | |
| if not endpoint: | |
| logger.error("%s: no endpoint configured", provider) | |
| return None | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| try: | |
| response = requests.post( | |
| endpoint, | |
| headers={ | |
| "Authorization": f"Bearer {key}", | |
| "Content-Type": "application/json", | |
| }, | |
| json={ | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| }, | |
| timeout=timeout, | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| text = data["choices"][0]["message"]["content"] | |
| tokens = data.get("usage", {}).get("total_tokens", len(text) // 4) | |
| self.stats[provider].requests += 1 | |
| self.stats[provider].tokens += tokens | |
| self.stats[provider].cost += tokens * COST_PER_TOKEN.get(provider, 0) | |
| return text | |
| else: | |
| logger.warning("%s: HTTP %d β %s", provider, response.status_code, response.text[:200]) | |
| self.stats[provider].failures += 1 | |
| return None | |
| except Exception as e: | |
| logger.error("%s exception: %s", provider, e) | |
| self.stats[provider].failures += 1 | |
| return None | |
| def _call_claude( | |
| self, *, provider: str, prompt: str, model: str, | |
| max_tokens: int, timeout: int, system_prompt: Optional[str], | |
| ) -> Optional[str]: | |
| """Anthropic Messages API (non-OpenAI-compatible).""" | |
| key = self.api_keys.get(Provider.CLAUDE) | |
| if not key: | |
| logger.warning("Claude: no API key (ANTHROPIC_API_KEY)") | |
| return None | |
| messages = [{"role": "user", "content": prompt}] | |
| body: Dict[str, Any] = { | |
| "model": model, | |
| "max_tokens": max_tokens, | |
| "messages": messages, | |
| } | |
| if system_prompt: | |
| body["system"] = system_prompt | |
| # Try primary key first, then secondary key (different account) on 529, | |
| # then exponential-backoff retries before giving up to OpenRouter. | |
| keys_to_try = [key] | |
| if self._claude_key_2: | |
| keys_to_try.append(self._claude_key_2) | |
| for attempt, api_key in enumerate(keys_to_try): | |
| try: | |
| response = requests.post( | |
| "https://api.anthropic.com/v1/messages", | |
| headers={ | |
| "x-api-key": api_key, | |
| "anthropic-version": "2023-06-01", | |
| "Content-Type": "application/json", | |
| }, | |
| json=body, | |
| timeout=timeout, | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| text = data["content"][0]["text"] | |
| tokens = data.get("usage", {}).get("output_tokens", len(text) // 4) | |
| self.stats[Provider.CLAUDE].requests += 1 | |
| self.stats[Provider.CLAUDE].tokens += tokens | |
| self.stats[Provider.CLAUDE].cost += tokens * COST_PER_TOKEN.get(Provider.CLAUDE, 0) | |
| return text | |
| elif response.status_code == 529: | |
| label = "key_2" if attempt > 0 else "key_1" | |
| logger.warning("Claude (%s): 529 Overloaded β %s", label, response.text[:120]) | |
| self.stats[Provider.CLAUDE].failures += 1 | |
| # If secondary key also 529'd, fall through to backoff retries below | |
| continue | |
| else: | |
| logger.warning("Claude: HTTP %d β %s", response.status_code, response.text[:200]) | |
| self.stats[Provider.CLAUDE].failures += 1 | |
| return None | |
| except Exception as e: | |
| logger.error("Claude exception: %s", e) | |
| self.stats[Provider.CLAUDE].failures += 1 | |
| return None | |
| # Both keys (if present) returned 529 β backoff retry with primary key | |
| backoff_delays = [2, 4, 8] | |
| for delay in backoff_delays: | |
| logger.info("Claude 529 backoff β retrying in %ds", delay) | |
| time.sleep(delay) | |
| try: | |
| response = requests.post( | |
| "https://api.anthropic.com/v1/messages", | |
| headers={ | |
| "x-api-key": key, | |
| "anthropic-version": "2023-06-01", | |
| "Content-Type": "application/json", | |
| }, | |
| json=body, | |
| timeout=timeout, | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| text = data["content"][0]["text"] | |
| tokens = data.get("usage", {}).get("output_tokens", len(text) // 4) | |
| self.stats[Provider.CLAUDE].requests += 1 | |
| self.stats[Provider.CLAUDE].tokens += tokens | |
| self.stats[Provider.CLAUDE].cost += tokens * COST_PER_TOKEN.get(Provider.CLAUDE, 0) | |
| logger.info("Claude 529 recovered after %ds backoff", delay) | |
| return text | |
| elif response.status_code != 529: | |
| logger.warning("Claude backoff retry: HTTP %d β %s", response.status_code, response.text[:200]) | |
| return None | |
| except Exception as e: | |
| logger.error("Claude backoff retry exception: %s", e) | |
| return None | |
| logger.warning("Claude: all 529 retries exhausted β handing off to OpenRouter") | |
| return None | |
| def _call_gemini( | |
| self, *, provider: str, prompt: str, model: str, | |
| max_tokens: int, timeout: int, system_prompt: Optional[str], | |
| ) -> Optional[str]: | |
| """Google Generative AI API (non-OpenAI-compatible).""" | |
| key = self.api_keys.get(Provider.GEMINI) | |
| if not key: | |
| logger.warning("Gemini: no API key (GEMINI_API_KEY)") | |
| return None | |
| contents = [] | |
| if system_prompt: | |
| contents.append({"parts": [{"text": system_prompt}], "role": "user"}) | |
| contents.append({"parts": [{"text": "Understood."}], "role": "model"}) | |
| contents.append({"parts": [{"text": prompt}]}) | |
| # Gemini safety settings: governance/policy prompts trigger safety | |
| # filters that truncate output at ~150 chars. BLOCK_NONE disables | |
| # content filtering for all categories (matches MIND-side fix). | |
| safety_settings = [ | |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, | |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, | |
| {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, | |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, | |
| ] | |
| # Gemini 2.5 Flash is a "thinking" model β internal CoT tokens | |
| # consume the maxOutputTokens budget. With 1024, ~900 go to | |
| # thinking β only ~30 tokens (126 chars) for actual output. | |
| # Fix: disable thinking (governance prompts don't need CoT) AND | |
| # raise maxOutputTokens to 8192. Flash is free, no cost impact. | |
| effective_max = max(max_tokens, 8192) | |
| def _gemini_post() -> requests.Response: | |
| return requests.post( | |
| f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}", | |
| headers={"Content-Type": "application/json"}, | |
| json={ | |
| "contents": contents, | |
| "generationConfig": { | |
| "maxOutputTokens": effective_max, | |
| "thinkingConfig": {"thinkingBudget": 0}, | |
| }, | |
| "safetySettings": safety_settings, | |
| }, | |
| timeout=timeout, | |
| ) | |
| try: | |
| response = _gemini_post() | |
| if response.status_code == 200: | |
| data = response.json() | |
| candidates = data.get("candidates", []) | |
| if not candidates: | |
| reason = data.get("promptFeedback", {}).get("blockReason", "unknown") | |
| logger.warning("Gemini: no candidates β blockReason=%s", reason) | |
| self.stats[Provider.GEMINI].failures += 1 | |
| return None | |
| finish_reason = candidates[0].get("finishReason", "") | |
| text = candidates[0]["content"]["parts"][0]["text"] | |
| tokens = len(text) // 4 | |
| self.stats[Provider.GEMINI].requests += 1 | |
| self.stats[Provider.GEMINI].tokens += tokens | |
| if finish_reason == "SAFETY": | |
| logger.warning("Gemini: finishReason=SAFETY β output likely truncated (%d chars)", len(text)) | |
| elif finish_reason == "MAX_TOKENS": | |
| logger.info("Gemini: finishReason=MAX_TOKENS (%d chars)", len(text)) | |
| return text | |
| elif response.status_code == 429: | |
| logger.warning("Gemini: 429 Resource Exhausted β backoff retry") | |
| self.stats[Provider.GEMINI].failures += 1 | |
| for delay in [5, 15, 30]: | |
| logger.info("Gemini 429 backoff β retrying in %ds", delay) | |
| time.sleep(delay) | |
| try: | |
| r2 = _gemini_post() | |
| if r2.status_code == 200: | |
| data = r2.json() | |
| text = data["candidates"][0]["content"]["parts"][0]["text"] | |
| tokens = len(text) // 4 | |
| self.stats[Provider.GEMINI].requests += 1 | |
| self.stats[Provider.GEMINI].tokens += tokens | |
| logger.info("Gemini 429 recovered after %ds backoff", delay) | |
| return text | |
| elif r2.status_code != 429: | |
| logger.warning("Gemini backoff retry: HTTP %d", r2.status_code) | |
| return None | |
| except Exception as e: | |
| logger.error("Gemini backoff retry exception: %s", e) | |
| return None | |
| logger.warning("Gemini: all 429 retries exhausted β handing off to OpenRouter") | |
| return None | |
| else: | |
| logger.warning("Gemini: HTTP %d β %s", response.status_code, response.text[:200]) | |
| self.stats[Provider.GEMINI].failures += 1 | |
| return None | |
| except Exception as e: | |
| logger.error("Gemini exception: %s", e) | |
| self.stats[Provider.GEMINI].failures += 1 | |
| return None | |
| def _call_cohere( | |
| self, *, provider: str, prompt: str, model: str, | |
| max_tokens: int, timeout: int, system_prompt: Optional[str], | |
| ) -> Optional[str]: | |
| """Cohere v2 Chat API (non-OpenAI-compatible response format).""" | |
| key = self.api_keys.get(Provider.COHERE) | |
| if not key: | |
| logger.warning("Cohere: no API key (COHERE_API_KEY)") | |
| return None | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| try: | |
| response = requests.post( | |
| "https://api.cohere.com/v2/chat", | |
| headers={ | |
| "Authorization": f"Bearer {key}", | |
| "Content-Type": "application/json", | |
| }, | |
| json={ | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| }, | |
| timeout=timeout, | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| # Cohere v2 nests: data["message"]["content"][0]["text"] | |
| if "message" in data and "content" in data["message"]: | |
| text = data["message"]["content"][0]["text"] | |
| else: | |
| logger.warning("Cohere: unexpected response format: %s", str(data)[:200]) | |
| self.stats[Provider.COHERE].failures += 1 | |
| return None | |
| tokens = (data.get("usage", {}) | |
| .get("billed_units", {}) | |
| .get("output_tokens", len(text) // 4)) | |
| self.stats[Provider.COHERE].requests += 1 | |
| self.stats[Provider.COHERE].tokens += tokens | |
| self.stats[Provider.COHERE].cost += tokens * COST_PER_TOKEN.get(Provider.COHERE, 0) | |
| return text | |
| else: | |
| logger.warning("Cohere: HTTP %d β %s", response.status_code, response.text[:200]) | |
| self.stats[Provider.COHERE].failures += 1 | |
| return None | |
| except Exception as e: | |
| logger.error("Cohere exception: %s", e) | |
| self.stats[Provider.COHERE].failures += 1 | |
| return None | |
| # ----- OpenRouter failsafe ---------------------------------------------- | |
| def _openrouter_failsafe( | |
| self, prompt: str, max_tokens: int, timeout: int | |
| ) -> Optional[str]: | |
| """Last-resort call via OpenRouter.""" | |
| return self._call_openai_compat( | |
| provider=Provider.OPENROUTER.value, | |
| prompt=prompt, | |
| model=DEFAULT_MODELS[Provider.OPENROUTER], | |
| max_tokens=max_tokens, | |
| timeout=timeout, | |
| system_prompt=None, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Module-level convenience | |
| # --------------------------------------------------------------------------- | |
| _default_client: Optional[LLMClient] = None | |
| def get_client(**kwargs) -> LLMClient: | |
| """Return (or create) the module-level default LLMClient singleton.""" | |
| global _default_client | |
| if _default_client is None: | |
| _default_client = LLMClient(**kwargs) | |
| return _default_client | |
| def call(provider: str, prompt: str, **kwargs) -> Optional[str]: | |
| """Shortcut: call a provider using the default client.""" | |
| return get_client().call(provider, prompt, **kwargs) | |