This commit is contained in:
Harrison Chase
2025-09-06 13:11:02 -07:00
parent 26bef498e8
commit 44a60a6f09
2 changed files with 56 additions and 86 deletions

View File

@@ -48,53 +48,55 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
approved_tool_calls = auto_approved_tool_calls.copy()
# Process all tool calls that need interrupts in parallel
requests = []
# Right now, we do not support multiple tool calls with interrupts
if len(interrupt_tool_calls) > 1:
raise ValueError("Does not currently support multiple tool calls with interrupts")
for tool_call in interrupt_tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
tool_config = self.tool_configs[tool_name]
# Right now, we do not support interrupting a tool call if other tool calls exist
if auto_approved_tool_calls:
raise ValueError("Does not currently support interrupting a tool call if other tool calls exist")
request: HumanInterrupt = {
"action_request": ActionRequest(
action=tool_name,
args=tool_args,
),
"config": tool_config,
"description": description,
# Only one tool call will need interrupts
tool_call = interrupt_tool_calls[0]
tool_name = tool_call["name"]
tool_args = tool_call["args"]
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
tool_config = self.tool_configs[tool_name]
request: HumanInterrupt = {
"action_request": ActionRequest(
action=tool_name,
args=tool_args,
),
"config": tool_config,
"description": description,
}
responses: list[HumanResponse] = interrupt([request])
response = responses[0]
if response["type"] == "accept":
approved_tool_calls.append(tool_call)
elif response["type"] == "edit":
edited: ActionRequest = response["args"]
new_tool_call = {
"type": "tool_call",
"name": tool_call["name"],
"args": edited["args"],
"id": tool_call["id"],
}
requests.append(request)
responses: list[HumanResponse] = interrupt(requests)
for i, response in enumerate(responses):
tool_call = interrupt_tool_calls[i]
if response["type"] == "accept":
approved_tool_calls.append(tool_call)
elif response["type"] == "edit":
edited: ActionRequest = response["args"]
new_tool_call = {
"name": tool_call["name"],
"args": edited["args"],
"id": tool_call["id"],
}
approved_tool_calls.append(new_tool_call)
elif response["type"] == "ignore":
# NOTE: does not work with multiple interrupts
return {"goto": "__end__"}
elif response["type"] == "response":
# NOTE: does not work with multiple interrupts
tool_message = {
"role": "tool",
"tool_call_id": tool_call["id"],
"content": response["args"],
}
return {"messages": [tool_message], "goto": "model"}
else:
raise ValueError(f"Unknown response type: {response['type']}")
approved_tool_calls.append(new_tool_call)
elif response["type"] == "ignore":
return {"goto": "__end__"}
elif response["type"] == "response":
tool_message = {
"role": "tool",
"tool_call_id": tool_call["id"],
"content": response["args"],
}
return {"messages": [tool_message], "goto": "model"}
else:
raise ValueError(f"Unknown response type: {response['type']}")
last_message.tool_calls = approved_tool_calls

View File

@@ -6,7 +6,6 @@ from langchain_core.messages import (
AIMessage,
AnyMessage,
MessageLikeRepresentation,
SystemMessage,
ToolMessage,
)
from langchain_core.messages.utils import count_tokens_approximately, trim_messages
@@ -101,79 +100,48 @@ class SummarizationMiddleware(AgentMiddleware):
):
return None
system_message, conversation_messages = self._split_system_message(messages)
cutoff_index = self._find_safe_cutoff(conversation_messages)
cutoff_index = self._find_safe_cutoff(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(
system_message, conversation_messages, cutoff_index
messages, cutoff_index
)
summary = self._create_summary(messages_to_summarize)
updated_system_message = self._build_updated_system_message(system_message, summary)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
updated_system_message,
*new_messages,
*preserved_messages,
]
}
def _build_new_messages(self, summary: str):
return [
{"role": "user", "content": f"Here is a summary of the conversation to date:\n\n{summary}"}
]
def _ensure_message_ids(self, messages: list[AnyMessage]) -> None:
"""Ensure all messages have unique IDs for the add_messages reducer."""
for msg in messages:
if msg.id is None:
msg.id = str(uuid.uuid4())
def _split_system_message(
self, messages: list[AnyMessage]
) -> tuple[SystemMessage | None, list[AnyMessage]]:
"""Separate system message from conversation messages."""
if messages and isinstance(messages[0], SystemMessage):
return messages[0], messages[1:]
return None, messages
def _partition_messages(
self,
system_message: SystemMessage | None,
conversation_messages: list[AnyMessage],
cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Partition messages into those to summarize and those to preserve.
We include the system message so that we can capture previous summaries.
"""
"""Partition messages into those to summarize and those to preserve."""
messages_to_summarize = conversation_messages[:cutoff_index]
preserved_messages = conversation_messages[cutoff_index:]
if system_message is not None:
messages_to_summarize = [system_message, *messages_to_summarize]
return messages_to_summarize, preserved_messages
def _build_updated_system_message(
self, original_system_message: SystemMessage | None, summary: str
) -> SystemMessage:
"""Build new system message incorporating the summary."""
if original_system_message is None:
original_content = ""
else:
content = cast("str", original_system_message.content)
original_content = content.split(self.summary_prefix)[0].strip()
if original_content:
content = f"{original_content}\n{self.summary_prefix}\n{summary}"
else:
content = f"{self.summary_prefix}\n{summary}"
return SystemMessage(
content=content,
id=original_system_message.id if original_system_message else str(uuid.uuid4()),
)
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
"""Find safe cutoff point that preserves AI/Tool message pairs.