diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index c5303fbed5f..f34c289cd13 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -404,6 +404,7 @@ def filter_messages( exclude_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None, include_ids: Optional[Sequence[str]] = None, exclude_ids: Optional[Sequence[str]] = None, + exclude_tool_calls: Optional[Sequence[str] | bool] = None, ) -> list[BaseMessage]: """Filter messages based on name, type or id. @@ -419,6 +420,13 @@ def filter_messages( SystemMessage, HumanMessage, AIMessage, ...). Default is None. include_ids: Message IDs to include. Default is None. exclude_ids: Message IDs to exclude. Default is None. + exclude_tool_calls: Tool call IDs to exclude. Default is None. + Can be one of the following: + - `True`: all AIMessages with tool calls and all ToolMessages will be excluded. + - a sequence of tool call IDs to exclude: + - ToolMessages with the corresponding tool call ID will be excluded. + - The `tool_calls` in the AIMessage will be updated to exclude matching tool calls. + If all tool_calls are filtered from an AIMessage, the whole message is excluded. Returns: A list of Messages that meets at least one of the incl_* conditions and none @@ -467,6 +475,43 @@ def filter_messages( else: pass + if exclude_tool_calls is True and ( + (isinstance(msg, AIMessage) and msg.tool_calls) + or isinstance(msg, ToolMessage) + ): + continue + + if isinstance(exclude_tool_calls, (list, tuple, set)): + if isinstance(msg, AIMessage) and msg.tool_calls: + tool_calls = [ + tool_call + for tool_call in msg.tool_calls + if tool_call["id"] not in exclude_tool_calls + ] + if not tool_calls: + continue + + content = msg.content + # handle Anthropic content blocks + if isinstance(msg.content, list): + content = [ + content_block + for content_block in msg.content + if ( + not isinstance(content_block, dict) + or content_block.get("type") != "tool_use" + or content_block.get("id") not in exclude_tool_calls + ) + ] + + msg = msg.model_copy( + update={"tool_calls": tool_calls, "content": content} + ) + elif ( + isinstance(msg, ToolMessage) and msg.tool_call_id in exclude_tool_calls + ): + continue + # default to inclusion when no inclusion criteria given. if ( not (include_types or include_ids or include_names) diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 6e3c608939d..b199a7443bf 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -165,6 +165,94 @@ def test_filter_message(filters: dict) -> None: assert messages == messages_model_copy +def test_filter_message_exclude_tool_calls() -> None: + tool_calls = [ + {"name": "foo", "id": "1", "args": {}, "type": "tool_call"}, + {"name": "bar", "id": "2", "args": {}, "type": "tool_call"}, + ] + messages = [ + HumanMessage("foo", name="blah", id="1"), + AIMessage("foo-response", name="blah", id="2"), + HumanMessage("bar", name="blur", id="3"), + AIMessage( + "bar-response", + tool_calls=tool_calls, + id="4", + ), + ToolMessage("baz", tool_call_id="1", id="5"), + ToolMessage("qux", tool_call_id="2", id="6"), + ] + messages_model_copy = [m.model_copy(deep=True) for m in messages] + expected = messages[:3] + + # test excluding all tool calls + actual = filter_messages(messages, exclude_tool_calls=True) + assert expected == actual + + # test explicitly excluding all tool calls + actual = filter_messages(messages, exclude_tool_calls={"1", "2"}) + assert expected == actual + + # test excluding a specific tool call + expected = messages[:5] + expected[3] = expected[3].model_copy(update={"tool_calls": [tool_calls[0]]}) + actual = filter_messages(messages, exclude_tool_calls=["2"]) + assert expected == actual + + # assert that we didn't mutate the original messages + assert messages == messages_model_copy + + +def test_filter_message_exclude_tool_calls_content_blocks() -> None: + tool_calls = [ + {"name": "foo", "id": "1", "args": {}, "type": "tool_call"}, + {"name": "bar", "id": "2", "args": {}, "type": "tool_call"}, + ] + messages = [ + HumanMessage("foo", name="blah", id="1"), + AIMessage("foo-response", name="blah", id="2"), + HumanMessage("bar", name="blur", id="3"), + AIMessage( + [ + {"text": "bar-response", "type": "text"}, + {"name": "foo", "type": "tool_use", "id": "1"}, + {"name": "bar", "type": "tool_use", "id": "2"}, + ], + tool_calls=tool_calls, + id="4", + ), + ToolMessage("baz", tool_call_id="1", id="5"), + ToolMessage("qux", tool_call_id="2", id="6"), + ] + messages_model_copy = [m.model_copy(deep=True) for m in messages] + expected = messages[:3] + + # test excluding all tool calls + actual = filter_messages(messages, exclude_tool_calls=True) + assert expected == actual + + # test explicitly excluding all tool calls + actual = filter_messages(messages, exclude_tool_calls={"1", "2"}) + assert expected == actual + + # test excluding a specific tool call + expected = messages[:4] + messages[-1:] + expected[3] = expected[3].model_copy( + update={ + "tool_calls": [tool_calls[1]], + "content": [ + {"text": "bar-response", "type": "text"}, + {"name": "bar", "type": "tool_use", "id": "2"}, + ], + } + ) + actual = filter_messages(messages, exclude_tool_calls=["1"]) + assert expected == actual + + # assert that we didn't mutate the original messages + assert messages == messages_model_copy + + _MESSAGES_TO_TRIM = [ SystemMessage("This is a 4 token text."), HumanMessage("This is a 4 token text.", id="first"),