"""Witness Engine: owns the two-call pattern. Call A: stream_interrogation — free-form prose, streamed token-by-token. Call B: take_deposition — JSON-constrained, one-shot, validated downstream. Two backends: - WitnessEngine — local llama.cpp via llama-cpp-python (GBNF-constrained). - InferenceProvidersWitnessEngine — hosted call via huggingface_hub.InferenceClient (response_format json_object + schema-in-prompt + parse_and_validate_mass). """ from __future__ import annotations import json import os from dataclasses import dataclass from typing import Iterator, Protocol from .grammar import gbnf_for_dock, parse_and_validate_mass class LlamaLike(Protocol): def create_chat_completion( self, messages, stream: bool = False, grammar=None, max_tokens: int = 512, temperature: float = 0.7, ): ... @dataclass class Turn: role: str # "user" | "assistant" content: str def load_llama(gguf_path: str, n_ctx: int = 4096) -> LlamaLike: """Lazy import so tests don't require llama-cpp-python installed.""" from llama_cpp import Llama # chat_format="chatml" matches Qwen2.5's actual chat template — without this, # llama-cpp-python's auto-detection sometimes drops/mangles tokens at template boundaries. return Llama( model_path=gguf_path, n_ctx=n_ctx, n_gpu_layers=-1, chat_format="chatml", verbose=False, ) class WitnessEngine: def __init__(self, llama: LlamaLike): self.llm = llama @staticmethod def _messages(system_prompt: str, turns: list[Turn], extra_user: str | None = None): msgs = [{"role": "system", "content": system_prompt}] for t in turns: msgs.append({"role": t.role, "content": t.content}) if extra_user is not None: msgs.append({"role": "user", "content": extra_user}) return msgs def stream_interrogation(self, system_prompt: str, turns: list[Turn]) -> Iterator[str]: msgs = self._messages(system_prompt, turns) stream = self.llm.create_chat_completion( messages=msgs, stream=True, temperature=0.6, top_p=0.9, repeat_penalty=1.05, max_tokens=256, ) for chunk in stream: delta = chunk.get("choices", [{}])[0].get("delta", {}) piece = delta.get("content") if piece: yield piece def take_deposition( self, system_prompt: str, turns: list[Turn], dock: list[str], ) -> tuple[dict[frozenset[str], float], str]: prompt = ( "Now, formally, state your belief about which of these suspects is responsible: " + ", ".join(dock) + ". Respond ONLY with a JSON object matching the required schema, " "where masses are non-negative and sum to 1." ) grammar_str = gbnf_for_dock(dock) try: from llama_cpp import LlamaGrammar grammar_obj = LlamaGrammar.from_string(grammar_str) except Exception: grammar_obj = grammar_str # tests pass a fake; structure-only msgs = self._messages(system_prompt, turns, extra_user=prompt) resp = self.llm.create_chat_completion( messages=msgs, stream=False, grammar=grammar_obj, temperature=0.2, max_tokens=400, ) raw = resp["choices"][0]["message"]["content"] return parse_and_validate_mass(raw, dock) DEFAULT_PROVIDER_MODEL = os.environ.get( "WITNESS_MODEL", "Qwen/Qwen2.5-7B-Instruct" ) def _deposition_schema_hint(dock: list[str]) -> str: sample = { "focal_elements": [ {"suspects": [dock[0]], "mass": 0.7}, {"suspects": list(dock), "mass": 0.3}, ], "one_line_summary": "One sentence summarising what you saw.", } return ( 'Respond with ONLY a JSON object of the form ' + json.dumps(sample) + '. Rules: "suspects" must be a non-empty subset of ' + json.dumps(dock) + '. "mass" values are in [0,1] and sum to ~1. The full-Dock subset ' '(all suspects) represents "I am not sure".' ) class InferenceProvidersWitnessEngine: """Witness engine backed by HF Inference Providers (OpenAI-compatible).""" def __init__(self, model: str = DEFAULT_PROVIDER_MODEL, token: str | None = None): from huggingface_hub import InferenceClient self.model = model self.client = InferenceClient( provider="auto", token=token or os.environ.get("HF_TOKEN"), ) @staticmethod def _messages(system_prompt: str, turns: list[Turn], extra_user: str | None = None): msgs = [{"role": "system", "content": system_prompt}] for t in turns: msgs.append({"role": t.role, "content": t.content}) if extra_user is not None: msgs.append({"role": "user", "content": extra_user}) return msgs def stream_interrogation(self, system_prompt: str, turns: list[Turn]) -> Iterator[str]: msgs = self._messages(system_prompt, turns) stream = self.client.chat.completions.create( model=self.model, messages=msgs, stream=True, temperature=0.6, top_p=0.9, max_tokens=256, ) for chunk in stream: delta = getattr(chunk.choices[0], "delta", None) piece = getattr(delta, "content", None) if delta else None if piece: yield piece def take_deposition( self, system_prompt: str, turns: list[Turn], dock: list[str], ) -> tuple[dict[frozenset[str], float], str]: schema_hint = _deposition_schema_hint(dock) prompt = ( "Formally, state your belief about which of these suspects is responsible: " + ", ".join(dock) + ". " + schema_hint ) msgs = self._messages(system_prompt, turns, extra_user=prompt) last_err: Exception | None = None for attempt in range(2): try: resp = self.client.chat.completions.create( model=self.model, messages=msgs, stream=False, temperature=0.2, max_tokens=400, response_format={"type": "json_object"}, ) raw = resp.choices[0].message.content or "" return parse_and_validate_mass(raw, dock) except Exception as e: last_err = e msgs.append({"role": "user", "content": "Your previous response was not valid. " + schema_hint}) raise RuntimeError(f"deposition failed: {last_err}")