mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-25 14:35:49 +00:00
cr
This commit is contained in:
@@ -48,53 +48,55 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
|
||||
approved_tool_calls = auto_approved_tool_calls.copy()
|
||||
|
||||
# Process all tool calls that need interrupts in parallel
|
||||
requests = []
|
||||
# Right now, we do not support multiple tool calls with interrupts
|
||||
if len(interrupt_tool_calls) > 1:
|
||||
raise ValueError("Does not currently support multiple tool calls with interrupts")
|
||||
|
||||
for tool_call in interrupt_tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
||||
tool_config = self.tool_configs[tool_name]
|
||||
# Right now, we do not support interrupting a tool call if other tool calls exist
|
||||
if auto_approved_tool_calls:
|
||||
raise ValueError("Does not currently support interrupting a tool call if other tool calls exist")
|
||||
|
||||
request: HumanInterrupt = {
|
||||
"action_request": ActionRequest(
|
||||
action=tool_name,
|
||||
args=tool_args,
|
||||
),
|
||||
"config": tool_config,
|
||||
"description": description,
|
||||
# Only one tool call will need interrupts
|
||||
tool_call = interrupt_tool_calls[0]
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
||||
tool_config = self.tool_configs[tool_name]
|
||||
|
||||
request: HumanInterrupt = {
|
||||
"action_request": ActionRequest(
|
||||
action=tool_name,
|
||||
args=tool_args,
|
||||
),
|
||||
"config": tool_config,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
responses: list[HumanResponse] = interrupt([request])
|
||||
response = responses[0]
|
||||
|
||||
if response["type"] == "accept":
|
||||
approved_tool_calls.append(tool_call)
|
||||
elif response["type"] == "edit":
|
||||
edited: ActionRequest = response["args"]
|
||||
new_tool_call = {
|
||||
"type": "tool_call",
|
||||
"name": tool_call["name"],
|
||||
"args": edited["args"],
|
||||
"id": tool_call["id"],
|
||||
}
|
||||
requests.append(request)
|
||||
|
||||
responses: list[HumanResponse] = interrupt(requests)
|
||||
|
||||
for i, response in enumerate(responses):
|
||||
tool_call = interrupt_tool_calls[i]
|
||||
|
||||
if response["type"] == "accept":
|
||||
approved_tool_calls.append(tool_call)
|
||||
elif response["type"] == "edit":
|
||||
edited: ActionRequest = response["args"]
|
||||
new_tool_call = {
|
||||
"name": tool_call["name"],
|
||||
"args": edited["args"],
|
||||
"id": tool_call["id"],
|
||||
}
|
||||
approved_tool_calls.append(new_tool_call)
|
||||
elif response["type"] == "ignore":
|
||||
# NOTE: does not work with multiple interrupts
|
||||
return {"goto": "__end__"}
|
||||
elif response["type"] == "response":
|
||||
# NOTE: does not work with multiple interrupts
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": response["args"],
|
||||
}
|
||||
return {"messages": [tool_message], "goto": "model"}
|
||||
else:
|
||||
raise ValueError(f"Unknown response type: {response['type']}")
|
||||
approved_tool_calls.append(new_tool_call)
|
||||
elif response["type"] == "ignore":
|
||||
return {"goto": "__end__"}
|
||||
elif response["type"] == "response":
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": response["args"],
|
||||
}
|
||||
return {"messages": [tool_message], "goto": "model"}
|
||||
else:
|
||||
raise ValueError(f"Unknown response type: {response['type']}")
|
||||
|
||||
last_message.tool_calls = approved_tool_calls
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
MessageLikeRepresentation,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import count_tokens_approximately, trim_messages
|
||||
@@ -101,79 +100,48 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
):
|
||||
return None
|
||||
|
||||
system_message, conversation_messages = self._split_system_message(messages)
|
||||
cutoff_index = self._find_safe_cutoff(conversation_messages)
|
||||
cutoff_index = self._find_safe_cutoff(messages)
|
||||
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(
|
||||
system_message, conversation_messages, cutoff_index
|
||||
messages, cutoff_index
|
||||
)
|
||||
|
||||
summary = self._create_summary(messages_to_summarize)
|
||||
updated_system_message = self._build_updated_system_message(system_message, summary)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
updated_system_message,
|
||||
*new_messages,
|
||||
*preserved_messages,
|
||||
]
|
||||
}
|
||||
|
||||
def _build_new_messages(self, summary: str):
|
||||
return [
|
||||
{"role": "user", "content": f"Here is a summary of the conversation to date:\n\n{summary}"}
|
||||
]
|
||||
|
||||
def _ensure_message_ids(self, messages: list[AnyMessage]) -> None:
|
||||
"""Ensure all messages have unique IDs for the add_messages reducer."""
|
||||
for msg in messages:
|
||||
if msg.id is None:
|
||||
msg.id = str(uuid.uuid4())
|
||||
|
||||
def _split_system_message(
|
||||
self, messages: list[AnyMessage]
|
||||
) -> tuple[SystemMessage | None, list[AnyMessage]]:
|
||||
"""Separate system message from conversation messages."""
|
||||
if messages and isinstance(messages[0], SystemMessage):
|
||||
return messages[0], messages[1:]
|
||||
return None, messages
|
||||
|
||||
def _partition_messages(
|
||||
self,
|
||||
system_message: SystemMessage | None,
|
||||
conversation_messages: list[AnyMessage],
|
||||
cutoff_index: int,
|
||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||
"""Partition messages into those to summarize and those to preserve.
|
||||
|
||||
We include the system message so that we can capture previous summaries.
|
||||
"""
|
||||
"""Partition messages into those to summarize and those to preserve."""
|
||||
messages_to_summarize = conversation_messages[:cutoff_index]
|
||||
preserved_messages = conversation_messages[cutoff_index:]
|
||||
|
||||
if system_message is not None:
|
||||
messages_to_summarize = [system_message, *messages_to_summarize]
|
||||
|
||||
return messages_to_summarize, preserved_messages
|
||||
|
||||
def _build_updated_system_message(
|
||||
self, original_system_message: SystemMessage | None, summary: str
|
||||
) -> SystemMessage:
|
||||
"""Build new system message incorporating the summary."""
|
||||
if original_system_message is None:
|
||||
original_content = ""
|
||||
else:
|
||||
content = cast("str", original_system_message.content)
|
||||
original_content = content.split(self.summary_prefix)[0].strip()
|
||||
|
||||
if original_content:
|
||||
content = f"{original_content}\n{self.summary_prefix}\n{summary}"
|
||||
else:
|
||||
content = f"{self.summary_prefix}\n{summary}"
|
||||
|
||||
return SystemMessage(
|
||||
content=content,
|
||||
id=original_system_message.id if original_system_message else str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
|
||||
"""Find safe cutoff point that preserves AI/Tool message pairs.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user