core: add tool_call exclusion in filter_message (#30289)

Extend functionallity to allow to filter pairs of tool calls (ai +
tool).

---------

Co-authored-by: vbarda <vadym@langchain.dev>
This commit is contained in:
Adrián Panella 2025-03-21 17:05:29 -06:00 committed by GitHub
parent 673ec00030
commit b75573e858
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 133 additions and 0 deletions

View File

@ -404,6 +404,7 @@ def filter_messages(
exclude_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None,
include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None,
exclude_tool_calls: Optional[Sequence[str] | bool] = None,
) -> list[BaseMessage]:
"""Filter messages based on name, type or id.
@ -419,6 +420,13 @@ def filter_messages(
SystemMessage, HumanMessage, AIMessage, ...). Default is None.
include_ids: Message IDs to include. Default is None.
exclude_ids: Message IDs to exclude. Default is None.
exclude_tool_calls: Tool call IDs to exclude. Default is None.
Can be one of the following:
- `True`: all AIMessages with tool calls and all ToolMessages will be excluded.
- a sequence of tool call IDs to exclude:
- ToolMessages with the corresponding tool call ID will be excluded.
- The `tool_calls` in the AIMessage will be updated to exclude matching tool calls.
If all tool_calls are filtered from an AIMessage, the whole message is excluded.
Returns:
A list of Messages that meets at least one of the incl_* conditions and none
@ -467,6 +475,43 @@ def filter_messages(
else:
pass
if exclude_tool_calls is True and (
(isinstance(msg, AIMessage) and msg.tool_calls)
or isinstance(msg, ToolMessage)
):
continue
if isinstance(exclude_tool_calls, (list, tuple, set)):
if isinstance(msg, AIMessage) and msg.tool_calls:
tool_calls = [
tool_call
for tool_call in msg.tool_calls
if tool_call["id"] not in exclude_tool_calls
]
if not tool_calls:
continue
content = msg.content
# handle Anthropic content blocks
if isinstance(msg.content, list):
content = [
content_block
for content_block in msg.content
if (
not isinstance(content_block, dict)
or content_block.get("type") != "tool_use"
or content_block.get("id") not in exclude_tool_calls
)
]
msg = msg.model_copy(
update={"tool_calls": tool_calls, "content": content}
)
elif (
isinstance(msg, ToolMessage) and msg.tool_call_id in exclude_tool_calls
):
continue
# default to inclusion when no inclusion criteria given.
if (
not (include_types or include_ids or include_names)

View File

@ -165,6 +165,94 @@ def test_filter_message(filters: dict) -> None:
assert messages == messages_model_copy
def test_filter_message_exclude_tool_calls() -> None:
tool_calls = [
{"name": "foo", "id": "1", "args": {}, "type": "tool_call"},
{"name": "bar", "id": "2", "args": {}, "type": "tool_call"},
]
messages = [
HumanMessage("foo", name="blah", id="1"),
AIMessage("foo-response", name="blah", id="2"),
HumanMessage("bar", name="blur", id="3"),
AIMessage(
"bar-response",
tool_calls=tool_calls,
id="4",
),
ToolMessage("baz", tool_call_id="1", id="5"),
ToolMessage("qux", tool_call_id="2", id="6"),
]
messages_model_copy = [m.model_copy(deep=True) for m in messages]
expected = messages[:3]
# test excluding all tool calls
actual = filter_messages(messages, exclude_tool_calls=True)
assert expected == actual
# test explicitly excluding all tool calls
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
assert expected == actual
# test excluding a specific tool call
expected = messages[:5]
expected[3] = expected[3].model_copy(update={"tool_calls": [tool_calls[0]]})
actual = filter_messages(messages, exclude_tool_calls=["2"])
assert expected == actual
# assert that we didn't mutate the original messages
assert messages == messages_model_copy
def test_filter_message_exclude_tool_calls_content_blocks() -> None:
tool_calls = [
{"name": "foo", "id": "1", "args": {}, "type": "tool_call"},
{"name": "bar", "id": "2", "args": {}, "type": "tool_call"},
]
messages = [
HumanMessage("foo", name="blah", id="1"),
AIMessage("foo-response", name="blah", id="2"),
HumanMessage("bar", name="blur", id="3"),
AIMessage(
[
{"text": "bar-response", "type": "text"},
{"name": "foo", "type": "tool_use", "id": "1"},
{"name": "bar", "type": "tool_use", "id": "2"},
],
tool_calls=tool_calls,
id="4",
),
ToolMessage("baz", tool_call_id="1", id="5"),
ToolMessage("qux", tool_call_id="2", id="6"),
]
messages_model_copy = [m.model_copy(deep=True) for m in messages]
expected = messages[:3]
# test excluding all tool calls
actual = filter_messages(messages, exclude_tool_calls=True)
assert expected == actual
# test explicitly excluding all tool calls
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
assert expected == actual
# test excluding a specific tool call
expected = messages[:4] + messages[-1:]
expected[3] = expected[3].model_copy(
update={
"tool_calls": [tool_calls[1]],
"content": [
{"text": "bar-response", "type": "text"},
{"name": "bar", "type": "tool_use", "id": "2"},
],
}
)
actual = filter_messages(messages, exclude_tool_calls=["1"])
assert expected == actual
# assert that we didn't mutate the original messages
assert messages == messages_model_copy
_MESSAGES_TO_TRIM = [
SystemMessage("This is a 4 token text."),
HumanMessage("This is a 4 token text.", id="first"),