mirror of
https://github.com/hwchase17/langchain.git
synced 2025-10-24 20:20:50 +00:00
fix(core): propagate extras when aggregating tool calls in v1 content (#33494)
This commit is contained in:
@@ -545,6 +545,9 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
|||||||
and call_id in id_to_tc
|
and call_id in id_to_tc
|
||||||
):
|
):
|
||||||
self.content[idx] = cast("dict[str, Any]", id_to_tc[call_id])
|
self.content[idx] = cast("dict[str, Any]", id_to_tc[call_id])
|
||||||
|
if "extras" in block:
|
||||||
|
# mypy does not account for instance check for dict above
|
||||||
|
self.content[idx]["extras"] = block["extras"] # type: ignore[index]
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@@ -358,6 +358,8 @@ def test_content_blocks() -> None:
|
|||||||
|
|
||||||
# test v1 content
|
# test v1 content
|
||||||
chunk_1.content = cast("str | list[str | dict]", chunk_1.content_blocks)
|
chunk_1.content = cast("str | list[str | dict]", chunk_1.content_blocks)
|
||||||
|
assert len(chunk_1.content) == 1
|
||||||
|
chunk_1.content[0]["extras"] = {"baz": "qux"} # type: ignore[index]
|
||||||
chunk_1.response_metadata["output_version"] = "v1"
|
chunk_1.response_metadata["output_version"] = "v1"
|
||||||
chunk_2.content = cast("str | list[str | dict]", chunk_2.content_blocks)
|
chunk_2.content = cast("str | list[str | dict]", chunk_2.content_blocks)
|
||||||
|
|
||||||
@@ -368,6 +370,7 @@ def test_content_blocks() -> None:
|
|||||||
"name": "foo",
|
"name": "foo",
|
||||||
"args": {"foo": "bar"},
|
"args": {"foo": "bar"},
|
||||||
"id": "abc_123",
|
"id": "abc_123",
|
||||||
|
"extras": {"baz": "qux"},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user