anthropic[patch]: allow tool call mutation (#22130)

If tool_use blocks and tool_calls with overlapping IDs are present,
prefer the values of the tool_calls. Allows for mutating AIMessages just
via tool_calls.
This commit is contained in:
Bagatur 2024-05-24 08:18:14 -07:00 committed by GitHub
parent c838de5027
commit baa3c975cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 69 additions and 48 deletions

View File

@ -172,25 +172,29 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
content = [] content = []
for item in message.content: for item in message.content:
if isinstance(item, str): if isinstance(item, str):
content.append( content.append({"type": "text", "text": item})
{
"type": "text",
"text": item,
}
)
elif isinstance(item, dict): elif isinstance(item, dict):
if "type" not in item: if "type" not in item:
raise ValueError("Dict content item must have a type key") raise ValueError("Dict content item must have a type key")
elif item["type"] == "image_url": elif item["type"] == "image_url":
# convert format # convert format
source = _format_image(item["image_url"]["url"]) source = _format_image(item["image_url"]["url"])
content.append( content.append({"type": "image", "source": source})
{
"type": "image",
"source": source,
}
)
elif item["type"] == "tool_use": elif item["type"] == "tool_use":
# If a tool_call with the same id as a tool_use content block
# exists, the tool_call is preferred.
if isinstance(message, AIMessage) and item["id"] in [
tc["id"] for tc in message.tool_calls
]:
overlapping = [
tc
for tc in message.tool_calls
if tc["id"] == item["id"]
]
content.extend(
_lc_tool_calls_to_anthropic_tool_use_blocks(overlapping)
)
else:
item.pop("text", None) item.pop("text", None)
content.append(item) content.append(item)
elif item["type"] == "text": elif item["type"] == "text":
@ -199,23 +203,14 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
# accepted. # accepted.
# https://github.com/anthropics/anthropic-sdk-python/issues/461 # https://github.com/anthropics/anthropic-sdk-python/issues/461
if text.strip(): if text.strip():
content.append( content.append({"type": "text", "text": text})
{
"type": "text",
"text": text,
}
)
else: else:
content.append(item) content.append(item)
else: else:
raise ValueError( raise ValueError(
f"Content items must be str or dict, instead was: {type(item)}" f"Content items must be str or dict, instead was: {type(item)}"
) )
elif ( elif isinstance(message, AIMessage) and message.tool_calls:
isinstance(message, AIMessage)
and not isinstance(message.content, list)
and message.tool_calls
):
content = ( content = (
[] []
if not message.content if not message.content
@ -228,12 +223,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
else: else:
content = message.content content = message.content
formatted_messages.append( formatted_messages.append({"role": role, "content": content})
{
"role": role,
"content": content,
}
)
return system, formatted_messages return system, formatted_messages

View File

@ -352,10 +352,7 @@ def test__format_messages_with_str_content_and_tool_calls() -> None:
"thought", "thought",
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}], tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}],
) )
tool = ToolMessage( tool = ToolMessage("blurb", tool_call_id="1")
"blurb",
tool_call_id="1",
)
messages = [system, human, ai, tool] messages = [system, human, ai, tool]
expected = ( expected = (
"fuzz", "fuzz",
@ -364,10 +361,7 @@ def test__format_messages_with_str_content_and_tool_calls() -> None:
{ {
"role": "assistant", "role": "assistant",
"content": [ "content": [
{ {"type": "text", "text": "thought"},
"type": "text",
"text": "thought",
},
{ {
"type": "tool_use", "type": "tool_use",
"name": "bar", "name": "bar",
@ -394,12 +388,7 @@ def test__format_messages_with_list_content_and_tool_calls() -> None:
# If content and tool_calls are specified and content is a list, then content is # If content and tool_calls are specified and content is a list, then content is
# preferred. # preferred.
ai = AIMessage( ai = AIMessage(
[ [{"type": "text", "text": "thought"}],
{
"type": "text",
"text": "thought",
}
],
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}], tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}],
) )
tool = ToolMessage( tool = ToolMessage(
@ -413,11 +402,53 @@ def test__format_messages_with_list_content_and_tool_calls() -> None:
{"role": "user", "content": "foo"}, {"role": "user", "content": "foo"},
{ {
"role": "assistant", "role": "assistant",
"content": [ "content": [{"type": "text", "text": "thought"}],
},
{ {
"type": "text", "role": "user",
"text": "thought", "content": [
} {"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
],
},
],
)
actual = _format_messages(messages)
assert expected == actual
def test__format_messages_with_tool_use_blocks_and_tool_calls() -> None:
"""Show that tool_calls are preferred to tool_use blocks when both have same id."""
system = SystemMessage("fuzz")
human = HumanMessage("foo")
# NOTE: tool_use block in contents and tool_calls have different arguments.
ai = AIMessage(
[
{"type": "text", "text": "thought"},
{
"type": "tool_use",
"name": "bar",
"id": "1",
"input": {"baz": "NOT_BUZZ"},
},
],
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "BUZZ"}}],
)
tool = ToolMessage("blurb", tool_call_id="1")
messages = [system, human, ai, tool]
expected = (
"fuzz",
[
{"role": "user", "content": "foo"},
{
"role": "assistant",
"content": [
{"type": "text", "text": "thought"},
{
"type": "tool_use",
"name": "bar",
"id": "1",
"input": {"baz": "BUZZ"}, # tool_calls value preferred.
},
], ],
}, },
{ {