diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 011c1d2ed46..ccaf1f04ef0 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -545,6 +545,9 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): and call_id in id_to_tc ): self.content[idx] = cast("dict[str, Any]", id_to_tc[call_id]) + if "extras" in block: + # mypy does not account for instance check for dict above + self.content[idx]["extras"] = block["extras"] # type: ignore[index] return self diff --git a/libs/core/tests/unit_tests/messages/test_ai.py b/libs/core/tests/unit_tests/messages/test_ai.py index a65033f82c1..6ced0a59e08 100644 --- a/libs/core/tests/unit_tests/messages/test_ai.py +++ b/libs/core/tests/unit_tests/messages/test_ai.py @@ -358,6 +358,8 @@ def test_content_blocks() -> None: # test v1 content chunk_1.content = cast("str | list[str | dict]", chunk_1.content_blocks) + assert len(chunk_1.content) == 1 + chunk_1.content[0]["extras"] = {"baz": "qux"} # type: ignore[index] chunk_1.response_metadata["output_version"] = "v1" chunk_2.content = cast("str | list[str | dict]", chunk_2.content_blocks) @@ -368,6 +370,7 @@ def test_content_blocks() -> None: "name": "foo", "args": {"foo": "bar"}, "id": "abc_123", + "extras": {"baz": "qux"}, } ]