From 44a60a6f09a3e723e3ee4682c28528885e9078a2 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 6 Sep 2025 13:11:02 -0700 Subject: [PATCH] cr --- .../agents/middleware/human_in_the_loop.py | 90 ++++++++++--------- .../agents/middleware/summarization.py | 52 +++-------- 2 files changed, 56 insertions(+), 86 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index b0df8e9edf4..d19932f9203 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -48,53 +48,55 @@ class HumanInTheLoopMiddleware(AgentMiddleware): approved_tool_calls = auto_approved_tool_calls.copy() - # Process all tool calls that need interrupts in parallel - requests = [] + # Right now, we do not support multiple tool calls with interrupts + if len(interrupt_tool_calls) > 1: + raise ValueError("Does not currently support multiple tool calls with interrupts") - for tool_call in interrupt_tool_calls: - tool_name = tool_call["name"] - tool_args = tool_call["args"] - description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}" - tool_config = self.tool_configs[tool_name] + # Right now, we do not support interrupting a tool call if other tool calls exist + if auto_approved_tool_calls: + raise ValueError("Does not currently support interrupting a tool call if other tool calls exist") - request: HumanInterrupt = { - "action_request": ActionRequest( - action=tool_name, - args=tool_args, - ), - "config": tool_config, - "description": description, + # Only one tool call will need interrupts + tool_call = interrupt_tool_calls[0] + tool_name = tool_call["name"] + tool_args = tool_call["args"] + description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}" + tool_config = self.tool_configs[tool_name] + + request: HumanInterrupt = { + "action_request": ActionRequest( + action=tool_name, + args=tool_args, + ), + "config": tool_config, + "description": description, + } + + responses: list[HumanResponse] = interrupt([request]) + response = responses[0] + + if response["type"] == "accept": + approved_tool_calls.append(tool_call) + elif response["type"] == "edit": + edited: ActionRequest = response["args"] + new_tool_call = { + "type": "tool_call", + "name": tool_call["name"], + "args": edited["args"], + "id": tool_call["id"], } - requests.append(request) - - responses: list[HumanResponse] = interrupt(requests) - - for i, response in enumerate(responses): - tool_call = interrupt_tool_calls[i] - - if response["type"] == "accept": - approved_tool_calls.append(tool_call) - elif response["type"] == "edit": - edited: ActionRequest = response["args"] - new_tool_call = { - "name": tool_call["name"], - "args": edited["args"], - "id": tool_call["id"], - } - approved_tool_calls.append(new_tool_call) - elif response["type"] == "ignore": - # NOTE: does not work with multiple interrupts - return {"goto": "__end__"} - elif response["type"] == "response": - # NOTE: does not work with multiple interrupts - tool_message = { - "role": "tool", - "tool_call_id": tool_call["id"], - "content": response["args"], - } - return {"messages": [tool_message], "goto": "model"} - else: - raise ValueError(f"Unknown response type: {response['type']}") + approved_tool_calls.append(new_tool_call) + elif response["type"] == "ignore": + return {"goto": "__end__"} + elif response["type"] == "response": + tool_message = { + "role": "tool", + "tool_call_id": tool_call["id"], + "content": response["args"], + } + return {"messages": [tool_message], "goto": "model"} + else: + raise ValueError(f"Unknown response type: {response['type']}") last_message.tool_calls = approved_tool_calls diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index 8edea01358d..7f3da15d61e 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -6,7 +6,6 @@ from langchain_core.messages import ( AIMessage, AnyMessage, MessageLikeRepresentation, - SystemMessage, ToolMessage, ) from langchain_core.messages.utils import count_tokens_approximately, trim_messages @@ -101,79 +100,48 @@ class SummarizationMiddleware(AgentMiddleware): ): return None - system_message, conversation_messages = self._split_system_message(messages) - cutoff_index = self._find_safe_cutoff(conversation_messages) + cutoff_index = self._find_safe_cutoff(messages) if cutoff_index <= 0: return None messages_to_summarize, preserved_messages = self._partition_messages( - system_message, conversation_messages, cutoff_index + messages, cutoff_index ) summary = self._create_summary(messages_to_summarize) - updated_system_message = self._build_updated_system_message(system_message, summary) + new_messages = self._build_new_messages(summary) return { "messages": [ RemoveMessage(id=REMOVE_ALL_MESSAGES), - updated_system_message, + *new_messages, *preserved_messages, ] } + def _build_new_messages(self, summary: str): + return [ + {"role": "user", "content": f"Here is a summary of the conversation to date:\n\n{summary}"} + ] + def _ensure_message_ids(self, messages: list[AnyMessage]) -> None: """Ensure all messages have unique IDs for the add_messages reducer.""" for msg in messages: if msg.id is None: msg.id = str(uuid.uuid4()) - def _split_system_message( - self, messages: list[AnyMessage] - ) -> tuple[SystemMessage | None, list[AnyMessage]]: - """Separate system message from conversation messages.""" - if messages and isinstance(messages[0], SystemMessage): - return messages[0], messages[1:] - return None, messages - def _partition_messages( self, - system_message: SystemMessage | None, conversation_messages: list[AnyMessage], cutoff_index: int, ) -> tuple[list[AnyMessage], list[AnyMessage]]: - """Partition messages into those to summarize and those to preserve. - - We include the system message so that we can capture previous summaries. - """ + """Partition messages into those to summarize and those to preserve.""" messages_to_summarize = conversation_messages[:cutoff_index] preserved_messages = conversation_messages[cutoff_index:] - if system_message is not None: - messages_to_summarize = [system_message, *messages_to_summarize] - return messages_to_summarize, preserved_messages - def _build_updated_system_message( - self, original_system_message: SystemMessage | None, summary: str - ) -> SystemMessage: - """Build new system message incorporating the summary.""" - if original_system_message is None: - original_content = "" - else: - content = cast("str", original_system_message.content) - original_content = content.split(self.summary_prefix)[0].strip() - - if original_content: - content = f"{original_content}\n{self.summary_prefix}\n{summary}" - else: - content = f"{self.summary_prefix}\n{summary}" - - return SystemMessage( - content=content, - id=original_system_message.id if original_system_message else str(uuid.uuid4()), - ) - def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int: """Find safe cutoff point that preserves AI/Tool message pairs.