mirror of
https://github.com/hwchase17/langchain.git
synced 2026-05-14 19:05:21 +00:00
fix(core): preserve chunk additional_kwargs across v3 stream assembly (#37435)
The v3 streaming path drops `additional_kwargs` from per-chunk `AIMessageChunk`s during assembly: `chunks_to_events` emits no event field for them, and `ChatModelStream._assemble_message` constructs the final `AIMessage` without an `additional_kwargs` argument. Non-streaming `ainvoke` returns the provider message unchanged, so streaming and non-streaming diverge for any provider that uses `additional_kwargs` to carry data outside the typed protocol blocks. ## How this surfaces The concrete failure mode is Gemini's `__gemini_function_call_thought_signatures__` — a per-tool-call signature blob the Google GenAI integration places in `additional_kwargs`, keyed by `tool_call_id`. Gemini requires that signature on follow-up turns to replay the prior thought trace; without it, multi-turn streaming flows lose thought continuity (and may regenerate thinking, charging additional reasoning tokens, or in some cases refuse). Other providers that use `additional_kwargs` (e.g. older `function_call` accumulators, custom routing metadata) hit the same gap; the fix is intentionally provider-agnostic. ## Fix Provider-agnostic, two seams: - `_compat_bridge` accumulates `msg.additional_kwargs` across chunks with `merge_dicts` (matching `AIMessageChunk`'s own merge semantics for fields that accumulate, like `function_call`) and emits the merged dict on the `message-finish` event as an off-spec extension. The bridge already uses one such extension (`metadata` on `MessageFinishData`); this PR follows the same pattern for `additional_kwargs`. - `ChatModelStream._finish` reads the new field; `_assemble_message` threads it onto the final `AIMessage` only when non-empty, preserving today's behavior of leaving `additional_kwargs` empty when no provider data needs to ride on it.
This commit is contained in:
@@ -57,6 +57,7 @@ from langchain_protocol.protocol import (
|
||||
)
|
||||
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
@@ -525,6 +526,7 @@ def _build_message_finish(
|
||||
*,
|
||||
usage: dict[str, Any] | None,
|
||||
response_metadata: dict[str, Any] | None,
|
||||
additional_kwargs: dict[str, Any] | None = None,
|
||||
) -> MessageFinishData:
|
||||
# Protocol 0.0.9 removed the top-level `reason` field from
|
||||
# `MessageFinishData`; the provider's raw `finish_reason` /
|
||||
@@ -536,6 +538,16 @@ def _build_message_finish(
|
||||
finish_data["usage"] = usage_info
|
||||
if response_metadata:
|
||||
finish_data["metadata"] = dict(response_metadata)
|
||||
# `additional_kwargs` is an off-spec extension on the message-finish
|
||||
# event (parallel to `metadata`, which `MessageFinishData` also doesn't
|
||||
# formally declare but the consumer reads). It carries provider-side
|
||||
# kwargs that don't map onto a typed protocol field — notably Gemini's
|
||||
# `__gemini_function_call_thought_signatures__`, which the model
|
||||
# requires on follow-up turns to replay prior thinking. Without this,
|
||||
# streaming-assembled messages would silently drop data that
|
||||
# `ainvoke` preserves, breaking multi-turn streaming flows.
|
||||
if additional_kwargs:
|
||||
finish_data["additional_kwargs"] = dict(additional_kwargs)
|
||||
return cast("MessageFinishData", finish_data)
|
||||
|
||||
|
||||
@@ -582,6 +594,7 @@ def chunks_to_events(
|
||||
next_wire_idx = 0
|
||||
usage: dict[str, Any] | None = None
|
||||
response_metadata: dict[str, Any] = {}
|
||||
additional_kwargs: dict[str, Any] = {}
|
||||
|
||||
for chunk in chunks:
|
||||
msg = chunk.message
|
||||
@@ -604,6 +617,17 @@ def chunks_to_events(
|
||||
if merged_rm:
|
||||
response_metadata.update(merged_rm)
|
||||
|
||||
# Carry chunks' `additional_kwargs` through to the assembled
|
||||
# message. Provider-side fields that don't map onto a typed
|
||||
# protocol block (e.g. Gemini's per-tool-call thought signatures)
|
||||
# live here on non-streaming `ainvoke` results; dropping them on
|
||||
# the streaming path silently diverges multi-turn behavior. Use
|
||||
# `merge_dicts` because the same key can arrive in pieces across
|
||||
# chunks (e.g. an accumulating `function_call`), matching how
|
||||
# `AIMessageChunk` merges itself.
|
||||
if msg.additional_kwargs:
|
||||
additional_kwargs = merge_dicts(additional_kwargs, msg.additional_kwargs)
|
||||
|
||||
if not started:
|
||||
started = True
|
||||
yield _build_message_start(msg, message_id)
|
||||
@@ -646,6 +670,7 @@ def chunks_to_events(
|
||||
yield _build_message_finish(
|
||||
usage=usage,
|
||||
response_metadata=response_metadata,
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -660,6 +685,7 @@ async def achunks_to_events(
|
||||
next_wire_idx = 0
|
||||
usage: dict[str, Any] | None = None
|
||||
response_metadata: dict[str, Any] = {}
|
||||
additional_kwargs: dict[str, Any] = {}
|
||||
|
||||
async for chunk in chunks:
|
||||
msg = chunk.message
|
||||
@@ -676,6 +702,12 @@ async def achunks_to_events(
|
||||
if merged_rm:
|
||||
response_metadata.update(merged_rm)
|
||||
|
||||
# See sync twin: carry chunk `additional_kwargs` through so
|
||||
# provider-specific data (e.g. Gemini thought signatures) reaches
|
||||
# the assembled message instead of being dropped.
|
||||
if msg.additional_kwargs:
|
||||
additional_kwargs = merge_dicts(additional_kwargs, msg.additional_kwargs)
|
||||
|
||||
if not started:
|
||||
started = True
|
||||
yield _build_message_start(msg, message_id)
|
||||
@@ -718,6 +750,7 @@ async def achunks_to_events(
|
||||
yield _build_message_finish(
|
||||
usage=usage,
|
||||
response_metadata=response_metadata,
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -519,6 +519,7 @@ class _ChatModelStreamBase:
|
||||
self._usage_value: UsageInfo | None = None
|
||||
self._start_metadata: MessageMetadata | None = None
|
||||
self._finish_metadata: dict[str, Any] | None = None
|
||||
self._additional_kwargs: dict[str, Any] | None = None
|
||||
self._done: bool = False
|
||||
self._error: BaseException | None = None
|
||||
self._output_message: AIMessage | None = None
|
||||
@@ -896,6 +897,15 @@ class _ChatModelStreamBase:
|
||||
self._done = True
|
||||
self._usage_value = data.get("usage")
|
||||
self._finish_metadata = cast("dict[str, Any] | None", data.get("metadata"))
|
||||
# Off-spec extension carrying provider-side `additional_kwargs`
|
||||
# that don't map onto a typed protocol field (e.g. Gemini's
|
||||
# `__gemini_function_call_thought_signatures__`). The compat
|
||||
# bridge emits this on `message-finish` so the assembled message
|
||||
# carries the same data `ainvoke` would have preserved.
|
||||
self._additional_kwargs = cast(
|
||||
"dict[str, Any] | None",
|
||||
cast("dict[str, Any]", data).get("additional_kwargs"),
|
||||
)
|
||||
|
||||
# Finalize any unswept chunks — both client- and server-side.
|
||||
_sweep_chunk_store(
|
||||
@@ -1011,14 +1021,17 @@ class _ChatModelStreamBase:
|
||||
for itc in self._invalid_tool_calls_acc
|
||||
]
|
||||
|
||||
return AIMessage(
|
||||
content=content,
|
||||
id=self._message_id,
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
usage_metadata=self._usage_value,
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
message_kwargs: dict[str, Any] = {
|
||||
"content": content,
|
||||
"id": self._message_id,
|
||||
"tool_calls": tool_calls,
|
||||
"invalid_tool_calls": invalid_tool_calls,
|
||||
"usage_metadata": self._usage_value,
|
||||
"response_metadata": response_metadata,
|
||||
}
|
||||
if self._additional_kwargs:
|
||||
message_kwargs["additional_kwargs"] = dict(self._additional_kwargs)
|
||||
return AIMessage(**message_kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -15,6 +15,10 @@ from langchain_core.language_models._compat_bridge import (
|
||||
chunks_to_events,
|
||||
message_to_events,
|
||||
)
|
||||
from langchain_core.language_models.chat_model_stream import (
|
||||
AsyncChatModelStream,
|
||||
ChatModelStream,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
@@ -649,6 +653,150 @@ async def test_achunks_to_events_reasoning_then_tool_call_no_index() -> None:
|
||||
assert "tool_call" in finish_types
|
||||
|
||||
|
||||
def test_chunks_to_events_preserves_additional_kwargs_on_assembled_message() -> None:
|
||||
"""Streaming-assembled `AIMessage` must retain chunk `additional_kwargs`.
|
||||
|
||||
When Gemini emits a `tool_call` after `thinking`, the source chunk
|
||||
carries a `__gemini_function_call_thought_signatures__` entry in
|
||||
`additional_kwargs`, keyed by `tool_call_id`. This signature is required
|
||||
on follow-up turns so Gemini can replay the prior thought trace.
|
||||
|
||||
Non-streaming `ainvoke` returns the provider's `AIMessage` unchanged, so
|
||||
the signature is preserved. The v3 streaming path runs chunks through
|
||||
`chunks_to_events` -> `ChatModelStream._assemble_message`; before this
|
||||
fix, that path built a fresh `AIMessage` without forwarding
|
||||
`additional_kwargs`, so the signature was silently dropped and
|
||||
multi-turn streaming Gemini diverged from non-streaming.
|
||||
|
||||
Provider-specific kwargs aren't this layer's business individually; the
|
||||
invariant is the general one — chunks' `additional_kwargs` survive into
|
||||
the assembled message in *some* form a follow-up turn can reach.
|
||||
"""
|
||||
thought_signature = "CiIBDDnWx-EXAMPLE-SIGNATURE-PAYLOAD=="
|
||||
tool_call_id = "tc-abc"
|
||||
|
||||
chunks = [
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=[{"type": "reasoning", "reasoning": "Thinking..."}],
|
||||
response_metadata={
|
||||
"output_version": "v1",
|
||||
"model_provider": "google_genai",
|
||||
},
|
||||
)
|
||||
),
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": tool_call_id,
|
||||
"name": "get_weather",
|
||||
"args": {"city": "San Francisco"},
|
||||
}
|
||||
],
|
||||
additional_kwargs={
|
||||
"__gemini_function_call_thought_signatures__": {
|
||||
tool_call_id: thought_signature,
|
||||
},
|
||||
},
|
||||
response_metadata={
|
||||
"output_version": "v1",
|
||||
"model_provider": "google_genai",
|
||||
},
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
stream = ChatModelStream()
|
||||
for event in chunks_to_events(iter(chunks), message_id="msg-1"):
|
||||
stream.dispatch(event)
|
||||
msg = stream.output
|
||||
|
||||
# Reachable through *some* channel a downstream consumer can route on.
|
||||
# Either `additional_kwargs` (mirrors non-streaming), or an `extras`
|
||||
# field on the `tool_call` block (the v1 protocol's slot for
|
||||
# provider-specific data).
|
||||
via_additional_kwargs = msg.additional_kwargs.get(
|
||||
"__gemini_function_call_thought_signatures__", {}
|
||||
).get(tool_call_id)
|
||||
tool_call_blocks = [
|
||||
b
|
||||
for b in (msg.content if isinstance(msg.content, list) else [])
|
||||
if isinstance(b, dict) and b.get("type") == "tool_call"
|
||||
]
|
||||
via_block_extras = next(
|
||||
(
|
||||
b.get("extras", {}).get("thought_signature")
|
||||
for b in tool_call_blocks
|
||||
if b.get("id") == tool_call_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
signature_preserved = via_additional_kwargs or via_block_extras
|
||||
assert signature_preserved == thought_signature, (
|
||||
"Chunk-level additional_kwargs (Gemini thought signature) was dropped "
|
||||
"during v3 stream assembly. Streaming-assembled AIMessage exposes "
|
||||
f"additional_kwargs={msg.additional_kwargs!r}, tool_call blocks="
|
||||
f"{tool_call_blocks!r}. Non-streaming `ainvoke` preserves this "
|
||||
"signature in additional_kwargs unchanged; streaming should not "
|
||||
"diverge."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_achunks_to_events_preserves_additional_kwargs_on_assembled_message() -> (
|
||||
None
|
||||
):
|
||||
"""Async twin of the additional_kwargs preservation regression."""
|
||||
thought_signature = "CiIBDDnWx-EXAMPLE-SIGNATURE-PAYLOAD=="
|
||||
tool_call_id = "tc-abc"
|
||||
|
||||
chunks = [
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=[{"type": "reasoning", "reasoning": "Thinking..."}],
|
||||
response_metadata={
|
||||
"output_version": "v1",
|
||||
"model_provider": "google_genai",
|
||||
},
|
||||
)
|
||||
),
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": tool_call_id,
|
||||
"name": "get_weather",
|
||||
"args": {"city": "San Francisco"},
|
||||
}
|
||||
],
|
||||
additional_kwargs={
|
||||
"__gemini_function_call_thought_signatures__": {
|
||||
tool_call_id: thought_signature,
|
||||
},
|
||||
},
|
||||
response_metadata={
|
||||
"output_version": "v1",
|
||||
"model_provider": "google_genai",
|
||||
},
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
stream = AsyncChatModelStream()
|
||||
async for event in achunks_to_events(_aiter_chunks(chunks), message_id="msg-1"):
|
||||
stream.dispatch(event)
|
||||
msg = await stream.output
|
||||
|
||||
via_additional_kwargs = msg.additional_kwargs.get(
|
||||
"__gemini_function_call_thought_signatures__", {}
|
||||
).get(tool_call_id)
|
||||
assert via_additional_kwargs == thought_signature
|
||||
|
||||
|
||||
def test_chunks_to_events_reasoning_in_additional_kwargs() -> None:
|
||||
"""Reasoning packed into additional_kwargs surfaces as a reasoning block."""
|
||||
chunks = [
|
||||
|
||||
Reference in New Issue
Block a user