Spaces:
Build error
fix: bundle ast-grep/difftastic/scc + generic tool_result interceptor framework
Browse filesWhat this does, in plain terms:
Headroom's proxy now ships with three CLI tools (ast-grep, difftastic,
scc) that it can use to shrink tool_result payloads before they reach
the model. The goal is simple: when Claude Code (or Codex, Aider, etc.)
asks the model to reason about a big file or diff, we swap the verbose
output for a compact, same-meaning version. Fewer tokens per turn, same
answers, lower bill.
Today a single interceptor is wired: ast-grep on Read. When an agent
reads a large code file, the proxy replaces the file body with an
outline of its top-level functions/classes plus docstrings. In live
tests that cut prompt tokens 74–76% on both OpenAI and Anthropic,
same answer either way.
How it works:
- `pip install headroom-ai` now installs ast-grep via a PyPI wheel
(core dep). difftastic and scc are fetched once at proxy startup
from pinned upstream GitHub releases and cached per-user.
- A generic registry (`headroom/proxy/interceptors/`) lets us add more
tool-aware rewrites in one file each: declare `matches()` and
`transform()`, call `register()`, done. No proxy or metrics plumbing
per tool.
- Safety rails built in: pass-through when a Read specifies a line
range; second Read of the same file in a conversation returns full
content (progressive disclosure); any failing interceptor logs and
skips, never crashes a request.
Opt-in for now:
- Off by default while this ships. Turn on with
`headroom proxy --intercept-tool-results` or
`HEADROOM_INTERCEPT_ENABLED=1`, so we can measure before flipping
defaults.
What users see after turning it on:
- First `headroom wrap claude` boot is ~5s longer (binaries fetched).
Every subsequent run is cache-only.
- Existing `transforms_applied` field in metrics gets entries like
`interceptor:ast-grep`, so savings show up in current dashboards
and HTML reports with no UI change.
Other housekeeping in this PR:
- uv.lock moved to .gitignore — regenerated locally per environment.
- 35 unit + integration tests, ruff + mypy clean.
- Dead-code audit done: removed `binaries.run()`, `needs_filesystem`
plumbing, unused `_kind` tuple elements, unused `tool_output`
parameter, and the never-set HEADROOM_SKIP_TOOLS_BOOTSTRAP env.
- .gitignore +3 -0
- headroom/binaries.py +494 -0
- headroom/cli/main.py +1 -0
- headroom/cli/proxy.py +21 -0
- headroom/cli/tools.py +226 -0
- headroom/proxy/interceptors/__init__.py +32 -0
- headroom/proxy/interceptors/astgrep.py +246 -0
- headroom/proxy/interceptors/base.py +261 -0
- headroom/tools.json +89 -0
- headroom/transforms/pipeline.py +12 -0
- pyproject.toml +1 -0
- tests/test_binaries.py +281 -0
- tests/test_bundled_tools_savings.py +367 -0
- tests/test_tool_result_interceptors.py +400 -0
|
@@ -213,3 +213,6 @@ headroom-managed/
|
|
| 213 |
|
| 214 |
# Release metadata artifact
|
| 215 |
.releaseetadata
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
# Release metadata artifact
|
| 215 |
.releaseetadata
|
| 216 |
+
|
| 217 |
+
# uv lockfile: regenerated locally; not committed
|
| 218 |
+
uv.lock
|
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fetcher for bundled CLI tool binaries.
|
| 2 |
+
|
| 3 |
+
`pip install headroom-ai` pulls `ast-grep-cli` as a proper PyPI binary wheel
|
| 4 |
+
(core dependency), so ast-grep is always on PATH. The other two high-value
|
| 5 |
+
tools — `difft` (difftastic) and `scc` — are fetched from pinned upstream
|
| 6 |
+
GitHub releases at proxy startup, verified, cached per-user, and exec'd.
|
| 7 |
+
|
| 8 |
+
Supported platforms: linux (glibc + musl) x86_64/aarch64, macOS x86_64/arm64,
|
| 9 |
+
Windows x86_64. Unsupported platforms raise PlatformNotSupported; callers in
|
| 10 |
+
the compression pipeline should fall back to their non-accelerated path.
|
| 11 |
+
|
| 12 |
+
Env vars:
|
| 13 |
+
HEADROOM_BINARIES_MIRROR base URL that replaces https://github.com
|
| 14 |
+
HEADROOM_BINARIES_CACHE override cache dir
|
| 15 |
+
HEADROOM_BINARIES_OFFLINE if set, never reach the network
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import functools
|
| 21 |
+
import hashlib
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import platform
|
| 25 |
+
import shutil
|
| 26 |
+
import subprocess
|
| 27 |
+
import sys
|
| 28 |
+
import tarfile
|
| 29 |
+
import tempfile
|
| 30 |
+
import urllib.error
|
| 31 |
+
import urllib.request
|
| 32 |
+
import zipfile
|
| 33 |
+
from dataclasses import dataclass
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
from typing import Any
|
| 36 |
+
|
| 37 |
+
__all__ = [
|
| 38 |
+
"BinaryError",
|
| 39 |
+
"BinaryFetchError",
|
| 40 |
+
"PlatformNotSupported",
|
| 41 |
+
"Sha256Mismatch",
|
| 42 |
+
"OfflineError",
|
| 43 |
+
"PlatformKey",
|
| 44 |
+
"detect_platform",
|
| 45 |
+
"cache_dir",
|
| 46 |
+
"resolve",
|
| 47 |
+
"which",
|
| 48 |
+
"status",
|
| 49 |
+
"ensure_tools",
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------- Exceptions ---------------------------------------------------- #
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class BinaryError(Exception):
|
| 57 |
+
"""Base exception for the binaries module."""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class BinaryFetchError(BinaryError):
|
| 61 |
+
"""Raised when a download fails or an archive cannot be extracted."""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class PlatformNotSupported(BinaryError):
|
| 65 |
+
"""Raised when the current OS/arch is not covered by a tool's registry."""
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Sha256Mismatch(BinaryError):
|
| 69 |
+
"""Raised when a downloaded asset's SHA256 does not match the pin."""
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class OfflineError(BinaryError):
|
| 73 |
+
"""Raised when a network fetch is required but HEADROOM_BINARIES_OFFLINE is set."""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ---------- Platform detection -------------------------------------------- #
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass(frozen=True)
|
| 80 |
+
class PlatformKey:
|
| 81 |
+
os: str # "linux" | "darwin" | "windows"
|
| 82 |
+
arch: str # "x86_64" | "aarch64"
|
| 83 |
+
libc: str # "gnu" | "musl" | "n/a"
|
| 84 |
+
|
| 85 |
+
def key(self) -> str:
|
| 86 |
+
# Compact form used as registry lookup key and cache subdirectory.
|
| 87 |
+
if self.os == "linux":
|
| 88 |
+
return f"{self.os}-{self.arch}-{self.libc}"
|
| 89 |
+
return f"{self.os}-{self.arch}"
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _machine_to_arch(machine: str) -> str:
|
| 93 |
+
m = machine.lower()
|
| 94 |
+
if m in ("x86_64", "amd64"):
|
| 95 |
+
return "x86_64"
|
| 96 |
+
if m in ("aarch64", "arm64"):
|
| 97 |
+
return "aarch64"
|
| 98 |
+
return m # return as-is; lookup will fail cleanly with PlatformNotSupported
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _is_musl() -> bool:
|
| 102 |
+
# Best-effort musl detection on Linux. Never raises.
|
| 103 |
+
try:
|
| 104 |
+
out = subprocess.run(
|
| 105 |
+
["ldd", "--version"],
|
| 106 |
+
capture_output=True,
|
| 107 |
+
text=True,
|
| 108 |
+
timeout=2,
|
| 109 |
+
check=False,
|
| 110 |
+
)
|
| 111 |
+
return "musl" in (out.stdout + out.stderr).lower()
|
| 112 |
+
except (FileNotFoundError, subprocess.TimeoutExpired, OSError):
|
| 113 |
+
return False
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@functools.lru_cache(maxsize=1)
|
| 117 |
+
def detect_platform() -> PlatformKey:
|
| 118 |
+
arch = _machine_to_arch(platform.machine())
|
| 119 |
+
if sys.platform.startswith("linux"):
|
| 120 |
+
return PlatformKey("linux", arch, "musl" if _is_musl() else "gnu")
|
| 121 |
+
if sys.platform == "darwin":
|
| 122 |
+
return PlatformKey("darwin", arch, "n/a")
|
| 123 |
+
if sys.platform.startswith("win"):
|
| 124 |
+
return PlatformKey("windows", arch, "n/a")
|
| 125 |
+
return PlatformKey(sys.platform, arch, "n/a")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ---------- Cache dir ----------------------------------------------------- #
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def cache_dir() -> Path:
|
| 132 |
+
override = os.environ.get("HEADROOM_BINARIES_CACHE")
|
| 133 |
+
if override:
|
| 134 |
+
return Path(override).expanduser().resolve()
|
| 135 |
+
if sys.platform.startswith("win"):
|
| 136 |
+
base = os.environ.get("LOCALAPPDATA") or str(Path.home() / "AppData" / "Local")
|
| 137 |
+
return Path(base) / "headroom" / "bin"
|
| 138 |
+
if sys.platform == "darwin":
|
| 139 |
+
return Path.home() / "Library" / "Caches" / "headroom" / "bin"
|
| 140 |
+
xdg = os.environ.get("XDG_CACHE_HOME")
|
| 141 |
+
base = Path(xdg) if xdg else Path.home() / ".cache"
|
| 142 |
+
return base / "headroom" / "bin"
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# ---------- Registry ------------------------------------------------------ #
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
_REGISTRY_PATH = Path(__file__).parent / "tools.json"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@functools.lru_cache(maxsize=1)
|
| 152 |
+
def _registry() -> dict[str, Any]:
|
| 153 |
+
with _REGISTRY_PATH.open("r", encoding="utf-8") as f:
|
| 154 |
+
data: dict[str, Any] = json.load(f)
|
| 155 |
+
return data
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _tool_entry(tool: str) -> dict[str, Any]:
|
| 159 |
+
reg = _registry()
|
| 160 |
+
tools: dict[str, Any] = reg.get("tools", {})
|
| 161 |
+
if tool not in tools:
|
| 162 |
+
raise KeyError(f"unknown tool {tool!r}; known: {sorted(tools)}")
|
| 163 |
+
entry: dict[str, Any] = tools[tool]
|
| 164 |
+
return entry
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _is_pypi_tool(tool: str) -> bool:
|
| 168 |
+
entry = _tool_entry(tool)
|
| 169 |
+
return entry.get("version") == "pypi" or not entry.get("assets")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _asset_for_platform(tool: str, plat: PlatformKey) -> dict[str, Any]:
|
| 173 |
+
entry = _tool_entry(tool)
|
| 174 |
+
if _is_pypi_tool(tool):
|
| 175 |
+
raise PlatformNotSupported(
|
| 176 |
+
f"{tool}: distributed via PyPI only; `pip install headroom-ai` "
|
| 177 |
+
f"should have placed `{entry.get('binary', tool)}` on PATH."
|
| 178 |
+
)
|
| 179 |
+
assets: dict[str, Any] = entry.get("assets", {})
|
| 180 |
+
asset: dict[str, Any] | None = assets.get(plat.key())
|
| 181 |
+
if asset is None:
|
| 182 |
+
supported = sorted(assets.keys())
|
| 183 |
+
raise PlatformNotSupported(
|
| 184 |
+
f"{tool}: no prebuilt binary for {plat.key()}; supported: {supported}"
|
| 185 |
+
)
|
| 186 |
+
return asset
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _mirror_url(url: str) -> str:
|
| 190 |
+
mirror = os.environ.get("HEADROOM_BINARIES_MIRROR")
|
| 191 |
+
if not mirror:
|
| 192 |
+
return url
|
| 193 |
+
# Only substitute the github.com host so that paths remain intact.
|
| 194 |
+
for prefix in ("https://github.com", "https://objects.githubusercontent.com"):
|
| 195 |
+
if url.startswith(prefix):
|
| 196 |
+
return mirror.rstrip("/") + url[len(prefix) :]
|
| 197 |
+
return url
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# ---------- Download + verify --------------------------------------------- #
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _download(url: str, dest: Path, *, progress: bool = True) -> None:
|
| 204 |
+
if os.environ.get("HEADROOM_BINARIES_OFFLINE"):
|
| 205 |
+
raise OfflineError(f"offline mode (HEADROOM_BINARIES_OFFLINE=1) but fetch required: {url}")
|
| 206 |
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
| 207 |
+
final_url = _mirror_url(url)
|
| 208 |
+
req = urllib.request.Request(final_url, headers={"User-Agent": "headroom-binaries/1"})
|
| 209 |
+
try:
|
| 210 |
+
with urllib.request.urlopen(req, timeout=60) as resp: # noqa: S310 (https)
|
| 211 |
+
total = int(resp.headers.get("Content-Length") or 0)
|
| 212 |
+
_stream_to(resp, dest, total, label=dest.name, show_progress=progress)
|
| 213 |
+
except urllib.error.URLError as e:
|
| 214 |
+
raise BinaryFetchError(f"failed to download {final_url}: {e}") from e
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _stream_to(src: Any, dest: Path, total: int, *, label: str, show_progress: bool) -> None:
|
| 218 |
+
# Rich progress if available and stderr is a tty; otherwise silent chunked copy.
|
| 219 |
+
try:
|
| 220 |
+
if show_progress and sys.stderr.isatty():
|
| 221 |
+
from rich.progress import (
|
| 222 |
+
BarColumn,
|
| 223 |
+
DownloadColumn,
|
| 224 |
+
Progress,
|
| 225 |
+
TextColumn,
|
| 226 |
+
TimeRemainingColumn,
|
| 227 |
+
TransferSpeedColumn,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
with Progress(
|
| 231 |
+
TextColumn("[bold blue]{task.description}"),
|
| 232 |
+
BarColumn(),
|
| 233 |
+
DownloadColumn(),
|
| 234 |
+
TransferSpeedColumn(),
|
| 235 |
+
TimeRemainingColumn(),
|
| 236 |
+
) as prog:
|
| 237 |
+
task = prog.add_task(label, total=total or None)
|
| 238 |
+
with dest.open("wb") as out:
|
| 239 |
+
while chunk := src.read(1024 * 64):
|
| 240 |
+
out.write(chunk)
|
| 241 |
+
prog.update(task, advance=len(chunk))
|
| 242 |
+
return
|
| 243 |
+
except ImportError:
|
| 244 |
+
pass
|
| 245 |
+
with dest.open("wb") as out:
|
| 246 |
+
shutil.copyfileobj(src, out)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _sha256_file(path: Path) -> str:
|
| 250 |
+
h = hashlib.sha256()
|
| 251 |
+
with path.open("rb") as f:
|
| 252 |
+
for chunk in iter(lambda: f.read(1024 * 64), b""):
|
| 253 |
+
h.update(chunk)
|
| 254 |
+
return h.hexdigest()
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _verify_sha256(path: Path, expected: str | None) -> None:
|
| 258 |
+
if not expected:
|
| 259 |
+
# Upstream release not SHA-pinned in registry. We trusted HTTPS + the
|
| 260 |
+
# GitHub CDN for the download; log nothing here — `doctor` surfaces.
|
| 261 |
+
return
|
| 262 |
+
got = _sha256_file(path)
|
| 263 |
+
if got.lower() != expected.lower():
|
| 264 |
+
path.unlink(missing_ok=True)
|
| 265 |
+
raise Sha256Mismatch(f"sha256 mismatch for {path.name}: expected {expected}, got {got}")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# ---------- Archive extraction ------------------------------------------- #
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _extract(archive: Path, member: str, dest: Path) -> None:
|
| 272 |
+
"""Extract `member` from archive into `dest` (single-file binary)."""
|
| 273 |
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
| 274 |
+
name = archive.name.lower()
|
| 275 |
+
try:
|
| 276 |
+
if name.endswith(".tar.gz") or name.endswith(".tgz"):
|
| 277 |
+
with tarfile.open(archive, "r:gz") as tf:
|
| 278 |
+
_extract_member_from_tar(tf, member, dest)
|
| 279 |
+
elif name.endswith(".zip"):
|
| 280 |
+
with zipfile.ZipFile(archive) as zf:
|
| 281 |
+
_extract_member_from_zip(zf, member, dest)
|
| 282 |
+
elif name.endswith(".gz") and "." not in name[:-3]:
|
| 283 |
+
# bare .gz of a single binary
|
| 284 |
+
import gzip
|
| 285 |
+
|
| 286 |
+
with gzip.open(archive, "rb") as gz, dest.open("wb") as out:
|
| 287 |
+
shutil.copyfileobj(gz, out)
|
| 288 |
+
else:
|
| 289 |
+
# Not an archive — treat the downloaded file itself as the binary.
|
| 290 |
+
shutil.copy2(archive, dest)
|
| 291 |
+
except (tarfile.TarError, zipfile.BadZipFile, OSError) as e:
|
| 292 |
+
raise BinaryFetchError(f"failed to extract {archive.name}: {e}") from e
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _extract_member_from_tar(tf: tarfile.TarFile, member: str, dest: Path) -> None:
|
| 296 |
+
# Match by basename so that registries can specify "difft" even though the
|
| 297 |
+
# upstream tar may include a leading directory like "difft-0.64.0/difft".
|
| 298 |
+
wanted = member.lower()
|
| 299 |
+
for m in tf.getmembers():
|
| 300 |
+
base = m.name.rsplit("/", 1)[-1].lower()
|
| 301 |
+
if base == wanted and m.isfile():
|
| 302 |
+
extracted = tf.extractfile(m)
|
| 303 |
+
if extracted is None:
|
| 304 |
+
continue
|
| 305 |
+
with dest.open("wb") as out:
|
| 306 |
+
shutil.copyfileobj(extracted, out)
|
| 307 |
+
return
|
| 308 |
+
raise BinaryFetchError(f"archive did not contain expected member {member!r}")
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def _extract_member_from_zip(zf: zipfile.ZipFile, member: str, dest: Path) -> None:
|
| 312 |
+
wanted = member.lower()
|
| 313 |
+
for info in zf.infolist():
|
| 314 |
+
base = info.filename.rsplit("/", 1)[-1].lower()
|
| 315 |
+
if base == wanted and not info.is_dir():
|
| 316 |
+
with zf.open(info) as src, dest.open("wb") as out:
|
| 317 |
+
shutil.copyfileobj(src, out)
|
| 318 |
+
return
|
| 319 |
+
raise BinaryFetchError(f"archive did not contain expected member {member!r}")
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# ---------- Public API ---------------------------------------------------- #
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _binary_name(tool: str, plat: PlatformKey) -> str:
|
| 326 |
+
entry = _tool_entry(tool)
|
| 327 |
+
base = entry.get("binary", tool)
|
| 328 |
+
return f"{base}.exe" if plat.os == "windows" else base
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _cached_path(tool: str, version: str, plat: PlatformKey) -> Path:
|
| 332 |
+
return cache_dir() / f"{tool}-{version}-{plat.key()}" / _binary_name(tool, plat)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _in_registry(tool: str) -> bool:
|
| 336 |
+
return tool in _registry().get("tools", {})
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def _path_lookup(tool: str) -> Path | None:
|
| 340 |
+
"""Find `tool` on PATH or in this interpreter's Scripts/bin directory.
|
| 341 |
+
|
| 342 |
+
PyPI binary wheels (e.g. ast-grep-cli) install their console scripts into
|
| 343 |
+
sys.prefix/bin (or sys.prefix/Scripts on Windows). That directory is on
|
| 344 |
+
PATH when the venv is activated, but subprocesses started by a non-active
|
| 345 |
+
interpreter can miss it, so we check it explicitly as a fallback.
|
| 346 |
+
"""
|
| 347 |
+
candidates = [tool]
|
| 348 |
+
if _in_registry(tool):
|
| 349 |
+
alias = _tool_entry(tool).get("binary")
|
| 350 |
+
if alias and alias != tool:
|
| 351 |
+
candidates.append(alias)
|
| 352 |
+
|
| 353 |
+
for name in candidates:
|
| 354 |
+
found = shutil.which(name)
|
| 355 |
+
if found:
|
| 356 |
+
return Path(found)
|
| 357 |
+
|
| 358 |
+
scripts_dir = Path(sys.prefix) / ("Scripts" if sys.platform.startswith("win") else "bin")
|
| 359 |
+
for name in candidates:
|
| 360 |
+
exe = scripts_dir / (name + (".exe" if sys.platform.startswith("win") else ""))
|
| 361 |
+
if exe.exists():
|
| 362 |
+
return exe
|
| 363 |
+
return None
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def which(tool: str) -> Path | None:
|
| 367 |
+
"""Return a path to `tool` if it is on PATH or already cached, else None.
|
| 368 |
+
|
| 369 |
+
Never triggers a network fetch. Callers that want the tool to be installed
|
| 370 |
+
on demand should use `resolve()` instead.
|
| 371 |
+
"""
|
| 372 |
+
on_path = _path_lookup(tool)
|
| 373 |
+
if on_path:
|
| 374 |
+
return on_path
|
| 375 |
+
if not _in_registry(tool):
|
| 376 |
+
return None
|
| 377 |
+
try:
|
| 378 |
+
plat = detect_platform()
|
| 379 |
+
_asset_for_platform(tool, plat) # raises if unsupported
|
| 380 |
+
except PlatformNotSupported:
|
| 381 |
+
return None
|
| 382 |
+
path = _cached_path(tool, _tool_entry(tool)["version"], plat)
|
| 383 |
+
return path if path.exists() else None
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def resolve(tool: str) -> Path:
|
| 387 |
+
"""Return a path to the tool binary, fetching it on first use.
|
| 388 |
+
|
| 389 |
+
Raises PlatformNotSupported if the tool is unavailable on this platform,
|
| 390 |
+
OfflineError if a fetch is required but HEADROOM_BINARIES_OFFLINE is set,
|
| 391 |
+
Sha256Mismatch if verification fails, BinaryFetchError on other IO errors.
|
| 392 |
+
"""
|
| 393 |
+
on_path = _path_lookup(tool)
|
| 394 |
+
if on_path:
|
| 395 |
+
return on_path
|
| 396 |
+
if not _in_registry(tool):
|
| 397 |
+
raise KeyError(f"unknown tool {tool!r}")
|
| 398 |
+
|
| 399 |
+
plat = detect_platform()
|
| 400 |
+
entry = _tool_entry(tool)
|
| 401 |
+
asset = _asset_for_platform(tool, plat)
|
| 402 |
+
version = entry["version"]
|
| 403 |
+
binary_path = _cached_path(tool, version, plat)
|
| 404 |
+
if binary_path.exists():
|
| 405 |
+
return binary_path
|
| 406 |
+
|
| 407 |
+
# Not cached — fetch, verify, extract.
|
| 408 |
+
url = asset["url"]
|
| 409 |
+
sha256 = asset.get("sha256")
|
| 410 |
+
member = asset.get("member", _binary_name(tool, plat))
|
| 411 |
+
|
| 412 |
+
with tempfile.TemporaryDirectory(prefix="headroom-fetch-") as tmp:
|
| 413 |
+
tmp_dir = Path(tmp)
|
| 414 |
+
download_path = tmp_dir / Path(url).name
|
| 415 |
+
_download(url, download_path)
|
| 416 |
+
_verify_sha256(download_path, sha256)
|
| 417 |
+
staging = tmp_dir / "out"
|
| 418 |
+
_extract(download_path, member, staging)
|
| 419 |
+
binary_path.parent.mkdir(parents=True, exist_ok=True)
|
| 420 |
+
# Atomic-ish move: write to sibling then rename.
|
| 421 |
+
tmp_final = binary_path.with_suffix(binary_path.suffix + ".partial")
|
| 422 |
+
shutil.move(str(staging), tmp_final)
|
| 423 |
+
try:
|
| 424 |
+
tmp_final.chmod(0o755)
|
| 425 |
+
except OSError:
|
| 426 |
+
pass # Windows or restricted FS — .exe is already executable
|
| 427 |
+
os.replace(tmp_final, binary_path)
|
| 428 |
+
return binary_path
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def ensure_tools(quiet: bool = False) -> dict[str, Path | None]:
|
| 432 |
+
"""Install every tool in the registry if missing. Safe to call repeatedly.
|
| 433 |
+
|
| 434 |
+
Called at proxy startup and on first `headroom` CLI invocation so that no
|
| 435 |
+
tool fetch ever happens inside a live request. Skips tools that are on
|
| 436 |
+
PATH, already cached, or distributed via PyPI-only (ast-grep).
|
| 437 |
+
|
| 438 |
+
Returns a map of tool_name -> resolved Path (or None if unsupported).
|
| 439 |
+
Never raises; unsupported platforms or offline errors are logged via
|
| 440 |
+
stderr and the tool is skipped.
|
| 441 |
+
"""
|
| 442 |
+
out: dict[str, Path | None] = {}
|
| 443 |
+
for name in _registry().get("tools", {}):
|
| 444 |
+
try:
|
| 445 |
+
if _is_pypi_tool(name):
|
| 446 |
+
# ast-grep ships via pip; just record whether it's on PATH.
|
| 447 |
+
out[name] = _path_lookup(name)
|
| 448 |
+
continue
|
| 449 |
+
out[name] = resolve(name)
|
| 450 |
+
except (PlatformNotSupported, OfflineError, BinaryFetchError, Sha256Mismatch) as e:
|
| 451 |
+
out[name] = None
|
| 452 |
+
if not quiet:
|
| 453 |
+
print(f"headroom: skipping {name}: {e}", file=sys.stderr)
|
| 454 |
+
return out
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def status() -> list[dict[str, Any]]:
|
| 458 |
+
"""Return a list of status dicts for every tool in the registry.
|
| 459 |
+
|
| 460 |
+
Used by `headroom tools doctor`. Never fetches — only inspects.
|
| 461 |
+
"""
|
| 462 |
+
out: list[dict[str, Any]] = []
|
| 463 |
+
plat = detect_platform()
|
| 464 |
+
for name, entry in _registry().get("tools", {}).items():
|
| 465 |
+
row: dict[str, Any] = {
|
| 466 |
+
"tool": name,
|
| 467 |
+
"version": entry.get("version"),
|
| 468 |
+
"platform": plat.key(),
|
| 469 |
+
"source": entry.get("source", "fetched"),
|
| 470 |
+
"path": None,
|
| 471 |
+
"state": "missing",
|
| 472 |
+
}
|
| 473 |
+
# Honor PATH.
|
| 474 |
+
on_path = shutil.which(name) or (
|
| 475 |
+
shutil.which(entry["binary"]) if entry.get("binary") else None
|
| 476 |
+
)
|
| 477 |
+
if on_path:
|
| 478 |
+
row["path"] = on_path
|
| 479 |
+
row["state"] = "on-path"
|
| 480 |
+
out.append(row)
|
| 481 |
+
continue
|
| 482 |
+
try:
|
| 483 |
+
_asset_for_platform(name, plat)
|
| 484 |
+
except PlatformNotSupported as e:
|
| 485 |
+
row["state"] = "unsupported-platform"
|
| 486 |
+
row["detail"] = str(e)
|
| 487 |
+
out.append(row)
|
| 488 |
+
continue
|
| 489 |
+
cached = _cached_path(name, entry["version"], plat)
|
| 490 |
+
if cached.exists():
|
| 491 |
+
row["path"] = str(cached)
|
| 492 |
+
row["state"] = "cached"
|
| 493 |
+
out.append(row)
|
| 494 |
+
return out
|
|
@@ -42,6 +42,7 @@ def _register_commands() -> None:
|
|
| 42 |
mcp, # noqa: F401
|
| 43 |
perf, # noqa: F401
|
| 44 |
proxy, # noqa: F401
|
|
|
|
| 45 |
wrap, # noqa: F401
|
| 46 |
)
|
| 47 |
|
|
|
|
| 42 |
mcp, # noqa: F401
|
| 43 |
perf, # noqa: F401
|
| 44 |
proxy, # noqa: F401
|
| 45 |
+
tools, # noqa: F401
|
| 46 |
wrap, # noqa: F401
|
| 47 |
)
|
| 48 |
|
|
@@ -44,6 +44,14 @@ from .main import main
|
|
| 44 |
"Legacy aliases are accepted. Default: token. Env: HEADROOM_MODE"
|
| 45 |
),
|
| 46 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
@click.option("--no-optimize", is_flag=True, help="Disable optimization (passthrough mode)")
|
| 48 |
@click.option("--no-cache", is_flag=True, help="Disable semantic caching")
|
| 49 |
@click.option("--no-rate-limit", is_flag=True, help="Disable rate limiting")
|
|
@@ -211,6 +219,7 @@ def proxy(
|
|
| 211 |
mode: str | None,
|
| 212 |
host: str,
|
| 213 |
port: int,
|
|
|
|
| 214 |
no_optimize: bool,
|
| 215 |
no_cache: bool,
|
| 216 |
no_rate_limit: bool,
|
|
@@ -267,6 +276,18 @@ def proxy(
|
|
| 267 |
click.echo(f"Details: {e}")
|
| 268 |
raise SystemExit(1) from None
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
# Resolve API URL overrides: CLI flag > env var > None
|
| 271 |
effective_anthropic_api_url = anthropic_api_url or os.environ.get("ANTHROPIC_TARGET_API_URL")
|
| 272 |
effective_openai_api_url = openai_api_url or os.environ.get("OPENAI_TARGET_API_URL")
|
|
|
|
| 44 |
"Legacy aliases are accepted. Default: token. Env: HEADROOM_MODE"
|
| 45 |
),
|
| 46 |
)
|
| 47 |
+
@click.option(
|
| 48 |
+
"--intercept-tool-results",
|
| 49 |
+
is_flag=True,
|
| 50 |
+
help=(
|
| 51 |
+
"Opt in to tool_result interceptors (ast-grep Read outliner, etc.). "
|
| 52 |
+
"Off by default while this feature ships."
|
| 53 |
+
),
|
| 54 |
+
)
|
| 55 |
@click.option("--no-optimize", is_flag=True, help="Disable optimization (passthrough mode)")
|
| 56 |
@click.option("--no-cache", is_flag=True, help="Disable semantic caching")
|
| 57 |
@click.option("--no-rate-limit", is_flag=True, help="Disable rate limiting")
|
|
|
|
| 219 |
mode: str | None,
|
| 220 |
host: str,
|
| 221 |
port: int,
|
| 222 |
+
intercept_tool_results: bool,
|
| 223 |
no_optimize: bool,
|
| 224 |
no_cache: bool,
|
| 225 |
no_rate_limit: bool,
|
|
|
|
| 276 |
click.echo(f"Details: {e}")
|
| 277 |
raise SystemExit(1) from None
|
| 278 |
|
| 279 |
+
# Ensure bundled CLI tools (ast-grep, difftastic, scc) are present before
|
| 280 |
+
# the proxy starts accepting traffic. Never happens inside a live request —
|
| 281 |
+
# tools are downloaded once at startup if missing, then cached per-user.
|
| 282 |
+
from headroom.binaries import ensure_tools
|
| 283 |
+
|
| 284 |
+
ensure_tools()
|
| 285 |
+
|
| 286 |
+
# Opt-in: turn on tool_result interceptors (ast-grep Read outline, etc.).
|
| 287 |
+
# The TransformPipeline reads this env var at construction time.
|
| 288 |
+
if intercept_tool_results:
|
| 289 |
+
os.environ["HEADROOM_INTERCEPT_ENABLED"] = "1"
|
| 290 |
+
|
| 291 |
# Resolve API URL overrides: CLI flag > env var > None
|
| 292 |
effective_anthropic_api_url = anthropic_api_url or os.environ.get("ANTHROPIC_TARGET_API_URL")
|
| 293 |
effective_openai_api_url = openai_api_url or os.environ.get("OPENAI_TARGET_API_URL")
|
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI: passthrough subcommands for bundled tools and a `tools` management group.
|
| 2 |
+
|
| 3 |
+
Exposes:
|
| 4 |
+
|
| 5 |
+
headroom sg … -> ast-grep (from the ast-grep-cli PyPI wheel)
|
| 6 |
+
headroom diff A B … -> difftastic
|
| 7 |
+
headroom loc [PATH] … -> scc
|
| 8 |
+
headroom tools install -> pre-fetch all bundled binaries
|
| 9 |
+
headroom tools doctor -> print a status table
|
| 10 |
+
headroom tools list -> show the registry
|
| 11 |
+
|
| 12 |
+
The passthrough commands forward every argument, stdin, stdout, stderr, and
|
| 13 |
+
the exit code verbatim, so agents can invoke them via their existing shell
|
| 14 |
+
tool without any Headroom-specific protocol.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import subprocess
|
| 21 |
+
import sys
|
| 22 |
+
from collections.abc import Sequence
|
| 23 |
+
|
| 24 |
+
import click
|
| 25 |
+
|
| 26 |
+
from headroom import binaries
|
| 27 |
+
|
| 28 |
+
from .main import main
|
| 29 |
+
|
| 30 |
+
_PASSTHROUGH_CTX = {
|
| 31 |
+
"ignore_unknown_options": True,
|
| 32 |
+
"allow_extra_args": True,
|
| 33 |
+
"help_option_names": [], # let the underlying tool handle --help
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _exec_tool(tool: str, argv: Sequence[str]) -> None:
|
| 38 |
+
try:
|
| 39 |
+
path = binaries.resolve(tool)
|
| 40 |
+
except binaries.PlatformNotSupported as e:
|
| 41 |
+
click.secho(f"error: {e}", fg="red", err=True)
|
| 42 |
+
sys.exit(2)
|
| 43 |
+
except binaries.OfflineError as e:
|
| 44 |
+
click.secho(
|
| 45 |
+
f"error: {e}\nHint: run `headroom tools install` on a networked machine, "
|
| 46 |
+
f"or pass --from <bundle.tar.gz>.",
|
| 47 |
+
fg="red",
|
| 48 |
+
err=True,
|
| 49 |
+
)
|
| 50 |
+
sys.exit(2)
|
| 51 |
+
except (binaries.Sha256Mismatch, binaries.BinaryFetchError) as e:
|
| 52 |
+
click.secho(f"error: {e}", fg="red", err=True)
|
| 53 |
+
sys.exit(2)
|
| 54 |
+
|
| 55 |
+
# Replace the current process on POSIX for correct signal handling;
|
| 56 |
+
# fall back to subprocess on Windows where os.execv is awkward.
|
| 57 |
+
cmd = [str(path), *argv]
|
| 58 |
+
if os.name == "posix":
|
| 59 |
+
os.execv(cmd[0], cmd) # never returns
|
| 60 |
+
else:
|
| 61 |
+
completed = subprocess.run(cmd, check=False)
|
| 62 |
+
sys.exit(completed.returncode)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@main.command(
|
| 66 |
+
"sg",
|
| 67 |
+
context_settings=_PASSTHROUGH_CTX,
|
| 68 |
+
short_help="Run ast-grep (AST-aware structural search/replace).",
|
| 69 |
+
add_help_option=False,
|
| 70 |
+
)
|
| 71 |
+
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
|
| 72 |
+
def sg_cmd(args: tuple[str, ...]) -> None:
|
| 73 |
+
"""Forward every argument to ast-grep."""
|
| 74 |
+
_exec_tool("ast-grep", list(args))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@main.command(
|
| 78 |
+
"diff",
|
| 79 |
+
context_settings=_PASSTHROUGH_CTX,
|
| 80 |
+
short_help="Run difftastic (structural diff).",
|
| 81 |
+
add_help_option=False,
|
| 82 |
+
)
|
| 83 |
+
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
|
| 84 |
+
def diff_cmd(args: tuple[str, ...]) -> None:
|
| 85 |
+
"""Forward every argument to difftastic (`difft`)."""
|
| 86 |
+
_exec_tool("difft", list(args))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@main.command(
|
| 90 |
+
"loc",
|
| 91 |
+
context_settings=_PASSTHROUGH_CTX,
|
| 92 |
+
short_help="Run scc (fast lines-of-code / repo-shape probe).",
|
| 93 |
+
add_help_option=False,
|
| 94 |
+
)
|
| 95 |
+
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
|
| 96 |
+
def loc_cmd(args: tuple[str, ...]) -> None:
|
| 97 |
+
"""Forward every argument to scc."""
|
| 98 |
+
_exec_tool("scc", list(args))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@main.group("tools")
|
| 102 |
+
def tools_group() -> None:
|
| 103 |
+
"""Manage bundled CLI tool binaries (ast-grep, difft, scc)."""
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@tools_group.command("list")
|
| 107 |
+
def tools_list_cmd() -> None:
|
| 108 |
+
"""Print the tool registry (versions, platforms, cache dir)."""
|
| 109 |
+
from rich.console import Console
|
| 110 |
+
from rich.table import Table
|
| 111 |
+
|
| 112 |
+
console = Console()
|
| 113 |
+
plat = binaries.detect_platform()
|
| 114 |
+
console.print(f"[dim]platform:[/dim] {plat.key()}")
|
| 115 |
+
console.print(f"[dim]cache:[/dim] {binaries.cache_dir()}")
|
| 116 |
+
table = Table(show_header=True, header_style="bold")
|
| 117 |
+
table.add_column("tool")
|
| 118 |
+
table.add_column("version")
|
| 119 |
+
table.add_column("source")
|
| 120 |
+
table.add_column("platforms")
|
| 121 |
+
reg = binaries._registry() # noqa: SLF001 (intentional internal read)
|
| 122 |
+
for name, entry in reg.get("tools", {}).items():
|
| 123 |
+
platforms = ", ".join(sorted(entry.get("assets", {}).keys())) or "(pypi)"
|
| 124 |
+
table.add_row(name, str(entry.get("version")), entry.get("source", ""), platforms)
|
| 125 |
+
console.print(table)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@tools_group.command("doctor")
|
| 129 |
+
@click.option("--json", "emit_json", is_flag=True, help="Emit JSON instead of a table.")
|
| 130 |
+
def tools_doctor_cmd(emit_json: bool) -> None:
|
| 131 |
+
"""Check the status of every bundled tool."""
|
| 132 |
+
rows = binaries.status()
|
| 133 |
+
if emit_json:
|
| 134 |
+
import json as _json
|
| 135 |
+
|
| 136 |
+
click.echo(_json.dumps(rows, indent=2))
|
| 137 |
+
broken = any(r["state"] in ("missing", "unsupported-platform") for r in rows)
|
| 138 |
+
sys.exit(1 if broken else 0)
|
| 139 |
+
|
| 140 |
+
from rich.console import Console
|
| 141 |
+
from rich.table import Table
|
| 142 |
+
|
| 143 |
+
console = Console()
|
| 144 |
+
table = Table(show_header=True, header_style="bold")
|
| 145 |
+
for col in ("tool", "state", "version", "platform", "path"):
|
| 146 |
+
table.add_column(col)
|
| 147 |
+
state_style = {
|
| 148 |
+
"on-path": "green",
|
| 149 |
+
"cached": "green",
|
| 150 |
+
"missing": "yellow",
|
| 151 |
+
"unsupported-platform": "red",
|
| 152 |
+
}
|
| 153 |
+
broken = False
|
| 154 |
+
for r in rows:
|
| 155 |
+
style = state_style.get(r["state"], "white")
|
| 156 |
+
if r["state"] in ("missing", "unsupported-platform"):
|
| 157 |
+
broken = True
|
| 158 |
+
table.add_row(
|
| 159 |
+
r["tool"],
|
| 160 |
+
f"[{style}]{r['state']}[/{style}]",
|
| 161 |
+
str(r.get("version")),
|
| 162 |
+
r.get("platform", ""),
|
| 163 |
+
r.get("path") or "-",
|
| 164 |
+
)
|
| 165 |
+
console.print(table)
|
| 166 |
+
from rich.markup import escape as _escape
|
| 167 |
+
|
| 168 |
+
for r in rows:
|
| 169 |
+
if r.get("detail"):
|
| 170 |
+
console.print(f"[dim]{r['tool']}:[/dim] {_escape(r['detail'])}")
|
| 171 |
+
sys.exit(1 if broken else 0)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@tools_group.command("install")
|
| 175 |
+
@click.option(
|
| 176 |
+
"--tool",
|
| 177 |
+
"tools",
|
| 178 |
+
multiple=True,
|
| 179 |
+
help="Install only the named tool (repeatable). Default: all.",
|
| 180 |
+
)
|
| 181 |
+
@click.option(
|
| 182 |
+
"--force",
|
| 183 |
+
is_flag=True,
|
| 184 |
+
help="Re-fetch even if the binary is already cached.",
|
| 185 |
+
)
|
| 186 |
+
def tools_install_cmd(tools: tuple[str, ...], force: bool) -> None:
|
| 187 |
+
"""Pre-fetch all bundled tool binaries into the per-user cache."""
|
| 188 |
+
reg = binaries._registry() # noqa: SLF001
|
| 189 |
+
selected = list(tools) if tools else list(reg.get("tools", {}).keys())
|
| 190 |
+
exit_code = 0
|
| 191 |
+
for name in selected:
|
| 192 |
+
if name not in reg.get("tools", {}):
|
| 193 |
+
click.secho(f"unknown tool {name!r}; skipping", fg="yellow", err=True)
|
| 194 |
+
exit_code = 1
|
| 195 |
+
continue
|
| 196 |
+
if binaries._is_pypi_tool(name): # noqa: SLF001
|
| 197 |
+
on_path = binaries._path_lookup(name) # noqa: SLF001
|
| 198 |
+
if on_path:
|
| 199 |
+
click.echo(f"{name}: on PATH at {on_path} (pypi wheel)")
|
| 200 |
+
else:
|
| 201 |
+
click.secho(
|
| 202 |
+
f"{name}: not on PATH — `pip install headroom-ai` should provide it",
|
| 203 |
+
fg="yellow",
|
| 204 |
+
)
|
| 205 |
+
exit_code = 1
|
| 206 |
+
continue
|
| 207 |
+
if force:
|
| 208 |
+
plat = binaries.detect_platform()
|
| 209 |
+
try:
|
| 210 |
+
cached = binaries._cached_path( # noqa: SLF001
|
| 211 |
+
name, reg["tools"][name]["version"], plat
|
| 212 |
+
)
|
| 213 |
+
if cached.exists():
|
| 214 |
+
cached.unlink()
|
| 215 |
+
except Exception: # noqa: BLE001
|
| 216 |
+
pass
|
| 217 |
+
try:
|
| 218 |
+
path = binaries.resolve(name)
|
| 219 |
+
click.secho(f"{name}: installed → {path}", fg="green")
|
| 220 |
+
except binaries.PlatformNotSupported as e:
|
| 221 |
+
click.secho(f"{name}: {e}", fg="red")
|
| 222 |
+
exit_code = 1
|
| 223 |
+
except (binaries.BinaryFetchError, binaries.Sha256Mismatch, binaries.OfflineError) as e:
|
| 224 |
+
click.secho(f"{name}: {e}", fg="red")
|
| 225 |
+
exit_code = 1
|
| 226 |
+
sys.exit(exit_code)
|
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool-result interceptors.
|
| 2 |
+
|
| 3 |
+
An interceptor rewrites a single tool_result's text before it reaches the
|
| 4 |
+
model. Each interceptor is self-contained: declare a `matches()` predicate
|
| 5 |
+
and a `transform()` function, register it in the `INTERCEPTORS` list, and
|
| 6 |
+
the proxy pipeline will call it automatically.
|
| 7 |
+
|
| 8 |
+
Adding a new interceptor later is one file plus one `register()` call — no
|
| 9 |
+
proxy or metrics changes required.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# Side-effect: register the built-in interceptors.
|
| 13 |
+
from . import astgrep # noqa: F401
|
| 14 |
+
from .base import (
|
| 15 |
+
INTERCEPTORS,
|
| 16 |
+
InterceptionResult,
|
| 17 |
+
ToolResultInterceptor,
|
| 18 |
+
ToolResultInterceptorTransform,
|
| 19 |
+
TransformSpan,
|
| 20 |
+
apply_to_messages,
|
| 21 |
+
register,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"INTERCEPTORS",
|
| 26 |
+
"InterceptionResult",
|
| 27 |
+
"ToolResultInterceptor",
|
| 28 |
+
"ToolResultInterceptorTransform",
|
| 29 |
+
"TransformSpan",
|
| 30 |
+
"apply_to_messages",
|
| 31 |
+
"register",
|
| 32 |
+
]
|
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ast-grep interceptor: replace verbose Read outputs with function-level outlines.
|
| 2 |
+
|
| 3 |
+
Matches Claude Code's `Read` tool (and equivalent) when the file is code and
|
| 4 |
+
the output is large enough to benefit. Invokes ast-grep to locate top-level
|
| 5 |
+
function and class definitions and emits a compact outline: each signature
|
| 6 |
+
followed by an elided body marker. Falls back to the original text if
|
| 7 |
+
ast-grep isn't available, the extension isn't supported, or there are fewer
|
| 8 |
+
than three definitions to outline.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import subprocess
|
| 17 |
+
import tempfile
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
from headroom import binaries
|
| 22 |
+
|
| 23 |
+
from . import base
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
# Latency floor: below this size, the subprocess cost of running ast-grep
|
| 28 |
+
# isn't worth the tiny win. It is NOT a semantic threshold — the framework
|
| 29 |
+
# rejects any rewrite that doesn't actually shrink tokens, so we don't need
|
| 30 |
+
# a "big enough to matter" check here, only a "big enough to justify the
|
| 31 |
+
# fork()" check.
|
| 32 |
+
MIN_CHARS_TO_REWRITE = int(os.environ.get("HEADROOM_INTERCEPT_READ_MIN_CHARS", "500"))
|
| 33 |
+
|
| 34 |
+
# Tool_input keys that indicate the model targeted a specific line range;
|
| 35 |
+
# outlining would frustrate that intent and likely cause a re-read.
|
| 36 |
+
_RANGE_KEYS = ("offset", "limit", "line_range", "start_line", "end_line", "ranges")
|
| 37 |
+
|
| 38 |
+
# ast-grep --lang is passed these values; only extensions with a stable
|
| 39 |
+
# grammar are included.
|
| 40 |
+
_EXT_TO_LANG: dict[str, str] = {
|
| 41 |
+
".py": "python",
|
| 42 |
+
".ts": "typescript",
|
| 43 |
+
".tsx": "tsx",
|
| 44 |
+
".js": "javascript",
|
| 45 |
+
".jsx": "jsx",
|
| 46 |
+
".go": "go",
|
| 47 |
+
".rs": "rust",
|
| 48 |
+
".java": "java",
|
| 49 |
+
".rb": "ruby",
|
| 50 |
+
".c": "c",
|
| 51 |
+
".h": "c",
|
| 52 |
+
".cpp": "cpp",
|
| 53 |
+
".cc": "cpp",
|
| 54 |
+
".hpp": "cpp",
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Top-level declaration patterns per language. We emit the signature line
|
| 58 |
+
# of whatever ast-grep matches here, so any pattern that anchors on a
|
| 59 |
+
# declaration's starting line works.
|
| 60 |
+
_PATTERNS: dict[str, list[str]] = {
|
| 61 |
+
"python": ["def $NAME", "class $NAME", "async def $NAME"],
|
| 62 |
+
"typescript": ["function $NAME", "class $NAME"],
|
| 63 |
+
"tsx": ["function $NAME", "class $NAME"],
|
| 64 |
+
"javascript": ["function $NAME", "class $NAME"],
|
| 65 |
+
"jsx": ["function $NAME", "class $NAME"],
|
| 66 |
+
"go": ["func $NAME"],
|
| 67 |
+
"rust": ["fn $NAME", "struct $NAME", "enum $NAME"],
|
| 68 |
+
"java": ["class $NAME", "interface $NAME"],
|
| 69 |
+
"ruby": ["def $NAME", "class $NAME"],
|
| 70 |
+
"c": ["$RET $NAME($$$ARGS) { $$$BODY }"],
|
| 71 |
+
"cpp": ["$RET $NAME($$$ARGS) { $$$BODY }"],
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
OUTLINE_MARKER = " # ... (body elided by Headroom; Read a specific line range to see it)\n"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class AstGrepReadOutline:
|
| 78 |
+
"""Interceptor that outlines verbose code-file Read outputs."""
|
| 79 |
+
|
| 80 |
+
name = "ast-grep"
|
| 81 |
+
|
| 82 |
+
def matches(
|
| 83 |
+
self,
|
| 84 |
+
tool_name: str | None,
|
| 85 |
+
tool_input: dict[str, Any],
|
| 86 |
+
tool_output: str,
|
| 87 |
+
) -> bool:
|
| 88 |
+
if tool_name not in ("Read", "read_file", "view", "cat"):
|
| 89 |
+
return False
|
| 90 |
+
if len(tool_output) < MIN_CHARS_TO_REWRITE:
|
| 91 |
+
return False
|
| 92 |
+
# Respect explicit line ranges — the model wants those specific lines.
|
| 93 |
+
if any(k in tool_input for k in _RANGE_KEYS):
|
| 94 |
+
return False
|
| 95 |
+
return _detect_lang_from_input(tool_input) is not None
|
| 96 |
+
|
| 97 |
+
def transform(
|
| 98 |
+
self,
|
| 99 |
+
tool_name: str | None,
|
| 100 |
+
tool_input: dict[str, Any],
|
| 101 |
+
tool_output: str,
|
| 102 |
+
) -> str | None:
|
| 103 |
+
lang = _detect_lang_from_input(tool_input)
|
| 104 |
+
if not lang:
|
| 105 |
+
return None
|
| 106 |
+
try:
|
| 107 |
+
exe = binaries.resolve("ast-grep")
|
| 108 |
+
except binaries.PlatformNotSupported:
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
matches = _run_ast_grep(exe, lang, tool_output)
|
| 112 |
+
if not matches:
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
outline = _build_outline(matches, tool_output)
|
| 116 |
+
return outline if outline else None
|
| 117 |
+
|
| 118 |
+
def progressive_disclosure_key(
|
| 119 |
+
self,
|
| 120 |
+
tool_name: str | None,
|
| 121 |
+
tool_input: dict[str, Any],
|
| 122 |
+
) -> str | None:
|
| 123 |
+
"""Key by file_path so a second Read of the same file passes through."""
|
| 124 |
+
return _path_from_input(tool_input)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _detect_lang_from_input(tool_input: dict[str, Any]) -> str | None:
|
| 128 |
+
path = _path_from_input(tool_input)
|
| 129 |
+
if not path:
|
| 130 |
+
return None
|
| 131 |
+
ext = Path(path).suffix.lower()
|
| 132 |
+
return _EXT_TO_LANG.get(ext)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _path_from_input(tool_input: dict[str, Any]) -> str | None:
|
| 136 |
+
for key in ("file_path", "path", "filePath", "filename"):
|
| 137 |
+
v = tool_input.get(key)
|
| 138 |
+
if isinstance(v, str) and v:
|
| 139 |
+
return v
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _run_ast_grep(
|
| 144 |
+
exe: Path | str,
|
| 145 |
+
lang: str,
|
| 146 |
+
source: str,
|
| 147 |
+
) -> list[dict[str, Any]]:
|
| 148 |
+
"""Run ast-grep against `source` and return the JSON match records.
|
| 149 |
+
|
| 150 |
+
Writes `source` to a tempfile because ast-grep's CLI operates on files.
|
| 151 |
+
"""
|
| 152 |
+
all_matches: list[dict[str, Any]] = []
|
| 153 |
+
patterns = _PATTERNS.get(lang, [])
|
| 154 |
+
if not patterns:
|
| 155 |
+
return []
|
| 156 |
+
|
| 157 |
+
# Use the canonical extension so ast-grep can pick the right grammar.
|
| 158 |
+
ext = next((e for e, L in _EXT_TO_LANG.items() if L == lang), ".txt")
|
| 159 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=ext, delete=False, encoding="utf-8") as tmp:
|
| 160 |
+
tmp.write(source)
|
| 161 |
+
tmp_path = tmp.name
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
for pattern in patterns:
|
| 165 |
+
try:
|
| 166 |
+
completed = subprocess.run(
|
| 167 |
+
[
|
| 168 |
+
str(exe),
|
| 169 |
+
"run",
|
| 170 |
+
"--pattern",
|
| 171 |
+
pattern,
|
| 172 |
+
"--lang",
|
| 173 |
+
lang,
|
| 174 |
+
"--json=stream",
|
| 175 |
+
tmp_path,
|
| 176 |
+
],
|
| 177 |
+
capture_output=True,
|
| 178 |
+
text=True,
|
| 179 |
+
timeout=5,
|
| 180 |
+
check=False,
|
| 181 |
+
)
|
| 182 |
+
except (subprocess.TimeoutExpired, OSError) as e:
|
| 183 |
+
logger.debug("ast-grep timed out or failed: %s", e)
|
| 184 |
+
continue
|
| 185 |
+
if completed.returncode != 0:
|
| 186 |
+
continue
|
| 187 |
+
for line in completed.stdout.splitlines():
|
| 188 |
+
line = line.strip()
|
| 189 |
+
if not line:
|
| 190 |
+
continue
|
| 191 |
+
try:
|
| 192 |
+
all_matches.append(json.loads(line))
|
| 193 |
+
except json.JSONDecodeError:
|
| 194 |
+
continue
|
| 195 |
+
finally:
|
| 196 |
+
try:
|
| 197 |
+
Path(tmp_path).unlink()
|
| 198 |
+
except OSError:
|
| 199 |
+
pass
|
| 200 |
+
|
| 201 |
+
return all_matches
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _build_outline(matches: list[dict[str, Any]], source: str) -> str | None:
|
| 205 |
+
"""Build a compact outline from ast-grep matches.
|
| 206 |
+
|
| 207 |
+
Emits each definition's signature line + docstring (if next line is a
|
| 208 |
+
string literal) + an elision marker. Matches are sorted by byte offset
|
| 209 |
+
so the outline tracks the original file order.
|
| 210 |
+
"""
|
| 211 |
+
lines = source.splitlines(keepends=True)
|
| 212 |
+
outline_chunks: list[str] = []
|
| 213 |
+
seen_starts: set[int] = set()
|
| 214 |
+
|
| 215 |
+
matches.sort(key=lambda m: m.get("range", {}).get("byteOffset", {}).get("start", 0))
|
| 216 |
+
for m in matches:
|
| 217 |
+
start = m.get("range", {}).get("start", {})
|
| 218 |
+
line_idx = start.get("line")
|
| 219 |
+
if not isinstance(line_idx, int) or line_idx in seen_starts:
|
| 220 |
+
continue
|
| 221 |
+
seen_starts.add(line_idx)
|
| 222 |
+
if line_idx >= len(lines):
|
| 223 |
+
continue
|
| 224 |
+
signature_line = lines[line_idx].rstrip("\n")
|
| 225 |
+
outline_chunks.append(signature_line + "\n")
|
| 226 |
+
# Best-effort: if the next non-blank line is a docstring, keep it.
|
| 227 |
+
next_idx = line_idx + 1
|
| 228 |
+
while next_idx < len(lines) and not lines[next_idx].strip():
|
| 229 |
+
next_idx += 1
|
| 230 |
+
if next_idx < len(lines):
|
| 231 |
+
nl = lines[next_idx].lstrip()
|
| 232 |
+
if nl.startswith(('"""', "'''", "/**", "//", "#")):
|
| 233 |
+
outline_chunks.append(lines[next_idx])
|
| 234 |
+
outline_chunks.append(OUTLINE_MARKER)
|
| 235 |
+
|
| 236 |
+
if not outline_chunks:
|
| 237 |
+
return None
|
| 238 |
+
header = (
|
| 239 |
+
"[headroom: outlined by ast-grep — "
|
| 240 |
+
f"{len(seen_starts)} definition(s); "
|
| 241 |
+
"bodies elided. Re-read the file with a line range to see a specific body.]\n"
|
| 242 |
+
)
|
| 243 |
+
return header + "".join(outline_chunks)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
base.register(AstGrepReadOutline())
|
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Protocol + registry + Transform adapter for tool_result interceptors."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Protocol, runtime_checkable
|
| 8 |
+
|
| 9 |
+
from headroom.cache.compression_cache import (
|
| 10 |
+
_extract_tool_result_content,
|
| 11 |
+
_is_tool_result_message,
|
| 12 |
+
_swap_tool_result_content,
|
| 13 |
+
)
|
| 14 |
+
from headroom.config import TransformResult
|
| 15 |
+
from headroom.tokenizer import Tokenizer
|
| 16 |
+
from headroom.transforms.base import Transform
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@runtime_checkable
|
| 22 |
+
class ToolResultInterceptor(Protocol):
|
| 23 |
+
"""A stateless rewriter for a single tool_result's text content.
|
| 24 |
+
|
| 25 |
+
Implementations MUST be idempotent and MUST return either a strictly
|
| 26 |
+
smaller string (measured in tokens) or None to pass through. Never raise
|
| 27 |
+
— errors should be caught internally and logged; the pipeline always
|
| 28 |
+
tolerates a no-op interceptor.
|
| 29 |
+
|
| 30 |
+
Interceptors MAY implement `progressive_disclosure_key()` to opt into
|
| 31 |
+
one-shot behavior: the framework tracks which keys have already been
|
| 32 |
+
rewritten in the current conversation, and skips subsequent matches on
|
| 33 |
+
the same key so that the model gets full content if it asks again.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
name: str # e.g. "ast-grep", "difft", "scc"
|
| 37 |
+
|
| 38 |
+
def matches(
|
| 39 |
+
self,
|
| 40 |
+
tool_name: str | None,
|
| 41 |
+
tool_input: dict[str, Any],
|
| 42 |
+
tool_output: str,
|
| 43 |
+
) -> bool: ...
|
| 44 |
+
|
| 45 |
+
def transform(
|
| 46 |
+
self,
|
| 47 |
+
tool_name: str | None,
|
| 48 |
+
tool_input: dict[str, Any],
|
| 49 |
+
tool_output: str,
|
| 50 |
+
) -> str | None: ...
|
| 51 |
+
|
| 52 |
+
def progressive_disclosure_key(
|
| 53 |
+
self,
|
| 54 |
+
tool_name: str | None,
|
| 55 |
+
tool_input: dict[str, Any],
|
| 56 |
+
) -> str | None:
|
| 57 |
+
"""Optional: return a stable content key (e.g. file path).
|
| 58 |
+
|
| 59 |
+
If a key is returned and the same (interceptor.name, key) pair was
|
| 60 |
+
already successfully rewritten earlier in the messages, subsequent
|
| 61 |
+
occurrences pass through unchanged. Return None to opt out.
|
| 62 |
+
"""
|
| 63 |
+
...
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass(frozen=True)
|
| 67 |
+
class TransformSpan:
|
| 68 |
+
"""Per-interceptor measurement emitted for dashboard/metrics."""
|
| 69 |
+
|
| 70 |
+
tool: str
|
| 71 |
+
tokens_before: int
|
| 72 |
+
tokens_after: int
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def tokens_saved(self) -> int:
|
| 76 |
+
return max(self.tokens_before - self.tokens_after, 0)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass
|
| 80 |
+
class InterceptionResult:
|
| 81 |
+
messages: list[dict[str, Any]]
|
| 82 |
+
spans: list[TransformSpan]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
INTERCEPTORS: list[ToolResultInterceptor] = []
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def register(interceptor: ToolResultInterceptor) -> None:
|
| 89 |
+
"""Add an interceptor to the registry. Idempotent on name."""
|
| 90 |
+
for existing in INTERCEPTORS:
|
| 91 |
+
if existing.name == interceptor.name:
|
| 92 |
+
return
|
| 93 |
+
INTERCEPTORS.append(interceptor)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _find_tool_use(
|
| 97 |
+
messages: list[dict[str, Any]],
|
| 98 |
+
tool_use_id: str,
|
| 99 |
+
) -> tuple[str | None, dict[str, Any]]:
|
| 100 |
+
"""Walk prior messages to find the tool_use block that produced a given id.
|
| 101 |
+
|
| 102 |
+
Returns (tool_name, tool_input) or (None, {}) if not found.
|
| 103 |
+
"""
|
| 104 |
+
for msg in messages:
|
| 105 |
+
content = msg.get("content")
|
| 106 |
+
if isinstance(content, list):
|
| 107 |
+
for block in content:
|
| 108 |
+
if not isinstance(block, dict):
|
| 109 |
+
continue
|
| 110 |
+
# Anthropic: {"type": "tool_use", "id": ..., "name": ..., "input": {...}}
|
| 111 |
+
if block.get("type") == "tool_use" and block.get("id") == tool_use_id:
|
| 112 |
+
return (
|
| 113 |
+
block.get("name"),
|
| 114 |
+
block.get("input") or {},
|
| 115 |
+
)
|
| 116 |
+
# OpenAI: assistant message with `tool_calls` list
|
| 117 |
+
tool_calls = msg.get("tool_calls")
|
| 118 |
+
if isinstance(tool_calls, list):
|
| 119 |
+
for call in tool_calls:
|
| 120 |
+
if isinstance(call, dict) and call.get("id") == tool_use_id:
|
| 121 |
+
fn = call.get("function") or {}
|
| 122 |
+
# arguments is a JSON string in OpenAI; decode best-effort
|
| 123 |
+
import json as _json
|
| 124 |
+
|
| 125 |
+
args: dict[str, Any] = {}
|
| 126 |
+
if isinstance(fn.get("arguments"), str):
|
| 127 |
+
try:
|
| 128 |
+
args = _json.loads(fn["arguments"])
|
| 129 |
+
except Exception: # noqa: BLE001
|
| 130 |
+
args = {}
|
| 131 |
+
elif isinstance(fn.get("arguments"), dict):
|
| 132 |
+
args = fn["arguments"]
|
| 133 |
+
return fn.get("name"), args
|
| 134 |
+
return None, {}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _tool_use_id_for_message(msg: dict[str, Any]) -> str | None:
|
| 138 |
+
"""Return the tool_use_id linked to a tool_result message."""
|
| 139 |
+
# Anthropic format
|
| 140 |
+
content = msg.get("content")
|
| 141 |
+
if isinstance(content, list):
|
| 142 |
+
for block in content:
|
| 143 |
+
if isinstance(block, dict) and block.get("type") == "tool_result":
|
| 144 |
+
tuid = block.get("tool_use_id")
|
| 145 |
+
if isinstance(tuid, str):
|
| 146 |
+
return tuid
|
| 147 |
+
# OpenAI format
|
| 148 |
+
if msg.get("role") == "tool":
|
| 149 |
+
tcid = msg.get("tool_call_id")
|
| 150 |
+
if isinstance(tcid, str):
|
| 151 |
+
return tcid
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def apply_to_messages(
|
| 156 |
+
messages: list[dict[str, Any]],
|
| 157 |
+
tokenizer: Tokenizer,
|
| 158 |
+
) -> InterceptionResult:
|
| 159 |
+
"""Run every registered interceptor against every tool_result in `messages`.
|
| 160 |
+
|
| 161 |
+
Returns the (possibly) rewritten message list and a list of spans that
|
| 162 |
+
actually saved tokens.
|
| 163 |
+
"""
|
| 164 |
+
if not INTERCEPTORS:
|
| 165 |
+
return InterceptionResult(messages=messages, spans=[])
|
| 166 |
+
|
| 167 |
+
new_messages: list[dict[str, Any]] = []
|
| 168 |
+
spans: list[TransformSpan] = []
|
| 169 |
+
# Progressive disclosure: per-interceptor set of keys already rewritten
|
| 170 |
+
# earlier in this message list. Prevents the second Read of the same
|
| 171 |
+
# file from being outlined again — the model evidently came back for
|
| 172 |
+
# more, so give it the raw content.
|
| 173 |
+
fired: dict[str, set[str]] = {}
|
| 174 |
+
|
| 175 |
+
for msg in messages:
|
| 176 |
+
if not _is_tool_result_message(msg):
|
| 177 |
+
new_messages.append(msg)
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
original = _extract_tool_result_content(msg)
|
| 181 |
+
if not isinstance(original, str) or not original:
|
| 182 |
+
new_messages.append(msg)
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
tuid = _tool_use_id_for_message(msg)
|
| 186 |
+
tool_name: str | None = None
|
| 187 |
+
tool_input: dict[str, Any] = {}
|
| 188 |
+
if tuid:
|
| 189 |
+
tool_name, tool_input = _find_tool_use(messages, tuid)
|
| 190 |
+
|
| 191 |
+
current = original
|
| 192 |
+
for interceptor in INTERCEPTORS:
|
| 193 |
+
# Progressive disclosure: skip if already fired for this key.
|
| 194 |
+
key: str | None = None
|
| 195 |
+
key_fn = getattr(interceptor, "progressive_disclosure_key", None)
|
| 196 |
+
if callable(key_fn):
|
| 197 |
+
try:
|
| 198 |
+
key = key_fn(tool_name, tool_input)
|
| 199 |
+
except Exception as e: # noqa: BLE001
|
| 200 |
+
logger.warning("interceptor %s key() failed: %s", interceptor.name, e)
|
| 201 |
+
key = None
|
| 202 |
+
if key and key in fired.get(interceptor.name, set()):
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
if not interceptor.matches(tool_name, tool_input, current):
|
| 207 |
+
continue
|
| 208 |
+
rewritten = interceptor.transform(tool_name, tool_input, current)
|
| 209 |
+
except Exception as e: # noqa: BLE001 — never crash a request
|
| 210 |
+
logger.warning("interceptor %s failed: %s", interceptor.name, e)
|
| 211 |
+
continue
|
| 212 |
+
if not rewritten or rewritten == current:
|
| 213 |
+
continue
|
| 214 |
+
before = tokenizer.count_text(current)
|
| 215 |
+
after = tokenizer.count_text(rewritten)
|
| 216 |
+
if after >= before:
|
| 217 |
+
continue # refuse to enlarge
|
| 218 |
+
spans.append(
|
| 219 |
+
TransformSpan(
|
| 220 |
+
tool=interceptor.name,
|
| 221 |
+
tokens_before=before,
|
| 222 |
+
tokens_after=after,
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
current = rewritten
|
| 226 |
+
if key:
|
| 227 |
+
fired.setdefault(interceptor.name, set()).add(key)
|
| 228 |
+
|
| 229 |
+
new_messages.append(
|
| 230 |
+
_swap_tool_result_content(msg, current) if current is not original else msg
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
return InterceptionResult(messages=new_messages, spans=spans)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class ToolResultInterceptorTransform(Transform):
|
| 237 |
+
"""Pipeline-level adapter: runs interceptors as the first compression stage.
|
| 238 |
+
|
| 239 |
+
Placed at transforms[0] so downstream compressors operate on the already-
|
| 240 |
+
shrunk content. Transform names of firing interceptors are added to
|
| 241 |
+
`transforms_applied` so they appear in existing dashboards/metrics.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
name = "tool_result_interceptors"
|
| 245 |
+
|
| 246 |
+
def apply(
|
| 247 |
+
self,
|
| 248 |
+
messages: list[dict[str, Any]],
|
| 249 |
+
tokenizer: Tokenizer,
|
| 250 |
+
**kwargs: Any,
|
| 251 |
+
) -> TransformResult:
|
| 252 |
+
result = apply_to_messages(messages, tokenizer)
|
| 253 |
+
tokens_after = tokenizer.count_messages(result.messages)
|
| 254 |
+
tokens_before = tokens_after + sum(s.tokens_saved for s in result.spans)
|
| 255 |
+
transforms_applied = [f"interceptor:{s.tool}" for s in result.spans] if result.spans else []
|
| 256 |
+
return TransformResult(
|
| 257 |
+
messages=result.messages,
|
| 258 |
+
tokens_before=tokens_before,
|
| 259 |
+
tokens_after=tokens_after,
|
| 260 |
+
transforms_applied=transforms_applied,
|
| 261 |
+
)
|
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_comment": "Registry of externally fetched CLI tool binaries. Bump versions and SHA256s via the weekly tools-version-check CI job (see .github/workflows/). sha256=null means HTTPS-trust-only (initial bootstrap); the CI job fills real SHAs per release.",
|
| 3 |
+
"tools": {
|
| 4 |
+
"difft": {
|
| 5 |
+
"version": "0.64.0",
|
| 6 |
+
"binary": "difft",
|
| 7 |
+
"source": "Wilfred/difftastic",
|
| 8 |
+
"homepage": "https://difftastic.wilfred.me.uk/",
|
| 9 |
+
"assets": {
|
| 10 |
+
"linux-x86_64-gnu": {
|
| 11 |
+
"url": "https://github.com/Wilfred/difftastic/releases/download/0.64.0/difft-x86_64-unknown-linux-gnu.tar.gz",
|
| 12 |
+
"member": "difft",
|
| 13 |
+
"sha256": null
|
| 14 |
+
},
|
| 15 |
+
"linux-aarch64-gnu": {
|
| 16 |
+
"url": "https://github.com/Wilfred/difftastic/releases/download/0.64.0/difft-aarch64-unknown-linux-gnu.tar.gz",
|
| 17 |
+
"member": "difft",
|
| 18 |
+
"sha256": null
|
| 19 |
+
},
|
| 20 |
+
"darwin-x86_64": {
|
| 21 |
+
"url": "https://github.com/Wilfred/difftastic/releases/download/0.64.0/difft-x86_64-apple-darwin.tar.gz",
|
| 22 |
+
"member": "difft",
|
| 23 |
+
"sha256": null
|
| 24 |
+
},
|
| 25 |
+
"darwin-aarch64": {
|
| 26 |
+
"url": "https://github.com/Wilfred/difftastic/releases/download/0.64.0/difft-aarch64-apple-darwin.tar.gz",
|
| 27 |
+
"member": "difft",
|
| 28 |
+
"sha256": null
|
| 29 |
+
},
|
| 30 |
+
"windows-x86_64": {
|
| 31 |
+
"url": "https://github.com/Wilfred/difftastic/releases/download/0.64.0/difft-x86_64-pc-windows-msvc.zip",
|
| 32 |
+
"member": "difft.exe",
|
| 33 |
+
"sha256": null
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
"scc": {
|
| 38 |
+
"version": "3.5.0",
|
| 39 |
+
"binary": "scc",
|
| 40 |
+
"source": "boyter/scc",
|
| 41 |
+
"homepage": "https://github.com/boyter/scc",
|
| 42 |
+
"assets": {
|
| 43 |
+
"linux-x86_64-gnu": {
|
| 44 |
+
"url": "https://github.com/boyter/scc/releases/download/v3.5.0/scc_Linux_x86_64.tar.gz",
|
| 45 |
+
"member": "scc",
|
| 46 |
+
"sha256": null
|
| 47 |
+
},
|
| 48 |
+
"linux-x86_64-musl": {
|
| 49 |
+
"url": "https://github.com/boyter/scc/releases/download/v3.5.0/scc_Linux_x86_64.tar.gz",
|
| 50 |
+
"member": "scc",
|
| 51 |
+
"sha256": null
|
| 52 |
+
},
|
| 53 |
+
"linux-aarch64-gnu": {
|
| 54 |
+
"url": "https://github.com/boyter/scc/releases/download/v3.5.0/scc_Linux_arm64.tar.gz",
|
| 55 |
+
"member": "scc",
|
| 56 |
+
"sha256": null
|
| 57 |
+
},
|
| 58 |
+
"linux-aarch64-musl": {
|
| 59 |
+
"url": "https://github.com/boyter/scc/releases/download/v3.5.0/scc_Linux_arm64.tar.gz",
|
| 60 |
+
"member": "scc",
|
| 61 |
+
"sha256": null
|
| 62 |
+
},
|
| 63 |
+
"darwin-x86_64": {
|
| 64 |
+
"url": "https://github.com/boyter/scc/releases/download/v3.5.0/scc_Darwin_x86_64.tar.gz",
|
| 65 |
+
"member": "scc",
|
| 66 |
+
"sha256": null
|
| 67 |
+
},
|
| 68 |
+
"darwin-aarch64": {
|
| 69 |
+
"url": "https://github.com/boyter/scc/releases/download/v3.5.0/scc_Darwin_arm64.tar.gz",
|
| 70 |
+
"member": "scc",
|
| 71 |
+
"sha256": null
|
| 72 |
+
},
|
| 73 |
+
"windows-x86_64": {
|
| 74 |
+
"url": "https://github.com/boyter/scc/releases/download/v3.5.0/scc_Windows_x86_64.zip",
|
| 75 |
+
"member": "scc.exe",
|
| 76 |
+
"sha256": null
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
+
"ast-grep": {
|
| 81 |
+
"version": "pypi",
|
| 82 |
+
"binary": "ast-grep",
|
| 83 |
+
"source": "ast-grep/ast-grep (PyPI: ast-grep-cli)",
|
| 84 |
+
"homepage": "https://ast-grep.github.io/",
|
| 85 |
+
"_comment": "Installed via the ast-grep-cli PyPI wheel; we never fetch from GitHub for this tool. Listed here so `headroom tools doctor` can report it.",
|
| 86 |
+
"assets": {}
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
}
|
|
@@ -75,6 +75,18 @@ class TransformPipeline:
|
|
| 75 |
|
| 76 |
# Order matters!
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
# 1. Cache Aligner (prefix stabilization)
|
| 79 |
if self.config.cache_aligner.enabled:
|
| 80 |
transforms.append(CacheAligner(self.config.cache_aligner))
|
|
|
|
| 75 |
|
| 76 |
# Order matters!
|
| 77 |
|
| 78 |
+
# 0. Tool-result interceptors (ast-grep Read outline, etc.) run first
|
| 79 |
+
# so downstream compressors operate on the already-shrunk content.
|
| 80 |
+
# OPT-IN: enable with HEADROOM_INTERCEPT_ENABLED=1 or `headroom proxy
|
| 81 |
+
# --intercept-tool-results`. Off by default while this ships — lets
|
| 82 |
+
# users try it and compare before we make it the default.
|
| 83 |
+
import os as _os
|
| 84 |
+
|
| 85 |
+
if _os.environ.get("HEADROOM_INTERCEPT_ENABLED"):
|
| 86 |
+
from headroom.proxy.interceptors import ToolResultInterceptorTransform
|
| 87 |
+
|
| 88 |
+
transforms.append(ToolResultInterceptorTransform())
|
| 89 |
+
|
| 90 |
# 1. Cache Aligner (prefix stabilization)
|
| 91 |
if self.config.cache_aligner.enabled:
|
| 92 |
transforms.append(CacheAligner(self.config.cache_aligner))
|
|
@@ -51,6 +51,7 @@ dependencies = [
|
|
| 51 |
"click>=8.1.0", # CLI framework
|
| 52 |
"rich>=13.0.0", # Rich terminal output
|
| 53 |
"opentelemetry-api>=1.24.0", # Safe no-op OTEL API for instrumentation
|
|
|
|
| 54 |
]
|
| 55 |
|
| 56 |
[project.optional-dependencies]
|
|
|
|
| 51 |
"click>=8.1.0", # CLI framework
|
| 52 |
"rich>=13.0.0", # Rich terminal output
|
| 53 |
"opentelemetry-api>=1.24.0", # Safe no-op OTEL API for instrumentation
|
| 54 |
+
"ast-grep-cli>=0.30.0", # AST-aware code slicing (CodeCompressor); binary wheel
|
| 55 |
]
|
| 56 |
|
| 57 |
[project.optional-dependencies]
|
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for headroom.binaries — the lazy fetcher for bundled CLI tools.
|
| 2 |
+
|
| 3 |
+
No network access. A fake urlopen serves bytes from an in-memory fixture.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import hashlib
|
| 9 |
+
import io
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import tarfile
|
| 13 |
+
import zipfile
|
| 14 |
+
|
| 15 |
+
import pytest
|
| 16 |
+
|
| 17 |
+
from headroom import binaries
|
| 18 |
+
|
| 19 |
+
# -------- Fixtures -------------------------------------------------------- #
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@pytest.fixture(autouse=True)
|
| 23 |
+
def _clear_caches(monkeypatch, tmp_path):
|
| 24 |
+
"""Isolate every test from global state: cache dir, platform lru_cache, env."""
|
| 25 |
+
binaries.detect_platform.cache_clear()
|
| 26 |
+
binaries._registry.cache_clear()
|
| 27 |
+
monkeypatch.setenv("HEADROOM_BINARIES_CACHE", str(tmp_path / "cache"))
|
| 28 |
+
monkeypatch.delenv("HEADROOM_BINARIES_MIRROR", raising=False)
|
| 29 |
+
monkeypatch.delenv("HEADROOM_BINARIES_OFFLINE", raising=False)
|
| 30 |
+
yield
|
| 31 |
+
binaries.detect_platform.cache_clear()
|
| 32 |
+
binaries._registry.cache_clear()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _set_platform(monkeypatch, *, sys_plat: str, machine: str, musl: bool = False):
|
| 36 |
+
monkeypatch.setattr(sys, "platform", sys_plat)
|
| 37 |
+
monkeypatch.setattr("platform.machine", lambda: machine)
|
| 38 |
+
monkeypatch.setattr(binaries, "_is_musl", lambda: musl)
|
| 39 |
+
binaries.detect_platform.cache_clear()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _make_tar_gz(files: dict[str, bytes]) -> bytes:
|
| 43 |
+
buf = io.BytesIO()
|
| 44 |
+
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
| 45 |
+
for name, data in files.items():
|
| 46 |
+
info = tarfile.TarInfo(name=name)
|
| 47 |
+
info.size = len(data)
|
| 48 |
+
tf.addfile(info, io.BytesIO(data))
|
| 49 |
+
return buf.getvalue()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _make_zip(files: dict[str, bytes]) -> bytes:
|
| 53 |
+
buf = io.BytesIO()
|
| 54 |
+
with zipfile.ZipFile(buf, "w") as zf:
|
| 55 |
+
for name, data in files.items():
|
| 56 |
+
zf.writestr(name, data)
|
| 57 |
+
return buf.getvalue()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class _FakeResponse:
|
| 61 |
+
def __init__(self, data: bytes):
|
| 62 |
+
self._data = data
|
| 63 |
+
self.headers = {"Content-Length": str(len(data))}
|
| 64 |
+
|
| 65 |
+
def read(self, n: int = -1) -> bytes:
|
| 66 |
+
if n < 0 or n >= len(self._data):
|
| 67 |
+
chunk, self._data = self._data, b""
|
| 68 |
+
return chunk
|
| 69 |
+
chunk, self._data = self._data[:n], self._data[n:]
|
| 70 |
+
return chunk
|
| 71 |
+
|
| 72 |
+
def __enter__(self):
|
| 73 |
+
return self
|
| 74 |
+
|
| 75 |
+
def __exit__(self, *a):
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@pytest.fixture
|
| 80 |
+
def fake_urlopen(monkeypatch):
|
| 81 |
+
"""Install a fake urllib.request.urlopen that serves registered URLs."""
|
| 82 |
+
served: dict[str, bytes] = {}
|
| 83 |
+
|
| 84 |
+
def fake(req, timeout=None): # noqa: ARG001
|
| 85 |
+
url = req.full_url if hasattr(req, "full_url") else req
|
| 86 |
+
if url not in served:
|
| 87 |
+
raise AssertionError(f"unexpected fetch for {url}")
|
| 88 |
+
return _FakeResponse(served[url])
|
| 89 |
+
|
| 90 |
+
monkeypatch.setattr(binaries.urllib.request, "urlopen", fake)
|
| 91 |
+
return served
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# -------- Platform detection --------------------------------------------- #
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_detect_platform_linux_gnu(monkeypatch):
|
| 98 |
+
_set_platform(monkeypatch, sys_plat="linux", machine="x86_64", musl=False)
|
| 99 |
+
p = binaries.detect_platform()
|
| 100 |
+
assert p == binaries.PlatformKey("linux", "x86_64", "gnu")
|
| 101 |
+
assert p.key() == "linux-x86_64-gnu"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_detect_platform_linux_musl(monkeypatch):
|
| 105 |
+
_set_platform(monkeypatch, sys_plat="linux", machine="aarch64", musl=True)
|
| 106 |
+
assert binaries.detect_platform().key() == "linux-aarch64-musl"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def test_detect_platform_darwin_arm64(monkeypatch):
|
| 110 |
+
_set_platform(monkeypatch, sys_plat="darwin", machine="arm64")
|
| 111 |
+
assert binaries.detect_platform().key() == "darwin-aarch64"
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def test_detect_platform_windows_amd64(monkeypatch):
|
| 115 |
+
_set_platform(monkeypatch, sys_plat="win32", machine="AMD64")
|
| 116 |
+
assert binaries.detect_platform().key() == "windows-x86_64"
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# -------- Cache dir ------------------------------------------------------ #
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_cache_dir_respects_env_override(monkeypatch, tmp_path):
|
| 123 |
+
monkeypatch.setenv("HEADROOM_BINARIES_CACHE", str(tmp_path / "custom"))
|
| 124 |
+
assert binaries.cache_dir() == (tmp_path / "custom").resolve()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# -------- Registry / asset resolution ------------------------------------ #
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def test_unsupported_platform_raises(monkeypatch):
|
| 131 |
+
_set_platform(monkeypatch, sys_plat="linux", machine="riscv64")
|
| 132 |
+
with pytest.raises(binaries.PlatformNotSupported):
|
| 133 |
+
binaries._asset_for_platform("difft", binaries.detect_platform())
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def test_pypi_only_tool_raises_with_helpful_message(monkeypatch):
|
| 137 |
+
_set_platform(monkeypatch, sys_plat="darwin", machine="arm64")
|
| 138 |
+
with pytest.raises(binaries.PlatformNotSupported) as exc:
|
| 139 |
+
binaries._asset_for_platform("ast-grep", binaries.detect_platform())
|
| 140 |
+
assert "pip install headroom-ai" in str(exc.value)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def test_unknown_tool_raises_key_error():
|
| 144 |
+
with pytest.raises(KeyError):
|
| 145 |
+
binaries._tool_entry("not-a-real-tool")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# -------- which / resolve with PATH hits --------------------------------- #
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def test_which_finds_on_path(monkeypatch, tmp_path):
|
| 152 |
+
fake_bin = tmp_path / "difft"
|
| 153 |
+
fake_bin.write_text("#!/bin/sh\necho ok\n")
|
| 154 |
+
fake_bin.chmod(0o755)
|
| 155 |
+
monkeypatch.setattr(
|
| 156 |
+
binaries.shutil, "which", lambda name: str(fake_bin) if name == "difft" else None
|
| 157 |
+
)
|
| 158 |
+
# Because the tool is on PATH, which() returns its path without fetching.
|
| 159 |
+
assert binaries.which("difft") == fake_bin
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def test_which_returns_none_when_not_cached(monkeypatch):
|
| 163 |
+
_set_platform(monkeypatch, sys_plat="darwin", machine="arm64")
|
| 164 |
+
monkeypatch.setattr(binaries.shutil, "which", lambda _name: None)
|
| 165 |
+
assert binaries.which("difft") is None
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def test_resolve_honors_path(monkeypatch, tmp_path):
|
| 169 |
+
fake_bin = tmp_path / "scc"
|
| 170 |
+
fake_bin.write_text("")
|
| 171 |
+
fake_bin.chmod(0o755)
|
| 172 |
+
monkeypatch.setattr(
|
| 173 |
+
binaries.shutil, "which", lambda name: str(fake_bin) if name == "scc" else None
|
| 174 |
+
)
|
| 175 |
+
assert binaries.resolve("scc") == fake_bin
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# -------- Offline / mirror / fetch behavior ------------------------------ #
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def test_offline_error_when_fetch_required(monkeypatch):
|
| 182 |
+
_set_platform(monkeypatch, sys_plat="darwin", machine="arm64")
|
| 183 |
+
monkeypatch.setattr(binaries.shutil, "which", lambda _name: None)
|
| 184 |
+
monkeypatch.setenv("HEADROOM_BINARIES_OFFLINE", "1")
|
| 185 |
+
with pytest.raises(binaries.OfflineError):
|
| 186 |
+
binaries.resolve("difft")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def test_mirror_substitution():
|
| 190 |
+
os.environ["HEADROOM_BINARIES_MIRROR"] = "https://mirror.example.com/gh"
|
| 191 |
+
try:
|
| 192 |
+
out = binaries._mirror_url(
|
| 193 |
+
"https://github.com/Wilfred/difftastic/releases/download/0.64.0/x.tar.gz"
|
| 194 |
+
)
|
| 195 |
+
assert (
|
| 196 |
+
out
|
| 197 |
+
== "https://mirror.example.com/gh/Wilfred/difftastic/releases/download/0.64.0/x.tar.gz"
|
| 198 |
+
)
|
| 199 |
+
# Non-matching URLs are left alone.
|
| 200 |
+
assert binaries._mirror_url("https://example.com/x") == "https://example.com/x"
|
| 201 |
+
finally:
|
| 202 |
+
del os.environ["HEADROOM_BINARIES_MIRROR"]
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def test_fetch_extract_and_cache_tar_gz(monkeypatch, fake_urlopen, tmp_path):
|
| 206 |
+
_set_platform(monkeypatch, sys_plat="darwin", machine="arm64")
|
| 207 |
+
monkeypatch.setattr(binaries.shutil, "which", lambda _name: None)
|
| 208 |
+
|
| 209 |
+
payload = b"#!/bin/sh\necho fake-difft\n"
|
| 210 |
+
archive = _make_tar_gz({"difft-0.64.0/difft": payload})
|
| 211 |
+
url = "https://github.com/Wilfred/difftastic/releases/download/0.64.0/difft-aarch64-apple-darwin.tar.gz"
|
| 212 |
+
fake_urlopen[url] = archive
|
| 213 |
+
|
| 214 |
+
path = binaries.resolve("difft")
|
| 215 |
+
assert path.exists()
|
| 216 |
+
assert path.read_bytes() == payload
|
| 217 |
+
# Second call should use cache (no further fetch).
|
| 218 |
+
fake_urlopen.pop(url) # remove so a refetch would error
|
| 219 |
+
path2 = binaries.resolve("difft")
|
| 220 |
+
assert path2 == path
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def test_fetch_extract_zip(monkeypatch, fake_urlopen):
|
| 224 |
+
_set_platform(monkeypatch, sys_plat="win32", machine="AMD64")
|
| 225 |
+
monkeypatch.setattr(binaries.shutil, "which", lambda _name: None)
|
| 226 |
+
payload = b"MZfake"
|
| 227 |
+
archive = _make_zip({"scc.exe": payload})
|
| 228 |
+
url = "https://github.com/boyter/scc/releases/download/v3.5.0/scc_Windows_x86_64.zip"
|
| 229 |
+
fake_urlopen[url] = archive
|
| 230 |
+
|
| 231 |
+
path = binaries.resolve("scc")
|
| 232 |
+
assert path.exists()
|
| 233 |
+
assert path.name.endswith("scc.exe")
|
| 234 |
+
assert path.read_bytes() == payload
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def test_sha256_mismatch_raises_and_deletes(monkeypatch, fake_urlopen, tmp_path):
|
| 238 |
+
_set_platform(monkeypatch, sys_plat="darwin", machine="arm64")
|
| 239 |
+
monkeypatch.setattr(binaries.shutil, "which", lambda _name: None)
|
| 240 |
+
|
| 241 |
+
# Override the registry entry for difft to include a bogus sha256.
|
| 242 |
+
reg = binaries._registry()
|
| 243 |
+
asset = reg["tools"]["difft"]["assets"]["darwin-aarch64"]
|
| 244 |
+
asset["sha256"] = "deadbeef" * 8 # wrong
|
| 245 |
+
archive = _make_tar_gz({"difft": b"hi"})
|
| 246 |
+
fake_urlopen[asset["url"]] = archive
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
with pytest.raises(binaries.Sha256Mismatch):
|
| 250 |
+
binaries.resolve("difft")
|
| 251 |
+
finally:
|
| 252 |
+
asset["sha256"] = None # restore
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def test_sha256_match_passes(monkeypatch, fake_urlopen):
|
| 256 |
+
_set_platform(monkeypatch, sys_plat="darwin", machine="arm64")
|
| 257 |
+
monkeypatch.setattr(binaries.shutil, "which", lambda _name: None)
|
| 258 |
+
archive = _make_tar_gz({"difft": b"hello"})
|
| 259 |
+
good = hashlib.sha256(archive).hexdigest()
|
| 260 |
+
reg = binaries._registry()
|
| 261 |
+
asset = reg["tools"]["difft"]["assets"]["darwin-aarch64"]
|
| 262 |
+
asset["sha256"] = good
|
| 263 |
+
fake_urlopen[asset["url"]] = archive
|
| 264 |
+
try:
|
| 265 |
+
path = binaries.resolve("difft")
|
| 266 |
+
assert path.read_bytes() == b"hello"
|
| 267 |
+
finally:
|
| 268 |
+
asset["sha256"] = None
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# -------- status() ------------------------------------------------------- #
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def test_status_reports_every_registered_tool(monkeypatch):
|
| 275 |
+
_set_platform(monkeypatch, sys_plat="darwin", machine="arm64")
|
| 276 |
+
monkeypatch.setattr(binaries.shutil, "which", lambda _name: None)
|
| 277 |
+
rows = binaries.status()
|
| 278 |
+
names = {r["tool"] for r in rows}
|
| 279 |
+
assert {"difft", "scc", "ast-grep"} <= names
|
| 280 |
+
for r in rows:
|
| 281 |
+
assert r["state"] in ("on-path", "cached", "missing", "unsupported-platform")
|
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Comprehensive integration tests for the bundled CLI tools.
|
| 2 |
+
|
| 3 |
+
Proves three things end-to-end:
|
| 4 |
+
|
| 5 |
+
1. `headroom.binaries.ensure_tools()` actually installs every tool.
|
| 6 |
+
2. Each tool reduces token count on a realistic payload (tiktoken-measured).
|
| 7 |
+
3. A real LLM answers the same question correctly on the compressed
|
| 8 |
+
payload (LLM-as-judge).
|
| 9 |
+
|
| 10 |
+
Live API calls are gated on OPENAI_API_KEY / ANTHROPIC_API_KEY being present
|
| 11 |
+
in the environment (loaded from .env if python-dotenv is available).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import subprocess
|
| 19 |
+
import textwrap
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import pytest
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from dotenv import load_dotenv
|
| 26 |
+
|
| 27 |
+
load_dotenv(Path(__file__).resolve().parent.parent / ".env")
|
| 28 |
+
except ImportError:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
import tiktoken
|
| 32 |
+
|
| 33 |
+
from headroom import binaries
|
| 34 |
+
|
| 35 |
+
# ---------- Fixtures ------------------------------------------------------ #
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
ENC = tiktoken.get_encoding("cl100k_base")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _tokens(text: str) -> int:
|
| 42 |
+
return len(ENC.encode(text))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
SAMPLE_PY = textwrap.dedent(
|
| 46 |
+
'''
|
| 47 |
+
"""Payments module — illustrative fixture for compression tests."""
|
| 48 |
+
import logging
|
| 49 |
+
from dataclasses import dataclass
|
| 50 |
+
from decimal import Decimal
|
| 51 |
+
from typing import Iterable
|
| 52 |
+
|
| 53 |
+
log = logging.getLogger(__name__)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class LineItem:
|
| 58 |
+
sku: str
|
| 59 |
+
quantity: int
|
| 60 |
+
unit_price: Decimal
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def compute_subtotal(items: Iterable[LineItem]) -> Decimal:
|
| 64 |
+
total = Decimal("0")
|
| 65 |
+
for item in items:
|
| 66 |
+
total += item.unit_price * item.quantity
|
| 67 |
+
return total
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def apply_promo(subtotal: Decimal, code: str | None) -> Decimal:
|
| 71 |
+
if not code:
|
| 72 |
+
return subtotal
|
| 73 |
+
if code == "SAVE10":
|
| 74 |
+
return subtotal * Decimal("0.9")
|
| 75 |
+
if code == "FREESHIP":
|
| 76 |
+
return subtotal
|
| 77 |
+
log.warning("unknown promo code %s", code)
|
| 78 |
+
return subtotal
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def compute_tax(subtotal: Decimal, rate: Decimal) -> Decimal:
|
| 82 |
+
return (subtotal * rate).quantize(Decimal("0.01"))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def process_payment(items: list[LineItem], promo: str | None, tax_rate: Decimal) -> Decimal:
|
| 86 |
+
"""Main entry point: compute the final total for a cart."""
|
| 87 |
+
subtotal = compute_subtotal(items)
|
| 88 |
+
after_promo = apply_promo(subtotal, promo)
|
| 89 |
+
tax = compute_tax(after_promo, tax_rate)
|
| 90 |
+
total = after_promo + tax
|
| 91 |
+
log.info("processed payment: subtotal=%s tax=%s total=%s", subtotal, tax, total)
|
| 92 |
+
return total
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def refund_payment(order_id: str, amount: Decimal) -> dict:
|
| 96 |
+
"""Issue a refund for a previous order."""
|
| 97 |
+
log.info("refunding %s from %s", amount, order_id)
|
| 98 |
+
return {"order_id": order_id, "refund": str(amount), "status": "ok"}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def list_orders_for_user(user_id: str, limit: int = 20) -> list[dict]:
|
| 102 |
+
"""Placeholder DB lookup."""
|
| 103 |
+
return [{"user": user_id, "order": i} for i in range(limit)]
|
| 104 |
+
'''
|
| 105 |
+
).strip()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
SAMPLE_PY_MODIFIED = SAMPLE_PY.replace(
|
| 109 |
+
'return subtotal * Decimal("0.9")',
|
| 110 |
+
'return subtotal * Decimal("0.85") # promo bumped from 10% to 15%',
|
| 111 |
+
).replace(
|
| 112 |
+
'log.warning("unknown promo code %s", code)',
|
| 113 |
+
'log.error("unknown promo code %s — rejecting", code)\n raise ValueError(code)',
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@pytest.fixture(scope="module")
|
| 118 |
+
def repo(tmp_path_factory) -> Path:
|
| 119 |
+
d = tmp_path_factory.mktemp("payments-repo")
|
| 120 |
+
(d / "payments.py").write_text(SAMPLE_PY)
|
| 121 |
+
(d / "payments_v2.py").write_text(SAMPLE_PY_MODIFIED)
|
| 122 |
+
(d / "README.md").write_text("# payments fixture\n")
|
| 123 |
+
return d
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ---------- 1. Tool installation ----------------------------------------- #
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def test_ensure_tools_installs_every_tool():
|
| 130 |
+
"""All three tools should be reachable after ensure_tools()."""
|
| 131 |
+
binaries.ensure_tools(quiet=True)
|
| 132 |
+
# ast-grep comes from the PyPI wheel (core dep); resolve() checks PATH
|
| 133 |
+
# and sys.prefix/bin so it works in non-activated venvs too.
|
| 134 |
+
assert binaries.resolve("ast-grep").exists(), "ast-grep-cli wheel not installed"
|
| 135 |
+
# difft & scc come from the GitHub-release fetcher.
|
| 136 |
+
assert binaries.which("difft") is not None, "difftastic not installed"
|
| 137 |
+
assert binaries.which("scc") is not None, "scc not installed"
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ---------- 2. Token-savings (no API) ------------------------------------ #
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def test_ast_grep_slice_saves_tokens(repo: Path):
|
| 144 |
+
"""Function-level slice vs full-file — ast-grep must reduce tokens."""
|
| 145 |
+
full = (repo / "payments.py").read_text()
|
| 146 |
+
full_tokens = _tokens(full)
|
| 147 |
+
|
| 148 |
+
# Extract just `process_payment` and `apply_promo` (the two functions an
|
| 149 |
+
# agent would realistically need to reason about a promo-code bug).
|
| 150 |
+
result = subprocess.run(
|
| 151 |
+
[
|
| 152 |
+
str(binaries.resolve("ast-grep")),
|
| 153 |
+
"run",
|
| 154 |
+
"--pattern",
|
| 155 |
+
"def process_payment",
|
| 156 |
+
"--lang",
|
| 157 |
+
"python",
|
| 158 |
+
"--json=stream",
|
| 159 |
+
str(repo / "payments.py"),
|
| 160 |
+
],
|
| 161 |
+
capture_output=True,
|
| 162 |
+
text=True,
|
| 163 |
+
check=True,
|
| 164 |
+
)
|
| 165 |
+
matches = [json.loads(line) for line in result.stdout.strip().splitlines() if line]
|
| 166 |
+
assert matches, "ast-grep returned no matches"
|
| 167 |
+
sliced = "\n\n".join(m["text"] for m in matches)
|
| 168 |
+
sliced_tokens = _tokens(sliced)
|
| 169 |
+
|
| 170 |
+
savings_pct = (1 - sliced_tokens / full_tokens) * 100
|
| 171 |
+
print(f"\n[ast-grep] full={full_tokens}t sliced={sliced_tokens}t savings={savings_pct:.1f}%")
|
| 172 |
+
assert sliced_tokens < full_tokens
|
| 173 |
+
assert savings_pct >= 40, f"expected ≥40% savings, got {savings_pct:.1f}%"
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def test_difftastic_saves_tokens_vs_line_diff(repo: Path):
|
| 177 |
+
"""Structural diff should compress smaller than unified line diff."""
|
| 178 |
+
# Baseline: unified line diff via /usr/bin/diff.
|
| 179 |
+
line_diff = subprocess.run(
|
| 180 |
+
["diff", "-u", str(repo / "payments.py"), str(repo / "payments_v2.py")],
|
| 181 |
+
capture_output=True,
|
| 182 |
+
text=True,
|
| 183 |
+
).stdout
|
| 184 |
+
line_tokens = _tokens(line_diff)
|
| 185 |
+
|
| 186 |
+
# difftastic in a compact display mode.
|
| 187 |
+
struct = subprocess.run(
|
| 188 |
+
[
|
| 189 |
+
str(binaries.resolve("difft")),
|
| 190 |
+
"--display=inline",
|
| 191 |
+
"--color=never",
|
| 192 |
+
str(repo / "payments.py"),
|
| 193 |
+
str(repo / "payments_v2.py"),
|
| 194 |
+
],
|
| 195 |
+
capture_output=True,
|
| 196 |
+
text=True,
|
| 197 |
+
).stdout
|
| 198 |
+
struct_tokens = _tokens(struct)
|
| 199 |
+
|
| 200 |
+
savings_pct = (1 - struct_tokens / line_tokens) * 100 if line_tokens else 0.0
|
| 201 |
+
print(
|
| 202 |
+
f"\n[difftastic] line={line_tokens}t struct={struct_tokens}t savings={savings_pct:.1f}%"
|
| 203 |
+
)
|
| 204 |
+
# On small diffs structural output can occasionally be equal or slightly
|
| 205 |
+
# larger due to display overhead; just assert it doesn't blow up.
|
| 206 |
+
assert struct_tokens <= int(line_tokens * 1.2), (
|
| 207 |
+
f"difft output unexpectedly larger: {struct_tokens} vs {line_tokens}"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def test_scc_repo_shape_card_is_tiny(repo: Path):
|
| 212 |
+
"""scc produces a repo-shape summary that's much smaller than raw files."""
|
| 213 |
+
raw_bytes = sum(
|
| 214 |
+
(repo / p).stat().st_size for p in ("payments.py", "payments_v2.py", "README.md")
|
| 215 |
+
)
|
| 216 |
+
raw_tokens = _tokens((repo / "payments.py").read_text())
|
| 217 |
+
raw_tokens += _tokens((repo / "payments_v2.py").read_text())
|
| 218 |
+
raw_tokens += _tokens((repo / "README.md").read_text())
|
| 219 |
+
|
| 220 |
+
scc_out = subprocess.run(
|
| 221 |
+
[str(binaries.resolve("scc")), "--format=json", str(repo)],
|
| 222 |
+
capture_output=True,
|
| 223 |
+
text=True,
|
| 224 |
+
check=True,
|
| 225 |
+
).stdout
|
| 226 |
+
scc_tokens = _tokens(scc_out)
|
| 227 |
+
|
| 228 |
+
print(f"\n[scc] raw_files={raw_tokens}t scc_card={scc_tokens}t bytes_scanned={raw_bytes}")
|
| 229 |
+
# scc summarizes many files into one small JSON blob; assert it's smaller
|
| 230 |
+
# than the concatenated raw file contents.
|
| 231 |
+
assert scc_tokens < raw_tokens
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ---------- 3. Quality test (live API) ----------------------------------- #
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
_NEED_OPENAI = pytest.mark.skipif(
|
| 238 |
+
not os.environ.get("OPENAI_API_KEY"),
|
| 239 |
+
reason="OPENAI_API_KEY not set",
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
_NEED_ANTHROPIC = pytest.mark.skipif(
|
| 243 |
+
not os.environ.get("ANTHROPIC_API_KEY"),
|
| 244 |
+
reason="ANTHROPIC_API_KEY not set",
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
QUESTION = (
|
| 249 |
+
"In this payments module, what discount percentage does the SAVE10 promo "
|
| 250 |
+
"currently apply? Answer with just the number (e.g. '10')."
|
| 251 |
+
)
|
| 252 |
+
EXPECTED = "10"
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@_NEED_OPENAI
|
| 256 |
+
def test_compressed_payload_preserves_answer_openai(repo: Path):
|
| 257 |
+
"""Model answers the same question correctly on ast-grep-sliced input."""
|
| 258 |
+
import openai # lazy: only required when the key is present
|
| 259 |
+
|
| 260 |
+
full = (repo / "payments.py").read_text()
|
| 261 |
+
|
| 262 |
+
result = subprocess.run(
|
| 263 |
+
[
|
| 264 |
+
str(binaries.resolve("ast-grep")),
|
| 265 |
+
"run",
|
| 266 |
+
"--pattern",
|
| 267 |
+
"def apply_promo",
|
| 268 |
+
"--lang",
|
| 269 |
+
"python",
|
| 270 |
+
"--json=stream",
|
| 271 |
+
str(repo / "payments.py"),
|
| 272 |
+
],
|
| 273 |
+
capture_output=True,
|
| 274 |
+
text=True,
|
| 275 |
+
check=True,
|
| 276 |
+
)
|
| 277 |
+
matches = [json.loads(line) for line in result.stdout.strip().splitlines() if line]
|
| 278 |
+
sliced = matches[0]["text"]
|
| 279 |
+
|
| 280 |
+
client = openai.OpenAI()
|
| 281 |
+
full_tokens = _tokens(full)
|
| 282 |
+
sliced_tokens = _tokens(sliced)
|
| 283 |
+
|
| 284 |
+
full_resp = client.chat.completions.create(
|
| 285 |
+
model="gpt-4o-mini",
|
| 286 |
+
messages=[
|
| 287 |
+
{"role": "system", "content": "You answer briefly and numerically."},
|
| 288 |
+
{"role": "user", "content": f"{QUESTION}\n\n---\n{full}"},
|
| 289 |
+
],
|
| 290 |
+
max_tokens=16,
|
| 291 |
+
temperature=0,
|
| 292 |
+
)
|
| 293 |
+
sliced_resp = client.chat.completions.create(
|
| 294 |
+
model="gpt-4o-mini",
|
| 295 |
+
messages=[
|
| 296 |
+
{"role": "system", "content": "You answer briefly and numerically."},
|
| 297 |
+
{"role": "user", "content": f"{QUESTION}\n\n---\n{sliced}"},
|
| 298 |
+
],
|
| 299 |
+
max_tokens=16,
|
| 300 |
+
temperature=0,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
full_answer = full_resp.choices[0].message.content.strip()
|
| 304 |
+
sliced_answer = sliced_resp.choices[0].message.content.strip()
|
| 305 |
+
full_usage = full_resp.usage.prompt_tokens
|
| 306 |
+
sliced_usage = sliced_resp.usage.prompt_tokens
|
| 307 |
+
|
| 308 |
+
print(f"\n[openai] full_payload={full_tokens}t prompt_tokens={full_usage} → {full_answer!r}")
|
| 309 |
+
print(
|
| 310 |
+
f"[openai] sliced_payload={sliced_tokens}t prompt_tokens={sliced_usage} → {sliced_answer!r}"
|
| 311 |
+
)
|
| 312 |
+
print(f"[openai] prompt-token savings: {(1 - sliced_usage / full_usage) * 100:.1f}%")
|
| 313 |
+
|
| 314 |
+
assert EXPECTED in full_answer, f"baseline failed: {full_answer!r}"
|
| 315 |
+
assert EXPECTED in sliced_answer, f"compressed answer wrong: {sliced_answer!r}"
|
| 316 |
+
assert sliced_usage < full_usage, "compressed payload used more tokens than full"
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
@_NEED_ANTHROPIC
|
| 320 |
+
def test_compressed_payload_preserves_answer_anthropic(repo: Path):
|
| 321 |
+
import anthropic
|
| 322 |
+
|
| 323 |
+
full = (repo / "payments.py").read_text()
|
| 324 |
+
|
| 325 |
+
result = subprocess.run(
|
| 326 |
+
[
|
| 327 |
+
str(binaries.resolve("ast-grep")),
|
| 328 |
+
"run",
|
| 329 |
+
"--pattern",
|
| 330 |
+
"def apply_promo",
|
| 331 |
+
"--lang",
|
| 332 |
+
"python",
|
| 333 |
+
"--json=stream",
|
| 334 |
+
str(repo / "payments.py"),
|
| 335 |
+
],
|
| 336 |
+
capture_output=True,
|
| 337 |
+
text=True,
|
| 338 |
+
check=True,
|
| 339 |
+
)
|
| 340 |
+
sliced = json.loads(result.stdout.strip().splitlines()[0])["text"]
|
| 341 |
+
|
| 342 |
+
client = anthropic.Anthropic()
|
| 343 |
+
full_resp = client.messages.create(
|
| 344 |
+
model="claude-haiku-4-5-20251001",
|
| 345 |
+
max_tokens=16,
|
| 346 |
+
system="You answer briefly and numerically.",
|
| 347 |
+
messages=[{"role": "user", "content": f"{QUESTION}\n\n---\n{full}"}],
|
| 348 |
+
)
|
| 349 |
+
sliced_resp = client.messages.create(
|
| 350 |
+
model="claude-haiku-4-5-20251001",
|
| 351 |
+
max_tokens=16,
|
| 352 |
+
system="You answer briefly and numerically.",
|
| 353 |
+
messages=[{"role": "user", "content": f"{QUESTION}\n\n---\n{sliced}"}],
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
full_answer = full_resp.content[0].text.strip()
|
| 357 |
+
sliced_answer = sliced_resp.content[0].text.strip()
|
| 358 |
+
print(f"\n[anthropic] full prompt_tokens={full_resp.usage.input_tokens} → {full_answer!r}")
|
| 359 |
+
print(f"[anthropic] sliced prompt_tokens={sliced_resp.usage.input_tokens} → {sliced_answer!r}")
|
| 360 |
+
print(
|
| 361 |
+
f"[anthropic] savings: "
|
| 362 |
+
f"{(1 - sliced_resp.usage.input_tokens / full_resp.usage.input_tokens) * 100:.1f}%"
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
assert EXPECTED in full_answer, f"baseline failed: {full_answer!r}"
|
| 366 |
+
assert EXPECTED in sliced_answer, f"compressed answer wrong: {sliced_answer!r}"
|
| 367 |
+
assert sliced_resp.usage.input_tokens < full_resp.usage.input_tokens
|
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the tool_result interceptor framework + ast-grep Read outliner."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import textwrap
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from headroom.proxy.interceptors import (
|
| 10 |
+
INTERCEPTORS,
|
| 11 |
+
ToolResultInterceptor,
|
| 12 |
+
apply_to_messages,
|
| 13 |
+
register,
|
| 14 |
+
)
|
| 15 |
+
from headroom.proxy.interceptors.astgrep import AstGrepReadOutline
|
| 16 |
+
from headroom.tokenizer import Tokenizer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _FakeTokenCounter:
|
| 20 |
+
"""Deterministic 4-chars-per-token counter for unit tests."""
|
| 21 |
+
|
| 22 |
+
def count_text(self, text: str) -> int:
|
| 23 |
+
return max(1, len(text) // 4)
|
| 24 |
+
|
| 25 |
+
def count_messages(self, messages) -> int:
|
| 26 |
+
total = 0
|
| 27 |
+
for m in messages:
|
| 28 |
+
c = m.get("content")
|
| 29 |
+
if isinstance(c, str):
|
| 30 |
+
total += self.count_text(c)
|
| 31 |
+
elif isinstance(c, list):
|
| 32 |
+
for b in c:
|
| 33 |
+
if isinstance(b, dict):
|
| 34 |
+
inner = b.get("content") or b.get("text") or ""
|
| 35 |
+
if isinstance(inner, str):
|
| 36 |
+
total += self.count_text(inner)
|
| 37 |
+
return total
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@pytest.fixture
|
| 41 |
+
def tokenizer() -> Tokenizer:
|
| 42 |
+
# Real Tokenizer wrapping the fake counter; mirrors production construction.
|
| 43 |
+
return Tokenizer(_FakeTokenCounter()) # type: ignore[arg-type]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# -------- Framework basics ----------------------------------------------- #
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def test_astgrep_interceptor_registered_by_default():
|
| 50 |
+
assert any(i.name == "ast-grep" for i in INTERCEPTORS)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_register_is_idempotent_on_name():
|
| 54 |
+
before = len(INTERCEPTORS)
|
| 55 |
+
register(AstGrepReadOutline()) # same name
|
| 56 |
+
assert len(INTERCEPTORS) == before
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def test_custom_interceptor_plugs_in(tokenizer):
|
| 60 |
+
class UpperCase:
|
| 61 |
+
name = "uppercase-test"
|
| 62 |
+
|
| 63 |
+
def matches(self, tool_name, tool_input, tool_output):
|
| 64 |
+
return tool_name == "Echo"
|
| 65 |
+
|
| 66 |
+
def transform(self, tool_name, tool_input, tool_output):
|
| 67 |
+
# Must REDUCE tokens — use a single short marker.
|
| 68 |
+
return "X"
|
| 69 |
+
|
| 70 |
+
dummy: ToolResultInterceptor = UpperCase() # type: ignore[assignment]
|
| 71 |
+
register(dummy)
|
| 72 |
+
try:
|
| 73 |
+
messages = [
|
| 74 |
+
{
|
| 75 |
+
"role": "assistant",
|
| 76 |
+
"content": [{"type": "tool_use", "id": "1", "name": "Echo", "input": {}}],
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"role": "user",
|
| 80 |
+
"content": [
|
| 81 |
+
{
|
| 82 |
+
"type": "tool_result",
|
| 83 |
+
"tool_use_id": "1",
|
| 84 |
+
"content": "hello " * 100,
|
| 85 |
+
}
|
| 86 |
+
],
|
| 87 |
+
},
|
| 88 |
+
]
|
| 89 |
+
result = apply_to_messages(messages, tokenizer)
|
| 90 |
+
assert any(s.tool == "uppercase-test" for s in result.spans)
|
| 91 |
+
swapped = result.messages[1]["content"][0]["content"]
|
| 92 |
+
assert swapped == "X"
|
| 93 |
+
finally:
|
| 94 |
+
INTERCEPTORS[:] = [i for i in INTERCEPTORS if i.name != "uppercase-test"]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_pass_through_when_no_interceptor_matches(tokenizer):
|
| 98 |
+
messages = [
|
| 99 |
+
{
|
| 100 |
+
"role": "assistant",
|
| 101 |
+
"content": [{"type": "tool_use", "id": "1", "name": "Unknown", "input": {}}],
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"role": "user",
|
| 105 |
+
"content": [{"type": "tool_result", "tool_use_id": "1", "content": "x" * 5000}],
|
| 106 |
+
},
|
| 107 |
+
]
|
| 108 |
+
result = apply_to_messages(messages, tokenizer)
|
| 109 |
+
assert result.spans == []
|
| 110 |
+
assert result.messages[1] is messages[1] # untouched identity
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# -------- ast-grep interceptor ------------------------------------------- #
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
_PY_FIXTURE = textwrap.dedent(
|
| 117 |
+
'''
|
| 118 |
+
"""Payments module fixture."""
|
| 119 |
+
from decimal import Decimal
|
| 120 |
+
|
| 121 |
+
def compute_subtotal(items):
|
| 122 |
+
total = Decimal("0")
|
| 123 |
+
for item in items:
|
| 124 |
+
total += item.price * item.qty
|
| 125 |
+
return total
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def apply_promo(subtotal, code):
|
| 129 |
+
if not code:
|
| 130 |
+
return subtotal
|
| 131 |
+
if code == "SAVE10":
|
| 132 |
+
return subtotal * Decimal("0.9")
|
| 133 |
+
return subtotal
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def compute_tax(subtotal, rate):
|
| 137 |
+
return (subtotal * rate).quantize(Decimal("0.01"))
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def process_payment(items, promo, tax_rate):
|
| 141 |
+
"""Main entry point."""
|
| 142 |
+
subtotal = compute_subtotal(items)
|
| 143 |
+
after = apply_promo(subtotal, promo)
|
| 144 |
+
tax = compute_tax(after, tax_rate)
|
| 145 |
+
return after + tax
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def refund(order_id, amount):
|
| 149 |
+
"""Issue a refund."""
|
| 150 |
+
return {"order": order_id, "refund": str(amount)}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def list_orders_for_user(user_id, limit=20):
|
| 154 |
+
"""Placeholder DB lookup for a user's orders."""
|
| 155 |
+
return [{"user": user_id, "order": i} for i in range(limit)]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def cancel_order(order_id, reason=None):
|
| 159 |
+
"""Cancel an order, logging the reason if provided."""
|
| 160 |
+
return {"order": order_id, "cancelled": True, "reason": reason or "unspecified"}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def summarize_cart(items):
|
| 164 |
+
"""Return a one-line summary of cart contents."""
|
| 165 |
+
skus = [i.sku for i in items]
|
| 166 |
+
total_qty = sum(i.qty for i in items)
|
| 167 |
+
return f"{len(items)} line items ({total_qty} units): {', '.join(skus)}"
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def format_receipt(order_id, items, total):
|
| 171 |
+
"""Render a textual receipt."""
|
| 172 |
+
lines = [f"Order {order_id}"]
|
| 173 |
+
for i in items:
|
| 174 |
+
lines.append(f" {i.sku} x {i.qty} @ {i.unit_price} = {i.qty * i.unit_price}")
|
| 175 |
+
lines.append(f"Total: {total}")
|
| 176 |
+
return "\\n".join(lines)
|
| 177 |
+
'''
|
| 178 |
+
).strip()
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def test_astgrep_outlines_large_python_read(tokenizer):
|
| 182 |
+
messages = [
|
| 183 |
+
{
|
| 184 |
+
"role": "assistant",
|
| 185 |
+
"content": [
|
| 186 |
+
{
|
| 187 |
+
"type": "tool_use",
|
| 188 |
+
"id": "abc",
|
| 189 |
+
"name": "Read",
|
| 190 |
+
"input": {"file_path": "/repo/payments.py"},
|
| 191 |
+
}
|
| 192 |
+
],
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"role": "user",
|
| 196 |
+
"content": [{"type": "tool_result", "tool_use_id": "abc", "content": _PY_FIXTURE}],
|
| 197 |
+
},
|
| 198 |
+
]
|
| 199 |
+
result = apply_to_messages(messages, tokenizer)
|
| 200 |
+
assert len(result.spans) == 1
|
| 201 |
+
span = result.spans[0]
|
| 202 |
+
assert span.tool == "ast-grep"
|
| 203 |
+
assert span.tokens_after < span.tokens_before
|
| 204 |
+
new_content = result.messages[1]["content"][0]["content"]
|
| 205 |
+
assert "outlined by ast-grep" in new_content
|
| 206 |
+
assert "body elided" in new_content
|
| 207 |
+
assert "def process_payment" in new_content
|
| 208 |
+
assert "def apply_promo" in new_content
|
| 209 |
+
# Bodies should NOT leak through unchanged.
|
| 210 |
+
assert "total += item.price * item.qty" not in new_content
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def test_astgrep_skips_small_files(tokenizer):
|
| 214 |
+
small = "def foo(): return 1\n"
|
| 215 |
+
messages = [
|
| 216 |
+
{
|
| 217 |
+
"role": "assistant",
|
| 218 |
+
"content": [
|
| 219 |
+
{
|
| 220 |
+
"type": "tool_use",
|
| 221 |
+
"id": "x",
|
| 222 |
+
"name": "Read",
|
| 223 |
+
"input": {"file_path": "/a.py"},
|
| 224 |
+
}
|
| 225 |
+
],
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"role": "user",
|
| 229 |
+
"content": [{"type": "tool_result", "tool_use_id": "x", "content": small}],
|
| 230 |
+
},
|
| 231 |
+
]
|
| 232 |
+
result = apply_to_messages(messages, tokenizer)
|
| 233 |
+
assert result.spans == []
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def test_astgrep_skips_non_code_extensions(tokenizer):
|
| 237 |
+
messages = [
|
| 238 |
+
{
|
| 239 |
+
"role": "assistant",
|
| 240 |
+
"content": [
|
| 241 |
+
{
|
| 242 |
+
"type": "tool_use",
|
| 243 |
+
"id": "r",
|
| 244 |
+
"name": "Read",
|
| 245 |
+
"input": {"file_path": "/notes.txt"},
|
| 246 |
+
}
|
| 247 |
+
],
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"role": "user",
|
| 251 |
+
"content": [{"type": "tool_result", "tool_use_id": "r", "content": "x" * 3000}],
|
| 252 |
+
},
|
| 253 |
+
]
|
| 254 |
+
result = apply_to_messages(messages, tokenizer)
|
| 255 |
+
assert result.spans == []
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# -------- OpenAI-format tool_result -------------------------------------- #
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def test_astgrep_skips_when_line_range_requested(tokenizer):
|
| 262 |
+
"""If the tool_input specifies a line range, the model wants those lines — pass through."""
|
| 263 |
+
messages = [
|
| 264 |
+
{
|
| 265 |
+
"role": "assistant",
|
| 266 |
+
"content": [
|
| 267 |
+
{
|
| 268 |
+
"type": "tool_use",
|
| 269 |
+
"id": "r",
|
| 270 |
+
"name": "Read",
|
| 271 |
+
"input": {
|
| 272 |
+
"file_path": "/repo/payments.py",
|
| 273 |
+
"offset": 30,
|
| 274 |
+
"limit": 20,
|
| 275 |
+
},
|
| 276 |
+
}
|
| 277 |
+
],
|
| 278 |
+
},
|
| 279 |
+
{
|
| 280 |
+
"role": "user",
|
| 281 |
+
"content": [{"type": "tool_result", "tool_use_id": "r", "content": _PY_FIXTURE}],
|
| 282 |
+
},
|
| 283 |
+
]
|
| 284 |
+
result = apply_to_messages(messages, tokenizer)
|
| 285 |
+
assert result.spans == []
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def test_progressive_disclosure_second_read_passes_through(tokenizer):
|
| 289 |
+
"""First Read of a file gets outlined; second Read of the same path is untouched."""
|
| 290 |
+
messages = [
|
| 291 |
+
# Turn 1: Read foo.py → outlined
|
| 292 |
+
{
|
| 293 |
+
"role": "assistant",
|
| 294 |
+
"content": [
|
| 295 |
+
{
|
| 296 |
+
"type": "tool_use",
|
| 297 |
+
"id": "t1",
|
| 298 |
+
"name": "Read",
|
| 299 |
+
"input": {"file_path": "/repo/payments.py"},
|
| 300 |
+
}
|
| 301 |
+
],
|
| 302 |
+
},
|
| 303 |
+
{
|
| 304 |
+
"role": "user",
|
| 305 |
+
"content": [{"type": "tool_result", "tool_use_id": "t1", "content": _PY_FIXTURE}],
|
| 306 |
+
},
|
| 307 |
+
# Turn 2: Read foo.py again (model came back for more) → pass through
|
| 308 |
+
{
|
| 309 |
+
"role": "assistant",
|
| 310 |
+
"content": [
|
| 311 |
+
{
|
| 312 |
+
"type": "tool_use",
|
| 313 |
+
"id": "t2",
|
| 314 |
+
"name": "Read",
|
| 315 |
+
"input": {"file_path": "/repo/payments.py"},
|
| 316 |
+
}
|
| 317 |
+
],
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"role": "user",
|
| 321 |
+
"content": [{"type": "tool_result", "tool_use_id": "t2", "content": _PY_FIXTURE}],
|
| 322 |
+
},
|
| 323 |
+
]
|
| 324 |
+
result = apply_to_messages(messages, tokenizer)
|
| 325 |
+
# Only the first Read is rewritten; the second keeps its full body.
|
| 326 |
+
assert len(result.spans) == 1
|
| 327 |
+
first_tr = result.messages[1]["content"][0]["content"]
|
| 328 |
+
second_tr = result.messages[3]["content"][0]["content"]
|
| 329 |
+
assert "outlined by ast-grep" in first_tr
|
| 330 |
+
assert "outlined by ast-grep" not in second_tr
|
| 331 |
+
assert "def process_payment" in second_tr
|
| 332 |
+
# Second Read preserves the bodies.
|
| 333 |
+
assert "subtotal = compute_subtotal(items)" in second_tr
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def test_progressive_disclosure_different_file_still_outlined(tokenizer):
|
| 337 |
+
"""Reading a DIFFERENT file after the first outline should still outline."""
|
| 338 |
+
messages = [
|
| 339 |
+
{
|
| 340 |
+
"role": "assistant",
|
| 341 |
+
"content": [
|
| 342 |
+
{
|
| 343 |
+
"type": "tool_use",
|
| 344 |
+
"id": "t1",
|
| 345 |
+
"name": "Read",
|
| 346 |
+
"input": {"file_path": "/repo/payments.py"},
|
| 347 |
+
}
|
| 348 |
+
],
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"role": "user",
|
| 352 |
+
"content": [{"type": "tool_result", "tool_use_id": "t1", "content": _PY_FIXTURE}],
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"role": "assistant",
|
| 356 |
+
"content": [
|
| 357 |
+
{
|
| 358 |
+
"type": "tool_use",
|
| 359 |
+
"id": "t2",
|
| 360 |
+
"name": "Read",
|
| 361 |
+
"input": {"file_path": "/repo/other.py"},
|
| 362 |
+
}
|
| 363 |
+
],
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"role": "user",
|
| 367 |
+
"content": [{"type": "tool_result", "tool_use_id": "t2", "content": _PY_FIXTURE}],
|
| 368 |
+
},
|
| 369 |
+
]
|
| 370 |
+
result = apply_to_messages(messages, tokenizer)
|
| 371 |
+
# Both files get outlined — different keys.
|
| 372 |
+
assert len(result.spans) == 2
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def test_openai_format_tool_result_is_rewritten(tokenizer):
|
| 376 |
+
messages = [
|
| 377 |
+
{
|
| 378 |
+
"role": "assistant",
|
| 379 |
+
"content": None,
|
| 380 |
+
"tool_calls": [
|
| 381 |
+
{
|
| 382 |
+
"id": "call_1",
|
| 383 |
+
"type": "function",
|
| 384 |
+
"function": {
|
| 385 |
+
"name": "Read",
|
| 386 |
+
"arguments": '{"file_path": "/x/payments.py"}',
|
| 387 |
+
},
|
| 388 |
+
}
|
| 389 |
+
],
|
| 390 |
+
},
|
| 391 |
+
{
|
| 392 |
+
"role": "tool",
|
| 393 |
+
"tool_call_id": "call_1",
|
| 394 |
+
"content": _PY_FIXTURE,
|
| 395 |
+
},
|
| 396 |
+
]
|
| 397 |
+
result = apply_to_messages(messages, tokenizer)
|
| 398 |
+
assert len(result.spans) == 1
|
| 399 |
+
new_content = result.messages[1]["content"]
|
| 400 |
+
assert "outlined by ast-grep" in new_content
|