mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
core: add tool_call exclusion in filter_message (#30289)
Extend functionallity to allow to filter pairs of tool calls (ai + tool). --------- Co-authored-by: vbarda <vadym@langchain.dev>
This commit is contained in:
parent
673ec00030
commit
b75573e858
@ -404,6 +404,7 @@ def filter_messages(
|
||||
exclude_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None,
|
||||
include_ids: Optional[Sequence[str]] = None,
|
||||
exclude_ids: Optional[Sequence[str]] = None,
|
||||
exclude_tool_calls: Optional[Sequence[str] | bool] = None,
|
||||
) -> list[BaseMessage]:
|
||||
"""Filter messages based on name, type or id.
|
||||
|
||||
@ -419,6 +420,13 @@ def filter_messages(
|
||||
SystemMessage, HumanMessage, AIMessage, ...). Default is None.
|
||||
include_ids: Message IDs to include. Default is None.
|
||||
exclude_ids: Message IDs to exclude. Default is None.
|
||||
exclude_tool_calls: Tool call IDs to exclude. Default is None.
|
||||
Can be one of the following:
|
||||
- `True`: all AIMessages with tool calls and all ToolMessages will be excluded.
|
||||
- a sequence of tool call IDs to exclude:
|
||||
- ToolMessages with the corresponding tool call ID will be excluded.
|
||||
- The `tool_calls` in the AIMessage will be updated to exclude matching tool calls.
|
||||
If all tool_calls are filtered from an AIMessage, the whole message is excluded.
|
||||
|
||||
Returns:
|
||||
A list of Messages that meets at least one of the incl_* conditions and none
|
||||
@ -467,6 +475,43 @@ def filter_messages(
|
||||
else:
|
||||
pass
|
||||
|
||||
if exclude_tool_calls is True and (
|
||||
(isinstance(msg, AIMessage) and msg.tool_calls)
|
||||
or isinstance(msg, ToolMessage)
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(exclude_tool_calls, (list, tuple, set)):
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
tool_calls = [
|
||||
tool_call
|
||||
for tool_call in msg.tool_calls
|
||||
if tool_call["id"] not in exclude_tool_calls
|
||||
]
|
||||
if not tool_calls:
|
||||
continue
|
||||
|
||||
content = msg.content
|
||||
# handle Anthropic content blocks
|
||||
if isinstance(msg.content, list):
|
||||
content = [
|
||||
content_block
|
||||
for content_block in msg.content
|
||||
if (
|
||||
not isinstance(content_block, dict)
|
||||
or content_block.get("type") != "tool_use"
|
||||
or content_block.get("id") not in exclude_tool_calls
|
||||
)
|
||||
]
|
||||
|
||||
msg = msg.model_copy(
|
||||
update={"tool_calls": tool_calls, "content": content}
|
||||
)
|
||||
elif (
|
||||
isinstance(msg, ToolMessage) and msg.tool_call_id in exclude_tool_calls
|
||||
):
|
||||
continue
|
||||
|
||||
# default to inclusion when no inclusion criteria given.
|
||||
if (
|
||||
not (include_types or include_ids or include_names)
|
||||
|
@ -165,6 +165,94 @@ def test_filter_message(filters: dict) -> None:
|
||||
assert messages == messages_model_copy
|
||||
|
||||
|
||||
def test_filter_message_exclude_tool_calls() -> None:
|
||||
tool_calls = [
|
||||
{"name": "foo", "id": "1", "args": {}, "type": "tool_call"},
|
||||
{"name": "bar", "id": "2", "args": {}, "type": "tool_call"},
|
||||
]
|
||||
messages = [
|
||||
HumanMessage("foo", name="blah", id="1"),
|
||||
AIMessage("foo-response", name="blah", id="2"),
|
||||
HumanMessage("bar", name="blur", id="3"),
|
||||
AIMessage(
|
||||
"bar-response",
|
||||
tool_calls=tool_calls,
|
||||
id="4",
|
||||
),
|
||||
ToolMessage("baz", tool_call_id="1", id="5"),
|
||||
ToolMessage("qux", tool_call_id="2", id="6"),
|
||||
]
|
||||
messages_model_copy = [m.model_copy(deep=True) for m in messages]
|
||||
expected = messages[:3]
|
||||
|
||||
# test excluding all tool calls
|
||||
actual = filter_messages(messages, exclude_tool_calls=True)
|
||||
assert expected == actual
|
||||
|
||||
# test explicitly excluding all tool calls
|
||||
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
|
||||
assert expected == actual
|
||||
|
||||
# test excluding a specific tool call
|
||||
expected = messages[:5]
|
||||
expected[3] = expected[3].model_copy(update={"tool_calls": [tool_calls[0]]})
|
||||
actual = filter_messages(messages, exclude_tool_calls=["2"])
|
||||
assert expected == actual
|
||||
|
||||
# assert that we didn't mutate the original messages
|
||||
assert messages == messages_model_copy
|
||||
|
||||
|
||||
def test_filter_message_exclude_tool_calls_content_blocks() -> None:
|
||||
tool_calls = [
|
||||
{"name": "foo", "id": "1", "args": {}, "type": "tool_call"},
|
||||
{"name": "bar", "id": "2", "args": {}, "type": "tool_call"},
|
||||
]
|
||||
messages = [
|
||||
HumanMessage("foo", name="blah", id="1"),
|
||||
AIMessage("foo-response", name="blah", id="2"),
|
||||
HumanMessage("bar", name="blur", id="3"),
|
||||
AIMessage(
|
||||
[
|
||||
{"text": "bar-response", "type": "text"},
|
||||
{"name": "foo", "type": "tool_use", "id": "1"},
|
||||
{"name": "bar", "type": "tool_use", "id": "2"},
|
||||
],
|
||||
tool_calls=tool_calls,
|
||||
id="4",
|
||||
),
|
||||
ToolMessage("baz", tool_call_id="1", id="5"),
|
||||
ToolMessage("qux", tool_call_id="2", id="6"),
|
||||
]
|
||||
messages_model_copy = [m.model_copy(deep=True) for m in messages]
|
||||
expected = messages[:3]
|
||||
|
||||
# test excluding all tool calls
|
||||
actual = filter_messages(messages, exclude_tool_calls=True)
|
||||
assert expected == actual
|
||||
|
||||
# test explicitly excluding all tool calls
|
||||
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
|
||||
assert expected == actual
|
||||
|
||||
# test excluding a specific tool call
|
||||
expected = messages[:4] + messages[-1:]
|
||||
expected[3] = expected[3].model_copy(
|
||||
update={
|
||||
"tool_calls": [tool_calls[1]],
|
||||
"content": [
|
||||
{"text": "bar-response", "type": "text"},
|
||||
{"name": "bar", "type": "tool_use", "id": "2"},
|
||||
],
|
||||
}
|
||||
)
|
||||
actual = filter_messages(messages, exclude_tool_calls=["1"])
|
||||
assert expected == actual
|
||||
|
||||
# assert that we didn't mutate the original messages
|
||||
assert messages == messages_model_copy
|
||||
|
||||
|
||||
_MESSAGES_TO_TRIM = [
|
||||
SystemMessage("This is a 4 token text."),
|
||||
HumanMessage("This is a 4 token text.", id="first"),
|
||||
|
Loading…
Reference in New Issue
Block a user