mirror of
https://github.com/hwchase17/langchain.git
synced 2025-10-23 11:16:58 +00:00
fix(core): propagate extras when aggregating tool calls in v1 content (#33494)
This commit is contained in:
@@ -545,6 +545,9 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
and call_id in id_to_tc
|
||||
):
|
||||
self.content[idx] = cast("dict[str, Any]", id_to_tc[call_id])
|
||||
if "extras" in block:
|
||||
# mypy does not account for instance check for dict above
|
||||
self.content[idx]["extras"] = block["extras"] # type: ignore[index]
|
||||
|
||||
return self
|
||||
|
||||
|
@@ -358,6 +358,8 @@ def test_content_blocks() -> None:
|
||||
|
||||
# test v1 content
|
||||
chunk_1.content = cast("str | list[str | dict]", chunk_1.content_blocks)
|
||||
assert len(chunk_1.content) == 1
|
||||
chunk_1.content[0]["extras"] = {"baz": "qux"} # type: ignore[index]
|
||||
chunk_1.response_metadata["output_version"] = "v1"
|
||||
chunk_2.content = cast("str | list[str | dict]", chunk_2.content_blocks)
|
||||
|
||||
@@ -368,6 +370,7 @@ def test_content_blocks() -> None:
|
||||
"name": "foo",
|
||||
"args": {"foo": "bar"},
|
||||
"id": "abc_123",
|
||||
"extras": {"baz": "qux"},
|
||||
}
|
||||
]
|
||||
|
||||
|
Reference in New Issue
Block a user