example implementation for SummarizationMiddleware

This commit is contained in:
Chester Curme
2025-10-28 11:24:33 -04:00
parent abd46163a2
commit c568d2fc62

View File

@@ -3,6 +3,7 @@
import uuid
from collections.abc import Callable, Iterable
from typing import Any, cast
import warnings
from langchain_core.messages import (
AIMessage,
@@ -93,12 +94,31 @@ class SummarizationMiddleware(AgentMiddleware):
model = init_chat_model(model)
self.model = model
self.max_tokens_before_summary = max_tokens_before_summary
self.messages_to_keep = messages_to_keep
self.token_counter = token_counter
self.summary_prompt = summary_prompt
self.summary_prefix = summary_prefix
if max_tokens_before_summary is None:
# Need to resolve https://github.com/langchain-ai/langchain/issues/33701
# to clarify desired behavior in None case
self.max_tokens_before_summary = None
try:
model_capabilities = model.capabilities
if model_capabilities["max_input_tokens"]:
self.max_tokens_before_summary = (
# 75% of max input tokens as threshold
# Could also parametrize this
int(model_capabilities["max_input_tokens"] * 0.75)
)
except ImportError:
warning_msg = (
"Defaulting to None max_tokens_before_summary. pip install "
"langchain-llm-capabilities to enable setting based on model "
"context window size."
)
warnings.warn(warning_msg, ImportWarning)
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Process messages before model invocation, potentially triggering summarization."""
messages = state["messages"]