chopratejas commited on
Commit
4102402
·
1 Parent(s): 39a55b4

Add Phase 2 Progressive Summarization and fix Agno integration tests

Browse files

Phase 2 - Progressive Summarization:
- Add ProgressiveSummarizer with callback pattern for external summarization
- Add AnchoredSummary for tracking which message positions were summarized
- Add SummarizationResult for tracking summarization operations
- Add extractive_summarizer fallback when no LLM callback provided
- Integrate CCR for storing originals and enabling retrieval
- Add SUMMARIZE strategy to IntelligentContextManager
- Add comprehensive tests (59 total for intelligent context)

Agno Integration Fix:
- Add _ensure_message_objects() to convert dicts to Agno Message objects
- Fix response(), response_stream(), aresponse(), aresponse_stream() to
ensure messages are Message objects before calling super()
- Update test mocks to use proper ModelResponse and Metrics objects
- All 66 Agno tests now pass

headroom/integrations/agno/model.py CHANGED
@@ -232,17 +232,14 @@ class HeadroomAgnoModel(Model): # type: ignore[misc]
232
  result.append({"role": "user", "content": content})
233
  return result
234
 
235
- def _convert_messages_from_openai(
236
- self, messages: list[dict[str, Any]], original_messages: list[Any]
237
- ) -> list[Any]:
238
- """Convert OpenAI format messages back to Agno Message objects.
239
 
240
- The Agno base model's response() method expects Message objects,
241
- not dicts, because it calls .log() on them internally.
242
 
243
  Args:
244
- messages: The optimized messages in OpenAI dict format
245
- original_messages: The original Agno Message objects (for reference)
246
 
247
  Returns:
248
  List of Agno Message objects
@@ -252,8 +249,7 @@ class HeadroomAgnoModel(Model): # type: ignore[misc]
252
  result = []
253
  for msg in messages:
254
  if isinstance(msg, dict):
255
- # Convert dict back to Agno Message
256
- # Handle the basic fields that Headroom might have modified
257
  try:
258
  result.append(AgnoMessage.from_dict(msg))
259
  except Exception:
@@ -271,6 +267,24 @@ class HeadroomAgnoModel(Model): # type: ignore[misc]
271
  result.append(msg)
272
  return result
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  def _optimize_messages(self, messages: list[Any]) -> tuple[list[Any], OptimizationMetrics]:
275
  """Apply Headroom optimization to messages.
276
 
@@ -375,7 +389,9 @@ class HeadroomAgnoModel(Model): # type: ignore[misc]
375
 
376
  This ensures tool outputs are compressed on subsequent API calls.
377
  """
378
- # Don't optimize here - let the tool loop in Model.response() call invoke(),
 
 
379
  # which will optimize messages for EACH API call (including tool results)
380
  return super().response(messages, **kwargs)
381
 
@@ -385,6 +401,8 @@ class HeadroomAgnoModel(Model): # type: ignore[misc]
385
  Like response(), delegates to inherited Model.response_stream() which
386
  calls self.invoke_stream() for each API call.
387
  """
 
 
388
  # Let the inherited streaming method handle the tool loop
389
  yield from super().response_stream(messages, **kwargs)
390
 
@@ -394,6 +412,8 @@ class HeadroomAgnoModel(Model): # type: ignore[misc]
394
  Delegates to inherited Model.aresponse() which calls self.ainvoke()
395
  for each API call, ensuring tool outputs are optimized.
396
  """
 
 
397
  # Let the inherited async method handle the tool loop
398
  return await super().aresponse(messages, **kwargs)
399
 
@@ -403,6 +423,8 @@ class HeadroomAgnoModel(Model): # type: ignore[misc]
403
  Delegates to inherited Model.aresponse_stream() which calls self.ainvoke_stream()
404
  for each API call, ensuring tool outputs are optimized.
405
  """
 
 
406
  # Let the inherited async streaming method handle the tool loop
407
  async for chunk in super().aresponse_stream(messages, **kwargs):
408
  yield chunk
 
232
  result.append({"role": "user", "content": content})
233
  return result
234
 
235
+ def _ensure_message_objects(self, messages: list[Any]) -> list[Any]:
236
+ """Ensure all messages are Agno Message objects (not dicts).
 
 
237
 
238
+ Agno's base Model methods call _log_messages() which requires
239
+ Message objects with a .log() method.
240
 
241
  Args:
242
+ messages: List of messages (may be dicts or Message objects)
 
243
 
244
  Returns:
245
  List of Agno Message objects
 
249
  result = []
250
  for msg in messages:
251
  if isinstance(msg, dict):
252
+ # Convert dict to Agno Message
 
253
  try:
254
  result.append(AgnoMessage.from_dict(msg))
255
  except Exception:
 
267
  result.append(msg)
268
  return result
269
 
270
+ def _convert_messages_from_openai(
271
+ self, messages: list[dict[str, Any]], original_messages: list[Any]
272
+ ) -> list[Any]:
273
+ """Convert OpenAI format messages back to Agno Message objects.
274
+
275
+ The Agno base model's response() method expects Message objects,
276
+ not dicts, because it calls .log() on them internally.
277
+
278
+ Args:
279
+ messages: The optimized messages in OpenAI dict format
280
+ original_messages: The original Agno Message objects (for reference)
281
+
282
+ Returns:
283
+ List of Agno Message objects
284
+ """
285
+ # Reuse the ensure method which handles the conversion
286
+ return self._ensure_message_objects(messages)
287
+
288
  def _optimize_messages(self, messages: list[Any]) -> tuple[list[Any], OptimizationMetrics]:
289
  """Apply Headroom optimization to messages.
290
 
 
389
 
390
  This ensures tool outputs are compressed on subsequent API calls.
391
  """
392
+ # Ensure messages are Message objects (Agno's _log_messages requires .log() method)
393
+ messages = self._ensure_message_objects(messages)
394
+ # Let the tool loop in Model.response() call invoke(),
395
  # which will optimize messages for EACH API call (including tool results)
396
  return super().response(messages, **kwargs)
397
 
 
401
  Like response(), delegates to inherited Model.response_stream() which
402
  calls self.invoke_stream() for each API call.
403
  """
404
+ # Ensure messages are Message objects (Agno's _log_messages requires .log() method)
405
+ messages = self._ensure_message_objects(messages)
406
  # Let the inherited streaming method handle the tool loop
407
  yield from super().response_stream(messages, **kwargs)
408
 
 
412
  Delegates to inherited Model.aresponse() which calls self.ainvoke()
413
  for each API call, ensuring tool outputs are optimized.
414
  """
415
+ # Ensure messages are Message objects (Agno's _log_messages requires .log() method)
416
+ messages = self._ensure_message_objects(messages)
417
  # Let the inherited async method handle the tool loop
418
  return await super().aresponse(messages, **kwargs)
419
 
 
423
  Delegates to inherited Model.aresponse_stream() which calls self.ainvoke_stream()
424
  for each API call, ensuring tool outputs are optimized.
425
  """
426
+ # Ensure messages are Message objects (Agno's _log_messages requires .log() method)
427
+ messages = self._ensure_message_objects(messages)
428
  # Let the inherited async streaming method handle the tool loop
429
  async for chunk in super().aresponse_stream(messages, **kwargs):
430
  yield chunk
headroom/transforms/intelligent_context.py CHANGED
@@ -9,10 +9,12 @@ All importance signals are derived from:
9
  2. TOIN-learned patterns (field_semantics, retrieval_rate)
10
  3. Embedding similarity (optional)
11
 
12
- Strategy Selection:
13
  - NONE: Under budget, no action needed
14
  - COMPRESS_FIRST: When <compress_threshold over budget, try deeper compression
15
  of tool outputs using ContentRouter before dropping messages
 
 
16
  - DROP_BY_SCORE: When significantly over budget, drop lowest-scored messages
17
  """
18
 
@@ -32,6 +34,7 @@ from .scoring import MessageScore, MessageScorer
32
  if TYPE_CHECKING:
33
  from ..telemetry.toin import ToolIntelligenceNetwork
34
  from .content_router import ContentRouter
 
35
 
36
  logger = logging.getLogger(__name__)
37
 
@@ -41,6 +44,7 @@ class ContextStrategy(Enum):
41
 
42
  NONE = "none" # Under budget, do nothing
43
  COMPRESS_FIRST = "compress" # Try deeper compression first
 
44
  DROP_BY_SCORE = "drop_scored" # Drop lowest-scored messages
45
  HYBRID = "hybrid" # Combination of strategies
46
 
@@ -72,6 +76,7 @@ class IntelligentContextManager(Transform):
72
  self,
73
  config: IntelligentContextConfig | None = None,
74
  toin: ToolIntelligenceNetwork | None = None,
 
75
  ):
76
  """
77
  Initialize intelligent context manager.
@@ -79,11 +84,15 @@ class IntelligentContextManager(Transform):
79
  Args:
80
  config: Configuration for context management.
81
  toin: Optional TOIN instance for learned patterns.
 
 
 
82
  """
83
  from ..config import IntelligentContextConfig
84
 
85
  self.config = config or IntelligentContextConfig()
86
  self.toin = toin
 
87
 
88
  # Initialize scorer with TOIN if available
89
  self.scorer = MessageScorer(
@@ -95,6 +104,9 @@ class IntelligentContextManager(Transform):
95
  # Lazy-loaded content router for COMPRESS_FIRST strategy
96
  self._content_router: ContentRouter | None = None
97
 
 
 
 
98
  def should_apply(
99
  self,
100
  messages: list[dict[str, Any]],
@@ -187,16 +199,61 @@ class IntelligentContextManager(Transform):
187
  warnings=warnings,
188
  )
189
 
190
- # Still over budget, fall through to DROP_BY_SCORE
191
  logger.debug(
192
  "IntelligentContextManager: COMPRESS_FIRST saved %d tokens but still "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  "over budget (%d > %d), proceeding to DROP_BY_SCORE",
194
  tokens_saved,
195
  current_tokens,
196
  available,
197
  )
198
  strategy = ContextStrategy.DROP_BY_SCORE
199
- # Need to recalculate protected indices after compression
200
  protected = self._get_protected_indices(result_messages)
201
 
202
  # ========== DROP_BY_SCORE STRATEGY ==========
@@ -301,15 +358,28 @@ class IntelligentContextManager(Transform):
301
  )
302
 
303
  def _select_strategy(self, current_tokens: int, available: int) -> ContextStrategy:
304
- """Select strategy based on how much over budget we are."""
 
 
 
 
 
 
 
305
  if current_tokens <= available:
306
  return ContextStrategy.NONE
307
 
308
  over_ratio = (current_tokens - available) / available
309
 
 
310
  if over_ratio < self.config.compress_threshold:
311
  return ContextStrategy.COMPRESS_FIRST
312
 
 
 
 
 
 
313
  return ContextStrategy.DROP_BY_SCORE
314
 
315
  def _get_content_router(self) -> ContentRouter | None:
@@ -684,3 +754,74 @@ class IntelligentContextManager(Transform):
684
  )
685
 
686
  return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  2. TOIN-learned patterns (field_semantics, retrieval_rate)
10
  3. Embedding similarity (optional)
11
 
12
+ Strategy Selection (in order of preference):
13
  - NONE: Under budget, no action needed
14
  - COMPRESS_FIRST: When <compress_threshold over budget, try deeper compression
15
  of tool outputs using ContentRouter before dropping messages
16
+ - SUMMARIZE: When <summarize_threshold over budget and summarization_enabled,
17
+ create anchored summaries of older messages (requires summarize_fn callback)
18
  - DROP_BY_SCORE: When significantly over budget, drop lowest-scored messages
19
  """
20
 
 
34
  if TYPE_CHECKING:
35
  from ..telemetry.toin import ToolIntelligenceNetwork
36
  from .content_router import ContentRouter
37
+ from .progressive_summarizer import ProgressiveSummarizer, SummarizeFn
38
 
39
  logger = logging.getLogger(__name__)
40
 
 
44
 
45
  NONE = "none" # Under budget, do nothing
46
  COMPRESS_FIRST = "compress" # Try deeper compression first
47
+ SUMMARIZE = "summarize" # Create anchored summaries of older messages
48
  DROP_BY_SCORE = "drop_scored" # Drop lowest-scored messages
49
  HYBRID = "hybrid" # Combination of strategies
50
 
 
76
  self,
77
  config: IntelligentContextConfig | None = None,
78
  toin: ToolIntelligenceNetwork | None = None,
79
+ summarize_fn: SummarizeFn | None = None,
80
  ):
81
  """
82
  Initialize intelligent context manager.
 
84
  Args:
85
  config: Configuration for context management.
86
  toin: Optional TOIN instance for learned patterns.
87
+ summarize_fn: Optional callback for summarization.
88
+ If provided and summarization_enabled=True, enables SUMMARIZE strategy.
89
+ Signature: (messages: list[dict], context: str) -> str
90
  """
91
  from ..config import IntelligentContextConfig
92
 
93
  self.config = config or IntelligentContextConfig()
94
  self.toin = toin
95
+ self._summarize_fn = summarize_fn
96
 
97
  # Initialize scorer with TOIN if available
98
  self.scorer = MessageScorer(
 
104
  # Lazy-loaded content router for COMPRESS_FIRST strategy
105
  self._content_router: ContentRouter | None = None
106
 
107
+ # Lazy-loaded progressive summarizer for SUMMARIZE strategy
108
+ self._progressive_summarizer: ProgressiveSummarizer | None = None
109
+
110
  def should_apply(
111
  self,
112
  messages: list[dict[str, Any]],
 
199
  warnings=warnings,
200
  )
201
 
202
+ # Still over budget, fall through to SUMMARIZE or DROP_BY_SCORE
203
  logger.debug(
204
  "IntelligentContextManager: COMPRESS_FIRST saved %d tokens but still "
205
+ "over budget (%d > %d), checking next strategy",
206
+ tokens_saved,
207
+ current_tokens,
208
+ available,
209
+ )
210
+ # Check if we should try summarization next
211
+ over_ratio = (current_tokens - available) / available
212
+ if self.config.summarization_enabled and over_ratio < self.config.summarize_threshold:
213
+ strategy = ContextStrategy.SUMMARIZE
214
+ else:
215
+ strategy = ContextStrategy.DROP_BY_SCORE
216
+ # Need to recalculate protected indices after compression
217
+ protected = self._get_protected_indices(result_messages)
218
+
219
+ # ========== SUMMARIZE STRATEGY ==========
220
+ # Create anchored summaries of older messages
221
+ if strategy == ContextStrategy.SUMMARIZE:
222
+ result_messages, summarize_transforms, tokens_saved = self._apply_summarize(
223
+ result_messages, tokenizer, protected, available
224
+ )
225
+ transforms_applied.extend(summarize_transforms)
226
+
227
+ # Recheck token count after summarization
228
+ current_tokens = tokenizer.count_messages(result_messages)
229
+
230
+ # If now under budget, we're done!
231
+ if current_tokens <= available:
232
+ logger.info(
233
+ "IntelligentContextManager: SUMMARIZE succeeded, saved %d tokens: %d -> %d",
234
+ tokens_saved,
235
+ tokens_before,
236
+ current_tokens,
237
+ )
238
+ return TransformResult(
239
+ messages=result_messages,
240
+ tokens_before=tokens_before,
241
+ tokens_after=current_tokens,
242
+ transforms_applied=transforms_applied,
243
+ markers_inserted=markers_inserted,
244
+ warnings=warnings,
245
+ )
246
+
247
+ # Still over budget, fall through to DROP_BY_SCORE
248
+ logger.debug(
249
+ "IntelligentContextManager: SUMMARIZE saved %d tokens but still "
250
  "over budget (%d > %d), proceeding to DROP_BY_SCORE",
251
  tokens_saved,
252
  current_tokens,
253
  available,
254
  )
255
  strategy = ContextStrategy.DROP_BY_SCORE
256
+ # Need to recalculate protected indices after summarization
257
  protected = self._get_protected_indices(result_messages)
258
 
259
  # ========== DROP_BY_SCORE STRATEGY ==========
 
358
  )
359
 
360
  def _select_strategy(self, current_tokens: int, available: int) -> ContextStrategy:
361
+ """Select strategy based on how much over budget we are.
362
+
363
+ Strategy selection order:
364
+ 1. NONE: Under budget
365
+ 2. COMPRESS_FIRST: < compress_threshold (default 10%) over budget
366
+ 3. SUMMARIZE: < summarize_threshold (default 25%) over budget AND enabled
367
+ 4. DROP_BY_SCORE: >= summarize_threshold over budget
368
+ """
369
  if current_tokens <= available:
370
  return ContextStrategy.NONE
371
 
372
  over_ratio = (current_tokens - available) / available
373
 
374
+ # Tier 1: Try compression first for small overages
375
  if over_ratio < self.config.compress_threshold:
376
  return ContextStrategy.COMPRESS_FIRST
377
 
378
+ # Tier 2: Try summarization for moderate overages (if enabled)
379
+ if self.config.summarization_enabled and over_ratio < self.config.summarize_threshold:
380
+ return ContextStrategy.SUMMARIZE
381
+
382
+ # Tier 3: Drop by score for large overages
383
  return ContextStrategy.DROP_BY_SCORE
384
 
385
  def _get_content_router(self) -> ContentRouter | None:
 
754
  )
755
 
756
  return scores
757
+
758
+ def _get_progressive_summarizer(self) -> ProgressiveSummarizer | None:
759
+ """Get or create progressive summarizer for SUMMARIZE strategy (lazy load)."""
760
+ if self._progressive_summarizer is None:
761
+ try:
762
+ from .progressive_summarizer import ProgressiveSummarizer
763
+
764
+ self._progressive_summarizer = ProgressiveSummarizer(
765
+ summarize_fn=self._summarize_fn,
766
+ max_summary_tokens=self.config.summary_max_tokens,
767
+ min_messages_to_summarize=3,
768
+ store_for_retrieval=True,
769
+ )
770
+ except ImportError:
771
+ logger.debug("ProgressiveSummarizer not available for SUMMARIZE")
772
+ return self._progressive_summarizer
773
+
774
+ def _apply_summarize(
775
+ self,
776
+ messages: list[dict[str, Any]],
777
+ tokenizer: Tokenizer,
778
+ protected: set[int],
779
+ target_tokens: int,
780
+ ) -> tuple[list[dict[str, Any]], list[str], int]:
781
+ """Apply progressive summarization to older messages.
782
+
783
+ This is the SUMMARIZE strategy: create anchored summaries of older
784
+ messages to reduce token count while maintaining retrievability.
785
+
786
+ Args:
787
+ messages: List of messages to summarize.
788
+ tokenizer: Tokenizer for counting.
789
+ protected: Set of protected message indices.
790
+ target_tokens: Target token budget.
791
+
792
+ Returns:
793
+ Tuple of (summarized_messages, transforms_applied, tokens_saved).
794
+ """
795
+ summarizer = self._get_progressive_summarizer()
796
+ if summarizer is None:
797
+ return messages, [], 0
798
+
799
+ # Get recent messages for context
800
+ context_messages = []
801
+ for i in sorted(protected):
802
+ if i < len(messages):
803
+ context_messages.append(messages[i])
804
+
805
+ try:
806
+ result = summarizer.summarize_messages(
807
+ messages=messages,
808
+ tokenizer=tokenizer,
809
+ protected_indices=protected,
810
+ target_tokens=target_tokens,
811
+ context_messages=context_messages[-5:], # Last 5 for context
812
+ )
813
+
814
+ tokens_saved = result.tokens_before - result.tokens_after
815
+
816
+ if tokens_saved > 0:
817
+ logger.info(
818
+ "SUMMARIZE: created %d summaries, saved %d tokens",
819
+ len(result.summaries_created),
820
+ tokens_saved,
821
+ )
822
+
823
+ return result.messages, result.transforms_applied, tokens_saved
824
+
825
+ except Exception as e:
826
+ logger.warning("SUMMARIZE: summarization failed: %s", e)
827
+ return messages, [], 0
headroom/transforms/progressive_summarizer.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Progressive summarization for Headroom SDK.
2
+
3
+ This module provides anchored summarization that progressively summarizes
4
+ older messages while maintaining retrieval capability via CCR.
5
+
6
+ Design principles:
7
+ 1. CALLBACK PATTERN: Summarization is done via a callback, not internal LLM calls
8
+ 2. ANCHORED: Summaries track which message positions they represent
9
+ 3. REVERSIBLE: Original content stored in CompressionStore for CCR retrieval
10
+ 4. INCREMENTAL: Only summarize newly dropped spans, then merge
11
+
12
+ Usage:
13
+ from headroom.transforms import ProgressiveSummarizer
14
+
15
+ # With custom summarizer callback
16
+ def my_summarizer(messages: list[dict], context: str) -> str:
17
+ # Your summarization logic (LLM call, extractive, etc.)
18
+ return "Summary of messages..."
19
+
20
+ summarizer = ProgressiveSummarizer(
21
+ summarize_fn=my_summarizer,
22
+ max_summary_tokens=500,
23
+ )
24
+
25
+ result = summarizer.summarize_messages(messages, tokenizer, protected)
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import hashlib
31
+ import json
32
+ import logging
33
+ import time
34
+ from dataclasses import dataclass, field
35
+ from typing import TYPE_CHECKING, Any, Protocol
36
+
37
+ if TYPE_CHECKING:
38
+ from ..cache.compression_store import CompressionStore
39
+ from ..tokenizer import Tokenizer
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class SummarizeFn(Protocol):
45
+ """Protocol for summarization callback functions.
46
+
47
+ The callback receives:
48
+ - messages: List of messages to summarize
49
+ - context: Optional context string (e.g., recent messages for relevance)
50
+
51
+ Returns:
52
+ - Summary string
53
+ """
54
+
55
+ def __call__(
56
+ self,
57
+ messages: list[dict[str, Any]],
58
+ context: str = "",
59
+ ) -> str: ...
60
+
61
+
62
+ @dataclass
63
+ class AnchoredSummary:
64
+ """A summary anchored to specific message positions.
65
+
66
+ Tracks which messages were summarized for:
67
+ - Retrieval: Can reconstruct original messages via CCR
68
+ - Merging: Can merge with adjacent summaries
69
+ - Positioning: Know where in conversation this summary belongs
70
+ """
71
+
72
+ summary_text: str
73
+ start_index: int # First message index summarized
74
+ end_index: int # Last message index summarized (inclusive)
75
+ original_message_count: int
76
+ original_tokens: int
77
+ summary_tokens: int
78
+ cache_hash: str | None = None # Hash for CCR retrieval
79
+ tool_names: list[str] = field(default_factory=list)
80
+ created_at: float = field(default_factory=time.time)
81
+
82
+ @property
83
+ def compression_ratio(self) -> float:
84
+ """Ratio of summary tokens to original tokens (lower = more compression)."""
85
+ if self.original_tokens == 0:
86
+ return 1.0
87
+ return self.summary_tokens / self.original_tokens
88
+
89
+ @property
90
+ def tokens_saved(self) -> int:
91
+ """Number of tokens saved by summarization."""
92
+ return max(0, self.original_tokens - self.summary_tokens)
93
+
94
+
95
+ @dataclass
96
+ class SummarizationResult:
97
+ """Result of a summarization operation."""
98
+
99
+ messages: list[dict[str, Any]]
100
+ summaries_created: list[AnchoredSummary]
101
+ tokens_before: int
102
+ tokens_after: int
103
+ transforms_applied: list[str]
104
+
105
+ @property
106
+ def tokens_saved(self) -> int:
107
+ """Total tokens saved."""
108
+ return max(0, self.tokens_before - self.tokens_after)
109
+
110
+
111
+ def extractive_summarizer(
112
+ messages: list[dict[str, Any]],
113
+ context: str = "",
114
+ max_items_per_role: int = 2,
115
+ ) -> str:
116
+ """Default extractive summarizer (no LLM required).
117
+
118
+ Creates a summary by extracting key content from messages:
119
+ - First and last message of each role
120
+ - Error indicators
121
+ - Tool names and brief results
122
+
123
+ This is a fallback when no LLM summarizer is provided.
124
+
125
+ Args:
126
+ messages: Messages to summarize.
127
+ context: Optional context (unused in extractive mode).
128
+ max_items_per_role: Max items to keep per role type.
129
+
130
+ Returns:
131
+ Extractive summary string.
132
+ """
133
+ if not messages:
134
+ return "[No messages to summarize]"
135
+
136
+ parts: list[str] = []
137
+ parts.append(f"[Summary of {len(messages)} messages]")
138
+
139
+ # Group by role
140
+ by_role: dict[str, list[dict[str, Any]]] = {}
141
+ for msg in messages:
142
+ role = msg.get("role", "unknown")
143
+ by_role.setdefault(role, []).append(msg)
144
+
145
+ # Extract key content from each role
146
+ for role, role_msgs in by_role.items():
147
+ if role == "tool":
148
+ # For tool messages, extract tool names and brief status
149
+ tool_names = set()
150
+ has_error = False
151
+ for msg in role_msgs:
152
+ content = msg.get("content", "")
153
+ # Try to detect tool name from context
154
+ tool_call_id = msg.get("tool_call_id", "")
155
+ if tool_call_id:
156
+ tool_names.add(f"tool:{tool_call_id[:8]}")
157
+
158
+ # Check for errors
159
+ content_lower = content.lower() if isinstance(content, str) else ""
160
+ if any(err in content_lower for err in ["error", "failed", "exception"]):
161
+ has_error = True
162
+
163
+ status = "with errors" if has_error else "successful"
164
+ parts.append(f"- {len(role_msgs)} tool outputs ({status})")
165
+
166
+ elif role == "assistant":
167
+ # Extract first and last assistant responses
168
+ if len(role_msgs) == 1:
169
+ content = role_msgs[0].get("content", "")
170
+ if isinstance(content, str):
171
+ preview = content[:100] + "..." if len(content) > 100 else content
172
+ parts.append(f"- Assistant: {preview}")
173
+ else:
174
+ parts.append(f"- {len(role_msgs)} assistant messages")
175
+
176
+ elif role == "user":
177
+ # Count user messages
178
+ parts.append(f"- {len(role_msgs)} user messages")
179
+
180
+ elif role == "system":
181
+ # Note system messages (shouldn't be summarized usually)
182
+ parts.append(f"- {len(role_msgs)} system messages")
183
+
184
+ return "\n".join(parts)
185
+
186
+
187
+ class ProgressiveSummarizer:
188
+ """Progressive summarization with anchoring and CCR integration.
189
+
190
+ This class implements the SUMMARIZE strategy for IntelligentContextManager:
191
+ 1. Identifies candidate messages (low-scored, non-protected)
192
+ 2. Groups consecutive messages for summarization
193
+ 3. Calls summarizer callback to create summaries
194
+ 4. Stores originals in CompressionStore for CCR retrieval
195
+ 5. Replaces messages with anchored summary message
196
+
197
+ Key features:
198
+ - Callback pattern: No LLM calls inside, summarization logic is external
199
+ - Anchored: Summaries track original positions for context
200
+ - Reversible: Originals cached for retrieval
201
+ - Incremental: Can merge adjacent summaries
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ summarize_fn: SummarizeFn | None = None,
207
+ max_summary_tokens: int = 500,
208
+ min_messages_to_summarize: int = 3,
209
+ compression_store: CompressionStore | None = None,
210
+ store_for_retrieval: bool = True,
211
+ ):
212
+ """Initialize the progressive summarizer.
213
+
214
+ Args:
215
+ summarize_fn: Callback function for summarization.
216
+ If None, uses extractive_summarizer as fallback.
217
+ max_summary_tokens: Target max tokens for each summary.
218
+ min_messages_to_summarize: Minimum messages in a group to summarize.
219
+ compression_store: Optional CompressionStore for CCR integration.
220
+ store_for_retrieval: Whether to store originals for retrieval.
221
+ """
222
+ self.summarize_fn = summarize_fn or extractive_summarizer
223
+ self.max_summary_tokens = max_summary_tokens
224
+ self.min_messages_to_summarize = min_messages_to_summarize
225
+ self._compression_store = compression_store
226
+ self.store_for_retrieval = store_for_retrieval
227
+
228
+ def _get_compression_store(self) -> CompressionStore | None:
229
+ """Get or create compression store (lazy load)."""
230
+ if self._compression_store is None and self.store_for_retrieval:
231
+ try:
232
+ from ..cache.compression_store import get_compression_store
233
+
234
+ self._compression_store = get_compression_store()
235
+ except ImportError:
236
+ logger.debug("CompressionStore not available for CCR")
237
+ return self._compression_store
238
+
239
+ def summarize_messages(
240
+ self,
241
+ messages: list[dict[str, Any]],
242
+ tokenizer: Tokenizer,
243
+ protected_indices: set[int],
244
+ target_tokens: int | None = None,
245
+ context_messages: list[dict[str, Any]] | None = None,
246
+ ) -> SummarizationResult:
247
+ """Summarize messages to reduce token count.
248
+
249
+ Args:
250
+ messages: List of messages to process.
251
+ tokenizer: Tokenizer for counting.
252
+ protected_indices: Indices that cannot be summarized.
253
+ target_tokens: Target token count (optional, summarizes all candidates if None).
254
+ context_messages: Recent messages for context in summarization.
255
+
256
+ Returns:
257
+ SummarizationResult with summarized messages.
258
+ """
259
+ from ..utils import deep_copy_messages
260
+
261
+ tokens_before = tokenizer.count_messages(messages)
262
+ result_messages = deep_copy_messages(messages)
263
+ transforms_applied: list[str] = []
264
+ summaries_created: list[AnchoredSummary] = []
265
+
266
+ # Find candidate groups for summarization
267
+ candidate_groups = self._find_summarization_candidates(result_messages, protected_indices)
268
+
269
+ if not candidate_groups:
270
+ logger.debug("ProgressiveSummarizer: no candidates for summarization")
271
+ return SummarizationResult(
272
+ messages=result_messages,
273
+ summaries_created=[],
274
+ tokens_before=tokens_before,
275
+ tokens_after=tokens_before,
276
+ transforms_applied=[],
277
+ )
278
+
279
+ # Build context string from recent messages
280
+ context_str = ""
281
+ if context_messages:
282
+ context_parts = []
283
+ for msg in context_messages[-3:]: # Last 3 messages for context
284
+ role = msg.get("role", "")
285
+ content = msg.get("content", "")
286
+ if isinstance(content, str) and content:
287
+ preview = content[:200] if len(content) > 200 else content
288
+ context_parts.append(f"{role}: {preview}")
289
+ context_str = "\n".join(context_parts)
290
+
291
+ # Process groups in reverse order (so indices stay valid)
292
+ current_tokens = tokens_before
293
+
294
+ for group in reversed(candidate_groups):
295
+ # Check if we've reached target
296
+ if target_tokens and current_tokens <= target_tokens:
297
+ break
298
+
299
+ start_idx, end_idx = group
300
+ group_messages = result_messages[start_idx : end_idx + 1]
301
+
302
+ # Skip if too few messages
303
+ if len(group_messages) < self.min_messages_to_summarize:
304
+ continue
305
+
306
+ # Calculate group tokens
307
+ group_tokens = sum(tokenizer.count_message(msg) for msg in group_messages)
308
+
309
+ # Skip small groups
310
+ if group_tokens < 100:
311
+ continue
312
+
313
+ # Create summary using callback
314
+ try:
315
+ summary_text = self.summarize_fn(group_messages, context_str)
316
+ except Exception as e:
317
+ logger.warning(
318
+ "ProgressiveSummarizer: summarization failed for group %d-%d: %s",
319
+ start_idx,
320
+ end_idx,
321
+ e,
322
+ )
323
+ continue
324
+
325
+ summary_tokens = tokenizer.count_text(summary_text)
326
+
327
+ # Only use summary if it saves tokens
328
+ if summary_tokens >= group_tokens:
329
+ logger.debug(
330
+ "ProgressiveSummarizer: summary not smaller (%d >= %d), skipping",
331
+ summary_tokens,
332
+ group_tokens,
333
+ )
334
+ continue
335
+
336
+ # Store original for CCR retrieval
337
+ cache_hash = None
338
+ if self.store_for_retrieval:
339
+ cache_hash = self._store_for_retrieval(
340
+ group_messages, summary_text, group_tokens, summary_tokens
341
+ )
342
+
343
+ # Extract tool names
344
+ tool_names = []
345
+ for msg in group_messages:
346
+ if msg.get("role") == "tool":
347
+ tool_call_id = msg.get("tool_call_id", "")
348
+ if tool_call_id:
349
+ tool_names.append(tool_call_id[:8])
350
+
351
+ # Create anchored summary
352
+ anchored = AnchoredSummary(
353
+ summary_text=summary_text,
354
+ start_index=start_idx,
355
+ end_index=end_idx,
356
+ original_message_count=len(group_messages),
357
+ original_tokens=group_tokens,
358
+ summary_tokens=summary_tokens,
359
+ cache_hash=cache_hash,
360
+ tool_names=tool_names,
361
+ )
362
+ summaries_created.append(anchored)
363
+
364
+ # Create summary message with retrieval marker
365
+ summary_content = summary_text
366
+ if cache_hash:
367
+ summary_content += f"\n[Retrieve full content: hash={cache_hash}]"
368
+
369
+ summary_message = {
370
+ "role": "user",
371
+ "content": summary_content,
372
+ }
373
+
374
+ # Replace group with summary message
375
+ result_messages = (
376
+ result_messages[:start_idx] + [summary_message] + result_messages[end_idx + 1 :]
377
+ )
378
+
379
+ # Update token count
380
+ tokens_saved = group_tokens - summary_tokens
381
+ current_tokens -= tokens_saved
382
+
383
+ transforms_applied.append(f"summarize:{start_idx}-{end_idx}:{len(group_messages)}")
384
+
385
+ logger.debug(
386
+ "ProgressiveSummarizer: summarized %d messages (%d-%d), saved %d tokens (%d -> %d)",
387
+ len(group_messages),
388
+ start_idx,
389
+ end_idx,
390
+ tokens_saved,
391
+ group_tokens,
392
+ summary_tokens,
393
+ )
394
+
395
+ # Update protected indices for subsequent groups
396
+ # (indices shift after replacement)
397
+ shift = len(group_messages) - 1 # We replaced N messages with 1
398
+ protected_indices = {idx - shift if idx > end_idx else idx for idx in protected_indices}
399
+
400
+ tokens_after = tokenizer.count_messages(result_messages)
401
+
402
+ if summaries_created:
403
+ logger.info(
404
+ "ProgressiveSummarizer: created %d summaries, saved %d tokens (%d -> %d)",
405
+ len(summaries_created),
406
+ tokens_before - tokens_after,
407
+ tokens_before,
408
+ tokens_after,
409
+ )
410
+
411
+ return SummarizationResult(
412
+ messages=result_messages,
413
+ summaries_created=summaries_created,
414
+ tokens_before=tokens_before,
415
+ tokens_after=tokens_after,
416
+ transforms_applied=transforms_applied,
417
+ )
418
+
419
+ def _find_summarization_candidates(
420
+ self,
421
+ messages: list[dict[str, Any]],
422
+ protected: set[int],
423
+ ) -> list[tuple[int, int]]:
424
+ """Find groups of consecutive messages that can be summarized.
425
+
426
+ Returns list of (start_index, end_index) tuples for candidate groups.
427
+ Groups are consecutive non-protected messages.
428
+
429
+ Args:
430
+ messages: List of messages.
431
+ protected: Set of protected indices.
432
+
433
+ Returns:
434
+ List of (start, end) tuples for candidate groups.
435
+ """
436
+ groups: list[tuple[int, int]] = []
437
+ current_start: int | None = None
438
+
439
+ for i, _msg in enumerate(messages):
440
+ if i in protected:
441
+ # End current group if exists
442
+ if current_start is not None:
443
+ if i - 1 >= current_start:
444
+ groups.append((current_start, i - 1))
445
+ current_start = None
446
+ else:
447
+ # Start or continue group
448
+ if current_start is None:
449
+ current_start = i
450
+
451
+ # Handle final group
452
+ if current_start is not None and len(messages) - 1 >= current_start:
453
+ groups.append((current_start, len(messages) - 1))
454
+
455
+ # Filter groups that are too small
456
+ groups = [
457
+ (start, end)
458
+ for start, end in groups
459
+ if end - start + 1 >= self.min_messages_to_summarize
460
+ ]
461
+
462
+ return groups
463
+
464
+ def _store_for_retrieval(
465
+ self,
466
+ messages: list[dict[str, Any]],
467
+ summary: str,
468
+ original_tokens: int,
469
+ summary_tokens: int,
470
+ ) -> str | None:
471
+ """Store original messages in CompressionStore for CCR retrieval.
472
+
473
+ Args:
474
+ messages: Original messages.
475
+ summary: Summary text.
476
+ original_tokens: Token count of originals.
477
+ summary_tokens: Token count of summary.
478
+
479
+ Returns:
480
+ Cache hash for retrieval, or None if storage failed.
481
+ """
482
+ store = self._get_compression_store()
483
+ if store is None:
484
+ return None
485
+
486
+ try:
487
+ # Serialize messages for storage
488
+ original_content = json.dumps(messages, ensure_ascii=False)
489
+
490
+ # Generate hash
491
+ content_hash = hashlib.sha256(original_content.encode()).hexdigest()[:24]
492
+
493
+ # Store in compression store
494
+ store.store(
495
+ original=original_content,
496
+ compressed=summary,
497
+ original_tokens=original_tokens,
498
+ compressed_tokens=summary_tokens,
499
+ original_item_count=len(messages),
500
+ compressed_item_count=1,
501
+ tool_name="progressive_summarizer",
502
+ )
503
+
504
+ return content_hash
505
+
506
+ except Exception as e:
507
+ logger.debug("Failed to store for CCR retrieval: %s", e)
508
+ return None
tests/test_integrations/agno/test_model.py CHANGED
@@ -29,6 +29,8 @@ pytestmark = pytest.mark.skipif(not AGNO_AVAILABLE, reason="Agno not installed")
29
  @pytest.fixture
30
  def mock_agno_model():
31
  """Create a mock Agno model (OpenAIChat-like)."""
 
 
32
  mock = MagicMock()
33
  mock.__class__.__name__ = "OpenAIChat"
34
  mock.__class__.__module__ = "agno.models.openai"
@@ -46,12 +48,45 @@ def mock_agno_model():
46
 
47
  mock.response = MagicMock(side_effect=mock_response)
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Mock streaming response
50
  def mock_stream(messages, **kwargs):
51
  yield MagicMock(content="Streaming...")
52
 
53
  mock.response_stream = MagicMock(side_effect=mock_stream)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  return mock
56
 
57
 
 
29
  @pytest.fixture
30
  def mock_agno_model():
31
  """Create a mock Agno model (OpenAIChat-like)."""
32
+ from agno.models.response import ModelResponse
33
+
34
  mock = MagicMock()
35
  mock.__class__.__name__ = "OpenAIChat"
36
  mock.__class__.__module__ = "agno.models.openai"
 
48
 
49
  mock.response = MagicMock(side_effect=mock_response)
50
 
51
+ # Mock invoke method (returns ModelResponse for Agno's response() loop)
52
+ def mock_invoke(messages, **kwargs):
53
+ from agno.models.metrics import Metrics
54
+
55
+ # Create a proper ModelResponse that Agno's response() can process
56
+ return ModelResponse(
57
+ role="assistant",
58
+ content="Hello! I'm a mock response.",
59
+ response_usage=Metrics(
60
+ input_tokens=10,
61
+ output_tokens=5,
62
+ total_tokens=15,
63
+ ),
64
+ )
65
+
66
+ mock.invoke = MagicMock(side_effect=mock_invoke)
67
+
68
  # Mock streaming response
69
  def mock_stream(messages, **kwargs):
70
  yield MagicMock(content="Streaming...")
71
 
72
  mock.response_stream = MagicMock(side_effect=mock_stream)
73
 
74
+ # Mock invoke_stream for streaming
75
+ def mock_invoke_stream(messages, **kwargs):
76
+ from agno.models.metrics import Metrics
77
+
78
+ yield ModelResponse(
79
+ role="assistant",
80
+ content="Streaming...",
81
+ response_usage=Metrics(
82
+ input_tokens=10,
83
+ output_tokens=5,
84
+ total_tokens=15,
85
+ ),
86
+ )
87
+
88
+ mock.invoke_stream = MagicMock(side_effect=mock_invoke_stream)
89
+
90
  return mock
91
 
92
 
tests/test_transforms/test_intelligent_context.py CHANGED
@@ -1279,3 +1279,540 @@ class TestCompressFirstEdgeCases:
1279
  # The recent messages should be protected
1280
  # With 6 messages and keep_last_turns=5, most should be protected
1281
  assert len(protected) > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1279
  # The recent messages should be protected
1280
  # With 6 messages and keep_last_turns=5, most should be protected
1281
  assert len(protected) > 0
1282
+
1283
+
1284
+ # ==============================================================================
1285
+ # SUMMARIZE STRATEGY TESTS
1286
+ # ==============================================================================
1287
+
1288
+
1289
+ class TestSummarizeStrategySelection:
1290
+ """Tests for SUMMARIZE strategy selection logic."""
1291
+
1292
+ def test_summarize_strategy_selected_when_enabled(self, tokenizer: Tokenizer):
1293
+ """SUMMARIZE should be selected when enabled and in threshold range."""
1294
+ messages = [
1295
+ {"role": "system", "content": "System"},
1296
+ {"role": "user", "content": "Hello " * 100},
1297
+ {"role": "assistant", "content": "Response " * 100},
1298
+ {"role": "user", "content": "More " * 100},
1299
+ {"role": "assistant", "content": "More response " * 100},
1300
+ {"role": "user", "content": "Final"},
1301
+ ]
1302
+
1303
+ config = IntelligentContextConfig(
1304
+ summarization_enabled=True,
1305
+ compress_threshold=0.05, # 5% triggers COMPRESS_FIRST
1306
+ summarize_threshold=0.30, # 30% is threshold for DROP_BY_SCORE
1307
+ keep_last_turns=1,
1308
+ )
1309
+ manager = IntelligentContextManager(config=config)
1310
+
1311
+ tokens = tokenizer.count_messages(messages)
1312
+ # Set limit so we're ~15% over (between compress and summarize thresholds)
1313
+ available = int(tokens / 1.15)
1314
+
1315
+ strategy = manager._select_strategy(tokens, available)
1316
+ assert strategy == ContextStrategy.SUMMARIZE
1317
+
1318
+ def test_summarize_not_selected_when_disabled(self, tokenizer: Tokenizer):
1319
+ """SUMMARIZE should not be selected when disabled."""
1320
+ messages = [
1321
+ {"role": "system", "content": "System"},
1322
+ {"role": "user", "content": "Hello " * 100},
1323
+ {"role": "assistant", "content": "Response " * 100},
1324
+ {"role": "user", "content": "Final"},
1325
+ ]
1326
+
1327
+ config = IntelligentContextConfig(
1328
+ summarization_enabled=False, # Disabled
1329
+ compress_threshold=0.05,
1330
+ summarize_threshold=0.30,
1331
+ )
1332
+ manager = IntelligentContextManager(config=config)
1333
+
1334
+ tokens = tokenizer.count_messages(messages)
1335
+ available = int(tokens / 1.15) # 15% over
1336
+
1337
+ strategy = manager._select_strategy(tokens, available)
1338
+ # Should skip SUMMARIZE and go to DROP_BY_SCORE
1339
+ assert strategy == ContextStrategy.DROP_BY_SCORE
1340
+
1341
+ def test_drop_strategy_when_over_summarize_threshold(self, tokenizer: Tokenizer):
1342
+ """DROP_BY_SCORE when over summarize_threshold even if enabled."""
1343
+ messages = [
1344
+ {"role": "system", "content": "System"},
1345
+ {"role": "user", "content": "Hello " * 100},
1346
+ {"role": "assistant", "content": "Response " * 100},
1347
+ ]
1348
+
1349
+ config = IntelligentContextConfig(
1350
+ summarization_enabled=True,
1351
+ compress_threshold=0.05,
1352
+ summarize_threshold=0.20,
1353
+ )
1354
+ manager = IntelligentContextManager(config=config)
1355
+
1356
+ tokens = tokenizer.count_messages(messages)
1357
+ available = int(tokens / 1.50) # 50% over - way over threshold
1358
+
1359
+ strategy = manager._select_strategy(tokens, available)
1360
+ assert strategy == ContextStrategy.DROP_BY_SCORE
1361
+
1362
+
1363
+ class TestSummarizeStrategy:
1364
+ """Tests for SUMMARIZE strategy execution."""
1365
+
1366
+ def test_summarize_reduces_tokens(self, tokenizer: Tokenizer):
1367
+ """SUMMARIZE should reduce token count."""
1368
+ # Create conversation with many messages to summarize
1369
+ messages = [
1370
+ {"role": "system", "content": "You are a helpful assistant."},
1371
+ ]
1372
+
1373
+ # Add many user/assistant turns
1374
+ for i in range(10):
1375
+ messages.append({"role": "user", "content": f"Question {i}: " + "explain this " * 20})
1376
+ messages.append(
1377
+ {"role": "assistant", "content": f"Answer {i}: " + "here is my response " * 30}
1378
+ )
1379
+
1380
+ messages.append({"role": "user", "content": "Final question"})
1381
+ messages.append({"role": "assistant", "content": "Final answer"})
1382
+
1383
+ config = IntelligentContextConfig(
1384
+ summarization_enabled=True,
1385
+ compress_threshold=0.05, # Low, so we skip COMPRESS_FIRST
1386
+ summarize_threshold=0.30,
1387
+ keep_last_turns=2, # Protect last 2 turns
1388
+ )
1389
+ manager = IntelligentContextManager(config=config)
1390
+
1391
+ tokens_before = tokenizer.count_messages(messages)
1392
+ # Set limit to trigger SUMMARIZE (15% over)
1393
+ target_limit = int(tokens_before / 1.15)
1394
+
1395
+ result = manager.apply(
1396
+ messages,
1397
+ tokenizer,
1398
+ model_limit=target_limit,
1399
+ output_buffer=50,
1400
+ )
1401
+
1402
+ # Should have reduced tokens
1403
+ assert result.tokens_after < result.tokens_before
1404
+
1405
+ def test_summarize_with_custom_summarizer(self, tokenizer: Tokenizer):
1406
+ """SUMMARIZE should use custom summarizer callback."""
1407
+ summarizer_called = []
1408
+
1409
+ def custom_summarizer(messages: list[dict], context: str = "") -> str:
1410
+ summarizer_called.append(len(messages))
1411
+ return f"[Summary of {len(messages)} messages]"
1412
+
1413
+ messages = [
1414
+ {"role": "system", "content": "System"},
1415
+ ]
1416
+ for i in range(8):
1417
+ messages.append({"role": "user", "content": f"Question {i} " * 30})
1418
+ messages.append({"role": "assistant", "content": f"Answer {i} " * 30})
1419
+ messages.append({"role": "user", "content": "Final"})
1420
+
1421
+ config = IntelligentContextConfig(
1422
+ summarization_enabled=True,
1423
+ compress_threshold=0.05,
1424
+ summarize_threshold=0.30,
1425
+ keep_last_turns=1,
1426
+ )
1427
+ manager = IntelligentContextManager(
1428
+ config=config,
1429
+ summarize_fn=custom_summarizer,
1430
+ )
1431
+
1432
+ tokens_before = tokenizer.count_messages(messages)
1433
+ target_limit = int(tokens_before / 1.15)
1434
+
1435
+ result = manager.apply(
1436
+ messages,
1437
+ tokenizer,
1438
+ model_limit=target_limit,
1439
+ output_buffer=50,
1440
+ )
1441
+
1442
+ # Summarizer should have been called
1443
+ assert len(summarizer_called) > 0
1444
+ # Should have reduced tokens
1445
+ assert result.tokens_after < result.tokens_before
1446
+
1447
+ def test_summarize_fallback_to_drop_when_not_enough(self, tokenizer: Tokenizer):
1448
+ """SUMMARIZE should fall back to DROP_BY_SCORE when not enough."""
1449
+
1450
+ # Custom summarizer that doesn't save much
1451
+ def ineffective_summarizer(messages: list[dict], context: str = "") -> str:
1452
+ # Return almost as long as original
1453
+ return "This is a very long summary " * 50
1454
+
1455
+ messages = [
1456
+ {"role": "system", "content": "System"},
1457
+ ]
1458
+ for i in range(6):
1459
+ messages.append({"role": "user", "content": f"Q{i} " * 20})
1460
+ messages.append({"role": "assistant", "content": f"A{i} " * 20})
1461
+ messages.append({"role": "user", "content": "Final"})
1462
+
1463
+ config = IntelligentContextConfig(
1464
+ summarization_enabled=True,
1465
+ compress_threshold=0.05,
1466
+ summarize_threshold=0.30,
1467
+ keep_last_turns=1,
1468
+ )
1469
+ manager = IntelligentContextManager(
1470
+ config=config,
1471
+ summarize_fn=ineffective_summarizer,
1472
+ )
1473
+
1474
+ tokens_before = tokenizer.count_messages(messages)
1475
+ # Very aggressive limit
1476
+ target_limit = int(tokens_before / 2.0)
1477
+
1478
+ result = manager.apply(
1479
+ messages,
1480
+ tokenizer,
1481
+ model_limit=target_limit,
1482
+ output_buffer=50,
1483
+ )
1484
+
1485
+ # Should still reduce tokens (via DROP_BY_SCORE fallback)
1486
+ assert result.tokens_after < result.tokens_before
1487
+
1488
+ def test_summarize_preserves_protected_messages(self, tokenizer: Tokenizer):
1489
+ """SUMMARIZE should never summarize protected messages."""
1490
+ messages = [
1491
+ {"role": "system", "content": "Important system prompt " * 20},
1492
+ {"role": "user", "content": "Old question " * 30},
1493
+ {"role": "assistant", "content": "Old answer " * 30},
1494
+ {"role": "user", "content": "Recent question " * 30},
1495
+ {"role": "assistant", "content": "Recent answer " * 30},
1496
+ {"role": "user", "content": "Final question"},
1497
+ ]
1498
+
1499
+ config = IntelligentContextConfig(
1500
+ summarization_enabled=True,
1501
+ compress_threshold=0.05,
1502
+ summarize_threshold=0.30,
1503
+ keep_system=True,
1504
+ keep_last_turns=2, # Protect last 2 user turns
1505
+ )
1506
+ manager = IntelligentContextManager(config=config)
1507
+
1508
+ tokens_before = tokenizer.count_messages(messages)
1509
+ target_limit = int(tokens_before / 1.15)
1510
+
1511
+ result = manager.apply(
1512
+ messages,
1513
+ tokenizer,
1514
+ model_limit=target_limit,
1515
+ output_buffer=50,
1516
+ )
1517
+
1518
+ # System message should still be present
1519
+ system_messages = [m for m in result.messages if m.get("role") == "system"]
1520
+ assert len(system_messages) >= 1
1521
+ assert "Important system prompt" in system_messages[0].get("content", "")
1522
+
1523
+
1524
+ class TestProgressiveSummarizer:
1525
+ """Tests for ProgressiveSummarizer component."""
1526
+
1527
+ def test_extractive_summarizer_default(self, tokenizer: Tokenizer):
1528
+ """Default extractive summarizer should work."""
1529
+ from headroom.transforms.progressive_summarizer import (
1530
+ extractive_summarizer,
1531
+ )
1532
+
1533
+ messages = [
1534
+ {"role": "user", "content": "Question 1 " * 20},
1535
+ {"role": "assistant", "content": "Answer 1 " * 30},
1536
+ {"role": "user", "content": "Question 2 " * 20},
1537
+ {"role": "assistant", "content": "Answer 2 " * 30},
1538
+ ]
1539
+
1540
+ # Test extractive summarizer directly
1541
+ summary = extractive_summarizer(messages)
1542
+ assert "[Summary of" in summary
1543
+ assert "4 messages" in summary
1544
+
1545
+ def test_progressive_summarizer_groups_messages(self, tokenizer: Tokenizer):
1546
+ """ProgressiveSummarizer should identify message groups correctly."""
1547
+ from headroom.transforms.progressive_summarizer import ProgressiveSummarizer
1548
+
1549
+ summarizer = ProgressiveSummarizer(
1550
+ min_messages_to_summarize=2,
1551
+ store_for_retrieval=False,
1552
+ )
1553
+
1554
+ messages = [
1555
+ {"role": "system", "content": "System"},
1556
+ {"role": "user", "content": "Q1 " * 30},
1557
+ {"role": "assistant", "content": "A1 " * 30},
1558
+ {"role": "user", "content": "Q2 " * 30},
1559
+ {"role": "assistant", "content": "A2 " * 30},
1560
+ {"role": "user", "content": "Final"},
1561
+ ]
1562
+
1563
+ # Protect only system (0) and final (5)
1564
+ protected = {0, 5}
1565
+
1566
+ groups = summarizer._find_summarization_candidates(messages, protected)
1567
+
1568
+ # Should find the middle messages as a group
1569
+ assert len(groups) >= 1
1570
+ # Group should include indices 1-4
1571
+ found_middle_group = any(start <= 1 and end >= 4 for start, end in groups)
1572
+ assert found_middle_group
1573
+
1574
+ def test_progressive_summarizer_respects_min_messages(self, tokenizer: Tokenizer):
1575
+ """ProgressiveSummarizer should respect min_messages_to_summarize."""
1576
+ from headroom.transforms.progressive_summarizer import ProgressiveSummarizer
1577
+
1578
+ summarizer = ProgressiveSummarizer(
1579
+ min_messages_to_summarize=5, # High threshold
1580
+ store_for_retrieval=False,
1581
+ )
1582
+
1583
+ messages = [
1584
+ {"role": "system", "content": "System"},
1585
+ {"role": "user", "content": "Q1"},
1586
+ {"role": "assistant", "content": "A1"},
1587
+ {"role": "user", "content": "Final"},
1588
+ ]
1589
+
1590
+ protected = {0, 3}
1591
+
1592
+ groups = summarizer._find_summarization_candidates(messages, protected)
1593
+
1594
+ # Should not find any groups (only 2 unprotected messages)
1595
+ assert len(groups) == 0
1596
+
1597
+ def test_progressive_summarizer_summarizes_messages(self, tokenizer: Tokenizer):
1598
+ """ProgressiveSummarizer should create summaries correctly."""
1599
+ from headroom.transforms.progressive_summarizer import ProgressiveSummarizer
1600
+
1601
+ summarizer = ProgressiveSummarizer(
1602
+ min_messages_to_summarize=3,
1603
+ store_for_retrieval=False,
1604
+ )
1605
+
1606
+ messages = [
1607
+ {"role": "system", "content": "System"},
1608
+ {"role": "user", "content": "Q1 " * 50},
1609
+ {"role": "assistant", "content": "A1 " * 50},
1610
+ {"role": "user", "content": "Q2 " * 50},
1611
+ {"role": "assistant", "content": "A2 " * 50},
1612
+ {"role": "user", "content": "Final question"},
1613
+ ]
1614
+
1615
+ protected = {0, 5} # System and final
1616
+
1617
+ result = summarizer.summarize_messages(
1618
+ messages=messages,
1619
+ tokenizer=tokenizer,
1620
+ protected_indices=protected,
1621
+ )
1622
+
1623
+ # Should have reduced message count
1624
+ assert len(result.messages) < len(messages)
1625
+ # Should have created summaries
1626
+ assert len(result.summaries_created) > 0
1627
+ # Should have saved tokens
1628
+ assert result.tokens_after < result.tokens_before
1629
+
1630
+
1631
+ class TestAnchoredSummary:
1632
+ """Tests for AnchoredSummary data structure."""
1633
+
1634
+ def test_anchored_summary_compression_ratio(self):
1635
+ """AnchoredSummary should calculate compression ratio correctly."""
1636
+ from headroom.transforms.progressive_summarizer import AnchoredSummary
1637
+
1638
+ summary = AnchoredSummary(
1639
+ summary_text="Summary",
1640
+ start_index=0,
1641
+ end_index=5,
1642
+ original_message_count=6,
1643
+ original_tokens=1000,
1644
+ summary_tokens=100,
1645
+ )
1646
+
1647
+ assert summary.compression_ratio == 0.1
1648
+ assert summary.tokens_saved == 900
1649
+
1650
+ def test_anchored_summary_zero_original_tokens(self):
1651
+ """AnchoredSummary should handle zero original tokens."""
1652
+ from headroom.transforms.progressive_summarizer import AnchoredSummary
1653
+
1654
+ summary = AnchoredSummary(
1655
+ summary_text="Summary",
1656
+ start_index=0,
1657
+ end_index=0,
1658
+ original_message_count=1,
1659
+ original_tokens=0,
1660
+ summary_tokens=10,
1661
+ )
1662
+
1663
+ assert summary.compression_ratio == 1.0
1664
+ assert summary.tokens_saved == 0
1665
+
1666
+
1667
+ class TestSummarizeEdgeCases:
1668
+ """Edge case tests for SUMMARIZE strategy."""
1669
+
1670
+ def test_summarize_empty_messages(self, tokenizer: Tokenizer):
1671
+ """SUMMARIZE should handle empty messages list."""
1672
+ from headroom.transforms.progressive_summarizer import ProgressiveSummarizer
1673
+
1674
+ summarizer = ProgressiveSummarizer(store_for_retrieval=False)
1675
+
1676
+ result = summarizer.summarize_messages(
1677
+ messages=[],
1678
+ tokenizer=tokenizer,
1679
+ protected_indices=set(),
1680
+ )
1681
+
1682
+ assert result.messages == []
1683
+ assert len(result.summaries_created) == 0
1684
+
1685
+ def test_summarize_all_protected(self, tokenizer: Tokenizer):
1686
+ """SUMMARIZE should handle when all messages are protected."""
1687
+ from headroom.transforms.progressive_summarizer import ProgressiveSummarizer
1688
+
1689
+ summarizer = ProgressiveSummarizer(store_for_retrieval=False)
1690
+
1691
+ messages = [
1692
+ {"role": "system", "content": "System"},
1693
+ {"role": "user", "content": "Question"},
1694
+ {"role": "assistant", "content": "Answer"},
1695
+ ]
1696
+
1697
+ result = summarizer.summarize_messages(
1698
+ messages=messages,
1699
+ tokenizer=tokenizer,
1700
+ protected_indices={0, 1, 2}, # All protected
1701
+ )
1702
+
1703
+ # Should return unchanged messages
1704
+ assert len(result.messages) == len(messages)
1705
+ assert len(result.summaries_created) == 0
1706
+
1707
+ def test_summarize_with_tool_messages(self, tokenizer: Tokenizer):
1708
+ """SUMMARIZE should handle tool messages."""
1709
+ import json
1710
+
1711
+ from headroom.transforms.progressive_summarizer import ProgressiveSummarizer
1712
+
1713
+ summarizer = ProgressiveSummarizer(
1714
+ min_messages_to_summarize=3,
1715
+ store_for_retrieval=False,
1716
+ )
1717
+
1718
+ messages = [
1719
+ {"role": "system", "content": "System"},
1720
+ {"role": "user", "content": "Search for data " * 20},
1721
+ {
1722
+ "role": "assistant",
1723
+ "content": "",
1724
+ "tool_calls": [
1725
+ {
1726
+ "id": "c1",
1727
+ "type": "function",
1728
+ "function": {"name": "search", "arguments": "{}"},
1729
+ }
1730
+ ],
1731
+ },
1732
+ {
1733
+ "role": "tool",
1734
+ "tool_call_id": "c1",
1735
+ "content": json.dumps([{"id": i, "data": f"result_{i}"} for i in range(20)]),
1736
+ },
1737
+ {"role": "assistant", "content": "Here are the results " * 20},
1738
+ {"role": "user", "content": "Final"},
1739
+ ]
1740
+
1741
+ protected = {0, 5}
1742
+
1743
+ result = summarizer.summarize_messages(
1744
+ messages=messages,
1745
+ tokenizer=tokenizer,
1746
+ protected_indices=protected,
1747
+ )
1748
+
1749
+ # Should complete without error
1750
+ assert result.messages is not None
1751
+ # Protected messages should be preserved
1752
+ assert result.messages[0].get("role") == "system"
1753
+
1754
+ def test_summarize_skips_small_token_groups(self, tokenizer: Tokenizer):
1755
+ """SUMMARIZE should skip groups with few tokens."""
1756
+ from headroom.transforms.progressive_summarizer import ProgressiveSummarizer
1757
+
1758
+ summarizer = ProgressiveSummarizer(
1759
+ min_messages_to_summarize=3,
1760
+ store_for_retrieval=False,
1761
+ )
1762
+
1763
+ # Very short messages
1764
+ messages = [
1765
+ {"role": "system", "content": "S"},
1766
+ {"role": "user", "content": "Q1"},
1767
+ {"role": "assistant", "content": "A1"},
1768
+ {"role": "user", "content": "Q2"},
1769
+ {"role": "assistant", "content": "A2"},
1770
+ {"role": "user", "content": "F"},
1771
+ ]
1772
+
1773
+ protected = {0, 5}
1774
+
1775
+ result = summarizer.summarize_messages(
1776
+ messages=messages,
1777
+ tokenizer=tokenizer,
1778
+ protected_indices=protected,
1779
+ )
1780
+
1781
+ # Should not create summaries (groups too small token-wise)
1782
+ # The summarizer checks for group_tokens < 100
1783
+ assert len(result.summaries_created) == 0
1784
+
1785
+ def test_summarize_callback_exception_handled(self, tokenizer: Tokenizer):
1786
+ """SUMMARIZE should handle callback exceptions gracefully."""
1787
+ from headroom.transforms.progressive_summarizer import ProgressiveSummarizer
1788
+
1789
+ def failing_summarizer(messages: list[dict], context: str = "") -> str:
1790
+ raise ValueError("Summarization failed!")
1791
+
1792
+ summarizer = ProgressiveSummarizer(
1793
+ summarize_fn=failing_summarizer,
1794
+ min_messages_to_summarize=3,
1795
+ store_for_retrieval=False,
1796
+ )
1797
+
1798
+ messages = [
1799
+ {"role": "system", "content": "System"},
1800
+ {"role": "user", "content": "Q " * 50},
1801
+ {"role": "assistant", "content": "A " * 50},
1802
+ {"role": "user", "content": "Q2 " * 50},
1803
+ {"role": "assistant", "content": "A2 " * 50},
1804
+ {"role": "user", "content": "Final"},
1805
+ ]
1806
+
1807
+ protected = {0, 5}
1808
+
1809
+ # Should not raise, should return original messages
1810
+ result = summarizer.summarize_messages(
1811
+ messages=messages,
1812
+ tokenizer=tokenizer,
1813
+ protected_indices=protected,
1814
+ )
1815
+
1816
+ assert result.messages is not None
1817
+ # No summaries created due to exception
1818
+ assert len(result.summaries_created) == 0