fix(core): preserve chunk additional_kwargs across v3 stream assembly (#37435)

The v3 streaming path drops `additional_kwargs` from per-chunk
`AIMessageChunk`s during assembly: `chunks_to_events` emits no event
field for them, and `ChatModelStream._assemble_message` constructs the
final `AIMessage` without an `additional_kwargs` argument. Non-streaming
`ainvoke` returns the provider message unchanged, so streaming and
non-streaming diverge for any provider that uses `additional_kwargs` to
carry data outside the typed protocol blocks.

## How this surfaces

The concrete failure mode is Gemini's
`__gemini_function_call_thought_signatures__` — a per-tool-call
signature blob the Google GenAI integration places in
`additional_kwargs`, keyed by `tool_call_id`. Gemini requires that
signature on follow-up turns to replay the prior thought trace; without
it, multi-turn streaming flows lose thought continuity (and may
regenerate thinking, charging additional reasoning tokens, or in some
cases refuse). Other providers that use `additional_kwargs` (e.g. older
`function_call` accumulators, custom routing metadata) hit the same gap;
the fix is intentionally provider-agnostic.

## Fix

Provider-agnostic, two seams:

- `_compat_bridge` accumulates `msg.additional_kwargs` across chunks
with `merge_dicts` (matching `AIMessageChunk`'s own merge semantics for
fields that accumulate, like `function_call`) and emits the merged dict
on the `message-finish` event as an off-spec extension. The bridge
already uses one such extension (`metadata` on `MessageFinishData`);
this PR follows the same pattern for `additional_kwargs`.
- `ChatModelStream._finish` reads the new field; `_assemble_message`
threads it onto the final `AIMessage` only when non-empty, preserving
today's behavior of leaving `additional_kwargs` empty when no provider
data needs to ride on it.
This commit is contained in:
Nick Hollon
2026-05-14 11:19:45 -07:00
committed by GitHub
parent 649d82f206
commit f42d80ca1c
3 changed files with 202 additions and 8 deletions

View File

@@ -57,6 +57,7 @@ from langchain_protocol.protocol import (
)
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.utils._merge import merge_dicts
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator
@@ -525,6 +526,7 @@ def _build_message_finish(
*,
usage: dict[str, Any] | None,
response_metadata: dict[str, Any] | None,
additional_kwargs: dict[str, Any] | None = None,
) -> MessageFinishData:
# Protocol 0.0.9 removed the top-level `reason` field from
# `MessageFinishData`; the provider's raw `finish_reason` /
@@ -536,6 +538,16 @@ def _build_message_finish(
finish_data["usage"] = usage_info
if response_metadata:
finish_data["metadata"] = dict(response_metadata)
# `additional_kwargs` is an off-spec extension on the message-finish
# event (parallel to `metadata`, which `MessageFinishData` also doesn't
# formally declare but the consumer reads). It carries provider-side
# kwargs that don't map onto a typed protocol field — notably Gemini's
# `__gemini_function_call_thought_signatures__`, which the model
# requires on follow-up turns to replay prior thinking. Without this,
# streaming-assembled messages would silently drop data that
# `ainvoke` preserves, breaking multi-turn streaming flows.
if additional_kwargs:
finish_data["additional_kwargs"] = dict(additional_kwargs)
return cast("MessageFinishData", finish_data)
@@ -582,6 +594,7 @@ def chunks_to_events(
next_wire_idx = 0
usage: dict[str, Any] | None = None
response_metadata: dict[str, Any] = {}
additional_kwargs: dict[str, Any] = {}
for chunk in chunks:
msg = chunk.message
@@ -604,6 +617,17 @@ def chunks_to_events(
if merged_rm:
response_metadata.update(merged_rm)
# Carry chunks' `additional_kwargs` through to the assembled
# message. Provider-side fields that don't map onto a typed
# protocol block (e.g. Gemini's per-tool-call thought signatures)
# live here on non-streaming `ainvoke` results; dropping them on
# the streaming path silently diverges multi-turn behavior. Use
# `merge_dicts` because the same key can arrive in pieces across
# chunks (e.g. an accumulating `function_call`), matching how
# `AIMessageChunk` merges itself.
if msg.additional_kwargs:
additional_kwargs = merge_dicts(additional_kwargs, msg.additional_kwargs)
if not started:
started = True
yield _build_message_start(msg, message_id)
@@ -646,6 +670,7 @@ def chunks_to_events(
yield _build_message_finish(
usage=usage,
response_metadata=response_metadata,
additional_kwargs=additional_kwargs,
)
@@ -660,6 +685,7 @@ async def achunks_to_events(
next_wire_idx = 0
usage: dict[str, Any] | None = None
response_metadata: dict[str, Any] = {}
additional_kwargs: dict[str, Any] = {}
async for chunk in chunks:
msg = chunk.message
@@ -676,6 +702,12 @@ async def achunks_to_events(
if merged_rm:
response_metadata.update(merged_rm)
# See sync twin: carry chunk `additional_kwargs` through so
# provider-specific data (e.g. Gemini thought signatures) reaches
# the assembled message instead of being dropped.
if msg.additional_kwargs:
additional_kwargs = merge_dicts(additional_kwargs, msg.additional_kwargs)
if not started:
started = True
yield _build_message_start(msg, message_id)
@@ -718,6 +750,7 @@ async def achunks_to_events(
yield _build_message_finish(
usage=usage,
response_metadata=response_metadata,
additional_kwargs=additional_kwargs,
)

View File

@@ -519,6 +519,7 @@ class _ChatModelStreamBase:
self._usage_value: UsageInfo | None = None
self._start_metadata: MessageMetadata | None = None
self._finish_metadata: dict[str, Any] | None = None
self._additional_kwargs: dict[str, Any] | None = None
self._done: bool = False
self._error: BaseException | None = None
self._output_message: AIMessage | None = None
@@ -896,6 +897,15 @@ class _ChatModelStreamBase:
self._done = True
self._usage_value = data.get("usage")
self._finish_metadata = cast("dict[str, Any] | None", data.get("metadata"))
# Off-spec extension carrying provider-side `additional_kwargs`
# that don't map onto a typed protocol field (e.g. Gemini's
# `__gemini_function_call_thought_signatures__`). The compat
# bridge emits this on `message-finish` so the assembled message
# carries the same data `ainvoke` would have preserved.
self._additional_kwargs = cast(
"dict[str, Any] | None",
cast("dict[str, Any]", data).get("additional_kwargs"),
)
# Finalize any unswept chunks — both client- and server-side.
_sweep_chunk_store(
@@ -1011,14 +1021,17 @@ class _ChatModelStreamBase:
for itc in self._invalid_tool_calls_acc
]
return AIMessage(
content=content,
id=self._message_id,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
usage_metadata=self._usage_value,
response_metadata=response_metadata,
)
message_kwargs: dict[str, Any] = {
"content": content,
"id": self._message_id,
"tool_calls": tool_calls,
"invalid_tool_calls": invalid_tool_calls,
"usage_metadata": self._usage_value,
"response_metadata": response_metadata,
}
if self._additional_kwargs:
message_kwargs["additional_kwargs"] = dict(self._additional_kwargs)
return AIMessage(**message_kwargs)
# ---------------------------------------------------------------------------

View File

@@ -15,6 +15,10 @@ from langchain_core.language_models._compat_bridge import (
chunks_to_events,
message_to_events,
)
from langchain_core.language_models.chat_model_stream import (
AsyncChatModelStream,
ChatModelStream,
)
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
@@ -649,6 +653,150 @@ async def test_achunks_to_events_reasoning_then_tool_call_no_index() -> None:
assert "tool_call" in finish_types
def test_chunks_to_events_preserves_additional_kwargs_on_assembled_message() -> None:
"""Streaming-assembled `AIMessage` must retain chunk `additional_kwargs`.
When Gemini emits a `tool_call` after `thinking`, the source chunk
carries a `__gemini_function_call_thought_signatures__` entry in
`additional_kwargs`, keyed by `tool_call_id`. This signature is required
on follow-up turns so Gemini can replay the prior thought trace.
Non-streaming `ainvoke` returns the provider's `AIMessage` unchanged, so
the signature is preserved. The v3 streaming path runs chunks through
`chunks_to_events` -> `ChatModelStream._assemble_message`; before this
fix, that path built a fresh `AIMessage` without forwarding
`additional_kwargs`, so the signature was silently dropped and
multi-turn streaming Gemini diverged from non-streaming.
Provider-specific kwargs aren't this layer's business individually; the
invariant is the general one — chunks' `additional_kwargs` survive into
the assembled message in *some* form a follow-up turn can reach.
"""
thought_signature = "CiIBDDnWx-EXAMPLE-SIGNATURE-PAYLOAD=="
tool_call_id = "tc-abc"
chunks = [
ChatGenerationChunk(
message=AIMessageChunk(
content=[{"type": "reasoning", "reasoning": "Thinking..."}],
response_metadata={
"output_version": "v1",
"model_provider": "google_genai",
},
)
),
ChatGenerationChunk(
message=AIMessageChunk(
content=[
{
"type": "tool_call",
"id": tool_call_id,
"name": "get_weather",
"args": {"city": "San Francisco"},
}
],
additional_kwargs={
"__gemini_function_call_thought_signatures__": {
tool_call_id: thought_signature,
},
},
response_metadata={
"output_version": "v1",
"model_provider": "google_genai",
},
)
),
]
stream = ChatModelStream()
for event in chunks_to_events(iter(chunks), message_id="msg-1"):
stream.dispatch(event)
msg = stream.output
# Reachable through *some* channel a downstream consumer can route on.
# Either `additional_kwargs` (mirrors non-streaming), or an `extras`
# field on the `tool_call` block (the v1 protocol's slot for
# provider-specific data).
via_additional_kwargs = msg.additional_kwargs.get(
"__gemini_function_call_thought_signatures__", {}
).get(tool_call_id)
tool_call_blocks = [
b
for b in (msg.content if isinstance(msg.content, list) else [])
if isinstance(b, dict) and b.get("type") == "tool_call"
]
via_block_extras = next(
(
b.get("extras", {}).get("thought_signature")
for b in tool_call_blocks
if b.get("id") == tool_call_id
),
None,
)
signature_preserved = via_additional_kwargs or via_block_extras
assert signature_preserved == thought_signature, (
"Chunk-level additional_kwargs (Gemini thought signature) was dropped "
"during v3 stream assembly. Streaming-assembled AIMessage exposes "
f"additional_kwargs={msg.additional_kwargs!r}, tool_call blocks="
f"{tool_call_blocks!r}. Non-streaming `ainvoke` preserves this "
"signature in additional_kwargs unchanged; streaming should not "
"diverge."
)
@pytest.mark.asyncio
async def test_achunks_to_events_preserves_additional_kwargs_on_assembled_message() -> (
None
):
"""Async twin of the additional_kwargs preservation regression."""
thought_signature = "CiIBDDnWx-EXAMPLE-SIGNATURE-PAYLOAD=="
tool_call_id = "tc-abc"
chunks = [
ChatGenerationChunk(
message=AIMessageChunk(
content=[{"type": "reasoning", "reasoning": "Thinking..."}],
response_metadata={
"output_version": "v1",
"model_provider": "google_genai",
},
)
),
ChatGenerationChunk(
message=AIMessageChunk(
content=[
{
"type": "tool_call",
"id": tool_call_id,
"name": "get_weather",
"args": {"city": "San Francisco"},
}
],
additional_kwargs={
"__gemini_function_call_thought_signatures__": {
tool_call_id: thought_signature,
},
},
response_metadata={
"output_version": "v1",
"model_provider": "google_genai",
},
)
),
]
stream = AsyncChatModelStream()
async for event in achunks_to_events(_aiter_chunks(chunks), message_id="msg-1"):
stream.dispatch(event)
msg = await stream.output
via_additional_kwargs = msg.additional_kwargs.get(
"__gemini_function_call_thought_signatures__", {}
).get(tool_call_id)
assert via_additional_kwargs == thought_signature
def test_chunks_to_events_reasoning_in_additional_kwargs() -> None:
"""Reasoning packed into additional_kwargs surfaces as a reasoning block."""
chunks = [