mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-25 01:16:55 +00:00
core: populate tool_calls when initializing AIMessage via content_blocks
This commit is contained in:
@@ -203,6 +203,13 @@ class AIMessage(BaseMessage):
|
||||
) -> None:
|
||||
"""Specify content as a positional arg or content_blocks for typing support."""
|
||||
if content_blocks is not None:
|
||||
# If there are tool calls in content_blocks, but not in tool_calls, add them
|
||||
content_tool_calls = [
|
||||
block for block in content_blocks if block.get("type") == "tool_call"
|
||||
]
|
||||
if content_tool_calls and "tool_calls" not in kwargs:
|
||||
kwargs["tool_calls"] = content_tool_calls
|
||||
|
||||
super().__init__(
|
||||
content=cast("Union[str, list[Union[str, dict]]]", content_blocks),
|
||||
**kwargs,
|
||||
@@ -273,7 +280,9 @@ class AIMessage(BaseMessage):
|
||||
# Ensure "type" is properly set on all tool call-like dicts.
|
||||
if tool_calls := values.get("tool_calls"):
|
||||
values["tool_calls"] = [
|
||||
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
||||
create_tool_call(
|
||||
**{k: v for k, v in tc.items() if k not in ("type", "extras")}
|
||||
)
|
||||
for tc in tool_calls
|
||||
]
|
||||
if invalid_tool_calls := values.get("invalid_tool_calls"):
|
||||
|
||||
@@ -253,7 +253,7 @@ def test_content_blocks() -> None:
|
||||
"id": "abc_123",
|
||||
},
|
||||
]
|
||||
missing_tool_call = {
|
||||
missing_tool_call: types.ToolCall = {
|
||||
"type": "tool_call",
|
||||
"name": "bar",
|
||||
"args": {"c": "d"},
|
||||
@@ -267,3 +267,20 @@ def test_content_blocks() -> None:
|
||||
],
|
||||
)
|
||||
assert message.content_blocks == [*standard_content, missing_tool_call]
|
||||
|
||||
# Check we auto-populate tool_calls
|
||||
standard_content = [
|
||||
{"type": "text", "text": "foo"},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"name": "foo",
|
||||
"args": {"a": "b"},
|
||||
"id": "abc_123",
|
||||
},
|
||||
missing_tool_call,
|
||||
]
|
||||
message = AIMessage(content_blocks=standard_content)
|
||||
assert message.tool_calls == [
|
||||
{"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"},
|
||||
missing_tool_call,
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user