From f42d80ca1c31dd5ee598dcb53d2d67d33bbe9d2d Mon Sep 17 00:00:00 2001 From: Nick Hollon Date: Thu, 14 May 2026 11:19:45 -0700 Subject: [PATCH] fix(core): preserve chunk `additional_kwargs` across v3 stream assembly (#37435) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The v3 streaming path drops `additional_kwargs` from per-chunk `AIMessageChunk`s during assembly: `chunks_to_events` emits no event field for them, and `ChatModelStream._assemble_message` constructs the final `AIMessage` without an `additional_kwargs` argument. Non-streaming `ainvoke` returns the provider message unchanged, so streaming and non-streaming diverge for any provider that uses `additional_kwargs` to carry data outside the typed protocol blocks. ## How this surfaces The concrete failure mode is Gemini's `__gemini_function_call_thought_signatures__` — a per-tool-call signature blob the Google GenAI integration places in `additional_kwargs`, keyed by `tool_call_id`. Gemini requires that signature on follow-up turns to replay the prior thought trace; without it, multi-turn streaming flows lose thought continuity (and may regenerate thinking, charging additional reasoning tokens, or in some cases refuse). Other providers that use `additional_kwargs` (e.g. older `function_call` accumulators, custom routing metadata) hit the same gap; the fix is intentionally provider-agnostic. ## Fix Provider-agnostic, two seams: - `_compat_bridge` accumulates `msg.additional_kwargs` across chunks with `merge_dicts` (matching `AIMessageChunk`'s own merge semantics for fields that accumulate, like `function_call`) and emits the merged dict on the `message-finish` event as an off-spec extension. The bridge already uses one such extension (`metadata` on `MessageFinishData`); this PR follows the same pattern for `additional_kwargs`. - `ChatModelStream._finish` reads the new field; `_assemble_message` threads it onto the final `AIMessage` only when non-empty, preserving today's behavior of leaving `additional_kwargs` empty when no provider data needs to ride on it. --- .../language_models/_compat_bridge.py | 33 ++++ .../language_models/chat_model_stream.py | 29 +++- .../language_models/test_compat_bridge.py | 148 ++++++++++++++++++ 3 files changed, 202 insertions(+), 8 deletions(-) diff --git a/libs/core/langchain_core/language_models/_compat_bridge.py b/libs/core/langchain_core/language_models/_compat_bridge.py index 527bd652d4b..43490f0717c 100644 --- a/libs/core/langchain_core/language_models/_compat_bridge.py +++ b/libs/core/langchain_core/language_models/_compat_bridge.py @@ -57,6 +57,7 @@ from langchain_protocol.protocol import ( ) from langchain_core.messages import AIMessageChunk, BaseMessage +from langchain_core.utils._merge import merge_dicts if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterator @@ -525,6 +526,7 @@ def _build_message_finish( *, usage: dict[str, Any] | None, response_metadata: dict[str, Any] | None, + additional_kwargs: dict[str, Any] | None = None, ) -> MessageFinishData: # Protocol 0.0.9 removed the top-level `reason` field from # `MessageFinishData`; the provider's raw `finish_reason` / @@ -536,6 +538,16 @@ def _build_message_finish( finish_data["usage"] = usage_info if response_metadata: finish_data["metadata"] = dict(response_metadata) + # `additional_kwargs` is an off-spec extension on the message-finish + # event (parallel to `metadata`, which `MessageFinishData` also doesn't + # formally declare but the consumer reads). It carries provider-side + # kwargs that don't map onto a typed protocol field — notably Gemini's + # `__gemini_function_call_thought_signatures__`, which the model + # requires on follow-up turns to replay prior thinking. Without this, + # streaming-assembled messages would silently drop data that + # `ainvoke` preserves, breaking multi-turn streaming flows. + if additional_kwargs: + finish_data["additional_kwargs"] = dict(additional_kwargs) return cast("MessageFinishData", finish_data) @@ -582,6 +594,7 @@ def chunks_to_events( next_wire_idx = 0 usage: dict[str, Any] | None = None response_metadata: dict[str, Any] = {} + additional_kwargs: dict[str, Any] = {} for chunk in chunks: msg = chunk.message @@ -604,6 +617,17 @@ def chunks_to_events( if merged_rm: response_metadata.update(merged_rm) + # Carry chunks' `additional_kwargs` through to the assembled + # message. Provider-side fields that don't map onto a typed + # protocol block (e.g. Gemini's per-tool-call thought signatures) + # live here on non-streaming `ainvoke` results; dropping them on + # the streaming path silently diverges multi-turn behavior. Use + # `merge_dicts` because the same key can arrive in pieces across + # chunks (e.g. an accumulating `function_call`), matching how + # `AIMessageChunk` merges itself. + if msg.additional_kwargs: + additional_kwargs = merge_dicts(additional_kwargs, msg.additional_kwargs) + if not started: started = True yield _build_message_start(msg, message_id) @@ -646,6 +670,7 @@ def chunks_to_events( yield _build_message_finish( usage=usage, response_metadata=response_metadata, + additional_kwargs=additional_kwargs, ) @@ -660,6 +685,7 @@ async def achunks_to_events( next_wire_idx = 0 usage: dict[str, Any] | None = None response_metadata: dict[str, Any] = {} + additional_kwargs: dict[str, Any] = {} async for chunk in chunks: msg = chunk.message @@ -676,6 +702,12 @@ async def achunks_to_events( if merged_rm: response_metadata.update(merged_rm) + # See sync twin: carry chunk `additional_kwargs` through so + # provider-specific data (e.g. Gemini thought signatures) reaches + # the assembled message instead of being dropped. + if msg.additional_kwargs: + additional_kwargs = merge_dicts(additional_kwargs, msg.additional_kwargs) + if not started: started = True yield _build_message_start(msg, message_id) @@ -718,6 +750,7 @@ async def achunks_to_events( yield _build_message_finish( usage=usage, response_metadata=response_metadata, + additional_kwargs=additional_kwargs, ) diff --git a/libs/core/langchain_core/language_models/chat_model_stream.py b/libs/core/langchain_core/language_models/chat_model_stream.py index 6b25a66fd2a..82da8c28ac7 100644 --- a/libs/core/langchain_core/language_models/chat_model_stream.py +++ b/libs/core/langchain_core/language_models/chat_model_stream.py @@ -519,6 +519,7 @@ class _ChatModelStreamBase: self._usage_value: UsageInfo | None = None self._start_metadata: MessageMetadata | None = None self._finish_metadata: dict[str, Any] | None = None + self._additional_kwargs: dict[str, Any] | None = None self._done: bool = False self._error: BaseException | None = None self._output_message: AIMessage | None = None @@ -896,6 +897,15 @@ class _ChatModelStreamBase: self._done = True self._usage_value = data.get("usage") self._finish_metadata = cast("dict[str, Any] | None", data.get("metadata")) + # Off-spec extension carrying provider-side `additional_kwargs` + # that don't map onto a typed protocol field (e.g. Gemini's + # `__gemini_function_call_thought_signatures__`). The compat + # bridge emits this on `message-finish` so the assembled message + # carries the same data `ainvoke` would have preserved. + self._additional_kwargs = cast( + "dict[str, Any] | None", + cast("dict[str, Any]", data).get("additional_kwargs"), + ) # Finalize any unswept chunks — both client- and server-side. _sweep_chunk_store( @@ -1011,14 +1021,17 @@ class _ChatModelStreamBase: for itc in self._invalid_tool_calls_acc ] - return AIMessage( - content=content, - id=self._message_id, - tool_calls=tool_calls, - invalid_tool_calls=invalid_tool_calls, - usage_metadata=self._usage_value, - response_metadata=response_metadata, - ) + message_kwargs: dict[str, Any] = { + "content": content, + "id": self._message_id, + "tool_calls": tool_calls, + "invalid_tool_calls": invalid_tool_calls, + "usage_metadata": self._usage_value, + "response_metadata": response_metadata, + } + if self._additional_kwargs: + message_kwargs["additional_kwargs"] = dict(self._additional_kwargs) + return AIMessage(**message_kwargs) # --------------------------------------------------------------------------- diff --git a/libs/core/tests/unit_tests/language_models/test_compat_bridge.py b/libs/core/tests/unit_tests/language_models/test_compat_bridge.py index e32591ef4e9..9480e80b56c 100644 --- a/libs/core/tests/unit_tests/language_models/test_compat_bridge.py +++ b/libs/core/tests/unit_tests/language_models/test_compat_bridge.py @@ -15,6 +15,10 @@ from langchain_core.language_models._compat_bridge import ( chunks_to_events, message_to_events, ) +from langchain_core.language_models.chat_model_stream import ( + AsyncChatModelStream, + ChatModelStream, +) from langchain_core.messages import AIMessage, AIMessageChunk from langchain_core.outputs import ChatGenerationChunk @@ -649,6 +653,150 @@ async def test_achunks_to_events_reasoning_then_tool_call_no_index() -> None: assert "tool_call" in finish_types +def test_chunks_to_events_preserves_additional_kwargs_on_assembled_message() -> None: + """Streaming-assembled `AIMessage` must retain chunk `additional_kwargs`. + + When Gemini emits a `tool_call` after `thinking`, the source chunk + carries a `__gemini_function_call_thought_signatures__` entry in + `additional_kwargs`, keyed by `tool_call_id`. This signature is required + on follow-up turns so Gemini can replay the prior thought trace. + + Non-streaming `ainvoke` returns the provider's `AIMessage` unchanged, so + the signature is preserved. The v3 streaming path runs chunks through + `chunks_to_events` -> `ChatModelStream._assemble_message`; before this + fix, that path built a fresh `AIMessage` without forwarding + `additional_kwargs`, so the signature was silently dropped and + multi-turn streaming Gemini diverged from non-streaming. + + Provider-specific kwargs aren't this layer's business individually; the + invariant is the general one — chunks' `additional_kwargs` survive into + the assembled message in *some* form a follow-up turn can reach. + """ + thought_signature = "CiIBDDnWx-EXAMPLE-SIGNATURE-PAYLOAD==" + tool_call_id = "tc-abc" + + chunks = [ + ChatGenerationChunk( + message=AIMessageChunk( + content=[{"type": "reasoning", "reasoning": "Thinking..."}], + response_metadata={ + "output_version": "v1", + "model_provider": "google_genai", + }, + ) + ), + ChatGenerationChunk( + message=AIMessageChunk( + content=[ + { + "type": "tool_call", + "id": tool_call_id, + "name": "get_weather", + "args": {"city": "San Francisco"}, + } + ], + additional_kwargs={ + "__gemini_function_call_thought_signatures__": { + tool_call_id: thought_signature, + }, + }, + response_metadata={ + "output_version": "v1", + "model_provider": "google_genai", + }, + ) + ), + ] + + stream = ChatModelStream() + for event in chunks_to_events(iter(chunks), message_id="msg-1"): + stream.dispatch(event) + msg = stream.output + + # Reachable through *some* channel a downstream consumer can route on. + # Either `additional_kwargs` (mirrors non-streaming), or an `extras` + # field on the `tool_call` block (the v1 protocol's slot for + # provider-specific data). + via_additional_kwargs = msg.additional_kwargs.get( + "__gemini_function_call_thought_signatures__", {} + ).get(tool_call_id) + tool_call_blocks = [ + b + for b in (msg.content if isinstance(msg.content, list) else []) + if isinstance(b, dict) and b.get("type") == "tool_call" + ] + via_block_extras = next( + ( + b.get("extras", {}).get("thought_signature") + for b in tool_call_blocks + if b.get("id") == tool_call_id + ), + None, + ) + + signature_preserved = via_additional_kwargs or via_block_extras + assert signature_preserved == thought_signature, ( + "Chunk-level additional_kwargs (Gemini thought signature) was dropped " + "during v3 stream assembly. Streaming-assembled AIMessage exposes " + f"additional_kwargs={msg.additional_kwargs!r}, tool_call blocks=" + f"{tool_call_blocks!r}. Non-streaming `ainvoke` preserves this " + "signature in additional_kwargs unchanged; streaming should not " + "diverge." + ) + + +@pytest.mark.asyncio +async def test_achunks_to_events_preserves_additional_kwargs_on_assembled_message() -> ( + None +): + """Async twin of the additional_kwargs preservation regression.""" + thought_signature = "CiIBDDnWx-EXAMPLE-SIGNATURE-PAYLOAD==" + tool_call_id = "tc-abc" + + chunks = [ + ChatGenerationChunk( + message=AIMessageChunk( + content=[{"type": "reasoning", "reasoning": "Thinking..."}], + response_metadata={ + "output_version": "v1", + "model_provider": "google_genai", + }, + ) + ), + ChatGenerationChunk( + message=AIMessageChunk( + content=[ + { + "type": "tool_call", + "id": tool_call_id, + "name": "get_weather", + "args": {"city": "San Francisco"}, + } + ], + additional_kwargs={ + "__gemini_function_call_thought_signatures__": { + tool_call_id: thought_signature, + }, + }, + response_metadata={ + "output_version": "v1", + "model_provider": "google_genai", + }, + ) + ), + ] + + stream = AsyncChatModelStream() + async for event in achunks_to_events(_aiter_chunks(chunks), message_id="msg-1"): + stream.dispatch(event) + msg = await stream.output + + via_additional_kwargs = msg.additional_kwargs.get( + "__gemini_function_call_thought_signatures__", {} + ).get(tool_call_id) + assert via_additional_kwargs == thought_signature + + def test_chunks_to_events_reasoning_in_additional_kwargs() -> None: """Reasoning packed into additional_kwargs surfaces as a reasoning block.""" chunks = [