Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
6c46c48ae8 fmt 2024-06-21 11:12:34 -07:00
Bagatur
ec74a22bb1 rfc: tool call / message filter 2024-06-21 11:09:12 -07:00

View File

@@ -310,6 +310,8 @@ def filter_messages(
exclude_types: Optional[Sequence[Union[str, Type[BaseMessage]]]] = None,
include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None,
include_tool_names: Optional[Sequence[str]] = None,
exclude_tool_names: Optional[Sequence[str]] = None,
) -> List[BaseMessage]:
"""Filter messages based on name, type or id.
@@ -325,6 +327,10 @@ def filter_messages(
SystemMessage, HumanMessage, AIMessage, ...).
include_ids: Message IDs to include.
exclude_ids: Message IDs to exclude.
include_tool_names: Remove any ToolCalls (on AIMessages) and ToolMessages that
don't have one of these tool names.
exclude_tool_names: Remove any ToolCalls (on AIMessages) and ToolMessage that
have one of these tool names.
Returns:
A list of Messages that meets at least one of the incl_* conditions and none
@@ -363,6 +369,19 @@ def filter_messages(
""" # noqa: E501
messages = convert_to_messages(messages)
filtered: List[BaseMessage] = []
include_tool_names = include_tool_names or ()
exclude_tool_names = exclude_tool_names or ()
def exclude_tool_name_condition(name: Optional[str]) -> bool:
if name is None:
return False
elif name in exclude_names:
return True
elif include_tool_names and name not in include_tool_names:
return True
else:
return False
for msg in messages:
if exclude_names and msg.name in exclude_names:
continue
@@ -373,6 +392,71 @@ def filter_messages(
else:
pass
if include_tool_names or exclude_tool_names:
if (
isinstance(msg, AIMessage)
and msg.tool_calls
or (
isinstance(msg.content, list)
and any(
block.get("type") == "tool_use"
for block in msg.content
if isinstance(block, dict)
)
)
):
msg = msg.copy(deep=True)
filtered_tool_calls = []
for tc in msg.tool_calls:
if exclude_tool_name_condition(tc["name"]):
continue
else:
filtered_tool_calls.append(tc)
msg.tool_calls = filtered_tool_calls
if isinstance(msg.content, list):
filtered_content = []
for block in msg.content:
if (
isinstance(block, dict)
and block.get("type") == "tool_use"
and exclude_tool_name_condition(block.get("name"))
):
continue
else:
filtered_content.append(block)
msg.content = filtered_content
elif isinstance(msg, ToolMessage):
tool_name = _tool_call_id_to_tool_name(messages, msg.tool_call_id)
if exclude_tool_name_condition(tool_name):
continue
elif (
isinstance(msg, HumanMessage)
and isinstance(msg.content, list)
and any(
block.get("type") == "tool_result"
for block in msg.content
if isinstance(block, dict)
)
):
msg = msg.copy(deep=True)
filtered_content = []
for block in msg.content:
if (
isinstance(block, dict)
and block.get("type") == "tool_result"
and exclude_tool_name_condition(
_tool_call_id_to_tool_name(
messages, block.get("tool_use_id")
)
)
):
continue
else:
filtered_content.append(block)
msg.content = filtered_content
else:
pass
# default to inclusion when no inclusion criteria given.
if not (include_types or include_ids or include_names):
filtered.append(msg)
@@ -388,6 +472,19 @@ def filter_messages(
return filtered
def _tool_call_id_to_tool_name(
messages: List[BaseMessage], tool_call_id: str
) -> Optional[str]:
tool_calls = [
tc
for msg in messages
if isinstance(msg, AIMessage)
for tc in msg.tool_calls
if tc["id"] == tool_call_id
]
return tool_calls[0]["name"] if tool_calls else None
@_runnable_support
def merge_message_runs(
messages: Sequence[MessageLikeRepresentation],