mirror of
https://github.com/hwchase17/langchain.git
synced 2026-05-05 11:12:11 +00:00
example implementation for SummarizationMiddleware
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, cast
|
||||
import warnings
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -93,12 +94,31 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = model
|
||||
self.max_tokens_before_summary = max_tokens_before_summary
|
||||
self.messages_to_keep = messages_to_keep
|
||||
self.token_counter = token_counter
|
||||
self.summary_prompt = summary_prompt
|
||||
self.summary_prefix = summary_prefix
|
||||
|
||||
if max_tokens_before_summary is None:
|
||||
# Need to resolve https://github.com/langchain-ai/langchain/issues/33701
|
||||
# to clarify desired behavior in None case
|
||||
self.max_tokens_before_summary = None
|
||||
try:
|
||||
model_capabilities = model.capabilities
|
||||
if model_capabilities["max_input_tokens"]:
|
||||
self.max_tokens_before_summary = (
|
||||
# 75% of max input tokens as threshold
|
||||
# Could also parametrize this
|
||||
int(model_capabilities["max_input_tokens"] * 0.75)
|
||||
)
|
||||
except ImportError:
|
||||
warning_msg = (
|
||||
"Defaulting to None max_tokens_before_summary. pip install "
|
||||
"langchain-llm-capabilities to enable setting based on model "
|
||||
"context window size."
|
||||
)
|
||||
warnings.warn(warning_msg, ImportWarning)
|
||||
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Process messages before model invocation, potentially triggering summarization."""
|
||||
messages = state["messages"]
|
||||
|
||||
Reference in New Issue
Block a user