Compare commits

...

4 Commits

Author SHA1 Message Date
Chester Curme
01b1122a81 example implementation for SummarizationMiddleware 2025-10-28 11:55:29 -04:00
Chester Curme
cb71e10f95 Revert "example implementation for SummarizationMiddleware"
This reverts commit c568d2fc62.
2025-10-28 11:33:57 -04:00
Chester Curme
c568d2fc62 example implementation for SummarizationMiddleware 2025-10-28 11:24:33 -04:00
Chester Curme
abd46163a2 example implementation 2025-10-28 10:45:36 -04:00
2 changed files with 105 additions and 14 deletions

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import itertools
from typing import TYPE_CHECKING, Annotated, Any, cast, get_args, get_origin, get_type_hints
import warnings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
@@ -612,8 +613,6 @@ def create_agent( # noqa: PLR0915
# Convert response format and setup structured output tools
# Raw schemas are wrapped in AutoStrategy to preserve auto-detection intent.
# AutoStrategy is converted to ToolStrategy upfront to calculate tools during agent creation,
# but may be replaced with ProviderStrategy later based on model capabilities.
initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None
if response_format is None:
initial_response_format = None
@@ -624,15 +623,29 @@ def create_agent( # noqa: PLR0915
# AutoStrategy provided - preserve it for later auto-detection
initial_response_format = response_format
else:
# Raw schema - wrap in AutoStrategy to enable auto-detection
initial_response_format = AutoStrategy(schema=response_format)
if not isinstance(response_format, AutoStrategy):
schema = response_format # raw schema
else:
schema = response_format.schema
try:
model_capabilities = model.capabilities
if model_capabilities["structured_output"]:
response_format = ProviderStrategy(schema=schema)
else:
response_format = ToolStrategy(schema=schema)
initial_response_format = response_format
except ImportError:
warning_msg = (
"Could not infer structured output strategy. pip install "
"langchain-llm-capabilities to enable model capability detection. "
"Defaulting to ToolStrategy."
)
warnings.warn(warning_msg, ImportWarning)
response_format = ToolStrategy(schema=schema)
initial_response_format = response_format
# For AutoStrategy, convert to ToolStrategy to setup tools upfront
# (may be replaced with ProviderStrategy later based on model)
tool_strategy_for_setup: ToolStrategy | None = None
if isinstance(initial_response_format, AutoStrategy):
tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema)
elif isinstance(initial_response_format, ToolStrategy):
if isinstance(initial_response_format, ToolStrategy):
tool_strategy_for_setup = initial_response_format
structured_output_tools: dict[str, OutputToolBinding] = {}

View File

@@ -3,6 +3,7 @@
import uuid
from collections.abc import Callable, Iterable
from typing import Any, cast
import warnings
from langchain_core.messages import (
AIMessage,
@@ -75,17 +76,28 @@ class SummarizationMiddleware(AgentMiddleware):
token_counter: TokenCounter = count_tokens_approximately,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
summary_prefix: str = SUMMARY_PREFIX,
*,
target_retention_pct: float = 0.3,
buffer_tokens: int = 500,
) -> None:
"""Initialize the summarization middleware.
Args:
model: The language model to use for generating summaries.
max_tokens_before_summary: Token threshold to trigger summarization.
If `None`, summarization is disabled.
If `None` and model capabilities are unavailable, summarization is disabled.
Deprecated in favor of dynamic token management using model capabilities.
messages_to_keep: Number of recent messages to preserve after summarization.
Used as fallback when token-based calculation is not possible.
token_counter: Function to count tokens in messages.
summary_prompt: Prompt template for generating summaries.
summary_prefix: Prefix added to system message when including summary.
target_retention_pct: Target percentage of max_input_tokens to retain after
summarization (default: 0.3 or 30%). Only used when model capabilities
are available via langchain-llm-features.
buffer_tokens: Safety buffer in tokens to prevent hitting limits
(default: 500). Only used when model capabilities are available via
langchain-llm-features.
"""
super().__init__()
@@ -98,6 +110,8 @@ class SummarizationMiddleware(AgentMiddleware):
self.token_counter = token_counter
self.summary_prompt = summary_prompt
self.summary_prefix = summary_prefix
self.target_retention_pct = target_retention_pct
self.buffer_tokens = buffer_tokens
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Process messages before model invocation, potentially triggering summarization."""
@@ -105,13 +119,37 @@ class SummarizationMiddleware(AgentMiddleware):
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
# Try to use model capabilities for dynamic token management
should_summarize = False
target_token_count: int | None = None
try:
max_input_tokens = self.model.capabilities.get("max_input_tokens")
max_output_tokens = self.model.capabilities.get("max_output_tokens")
if max_input_tokens is not None and max_output_tokens is not None:
if total_tokens + max_output_tokens + self.buffer_tokens > max_input_tokens:
should_summarize = True
target_token_count = int(max_input_tokens * self.target_retention_pct)
else:
return None
except ImportError:
warning_msg = (
"Defaulting to None max_tokens_before_summary. pip install "
"langchain-llm-capabilities to enable setting based on model "
"context window size."
)
warnings.warn(warning_msg, ImportWarning)
if (
self.max_tokens_before_summary is not None
and total_tokens < self.max_tokens_before_summary
and not should_summarize
):
return None
cutoff_index = self._find_safe_cutoff(messages)
cutoff_index = self._find_safe_cutoff(messages, target_token_count=target_token_count)
if cutoff_index <= 0:
return None
@@ -151,12 +189,52 @@ class SummarizationMiddleware(AgentMiddleware):
return messages_to_summarize, preserved_messages
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
def _find_safe_cutoff(
self,
messages: list[AnyMessage],
target_token_count: int | None = None,
) -> int:
"""Find safe cutoff point that preserves AI/Tool message pairs.
Returns the index where messages can be safely cut without separating
related AI and Tool messages. Returns 0 if no safe cutoff is found.
Args:
messages: List of messages to find a cutoff point for.
target_token_count: Target token count to retain. If provided, keeps as many
messages as possible while staying under this limit. If `None`, falls back
to keeping `messages_to_keep` recent messages.
Returns:
The index where messages can be safely cut without separating
related AI and Tool messages. Returns 0 if no safe cutoff is found.
"""
if target_token_count is not None:
# Token-based cutoff: use trim_messages to keep as many as possible
try:
trimmed = trim_messages(
messages,
max_tokens=target_token_count,
token_counter=self.token_counter,
strategy="last",
allow_partial=False,
)
# Find where the trimmed messages start in the original list
if not trimmed:
return 0
# Calculate cutoff index based on how many messages were kept
num_kept = len(trimmed)
target_cutoff = len(messages) - num_kept
# Ensure we have a safe cutoff point (doesn't separate AI/Tool pairs)
for i in range(target_cutoff, -1, -1):
if self._is_safe_cutoff_point(messages, i):
return i
return 0
except Exception: # noqa: BLE001
# Fall back to message-count-based approach if trim_messages fails
pass
# Legacy message-count-based cutoff
if len(messages) <= self.messages_to_keep:
return 0