diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 6f03a3f6709..383b256cb26 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from collections.abc import Sequence from uuid import UUID + from langchain_protocol.protocol import MessagesData from tenacity import RetryCallState from typing_extensions import Self @@ -124,6 +125,43 @@ class LLMManagerMixin: **kwargs: Additional keyword arguments. """ + def on_stream_event( + self, + event: MessagesData, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + **kwargs: Any, + ) -> Any: + """Run on each protocol event produced by `stream_v2` / `astream_v2`. + + Fires once per `MessagesData` event — `message-start`, per-block + `content-block-start` / `content-block-delta` / + `content-block-finish`, and `message-finish`. Analogous to + `on_llm_new_token` in v1 streaming, but at event granularity rather + than chunk: a single chunk can map to multiple events (e.g. a + `content-block-start` plus its first `content-block-delta`), and + lifecycle boundaries are explicit. + + Fires uniformly whether the provider emits events natively via + `_stream_chat_model_events` or goes through the chunk-to-event + compat bridge. Observers see the same event stream regardless of + how the underlying model produces output. + + Not fired from v1 `stream()` / `astream()`; for those, keep using + `on_llm_new_token`. Purely additive — `on_chat_model_start`, + `on_llm_end`, and `on_llm_error` still fire around a v2 call as + they do around a v1 call. + + Args: + event: The protocol event. + run_id: The ID of the current run. + parent_run_id: The ID of the parent run. + tags: The tags. + **kwargs: Additional keyword arguments. + """ + class ChainManagerMixin: """Mixin for chain callbacks.""" @@ -288,10 +326,10 @@ class CallbackManagerMixin: !!! note When overriding this method, the signature **must** include the two - required positional arguments ``serialized`` and ``messages``. Avoid - using ``*args`` in your override — doing so causes an ``IndexError`` - in the fallback path when the callback system converts ``messages`` - to prompt strings for ``on_llm_start``. Always declare the + required positional arguments `serialized` and `messages`. Avoid + using `*args` in your override — doing so causes an `IndexError` + in the fallback path when the callback system converts `messages` + to prompt strings for `on_llm_start`. Always declare the signature explicitly: .. code-block:: python @@ -557,10 +595,10 @@ class AsyncCallbackHandler(BaseCallbackHandler): !!! note When overriding this method, the signature **must** include the two - required positional arguments ``serialized`` and ``messages``. Avoid - using ``*args`` in your override — doing so causes an ``IndexError`` - in the fallback path when the callback system converts ``messages`` - to prompt strings for ``on_llm_start``. Always declare the + required positional arguments `serialized` and `messages`. Avoid + using `*args` in your override — doing so causes an `IndexError` + in the fallback path when the callback system converts `messages` + to prompt strings for `on_llm_start`. Always declare the signature explicitly: .. code-block:: python @@ -652,6 +690,31 @@ class AsyncCallbackHandler(BaseCallbackHandler): the error occurred. """ + async def on_stream_event( + self, + event: MessagesData, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + **kwargs: Any, + ) -> None: + """Run on each protocol event produced by `astream_v2`. + + See :meth:`LLMManagerMixin.on_stream_event` for the full contract. + Fires once per `MessagesData` event at event granularity, uniformly + across native and compat-bridge providers, and is purely additive + to the existing `on_chat_model_start` / `on_llm_end` / + `on_llm_error` callbacks. + + Args: + event: The protocol event. + run_id: The ID of the current run. + parent_run_id: The ID of the parent run. + tags: The tags. + **kwargs: Additional keyword arguments. + """ + async def on_chain_start( self, serialized: dict[str, Any], diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index b7251b754dd..0efab4ad04c 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -35,6 +35,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence from uuid import UUID + from langchain_protocol.protocol import MessagesData from tenacity import RetryCallState from langchain_core.agents import AgentAction, AgentFinish @@ -747,6 +748,26 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): **kwargs, ) + def on_stream_event(self, event: MessagesData, **kwargs: Any) -> None: + """Run on each protocol event from `stream_v2`. + + Args: + event: The protocol event. + **kwargs: Additional keyword arguments. + """ + if not self.handlers: + return + handle_event( + self.handlers, + "on_stream_event", + "ignore_llm", + event, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): """Async callback manager for LLM run.""" @@ -849,6 +870,26 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): **kwargs, ) + async def on_stream_event(self, event: MessagesData, **kwargs: Any) -> None: + """Run on each protocol event from `astream_v2`. + + Args: + event: The protocol event. + **kwargs: Additional keyword arguments. + """ + if not self.handlers: + return + await ahandle_event( + self.handlers, + "on_stream_event", + "ignore_llm", + event, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): """Callback manager for chain run.""" diff --git a/libs/core/langchain_core/language_models/_compat_bridge.py b/libs/core/langchain_core/language_models/_compat_bridge.py new file mode 100644 index 00000000000..bc6928d8981 --- /dev/null +++ b/libs/core/langchain_core/language_models/_compat_bridge.py @@ -0,0 +1,722 @@ +"""Compat bridge: convert `AIMessageChunk` streams to protocol events. + +The bridge trusts `AIMessageChunk.content_blocks` as the single +protocol view of any chunk. That property runs the three-tier lookup +(`output_version == "v1"` short-circuit, registered translator, or +best-effort parsing) and returns a `list[ContentBlock]` for every +well-formed message — whether the provider is a registered partner, an +unregistered community model, or not tagged at all. + +Per-chunk `content_blocks` output is a **delta slice**, not accumulated +state: providers in this ecosystem emit SSE-style chunks that each carry +their own increment. The bridge therefore forwards each slice straight +through as a `content-block-delta` event, and accumulates per-index +state only so the final `content-block-finish` event can report a +finalized block (e.g. `tool_call_chunk` args parsed to a dict). + +Lifecycle:: + + message-start + -> content-block-start (first time each index is observed) + -> content-block-delta* (per chunk, carrying the slice) + -> content-block-finish (finalized block) + -> message-finish + +Public API: + +- `chunks_to_events` / `achunks_to_events` — for live streams where + chunks arrive over time. +- `message_to_events` / `amessage_to_events` — for replaying a finalized + `AIMessage` (cache hit, checkpoint restore, graph-node return value) + as a synthetic event lifecycle. +""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, cast + +from langchain_protocol.protocol import ( + ContentBlock, + ContentBlockDeltaData, + ContentBlockFinishData, + ContentBlockStartData, + FinalizedContentBlock, + InvalidToolCall, + MessageFinishData, + MessageMetadata, + MessagesData, + MessageStartData, + ReasoningContentBlock, + ServerToolCall, + ServerToolCallChunk, + TextContentBlock, + ToolCall, + ToolCallChunk, + UsageInfo, +) + +from langchain_core.messages import AIMessageChunk, BaseMessage + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Iterator + + from langchain_core.outputs import ChatGenerationChunk + + +CompatBlock = dict[str, Any] +"""Internal working type for a content block. + +The bridge works with plain dicts internally because two separate but +structurally similar `ContentBlock` Unions exist — one in +`langchain_core.messages.content` (returned by `msg.content_blocks`), +one in `langchain_protocol.protocol` (the wire/event shape). They are +not mypy-compatible despite being near-isomorphic. Passing through +`dict[str, Any]` launders between them. See `_to_protocol_block` for +the single seam where the laundering cast lives. +""" + + +# --------------------------------------------------------------------------- +# Type laundering between core and protocol `ContentBlock` unions +# --------------------------------------------------------------------------- + + +def _to_protocol_block(block: CompatBlock) -> ContentBlock: + """Narrow an internal working dict to a protocol `ContentBlock`. + + Single seam between the two `ContentBlock` type systems: + `langchain_core.messages.content` (what `msg.content_blocks` + returns) and `langchain_protocol.protocol` (what event payloads + require). The two Unions overlap structurally but are nominally + distinct to mypy, so we launder through `dict[str, Any]`. When the + Unions are unified, this helper and its finalized counterpart can be + deleted. + """ + return cast("ContentBlock", block) + + +def _to_finalized_block(block: CompatBlock) -> FinalizedContentBlock: + """Counterpart of `_to_protocol_block` for finalized blocks.""" + return cast("FinalizedContentBlock", block) + + +# --------------------------------------------------------------------------- +# Block iteration +# --------------------------------------------------------------------------- + + +def _iter_protocol_blocks(msg: BaseMessage) -> list[tuple[Any, CompatBlock]]: + """Read per-chunk protocol blocks from `msg.content_blocks`. + + Returns `(key, block)` pairs. The key is the block's stable identifier + across the stream: the block's `index` field when present (can be an + int or a string — some providers use string identifiers like + `"lc_rs_305f30"`), or the positional index within the message as a + fallback. Callers are responsible for allocating wire-level `uint` + indices; this helper only surfaces the source-side identity. + + For finalized `AIMessage`, also surfaces `invalid_tool_calls` + — which `AIMessage.content_blocks` currently omits from its return + value even though they are a defined protocol block type. + + The positional fallback is a known fragility: when a provider emits + blocks without an `index` field (e.g. Anthropic's `_stream` with + `coerce_content_to_string=True`, where text chunks lose their + source-side index), every such chunk gets positional key 0 and + successive chunks merge into one block. This works correctly for + single-type streams (pure-text responses merge cleanly) because all + chunks share the same key and the open-block logic collapses them. + It would miscategorise a stream that mixed indexed structured + blocks with non-indexed coerced-text blocks, since an indexed + block with `index == 0` would collide with the anonymous text + block's positional-0 key. In the anthropic integration this + cannot currently occur: coerce-to-string mode is only selected + when no tools, thinking, or documents are present, and any of + those flips the stream to structured mode where every block + carries an integer index. A native `_stream_chat_model_events` + hook per provider (or a bridge-level "continue the open block when + the source has no identity" rule) would close the gap if another + integration ever emits mixed content. + """ + try: + raw = msg.content_blocks + except Exception: + return [] + + result: list[tuple[Any, CompatBlock]] = [] + for i, block in enumerate(raw): + if not isinstance(block, dict): + continue + key = block.get("index", i) + result.append((key, dict(block))) + + if not isinstance(msg, AIMessageChunk): + # Finalized AIMessage: pull invalid_tool_calls from the dedicated + # field — AIMessage.content_blocks does not currently include them. + for itc in getattr(msg, "invalid_tool_calls", None) or []: + itc_block: CompatBlock = {"type": "invalid_tool_call"} + for key_name in ("id", "name", "args", "error"): + if itc.get(key_name) is not None: + itc_block[key_name] = itc[key_name] + result.append((len(result), itc_block)) + + return result + + +# --------------------------------------------------------------------------- +# Per-block helpers +# --------------------------------------------------------------------------- + + +# Fields that can carry large payloads (inline base64 media, parsed args, +# arbitrary dicts). Stripped from `content-block-start` for self-contained +# block types so the payload rides on `content-block-finish` alone instead +# of being serialized twice on the wire. +_HEAVY_FIELDS = frozenset({"args", "data", "output", "transcript", "value"}) + + +def _start_skeleton(block: CompatBlock) -> ContentBlock: + """Empty-content placeholder for the `content-block-start` event. + + Deltaable block types (text, reasoning, the `_chunk` tool variants) + get an empty payload so the lifecycle's "start" signal is distinct + from the first incremental delta. Self-contained types (image, + audio, video, file, non_standard, finalized tool calls) drop their + heavy payload fields; those are carried by `content-block-finish`. + Correlation fields (id, name, toolCallId) and small metadata + (mime_type, url, status, …) are preserved on the start event. + """ + btype = block.get("type", "text") + if btype == "text": + return TextContentBlock(type="text", text="") + if btype == "reasoning": + return ReasoningContentBlock(type="reasoning", reasoning="") + if btype == "tool_call_chunk": + return ToolCallChunk( + type="tool_call_chunk", + id=block.get("id"), + name=block.get("name"), + args="", + ) + if btype == "server_tool_call_chunk": + s_skel = ServerToolCallChunk( + type="server_tool_call_chunk", + args="", + ) + if block.get("id") is not None: + s_skel["id"] = block["id"] + if block.get("name") is not None: + s_skel["name"] = block["name"] + return s_skel + + stripped: CompatBlock = {k: v for k, v in block.items() if k not in _HEAVY_FIELDS} + # Restore required-but-heavy fields with minimal placeholders so the + # start event still validates against the CDDL shape of the block type. + if btype in ("tool_call", "server_tool_call"): + stripped["args"] = {} + elif btype == "non_standard": + stripped["value"] = {} + return _to_protocol_block(stripped) + + +def _should_emit_delta(block: CompatBlock) -> bool: + """Whether a per-chunk block carries content worth a delta event. + + Deltaable types emit only when they have fresh content. Self-contained + / already-finalized types skip the delta entirely — the `finish` + event carries them. + """ + btype = block.get("type") + if btype == "text": + return bool(block.get("text")) + if btype == "reasoning": + return bool(block.get("reasoning")) + if btype in ("tool_call_chunk", "server_tool_call_chunk"): + return bool( + block.get("args") or block.get("id") or block.get("name"), + ) + return False + + +def _accumulate(state: CompatBlock | None, delta: CompatBlock) -> CompatBlock: + """Merge a per-chunk delta slice into accumulated per-index state. + + Used only for the finalization pass — live delta events are emitted + directly from the per-chunk block, without round-tripping through + accumulated state. + """ + if state is None: + return dict(delta) + btype = state.get("type") + dtype = delta.get("type") + if btype == "text" and dtype == "text": + state["text"] = state.get("text", "") + delta.get("text", "") + # Providers may send non-text fields (like `id`, or annotations) + # on later deltas. Merging (not replacing) keeps earlier keys + # intact while picking up these late-arriving fields. + for key, value in delta.items(): + if key in ("type", "text") or value is None: + continue + if key == "extras" and isinstance(value, dict): + state["extras"] = {**(state.get("extras") or {}), **value} + else: + state[key] = value + elif btype == "reasoning" and dtype == "reasoning": + state["reasoning"] = state.get("reasoning", "") + delta.get("reasoning", "") + # Providers may ship non-text fields on later deltas. Claude's + # `signature_delta` arrives after the reasoning text, surfaced + # as `extras.signature`; merging (not replacing) keeps earlier + # keys intact. + for key, value in delta.items(): + if key in ("type", "reasoning") or value is None: + continue + if key == "extras" and isinstance(value, dict): + state["extras"] = {**(state.get("extras") or {}), **value} + else: + state[key] = value + elif btype in ("tool_call_chunk", "server_tool_call_chunk") and dtype == btype: + state["args"] = (state.get("args", "") or "") + (delta.get("args") or "") + if delta.get("id") is not None: + state["id"] = delta["id"] + if delta.get("name") is not None: + state["name"] = delta["name"] + else: + # Self-contained or already-finalized types: replace wholesale. + state.clear() + state.update(delta) + return state + + +def finalize_tool_call_chunk( + *, + raw_args: str | None, + id_: str | None, + name: str | None, + extras: dict[str, Any], + finalized_type: str, +) -> FinalizedContentBlock: + """Parse accumulated tool-chunk args into a finalized block. + + Shared between the compat bridge's `_finalize_block` and the + `ChatModelStream` end-of-stream sweep. Parses `raw_args` as JSON: + on success builds the requested finalized type (`tool_call` or + `server_tool_call`) with provider-specific fields (`extras`) + preserved; on failure falls back to `invalid_tool_call` carrying + the raw string so downstream consumers can still introspect the + malformed payload. + + Args: + raw_args: Accumulated partial-JSON string; `None` or empty + treated as `{}`. + id_: Tool-call id collected across chunks. + name: Tool name collected across chunks. + extras: Provider-specific fields to carry onto the finalized + block. Callers are responsible for having already dropped + keys they don't want propagated (notably `type`, `id`, + `name`, `args`, and `index` on client-side `tool_call`). + finalized_type: `"tool_call"` or `"server_tool_call"`. + + Returns: + A `ToolCall`, `ServerToolCall`, or `InvalidToolCall` — the + latter when `raw_args` is non-empty but not valid JSON. + """ + raw = raw_args or "{}" + try: + parsed = json.loads(raw) if raw else {} + except (json.JSONDecodeError, TypeError): + invalid = InvalidToolCall( + type="invalid_tool_call", + id=id_, + name=name, + args=raw, + error="Failed to parse tool call arguments as JSON", + ) + invalid.update(extras) # type: ignore[typeddict-item] + return invalid + if finalized_type == "tool_call": + finalized_tc = ToolCall( + type="tool_call", + id=id_ or "", + name=name or "", + args=parsed, + ) + finalized_tc.update(extras) # type: ignore[typeddict-item] + return finalized_tc + finalized_stc = ServerToolCall( + type="server_tool_call", + id=id_ or "", + name=name or "", + args=parsed, + ) + finalized_stc.update(extras) # type: ignore[typeddict-item] + return finalized_stc + + +def _finalize_block(block: CompatBlock) -> FinalizedContentBlock: + """Promote chunk variants to their finalized form. + + `tool_call_chunk` becomes `tool_call` — or `invalid_tool_call` + if the accumulated `args` don't parse as JSON. + `server_tool_call_chunk` becomes `server_tool_call` under the same + rule. Everything else passes through: text/reasoning blocks carry + their accumulated snapshot, and self-contained types are already in + their terminal shape. + """ + btype = block.get("type") + if btype in ("tool_call_chunk", "server_tool_call_chunk"): + # Carry provider-specific fields from the accumulated chunk onto + # the finalized block. Drop the chunk-only keys we rewrite + # explicitly. `index` is stripped on client-side + # `tool_call` / `invalid_tool_call` finalizations to match v1 + # (`AIMessage.init_tool_calls` rebuilds tool_call blocks without + # `index`), preventing `merge_lists` from re-merging further + # chunks into an already-parsed args dict. `server_tool_call` + # retains `index` because v1's `init_server_tool_calls` + # finalizes in-place and preserves it. + client_tool_call = btype == "tool_call_chunk" + extras_drop = {"type", "id", "name", "args"} + if client_tool_call: + extras_drop = extras_drop | {"index"} + extras = { + k: v for k, v in block.items() if k not in extras_drop and v is not None + } + return finalize_tool_call_chunk( + raw_args=block.get("args"), + id_=block.get("id"), + name=block.get("name"), + extras=extras, + finalized_type="tool_call" if client_tool_call else "server_tool_call", + ) + return _to_finalized_block(block) + + +# --------------------------------------------------------------------------- +# Metadata, usage, finish-reason +# --------------------------------------------------------------------------- + + +def _extract_start_metadata(response_metadata: dict[str, Any]) -> MessageMetadata: + """Pull provider/model hints for the `message-start` event.""" + metadata: MessageMetadata = {} + if "model_provider" in response_metadata: + metadata["provider"] = response_metadata["model_provider"] + if "model_name" in response_metadata: + metadata["model"] = response_metadata["model_name"] + return metadata + + +def _accumulate_usage( + current: dict[str, Any] | None, delta: Any +) -> dict[str, Any] | None: + """Sum usage counts and merge detail dicts across chunks.""" + if not isinstance(delta, dict): + return current + if current is None: + return dict(delta) + for key in ("input_tokens", "output_tokens", "total_tokens", "cached_tokens"): + if key in delta: + current[key] = current.get(key, 0) + delta[key] + for detail_key in ("input_token_details", "output_token_details"): + if detail_key in delta and isinstance(delta[detail_key], dict): + if detail_key not in current: + current[detail_key] = {} + current[detail_key].update(delta[detail_key]) + return current + + +def _to_protocol_usage(usage: dict[str, Any] | None) -> UsageInfo | None: + """Convert accumulated usage to the protocol's `UsageInfo` shape.""" + if usage is None: + return None + result: UsageInfo = {} + for key in ("input_tokens", "output_tokens", "total_tokens", "cached_tokens"): + if key in usage: + result[key] = usage[key] + return result or None + + +# --------------------------------------------------------------------------- +# Event builders +# --------------------------------------------------------------------------- + + +def _build_message_start( + msg: BaseMessage, + message_id: str | None, +) -> MessageStartData: + start_data = MessageStartData(event="message-start", role="ai") + resolved_id = message_id if message_id is not None else getattr(msg, "id", None) + if resolved_id: + start_data["message_id"] = resolved_id + start_metadata = _extract_start_metadata(msg.response_metadata or {}) + if start_metadata: + start_data["metadata"] = start_metadata + return start_data + + +def _build_message_finish( + *, + usage: dict[str, Any] | None, + response_metadata: dict[str, Any] | None, +) -> MessageFinishData: + # Protocol 0.0.9 removed the top-level `reason` field from + # `MessageFinishData`; the provider's raw `finish_reason` / + # `stop_reason` now rides inside `metadata` alongside other + # response metadata. Pass it through unchanged. + finish_data = MessageFinishData(event="message-finish") + usage_info = _to_protocol_usage(usage) + if usage_info is not None: + finish_data["usage"] = usage_info + if response_metadata: + finish_data["metadata"] = dict(response_metadata) + return finish_data + + +def _finalize_and_build_finish( + wire_idx: int, + block: CompatBlock, +) -> MessagesData: + """Finalize a block and wrap it in a `content-block-finish` event.""" + return ContentBlockFinishData( + event="content-block-finish", + index=wire_idx, + content_block=_finalize_block(block), + ) + + +# --------------------------------------------------------------------------- +# Main generators +# --------------------------------------------------------------------------- + + +def chunks_to_events( + chunks: Iterator[ChatGenerationChunk], + *, + message_id: str | None = None, +) -> Iterator[MessagesData]: + """Convert a stream of `ChatGenerationChunk` to protocol events. + + Blocks stream one at a time: when a chunk carries a different block + identifier than the currently-open one, the open block is finished + before the new block starts, matching the protocol's no-interleave + rule. Source-side identifiers (from the block's `index` field, which + may be int or string) are translated to sequential `uint` wire + indices. + + Args: + chunks: Iterator of `ChatGenerationChunk` from `_stream()`. + message_id: Optional stable message ID. + + Yields: + `MessagesData` lifecycle events. + """ + started = False + open_key: Any = None + open_block: CompatBlock | None = None + open_wire_idx: int = 0 + next_wire_idx = 0 + usage: dict[str, Any] | None = None + response_metadata: dict[str, Any] = {} + + for chunk in chunks: + msg = chunk.message + if not isinstance(msg, AIMessageChunk): + continue + + # The v1 `stream()` wrapper merges `generation_info` into + # `response_metadata` before yielding (`chat_models.py` via + # `_gen_info_and_msg_metadata`). We bypass that wrapper by reading + # `_stream` directly, so reproduce the merge here with the same + # priority: `generation_info` first, then `message.response_metadata` + # overlays. This is how provider fields like `model_name`, + # `system_fingerprint`, and `finish_reason` reach the bridge when + # a provider emits them via `generation_info` instead of the + # message's `response_metadata`. + merged_rm: dict[str, Any] = { + **(chunk.generation_info or {}), + **(msg.response_metadata or {}), + } + if merged_rm: + response_metadata.update(merged_rm) + + if not started: + started = True + yield _build_message_start(msg, message_id) + + for key, block in _iter_protocol_blocks(msg): + if key != open_key: + if open_block is not None: + yield _finalize_and_build_finish(open_wire_idx, open_block) + open_key = key + open_wire_idx = next_wire_idx + next_wire_idx += 1 + open_block = dict(block) + yield ContentBlockStartData( + event="content-block-start", + index=open_wire_idx, + content_block=_start_skeleton(block), + ) + else: + open_block = _accumulate(open_block, block) + if _should_emit_delta(block): + yield ContentBlockDeltaData( + event="content-block-delta", + index=open_wire_idx, + content_block=_to_protocol_block(block), + ) + + if msg.usage_metadata: + usage = _accumulate_usage(usage, msg.usage_metadata) + + if not started: + return + + if open_block is not None: + yield _finalize_and_build_finish(open_wire_idx, open_block) + + yield _build_message_finish( + usage=usage, + response_metadata=response_metadata, + ) + + +async def achunks_to_events( + chunks: AsyncIterator[ChatGenerationChunk], + *, + message_id: str | None = None, +) -> AsyncIterator[MessagesData]: + """Async variant of `chunks_to_events`.""" + started = False + open_key: Any = None + open_block: CompatBlock | None = None + open_wire_idx: int = 0 + next_wire_idx = 0 + usage: dict[str, Any] | None = None + response_metadata: dict[str, Any] = {} + + async for chunk in chunks: + msg = chunk.message + if not isinstance(msg, AIMessageChunk): + continue + + # See sync twin for rationale: merge `generation_info` into the + # accumulated `response_metadata` with the same priority as the + # v1 `stream()` wrapper. + merged_rm: dict[str, Any] = { + **(chunk.generation_info or {}), + **(msg.response_metadata or {}), + } + if merged_rm: + response_metadata.update(merged_rm) + + if not started: + started = True + yield _build_message_start(msg, message_id) + + for key, block in _iter_protocol_blocks(msg): + if key != open_key: + if open_block is not None: + yield _finalize_and_build_finish(open_wire_idx, open_block) + open_key = key + open_wire_idx = next_wire_idx + next_wire_idx += 1 + open_block = dict(block) + yield ContentBlockStartData( + event="content-block-start", + index=open_wire_idx, + content_block=_start_skeleton(block), + ) + else: + open_block = _accumulate(open_block, block) + if _should_emit_delta(block): + yield ContentBlockDeltaData( + event="content-block-delta", + index=open_wire_idx, + content_block=_to_protocol_block(block), + ) + + if msg.usage_metadata: + usage = _accumulate_usage(usage, msg.usage_metadata) + + if not started: + return + + if open_block is not None: + yield _finalize_and_build_finish(open_wire_idx, open_block) + + yield _build_message_finish( + usage=usage, + response_metadata=response_metadata, + ) + + +def message_to_events( + msg: BaseMessage, + *, + message_id: str | None = None, +) -> Iterator[MessagesData]: + """Replay a finalized message as a synthetic event lifecycle. + + For a message returned whole (from a graph node, checkpoint, or + cache), produce the same `message-start` / per-block / + `message-finish` event stream a live call would produce. Consumers + downstream see a uniform event shape regardless of source. + + Text and reasoning blocks emit a single `content-block-delta` with + the full accumulated content. Already-finalized blocks (tool_call, + server_tool_call, image, etc.) skip the delta and rely on the + `content-block-finish` event alone. + + Args: + msg: The finalized message — typically an `AIMessage`. + message_id: Optional stable message ID; falls back to `msg.id`. + + Yields: + `MessagesData` lifecycle events. + """ + response_metadata = msg.response_metadata or {} + yield _build_message_start(msg, message_id) + + for wire_idx, (_key, block) in enumerate(_iter_protocol_blocks(msg)): + yield ContentBlockStartData( + event="content-block-start", + index=wire_idx, + content_block=_start_skeleton(block), + ) + if _should_emit_delta(block): + yield ContentBlockDeltaData( + event="content-block-delta", + index=wire_idx, + content_block=_to_protocol_block(block), + ) + yield ContentBlockFinishData( + event="content-block-finish", + index=wire_idx, + content_block=_finalize_block(block), + ) + + yield _build_message_finish( + usage=getattr(msg, "usage_metadata", None), + response_metadata=response_metadata, + ) + + +async def amessage_to_events( + msg: BaseMessage, + *, + message_id: str | None = None, +) -> AsyncIterator[MessagesData]: + """Async variant of `message_to_events`.""" + for event in message_to_events(msg, message_id=message_id): + yield event + + +__all__ = [ + "CompatBlock", + "achunks_to_events", + "amessage_to_events", + "chunks_to_events", + "finalize_tool_call_chunk", + "message_to_events", +] diff --git a/libs/core/langchain_core/language_models/chat_model_stream.py b/libs/core/langchain_core/language_models/chat_model_stream.py new file mode 100644 index 00000000000..27b200b58bd --- /dev/null +++ b/libs/core/langchain_core/language_models/chat_model_stream.py @@ -0,0 +1,1317 @@ +"""Per-message streaming objects for content-block protocol events. + +`ChatModelStream` is the synchronous variant returned by +`BaseChatModel.stream_v2()`. `AsyncChatModelStream` is the +asynchronous variant returned by `BaseChatModel.astream_v2()`. + +Both expose typed projection properties (`.text`, `.reasoning`, +`.tool_calls`, `.usage`, `.output`) that accumulate protocol +events as they arrive. Projections can be iterated for deltas or +drained for the final accumulated value. + +Raw protocol events are also available via direct iteration on the +stream object (replay-buffer semantics — multiple independent +consumers supported). +""" + +from __future__ import annotations + +import asyncio +import contextlib +from typing import TYPE_CHECKING, Any, cast + +from langchain_core.language_models._compat_bridge import finalize_tool_call_chunk +from langchain_core.messages import AIMessage + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Generator, Iterator + + from langchain_protocol.protocol import ( + ContentBlockDeltaData, + ContentBlockFinishData, + FinalizedContentBlock, + InvalidToolCall, + MessageFinishData, + MessageMetadata, + MessagesData, + MessageStartData, + ReasoningContentBlock, + ServerToolCallChunk, + TextContentBlock, + ToolCall, + ToolCallChunk, + UsageInfo, + ) + from typing_extensions import Self + + +# --------------------------------------------------------------------------- +# Tool-call chunk helpers (shared by tool_call_chunk and server_tool_call_chunk) +# --------------------------------------------------------------------------- + + +def _merge_chunk_into_store( + store: dict[int, dict[str, Any]], + idx: int, + block: dict[str, Any], +) -> None: + """Merge a tool-call-chunk delta: sticky id/name, concat args.""" + existing = store.get(idx, {}) + if block.get("id") and "id" not in existing: + existing["id"] = block["id"] + if block.get("name") and "name" not in existing: + existing["name"] = block["name"] + existing["args"] = existing.get("args", "") + (block.get("args") or "") + store[idx] = existing + + +def _sweep_chunk_store( + store: dict[int, dict[str, Any]], + *, + finalized_type: str, + finalized_blocks: dict[int, FinalizedContentBlock], + tool_calls_acc: list[ToolCall] | None, + invalid_acc: list[InvalidToolCall], +) -> None: + """Parse each unswept chunk's `args`; record as `finalized_type` or invalid. + + `tool_calls_acc` is only populated when `finalized_type == "tool_call"` + (server-side calls don't surface through `.tool_calls`). + + Deliberately does not backfill `index` onto finalized tool-call blocks: + matches v1 (`AIMessage.init_tool_calls` drops `index` when substituting + `tool_call_chunk` → `tool_call`) and prevents `merge_lists` from + re-merging further chunks into an already-parsed args dict. + """ + for idx in sorted(store): + chunk = store[idx] + # Carry over any non-finalize-rewritten fields the chunk collected + # (e.g., `extras`). `_merge_chunk_into_store` only populates + # `id` / `name` / `args`, so this is empty in practice today; + # future provider-specific fields would flow through here. + extras = { + k: v + for k, v in chunk.items() + if k not in ("type", "id", "name", "args") and v is not None + } + final_block = finalize_tool_call_chunk( + raw_args=chunk.get("args"), + id_=chunk.get("id"), + name=chunk.get("name"), + extras=extras, + finalized_type=finalized_type, + ) + if final_block["type"] == "invalid_tool_call": + invalid_acc.append(final_block) + elif tool_calls_acc is not None and finalized_type == "tool_call": + tool_calls_acc.append(cast("ToolCall", final_block)) + finalized_blocks[idx] = final_block + store.clear() + + +# --------------------------------------------------------------------------- +# Projection base — shared producer API +# --------------------------------------------------------------------------- + + +class _ProjectionBase: + """Shared state and producer API for sync and async projections. + + The `push` / `complete` / `fail` methods are the producer-side + API — called by the stream as events arrive. Subclasses add the + consumer protocol (sync iteration or async iteration + await). + + `done` and `error` are safe read-only views of the terminal state + for iterators and other siblings that need to observe lifecycle + without reaching into the underlying fields. + """ + + __slots__ = ("_deltas", "_done", "_error", "_final_set", "_final_value") + + def __init__(self) -> None: + """Initialize empty projection state.""" + self._deltas: list[Any] = [] + self._final_value: Any = None + self._final_set: bool = False + self._done: bool = False + self._error: BaseException | None = None + + @property + def done(self) -> bool: + """Whether the projection has finished (successfully or via error).""" + return self._done + + @property + def error(self) -> BaseException | None: + """The terminal error, if any.""" + return self._error + + def push(self, delta: Any) -> None: + """Append a delta value. Producer-side API.""" + self._deltas.append(delta) + + def complete(self, final_value: Any) -> None: + """Set the final accumulated value and mark as done. Producer-side API.""" + self._final_value = final_value + self._final_set = True + self._done = True + + def fail(self, error: BaseException) -> None: + """Mark as errored. Producer-side API.""" + self._error = error + self._done = True + + +# --------------------------------------------------------------------------- +# Sync projections +# --------------------------------------------------------------------------- + + +class SyncProjection(_ProjectionBase): + """Sync iterable of deltas with pull-based backpressure. + + Follows the same `_request_more` convention as langgraph's + `EventLog`: when the cursor catches up to the buffer and the + projection is not done, it calls `_request_more()` to pull more + events from the producer. + + Each call to `__iter__` creates a new cursor at position 0. + Multiple iterators replay all deltas from the start. + """ + + __slots__ = ("_ensure_started", "_request_more") + + def __init__(self) -> None: + """Initialize with no pull callback.""" + super().__init__() + self._ensure_started: Callable[[], None] | None = None + self._request_more: Callable[[], bool] | None = None + + def set_start(self, cb: Callable[[], None] | None) -> None: + """Install a lazy-start callback invoked on first consumption.""" + self._ensure_started = cb + + def set_request_more(self, cb: Callable[[], bool] | None) -> None: + """Install the pull callback the iterator uses to drain the source.""" + self._request_more = cb + + def __iter__(self) -> Iterator[Any]: + """Yield deltas, pulling via `_request_more` when caught up.""" + if self._ensure_started is not None: + self._ensure_started() + cursor = 0 + while True: + if cursor < len(self._deltas): + yield self._deltas[cursor] + cursor += 1 + elif self._error is not None: + raise self._error + elif self._done: + return + elif self._request_more is not None: + while cursor >= len(self._deltas) and not self._done: + if not self._request_more(): + break + if cursor >= len(self._deltas): + if self._error is not None: + raise self._error + return + else: + return + + def get(self) -> Any: + """Drain via `_request_more` and return the final value.""" + if self._ensure_started is not None: + self._ensure_started() + if not self._done and self._request_more is not None: + while not self._done: + if not self._request_more(): + break + if self._error is not None: + raise self._error + return self._final_value + + +class SyncTextProjection(SyncProjection): + """String-specialized sync projection. + + Adds `__str__`, `__bool__`, `__repr__` for ergonomic use with + `.text` and `.reasoning` projections. + """ + + __slots__ = () + + def __str__(self) -> str: + """Drain and return the full accumulated string.""" + val = self.get() + return val if val is not None else "" + + def __bool__(self) -> bool: + """Return whether any deltas have been pushed.""" + return len(self._deltas) > 0 + + def __repr__(self) -> str: + """Return repr of the accumulated text so far.""" + if self._final_set: + return repr(self._final_value) + return repr("".join(self._deltas)) + + +# --------------------------------------------------------------------------- +# Async projection +# --------------------------------------------------------------------------- + + +class AsyncProjection(_ProjectionBase): + """Async iterable of deltas that is also awaitable for the final value. + + Uses an `asyncio.Event` to notify consumers of state changes. Each + waiter — the awaitable (`__await__`) and each async iterator cursor + — shares the event and re-checks its own condition on wake. The event + is cleared before a waiter awaits, so stale "something happened" + signals don't cause spin loops. + + This is single-loop only — producers and consumers must share an + event loop. If cross-thread wake is ever required, revert to a + list-of-futures pattern with `call_soon_threadsafe`. + """ + + __slots__ = ("_arequest_more", "_ensure_started", "_event") + + def __init__(self) -> None: + """Initialize with an un-set event and no pump callback.""" + super().__init__() + self._event = asyncio.Event() + self._arequest_more: Callable[[], Awaitable[bool]] | None = None + self._ensure_started: Callable[[], Awaitable[None]] | None = None + + def set_start(self, cb: Callable[[], Awaitable[None]] | None) -> None: + """Install a lazy-start callback invoked on first consumption.""" + self._ensure_started = cb + + def set_arequest_more(self, cb: Callable[[], Awaitable[bool]] | None) -> None: + """Wire the async pull callback iterators use to drive the source. + + Mirrors `SyncProjection.set_request_more`. Under caller-driven + streaming, consumers call this callback when their buffer is + empty so that the owning graph advances one step. + + Args: + cb: Async no-arg callable returning `True` when a new event + was produced, `False` when the source is exhausted. Pass + `None` to unwire. + """ + self._arequest_more = cb + + def push(self, delta: Any) -> None: + """Append a delta and notify waiters.""" + super().push(delta) + self._event.set() + + def complete(self, final_value: Any) -> None: + """Set the final value, mark done, and notify waiters.""" + super().complete(final_value) + self._event.set() + + def fail(self, error: BaseException) -> None: + """Mark errored and notify waiters.""" + super().fail(error) + self._event.set() + + # -- Async iterable (yields deltas) ------------------------------------ + + def __aiter__(self) -> _AsyncProjectionIterator: + """Return an async iterator over deltas.""" + return _AsyncProjectionIterator(self) + + # -- Awaitable (returns final value) ----------------------------------- + + def __await__(self) -> Generator[Any, None, Any]: + """Await the final accumulated value.""" + return self._await_impl().__await__() + + async def _await_impl(self) -> Any: + """Wait until the final value is set and return it. + + When a caller-driven pump is wired via `set_arequest_more`, drive + it instead of blocking on `self._event`; otherwise fall back to + the event (used by tests that dispatch manually). + """ + if self._ensure_started is not None: + await self._ensure_started() + while not self._final_set: + if self._error is not None: + raise self._error + if self._arequest_more is not None: + if not await self._arequest_more() and not self._final_set: + # Pump exhausted without completing this projection — + # nothing more will arrive. Return current state and + # let callers observe the missing final via the + # returned None / unset error. + break + else: + self._event.clear() + await self._event.wait() + if self._error is not None: + raise self._error + return self._final_value + + +class _AsyncProjectionIterator: + """Async iterator over an `AsyncProjection`'s deltas.""" + + __slots__ = ("_offset", "_proj") + + def __init__(self, proj: AsyncProjection) -> None: + """Initialize cursor at position 0.""" + self._proj = proj + self._offset = 0 + + def __aiter__(self) -> _AsyncProjectionIterator: + """Return self for the async iteration protocol.""" + return self + + async def __anext__(self) -> Any: + """Return the next delta, awaiting if necessary. + + When the projection has an `_arequest_more` pump wired, drain it + in an inner loop (mirrors `SyncProjection.__iter__`) until this + cursor advances or the pump reports exhaustion. Without a pump, + fall back to waiting on the shared event. + """ + proj = self._proj + if proj._ensure_started is not None: # noqa: SLF001 + await proj._ensure_started() # noqa: SLF001 + while True: + # Direct access to the projection's internal list/event is + # intentional — the iterator is the projection's sidekick and + # depends on reading the shared buffer by cursor. + if self._offset < len(proj._deltas): # noqa: SLF001 + item = proj._deltas[self._offset] # noqa: SLF001 + self._offset += 1 + return item + if proj.error is not None: + raise proj.error + if proj.done: + raise StopAsyncIteration + if proj._arequest_more is not None: # noqa: SLF001 + # Caller-driven: drive the producer. Pump may land new + # deltas for a sibling projection — loop until our cursor + # advances, the projection terminates, or the pump is + # exhausted. + while ( + self._offset >= len(proj._deltas) # noqa: SLF001 + and not proj.done + ): + if not await proj._arequest_more(): # noqa: SLF001 + break + if ( + self._offset >= len(proj._deltas) # noqa: SLF001 + and not proj.done + ): + if proj.error is not None: + raise proj.error + raise StopAsyncIteration + else: + proj._event.clear() # noqa: SLF001 + await proj._event.wait() # noqa: SLF001 + + +# --------------------------------------------------------------------------- +# Sync stream +# --------------------------------------------------------------------------- + + +class _ChatModelStreamBase: + """Shared state and event dispatch for chat-model streams. + + Holds accumulated protocol state (text, reasoning, tool calls, + usage, metadata) and the event-dispatch machinery that drives the + typed projections. `ChatModelStream` (sync) and + `AsyncChatModelStream` (async) inherit from this base and add the + projection types and consumer APIs for their flavor. + """ + + # Projection instances — concrete subclasses create them as sync or + # async variants in their own __init__ after calling super(). + _text_proj: _ProjectionBase + _reasoning_proj: _ProjectionBase + _tool_calls_proj: _ProjectionBase + + def __init__( + self, + *, + namespace: list[str] | None = None, + node: str | None = None, + message_id: str | None = None, + ) -> None: + self._namespace = namespace or [] + self._node = node + self._message_id = message_id + + # Accumulated state + self._text_acc: str = "" + self._reasoning_acc: str = "" + # Per-block text / reasoning storage keyed by wire index. Used to + # populate the finalized block payload without cross-contaminating + # other blocks of the same type in the same message. Without + # per-block storage the message-wide accumulator would bleed + # earlier block text into later finalized blocks. + self._text_per_block: dict[int, str] = {} + self._reasoning_per_block: dict[int, str] = {} + self._tool_call_chunks: dict[int, dict[str, Any]] = {} + self._tool_calls_acc: list[ToolCall] = [] + self._invalid_tool_calls_acc: list[InvalidToolCall] = [] + self._server_tool_call_chunks: dict[int, dict[str, Any]] = {} + # Ordered snapshot of every finalized block, keyed by event index. + # Single source of truth for .output.content. Typed accumulators + # (text/reasoning/tool_calls/invalid_tool_calls) continue to serve + # the public projections. + self._blocks: dict[int, FinalizedContentBlock] = {} + self._usage_value: UsageInfo | None = None + self._start_metadata: MessageMetadata | None = None + self._finish_metadata: dict[str, Any] | None = None + self._done: bool = False + self._error: BaseException | None = None + self._output_message: AIMessage | None = None + + # Raw event replay buffer + self._events: list[MessagesData] = [] + + # -- Common properties ------------------------------------------------ + + @property + def namespace(self) -> list[str]: + """Graph namespace path for this message.""" + return self._namespace + + @property + def node(self) -> str | None: + """Graph node that produced this message.""" + return self._node + + @property + def message_id(self) -> str | None: + """Stable message identifier.""" + return self._message_id + + def set_message_id(self, message_id: str) -> None: + """Assign the stable message identifier once the run starts. + + Called by the stream driver (`stream_v2` / `astream_v2`) after + `on_chat_model_start` produces a run id. Not intended for + end-user code. + """ + self._message_id = message_id + + @property + def done(self) -> bool: + """Whether the stream has finished.""" + return self._done + + @property + def has_events(self) -> bool: + """Whether any protocol events have been recorded.""" + return bool(self._events) + + @property + def output_message(self) -> AIMessage | None: + """The assembled message if the stream has finished, else `None`. + + Unlike `ChatModelStream.output` (which blocks until the stream + finishes), this never pumps, blocks, or raises. Intended for the + stream driver (`stream_v2` / `astream_v2`) to check whether the + stream produced a message before firing `on_llm_end` callbacks. + """ + return self._output_message + + # -- Event ingestion (public) ------------------------------------------ + + def dispatch(self, event: MessagesData) -> None: + """Route a protocol event to the appropriate internal handler. + + Public entry point for feeding events into the stream. Called by + the stream driver (`stream_v2` / `astream_v2`'s pump) and by + any observer or test that needs to inject protocol events. + """ + self._record_event(event) + event_type = event.get("event") + if event_type == "message-start": + self._push_message_start(cast("MessageStartData", event)) + elif event_type == "content-block-delta": + self._push_content_block_delta(cast("ContentBlockDeltaData", event)) + elif event_type == "content-block-finish": + self._push_content_block_finish(cast("ContentBlockFinishData", event)) + elif event_type == "message-finish": + self._finish(cast("MessageFinishData", event)) + elif event_type == "error": + self.fail(RuntimeError(event.get("message", "Unknown error"))) + # content-block-start is informational — no accumulation needed + + # -- Internal push API (called by dispatch) ---------------------------- + + def _record_event(self, event: MessagesData) -> None: + """Append a raw event to the replay buffer.""" + self._events.append(event) + + def _push_message_start(self, data: MessageStartData) -> None: + """Process a `message-start` event.""" + self._start_metadata = data.get("metadata") + + def _push_content_block_delta(self, data: ContentBlockDeltaData) -> None: + """Process a `content-block-delta` event.""" + block = data.get("content_block") + if block is None: + return + btype = block.get("type", "") + event_idx = data.get("index") + + if btype == "text": + text_block = cast("TextContentBlock", block) + delta_text = text_block.get("text", "") + if delta_text: + self._text_acc += delta_text + if event_idx is not None: + self._text_per_block[event_idx] = ( + self._text_per_block.get(event_idx, "") + delta_text + ) + self._text_proj.push(delta_text) + elif btype == "reasoning": + reasoning_block = cast("ReasoningContentBlock", block) + delta_r = reasoning_block.get("reasoning", "") + if delta_r: + self._reasoning_acc += delta_r + if event_idx is not None: + self._reasoning_per_block[event_idx] = ( + self._reasoning_per_block.get(event_idx, "") + delta_r + ) + self._reasoning_proj.push(delta_r) + elif btype == "tool_call_chunk": + tcc = cast("ToolCallChunk", block) + # The protocol puts the block index on the event + # (`ContentBlockDeltaData`), not inside `content_block`. + # Fall back to `content_block.index` for providers that echo + # it there. + idx = data.get("index") + if idx is None: + idx = tcc.get("index", len(self._tool_call_chunks)) + _merge_chunk_into_store(self._tool_call_chunks, idx, dict(tcc)) + chunk_block: ToolCallChunk = { + "type": "tool_call_chunk", + "id": tcc.get("id"), + "name": tcc.get("name"), + "args": tcc.get("args"), + } + if "index" in tcc: + chunk_block["index"] = tcc["index"] + self._tool_calls_proj.push(chunk_block) + elif btype == "server_tool_call_chunk": + stcc = cast("ServerToolCallChunk", block) + idx = data.get("index") + if idx is None: + idx = len(self._server_tool_call_chunks) + _merge_chunk_into_store( + self._server_tool_call_chunks, + idx, + dict(stcc), + ) + + def _resolve_block_text(self, idx: int | None, full_text: str) -> str: + """Return authoritative text for a single text block at `idx`. + + Prefers per-block delta accumulation; reconciles with the finish + event's `full_text` when the provider emits authoritative text + that differs from what the deltas built up. + + Does not mutate `self._text_acc` (the delta-sum accumulator) — + the message-wide projection value is derived from per-block + storage at `_finish` time, so reconciliation remains correct + regardless of finish ordering across blocks. + """ + if idx is None: + # No wire index — legacy behavior: use the message-wide + # accumulator. Preserved for pre-index semantics; not + # exercised by the compat bridge or any in-tree provider. + if full_text and full_text != self._text_acc: + self._text_acc = full_text + return self._text_acc + existing = self._text_per_block.get(idx, "") + if full_text and full_text != existing: + if not existing: + # No deltas arrived for this block — surface the full + # text as a single delta so the stream projection + # reflects it. + self._text_acc += full_text + self._text_proj.push(full_text) + elif full_text.startswith(existing): + # Authoritative text extends the partial deltas — emit + # the tail so delta consumers see the completion. + tail = full_text[len(existing) :] + self._text_acc += tail + self._text_proj.push(tail) + # else: authoritative text replaces the partial deltas + # entirely. No corrective delta is emitted (semantics + # would be ambiguous mid-stream). `_text_acc` is not + # spliced — the final value is computed from per-block + # storage at `_finish`, so this remains correct even when + # other blocks have added to `_text_acc` in between. + self._text_per_block[idx] = full_text + return self._text_per_block.get(idx, "") + + def _resolve_block_reasoning(self, idx: int | None, full_r: str) -> str: + """Return authoritative reasoning text for a single block at `idx`. + + Mirrors `_resolve_block_text` for the reasoning projection. + """ + if idx is None: + if full_r and full_r != self._reasoning_acc: + self._reasoning_acc = full_r + return self._reasoning_acc + existing = self._reasoning_per_block.get(idx, "") + if full_r and full_r != existing: + if not existing: + self._reasoning_acc += full_r + self._reasoning_proj.push(full_r) + elif full_r.startswith(existing): + tail = full_r[len(existing) :] + self._reasoning_acc += tail + self._reasoning_proj.push(tail) + self._reasoning_per_block[idx] = full_r + return self._reasoning_per_block.get(idx, "") + + def _push_content_block_finish(self, data: ContentBlockFinishData) -> None: + """Process a `content-block-finish` event.""" + block = data.get("content_block") + if block is None: + return + btype = block.get("type", "") + idx = data.get("index") + finalized: FinalizedContentBlock | None = None + + if btype == "text": + text_block = cast("TextContentBlock", block) + full_text = text_block.get("text", "") + block_text = self._resolve_block_text(idx, full_text) + finalized = cast( + "FinalizedContentBlock", + { + **text_block, + "type": "text", + "text": block_text, + }, + ) + elif btype == "reasoning": + reasoning_block = cast("ReasoningContentBlock", block) + full_r = reasoning_block.get("reasoning", "") + block_reasoning = self._resolve_block_reasoning(idx, full_r) + # Keep provider-specific fields alongside the accumulated + # reasoning text. Anthropic's `signature` arrives under + # `extras` and is required on follow-up turns. Only overwrite + # `reasoning` when we have accumulated content; OpenAI can + # emit a reasoning block with no text deltas, and writing an + # empty string there makes downstream serializers synthesize + # an empty summary entry. + finalized_dict: dict[str, Any] = {**reasoning_block, "type": "reasoning"} + if block_reasoning: + finalized_dict["reasoning"] = block_reasoning + finalized = cast("FinalizedContentBlock", finalized_dict) + elif btype == "tool_call": + tcb = cast("ToolCall", block) + # Preserve provider-specific fields (extras, etc.) on the + # content block. `_assemble_message` separately projects the + # minimal {id, name, args, type} shape onto + # `AIMessage.tool_calls`. Strip `index` to match v1 + # (`AIMessage.init_tool_calls` rebuilds the block without + # `index`); see `_finalize_block` in `_compat_bridge.py`. + tc = cast( + "ToolCall", + { + **{k: v for k, v in tcb.items() if k != "index"}, + "type": "tool_call", + "id": tcb.get("id", ""), + "name": tcb.get("name", ""), + "args": tcb.get("args", {}), + }, + ) + self._tool_calls_acc.append(tc) + if idx is not None and idx in self._tool_call_chunks: + del self._tool_call_chunks[idx] + finalized = tc + elif btype == "invalid_tool_call": + itc = cast("InvalidToolCall", block) + # Strip `index` on the stored block to stay symmetric with + # the `tool_call` path. + itc = cast( + "InvalidToolCall", + {k: v for k, v in itc.items() if k != "index"}, + ) + self._invalid_tool_calls_acc.append(itc) + # Critical: drop the stale chunk so _finish's sweep doesn't revive + # it as an empty-args ToolCall. + if idx is not None and idx in self._tool_call_chunks: + del self._tool_call_chunks[idx] + if idx is not None and idx in self._server_tool_call_chunks: + del self._server_tool_call_chunks[idx] + finalized = itc + elif btype in ( + "server_tool_call", + "server_tool_result", + "image", + "audio", + "video", + "file", + "non_standard", + ): + if btype == "server_tool_call" and idx is not None: + self._server_tool_call_chunks.pop(idx, None) + finalized = block + + if finalized is not None and idx is not None: + # Backfill the wire index onto the finalized block when the + # source didn't supply one. `langchain_core.utils._merge`'s + # block-merger (used by `AIMessageChunk.__add__` / + # `add_ai_message_chunks`) keys on `block["index"]` to group + # deltas into the same output block — without it, a v2- + # assembled `AIMessage` that later re-enters the chunk + # aggregation path won't merge cleanly. Client-side + # `tool_call` / `invalid_tool_call` blocks are excluded: v1 + # finalization drops `index` on them so further deltas + # cannot clobber already-parsed args, and v2 mirrors that. + if btype not in ("tool_call", "invalid_tool_call"): + finalized.setdefault("index", idx) + self._blocks[idx] = finalized + + def _finish(self, data: MessageFinishData) -> None: + """Process a `message-finish` event.""" + self._done = True + self._usage_value = data.get("usage") + self._finish_metadata = data.get("metadata") + + # Finalize any unswept chunks — both client- and server-side. + _sweep_chunk_store( + self._tool_call_chunks, + finalized_type="tool_call", + finalized_blocks=self._blocks, + tool_calls_acc=self._tool_calls_acc, + invalid_acc=self._invalid_tool_calls_acc, + ) + _sweep_chunk_store( + self._server_tool_call_chunks, + finalized_type="server_tool_call", + finalized_blocks=self._blocks, + tool_calls_acc=None, + invalid_acc=self._invalid_tool_calls_acc, + ) + + # Prefer the per-block sum when any indexed text / reasoning + # arrived — it stays correct regardless of finish ordering and + # of whether finish events carried authoritative text that + # differed from the deltas. Fall back to the delta-sum + # accumulator only for the legacy no-index path. + if self._text_per_block: + text_final = "".join( + self._text_per_block[i] for i in sorted(self._text_per_block) + ) + else: + text_final = self._text_acc + if self._reasoning_per_block: + reasoning_final = "".join( + self._reasoning_per_block[i] for i in sorted(self._reasoning_per_block) + ) + else: + reasoning_final = self._reasoning_acc + + self._text_proj.complete(text_final) + self._reasoning_proj.complete(reasoning_final) + self._tool_calls_proj.complete(self._tool_calls_acc) + self._output_message = self._assemble_message() + + def fail(self, error: BaseException) -> None: + """Mark the stream as errored and propagate to all projections. + + Public API — called by the stream driver (`stream_v2` / + `astream_v2`) when the underlying producer raises, by + `dispatch` when an `error` protocol event arrives, and by + cancellation paths. + """ + self._done = True + self._error = error + self._text_proj.fail(error) + self._reasoning_proj.fail(error) + self._tool_calls_proj.fail(error) + + def _assemble_message(self) -> AIMessage: + """Build an `AIMessage` from accumulated state. + + Content is built from `self._blocks`, an index-ordered snapshot of + finalized protocol blocks. The bare-string fast path is used when + the message has exactly one `text` block (the common chat case); + otherwise content is a list of protocol-shape block dicts. + """ + content: Any + if not self._blocks: + # No protocol blocks ever arrived. Fall back to the accumulated + # text (possibly empty) as bare-string content. + content = self._text_acc + else: + # `ChatModelStream` is the v1 content-block surface: content + # is always a list of protocol blocks when any block arrived. + # Do not collapse a single text block down to a bare string — + # that would drop block-level fields (`id`, `index`, + # annotations, extras) that downstream serializers need to + # round-trip the message on a follow-up turn. + ordered_blocks = [self._blocks[idx] for idx in sorted(self._blocks)] + content = [dict(b) for b in ordered_blocks] + + response_metadata: dict[str, Any] = {} + if self._start_metadata: + if "provider" in self._start_metadata: + response_metadata["model_provider"] = self._start_metadata["provider"] + if "model" in self._start_metadata: + response_metadata["model_name"] = self._start_metadata["model"] + if self._finish_metadata: + response_metadata.update(self._finish_metadata) + # Pin `output_version` last: `stream_v2` always assembles content as v1 + # protocol blocks, regardless of the provider's configured output format. + # A provider-supplied `output_version` in finish metadata (e.g. + # `"responses/v1"` from `ChatOpenAI(use_responses_api=True, ...)`) would + # otherwise cause `AIMessage.content_blocks` to re-run the wrong + # translator on already-v1 content. + response_metadata["output_version"] = "v1" + + tool_calls = [ + { + "id": tc.get("id", ""), + "name": tc.get("name", ""), + "args": tc.get("args", {}), + "type": "tool_call", + } + for tc in self._tool_calls_acc + ] + + invalid_tool_calls = [ + { + "type": "invalid_tool_call", + "id": itc.get("id") or None, + "name": itc.get("name") or None, + "args": itc.get("args") or None, + "error": itc.get("error"), + } + for itc in self._invalid_tool_calls_acc + ] + + return AIMessage( + content=content, + id=self._message_id, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + usage_metadata=self._usage_value, + response_metadata=response_metadata, + ) + + +# --------------------------------------------------------------------------- +# Sync stream +# --------------------------------------------------------------------------- + + +class ChatModelStream(_ChatModelStreamBase): + """Synchronous per-message streaming object for a single LLM response. + + Returned by `BaseChatModel.stream_v2()`. Content-block protocol + events are fed into this object and accumulated into typed projections. + + Projections (always return the same cached object): + + - `.text` — iterable of `str` deltas; `str()` for full text + - `.reasoning` — same as `.text` for reasoning content + - `.tool_calls` — iterable of `ToolCallChunk` deltas; + `.get()` returns `list[ToolCall]` + - `.output` — blocking property, returns assembled `AIMessage` + + Usage info is available on `.output.usage_metadata` once the stream + has finished. + + !!! note "Output shape is always v1 content blocks" + + `.output.content` is always a list of v1 protocol blocks + (text, reasoning, tool_call, image, …), regardless of the + underlying model's `output_version` setting. That attribute + only controls the legacy `stream()` / `astream()` / `invoke()` + paths; `ChatModelStream` is built on the content-block + protocol and emits v1 shapes by construction. + + Raw event iteration:: + + for event in stream: + print(event) # MessagesData dicts + """ + + _text_proj: SyncTextProjection + _reasoning_proj: SyncTextProjection + _tool_calls_proj: SyncProjection + + def __init__( # noqa: D107 + self, + *, + namespace: list[str] | None = None, + node: str | None = None, + message_id: str | None = None, + ) -> None: + super().__init__(namespace=namespace, node=node, message_id=message_id) + # Projections — created eagerly + self._text_proj = SyncTextProjection() + self._reasoning_proj = SyncTextProjection() + self._tool_calls_proj = SyncProjection() + # Pull callback (set by bind_pump or set_request_more) + self._ensure_started: Callable[[], None] | None = None + self._request_more: Callable[[], bool] | None = None + + # -- Pump/pull wiring -------------------------------------------------- + + def bind_pump(self, pump_one: Callable[[], bool]) -> None: + """Bind a pump for standalone streaming. + + Delegates to `set_request_more`. Used by + `BaseChatModel.stream_v2()`. + """ + self.set_request_more(pump_one) + + def set_start(self, cb: Callable[[], None] | None) -> None: + """Install a lazy-start callback on this stream and its projections.""" + self._ensure_started = cb + self._text_proj.set_start(cb) + self._reasoning_proj.set_start(cb) + self._tool_calls_proj.set_start(cb) + + def set_request_more(self, cb: Callable[[], bool]) -> None: + """Set the pull callback on this stream and all its projections. + + Used by langgraph's `GraphRunStream._wire_request_more` to + connect the shared graph pump. + """ + self._request_more = cb + self._text_proj.set_request_more(cb) + self._reasoning_proj.set_request_more(cb) + self._tool_calls_proj.set_request_more(cb) + + # -- Public projections ------------------------------------------------ + + @property + def text(self) -> SyncTextProjection: + """Text content — iterable of `str` deltas, `str()` for full.""" + return self._text_proj + + @property + def reasoning(self) -> SyncTextProjection: + """Reasoning content — same interface as :attr:`text`.""" + return self._reasoning_proj + + @property + def tool_calls(self) -> SyncProjection: + """Tool calls — iterable of `ToolCallChunk` deltas. + + `.get()` returns finalized `list[ToolCall]`. + """ + return self._tool_calls_proj + + @property + def output(self) -> AIMessage: + """Assembled `AIMessage` — blocks until the stream finishes.""" + self._drain() + if self._error is not None: + raise self._error + if self._output_message is None: + msg = "Stream finished without producing a message" + raise RuntimeError(msg) + return self._output_message + + # -- Raw event iteration (replay buffer) ------------------------------- + + def __iter__(self) -> Iterator[MessagesData]: + """Iterate raw protocol events with replay-buffer semantics.""" + if self._ensure_started is not None: + self._ensure_started() + cursor = 0 + while True: + if cursor < len(self._events): + yield self._events[cursor] + cursor += 1 + elif self._error is not None: + raise self._error + elif self._done: + return + elif self._request_more is not None: + while cursor >= len(self._events) and not self._done: + if not self._request_more(): + break + if cursor >= len(self._events): + if self._error is not None: + raise self._error + return + else: + return + + # -- Internal helpers -------------------------------------------------- + + def _drain(self) -> None: + """Pull all remaining events until done.""" + if self._done: + return + if self._ensure_started is not None: + self._ensure_started() + if self._request_more is not None: + while not self._done: + if not self._request_more(): + break + + +# --------------------------------------------------------------------------- +# Async stream +# --------------------------------------------------------------------------- + + +class AsyncChatModelStream(_ChatModelStreamBase): + """Asynchronous per-message streaming object for a single LLM response. + + Returned by `BaseChatModel.astream_v2()`. Content-block events + are fed into this object by a background producer task. + + Projections: + + - `.text` — async iterable of text deltas; awaitable for full text + - `.reasoning` — async iterable of reasoning deltas; awaitable + - `.tool_calls` — async iterable of `ToolCallChunk` deltas; + awaitable for `list[ToolCall]` + - `.output` — awaitable for assembled `AIMessage` + + Usage info is available on `.output.usage_metadata` once the stream + has finished. + + !!! note "Output shape is always v1 content blocks" + + The assembled message's content is always a list of v1 + protocol blocks, regardless of the model's `output_version` + setting — see `ChatModelStream` for the full rationale. + + The stream itself is awaitable (`msg = await stream`) and + async-iterable (`async for event in stream`). + """ + + _text_proj: AsyncProjection + _reasoning_proj: AsyncProjection + _tool_calls_proj: AsyncProjection + + def __init__( # noqa: D107 + self, + *, + namespace: list[str] | None = None, + node: str | None = None, + message_id: str | None = None, + ) -> None: + super().__init__(namespace=namespace, node=node, message_id=message_id) + self._text_proj = AsyncProjection() + self._reasoning_proj = AsyncProjection() + self._tool_calls_proj = AsyncProjection() + self._output_proj = AsyncProjection() + self._events_proj = AsyncProjection() + self._ensure_started: Callable[[], Awaitable[None]] | None = None + self._producer_task: asyncio.Task[None] | None = None + # Teardown callback invoked by `aclose()` only when the producer + # task was cancelled before its body ran (so the normal + # `_produce` CancelledError handler — which fires + # `on_llm_error` — never executed). Set by `astream_v2`. + self._on_aclose_fail: Callable[[BaseException], Awaitable[None]] | None = None + + # -- Pump/pull wiring (async) ------------------------------------------ + + def set_arequest_more(self, cb: Callable[[], Awaitable[bool]] | None) -> None: + """Fan the async pump callback out to every projection. + + Used by langgraph's `AsyncGraphRunStream._wire_arequest_more` so + cursors on `stream.text`, `stream.reasoning`, etc. can drive the + shared graph pump when their buffer is empty. + + Args: + cb: Async no-arg callable returning `True` when a new event + was produced, `False` when the source is exhausted. Pass + `None` to unwire. + """ + for proj in ( + self._text_proj, + self._reasoning_proj, + self._tool_calls_proj, + self._output_proj, + self._events_proj, + ): + proj.set_arequest_more(cb) + + def set_start(self, cb: Callable[[], Awaitable[None]] | None) -> None: + """Install a lazy-start callback on this stream and its projections.""" + self._ensure_started = cb + for proj in ( + self._text_proj, + self._reasoning_proj, + self._tool_calls_proj, + self._output_proj, + self._events_proj, + ): + proj.set_start(cb) + + # -- Public projections ------------------------------------------------ + + @property + def text(self) -> AsyncProjection: + """Text content — async iterable of deltas, awaitable for full.""" + return self._text_proj + + @property + def reasoning(self) -> AsyncProjection: + """Reasoning content — same interface as :attr:`text`.""" + return self._reasoning_proj + + @property + def tool_calls(self) -> AsyncProjection: + """Tool calls — async iterable, awaitable for finalized list.""" + return self._tool_calls_proj + + @property + def output(self) -> AsyncProjection: + """Assembled `AIMessage` — awaitable.""" + return self._output_proj + + def __await__(self) -> Generator[Any, None, AIMessage]: + """Await the assembled `AIMessage` and full producer lifecycle. + + The producer task is awaited after the output projection resolves so + that post-stream work (notably `on_llm_end` callbacks) has run by + the time the caller's `await` returns. + """ + return self._await_full().__await__() + + async def _await_full(self) -> AIMessage: + if self._ensure_started is not None: + await self._ensure_started() + message: AIMessage = await self._output_proj + if self._producer_task is not None: + await self._producer_task + return message + + def __aiter__(self) -> _AsyncProjectionIterator: + """Iterate raw protocol events asynchronously.""" + return _AsyncProjectionIterator(self._events_proj) + + # -- Cleanup ----------------------------------------------------------- + + async def aclose(self) -> None: + """Cancel the background producer task and release resources. + + If a consumer cancels mid-stream or decides to stop iterating + early, the producer task keeps pumping the provider HTTP call to + completion because `asyncio.Task` has no implicit link to its + awaiter. Call this method to cancel the producer explicitly; the + stream transitions to an errored state with `CancelledError`. + + If the stream has already produced a message successfully (for + example, after `await stream.output`), the producer may still be + running post-stream work such as `on_llm_end` callbacks. In that + case `aclose()` awaits the task rather than cancelling it — + turning a successful run into a cancelled one would drop the + end callback and corrupt tracing. + + Idempotent: safe to call multiple times, including after the + stream has finished normally. Also invoked by the async context + manager protocol on `__aexit__`. + """ + if self._ensure_started is not None and self._producer_task is None: + await self._ensure_started() + + task = self._producer_task + if task is None: + return + if task.done() and self._done: + return + + we_cancelled = not (self._output_message is not None and self._error is None) + if we_cancelled and not task.done(): + task.cancel() + + # Wait for the task via a linked `Future`, not by awaiting the + # task directly. Awaiting the task would raise `CancelledError` + # in two indistinguishable cases: (1) the task we just cancelled + # completed, (2) our caller cancelled us. `asyncio.Task.cancelling()` + # disambiguates on 3.11+ but doesn't exist on 3.10. + # + # The `done_future` resolves with `None` whenever the task + # finishes (any reason). It is not a `Task` itself, so its + # `await` only raises when our caller is cancelled — giving us + # a portable, unambiguous signal to propagate. + if not task.done(): + loop = asyncio.get_running_loop() + done_future: asyncio.Future[None] = loop.create_future() + + def _link(_: asyncio.Task[None]) -> None: + if not done_future.done(): + done_future.set_result(None) + + task.add_done_callback(_link) + try: + await done_future + finally: + task.remove_done_callback(_link) + + # If the task was cancelled before `_produce` ran (e.g. + # `astream_v2()` immediately followed by `aclose()`), the stream + # never reached `_produce`'s CancelledError handler — its + # projections are still pending and no end-of-lifecycle callback + # has fired. Resolve both here so callers of `await stream.output` + # don't hang and tracing sees a matching end event. + if we_cancelled and not self._done: + cancel_exc = asyncio.CancelledError() + self.fail(cancel_exc) + teardown = self._on_aclose_fail + if teardown is not None: + with contextlib.suppress(Exception): + await teardown(cancel_exc) + + async def __aenter__(self) -> Self: + """Enter the async context — returns self.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: object, + ) -> None: + """Exit the async context — cancels the producer via `aclose()`.""" + del exc_type, exc, tb + await self.aclose() + + # -- Internal API (extend base to drive async projections) ------------- + + def _record_event(self, event: MessagesData) -> None: + """Record event and push to async event replay projection.""" + super()._record_event(event) + self._events_proj.push(event) + + def _finish(self, data: MessageFinishData) -> None: + """Finish base projections and async-only projections.""" + super()._finish(data) + self._output_proj.complete(self._output_message) + self._events_proj.complete(self._events) + + def fail(self, error: BaseException) -> None: + """Fail base projections and async-only projections.""" + super().fail(error) + self._output_proj.fail(error) + self._events_proj.fail(error) + + +__all__ = [ + "AsyncChatModelStream", + "AsyncProjection", + "ChatModelStream", + "SyncProjection", + "SyncTextProjection", +] diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 5cbb92547b5..c38a9d3470f 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -12,6 +12,7 @@ from functools import cached_property from operator import itemgetter from typing import TYPE_CHECKING, Any, Literal, cast +from langchain_protocol.protocol import MessageFinishData from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Self, override @@ -24,6 +25,12 @@ from langchain_core.callbacks import ( Callbacks, ) from langchain_core.globals import get_llm_cache +from langchain_core.language_models._compat_bridge import ( + achunks_to_events, + amessage_to_events, + chunks_to_events, + message_to_events, +) from langchain_core.language_models._utils import ( _filter_invocation_params_for_tracing, _normalize_messages, @@ -34,6 +41,10 @@ from langchain_core.language_models.base import ( LangSmithParams, LanguageModelInput, ) +from langchain_core.language_models.chat_model_stream import ( + AsyncChatModelStream, + ChatModelStream, +) from langchain_core.language_models.model_profile import ( ModelProfile, _warn_unknown_profile_keys, @@ -69,7 +80,10 @@ from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPro from langchain_core.rate_limiters import BaseRateLimiter from langchain_core.runnables import RunnableMap, RunnablePassthrough from langchain_core.runnables.config import ensure_config, run_in_executor -from langchain_core.tracers._streaming import _StreamingCallbackHandler +from langchain_core.tracers._streaming import ( + _StreamingCallbackHandler, + _V2StreamingCallbackHandler, +) from langchain_core.utils.function_calling import ( convert_to_json_schema, convert_to_openai_tool, @@ -81,6 +95,8 @@ if TYPE_CHECKING: import builtins import uuid + from langchain_protocol.protocol import MessagesData + from langchain_core.output_parsers.base import OutputParserLike from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.tools import BaseTool @@ -489,6 +505,30 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): "AIMessage", cast("ChatGeneration", llm_result.generations[0][0]).message ) + def _streaming_disabled(self, **kwargs: Any) -> bool: + """Return whether streaming is hard-disabled for this call. + + Shared opt-outs honored by both `_should_stream` and + `_should_stream_v2` — these override any affirmative trigger + (attached handler, `stream=True`, etc.): + + - `self.disable_streaming is True` + - `self.disable_streaming == "tool_calling"` with `tools` passed + - `stream=` in call kwargs + - `self.streaming is False` on the instance + """ + if self.disable_streaming is True: + return True + # We assume tools are passed in via "tools" kwarg in all models. + if self.disable_streaming == "tool_calling" and kwargs.get("tools"): + return True + if "stream" in kwargs and not kwargs["stream"]: + return True + return ( + "streaming" in self.model_fields_set + and getattr(self, "streaming", None) is False + ) + def _should_stream( self, *, @@ -509,26 +549,163 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): if async_api and async_not_implemented and sync_not_implemented: return False - # Check if streaming has been disabled on this instance. - if self.disable_streaming is True: - return False - # We assume tools are passed in via "tools" kwarg in all models. - if self.disable_streaming == "tool_calling" and kwargs.get("tools"): + if self._streaming_disabled(**kwargs): return False - # Check if a runtime streaming flag has been passed in. - if "stream" in kwargs: - return bool(kwargs["stream"]) + # Affirmative: explicit `stream=` kwarg. + if kwargs.get("stream"): + return True - if "streaming" in self.model_fields_set: - streaming_value = getattr(self, "streaming", None) - if isinstance(streaming_value, bool): - return streaming_value + # Affirmative: instance-level `streaming=True` attribute. + if ( + "streaming" in self.model_fields_set + and getattr(self, "streaming", None) is True + ): + return True - # Check if any streaming callback handlers have been passed in. + # Affirmative: a v1 streaming callback handler is attached. handlers = run_manager.handlers if run_manager else [] return any(isinstance(h, _StreamingCallbackHandler) for h in handlers) + def _should_stream_v2( + self, + *, + async_api: bool, + run_manager: CallbackManagerForLLMRun + | AsyncCallbackManagerForLLMRun + | None = None, + **kwargs: Any, + ) -> bool: + """Determine whether an invoke should route through the v2 event path. + + Runs alongside `_should_stream` inside `_generate_with_cache` / + `_agenerate_with_cache` — after the run manager is open — and + wins over the v1 streaming branch when a handler has declared + itself a `_V2StreamingCallbackHandler`. Parallel to + `_should_stream` rather than a delegation — v1 and v2 have + disjoint affirmative triggers. + + Args: + async_api: Whether the caller is on the async path. + run_manager: The active LLM run manager. + **kwargs: Call kwargs; inspected for `disable_streaming` + semantics and an explicit `stream=False` override. + + Returns: + `True` if any attached handler inherits + `_V2StreamingCallbackHandler` and the model can drive the v2 + event generator (natively or via the `_stream` compat + bridge). + """ + # Opt-in: only route through v2 when a v2 handler is attached. + handlers = run_manager.handlers if run_manager else [] + if not any(isinstance(h, _V2StreamingCallbackHandler) for h in handlers): + return False + + # Need a source of v2 events on the requested flavor. A native + # `_(a)stream_chat_model_events` hook bypasses the bridge; + # otherwise the bridge wraps `_stream` / `_astream`. Async can + # fall back to sync. + # + # `cls._stream is not BaseChatModel._stream` is an identity + # check for "subclass overrode `_stream`" — same pattern as + # `_should_stream`. + cls = type(self) + has_native_sync = getattr(cls, "_stream_chat_model_events", None) is not None + has_native_async = getattr(cls, "_astream_chat_model_events", None) is not None + overrides_sync = cls._stream is not BaseChatModel._stream + overrides_async = cls._astream is not BaseChatModel._astream + has_sync_source = has_native_sync or overrides_sync + has_async_source = has_native_async or overrides_async + has_source = ( + (has_sync_source or has_async_source) if async_api else has_sync_source + ) + if not has_source: + return False + + return not self._streaming_disabled(**kwargs) + + def _iter_v2_events( + self, + messages: list[BaseMessage], + *, + run_manager: CallbackManagerForLLMRun, + stream: ChatModelStream, + stop: list[str] | None = None, + **kwargs: Any, + ) -> Iterator[MessagesData]: + """Drive the v2 event generator with per-event dispatch. + + Shared between `stream_v2`'s pump and the invoke-time v2 branch + in `_generate_with_cache`. Picks the native + `_stream_chat_model_events` hook when the subclass provides one, + else bridges `_stream` chunks via `chunks_to_events`. Each event + is dispatched into `stream` and fired as `on_stream_event` on + the run manager. Run-lifecycle callbacks + (`on_chat_model_start` / `on_llm_end` / `on_llm_error`) and + rate-limiter acquisition are the caller's responsibility. + + Args: + messages: Normalized input messages. + run_manager: Active LLM run manager; receives + `on_stream_event` per event. + stream: Accumulator owned by the caller; receives each + event via `stream.dispatch`. + stop: Optional stop sequences. + **kwargs: Forwarded to the event producer. + + Yields: + Each protocol event produced by the model. + """ + native = cast( + "Callable[..., Iterator[MessagesData]] | None", + getattr(self, "_stream_chat_model_events", None), + ) + if native is not None: + event_iter: Iterator[MessagesData] = native( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + event_iter = chunks_to_events( + self._stream(messages, stop=stop, run_manager=run_manager, **kwargs), + message_id=stream.message_id, + ) + for event in event_iter: + stream.dispatch(event) + run_manager.on_stream_event(event) + yield event + + async def _aiter_v2_events( + self, + messages: list[BaseMessage], + *, + run_manager: AsyncCallbackManagerForLLMRun, + stream: AsyncChatModelStream, + stop: list[str] | None = None, + **kwargs: Any, + ) -> AsyncIterator[MessagesData]: + """Async counterpart to `_iter_v2_events`. + + See `_iter_v2_events` for the shared contract. + """ + native = cast( + "Callable[..., AsyncIterator[MessagesData]] | None", + getattr(self, "_astream_chat_model_events", None), + ) + if native is not None: + event_iter: AsyncIterator[MessagesData] = native( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + event_iter = achunks_to_events( + self._astream(messages, stop=stop, run_manager=run_manager, **kwargs), + message_id=stream.message_id, + ) + async for event in event_iter: + stream.dispatch(event) + await run_manager.on_stream_event(event) + yield event + @override def stream( self, @@ -791,6 +968,334 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): LLMResult(generations=[[generation]]), ) + # --- stream_v2 / astream_v2 --- + + def stream_v2( + self, + input: LanguageModelInput, + config: RunnableConfig | None = None, + *, + stop: list[str] | None = None, + **kwargs: Any, + ) -> ChatModelStream: + """Stream content-block lifecycle events for a single model call. + + Returns a `ChatModelStream` with typed projections + (`.text`, `.reasoning`, `.tool_calls`, `.output`). + + !!! warning + + This API is experimental and may change. + + !!! note "Always produces v1-shaped content" + + `ChatModelStream.output.content` is always a list of v1 + content blocks (text / reasoning / tool_call / image / …), + regardless of the model's `output_version` attribute. The + setting only affects the legacy `stream()` / `astream()` / + `invoke()` paths. If you're mixing `stream_v2` with those + paths in the same pipeline and need a consistent output + shape across them, set `output_version="v1"` on the model. + + Args: + input: The model input. + config: Optional runnable config. + stop: Optional list of stop words. + **kwargs: Additional keyword arguments passed to the model. + + Returns: + A `ChatModelStream` with typed projections. + """ + config = ensure_config(config) + messages = self._convert_input(input).to_messages() + input_messages = _normalize_messages(messages) + + # Strip tracing-only kwargs before forwarding to `_stream` — matches + # `stream()` / `astream()`. Provider clients reject unknown kwargs, so + # `.with_structured_output().stream_v2(...)` and any other binding that + # carries `ls_structured_output_format` / `structured_output_format` + # would raise without this pop. + ls_structured_output_format = kwargs.pop( + "ls_structured_output_format", None + ) or kwargs.pop("structured_output_format", None) + ls_structured_output_format_dict = _format_ls_structured_output( + ls_structured_output_format + ) + + params = self._get_invocation_params(stop=stop, **kwargs) + options = {"stop": stop, **kwargs, **ls_structured_output_format_dict} + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params_with_defaults(stop=stop, **kwargs), + } + callback_manager = CallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + inheritable_metadata, + self.metadata, + langsmith_inheritable_metadata=_filter_invocation_params_for_tracing( + params + ), + ) + stream = ChatModelStream() + run_manager: CallbackManagerForLLMRun | None = None + event_iter_ref: Iterator[MessagesData] | None = None + rate_limiter_acquired = self.rate_limiter is None + run_name = config.get("run_name") + run_id = config.pop("run_id", None) + + def ensure_started() -> None: + nonlocal event_iter_ref, run_manager + if event_iter_ref is not None: + return + + (run_manager,) = callback_manager.on_chat_model_start( + self._serialized, + [_format_for_tracing(messages)], + invocation_params=params, + options=options, + name=run_name, + run_id=run_id, + batch_size=1, + ) + stream.set_message_id("-".join((LC_ID_PREFIX, str(run_manager.run_id)))) + event_iter_ref = iter( + self._iter_v2_events( + input_messages, + run_manager=run_manager, + stream=stream, + stop=stop, + **kwargs, + ) + ) + + def pump_one() -> bool: + nonlocal rate_limiter_acquired + ensure_started() + if not rate_limiter_acquired: + assert self.rate_limiter is not None # noqa: S101 + self.rate_limiter.acquire(blocking=True) + rate_limiter_acquired = True + assert event_iter_ref is not None # noqa: S101 + assert run_manager is not None # noqa: S101 + try: + next(event_iter_ref) + except StopIteration: + if not stream.done: + if stream.has_events: + # Native event producers may omit the terminal + # `message-finish`. Close the lifecycle here so + # `on_llm_end` still observes the assembled + # message. A truly empty stream remains an error + # for parity with `stream()`. + stream.dispatch(MessageFinishData(event="message-finish")) + else: + err = ValueError("No generation chunks were returned") + stream.fail(err) + run_manager.on_llm_error( + err, + response=LLMResult(generations=[]), + ) + return False + if stream.done and stream.output_message is not None: + run_manager.on_llm_end( + LLMResult( + generations=[ + [ChatGeneration(message=stream.output_message)], + ], + ), + ) + return False + except BaseException as exc: + stream.fail(exc) + run_manager.on_llm_error( + exc, + response=LLMResult(generations=[]), + ) + return False + if stream.done and stream.output_message is not None: + run_manager.on_llm_end( + LLMResult( + generations=[ + [ChatGeneration(message=stream.output_message)], + ], + ), + ) + return True + + stream.set_start(ensure_started) + stream.bind_pump(pump_one) + return stream + + async def astream_v2( + self, + input: LanguageModelInput, + config: RunnableConfig | None = None, + *, + stop: list[str] | None = None, + **kwargs: Any, + ) -> AsyncChatModelStream: + """Async variant of `stream_v2`. + + Returns an `AsyncChatModelStream` whose projections are + async-iterable and awaitable. + + !!! warning + + This API is experimental and may change. + + !!! note "Always produces v1-shaped content" + + The assembled message's content is always a list of v1 + content blocks, regardless of the model's `output_version` + attribute — see `stream_v2` for the full rationale. + + Args: + input: The model input. + config: Optional runnable config. + stop: Optional list of stop words. + **kwargs: Additional keyword arguments passed to the model. + + Returns: + An `AsyncChatModelStream` with typed projections. + """ + config = ensure_config(config) + messages = self._convert_input(input).to_messages() + input_messages = _normalize_messages(messages) + + # Strip tracing-only kwargs before forwarding — see `stream_v2` for the + # full rationale. + ls_structured_output_format = kwargs.pop( + "ls_structured_output_format", None + ) or kwargs.pop("structured_output_format", None) + ls_structured_output_format_dict = _format_ls_structured_output( + ls_structured_output_format + ) + + params = self._get_invocation_params(stop=stop, **kwargs) + options = {"stop": stop, **kwargs, **ls_structured_output_format_dict} + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params_with_defaults(stop=stop, **kwargs), + } + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + inheritable_metadata, + self.metadata, + langsmith_inheritable_metadata=_filter_invocation_params_for_tracing( + params + ), + ) + stream = AsyncChatModelStream() + run_manager: AsyncCallbackManagerForLLMRun | None = None + run_name = config.get("run_name") + run_id = config.pop("run_id", None) + start_lock = asyncio.Lock() + + async def _produce() -> None: + assert run_manager is not None # noqa: S101 + try: + if self.rate_limiter: + await self.rate_limiter.aacquire(blocking=True) + + async for _event in self._aiter_v2_events( + input_messages, + run_manager=run_manager, + stream=stream, + stop=stop, + **kwargs, + ): + pass + if not stream.done: + if stream.has_events: + # Native event producers may omit the terminal + # `message-finish`. Close the lifecycle here so + # `on_llm_end` sees the finalized message. A + # truly empty stream remains an error for parity + # with `astream()`. + stream.dispatch(MessageFinishData(event="message-finish")) + else: + err = ValueError("No generation chunks were returned") + stream.fail(err) + await run_manager.on_llm_error( + err, + response=LLMResult(generations=[]), + ) + return + if stream.done and stream.output_message is not None: + await run_manager.on_llm_end( + LLMResult( + generations=[ + [ChatGeneration(message=stream.output_message)], + ], + ), + ) + except asyncio.CancelledError as exc: + stream.fail(exc) + # Close the callback lifecycle so tracing observes a + # matching end event for the earlier `on_chat_model_start`. + # `on_llm_error` is `@shielded`, so the callback runs to + # completion in the background even though the `await` + # here re-raises our cancellation. + with contextlib.suppress(Exception): + await run_manager.on_llm_error( + exc, + response=LLMResult(generations=[]), + ) + raise + except BaseException as exc: + stream.fail(exc) + await run_manager.on_llm_error( + exc, + response=LLMResult(generations=[]), + ) + + async def ensure_started() -> None: + nonlocal run_manager + if stream._producer_task is not None: # noqa: SLF001 + return + + async with start_lock: + if stream._producer_task is not None: # noqa: SLF001 + return + + (run_manager,) = await callback_manager.on_chat_model_start( + self._serialized, + [_format_for_tracing(messages)], + invocation_params=params, + options=options, + name=run_name, + run_id=run_id, + batch_size=1, + ) + stream.set_message_id("-".join((LC_ID_PREFIX, str(run_manager.run_id)))) + stream._producer_task = asyncio.get_running_loop().create_task( # noqa: SLF001 + _produce() + ) + + async def _on_aclose_fail(exc: BaseException) -> None: + assert run_manager is not None # noqa: S101 + # Invoked by `stream.aclose()` only when the producer was + # cancelled before `_produce` ran — so `on_llm_error` from + # the CancelledError handler never fired. Shielded by the + # callback manager; runs to completion even if our caller + # is being cancelled. + await run_manager.on_llm_error( + exc, + response=LLMResult(generations=[]), + ) + + stream.set_start(ensure_started) + stream._on_aclose_fail = _on_aclose_fail # noqa: SLF001 + return stream + # --- Custom methods --- def _combine_llm_outputs(self, _llm_outputs: list[dict | None], /) -> dict: @@ -835,6 +1340,52 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): converted_generations.append(gen) return converted_generations + def _replay_v2_events_for_cache_hit( + self, + generations: list[ChatGeneration], + *, + run_manager: CallbackManagerForLLMRun | None, + **kwargs: Any, + ) -> None: + """Replay cached messages as v2 events when a v2 handler is attached. + + A warm cache must produce the same `on_stream_event` stream as a cold + call so LangGraph-style consumers do not observe behavior that depends + on cache state. Gated by `_should_stream_v2` so a `disable_streaming` + config that suppresses v2 on cold calls also suppresses it here. + """ + if run_manager is None or not self._should_stream_v2( + async_api=False, run_manager=run_manager, **kwargs + ): + return + message_id = f"{LC_ID_PREFIX}-{run_manager.run_id}" + for gen in generations: + msg = getattr(gen, "message", None) + if not isinstance(msg, AIMessage): + continue + for event in message_to_events(msg, message_id=message_id): + run_manager.on_stream_event(event) + + async def _areplay_v2_events_for_cache_hit( + self, + generations: list[ChatGeneration], + *, + run_manager: AsyncCallbackManagerForLLMRun | None, + **kwargs: Any, + ) -> None: + """Async counterpart to `_replay_v2_events_for_cache_hit`.""" + if run_manager is None or not self._should_stream_v2( + async_api=True, run_manager=run_manager, **kwargs + ): + return + message_id = f"{LC_ID_PREFIX}-{run_manager.run_id}" + for gen in generations: + msg = getattr(gen, "message", None) + if not isinstance(msg, AIMessage): + continue + async for event in amessage_to_events(msg, message_id=message_id): + await run_manager.on_stream_event(event) + def _get_invocation_params( self, stop: list[str] | None = None, @@ -1237,6 +1788,11 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): cache_val = llm_cache.lookup(prompt, llm_string) if isinstance(cache_val, list): converted_generations = self._convert_cached_generations(cache_val) + self._replay_v2_events_for_cache_hit( + converted_generations, + run_manager=run_manager, + **kwargs, + ) return ChatResult(generations=converted_generations) elif self.cache is None: pass @@ -1250,9 +1806,39 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): if self.rate_limiter: self.rate_limiter.acquire(blocking=True) + # v2 streaming: preferred over v1 when any attached handler opts in via + # `_V2StreamingCallbackHandler`. Drives the protocol event generator + # (native or `_stream` compat bridge) through the shared helper so + # `on_stream_event` fires per event, then returns a normal `ChatResult` + # so caching / `on_llm_end` stay on the existing generate path. + if self._should_stream_v2( + async_api=False, + run_manager=run_manager, + **kwargs, + ): + stream_accum = ChatModelStream( + message_id=( + f"{LC_ID_PREFIX}-{run_manager.run_id}" if run_manager else None + ) + ) + assert run_manager is not None # noqa: S101 + for _event in self._iter_v2_events( + messages, + run_manager=run_manager, + stream=stream_accum, + stop=stop, + **kwargs, + ): + pass + if stream_accum.output_message is None: + msg = "v2 stream finished without producing a message" + raise RuntimeError(msg) + result = ChatResult( + generations=[ChatGeneration(message=stream_accum.output_message)] + ) # If stream is not explicitly set, check if implicitly requested by # astream_events() or astream_log(). Bail out if _stream not implemented - if self._should_stream( + elif self._should_stream( async_api=False, run_manager=run_manager, **kwargs, @@ -1363,6 +1949,11 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): cache_val = await llm_cache.alookup(prompt, llm_string) if isinstance(cache_val, list): converted_generations = self._convert_cached_generations(cache_val) + await self._areplay_v2_events_for_cache_hit( + converted_generations, + run_manager=run_manager, + **kwargs, + ) return ChatResult(generations=converted_generations) elif self.cache is None: pass @@ -1376,9 +1967,35 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): if self.rate_limiter: await self.rate_limiter.aacquire(blocking=True) + # v2 streaming: see sync counterpart in `_generate_with_cache`. + if self._should_stream_v2( + async_api=True, + run_manager=run_manager, + **kwargs, + ): + stream_accum = AsyncChatModelStream( + message_id=( + f"{LC_ID_PREFIX}-{run_manager.run_id}" if run_manager else None + ) + ) + assert run_manager is not None # noqa: S101 + async for _event in self._aiter_v2_events( + messages, + run_manager=run_manager, + stream=stream_accum, + stop=stop, + **kwargs, + ): + pass + if stream_accum.output_message is None: + msg = "v2 stream finished without producing a message" + raise RuntimeError(msg) + result = ChatResult( + generations=[ChatGeneration(message=stream_accum.output_message)] + ) # If stream is not explicitly set, check if implicitly requested by # astream_events() or astream_log(). Bail out if _astream not implemented - if self._should_stream( + elif self._should_stream( async_api=True, run_manager=run_manager, **kwargs, diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 29a7d8ed731..c63b5c6fce0 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -103,6 +103,10 @@ if TYPE_CHECKING: AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) + from langchain_core.language_models.chat_model_stream import ( + AsyncChatModelStream, + ChatModelStream, + ) from langchain_core.prompts.base import BasePromptTemplate from langchain_core.runnables.fallbacks import ( RunnableWithFallbacks as RunnableWithFallbacksT, @@ -1169,6 +1173,46 @@ class Runnable(ABC, Generic[Input, Output]): """ yield await self.ainvoke(input, config, **kwargs) + def stream_v2( + self, + input: Input, + config: RunnableConfig | None = None, + **kwargs: Any | None, + ) -> ChatModelStream: + """Stream content-block lifecycle events (v2 protocol). + + Implemented by `BaseChatModel` (and forwarded by `RunnableBinding`). + Generic `Runnable`s don't participate in the v2 event protocol — + use `.stream()` instead. + + Raises: + NotImplementedError: Always, on the base `Runnable` class. + """ + msg = ( + f"{type(self).__name__} does not implement `stream_v2`. " + "`stream_v2` is only implemented by chat models; use `.stream()` " + "for generic Runnables." + ) + raise NotImplementedError(msg) + + async def astream_v2( + self, + input: Input, + config: RunnableConfig | None = None, + **kwargs: Any | None, + ) -> AsyncChatModelStream: + """Async variant of `stream_v2`. See that method. + + Raises: + NotImplementedError: Always, on the base `Runnable` class. + """ + msg = ( + f"{type(self).__name__} does not implement `astream_v2`. " + "`astream_v2` is only implemented by chat models; use `.astream()` " + "for generic Runnables." + ) + raise NotImplementedError(msg) + @overload def astream_log( self, @@ -5889,6 +5933,43 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ ): yield item + @override + def stream_v2( + self, + input: Input, + config: RunnableConfig | None = None, + **kwargs: Any | None, + ) -> ChatModelStream: + """Forward `stream_v2` to the bound runnable with bound kwargs merged. + + Chat-model-specific: the bound runnable must implement `stream_v2` + (see `BaseChatModel`). Without this override, `__getattr__` would + forward the call but drop `self.kwargs` — losing tools bound via + `bind_tools`, `stop` sequences, etc. + """ + return self.bound.stream_v2( + input, + self._merge_configs(config), + **{**self.kwargs, **kwargs}, + ) + + @override + async def astream_v2( + self, + input: Input, + config: RunnableConfig | None = None, + **kwargs: Any | None, + ) -> AsyncChatModelStream: + """Forward `astream_v2` to the bound runnable with bound kwargs merged. + + Async variant of `stream_v2`. See that method for the full rationale. + """ + return await self.bound.astream_v2( + input, + self._merge_configs(config), + **{**self.kwargs, **kwargs}, + ) + @override async def astream_events( self, diff --git a/libs/core/langchain_core/tracers/_streaming.py b/libs/core/langchain_core/tracers/_streaming.py index 7ed7dcf747a..2c2b54c0e49 100644 --- a/libs/core/langchain_core/tracers/_streaming.py +++ b/libs/core/langchain_core/tracers/_streaming.py @@ -28,6 +28,25 @@ class _StreamingCallbackHandler(typing.Protocol[T]): """Used for internal astream_log and astream events implementations.""" +# THIS IS USED IN LANGGRAPH. +class _V2StreamingCallbackHandler: + """Marker base class for handlers that consume `on_stream_event` (v2). + + A handler inheriting from this class signals that it wants content- + block lifecycle events from `stream_v2` / `astream_v2` rather than + the v1 `on_llm_new_token` chunks. `BaseChatModel.invoke` uses + `isinstance(handler, _V2StreamingCallbackHandler)` to decide whether + to route an invoke through the v2 event generator. + + Implemented as a concrete marker class (not a `Protocol`) so opt-in + is explicit via inheritance. An empty `runtime_checkable` Protocol + would match every object and misroute every call. The event + delivery contract itself lives on + `BaseCallbackHandler.on_stream_event`. + """ + + __all__ = [ "_StreamingCallbackHandler", + "_V2StreamingCallbackHandler", ] diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 0c50ca58460..51fc7fe6b70 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "packaging>=23.2.0", "pydantic>=2.7.4,<3.0.0", "uuid-utils>=0.12.0,<1.0", + "langchain-protocol>=0.0.10", ] [project.urls] @@ -93,7 +94,6 @@ enable_error_code = "deprecated" # TODO: activate for 'strict' checking disallow_any_generics = false - [tool.ruff.format] docstring-code-format = true diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 4b7206de46b..b1668123a65 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -12,6 +12,8 @@ from pydantic import model_validator from typing_extensions import Self, override from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + BaseCallbackHandler, CallbackManagerForLLMRun, ) from langchain_core.language_models import ( @@ -42,6 +44,7 @@ from langchain_core.messages import ( from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs.llm_result import LLMResult from langchain_core.tracers import LogStreamCallbackHandler +from langchain_core.tracers._streaming import _V2StreamingCallbackHandler from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.context import collect_runs from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler @@ -463,6 +466,26 @@ async def test_streaming_attribute_overrides_streaming_callback() -> None: ).content == "invoke" +class _FakeV2Handler(BaseCallbackHandler, _V2StreamingCallbackHandler): + """Minimal v2 handler marker for routing tests; records nothing.""" + + +async def test_streaming_attribute_overrides_v2_callback() -> None: + """`self.streaming=False` must opt out of the v2 event path too. + + `_should_stream_v2` shares the `_streaming_disabled` opt-outs with + `_should_stream`, so an instance-level `streaming=False` takes + precedence over an attached `_V2StreamingCallbackHandler`. + """ + model = StreamingModel(streaming=False) + assert ( + await model.ainvoke([], config={"callbacks": [_FakeV2Handler()]}) + ).content == "invoke" + assert ( + model.invoke([], config={"callbacks": [_FakeV2Handler()]}) + ).content == "invoke" + + @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) def test_disable_streaming_no_streaming_model( *, @@ -1469,6 +1492,32 @@ class FakeChatModelWithInvocationParams(SimpleChatModel): return "test response" +class FakeStreamingChatModelWithInvocationParams(FakeChatModelWithInvocationParams): + """Streaming counterpart for tracer metadata tests.""" + + @override + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + del messages, stop, run_manager, kwargs + yield ChatGenerationChunk(message=AIMessageChunk(content="test response")) + + @override + async def _astream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + del messages, stop, run_manager, kwargs + yield ChatGenerationChunk(message=AIMessageChunk(content="test response")) + + def test_invocation_params_passed_to_tracer_metadata() -> None: """Test that invocation params are passed to tracer metadata.""" llm = FakeChatModelWithInvocationParams() @@ -1510,3 +1559,44 @@ def test_invocation_params_passed_to_tracer_metadata() -> None: "runtime": run.extra["runtime"], } assert run.metadata == run.extra["metadata"] + + +def test_stream_v2_invocation_params_passed_to_tracer_metadata() -> None: + """`stream_v2()` must preserve filtered invocation params for tracing.""" + llm = FakeStreamingChatModelWithInvocationParams() + collector = LangChainTracerRunCollector() + + with collector.tracing_callback() as tracer: + _ = llm.stream_v2( + [HumanMessage(content="Hello")], + config={"callbacks": [tracer]}, + stop=["done"], + ).output + + assert len(collector.runs) == 1 + metadata = collector.runs[0].extra["metadata"] + + assert metadata["_type"] == "fake-chat-model-with-invocation-params" + assert metadata["stop"] == ["done"] + assert metadata["temperature"] == 0.7 + + +async def test_astream_v2_invocation_params_passed_to_tracer_metadata() -> None: + """`astream_v2()` must preserve filtered invocation params for tracing.""" + llm = FakeStreamingChatModelWithInvocationParams() + collector = LangChainTracerRunCollector() + + with collector.tracing_callback() as tracer: + stream = await llm.astream_v2( + [HumanMessage(content="Hello")], + config={"callbacks": [tracer]}, + stop=["done"], + ) + _ = await stream + + assert len(collector.runs) == 1 + metadata = collector.runs[0].extra["metadata"] + + assert metadata["_type"] == "fake-chat-model-with-invocation-params" + assert metadata["stop"] == ["done"] + assert metadata["temperature"] == 0.7 diff --git a/libs/core/tests/unit_tests/language_models/test_chat_model_stream.py b/libs/core/tests/unit_tests/language_models/test_chat_model_stream.py new file mode 100644 index 00000000000..a27a9e9d103 --- /dev/null +++ b/libs/core/tests/unit_tests/language_models/test_chat_model_stream.py @@ -0,0 +1,904 @@ +"""Tests for ChatModelStream, AsyncChatModelStream, and projections.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any, cast + +import pytest + +from langchain_core.language_models.chat_model_stream import ( + AsyncChatModelStream, + AsyncProjection, + ChatModelStream, + SyncProjection, + SyncTextProjection, +) + +if TYPE_CHECKING: + from langchain_protocol.protocol import ContentBlockFinishData, MessagesData + +# --------------------------------------------------------------------------- +# Projection unit tests +# --------------------------------------------------------------------------- + + +class TestSyncProjection: + """Test SyncProjection push/pull mechanics.""" + + def test_push_and_iterate(self) -> None: + proj = SyncProjection() + proj.push("a") + proj.push("b") + proj.complete(["a", "b"]) + assert list(proj) == ["a", "b"] + + def test_get_returns_final_value(self) -> None: + proj = SyncProjection() + proj.push("x") + proj.complete("final") + assert proj.get() == "final" + + def test_request_more_pulls(self) -> None: + proj = SyncProjection() + calls = iter(["a", "b", None]) + + def pump() -> bool: + val = next(calls) + if val is None: + proj.complete("ab") + return True + proj.push(val) + return True + + proj._request_more = pump + assert list(proj) == ["a", "b"] + assert proj.get() == "ab" + + def test_error_propagation(self) -> None: + proj = SyncProjection() + proj.push("partial") + proj.fail(ValueError("boom")) + with pytest.raises(ValueError, match="boom"): + list(proj) + + def test_error_on_get(self) -> None: + proj = SyncProjection() + proj.fail(ValueError("boom")) + with pytest.raises(ValueError, match="boom"): + proj.get() + + def test_multi_cursor_replay(self) -> None: + proj = SyncProjection() + proj.push("a") + proj.push("b") + proj.complete(None) + assert list(proj) == ["a", "b"] + assert list(proj) == ["a", "b"] # Second iteration replays + + def test_empty_projection(self) -> None: + proj = SyncProjection() + proj.complete([]) + assert list(proj) == [] + assert proj.get() == [] + + +class TestSyncTextProjection: + """Test SyncTextProjection string convenience methods.""" + + def test_str_drains(self) -> None: + proj = SyncTextProjection() + proj.push("Hello") + proj.push(" world") + proj.complete("Hello world") + assert str(proj) == "Hello world" + + def test_str_with_pump(self) -> None: + proj = SyncTextProjection() + done = False + + def pump() -> bool: + nonlocal done + if not done: + proj.push("Hi") + proj.complete("Hi") + done = True + return True + return False + + proj._request_more = pump + assert str(proj) == "Hi" + + def test_bool_nonempty(self) -> None: + proj = SyncTextProjection() + assert not proj + proj.push("x") + assert proj + + def test_repr(self) -> None: + proj = SyncTextProjection() + proj.push("hello") + assert repr(proj) == "'hello'" + proj.complete("hello") + assert repr(proj) == "'hello'" + + +class TestAsyncProjection: + """Test AsyncProjection async iteration and awaiting.""" + + @pytest.mark.asyncio + async def test_await_final_value(self) -> None: + proj = AsyncProjection() + proj.push("a") + proj.complete("final") + assert await proj == "final" + + @pytest.mark.asyncio + async def test_async_iter(self) -> None: + proj = AsyncProjection() + + async def produce() -> None: + await asyncio.sleep(0) + proj.push("x") + await asyncio.sleep(0) + proj.push("y") + await asyncio.sleep(0) + proj.complete("xy") + + asyncio.get_running_loop().create_task(produce()) + deltas = [d async for d in proj] + assert deltas == ["x", "y"] + + @pytest.mark.asyncio + async def test_error_on_await(self) -> None: + proj = AsyncProjection() + proj.fail(ValueError("async boom")) + with pytest.raises(ValueError, match="async boom"): + await proj + + @pytest.mark.asyncio + async def test_error_on_iter(self) -> None: + proj = AsyncProjection() + proj.push("partial") + proj.fail(ValueError("mid-stream")) + with pytest.raises(ValueError, match="mid-stream"): + async for _ in proj: + pass + + @pytest.mark.asyncio + async def test_arequest_more_drives_iteration(self) -> None: + """Cursor drives the async pump when the buffer is empty.""" + proj = AsyncProjection() + deltas = iter(["a", "b", "c"]) + + async def pump() -> bool: + try: + proj.push(next(deltas)) + except StopIteration: + proj.complete("abc") + return False + return True + + proj.set_arequest_more(pump) + collected = [d async for d in proj] + assert collected == ["a", "b", "c"] + assert await proj == "abc" + + @pytest.mark.asyncio + async def test_arequest_more_drives_await(self) -> None: + """`await projection` drives the pump too, not just iteration.""" + proj = AsyncProjection() + steps = iter([("push", "x"), ("push", "y"), ("complete", "xy")]) + + async def pump() -> bool: + try: + action, value = next(steps) + except StopIteration: + return False + if action == "push": + proj.push(value) + else: + proj.complete(value) + return True + + proj.set_arequest_more(pump) + assert await proj == "xy" + + @pytest.mark.asyncio + async def test_arequest_more_stops_when_pump_exhausts(self) -> None: + """Pump returning False without completing ends iteration cleanly.""" + proj = AsyncProjection() + pushed = [False] + + async def pump() -> bool: + if not pushed[0]: + proj.push("only") + pushed[0] = True + return True + return False + + proj.set_arequest_more(pump) + collected = [d async for d in proj] + assert collected == ["only"] + + @pytest.mark.asyncio + async def test_async_chat_model_stream_set_arequest_more_fans_out(self) -> None: + """`set_arequest_more` wires every projection on AsyncChatModelStream.""" + stream = AsyncChatModelStream(message_id="m1") + + async def pump() -> bool: + return False + + stream.set_arequest_more(pump) + for proj in ( + stream._text_proj, + stream._reasoning_proj, + stream._tool_calls_proj, + stream._output_proj, + stream._events_proj, + ): + assert proj._arequest_more is pump + + @pytest.mark.asyncio + async def test_concurrent_text_and_output_share_pump(self) -> None: + """Concurrent `stream.text` + `await stream.output` both drive the pump.""" + stream = AsyncChatModelStream(message_id="m1") + + events: list[MessagesData] = [ + { + "event": "message-start", + "role": "ai", + "message_id": "m1", + "metadata": {"provider": "test", "model": "fake"}, + }, + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": "hello "}, + }, + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": "world"}, + }, + { + "event": "content-block-finish", + "index": 0, + "content_block": {"type": "text", "text": "hello world"}, + }, + {"event": "message-finish"}, + ] + cursor = iter(events) + pump_lock = asyncio.Lock() + + async def pump() -> bool: + async with pump_lock: + try: + evt = next(cursor) + except StopIteration: + return False + stream.dispatch(evt) + return True + + stream.set_arequest_more(pump) + + async def drain_text() -> str: + buf = [delta async for delta in stream.text] + return "".join(buf) + + text, message = await asyncio.gather(drain_text(), stream.output) + assert text == "hello world" + assert message.content == [{"type": "text", "text": "hello world", "index": 0}] + + +# --------------------------------------------------------------------------- +# ChatModelStream unit tests +# --------------------------------------------------------------------------- + + +class TestChatModelStream: + """Test sync ChatModelStream via `stream.dispatch`.""" + + def test_text_projection_cached(self) -> None: + stream = ChatModelStream() + assert stream.text is stream.text + + def test_reasoning_projection_cached(self) -> None: + stream = ChatModelStream() + assert stream.reasoning is stream.reasoning + + def test_tool_calls_projection_cached(self) -> None: + stream = ChatModelStream() + assert stream.tool_calls is stream.tool_calls + + def test_text_deltas_via_pump(self) -> None: + stream = ChatModelStream() + events: list[MessagesData] = [ + {"event": "message-start", "role": "ai"}, + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": "Hi"}, + }, + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": " there"}, + }, + { + "event": "content-block-finish", + "index": 0, + "content_block": {"type": "text", "text": "Hi there"}, + }, + {"event": "message-finish"}, + ] + idx = 0 + + def pump() -> bool: + nonlocal idx + if idx >= len(events): + return False + stream.dispatch(events[idx]) + idx += 1 + return True + + stream.bind_pump(pump) + assert list(stream.text) == ["Hi", " there"] + assert str(stream.text) == "Hi there" + + def test_tool_call_chunk_streaming(self) -> None: + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": { + "type": "tool_call_chunk", + "id": "tc1", + "name": "search", + "args": '{"q":', + "index": 0, + }, + } + ) + stream.dispatch( + { # type: ignore[arg-type,misc] + "event": "content-block-delta", + "index": 0, + "content_block": { + "type": "tool_call_chunk", + "args": '"test"}', + "index": 0, + }, + } + ) + stream.dispatch( + { + "event": "content-block-finish", + "index": 0, + "content_block": { + "type": "tool_call", + "id": "tc1", + "name": "search", + "args": {"q": "test"}, + }, + } + ) + stream.dispatch({"event": "message-finish"}) + + # Check chunk deltas were pushed + chunks = list(stream.tool_calls) + assert len(chunks) == 2 # two chunk deltas + assert chunks[0]["type"] == "tool_call_chunk" + assert chunks[0]["name"] == "search" + + # Check finalized tool calls + finalized = stream.tool_calls.get() + assert len(finalized) == 1 + assert finalized[0]["name"] == "search" + assert finalized[0]["args"] == {"q": "test"} + + def test_multi_tool_parallel(self) -> None: + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + # Tool 1 starts + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": { + "type": "tool_call_chunk", + "id": "t1", + "name": "foo", + "args": '{"a":', + "index": 0, + }, + } + ) + # Tool 2 starts + stream.dispatch( + { + "event": "content-block-delta", + "index": 1, + "content_block": { + "type": "tool_call_chunk", + "id": "t2", + "name": "bar", + "args": '{"b":', + "index": 1, + }, + } + ) + # Tool 1 finishes + stream.dispatch( + { + "event": "content-block-finish", + "index": 0, + "content_block": { + "type": "tool_call", + "id": "t1", + "name": "foo", + "args": {"a": 1}, + }, + } + ) + # Tool 2 finishes + stream.dispatch( + { + "event": "content-block-finish", + "index": 1, + "content_block": { + "type": "tool_call", + "id": "t2", + "name": "bar", + "args": {"b": 2}, + }, + } + ) + stream.dispatch({"event": "message-finish"}) + + finalized = stream.tool_calls.get() + assert len(finalized) == 2 + assert finalized[0]["name"] == "foo" + assert finalized[1]["name"] == "bar" + + def test_output_assembles_aimessage(self) -> None: + stream = ChatModelStream(message_id="msg-1") + stream.dispatch( + { + "event": "message-start", + "role": "ai", + "metadata": {"provider": "anthropic", "model": "claude-4"}, + } + ) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": "Hello"}, + } + ) + stream.dispatch( + { + "event": "content-block-finish", + "index": 0, + "content_block": {"type": "text", "text": "Hello"}, + } + ) + stream.dispatch( + { + "event": "message-finish", + "metadata": {"finish_reason": "stop"}, + "usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + } + ) + + msg = stream.output + assert msg.content == [{"type": "text", "text": "Hello", "index": 0}] + assert msg.id == "msg-1" + assert msg.response_metadata["finish_reason"] == "stop" + assert msg.response_metadata["model_provider"] == "anthropic" + assert msg.usage_metadata is not None + assert msg.usage_metadata["input_tokens"] == 10 + + def test_error_propagates_to_projections(self) -> None: + stream = ChatModelStream() + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": "partial"}, + } + ) + stream.fail(RuntimeError("connection lost")) + + with pytest.raises(RuntimeError, match="connection lost"): + str(stream.text) + + with pytest.raises(RuntimeError, match="connection lost"): + stream.tool_calls.get() + + def test_raw_event_iteration(self) -> None: + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": "hi"}, + } + ) + stream.dispatch({"event": "message-finish"}) + + events = list(stream) + assert len(events) == 3 + assert events[0]["event"] == "message-start" + assert events[2]["event"] == "message-finish" + + def test_raw_event_multi_cursor(self) -> None: + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch({"event": "message-finish"}) + + assert list(stream) == list(stream) # Replay + + def test_invalid_tool_call_preserved_on_finish(self) -> None: + """An `invalid_tool_call` finish lands on `invalid_tool_calls`.""" + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch( + { + "event": "content-block-finish", + "index": 0, + "content_block": { + "type": "invalid_tool_call", + "id": "call_1", + "name": "search", + "args": '{"q": ', # malformed + "error": "Failed to parse tool call arguments as JSON", + }, + } + ) + stream.dispatch({"event": "message-finish"}) + + msg = stream.output + assert msg.tool_calls == [] + assert len(msg.invalid_tool_calls) == 1 + assert msg.invalid_tool_calls[0]["name"] == "search" + assert msg.invalid_tool_calls[0]["args"] == '{"q": ' + assert msg.invalid_tool_calls[0]["error"] == ( + "Failed to parse tool call arguments as JSON" + ) + + def test_invalid_tool_call_survives_sweep(self) -> None: + """Regression: finish deletes stale chunk, sweep cannot revive it.""" + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + # Stream a tool_call_chunk with malformed JSON args + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": { + "type": "tool_call_chunk", + "id": "call_1", + "name": "search", + "args": '{"q": ', + "index": 0, + }, + } + ) + # Finish event declares the call invalid + stream.dispatch( + { + "event": "content-block-finish", + "index": 0, + "content_block": { + "type": "invalid_tool_call", + "id": "call_1", + "name": "search", + "args": '{"q": ', + "error": "Failed to parse tool call arguments as JSON", + }, + } + ) + stream.dispatch({"event": "message-finish"}) + + msg = stream.output + # The sweep must NOT have revived the chunk as an empty-args tool_call. + assert msg.tool_calls == [] + assert len(msg.invalid_tool_calls) == 1 + + def test_output_content_uses_protocol_tool_call_shape(self) -> None: + """`.output.content` must emit `type: tool_call`, not legacy tool_use.""" + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": "Let me search."}, + } + ) + stream.dispatch( + { + "event": "content-block-finish", + "index": 0, + "content_block": {"type": "text", "text": "Let me search."}, + } + ) + stream.dispatch( + { + "event": "content-block-finish", + "index": 1, + "content_block": { + "type": "tool_call", + "id": "call_1", + "name": "search", + "args": {"q": "weather"}, + }, + } + ) + stream.dispatch({"event": "message-finish"}) + + msg = stream.output + assert isinstance(msg.content, list) + content = cast("list[dict[str, Any]]", msg.content) + types = [b.get("type") for b in content] + assert types == ["text", "tool_call"] + tool_block = content[1] + assert tool_block["name"] == "search" + assert tool_block["args"] == {"q": "weather"} + # Legacy shape fields must be absent + assert "input" not in tool_block + assert tool_block.get("type") != "tool_use" + + def test_server_tool_call_finish_lands_in_output_content(self) -> None: + """Server-executed tool call finish events flow into .output.content.""" + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch( + { + "event": "content-block-finish", + "index": 0, + "content_block": { + "type": "server_tool_call", + "id": "srv_1", + "name": "web_search", + "args": {"q": "weather"}, + }, + } + ) + stream.dispatch( + cast( + "ContentBlockFinishData", + { + "event": "content-block-finish", + "index": 1, + "content_block": { + "type": "server_tool_result", + "tool_call_id": "srv_1", + "status": "success", + "output": "62F, clear", + }, + }, + ) + ) + stream.dispatch({"event": "message-finish"}) + + msg = stream.output + assert isinstance(msg.content, list) + content = cast("list[dict[str, Any]]", msg.content) + types = [b.get("type") for b in content] + assert types == ["server_tool_call", "server_tool_result"] + # Regular tool_calls projection must NOT include server-executed ones + assert msg.tool_calls == [] + + def test_server_tool_call_chunk_sweep(self) -> None: + """Unfinished server_tool_call_chunks get swept to server_tool_call.""" + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": { + "type": "server_tool_call_chunk", + "id": "srv_1", + "name": "web_search", + "args": '{"q":', + }, + } + ) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": { + "type": "server_tool_call_chunk", + "args": ' "weather"}', + }, + } + ) + stream.dispatch({"event": "message-finish"}) + + msg = stream.output + assert isinstance(msg.content, list) + content = cast("list[dict[str, Any]]", msg.content) + assert content[0]["type"] == "server_tool_call" + assert content[0]["args"] == {"q": "weather"} + assert content[0]["name"] == "web_search" + + def test_image_block_pass_through(self) -> None: + """An image block finished via the event stream reaches .output.content.""" + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch( + { + "event": "content-block-finish", + "index": 0, + "content_block": { + "type": "image", + "url": "https://example.com/cat.png", + "mime_type": "image/png", + }, + } + ) + stream.dispatch({"event": "message-finish"}) + + msg = stream.output + assert isinstance(msg.content, list) + assert msg.content[0] == { + "type": "image", + "url": "https://example.com/cat.png", + "mime_type": "image/png", + "index": 0, + } + + def test_sweep_of_unfinished_malformed_chunk_produces_invalid_tool_call( + self, + ) -> None: + """Unfinished chunk with malformed JSON sweeps to invalid_tool_call.""" + stream = ChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": { + "type": "tool_call_chunk", + "id": "call_1", + "name": "search", + "args": '{"q": ', # malformed, never completed + "index": 0, + }, + } + ) + stream.dispatch({"event": "message-finish"}) + + msg = stream.output + assert msg.tool_calls == [] + assert len(msg.invalid_tool_calls) == 1 + itc = msg.invalid_tool_calls[0] + assert itc["name"] == "search" + assert itc["args"] == '{"q": ' + assert "Failed to parse" in (itc["error"] or "") + + +# --------------------------------------------------------------------------- +# AsyncChatModelStream unit tests +# --------------------------------------------------------------------------- + + +class TestAsyncChatModelStream: + """Test async ChatModelStream.""" + + @pytest.mark.asyncio + async def test_await_output(self) -> None: + stream = AsyncChatModelStream(message_id="m1") + + async def produce() -> None: + await asyncio.sleep(0) + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": "Hi"}, + } + ) + stream.dispatch({"event": "message-finish"}) + + asyncio.get_running_loop().create_task(produce()) + msg = await stream + assert msg.content == "Hi" + + @pytest.mark.asyncio + async def test_async_text_deltas(self) -> None: + stream = AsyncChatModelStream() + + async def produce() -> None: + await asyncio.sleep(0) + stream.dispatch({"event": "message-start", "role": "ai"}) + await asyncio.sleep(0) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": "a"}, + } + ) + await asyncio.sleep(0) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": {"type": "text", "text": "b"}, + } + ) + await asyncio.sleep(0) + stream.dispatch({"event": "message-finish"}) + + asyncio.get_running_loop().create_task(produce()) + deltas = [d async for d in stream.text] + assert deltas == ["a", "b"] + + @pytest.mark.asyncio + async def test_await_tool_calls(self) -> None: + stream = AsyncChatModelStream() + stream.dispatch({"event": "message-start", "role": "ai"}) + stream.dispatch( + { + "event": "content-block-delta", + "index": 0, + "content_block": { + "type": "tool_call_chunk", + "id": "tc1", + "name": "search", + "args": '{"q":"hi"}', + "index": 0, + }, + } + ) + stream.dispatch( + { + "event": "content-block-finish", + "index": 0, + "content_block": { + "type": "tool_call", + "id": "tc1", + "name": "search", + "args": {"q": "hi"}, + }, + } + ) + stream.dispatch({"event": "message-finish"}) + + result = await stream.tool_calls + assert len(result) == 1 + assert result[0]["name"] == "search" + + @pytest.mark.asyncio + async def test_async_raw_event_iteration(self) -> None: + stream = AsyncChatModelStream() + + async def produce() -> None: + await asyncio.sleep(0) + stream.dispatch({"event": "message-start", "role": "ai"}) + await asyncio.sleep(0) + stream.dispatch({"event": "message-finish"}) + + asyncio.get_running_loop().create_task(produce()) + events = [e async for e in stream] + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_error_propagation(self) -> None: + stream = AsyncChatModelStream() + stream.fail(RuntimeError("async fail")) + + with pytest.raises(RuntimeError, match="async fail"): + await stream.text + with pytest.raises(RuntimeError, match="async fail"): + await stream diff --git a/libs/core/tests/unit_tests/language_models/test_chat_model_streamer.py b/libs/core/tests/unit_tests/language_models/test_chat_model_streamer.py new file mode 100644 index 00000000000..4390825beef --- /dev/null +++ b/libs/core/tests/unit_tests/language_models/test_chat_model_streamer.py @@ -0,0 +1,445 @@ +"""Tests for BaseChatModel.stream_v2() / astream_v2().""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any + +import pytest +from pydantic import Field + +from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler +from langchain_core.language_models.chat_model_stream import ( + AsyncChatModelStream, + ChatModelStream, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.fake_chat_models import FakeListChatModel +from langchain_core.messages import AIMessageChunk +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Iterator + + from langchain_protocol.protocol import MessagesData + + from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, + ) + from langchain_core.messages import BaseMessage + from langchain_core.outputs import LLMResult + + +class TestStreamV2Sync: + """Test BaseChatModel.stream_v2() with FakeListChatModel.""" + + def test_stream_text(self) -> None: + model = FakeListChatModel(responses=["Hello world!"]) + stream = model.stream_v2("test") + + assert isinstance(stream, ChatModelStream) + deltas = list(stream.text) + assert "".join(deltas) == "Hello world!" + assert stream.done + + def test_stream_output(self) -> None: + model = FakeListChatModel(responses=["Hello!"]) + stream = model.stream_v2("test") + + msg = stream.output + assert isinstance(msg.content, list) + assert msg.content == [{"type": "text", "text": "Hello!", "index": 0}] + assert msg.id is not None + + def test_stream_usage_none_for_fake(self) -> None: + model = FakeListChatModel(responses=["Hi"]) + stream = model.stream_v2("test") + # Drain + for _ in stream.text: + pass + assert stream.output.usage_metadata is None + + def test_stream_raw_events(self) -> None: + model = FakeListChatModel(responses=["ab"]) + stream = model.stream_v2("test") + + events = list(stream) + event_types = [e.get("event") for e in events] + assert event_types[0] == "message-start" + assert event_types[-1] == "message-finish" + assert "content-block-delta" in event_types + + +class TestAstreamV2: + """Test BaseChatModel.astream_v2() with FakeListChatModel.""" + + @pytest.mark.asyncio + async def test_astream_text_await(self) -> None: + model = FakeListChatModel(responses=["Hello!"]) + stream = await model.astream_v2("test") + + assert isinstance(stream, AsyncChatModelStream) + full = await stream.text + assert full == "Hello!" + + @pytest.mark.asyncio + async def test_astream_text_deltas(self) -> None: + model = FakeListChatModel(responses=["Hi"]) + stream = await model.astream_v2("test") + + deltas = [d async for d in stream.text] + assert "".join(deltas) == "Hi" + + @pytest.mark.asyncio + async def test_astream_await_output(self) -> None: + model = FakeListChatModel(responses=["Hey"]) + stream = await model.astream_v2("test") + + msg = await stream + assert msg.content == [{"type": "text", "text": "Hey", "index": 0}] + + +class _RecordingHandler(BaseCallbackHandler): + """Sync callback handler that records lifecycle hook invocations.""" + + def __init__(self) -> None: + self.events: list[str] = [] + self.stream_events: list[MessagesData] = [] + self.last_llm_end_response: LLMResult | None = None + + def on_chat_model_start(self, *args: Any, **kwargs: Any) -> None: + del args, kwargs + self.events.append("on_chat_model_start") + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + del kwargs + self.events.append("on_llm_end") + self.last_llm_end_response = response + + def on_llm_error(self, *args: Any, **kwargs: Any) -> None: + del args, kwargs + self.events.append("on_llm_error") + + def on_stream_event(self, event: MessagesData, **kwargs: Any) -> None: + del kwargs + self.stream_events.append(event) + + +class _AsyncRecordingHandler(AsyncCallbackHandler): + """Async callback handler that records lifecycle hook invocations.""" + + def __init__(self) -> None: + self.events: list[str] = [] + self.stream_events: list[MessagesData] = [] + self.last_llm_end_response: LLMResult | None = None + + async def on_chat_model_start(self, *args: Any, **kwargs: Any) -> None: + del args, kwargs + self.events.append("on_chat_model_start") + + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + del kwargs + self.events.append("on_llm_end") + self.last_llm_end_response = response + + async def on_llm_error(self, *args: Any, **kwargs: Any) -> None: + del args, kwargs + self.events.append("on_llm_error") + + async def on_stream_event(self, event: MessagesData, **kwargs: Any) -> None: + del kwargs + self.stream_events.append(event) + + +class _EmptyStreamModel(BaseChatModel): + """Fake chat model whose stream producers yield no chunks.""" + + @property + def _llm_type(self) -> str: + return "empty-stream-fake" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + del messages, stop, run_manager, kwargs + raise NotImplementedError + + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + del messages, stop, run_manager, kwargs + if False: + yield ChatGenerationChunk(message=AIMessageChunk(content="")) + + async def _astream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + del messages, stop, run_manager, kwargs + if False: + yield ChatGenerationChunk(message=AIMessageChunk(content="")) + + +class TestCallbacks: + """Verify stream_v2 fires on_llm_end / on_llm_error callbacks.""" + + def test_stream_v2_defers_on_chat_model_start_until_consumed(self) -> None: + handler = _RecordingHandler() + model = FakeListChatModel(responses=["done"], callbacks=[handler]) + + stream = model.stream_v2("test") + + assert handler.events == [] + + _ = stream.output + + assert handler.events[0] == "on_chat_model_start" + + def test_on_llm_end_fires_after_drain(self) -> None: + handler = _RecordingHandler() + model = FakeListChatModel(responses=["done"], callbacks=[handler]) + stream = model.stream_v2("test") + for _ in stream.text: + pass + _ = stream.output + + assert "on_chat_model_start" in handler.events + assert "on_llm_end" in handler.events + assert handler.events.index("on_llm_end") > handler.events.index( + "on_chat_model_start" + ) + + @pytest.mark.asyncio + async def test_on_llm_end_fires_async(self) -> None: + handler = _AsyncRecordingHandler() + model = FakeListChatModel(responses=["done"], callbacks=[handler]) + stream = await model.astream_v2("test") + _ = await stream + + assert "on_chat_model_start" in handler.events + assert "on_llm_end" in handler.events + + @pytest.mark.asyncio + async def test_astream_v2_defers_on_chat_model_start_until_consumed(self) -> None: + handler = _AsyncRecordingHandler() + model = FakeListChatModel(responses=["done"], callbacks=[handler]) + + stream = await model.astream_v2("test") + + assert handler.events == [] + + _ = await stream + + assert handler.events[0] == "on_chat_model_start" + + def test_on_llm_end_receives_assembled_message(self) -> None: + """The LLMResult passed to on_llm_end must carry the final message. + + Without this, LangSmith traces would see an empty generations list. + """ + handler = _RecordingHandler() + model = FakeListChatModel(responses=["hello"], callbacks=[handler]) + stream = model.stream_v2("test") + _ = stream.output + + response = handler.last_llm_end_response + assert response is not None + assert response.generations + gen = response.generations[0][0] + assert isinstance(gen, ChatGeneration) + assert gen.message.content == [{"type": "text", "text": "hello", "index": 0}] + + @pytest.mark.asyncio + async def test_on_llm_end_receives_assembled_message_async(self) -> None: + handler = _AsyncRecordingHandler() + model = FakeListChatModel(responses=["hello"], callbacks=[handler]) + stream = await model.astream_v2("test") + _ = await stream + + response = handler.last_llm_end_response + assert response is not None + assert response.generations + gen = response.generations[0][0] + assert isinstance(gen, ChatGeneration) + assert gen.message.content == [{"type": "text", "text": "hello", "index": 0}] + + def test_empty_stream_reports_error_without_finish_only_lifecycle(self) -> None: + handler = _RecordingHandler() + stream = _EmptyStreamModel(callbacks=[handler]).stream_v2("test") + + with pytest.raises(ValueError, match="No generation chunks were returned"): + list(stream) + + assert handler.stream_events == [] + assert "on_llm_error" in handler.events + assert "on_llm_end" not in handler.events + + @pytest.mark.asyncio + async def test_empty_astream_reports_error(self) -> None: + handler = _AsyncRecordingHandler() + stream = await _EmptyStreamModel(callbacks=[handler]).astream_v2("test") + + with pytest.raises(ValueError, match="No generation chunks were returned"): + await stream + task = stream._producer_task + assert task is not None + await task + + assert handler.stream_events == [] + assert "on_llm_error" in handler.events + assert "on_llm_end" not in handler.events + + +class TestOnStreamEvent: + """`on_stream_event` must fire once per protocol event from stream_v2.""" + + def test_on_stream_event_fires_for_every_event_sync(self) -> None: + handler = _RecordingHandler() + model = FakeListChatModel(responses=["Hi"], callbacks=[handler]) + stream = model.stream_v2("test") + _ = stream.output + + # Every event the stream sees should also reach the observer. + assert len(handler.stream_events) == len(list(stream)) + event_types = [e["event"] for e in handler.stream_events] + assert event_types[0] == "message-start" + assert event_types[-1] == "message-finish" + assert "content-block-delta" in event_types + + @pytest.mark.asyncio + async def test_on_stream_event_fires_for_every_event_async(self) -> None: + handler = _AsyncRecordingHandler() + model = FakeListChatModel(responses=["Hi"], callbacks=[handler]) + stream = await model.astream_v2("test") + _ = await stream + + event_types = [e["event"] for e in handler.stream_events] + assert event_types[0] == "message-start" + assert event_types[-1] == "message-finish" + assert "content-block-delta" in event_types + + def test_on_stream_event_ordering_relative_to_lifecycle(self) -> None: + """Stream events must all fire between on_chat_model_start and on_llm_end.""" + handler = _RecordingHandler() + model = FakeListChatModel(responses=["Hi"], callbacks=[handler]) + stream = model.stream_v2("test") + _ = stream.output + + # on_stream_event doesn't show up in `events` (different list), but + # on_chat_model_start and on_llm_end bracket the run. + assert handler.events[0] == "on_chat_model_start" + assert handler.events[-1] == "on_llm_end" + # And we did see stream events during that bracket. + assert handler.stream_events + + +class TestCancellation: + """Cancellation of `astream_v2` must propagate, not be swallowed.""" + + @pytest.mark.asyncio + async def test_astream_v2_cancellation_propagates(self) -> None: + """Cancelling the producer task must raise CancelledError. + + Regression test: the producer's `except BaseException` previously + swallowed `asyncio.CancelledError`, converting it into an + `on_llm_error` + `stream._fail` pair that never propagated. + """ + model = FakeListChatModel(responses=["abcdefghij"], sleep=0.05) + stream = await model.astream_v2("test") + aiter_ = stream.text.__aiter__() + await aiter_.__anext__() + task = stream._producer_task + assert task is not None + + await asyncio.sleep(0.01) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + assert isinstance(stream._error, asyncio.CancelledError) + + +class _KwargRecordingModel(FakeListChatModel): + """Fake model that records kwargs passed to `_stream` / `_astream`.""" + + received_kwargs: list[dict[str, Any]] = Field(default_factory=list) + + def _stream( + self, + messages: Any, + stop: Any = None, + run_manager: Any = None, + **kwargs: Any, + ) -> Any: + self.received_kwargs.append({"stop": stop, **kwargs}) + return super()._stream(messages, stop=stop, run_manager=run_manager, **kwargs) + + async def _astream( + self, + messages: Any, + stop: Any = None, + run_manager: Any = None, + **kwargs: Any, + ) -> Any: + self.received_kwargs.append({"stop": stop, **kwargs}) + async for chunk in super()._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk + + +class TestRunnableBindingForwarding: + """`RunnableBinding.stream_v2` must merge bound kwargs into the call. + + Without the explicit override on `RunnableBinding`, `__getattr__` + forwards the call but drops `self.kwargs` — so tools bound via + `bind_tools`, stop sequences bound via `bind`, etc. would be silently + ignored. + """ + + def test_bound_kwargs_reach_stream_v2(self) -> None: + model = _KwargRecordingModel(responses=["hi"]) + model.received_kwargs = [] + bound = model.bind(my_marker="sentinel-42") + + stream = bound.stream_v2("test") + for _ in stream.text: + pass + + assert len(model.received_kwargs) == 1 + assert model.received_kwargs[0].get("my_marker") == "sentinel-42" + + def test_call_kwargs_override_bound_kwargs(self) -> None: + model = _KwargRecordingModel(responses=["hi"]) + model.received_kwargs = [] + bound = model.bind(my_marker="from-bind") + + stream = bound.stream_v2("test", my_marker="from-call") + for _ in stream.text: + pass + + assert model.received_kwargs[0].get("my_marker") == "from-call" + + @pytest.mark.asyncio + async def test_bound_kwargs_reach_astream_v2(self) -> None: + model = _KwargRecordingModel(responses=["hi"]) + model.received_kwargs = [] + bound = model.bind(my_marker="sentinel-async") + + stream = await bound.astream_v2("test") + _ = await stream + + assert len(model.received_kwargs) == 1 + assert model.received_kwargs[0].get("my_marker") == "sentinel-async" diff --git a/libs/core/tests/unit_tests/language_models/test_compat_bridge.py b/libs/core/tests/unit_tests/language_models/test_compat_bridge.py new file mode 100644 index 00000000000..ce9f2f75a5a --- /dev/null +++ b/libs/core/tests/unit_tests/language_models/test_compat_bridge.py @@ -0,0 +1,986 @@ +"""Tests for the compat bridge (chunk-to-event conversion).""" + +from typing import TYPE_CHECKING, Any, cast + +import pytest +from langchain_tests.utils.stream_lifecycle import assert_valid_event_stream + +from langchain_core.language_models._compat_bridge import ( + CompatBlock, + _finalize_block, + _to_protocol_usage, + amessage_to_events, + chunks_to_events, + message_to_events, +) +from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk + +if TYPE_CHECKING: + from langchain_protocol.protocol import ( + ContentBlockDeltaData, + InvalidToolCall, + MessageFinishData, + MessageStartData, + ReasoningContentBlock, + ServerToolCall, + TextContentBlock, + ToolCall, + ) + + +# --------------------------------------------------------------------------- +# Pure helpers +# --------------------------------------------------------------------------- + + +def test_finalize_block_text_passes_through() -> None: + block: CompatBlock = {"type": "text", "text": "hello"} + result = _finalize_block(block) + text_result = cast("TextContentBlock", result) + assert text_result["type"] == "text" + assert text_result["text"] == "hello" + + +def test_finalize_block_tool_call_chunk_valid_json() -> None: + block: CompatBlock = { + "type": "tool_call_chunk", + "args": '{"query": "test"}', + "id": "tc1", + "name": "search", + } + result = _finalize_block(block) + tool_call = cast("ToolCall", result) + assert tool_call["type"] == "tool_call" + assert tool_call["id"] == "tc1" + assert tool_call["name"] == "search" + assert tool_call["args"] == {"query": "test"} + + +def test_finalize_block_tool_call_chunk_invalid_json() -> None: + block: CompatBlock = { + "type": "tool_call_chunk", + "args": "not json", + "id": "tc1", + "name": "search", + } + result = _finalize_block(block) + invalid = cast("InvalidToolCall", result) + assert invalid["type"] == "invalid_tool_call" + assert invalid.get("error") is not None + + +def test_finalize_block_server_tool_call_chunk_valid_json() -> None: + block: CompatBlock = { + "type": "server_tool_call_chunk", + "args": '{"q": "weather"}', + "id": "srv_1", + "name": "web_search", + } + result = _finalize_block(block) + server_result = cast("ServerToolCall", result) + assert server_result["type"] == "server_tool_call" + assert server_result["id"] == "srv_1" + assert server_result["name"] == "web_search" + assert server_result["args"] == {"q": "weather"} + + +def test_finalize_block_server_tool_call_chunk_invalid_json() -> None: + block: CompatBlock = { + "type": "server_tool_call_chunk", + "args": "not json", + "id": "srv_1", + "name": "web_search", + } + result = _finalize_block(block) + invalid = cast("InvalidToolCall", result) + assert invalid["type"] == "invalid_tool_call" + assert invalid.get("error") is not None + + +def test_to_protocol_usage_present() -> None: + usage = {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30} + result = _to_protocol_usage(usage) + assert result is not None + assert result["input_tokens"] == 10 + assert result["output_tokens"] == 20 + + +def test_to_protocol_usage_none() -> None: + assert _to_protocol_usage(None) is None + + +# --------------------------------------------------------------------------- +# chunks_to_events: streaming lifecycle +# --------------------------------------------------------------------------- + + +def test_chunks_to_events_text_only() -> None: + """Multi-chunk text stream produces a clean lifecycle.""" + chunks = [ + ChatGenerationChunk(message=AIMessageChunk(content="Hello", id="msg-1")), + ChatGenerationChunk(message=AIMessageChunk(content=" world", id="msg-1")), + ] + + events = list(chunks_to_events(iter(chunks), message_id="msg-1")) + event_types = [e["event"] for e in events] + + assert event_types[0] == "message-start" + assert "content-block-start" in event_types + assert event_types.count("content-block-delta") == 2 + assert "content-block-finish" in event_types + assert event_types[-1] == "message-finish" + + finish = cast("MessageFinishData", events[-1]) + # No provider finish_reason in fixtures — metadata carries no + # `finish_reason` key (the bridge passes response_metadata through + # unchanged). + assert "finish_reason" not in (finish.get("metadata") or {}) + + +def test_chunks_to_events_empty_iterator() -> None: + """No chunks means no events.""" + assert list(chunks_to_events(iter([]))) == [] + + +def test_chunks_to_events_block_transitions_close_previous_block() -> None: + """String-keyed blocks that transition mid-stream each get their own lifecycle. + + Regression test for OpenAI `responses/v1` style streams where + `content_blocks` uses string identifiers (e.g. `"lc_rs_305f30"`) to + distinguish blocks. Each distinct block must get its own + `content-block-start` / `content-block-finish` pair, with sequential + `uint` wire indices, and blocks must not interleave. + """ + chunks = [ + ChatGenerationChunk( + message=AIMessageChunk( + content=[ + {"type": "reasoning", "reasoning": "hmm", "index": "rs_a"}, + ], + id="msg-1", + ) + ), + ChatGenerationChunk( + message=AIMessageChunk( + content=[ + {"type": "reasoning", "reasoning": " then", "index": "rs_a"}, + ], + id="msg-1", + ) + ), + ChatGenerationChunk( + message=AIMessageChunk( + content=[ + {"type": "reasoning", "reasoning": "different", "index": "rs_b"}, + ], + id="msg-1", + ) + ), + ChatGenerationChunk( + message=AIMessageChunk( + content=[ + {"type": "text", "text": "answer: ", "index": "txt_1"}, + ], + id="msg-1", + ) + ), + ChatGenerationChunk( + message=AIMessageChunk( + content=[ + {"type": "text", "text": "42", "index": "txt_1"}, + ], + id="msg-1", + ) + ), + ] + + events = list(chunks_to_events(iter(chunks), message_id="msg-1")) + + starts: list[Any] = [e for e in events if e["event"] == "content-block-start"] + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + assert [s["content_block"]["type"] for s in starts] == [ + "reasoning", + "reasoning", + "text", + ] + assert [f["content_block"]["type"] for f in finishes] == [ + "reasoning", + "reasoning", + "text", + ] + # Wire indices are sequential uints regardless of source-side keys. + assert [s["index"] for s in starts] == [0, 1, 2] + assert [f["index"] for f in finishes] == [0, 1, 2] + + # Finish events must be interleaved with starts (no-interleave rule): + # block 0 finishes before block 1 starts, etc. + events_any: list[Any] = events + lifecycle = [ + (e["event"], e["index"]) + for e in events_any + if e["event"] in ("content-block-start", "content-block-finish") + ] + assert lifecycle == [ + ("content-block-start", 0), + ("content-block-finish", 0), + ("content-block-start", 1), + ("content-block-finish", 1), + ("content-block-start", 2), + ("content-block-finish", 2), + ] + + # Each finish carries the accumulated content for its block. + assert finishes[0]["content_block"]["reasoning"] == "hmm then" + assert finishes[1]["content_block"]["reasoning"] == "different" + assert finishes[2]["content_block"]["text"] == "answer: 42" + + +def test_chunks_to_events_tool_call_multichunk() -> None: + """Partial tool-call args across chunks finalize to a single tool_call.""" + chunks = [ + ChatGenerationChunk( + message=AIMessageChunk( + content="", + id="msg-1", + tool_call_chunks=[ + { + "index": 0, + "id": "tc1", + "name": "search", + "args": '{"q":', + "type": "tool_call_chunk", + } + ], + ) + ), + ChatGenerationChunk( + message=AIMessageChunk( + content="", + id="msg-1", + tool_call_chunks=[ + { + "index": 0, + "id": None, + "name": None, + "args": ' "test"}', + "type": "tool_call_chunk", + } + ], + ) + ), + ] + + events = list(chunks_to_events(iter(chunks), message_id="msg-1")) + event_types = [e["event"] for e in events] + + assert event_types[0] == "message-start" + assert "content-block-start" in event_types + assert "content-block-finish" in event_types + assert event_types[-1] == "message-finish" + + # Exactly one block finalized, args parsed to a dict. + finish_events: list[Any] = [ + e for e in events if e["event"] == "content-block-finish" + ] + assert len(finish_events) == 1 + finalized = cast("ToolCall", finish_events[0]["content_block"]) + assert finalized["type"] == "tool_call" + assert finalized["args"] == {"q": "test"} + + # No provider finish_reason in the fixture chunks — the bridge does + # not synthesize one. It deliberately does not infer `"tool_use"` + # from the presence of a valid tool_call either; terminal reasons + # are provider-specific (see `_build_message_finish`). + assert "finish_reason" not in ( + cast("MessageFinishData", events[-1]).get("metadata") or {} + ) + + +def test_chunks_to_events_invalid_tool_call_keeps_stop_reason() -> None: + """Malformed tool-args become invalid_tool_call; finish_reason stays `stop`.""" + chunks = [ + ChatGenerationChunk( + message=AIMessageChunk( + content="", + id="msg-bad", + tool_call_chunks=[ + { + "index": 0, + "id": "tc1", + "name": "search", + "args": "{oops", + "type": "tool_call_chunk", + }, + ], + ) + ), + ] + + events = list(chunks_to_events(iter(chunks), message_id="msg-bad")) + + finish_events: list[Any] = [ + e for e in events if e["event"] == "content-block-finish" + ] + assert len(finish_events) == 1 + assert finish_events[0]["content_block"]["type"] == "invalid_tool_call" + assert "finish_reason" not in ( + cast("MessageFinishData", events[-1]).get("metadata") or {} + ) + + +def test_chunks_to_events_anthropic_server_tool_use_routes_through_translator() -> None: + """`server_tool_use` shape + anthropic provider tag becomes `server_tool_call`.""" + chunks = [ + ChatGenerationChunk( + message=AIMessageChunk( + content=[ + {"type": "text", "text": "Let me search. "}, + { + "type": "server_tool_use", + "id": "srvtoolu_01", + "name": "web_search", + "input": {"query": "weather"}, + }, + ], + response_metadata={"model_provider": "anthropic"}, + ) + ), + ] + + events = list(chunks_to_events(iter(chunks))) + finish_blocks: list[Any] = [ + e["content_block"] for e in events if e["event"] == "content-block-finish" + ] + block_types = [b.get("type") for b in finish_blocks] + assert "server_tool_call" in block_types + assert "text" in block_types + + +def test_chunks_to_events_unregistered_provider_falls_back() -> None: + """Unknown provider tag doesn't crash; best-effort parsing surfaces text.""" + chunks = [ + ChatGenerationChunk( + message=AIMessageChunk( + content="Hello", + response_metadata={"model_provider": "totally-made-up-provider"}, + ) + ), + ] + + events = list(chunks_to_events(iter(chunks))) + finish_events: list[Any] = [ + e for e in events if e["event"] == "content-block-finish" + ] + assert [e["content_block"]["type"] for e in finish_events] == ["text"] + + +def test_chunks_to_events_no_provider_text_plus_tool_call() -> None: + """Without a provider tag, text + tool_call_chunks both come through. + + This is the case the old legacy path silently dropped the tool call + because it re-mined tool_call_chunks on top of the positional index + already used by the text block. Trusting content_blocks keeps them + on distinct indices. + """ + chunks = [ + ChatGenerationChunk( + message=AIMessageChunk( + content="Hello", + tool_call_chunks=[ + { + "index": 1, + "id": "t1", + "name": "search", + "args": '{"q": "x"}', + "type": "tool_call_chunk", + }, + ], + ) + ), + ] + + events = list(chunks_to_events(iter(chunks))) + finish_blocks: list[Any] = [ + e["content_block"] for e in events if e["event"] == "content-block-finish" + ] + types = [b.get("type") for b in finish_blocks] + assert "text" in types + assert "tool_call" in types + + +def test_chunks_to_events_reasoning_in_additional_kwargs() -> None: + """Reasoning packed into additional_kwargs surfaces as a reasoning block.""" + chunks = [ + ChatGenerationChunk( + message=AIMessageChunk( + content=[{"type": "text", "text": "2+2=4"}], + additional_kwargs={"reasoning_content": "Adding two and two..."}, + response_metadata={"model_provider": "unknown-open-model"}, + ) + ), + ] + + events = list(chunks_to_events(iter(chunks))) + finish_blocks: list[Any] = [ + e["content_block"] for e in events if e["event"] == "content-block-finish" + ] + types = [b.get("type") for b in finish_blocks] + assert "reasoning" in types + assert "text" in types + + +# --------------------------------------------------------------------------- +# message_to_events: finalized-message replay +# --------------------------------------------------------------------------- + + +def test_message_to_events_text_only() -> None: + msg = AIMessage(content="Hello world", id="msg-1") + events = list(message_to_events(msg)) + + event_types = [e["event"] for e in events] + assert event_types == [ + "message-start", + "content-block-start", + "content-block-delta", + "content-block-finish", + "message-finish", + ] + start = cast("MessageStartData", events[0]) + assert start["message_id"] == "msg-1" + + delta_event = cast("ContentBlockDeltaData", events[2]) + delta = cast("TextContentBlock", delta_event["content_block"]) + assert delta["text"] == "Hello world" + + final = cast("MessageFinishData", events[-1]) + assert "finish_reason" not in (final.get("metadata") or {}) + + +def test_message_to_events_empty_content_yields_start_finish_only() -> None: + msg = AIMessage(content="", id="msg-empty") + events = list(message_to_events(msg)) + event_types = [e["event"] for e in events] + assert event_types == ["message-start", "message-finish"] + + +def test_message_to_events_reasoning_text_order() -> None: + msg = AIMessage( + content=[ + {"type": "reasoning", "reasoning": "think hard"}, + {"type": "text", "text": "the answer"}, + ], + id="msg-2", + ) + events = list(message_to_events(msg)) + + starts: list[Any] = [e for e in events if e["event"] == "content-block-start"] + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + assert [s["content_block"]["type"] for s in starts] == ["reasoning", "text"] + assert [f["content_block"]["type"] for f in finishes] == ["reasoning", "text"] + + deltas: list[Any] = [e for e in events if e["event"] == "content-block-delta"] + assert len(deltas) == 2 + assert cast("ReasoningContentBlock", deltas[0]["content_block"])["reasoning"] == ( + "think hard" + ) + assert cast("TextContentBlock", deltas[1]["content_block"])["text"] == "the answer" + + +def test_message_to_events_tool_call_skips_delta() -> None: + msg = AIMessage( + content="", + id="msg-3", + tool_calls=[ + {"id": "tc1", "name": "search", "args": {"q": "hi"}, "type": "tool_call"}, + ], + ) + events = list(message_to_events(msg)) + + # Finalized tool_call blocks carry no useful incremental text, + # so no content-block-delta is emitted. + deltas: list[Any] = [e for e in events if e["event"] == "content-block-delta"] + assert deltas == [] + + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + assert len(finishes) == 1 + tc = cast("ToolCall", finishes[0]["content_block"]) + assert tc["type"] == "tool_call" + assert tc["args"] == {"q": "hi"} + + # Message has no `finish_reason` / `stop_reason` in metadata; the + # bridge does not synthesize one and does not second-guess based on + # the presence of a tool_call. + final = cast("MessageFinishData", events[-1]) + assert "finish_reason" not in (final.get("metadata") or {}) + + +def test_message_to_events_invalid_tool_calls_surfaced_from_field() -> None: + """`invalid_tool_calls` on AIMessage surface as protocol blocks. + + `AIMessage.content_blocks` does not currently include + `invalid_tool_calls`, so the bridge merges them in explicitly. + """ + msg = AIMessage( + content="", + invalid_tool_calls=[ + { + "type": "invalid_tool_call", + "id": "call_1", + "name": "search", + "args": '{"q":', + "error": "bad json", + } + ], + ) + events = list(message_to_events(msg)) + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + types = [f["content_block"]["type"] for f in finishes] + assert "invalid_tool_call" in types + + +def test_message_to_events_preserves_finish_reason_and_metadata() -> None: + msg = AIMessage( + content="done", + id="msg-4", + response_metadata={ + "finish_reason": "length", + "model_name": "test-model", + "stop_sequence": "", + }, + ) + events = list(message_to_events(msg)) + + start = cast("MessageStartData", events[0]) + assert start["metadata"] == {"model": "test-model"} + + # Passthrough: response_metadata lands on `metadata` unchanged, + # including the raw provider `finish_reason`. + final = cast("MessageFinishData", events[-1]) + assert final["metadata"] == { + "finish_reason": "length", + "model_name": "test-model", + "stop_sequence": "", + } + + +def test_message_to_events_propagates_usage() -> None: + msg = AIMessage( + content="hi", + id="msg-5", + usage_metadata={"input_tokens": 10, "output_tokens": 2, "total_tokens": 12}, + ) + events = list(message_to_events(msg)) + + final = cast("MessageFinishData", events[-1]) + assert final["usage"] == { + "input_tokens": 10, + "output_tokens": 2, + "total_tokens": 12, + } + + +def test_message_to_events_message_id_override() -> None: + msg = AIMessage(content="x", id="msg-orig") + events = list(message_to_events(msg, message_id="msg-override")) + start = cast("MessageStartData", events[0]) + assert start["message_id"] == "msg-override" + + +def test_message_to_events_self_contained_start_strips_heavy_fields() -> None: + """`content-block-start` must not duplicate heavy payload fields. + + For image/audio/video/file/non_standard and finalized tool_call blocks, + the large payload (base64 `data`, parsed `args`, arbitrary `value`) + should appear only on `content-block-finish`, not on `content-block-start`. + Start preserves correlation and small metadata fields. + """ + msg = AIMessage( + content=[ + { + "type": "image", + "id": "img-1", + "mime_type": "image/png", + "data": "A" * 1024, + }, + { + "type": "audio", + "id": "aud-1", + "mime_type": "audio/mp3", + "data": "B" * 1024, + "transcript": "hello", + }, + { + "type": "non_standard", + "id": "ns-1", + "value": {"big": "C" * 1024}, + }, + ], + id="msg-heavy", + ) + events = list(message_to_events(msg)) + + starts: list[Any] = [e for e in events if e["event"] == "content-block-start"] + assert [s["content_block"]["type"] for s in starts] == [ + "image", + "audio", + "non_standard", + ] + + image_start = starts[0]["content_block"] + assert image_start["id"] == "img-1" + assert image_start["mime_type"] == "image/png" + assert "data" not in image_start + + audio_start = starts[1]["content_block"] + assert audio_start["id"] == "aud-1" + assert audio_start["mime_type"] == "audio/mp3" + assert "data" not in audio_start + assert "transcript" not in audio_start + + ns_start = starts[2]["content_block"] + assert ns_start["type"] == "non_standard" + assert ns_start["value"] == {} + + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + assert finishes[0]["content_block"]["data"] == "A" * 1024 + assert finishes[1]["content_block"]["data"] == "B" * 1024 + assert finishes[1]["content_block"]["transcript"] == "hello" + assert finishes[2]["content_block"]["value"] == {"big": "C" * 1024} + + +def test_message_to_events_finalized_tool_call_start_strips_args() -> None: + """Finalized `tool_call` keeps id/name on start but not parsed args.""" + msg = AIMessage( + content="", + id="msg-tc", + tool_calls=[ + { + "id": "tc1", + "name": "search", + "args": {"q": "big payload " * 100}, + "type": "tool_call", + }, + ], + ) + events = list(message_to_events(msg)) + + starts: list[Any] = [e for e in events if e["event"] == "content-block-start"] + assert len(starts) == 1 + tc_start = starts[0]["content_block"] + assert tc_start["type"] == "tool_call" + assert tc_start["id"] == "tc1" + assert tc_start["name"] == "search" + assert tc_start["args"] == {} + + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + tc_finish = cast("ToolCall", finishes[0]["content_block"]) + assert tc_finish["args"] == {"q": "big payload " * 100} + + +@pytest.mark.asyncio +async def test_amessage_to_events_matches_sync() -> None: + msg = AIMessage( + content=[ + {"type": "reasoning", "reasoning": "why"}, + {"type": "text", "text": "because"}, + ], + id="msg-async", + ) + sync_events = list(message_to_events(msg)) + async_events = [e async for e in amessage_to_events(msg)] + assert async_events == sync_events + + +# --------------------------------------------------------------------------- +# Lifecycle validator: provider-style emission patterns +# --------------------------------------------------------------------------- + + +def _aimsg_chunk(blocks: list[CompatBlock], msg_id: str = "m") -> ChatGenerationChunk: + """Wrap a list of content blocks into a ChatGenerationChunk. + + Matches what a provider's `_stream` would yield per SSE event. + """ + return ChatGenerationChunk(message=AIMessageChunk(content=blocks, id=msg_id)) + + +def test_lifecycle_validator_openai_chat_completions_style() -> None: + """Text + streaming tool call with int indices, all at index 0/1. + + Mirrors OpenAI chat-completions API where each delta stays at the + same integer index and a new tool call bumps the index. + """ + chunks = [ + _aimsg_chunk([{"type": "text", "text": "Hello", "index": 0}]), + _aimsg_chunk([{"type": "text", "text": " there", "index": 0}]), + ] + # Tool-call chunks go via the tool_call_chunks channel, not content. + chunks.extend( + [ + ChatGenerationChunk( + message=AIMessageChunk( + content="", + id="m", + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "index": 1, + "id": "tc1", + "name": "lookup", + "args": '{"q":', + } + ], + ) + ), + ChatGenerationChunk( + message=AIMessageChunk( + content="", + id="m", + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "index": 1, + "id": None, + "name": None, + "args": ' "pie"}', + } + ], + ) + ), + ] + ) + + events = list(chunks_to_events(iter(chunks), message_id="m")) + assert_valid_event_stream(events) + + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + types = [f["content_block"]["type"] for f in finishes] + assert types == ["text", "tool_call"] + assert finishes[0]["content_block"]["text"] == "Hello there" + assert finishes[1]["content_block"]["args"] == {"q": "pie"} + + +def test_lifecycle_validator_openai_responses_style() -> None: + """Reasoning → text → reasoning → text with string block identifiers. + + Mirrors OpenAI `responses/v1` output_version where each distinct + block has a string index like `lc_rs_305f30`. + """ + chunks = [ + _aimsg_chunk([{"type": "reasoning", "reasoning": "hmm", "index": "rs_a"}]), + _aimsg_chunk([{"type": "reasoning", "reasoning": " first", "index": "rs_a"}]), + _aimsg_chunk([{"type": "text", "text": "Answer: ", "index": "txt_a"}]), + _aimsg_chunk([{"type": "text", "text": "42", "index": "txt_a"}]), + _aimsg_chunk([{"type": "reasoning", "reasoning": "actually", "index": "rs_b"}]), + _aimsg_chunk([{"type": "text", "text": "42!", "index": "txt_b"}]), + ] + + events = list(chunks_to_events(iter(chunks), message_id="m")) + assert_valid_event_stream(events) + + starts: list[Any] = [e for e in events if e["event"] == "content-block-start"] + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + # Four distinct blocks: reasoning, text, reasoning, text + assert [s["content_block"]["type"] for s in starts] == [ + "reasoning", + "text", + "reasoning", + "text", + ] + assert [s["index"] for s in starts] == [0, 1, 2, 3] + assert [f["index"] for f in finishes] == [0, 1, 2, 3] + assert finishes[0]["content_block"]["reasoning"] == "hmm first" + assert finishes[1]["content_block"]["text"] == "Answer: 42" + assert finishes[2]["content_block"]["reasoning"] == "actually" + assert finishes[3]["content_block"]["text"] == "42!" + + +def test_lifecycle_validator_anthropic_style_text_and_thinking() -> None: + """Interleaved text and thinking blocks with int indices. + + Mirrors Anthropic's per-event structure: one block per chunk, each + labeled with an int `index` from Anthropic's content_block_start / + content_block_delta events. + """ + chunks = [ + _aimsg_chunk([{"type": "reasoning", "reasoning": "Let me think", "index": 0}]), + _aimsg_chunk([{"type": "reasoning", "reasoning": " more", "index": 0}]), + _aimsg_chunk([{"type": "text", "text": "The answer is", "index": 1}]), + _aimsg_chunk([{"type": "text", "text": " 42.", "index": 1}]), + ] + + events = list(chunks_to_events(iter(chunks), message_id="m")) + assert_valid_event_stream(events) + + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + assert [f["content_block"]["type"] for f in finishes] == ["reasoning", "text"] + assert finishes[0]["content_block"]["reasoning"] == "Let me think more" + assert finishes[1]["content_block"]["text"] == "The answer is 42." + + +def test_lifecycle_validator_anthropic_reasoning_preserves_signature() -> None: + """A later reasoning delta's `extras.signature` must land on the finish block. + + Anthropic emits reasoning content as `thinking_delta` events (text), + followed by a `signature_delta` event carrying the cryptographic + signature that the API requires on any follow-up turn. After the + content-block-start/delta translation, that signature arrives as + `extras.signature` on a reasoning delta that has no new text. If + the bridge drops it, Claude rejects the next request with + `messages..content..thinking.signature: Field required`. + """ + chunks = [ + _aimsg_chunk([{"type": "reasoning", "reasoning": "Let me think", "index": 0}]), + _aimsg_chunk([{"type": "reasoning", "reasoning": " more", "index": 0}]), + # signature_delta arrives after the text; no new reasoning text + # but carries the signature under `extras`. + _aimsg_chunk( + [ + { + "type": "reasoning", + "reasoning": "", + "index": 0, + "extras": {"signature": "sig-abc123"}, + } + ] + ), + _aimsg_chunk([{"type": "text", "text": "Hi.", "index": 1}]), + ] + + events = list(chunks_to_events(iter(chunks), message_id="m")) + assert_valid_event_stream(events) + + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + reasoning_finish = finishes[0]["content_block"] + assert reasoning_finish["type"] == "reasoning" + assert reasoning_finish["reasoning"] == "Let me think more" + assert reasoning_finish.get("extras", {}).get("signature") == "sig-abc123" + + +def test_lifecycle_validator_anthropic_style_tool_use_after_text() -> None: + """Text then tool_use (tool_call_chunk) — Anthropic tool-calling pattern.""" + chunks = [ + _aimsg_chunk([{"type": "text", "text": "Looking up...", "index": 0}]), + ChatGenerationChunk( + message=AIMessageChunk( + content=[], + id="m", + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "index": 1, + "id": "toolu_1", + "name": "search", + "args": "", + } + ], + ) + ), + ChatGenerationChunk( + message=AIMessageChunk( + content=[], + id="m", + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "index": 1, + "id": None, + "name": None, + "args": '{"query": "42"}', + } + ], + ) + ), + ] + + events = list(chunks_to_events(iter(chunks), message_id="m")) + assert_valid_event_stream(events) + + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + assert [f["content_block"]["type"] for f in finishes] == ["text", "tool_call"] + assert finishes[1]["content_block"]["args"] == {"query": "42"} + assert finishes[1]["content_block"]["id"] == "toolu_1" + + +def test_lifecycle_validator_inline_image_block() -> None: + """A self-contained image block gets start + finish with no delta.""" + chunks = [ + _aimsg_chunk( + [ + { + "type": "image", + "id": "img1", + "mime_type": "image/png", + "data": "AAAA", + "index": 0, + } + ] + ), + ] + events = list(chunks_to_events(iter(chunks), message_id="m")) + assert_valid_event_stream(events) + + starts: list[Any] = [e for e in events if e["event"] == "content-block-start"] + deltas: list[Any] = [e for e in events if e["event"] == "content-block-delta"] + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + assert [s["content_block"]["type"] for s in starts] == ["image"] + # Self-contained block: no delta, and start has heavy fields stripped. + assert deltas == [] + assert "data" not in starts[0]["content_block"] + assert finishes[0]["content_block"]["data"] == "AAAA" + + +def test_lifecycle_validator_invalid_tool_call_args() -> None: + """Malformed JSON args finalize to invalid_tool_call; lifecycle still valid.""" + chunks = [ + ChatGenerationChunk( + message=AIMessageChunk( + content="", + id="m", + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "index": 0, + "id": "bad1", + "name": "noop", + "args": "not json", + } + ], + ) + ), + ] + events = list(chunks_to_events(iter(chunks), message_id="m")) + assert_valid_event_stream(events) + + finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"] + assert len(finishes) == 1 + assert finishes[0]["content_block"]["type"] == "invalid_tool_call" + + +def test_lifecycle_validator_empty_stream() -> None: + """An empty chunk iterator produces no events (and still validates).""" + assert_valid_event_stream(list(chunks_to_events(iter([])))) + + +def test_lifecycle_validator_message_to_events_roundtrip() -> None: + """`message_to_events` also produces spec-conformant lifecycles.""" + msg = AIMessage( + content=[ + {"type": "reasoning", "reasoning": "think"}, + {"type": "text", "text": "answer"}, + { + "type": "image", + "id": "img1", + "mime_type": "image/png", + "data": "X" * 256, + }, + ], + id="msg-1", + tool_calls=[ + {"id": "t1", "name": "search", "args": {"q": "pie"}, "type": "tool_call"}, + ], + ) + events = list(message_to_events(msg)) + assert_valid_event_stream(events) diff --git a/libs/core/tests/unit_tests/language_models/test_stream_v2.py b/libs/core/tests/unit_tests/language_models/test_stream_v2.py new file mode 100644 index 00000000000..a46629a5fed --- /dev/null +++ b/libs/core/tests/unit_tests/language_models/test_stream_v2.py @@ -0,0 +1,1357 @@ +"""Tests for stream_v2 / astream_v2 and ChatModelStream.""" + +from __future__ import annotations + +import asyncio +import contextlib +from typing import TYPE_CHECKING, Any + +import pytest +from langchain_protocol.protocol import ( + ContentBlockDeltaData, + ContentBlockFinishData, + MessageFinishData, + ReasoningContentBlock, + TextContentBlock, + ToolCall, + UsageInfo, +) + +from langchain_core.caches import InMemoryCache +from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler +from langchain_core.language_models.chat_model_stream import ( + AsyncChatModelStream, + ChatModelStream, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.fake_chat_models import FakeListChatModel +from langchain_core.messages import AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk, ChatResult +from langchain_core.tracers._streaming import _V2StreamingCallbackHandler + +if TYPE_CHECKING: + from collections.abc import Iterator + + from langchain_core.callbacks import CallbackManagerForLLMRun + from langchain_core.messages import BaseMessage + + +class _MalformedToolCallModel(BaseChatModel): + """Fake model that emits a tool_call_chunk with malformed JSON args.""" + + @property + def _llm_type(self) -> str: + return "malformed-tool-call-fake" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + del messages, stop, run_manager, kwargs + raise NotImplementedError + + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + del messages, stop, run_manager, kwargs + yield ChatGenerationChunk( + message=AIMessageChunk( + content="", + tool_call_chunks=[ + { + "name": "search", + "args": '{"q": ', # malformed JSON + "id": "call_1", + "index": 0, + } + ], + ) + ) + + +class _AnthropicStyleServerToolModel(BaseChatModel): + """Fake model that streams Anthropic-native server_tool_use shapes. + + Exercises Phase E: the bridge should call `content_blocks` (which + invokes the Anthropic translator) to convert `server_tool_use` into + protocol `server_tool_call` blocks instead of silently dropping them. + """ + + @property + def _llm_type(self) -> str: + return "anthropic-style-fake" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + del messages, stop, run_manager, kwargs + raise NotImplementedError + + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + del messages, stop, run_manager, kwargs + # Single chunk carrying a complete server_tool_use block — what + # Anthropic typically emits once input_json_delta finishes. + yield ChatGenerationChunk( + message=AIMessageChunk( + content=[ + { + "type": "server_tool_use", + "id": "srvtoolu_01", + "name": "web_search", + "input": {"query": "weather today"}, + }, + {"type": "text", "text": "Based on the search..."}, + ], + response_metadata={"model_provider": "anthropic"}, + ) + ) + + +class TestChatModelStream: + """Test the sync ChatModelStream object.""" + + def test_push_text_delta(self) -> None: + stream = ChatModelStream() + stream._push_content_block_delta( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text="Hello"), + ) + ) + assert stream._text_acc == "Hello" + + def test_push_reasoning_delta(self) -> None: + stream = ChatModelStream() + stream._push_content_block_delta( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=ReasoningContentBlock( + type="reasoning", reasoning="think" + ), + ) + ) + assert stream._reasoning_acc == "think" + + def test_push_content_block_finish_tool_call(self) -> None: + stream = ChatModelStream() + stream._push_content_block_finish( + ContentBlockFinishData( + event="content-block-finish", + index=0, + content_block=ToolCall( + type="tool_call", + id="tc1", + name="search", + args={"q": "test"}, + ), + ) + ) + assert len(stream._tool_calls_acc) == 1 + assert stream._tool_calls_acc[0]["name"] == "search" + + def test_finish(self) -> None: + stream = ChatModelStream() + assert not stream.done + usage = UsageInfo(input_tokens=10, output_tokens=5, total_tokens=15) + stream._finish(MessageFinishData(event="message-finish", usage=usage)) + assert stream.done + assert stream._usage_value == usage + + def test_fail(self) -> None: + stream = ChatModelStream() + stream.fail(RuntimeError("test")) + assert stream.done + + def test_pump_driven_text(self) -> None: + """Test text projection with pump binding.""" + stream = ChatModelStream() + deltas: list[ContentBlockDeltaData] = [ + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text="Hi"), + ), + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text=" there"), + ), + ] + finish = MessageFinishData(event="message-finish") + idx = 0 + + def pump_one() -> bool: + nonlocal idx + if idx < len(deltas): + stream._push_content_block_delta(deltas[idx]) + idx += 1 + return True + if idx == len(deltas): + stream._finish(finish) + idx += 1 + return True + return False + + stream.bind_pump(pump_one) + + text_deltas = list(stream.text) + assert text_deltas == ["Hi", " there"] + assert stream.done + + +class TestAsyncChatModelStream: + """Test the async ChatModelStream object.""" + + @pytest.mark.asyncio + async def test_text_await(self) -> None: + stream = AsyncChatModelStream() + stream._push_content_block_delta( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text="Hello"), + ) + ) + stream._push_content_block_delta( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text=" world"), + ) + ) + stream._finish(MessageFinishData(event="message-finish")) + + full = await stream.text + assert full == "Hello world" + + @pytest.mark.asyncio + async def test_text_async_iter(self) -> None: + stream = AsyncChatModelStream() + + async def produce() -> None: + await asyncio.sleep(0) + stream._push_content_block_delta( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text="a"), + ) + ) + await asyncio.sleep(0) + stream._push_content_block_delta( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text="b"), + ) + ) + await asyncio.sleep(0) + stream._finish(MessageFinishData(event="message-finish")) + + asyncio.get_running_loop().create_task(produce()) + + deltas = [d async for d in stream.text] + assert deltas == ["a", "b"] + + @pytest.mark.asyncio + async def test_tool_calls_await(self) -> None: + stream = AsyncChatModelStream() + stream._push_content_block_finish( + ContentBlockFinishData( + event="content-block-finish", + index=0, + content_block=ToolCall( + type="tool_call", + id="tc1", + name="search", + args={"q": "test"}, + ), + ) + ) + stream._finish(MessageFinishData(event="message-finish")) + + tool_calls = await stream.tool_calls + assert len(tool_calls) == 1 + assert tool_calls[0]["name"] == "search" + + @pytest.mark.asyncio + async def test_error_propagation(self) -> None: + stream = AsyncChatModelStream() + stream.fail(RuntimeError("boom")) + + with pytest.raises(RuntimeError, match="boom"): + await stream.text + + +class TestStreamV2: + """Test BaseChatModel.stream_v2() with FakeListChatModel.""" + + def test_stream_v2_text(self) -> None: + model = FakeListChatModel(responses=["Hello world!"]) + stream = model.stream_v2("test") + + assert isinstance(stream, ChatModelStream) + deltas = list(stream.text) + assert "".join(deltas) == "Hello world!" + assert stream.done + + def test_stream_v2_usage(self) -> None: + model = FakeListChatModel(responses=["Hi"]) + stream = model.stream_v2("test") + + # Drain stream + for _ in stream.text: + pass + # FakeListChatModel doesn't emit usage, so it should be None + assert stream.output.usage_metadata is None + assert stream.done + + def test_stream_v2_malformed_tool_args_produce_invalid_tool_call(self) -> None: + """End-to-end: malformed tool-call JSON becomes invalid_tool_calls.""" + model = _MalformedToolCallModel() + stream = model.stream_v2("test") + msg = stream.output + + assert msg.tool_calls == [] + assert len(msg.invalid_tool_calls) == 1 + itc = msg.invalid_tool_calls[0] + assert itc["name"] == "search" + assert itc["args"] == '{"q": ' + assert itc["id"] == "call_1" + + def test_stream_v2_translates_anthropic_server_tool_use_to_protocol(self) -> None: + """Phase E end-to-end: server_tool_use becomes server_tool_call in output.""" + model = _AnthropicStyleServerToolModel() + stream = model.stream_v2("weather?") + msg = stream.output + + assert isinstance(msg.content, list) + types = [b.get("type") for b in msg.content if isinstance(b, dict)] + # The server tool call must appear in the output content. + assert "server_tool_call" in types + # Text block should also be present. + assert "text" in types + # Regular tool_calls should NOT include the server-executed call. + assert msg.tool_calls == [] + + +class TestAstreamV2: + """Test BaseChatModel.astream_v2() with FakeListChatModel.""" + + @pytest.mark.asyncio + async def test_astream_v2_text(self) -> None: + model = FakeListChatModel(responses=["Hello!"]) + stream = await model.astream_v2("test") + + assert isinstance(stream, AsyncChatModelStream) + full = await stream.text + assert full == "Hello!" + + @pytest.mark.asyncio + async def test_astream_v2_deltas(self) -> None: + model = FakeListChatModel(responses=["Hi"]) + stream = await model.astream_v2("test") + + deltas = [d async for d in stream.text] + assert "".join(deltas) == "Hi" + + +class TestPerBlockAccumulation: + """Regression: per-block text/reasoning must not cross-contaminate. + + When a message contains more than one `text` or `reasoning` block + (Anthropic interleaves text around `tool_use`; OpenAI Responses + emits multiple reasoning summary items), each finalized block must + carry only its own payload — not the running message-wide total. + """ + + def test_two_text_blocks_keep_their_own_text(self) -> None: + stream = ChatModelStream() + # Block 0: "A" + stream.dispatch( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text="A"), + ) + ) + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=0, + content_block=TextContentBlock(type="text", text="A"), + ) + ) + # Block 1: "B" + stream.dispatch( + ContentBlockDeltaData( + event="content-block-delta", + index=1, + content_block=TextContentBlock(type="text", text="B"), + ) + ) + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=1, + content_block=TextContentBlock(type="text", text="B"), + ) + ) + stream.dispatch(MessageFinishData(event="message-finish")) + + content = stream.output.content + assert isinstance(content, list) + text_blocks = [ + b for b in content if isinstance(b, dict) and b.get("type") == "text" + ] + assert [b["text"] for b in text_blocks] == ["A", "B"], ( + "Finalized text blocks must carry their own payloads, not the " + "concatenation of all earlier text blocks." + ) + # Message-wide projection still sums to the full text. + assert str(stream.text) == "AB" + + def test_two_reasoning_blocks_keep_their_own_text(self) -> None: + stream = ChatModelStream() + # Block 0: "one" + stream.dispatch( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=ReasoningContentBlock(type="reasoning", reasoning="one"), + ) + ) + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=0, + content_block=ReasoningContentBlock(type="reasoning", reasoning="one"), + ) + ) + # Block 1: "two" + stream.dispatch( + ContentBlockDeltaData( + event="content-block-delta", + index=1, + content_block=ReasoningContentBlock(type="reasoning", reasoning="two"), + ) + ) + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=1, + content_block=ReasoningContentBlock(type="reasoning", reasoning="two"), + ) + ) + stream.dispatch(MessageFinishData(event="message-finish")) + + content = stream.output.content + assert isinstance(content, list) + reasoning_blocks = [ + b for b in content if isinstance(b, dict) and b.get("type") == "reasoning" + ] + assert [b["reasoning"] for b in reasoning_blocks] == ["one", "two"] + assert str(stream.reasoning) == "onetwo" + + def test_finish_text_reconciles_with_partial_deltas(self) -> None: + """`.text` must agree with `.output.content` when finish corrects deltas. + + If deltas stream "hel" and the `content-block-finish` payload + carries the authoritative "hello", both the per-block finalized + text and the message-wide projection must land on "hello". + """ + stream = ChatModelStream() + stream.dispatch( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text="hel"), + ) + ) + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=0, + content_block=TextContentBlock(type="text", text="hello"), + ) + ) + stream.dispatch(MessageFinishData(event="message-finish")) + + content = stream.output.content + assert isinstance(content, list) + text_blocks = [ + b for b in content if isinstance(b, dict) and b.get("type") == "text" + ] + assert [b["text"] for b in text_blocks] == ["hello"] + assert str(stream.text) == "hello" + + def test_out_of_order_finish_still_produces_correct_final_text(self) -> None: + """Reconciliation must not depend on `_text_acc` suffix layout. + + If block 0 finishes with authoritative text *after* block 1 has + already emitted deltas (possible in theory for a native + `_stream_chat_model_events` provider, or any future mutation + path that touches `_text_acc`), the in-place splice would + corrupt the message-wide accumulator. The final value must be + derived from per-block storage so both `stream.output.content` + and `str(stream.text)` remain correct regardless of finish + ordering. + """ + stream = ChatModelStream() + # Block 0 streams deltas first. + stream.dispatch( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text="aaa"), + ) + ) + # Block 1 streams deltas before block 0 finishes. + stream.dispatch( + ContentBlockDeltaData( + event="content-block-delta", + index=1, + content_block=TextContentBlock(type="text", text="bb"), + ) + ) + # Block 0 finishes with authoritative text different from deltas. + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=0, + content_block=TextContentBlock(type="text", text="XXX"), + ) + ) + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=1, + content_block=TextContentBlock(type="text", text="bb"), + ) + ) + stream.dispatch(MessageFinishData(event="message-finish")) + + content = stream.output.content + assert isinstance(content, list) + text_blocks = [ + b for b in content if isinstance(b, dict) and b.get("type") == "text" + ] + assert [b["text"] for b in text_blocks] == ["XXX", "bb"] + # `str(stream.text)` must reflect the authoritative per-block + # concatenation, not the splice-in-place result ("aaXXX") that + # would have been left over from the old suffix assumption. + assert str(stream.text) == "XXXbb" + + def test_finish_reasoning_reconciles_with_partial_deltas(self) -> None: + """Same reconciliation invariant for the reasoning projection.""" + stream = ChatModelStream() + stream.dispatch( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=ReasoningContentBlock(type="reasoning", reasoning="thi"), + ) + ) + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=0, + content_block=ReasoningContentBlock( + type="reasoning", reasoning="thinking" + ), + ) + ) + stream.dispatch(MessageFinishData(event="message-finish")) + + content = stream.output.content + assert isinstance(content, list) + reasoning_blocks = [ + b for b in content if isinstance(b, dict) and b.get("type") == "reasoning" + ] + assert [b["reasoning"] for b in reasoning_blocks] == ["thinking"] + assert str(stream.reasoning) == "thinking" + + def test_interleaved_text_blocks_around_tool_call(self) -> None: + """Anthropic shape: text, then tool_call, then more text.""" + stream = ChatModelStream() + # Block 0: text "before" + stream.dispatch( + ContentBlockDeltaData( + event="content-block-delta", + index=0, + content_block=TextContentBlock(type="text", text="before"), + ) + ) + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=0, + content_block=TextContentBlock(type="text", text="before"), + ) + ) + # Block 1: tool_call + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=1, + content_block=ToolCall( + type="tool_call", + id="tc1", + name="search", + args={"q": "x"}, + ), + ) + ) + # Block 2: text "after" + stream.dispatch( + ContentBlockDeltaData( + event="content-block-delta", + index=2, + content_block=TextContentBlock(type="text", text="after"), + ) + ) + stream.dispatch( + ContentBlockFinishData( + event="content-block-finish", + index=2, + content_block=TextContentBlock(type="text", text="after"), + ) + ) + stream.dispatch(MessageFinishData(event="message-finish")) + + content = stream.output.content + assert isinstance(content, list) + text_blocks = [ + b for b in content if isinstance(b, dict) and b.get("type") == "text" + ] + assert [b["text"] for b in text_blocks] == ["before", "after"] + + +class _RecordingStreamModel(BaseChatModel): + """Fake model that records the kwargs passed to _stream / _astream.""" + + last_stream_kwargs: dict[str, Any] = {} # noqa: RUF012 + last_astream_kwargs: dict[str, Any] = {} # noqa: RUF012 + + @property + def _llm_type(self) -> str: + return "recording-fake" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + del messages, stop, run_manager, kwargs + raise NotImplementedError + + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + del messages, stop, run_manager + type(self).last_stream_kwargs = dict(kwargs) + yield ChatGenerationChunk(message=AIMessageChunk(content="ok")) + + async def _astream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: Any = None, + **kwargs: Any, + ) -> Any: + del messages, stop, run_manager + type(self).last_astream_kwargs = dict(kwargs) + yield ChatGenerationChunk(message=AIMessageChunk(content="ok")) + + +class TestStructuredOutputKwargStripping: + """Regression: structured-output tracing kwargs must not reach _stream. + + `stream()` / `astream()` pop `ls_structured_output_format` and + `structured_output_format` before forwarding kwargs to `_stream` — + provider clients reject unknown kwargs. `stream_v2` / `astream_v2` + must do the same, or `.with_structured_output().stream_v2()` breaks. + """ + + def test_stream_v2_strips_ls_structured_output_format(self) -> None: + model = _RecordingStreamModel() + bound = model.bind(ls_structured_output_format={"schema": {"type": "object"}}) + stream = bound.stream_v2("test") + _ = stream.output # drain + recorded = _RecordingStreamModel.last_stream_kwargs + assert "ls_structured_output_format" not in recorded + assert "structured_output_format" not in recorded + + def test_stream_v2_strips_structured_output_format(self) -> None: + model = _RecordingStreamModel() + bound = model.bind(structured_output_format={"schema": {"type": "object"}}) + stream = bound.stream_v2("test") + _ = stream.output + recorded = _RecordingStreamModel.last_stream_kwargs + assert "ls_structured_output_format" not in recorded + assert "structured_output_format" not in recorded + + @pytest.mark.asyncio + async def test_astream_v2_strips_ls_structured_output_format(self) -> None: + model = _RecordingStreamModel() + bound = model.bind(ls_structured_output_format={"schema": {"type": "object"}}) + stream = await bound.astream_v2("test") + _ = await stream + assert ( + "ls_structured_output_format" + not in _RecordingStreamModel.last_astream_kwargs + ) + assert ( + "structured_output_format" not in _RecordingStreamModel.last_astream_kwargs + ) + + @pytest.mark.asyncio + async def test_astream_v2_strips_structured_output_format(self) -> None: + model = _RecordingStreamModel() + bound = model.bind(structured_output_format={"schema": {"type": "object"}}) + stream = await bound.astream_v2("test") + _ = await stream + assert ( + "ls_structured_output_format" + not in _RecordingStreamModel.last_astream_kwargs + ) + assert ( + "structured_output_format" not in _RecordingStreamModel.last_astream_kwargs + ) + + +class _SlowTeardownModel(BaseChatModel): + """Fake model whose `_astream` blocks cancellation teardown on a gate. + + Used to exercise the caller-cancellation path in `aclose()`: + cancelling the producer causes it to enter a `CancelledError` + handler that waits on `teardown_gate` before re-raising. That + keeps the producer task in a "cancelled-but-not-done" state long + enough for the test to cancel `aclose`'s caller deterministically. + """ + + def __init__(self, teardown_gate: asyncio.Event, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._teardown_gate = teardown_gate + + @property + def _llm_type(self) -> str: + return "slow-teardown-fake" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + del messages, stop, run_manager, kwargs + raise NotImplementedError + + async def _astream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: Any = None, + **kwargs: Any, + ) -> Any: + del messages, stop, run_manager, kwargs + yield ChatGenerationChunk(message=AIMessageChunk(content="first")) + # Block forever; cancellation is the only way out. + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + # Hold the cancellation teardown open until the test releases + # the gate. The task stays in a pending state while this + # handler is suspended, so `await task` on the `aclose()` + # side remains blocked. + await self._teardown_gate.wait() + raise + + +class _GatedStreamModel(BaseChatModel): + """Fake model whose _astream blocks on an event until released. + + Used to exercise consumer-cancellation cleanup: the producer task + is parked inside `_astream` awaiting the gate, and `aclose()` must + cancel it rather than leave it running. + """ + + def __init__(self, gate: asyncio.Event, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._gate = gate + self._cancelled = False + + @property + def _llm_type(self) -> str: + return "gated-fake" + + @property + def cancelled(self) -> bool: + return self._cancelled + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + del messages, stop, run_manager, kwargs + raise NotImplementedError + + async def _astream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: Any = None, + **kwargs: Any, + ) -> Any: + del messages, stop, run_manager, kwargs + yield ChatGenerationChunk(message=AIMessageChunk(content="first")) + try: + await self._gate.wait() + except asyncio.CancelledError: + self._cancelled = True + raise + yield ChatGenerationChunk(message=AIMessageChunk(content="second")) + + +class TestAsyncStreamAclose: + """Regression: aclose() must cancel the background producer task.""" + + @pytest.mark.asyncio + async def test_aclose_cancels_producer_task(self) -> None: + gate = asyncio.Event() + model = _GatedStreamModel(gate=gate) + stream = await model.astream_v2("test") + + # Pull the first delta so the producer enters the gated section. + aiter_ = stream.text.__aiter__() + first = await aiter_.__anext__() + assert first == "first" + assert stream._producer_task is not None + assert not stream._producer_task.done() + + await stream.aclose() + + assert stream._producer_task.done() + assert stream._producer_task.cancelled() or model.cancelled + + @pytest.mark.asyncio + async def test_aclose_is_idempotent(self) -> None: + gate = asyncio.Event() + model = _GatedStreamModel(gate=gate) + stream = await model.astream_v2("test") + aiter_ = stream.text.__aiter__() + await aiter_.__anext__() + + await stream.aclose() + await stream.aclose() # second call must not raise + + @pytest.mark.asyncio + async def test_async_context_manager_closes_stream(self) -> None: + gate = asyncio.Event() + model = _GatedStreamModel(gate=gate) + stream = await model.astream_v2("test") + + async with stream as s: + assert s is stream + aiter_ = stream.text.__aiter__() + await aiter_.__anext__() + + assert stream._producer_task is not None + assert stream._producer_task.done() + + @pytest.mark.asyncio + async def test_aclose_propagates_caller_cancellation(self) -> None: + """`aclose()` must not swallow cancellation of its caller. + + Uses `_SlowTeardownModel`, whose cancelled producer blocks + inside its `CancelledError` handler waiting on `teardown_gate`. + That keeps the producer task pending long enough for the test + to cancel the closer task while it is genuinely suspended + inside `aclose()` — exercising the caller-cancel propagation + path deterministically on all Python versions. + """ + teardown_gate = asyncio.Event() + model = _SlowTeardownModel(teardown_gate=teardown_gate) + stream = await model.astream_v2("test") + + # Prime the producer so it enters `_astream`'s forever-blocking + # await. + aiter_ = stream.text.__aiter__() + await aiter_.__anext__() + + closer_returned_normally = False + + async def closer() -> None: + nonlocal closer_returned_normally + await stream.aclose() + closer_returned_normally = True + + closer_task = asyncio.create_task(closer()) + # Pump the loop until the producer has been cancelled and has + # entered its cancellation-teardown suspension on + # `teardown_gate`. At that point `closer` is guaranteed to be + # suspended inside `aclose`'s linked-future await. + for _ in range(10): + await asyncio.sleep(0) + assert stream._producer_task is not None + if ( + stream._producer_task is not None + and not stream._producer_task.done() + and not closer_task.done() + ): + break + + assert stream._producer_task is not None + assert not stream._producer_task.done(), ( + "producer task should be parked in its cancellation teardown" + ) + assert not closer_task.done(), "closer must still be inside aclose" + + closer_task.cancel() + with pytest.raises(asyncio.CancelledError): + await closer_task + assert not closer_returned_normally + + # Release the producer so it can finish cancellation, then + # await it to avoid leaking a pending task out of the test. + teardown_gate.set() + with contextlib.suppress(BaseException): + await stream._producer_task + + @pytest.mark.asyncio + async def test_aclose_before_producer_starts_resolves_projections(self) -> None: + """Early-cancel path: `_produce` never runs. + + If a consumer calls `astream_v2()` and immediately `aclose()` + (or `async with` exits before the loop schedules `_produce`), + `task.cancel()` marks the task cancelled without ever invoking + its body — so neither `stream.fail` nor `on_llm_error` fires. + Consumers awaiting `stream.output` / `stream.text` would hang + forever without explicit cleanup in `aclose()`. + """ + error_events: list[BaseException] = [] + + class RecordingHandler(AsyncCallbackHandler): + async def on_llm_error(self, error: BaseException, **_: Any) -> None: + error_events.append(error) + + handler = RecordingHandler() + gate = asyncio.Event() + model = _GatedStreamModel(gate=gate) + stream = await model.astream_v2("test", config={"callbacks": [handler]}) + # No yield to the event loop between `astream_v2` returning and + # `aclose()` — the producer task has been created but its body + # has not executed. + await stream.aclose() + + # `await stream.output` must resolve (with CancelledError) + # rather than hang. + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(stream.output, timeout=1.0) + + # `on_llm_error` must have been invoked for tracing continuity, + # even though `_produce` never reached its CancelledError handler. + for _ in range(20): + if error_events: + break + await asyncio.sleep(0) + assert len(error_events) == 1 + assert isinstance(error_events[0], asyncio.CancelledError) + + @pytest.mark.asyncio + async def test_aclose_fires_on_llm_error_for_tracing(self) -> None: + """Cancellation via `aclose()` must close the callback lifecycle. + + Without this, handlers / tracing see a started run with no + matching end-or-error event for cancelled streams. + """ + end_events: list[Any] = [] + error_events: list[BaseException] = [] + + class RecordingHandler(AsyncCallbackHandler): + async def on_llm_end(self, response: Any, **_: Any) -> None: + end_events.append(response) + + async def on_llm_error(self, error: BaseException, **_: Any) -> None: + error_events.append(error) + + handler = RecordingHandler() + gate = asyncio.Event() + model = _GatedStreamModel(gate=gate) + stream = await model.astream_v2("test", config={"callbacks": [handler]}) + + aiter_ = stream.text.__aiter__() + await aiter_.__anext__() + + await stream.aclose() + + # Let the shielded callback finish. + for _ in range(10): + if error_events: + break + await asyncio.sleep(0) + + assert not end_events, "on_llm_end must not fire for cancelled stream" + assert len(error_events) == 1, ( + "aclose()-triggered cancellation must fire on_llm_error so " + "tracing observes a matching end event." + ) + assert isinstance(error_events[0], asyncio.CancelledError) + + @pytest.mark.asyncio + async def test_aclose_preserves_successful_stream_mid_on_llm_end(self) -> None: + """A successful stream must not be turned into CancelledError. + + After `message-finish` dispatches, `_output_proj` is already + complete, but `_producer_task` may still be inside + `run_manager.on_llm_end(...)`. Canceling unconditionally would + drop the end callback and corrupt an otherwise successful run. + """ + end_gate = asyncio.Event() + end_fired = asyncio.Event() + + class SlowEndHandler(AsyncCallbackHandler): + async def on_llm_end(self, response: Any, **_: Any) -> None: + del response + end_fired.set() + await end_gate.wait() + + handler = SlowEndHandler() + model = FakeListChatModel(responses=["ok"]) + stream = await model.astream_v2("test", config={"callbacks": [handler]}) + + # Wait until the stream has assembled the message and the + # slow on_llm_end handler has started running. + message = await stream.output + await end_fired.wait() + assert message.text == "ok" + assert stream._producer_task is not None + assert not stream._producer_task.done() + assert stream._error is None + + # Kick off aclose; release the callback so it completes. + close_task = asyncio.create_task(stream.aclose()) + await asyncio.sleep(0) + end_gate.set() + await close_task + + assert stream._producer_task.done() + assert not stream._producer_task.cancelled() + # The success path must be preserved — no error installed. + assert stream._error is None + # And the output projection is still resolvable. + assert (await stream.output).text == "ok" + + +class _V2RecordingHandler(BaseCallbackHandler, _V2StreamingCallbackHandler): + """Records every protocol event dispatched via `on_stream_event`.""" + + def __init__(self) -> None: + self.events: list[Any] = [] + + def on_stream_event(self, event: Any, **_: Any) -> None: + self.events.append(event) + + +class _AsyncV2RecordingHandler(AsyncCallbackHandler, _V2StreamingCallbackHandler): + """Async counterpart to `_V2RecordingHandler`.""" + + def __init__(self) -> None: + self.events: list[Any] = [] + + async def on_stream_event(self, event: Any, **_: Any) -> None: + self.events.append(event) + + +class TestCacheHitV2Replay: + """Cache hits must replay protocol events for v2 handlers. + + Without replay, `on_stream_event` fires on cache misses but not on + warm-cache calls — LangGraph-style consumers would see behavior + that depends on cache state alone. + """ + + def test_cache_hit_replays_events_to_v2_handler(self) -> None: + cache = InMemoryCache() + model = FakeListChatModel(responses=["Hello"], cache=cache) + handler = _V2RecordingHandler() + + # Cold call: populates cache and fires events. + model.invoke("prompt", config={"callbacks": [handler]}) + cold_events = list(handler.events) + handler.events.clear() + + # Warm call: events must fire again from the replayed cache hit. + model.invoke("prompt", config={"callbacks": [handler]}) + warm_events = list(handler.events) + + assert warm_events, "cache hit must replay v2 events" + warm_types = [e["event"] for e in warm_events] + # Lifecycle anchors must be present on the warm path, matching cold. + # Replay collapses per-chunk deltas into a single delta per block, + # so we assert shape equivalence at the anchor level rather than + # exact event-count equality. + assert warm_types[0] == "message-start" + assert warm_types[-1] == "message-finish" + assert "content-block-start" in warm_types + assert "content-block-finish" in warm_types + cold_types = [e["event"] for e in cold_events] + assert cold_types[0] == "message-start" + assert cold_types[-1] == "message-finish" + + def test_cache_hit_skips_replay_without_v2_handler(self) -> None: + """A v1-only callback set must not accidentally trigger v2 replay.""" + cache = InMemoryCache() + model = FakeListChatModel(responses=["Hi"], cache=cache) + + # Prime the cache. + model.invoke("prompt") + + class _V1OnlyHandler(BaseCallbackHandler): + def __init__(self) -> None: + self.stream_events: list[Any] = [] + + def on_stream_event(self, event: Any, **_: Any) -> None: + self.stream_events.append(event) + + handler = _V1OnlyHandler() + model.invoke("prompt", config={"callbacks": [handler]}) + # No `_V2StreamingCallbackHandler` marker -> no replay. + assert handler.stream_events == [] + + @pytest.mark.asyncio + async def test_acache_hit_replays_events_to_v2_handler(self) -> None: + cache = InMemoryCache() + model = FakeListChatModel(responses=["Hello"], cache=cache) + handler = _AsyncV2RecordingHandler() + + await model.ainvoke("prompt", config={"callbacks": [handler]}) + handler.events.clear() + + await model.ainvoke("prompt", config={"callbacks": [handler]}) + assert handler.events, "async cache hit must replay v2 events" + types = [e["event"] for e in handler.events] + assert "message-start" in types + assert "message-finish" in types + + +class _ProviderMetadataStreamModel(BaseChatModel): + """Fake model that advertises `output_version="responses/v1"` in metadata. + + Verifies `stream_v2` pins the assembled message's `output_version` to + `"v1"` — the shape it actually produces — regardless of what the + provider's chunk metadata claims. + """ + + @property + def _llm_type(self) -> str: + return "provider-metadata-fake" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + del messages, stop, run_manager, kwargs + raise NotImplementedError + + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + del messages, stop, run_manager, kwargs + yield ChatGenerationChunk( + message=AIMessageChunk( + content="hi", + response_metadata={"output_version": "responses/v1"}, + ) + ) + + +class TestOutputVersionPinning: + """`stream_v2().output` always serializes as v1 content blocks.""" + + def test_output_version_pinned_to_v1(self) -> None: + model = _ProviderMetadataStreamModel() + stream = model.stream_v2("hi") + msg = stream.output + # Assembled message must claim `"v1"` even though the provider + # chunk metadata advertised `"responses/v1"`. + assert msg.response_metadata.get("output_version") == "v1" + + +class _BedrockConverseToolCallModel(BaseChatModel): + """Replays a captured `ChatBedrockConverse` tool-calling stream. + + Bedrock opens a tool block with `args=None` (name + id only) and + starts streaming JSON args in the *next* chunk. Other providers + emit `args=""` on the opener, so the compat bridge's accumulator + never saw `None` on the state side until Bedrock hit it. + """ + + @property + def _llm_type(self) -> str: + return "bedrock-converse-fake" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + del messages, stop, run_manager, kwargs + raise NotImplementedError + + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + del messages, stop, run_manager, kwargs + meta = {"model_provider": "bedrock_converse", "ls_provider": "amazon_bedrock"} + + # Text at content index 0. + yield ChatGenerationChunk( + message=AIMessageChunk( + content=[{"type": "text", "text": "Hello ", "index": 0}], + response_metadata=meta, + ) + ) + yield ChatGenerationChunk( + message=AIMessageChunk( + content=[{"type": "text", "text": "Boston!", "index": 0}], + response_metadata=meta, + ) + ) + # Tool opener: name + id, args=None (the Bedrock-specific shape). + yield ChatGenerationChunk( + message=AIMessageChunk( + content=[ + { + "type": "tool_use", + "name": "get_weather", + "id": "tooluse_1", + "index": 1, + } + ], + response_metadata=meta, + tool_call_chunks=[ + { + "name": "get_weather", + "args": None, + "id": "tooluse_1", + "index": 1, + "type": "tool_call_chunk", + } + ], + ) + ) + # Args deltas; each intermediate JSON slice is itself unparseable. + for slice_ in ('{"location":', ' "Boston', '"}'): + yield ChatGenerationChunk( + message=AIMessageChunk( + content=[ + {"type": "tool_use", "input": slice_, "id": None, "index": 1} + ], + response_metadata=meta, + tool_call_chunks=[ + { + "name": None, + "args": slice_, + "id": None, + "index": 1, + "type": "tool_call_chunk", + } + ], + ) + ) + # Terminal chunk carrying usage + stop reason. + yield ChatGenerationChunk( + message=AIMessageChunk( + content="", + response_metadata={"stopReason": "tool_use", **meta}, + usage_metadata={ + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + }, + chunk_position="last", + ) + ) + + +class TestBedrockConverseToolCallArgs: + """Regression: Bedrock's `args=None` tool opener must not break accumulation. + + The compat bridge's `_accumulate` used to do + `state.get("args", "") + (delta.get("args") or "")`; `state.get("args", "")` + returns the stored value when the key exists, so a Bedrock opener that + stores `args=None` poisoned the state and the next delta raised + `TypeError: unsupported operand type(s) for +: 'NoneType' and 'str'`. + """ + + def test_bedrock_tool_call_assembles_without_error(self) -> None: + model = _BedrockConverseToolCallModel() + stream = model.stream_v2("What's the weather in Boston?") + # Drive the stream to completion — the raise would have surfaced here. + events = list(stream) + + kinds = [e["event"] for e in events] + assert kinds[0] == "message-start" + assert kinds[-1] == "message-finish" + + msg = stream.output + assert msg.tool_calls == [ + { + "name": "get_weather", + "args": {"location": "Boston"}, + "id": "tooluse_1", + "type": "tool_call", + } + ] + # The args are assembled by concatenating deltas, so no + # partial-JSON slice should register as an `invalid_tool_call`. + assert msg.invalid_tool_calls == [] + # Text block round-trips alongside the tool call. + text_blocks = [ + b for b in msg.content if isinstance(b, dict) and b.get("type") == "text" + ] + assert len(text_blocks) == 1 + assert text_blocks[0]["text"] == "Hello Boston!" diff --git a/libs/core/tests/unit_tests/language_models/test_v1_parity.py b/libs/core/tests/unit_tests/language_models/test_v1_parity.py new file mode 100644 index 00000000000..d7f3f3a602c --- /dev/null +++ b/libs/core/tests/unit_tests/language_models/test_v1_parity.py @@ -0,0 +1,392 @@ +"""V1 parity tests: stream_v2() output must match model.stream() output. + +These are the acceptance criteria for streaming v2 — if any test fails, +v2 has a regression vs v1. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest +from typing_extensions import override + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.fake_chat_models import FakeListChatModel +from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Iterator + + from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, + ) + from langchain_core.messages import BaseMessage + + +class _ScriptedChunkModel(BaseChatModel): + """Fake chat model that streams a fixed, pre-built sequence of chunks. + + Lets us write parity tests that exercise tool calls, reasoning, + usage metadata, and response metadata — shapes `FakeListChatModel` + cannot produce. + """ + + scripted_chunks: list[AIMessageChunk] + raise_after: bool = False + """If True, raise `_FakeStreamError` after yielding all scripted chunks.""" + + @property + @override + def _llm_type(self) -> str: + return "scripted-chunk-fake" + + def _merged(self) -> AIMessageChunk: + merged = self.scripted_chunks[0] + for c in self.scripted_chunks[1:]: + merged = merged + c + return merged + + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + merged = self._merged() + final = AIMessage( + content=merged.content, + id=merged.id, + tool_calls=merged.tool_calls, + usage_metadata=merged.usage_metadata, + response_metadata=merged.response_metadata, + ) + return ChatResult(generations=[ChatGeneration(message=final)]) + + @override + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + for chunk in self.scripted_chunks: + yield ChatGenerationChunk(message=chunk) + if self.raise_after: + msg = "scripted failure" + raise _FakeStreamError(msg) + + @override + async def _astream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + for chunk in self.scripted_chunks: + yield ChatGenerationChunk(message=chunk) + if self.raise_after: + msg = "scripted failure" + raise _FakeStreamError(msg) + + +class _FakeStreamError(RuntimeError): + """Marker exception raised by `_ScriptedChunkModel` during streaming.""" + + +def _collect_v1_message(model: BaseChatModel, input_text: str) -> AIMessage: + """Run model.stream() (in v1 output mode) and merge chunks into an AIMessage. + + `ChatModelStream.output` is always v1-shaped (content is a list of + protocol blocks when blocks arrived). The legacy stream path only + emits v1-shaped content when `output_version="v1"` is set on the + model, so force it here for a like-for-like parity comparison. + """ + model.output_version = "v1" + chunks: list[AIMessageChunk] = [ + chunk for chunk in model.stream(input_text) if isinstance(chunk, AIMessageChunk) + ] + if not chunks: + msg = "No chunks produced" + raise RuntimeError(msg) + merged = chunks[0] + for c in chunks[1:]: + merged = merged + c + return AIMessage( + content=merged.content, + id=merged.id, + tool_calls=merged.tool_calls, + usage_metadata=merged.usage_metadata, + response_metadata=merged.response_metadata, + ) + + +def _collect_v2_message(model: BaseChatModel, input_text: str) -> AIMessage: + """Run model.stream_v2() and get .output.""" + stream = model.stream_v2(input_text) + return stream.output + + +class TestV1ParityBasic: + """Smoke-level parity using the simple text-only fake.""" + + def test_text_only_content_matches(self) -> None: + model = FakeListChatModel(responses=["Hello world!"]) + v1 = _collect_v1_message(model, "test") + model.i = 0 + v2 = _collect_v2_message(model, "test") + + assert v1.content == v2.content + + def test_message_id_present(self) -> None: + model = FakeListChatModel(responses=["Hi"]) + v1 = _collect_v1_message(model, "test") + model.i = 0 + v2 = _collect_v2_message(model, "test") + + assert v1.id is not None + assert v2.id is not None + + def test_empty_response(self) -> None: + """A truly empty stream is an error, matching `stream()` parity. + + `stream_v2` distinguishes "producer emitted events but no terminal + `message-finish`" (which is synthesized, for native-event providers + that omit it) from "producer emitted nothing at all" (which fails + with `ValueError`, same as `stream()`). + """ + model = FakeListChatModel(responses=[""]) + stream = model.stream_v2("test") + with pytest.raises(ValueError, match="No generation chunks"): + _ = stream.output + + def test_multi_character_response(self) -> None: + text = "The quick brown fox" + model = FakeListChatModel(responses=[text]) + v2 = _collect_v2_message(model, "test") + assert isinstance(v2.content, list) + assert len(v2.content) == 1 + text_block = v2.content[0] + assert isinstance(text_block, dict) + assert text_block["type"] == "text" + assert text_block["text"] == text + + def test_text_deltas_reconstruct_content(self) -> None: + model = FakeListChatModel(responses=["Hello!"]) + stream = model.stream_v2("test") + + deltas = list(stream.text) + content = stream.output.content + assert isinstance(content, list) + first_block = content[0] + assert isinstance(first_block, dict) + assert "".join(deltas) == first_block["text"] + + +class TestV1ParityToolCalls: + """Tool-call parity — the most load-bearing v1 shape.""" + + @staticmethod + def _make_model() -> _ScriptedChunkModel: + chunks = [ + AIMessageChunk( + content="", + id="run-tool-1", + tool_call_chunks=[ + {"index": 0, "id": "call_1", "name": "get_weather", "args": ""}, + ], + ), + AIMessageChunk( + content="", + id="run-tool-1", + tool_call_chunks=[ + {"index": 0, "id": None, "name": None, "args": '{"city": "'}, + ], + ), + AIMessageChunk( + content="", + id="run-tool-1", + tool_call_chunks=[ + {"index": 0, "id": None, "name": None, "args": 'Paris"}'}, + ], + response_metadata={"finish_reason": "tool_use"}, + ), + ] + return _ScriptedChunkModel(scripted_chunks=chunks) + + def test_tool_calls_match(self) -> None: + model = self._make_model() + v1 = _collect_v1_message(model, "weather?") + v2 = _collect_v2_message(self._make_model(), "weather?") + + assert len(v1.tool_calls) == 1 + assert len(v2.tool_calls) == 1 + assert v1.tool_calls[0]["id"] == v2.tool_calls[0]["id"] == "call_1" + assert v1.tool_calls[0]["name"] == v2.tool_calls[0]["name"] == "get_weather" + assert v1.tool_calls[0]["args"] == v2.tool_calls[0]["args"] == {"city": "Paris"} + + def test_tool_calls_via_projection(self) -> None: + model = self._make_model() + stream = model.stream_v2("weather?") + finalized = stream.tool_calls.get() + assert len(finalized) == 1 + assert finalized[0]["name"] == "get_weather" + assert finalized[0]["args"] == {"city": "Paris"} + + def test_finish_reason_tool_use(self) -> None: + model = self._make_model() + v2 = _collect_v2_message(model, "weather?") + assert v2.response_metadata.get("finish_reason") == "tool_use" + + +class TestV1ParityUsage: + """Usage metadata parity.""" + + @staticmethod + def _make_model() -> _ScriptedChunkModel: + chunks = [ + AIMessageChunk(content="Hi", id="run-usage-1"), + AIMessageChunk( + content=" there", + id="run-usage-1", + usage_metadata={ + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + }, + response_metadata={"finish_reason": "stop"}, + ), + ] + return _ScriptedChunkModel(scripted_chunks=chunks) + + def test_usage_metadata_present(self) -> None: + v1 = _collect_v1_message(self._make_model(), "hello") + v2 = _collect_v2_message(self._make_model(), "hello") + + assert v1.usage_metadata is not None + assert v2.usage_metadata is not None + assert v1.usage_metadata["input_tokens"] == v2.usage_metadata["input_tokens"] + assert v1.usage_metadata["output_tokens"] == v2.usage_metadata["output_tokens"] + assert v1.usage_metadata["total_tokens"] == v2.usage_metadata["total_tokens"] + + def test_usage_projection_matches(self) -> None: + stream = self._make_model().stream_v2("hello") + # Drain so usage is available + for _ in stream.text: + pass + usage = stream.output.usage_metadata + assert usage is not None + assert usage["input_tokens"] == 10 + assert usage["output_tokens"] == 5 + + +class TestV1ParityResponseMetadata: + """Response metadata preservation (fix 5b).""" + + @staticmethod + def _make_model() -> _ScriptedChunkModel: + chunks = [ + AIMessageChunk( + content="ok", + id="run-meta-1", + response_metadata={ + "finish_reason": "stop", + "model_provider": "fake-provider", + "stop_sequence": None, + }, + ), + ] + return _ScriptedChunkModel(scripted_chunks=chunks) + + def test_finish_reason_preserved(self) -> None: + v2 = _collect_v2_message(self._make_model(), "hi") + assert v2.response_metadata.get("finish_reason") == "stop" + + def test_provider_metadata_preserved(self) -> None: + """Non-finish-reason keys should survive the round-trip.""" + v2 = _collect_v2_message(self._make_model(), "hi") + # stop_sequence came from response_metadata on chunks; the bridge + # should carry it through via MessageFinishData.metadata. + assert "stop_sequence" in v2.response_metadata + + +class TestV1ParityReasoning: + """Reasoning content parity — order must be preserved.""" + + @staticmethod + def _make_model() -> _ScriptedChunkModel: + chunks = [ + AIMessageChunk( + content=[ + {"type": "reasoning", "reasoning": "Let me think. ", "index": 0}, + ], + id="run-reason-1", + ), + AIMessageChunk( + content=[ + {"type": "reasoning", "reasoning": "Done.", "index": 0}, + ], + id="run-reason-1", + ), + AIMessageChunk( + content=[ + {"type": "text", "text": "The answer is 42.", "index": 1}, + ], + id="run-reason-1", + response_metadata={"finish_reason": "stop"}, + ), + ] + return _ScriptedChunkModel(scripted_chunks=chunks) + + def test_reasoning_text_order(self) -> None: + """Reasoning block should come before text block in .output.content.""" + v2 = _collect_v2_message(self._make_model(), "think") + assert isinstance(v2.content, list) + types_in_order = [b.get("type") for b in v2.content if isinstance(b, dict)] + assert types_in_order == ["reasoning", "text"] + + def test_reasoning_projection(self) -> None: + stream = self._make_model().stream_v2("think") + full_reasoning = str(stream.reasoning) + assert full_reasoning == "Let me think. Done." + + +class TestV1ParityError: + """Errors during streaming must propagate on both paths.""" + + def test_error_propagates_sync(self) -> None: + chunks = [ + AIMessageChunk(content="partial", id="run-err-1"), + ] + model = _ScriptedChunkModel(scripted_chunks=chunks, raise_after=True) + + stream = model.stream_v2("boom") + # Drain first; error may surface here or at .output access. + try: + list(stream.text) + except _FakeStreamError: + return # Error surfaced during iteration — pass + with pytest.raises(_FakeStreamError): + _ = stream.output + + @pytest.mark.asyncio + async def test_error_propagates_async(self) -> None: + chunks = [ + AIMessageChunk(content="partial", id="run-err-2"), + ] + model = _ScriptedChunkModel(scripted_chunks=chunks, raise_after=True) + + stream = await model.astream_v2("boom") + try: + async for _ in stream.text: + pass + except _FakeStreamError: + return + with pytest.raises(_FakeStreamError): + _ = await stream diff --git a/libs/core/uv.lock b/libs/core/uv.lock index 728bc2d7826..6c66832404e 100644 --- a/libs/core/uv.lock +++ b/libs/core/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.14' and platform_python_implementation == 'PyPy'", @@ -999,6 +999,7 @@ version = "1.3.1" source = { editable = "." } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -1045,6 +1046,7 @@ typing = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -1087,6 +1089,18 @@ typing = [ { name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/langchain/uv.lock b/libs/langchain/uv.lock index 8ad1fff9b38..21ece246315 100644 --- a/libs/langchain/uv.lock +++ b/libs/langchain/uv.lock @@ -2605,6 +2605,7 @@ version = "1.3.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -2617,6 +2618,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -2845,6 +2847,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/a7/ee0f47bfdf0cdc8a319eee66302988cfc741695d99a758ff714f93fc04a9/langchain_perplexity-1.1.0-py3-none-any.whl", hash = "sha256:74165561403869aa4dd01b215ecb051578a9d794e843e8e7cc13999c68ff69b5", size = 11159, upload-time = "2025-11-24T15:05:13.596Z" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock index 1a73a323a1f..e09da8a997f 100644 --- a/libs/langchain_v1/uv.lock +++ b/libs/langchain_v1/uv.lock @@ -2212,6 +2212,7 @@ version = "1.3.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -2224,6 +2225,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -2452,6 +2454,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/a7/ee0f47bfdf0cdc8a319eee66302988cfc741695d99a758ff714f93fc04a9/langchain_perplexity-1.1.0-py3-none-any.whl", hash = "sha256:74165561403869aa4dd01b215ecb051578a9d794e843e8e7cc13999c68ff69b5", size = 11159, upload-time = "2025-11-24T15:05:13.596Z" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/model-profiles/uv.lock b/libs/model-profiles/uv.lock index a54a5826a79..48dd69dfdfa 100644 --- a/libs/model-profiles/uv.lock +++ b/libs/model-profiles/uv.lock @@ -536,6 +536,7 @@ version = "1.3.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -548,6 +549,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -708,6 +710,18 @@ typing = [ { name = "types-tqdm", specifier = ">=4.66.0.5,<5.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langgraph" version = "1.1.6" diff --git a/libs/partners/anthropic/tests/cassettes/test_streaming_tool_call_v1_v2_parity.yaml.gz b/libs/partners/anthropic/tests/cassettes/test_streaming_tool_call_v1_v2_parity.yaml.gz new file mode 100644 index 00000000000..7b453763d14 Binary files /dev/null and b/libs/partners/anthropic/tests/cassettes/test_streaming_tool_call_v1_v2_parity.yaml.gz differ diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index 0c0299ec74d..113ae45c791 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -6,7 +6,7 @@ import asyncio import json import os from base64 import b64encode -from typing import Literal, cast +from typing import Any, Literal, cast import anthropic import httpx @@ -28,6 +28,7 @@ from langchain_core.messages import ( from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import tool +from langchain_tests.utils.stream_lifecycle import assert_valid_event_stream from pydantic import BaseModel, Field from typing_extensions import TypedDict @@ -902,8 +903,17 @@ def test_agent_loop(output_version: Literal["v0", "v1"]) -> None: @pytest.mark.default_cassette("test_agent_loop_streaming.yaml.gz") @pytest.mark.vcr -@pytest.mark.parametrize("output_version", ["v0", "v1"]) -def test_agent_loop_streaming(output_version: Literal["v0", "v1"]) -> None: +@pytest.mark.parametrize( + ("output_version", "use_v2_stream"), + [ + ("v0", False), + ("v1", False), + ("v1", True), + ], +) +def test_agent_loop_streaming( + output_version: Literal["v0", "v1"], *, use_v2_stream: bool +) -> None: @tool def get_weather(location: str) -> str: """Get the weather for a location.""" @@ -916,7 +926,10 @@ def test_agent_loop_streaming(output_version: Literal["v0", "v1"]) -> None: ) llm_with_tools = llm.bind_tools([get_weather]) input_message = HumanMessage("What is the weather in San Francisco, CA?") - tool_call_message = llm_with_tools.invoke([input_message]) + if use_v2_stream: + tool_call_message = llm_with_tools.stream_v2([input_message]).output + else: + tool_call_message = llm_with_tools.invoke([input_message]) assert isinstance(tool_call_message, AIMessage) tool_calls = tool_call_message.tool_calls @@ -924,20 +937,68 @@ def test_agent_loop_streaming(output_version: Literal["v0", "v1"]) -> None: tool_call = tool_calls[0] tool_message = get_weather.invoke(tool_call) assert isinstance(tool_message, ToolMessage) - response = llm_with_tools.invoke( - [ - input_message, - tool_call_message, - tool_message, - ] + if use_v2_stream: + response = llm_with_tools.stream_v2( + [input_message, tool_call_message, tool_message] + ).output + else: + response = llm_with_tools.invoke( + [ + input_message, + tool_call_message, + tool_message, + ] + ) + assert isinstance(response, AIMessage) + + +@pytest.mark.default_cassette("test_agent_loop_streaming.yaml.gz") +@pytest.mark.vcr +async def test_agent_loop_streaming_astream_v2_v1() -> None: + """Async multi-turn through `astream_v2`. + + Mirrors `test_agent_loop_streaming` for `output_version="v1"` but + exercises `AsyncChatModelStream` end-to-end. + """ + + @tool + def get_weather(location: str) -> str: + """Get the weather for a location.""" + return "It's sunny." + + llm = ChatAnthropic( + model=MODEL_NAME, + streaming=True, + output_version="v1", # type: ignore[call-arg] + ) + llm_with_tools = llm.bind_tools([get_weather]) + input_message = HumanMessage("What is the weather in San Francisco, CA?") + tool_call_message = await (await llm_with_tools.astream_v2([input_message])) + assert isinstance(tool_call_message, AIMessage) + tool_calls = tool_call_message.tool_calls + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + tool_message = get_weather.invoke(tool_call) + assert isinstance(tool_message, ToolMessage) + response = await ( + await llm_with_tools.astream_v2( + [input_message, tool_call_message, tool_message] + ) ) assert isinstance(response, AIMessage) @pytest.mark.default_cassette("test_citations.yaml.gz") @pytest.mark.vcr -@pytest.mark.parametrize("output_version", ["v0", "v1"]) -def test_citations(output_version: Literal["v0", "v1"]) -> None: +@pytest.mark.parametrize( + ("output_version", "use_v2_stream"), + [ + ("v0", False), + ("v1", False), + ("v1", True), + ], +) +def test_citations(output_version: Literal["v0", "v1"], *, use_v2_stream: bool) -> None: llm = ChatAnthropic(model=MODEL_NAME, output_version=output_version) # type: ignore[call-arg] messages = [ { @@ -967,10 +1028,19 @@ def test_citations(output_version: Literal["v0", "v1"]) -> None: assert any("citations" in block for block in response.content) # Test streaming - full: BaseMessageChunk | None = None - for chunk in llm.stream(messages): - full = cast("BaseMessageChunk", chunk) if full is None else full + chunk - assert isinstance(full, AIMessageChunk) + full: BaseMessage + if use_v2_stream: + full = llm.stream_v2(messages).output + else: + aggregated: BaseMessageChunk | None = None + for chunk in llm.stream(messages): + aggregated = ( + cast("BaseMessageChunk", chunk) + if aggregated is None + else aggregated + chunk + ) + assert isinstance(aggregated, AIMessageChunk) + full = aggregated assert isinstance(full.content, list) assert not any("citation" in block for block in full.content) if output_version == "v1": @@ -1029,7 +1099,8 @@ def test_thinking() -> None: @pytest.mark.default_cassette("test_thinking.yaml.gz") @pytest.mark.vcr -def test_thinking_v1() -> None: +@pytest.mark.parametrize("use_v2_stream", [False, True]) +def test_thinking_v1(*, use_v2_stream: bool) -> None: llm = ChatAnthropic( model="claude-sonnet-4-5-20250929", # type: ignore[call-arg] max_tokens=5_000, # type: ignore[call-arg] @@ -1051,10 +1122,19 @@ def test_thinking_v1() -> None: assert isinstance(signature, str) # Test streaming - full: BaseMessageChunk | None = None - for chunk in llm.stream([input_message]): - full = cast(BaseMessageChunk, chunk) if full is None else full + chunk - assert isinstance(full, AIMessageChunk) + full: BaseMessage + if use_v2_stream: + full = llm.stream_v2([input_message]).output + else: + aggregated: BaseMessageChunk | None = None + for chunk in llm.stream([input_message]): + aggregated = ( + cast(BaseMessageChunk, chunk) + if aggregated is None + else aggregated + chunk + ) + assert isinstance(aggregated, AIMessageChunk) + full = aggregated assert isinstance(full.content, list) assert any("reasoning" in block for block in full.content) for block in full.content: @@ -2516,3 +2596,96 @@ def test_compaction_streaming() -> None: third_response = llm.invoke(messages) content_blocks = third_response.content_blocks assert [block["type"] for block in content_blocks] == ["text"] + + +class _Person(BaseModel): + """A person with a name and age.""" + + name: str = Field(description="The person's name") + age: int = Field(description="The person's age in years") + + +def _stable_blocks(blocks: Any) -> list[dict[str, Any]]: + """Drop fields that vary between API calls so blocks can be compared. + + Tool-call ids, wire indices, and provider extras are not path- or call- + stable; strip them so the comparison targets the semantic content. + """ + volatile = {"id", "index", "extras"} + return [{k: v for k, v in b.items() if k not in volatile} for b in blocks] + + +@pytest.mark.default_cassette("test_streaming_tool_call_v1_v2_parity.yaml.gz") +@pytest.mark.vcr +def test_streaming_tool_call_v1_v2_parity() -> None: + """`AIMessage` parity between `stream()` reduction and `stream_v2().output`. + + Runs the same forced-tool-call prompt through both the legacy chunk + stream (reduced with `AIMessageChunk.__add__`) and the `stream_v2` + bridge path on a `v1`-output `ChatAnthropic`, then compares the + resulting messages on path-independent invariants: + + - tool call name and args (ids vary between calls and are ignored) + - exactly one tool call, no invalid tool calls + - `content_blocks` (the v1 projection, stripped of volatile fields) + - a valid tool-use `finish_reason` + + The v2 path is additionally validated against the full protocol + lifecycle via `assert_valid_event_stream`. + """ + llm = ChatAnthropic( + model=MODEL_NAME, + output_version="v1", # type: ignore[call-arg] + ) + with_tool = llm.bind_tools( + [_Person], + tool_choice={"type": "tool", "name": "_Person"}, + ) + prompt = "Extract: Erick is 27 years old." + + v1_full: AIMessageChunk | None = None + for chunk in with_tool.stream(prompt): + assert isinstance(chunk, AIMessageChunk) + v1_full = chunk if v1_full is None else v1_full + chunk + assert isinstance(v1_full, AIMessageChunk) + + stream = with_tool.stream_v2(prompt) + events = list(stream) + assert_valid_event_stream(events) + v2_message = stream.output + assert isinstance(v2_message, AIMessage) + + assert len(v1_full.tool_calls) == len(v2_message.tool_calls) == 1 + assert not v1_full.invalid_tool_calls + assert not v2_message.invalid_tool_calls + + v1_tc = v1_full.tool_calls[0] + v2_tc = v2_message.tool_calls[0] + assert v1_tc["name"] == v2_tc["name"] == "_Person" + assert v1_tc["args"] == v2_tc["args"] == {"name": "Erick", "age": 27} + + v1_blocks = _stable_blocks(v1_full.content_blocks) + v2_blocks = _stable_blocks(v2_message.content_blocks) + assert v1_blocks == v2_blocks + assert v1_blocks == [ + { + "type": "tool_call", + "name": "_Person", + "args": {"name": "Erick", "age": 27}, + } + ] + + # The compat bridge passes the provider's raw terminal reason through + # unchanged — Anthropic surfaces it under `stop_reason` on both paths. + # Accept either key on both sides rather than asserting a specific + # normalization that the bridge does not perform. + v1_finish = v1_full.response_metadata.get( + "finish_reason" + ) or v1_full.response_metadata.get("stop_reason") + v2_finish = v2_message.response_metadata.get( + "finish_reason" + ) or v2_message.response_metadata.get("stop_reason") + assert v1_finish is not None + assert v2_finish is not None + assert any(k in v1_finish for k in ("tool_use", "tool_calls", "stop")) + assert any(k in v2_finish for k in ("tool_use", "tool_calls", "stop")) diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 1a8db371380..b861cc1acaf 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -2843,3 +2843,161 @@ def test_no_task_budget_no_beta() -> None: betas = payload.get("betas") if betas: assert "task-budgets-2026-03-13" not in betas + + +def test_anthropic_stream_v2_lifecycle() -> None: + """Validate lifecycle events across a thinking + text + tool_use stream. + + Anthropic emits raw `content_block_start` / `content_block_delta` / + `content_block_stop` events with integer `index` fields, interleaved + with `message_start` and `message_delta`. This test threads a + realistic event sequence through `_stream` via a mocked raw client + and asserts that `stream_v2` produces a spec-conformant event + stream: paired start/finish per block, no interleaving, sequential + `uint` wire indices. + """ + from unittest.mock import patch + + from anthropic.types import ( + InputJSONDelta, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + RawContentBlockStopEvent, + RawMessageDeltaEvent, + RawMessageStartEvent, + RawMessageStopEvent, + TextDelta, + ThinkingBlock, + ThinkingDelta, + ToolUseBlock, + ) + from anthropic.types.raw_message_delta_event import Delta as RawMessageDelta + from anthropic.types.raw_message_delta_event import ( + MessageDeltaUsage as RawMessageDeltaUsage, + ) + from langchain_tests.utils.stream_lifecycle import assert_valid_event_stream + + msg = Message( + id="msg_1", + content=[], + model=MODEL_NAME, + role="assistant", + stop_reason=None, + stop_sequence=None, + usage=Usage(input_tokens=10, output_tokens=0), + type="message", + ) + + events = [ + RawMessageStartEvent(message=msg, type="message_start"), + # thinking block (index=0) + RawContentBlockStartEvent( + content_block=ThinkingBlock(signature="", thinking="", type="thinking"), + index=0, + type="content_block_start", + ), + RawContentBlockDeltaEvent( + delta=ThinkingDelta(thinking="Let me ", type="thinking_delta"), + index=0, + type="content_block_delta", + ), + RawContentBlockDeltaEvent( + delta=ThinkingDelta(thinking="think.", type="thinking_delta"), + index=0, + type="content_block_delta", + ), + RawContentBlockStopEvent(index=0, type="content_block_stop"), + # text block (index=1) + RawContentBlockStartEvent( + content_block=TextBlock(text="", type="text"), + index=1, + type="content_block_start", + ), + RawContentBlockDeltaEvent( + delta=TextDelta(text="The answer ", type="text_delta"), + index=1, + type="content_block_delta", + ), + RawContentBlockDeltaEvent( + delta=TextDelta(text="is 42.", type="text_delta"), + index=1, + type="content_block_delta", + ), + RawContentBlockStopEvent(index=1, type="content_block_stop"), + # tool_use block (index=2) + RawContentBlockStartEvent( + content_block=ToolUseBlock( + id="toolu_1", + input={}, + name="search", + type="tool_use", + ), + index=2, + type="content_block_start", + ), + RawContentBlockDeltaEvent( + delta=InputJSONDelta(partial_json='{"q":', type="input_json_delta"), + index=2, + type="content_block_delta", + ), + RawContentBlockDeltaEvent( + delta=InputJSONDelta(partial_json=' "weather"}', type="input_json_delta"), + index=2, + type="content_block_delta", + ), + RawContentBlockStopEvent(index=2, type="content_block_stop"), + # message_delta with final usage and stop_reason + RawMessageDeltaEvent( + delta=RawMessageDelta(stop_reason="tool_use", stop_sequence=None), + type="message_delta", + usage=RawMessageDeltaUsage( + output_tokens=50, + input_tokens=10, + cache_read_input_tokens=0, + cache_creation_input_tokens=0, + ), + ), + RawMessageStopEvent(type="message_stop"), + ] + + # Enable thinking so `coerce_content_to_string=False` in `_stream`, + # which gives every content block an integer `index` field — the + # structured path the protocol bridge actually exercises. Default + # (no tools / thinking / documents) coerces text to a plain string, + # which strips indices and is a separate code path not covered here. + llm = ChatAnthropic( + model=MODEL_NAME, + thinking={"type": "enabled", "budget_tokens": 1024}, + ) + + def mock_create(_payload: Any) -> list: + return events + + with patch.object(llm, "_create", mock_create): + stream_events = list(llm.stream_v2("Test query")) + + assert_valid_event_stream(stream_events) + + finishes = [e for e in stream_events if e["event"] == "content-block-finish"] + types = [f["content_block"]["type"] for f in finishes] + assert types == ["reasoning", "text", "tool_call"] + + wire_indices = [f["index"] for f in finishes] + assert wire_indices == [0, 1, 2] + + # Content accumulation reaches content-block-finish intact. + reasoning_block = cast("dict[str, Any]", finishes[0]["content_block"]) + text_block = cast("dict[str, Any]", finishes[1]["content_block"]) + tool_block = cast("dict[str, Any]", finishes[2]["content_block"]) + assert reasoning_block["reasoning"] == "Let me think." + assert text_block["text"] == "The answer is 42." + assert tool_block["args"] == {"q": "weather"} + assert tool_block["name"] == "search" + + # message-finish carries the tool_use stop reason inside metadata + # (protocol 0.0.9 moved the finish reason off the top-level event + # and into `metadata`, where the bridge deposits the provider's raw + # `stop_reason` alongside other response metadata). + message_finish = stream_events[-1] + assert message_finish["event"] == "message-finish" + assert message_finish["metadata"]["stop_reason"] == "tool_use" diff --git a/libs/partners/anthropic/uv.lock b/libs/partners/anthropic/uv.lock index ed5cd57eeb3..8714343b2e4 100644 --- a/libs/partners/anthropic/uv.lock +++ b/libs/partners/anthropic/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", @@ -660,10 +660,11 @@ typing = [ [[package]] name = "langchain-core" -version = "1.3.0a3" +version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -676,6 +677,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -718,6 +720,18 @@ typing = [ { name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/chroma/uv.lock b/libs/partners/chroma/uv.lock index fb15799cf86..7e17280c880 100644 --- a/libs/partners/chroma/uv.lock +++ b/libs/partners/chroma/uv.lock @@ -828,6 +828,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -840,6 +841,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -882,6 +884,18 @@ typing = [ { name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/deepseek/uv.lock b/libs/partners/deepseek/uv.lock index 85ea0e972b3..55b04bb24d1 100644 --- a/libs/partners/deepseek/uv.lock +++ b/libs/partners/deepseek/uv.lock @@ -395,6 +395,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -407,6 +408,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -548,6 +550,18 @@ typing = [ { name = "types-tqdm", specifier = ">=4.66.0.5,<5.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/exa/uv.lock b/libs/partners/exa/uv.lock index 57af0d208b0..12b8e4a5a03 100644 --- a/libs/partners/exa/uv.lock +++ b/libs/partners/exa/uv.lock @@ -421,6 +421,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -433,6 +434,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -535,6 +537,18 @@ typing = [ { name = "mypy", specifier = ">=1.10.0,<2.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/fireworks/uv.lock b/libs/partners/fireworks/uv.lock index 1aee72cd371..15e33175cdd 100644 --- a/libs/partners/fireworks/uv.lock +++ b/libs/partners/fireworks/uv.lock @@ -701,6 +701,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -713,6 +714,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -823,6 +825,18 @@ typing = [ { name = "types-requests", specifier = ">=2.0.0,<3.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/groq/uv.lock b/libs/partners/groq/uv.lock index b7f452e9221..f5c8a067cb6 100644 --- a/libs/partners/groq/uv.lock +++ b/libs/partners/groq/uv.lock @@ -339,6 +339,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -351,6 +352,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -452,6 +454,18 @@ typing = [ { name = "mypy", specifier = ">=1.10.0,<2.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/huggingface/uv.lock b/libs/partners/huggingface/uv.lock index 9aa39b2c910..5850ba81d81 100644 --- a/libs/partners/huggingface/uv.lock +++ b/libs/partners/huggingface/uv.lock @@ -1057,6 +1057,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -1069,6 +1070,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -1193,6 +1195,18 @@ typing = [ { name = "mypy", specifier = ">=1.10.0,<2.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/mistralai/uv.lock b/libs/partners/mistralai/uv.lock index 880bac06de7..346559cb042 100644 --- a/libs/partners/mistralai/uv.lock +++ b/libs/partners/mistralai/uv.lock @@ -374,6 +374,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -386,6 +387,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -486,6 +488,18 @@ typing = [ { name = "mypy", specifier = ">=1.10.0,<2.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/nomic/uv.lock b/libs/partners/nomic/uv.lock index 85a5938bc8a..397b5403ef6 100644 --- a/libs/partners/nomic/uv.lock +++ b/libs/partners/nomic/uv.lock @@ -387,6 +387,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -399,6 +400,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -503,6 +505,18 @@ typing = [ { name = "mypy", specifier = ">=1.18.1,<1.19.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/ollama/uv.lock b/libs/partners/ollama/uv.lock index eb82b7cec36..d95ae498817 100644 --- a/libs/partners/ollama/uv.lock +++ b/libs/partners/ollama/uv.lock @@ -313,6 +313,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -325,6 +326,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -423,6 +425,18 @@ typing = [ { name = "ty", specifier = ">=0.0.1,<1.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/openai/tests/cassettes/test_reasoning_text_v1_v2_parity.yaml.gz b/libs/partners/openai/tests/cassettes/test_reasoning_text_v1_v2_parity.yaml.gz new file mode 100644 index 00000000000..774075e007b Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_reasoning_text_v1_v2_parity.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_streaming_tool_call_v1_v2_parity.yaml.gz b/libs/partners/openai/tests/cassettes/test_streaming_tool_call_v1_v2_parity.yaml.gz new file mode 100644 index 00000000000..d16c3989057 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_streaming_tool_call_v1_v2_parity.yaml.gz differ diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 473da166922..d939877346b 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -22,7 +22,8 @@ from langchain_core.messages import ( ToolMessage, ) from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult -from pydantic import BaseModel, field_validator +from langchain_tests.utils.stream_lifecycle import assert_valid_event_stream +from pydantic import BaseModel, Field, field_validator from typing_extensions import TypedDict from langchain_openai import ChatOpenAI @@ -1277,3 +1278,63 @@ async def test_schema_parsing_failures_responses_api_async() -> None: assert e.response is not None # type: ignore[attr-defined] else: raise AssertionError + + +class _Person(BaseModel): + """A person with a name and age.""" + + name: str = Field(description="The person's name") + age: int = Field(description="The person's age in years") + + +@pytest.mark.vcr +def test_streaming_tool_call_v1_v2_parity() -> None: + """`stream()` and `stream_v2()` must agree on their final `AIMessage`. + + Both paths are invoked against the same HTTP response (the cassette's + single recorded interaction, replayed for both calls via + `allow_playback_repeats=True`). Any remaining divergence is a real + library issue, not a difference between two LLM calls. + """ + llm = ChatOpenAI( + model="gpt-4o-mini", + temperature=0, + output_version="v1", + ) + with_tool = llm.bind_tools([_Person], tool_choice="_Person") + prompt = "Extract: Erick is 27 years old." + + v1: AIMessageChunk | None = None + for chunk in with_tool.stream(prompt): + assert isinstance(chunk, AIMessageChunk) + v1 = chunk if v1 is None else v1 + chunk + assert isinstance(v1, AIMessageChunk) + + stream = with_tool.stream_v2(prompt) + events = list(stream) + assert_valid_event_stream(events) + v2 = stream.output + assert isinstance(v2, AIMessage) + + assert v1.tool_calls == v2.tool_calls + assert v1.invalid_tool_calls == v2.invalid_tool_calls + assert v1.content_blocks == v2.content_blocks + + # `usage_metadata` top-level counts must match. The detail dicts + # (`input_token_details`, `output_token_details`) survive in v1 but + # are dropped by the bridge's `_to_protocol_usage` because + # `langchain_protocol.UsageInfo` has no fields for them. Tracked + # as a protocol-repo change; compare counts strictly for now. + detail_keys = {"input_token_details", "output_token_details"} + v1_usage = { + k: v for k, v in (v1.usage_metadata or {}).items() if k not in detail_keys + } + v2_usage = { + k: v for k, v in (v2.usage_metadata or {}).items() if k not in detail_keys + } + assert v1_usage == v2_usage + + # `response_metadata` must match exactly: the bridge passes the + # provider's raw `finish_reason` through without normalization, so + # OpenAI's `"stop"` on a forced tool call appears on both paths. + assert v1.response_metadata == v2.response_metadata diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py index 219a3a748e8..33531867c8f 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -25,6 +25,7 @@ from langchain_core.messages import ( ) from langchain_core.tools import tool from langchain_core.utils.function_calling import convert_to_openai_tool +from langchain_tests.utils.stream_lifecycle import assert_valid_event_stream from pydantic import BaseModel from typing_extensions import TypedDict @@ -91,8 +92,17 @@ def test_incomplete_response() -> None: @pytest.mark.default_cassette("test_web_search.yaml.gz") @pytest.mark.vcr -@pytest.mark.parametrize("output_version", ["responses/v1", "v1"]) -def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None: +@pytest.mark.parametrize( + ("output_version", "use_v2_stream"), + [ + ("responses/v1", False), + ("v1", False), + ("v1", True), + ], +) +def test_web_search( + output_version: Literal["responses/v1", "v1"], use_v2_stream: bool +) -> None: llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version) first_response = llm.invoke( "What was a positive news story from today?", @@ -101,13 +111,22 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None: _check_response(first_response) # Test streaming - full: BaseMessageChunk | None = None - for chunk in llm.stream( - "What was a positive news story from today?", - tools=[{"type": "web_search_preview"}], - ): - assert isinstance(chunk, AIMessageChunk) - full = chunk if full is None else full + chunk + full: BaseMessage + if use_v2_stream: + full = llm.stream_v2( + "What was a positive news story from today?", + tools=[{"type": "web_search_preview"}], + ).output + else: + aggregated: BaseMessageChunk | None = None + for chunk in llm.stream( + "What was a positive news story from today?", + tools=[{"type": "web_search_preview"}], + ): + assert isinstance(chunk, AIMessageChunk) + aggregated = chunk if aggregated is None else aggregated + chunk + assert aggregated is not None + full = aggregated _check_response(full) # Use OpenAI's stateful API @@ -238,8 +257,18 @@ def test_agent_loop(output_version: Literal["responses/v1", "v1"]) -> None: @pytest.mark.default_cassette("test_agent_loop_streaming.yaml.gz") @pytest.mark.vcr -@pytest.mark.parametrize("output_version", ["responses/v1", "v1"]) -def test_agent_loop_streaming(output_version: Literal["responses/v1", "v1"]) -> None: +@pytest.mark.parametrize( + ("output_version", "use_v2_stream"), + [ + ("responses/v1", False), + ("responses/v1", True), + ("v1", False), + ("v1", True), + ], +) +def test_agent_loop_streaming( + output_version: Literal["responses/v1", "v1"], use_v2_stream: bool +) -> None: @tool def get_weather(location: str) -> str: """Get the weather for a location.""" @@ -254,20 +283,70 @@ def test_agent_loop_streaming(output_version: Literal["responses/v1", "v1"]) -> ) llm_with_tools = llm.bind_tools([get_weather]) input_message = HumanMessage("What is the weather in San Francisco, CA?") - tool_call_message = llm_with_tools.invoke([input_message]) + if use_v2_stream: + tool_call_message = llm_with_tools.stream_v2([input_message]).output + else: + tool_call_message = llm_with_tools.invoke([input_message]) assert isinstance(tool_call_message, AIMessage) tool_calls = tool_call_message.tool_calls assert len(tool_calls) == 1 tool_call = tool_calls[0] tool_message = get_weather.invoke(tool_call) assert isinstance(tool_message, ToolMessage) - response = llm_with_tools.invoke( - [ - input_message, - tool_call_message, - tool_message, - ] + if use_v2_stream: + response = llm_with_tools.stream_v2( + [input_message, tool_call_message, tool_message] + ).output + else: + response = llm_with_tools.invoke( + [ + input_message, + tool_call_message, + tool_message, + ] + ) + assert isinstance(response, AIMessage) + + +@pytest.mark.default_cassette("test_agent_loop_streaming.yaml.gz") +@pytest.mark.vcr +async def test_agent_loop_streaming_astream_v2_v1() -> None: + """Async multi-turn through `astream_v2`. + + Mirrors `test_agent_loop_streaming` for `output_version="v1"` but + exercises `AsyncChatModelStream` end-to-end: aggregation in the + async state machine, async projections, and the background + producer task. Cassette byte-matches guarantee the aggregated + message serializes identically to the legacy path on the + follow-up turn. + """ + + @tool + def get_weather(location: str) -> str: + """Get the weather for a location.""" + return "It's sunny." + + llm = ChatOpenAI( + model="gpt-5.2", + use_responses_api=True, + reasoning={"effort": "medium", "summary": "auto"}, + streaming=True, + output_version="v1", ) + llm_with_tools = llm.bind_tools([get_weather]) + input_message = HumanMessage("What is the weather in San Francisco, CA?") + stream = await llm_with_tools.astream_v2([input_message]) + tool_call_message = await stream + assert isinstance(tool_call_message, AIMessage) + tool_calls = tool_call_message.tool_calls + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + tool_message = get_weather.invoke(tool_call) + assert isinstance(tool_message, ToolMessage) + stream = await llm_with_tools.astream_v2( + [input_message, tool_call_message, tool_message] + ) + response = await stream assert isinstance(response, AIMessage) @@ -543,9 +622,18 @@ def test_file_search( @pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz") @pytest.mark.vcr -@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) +@pytest.mark.parametrize( + ("output_version", "use_v2_stream"), + [ + ("v0", False), + ("responses/v1", False), + ("v1", False), + ("v1", True), + ], +) def test_stream_reasoning_summary( output_version: Literal["v0", "responses/v1", "v1"], + use_v2_stream: bool, ) -> None: llm = ChatOpenAI( model="o4-mini", @@ -557,11 +645,16 @@ def test_stream_reasoning_summary( "role": "user", "content": "What was the third tallest buliding in the year 2000?", } - response_1: BaseMessageChunk | None = None - for chunk in llm.stream([message_1]): - assert isinstance(chunk, AIMessageChunk) - response_1 = chunk if response_1 is None else response_1 + chunk - assert isinstance(response_1, AIMessageChunk) + response_1: BaseMessage + if use_v2_stream: + response_1 = llm.stream_v2([message_1]).output + else: + aggregated: BaseMessageChunk | None = None + for chunk in llm.stream([message_1]): + assert isinstance(chunk, AIMessageChunk) + aggregated = chunk if aggregated is None else aggregated + chunk + assert isinstance(aggregated, AIMessageChunk) + response_1 = aggregated if output_version == "v0": reasoning = response_1.additional_kwargs["reasoning"] assert set(reasoning.keys()) == {"id", "type", "summary"} @@ -610,8 +703,18 @@ def test_stream_reasoning_summary( @pytest.mark.default_cassette("test_code_interpreter.yaml.gz") @pytest.mark.vcr -@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) -def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -> None: +@pytest.mark.parametrize( + ("output_version", "use_v2_stream"), + [ + ("v0", False), + ("responses/v1", False), + ("v1", False), + ("v1", True), + ], +) +def test_code_interpreter( + output_version: Literal["v0", "responses/v1", "v1"], use_v2_stream: bool +) -> None: llm = ChatOpenAI( model="o4-mini", use_responses_api=True, output_version=output_version ) @@ -664,11 +767,16 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) - [{"type": "code_interpreter", "container": container_id}] ) - full: BaseMessageChunk | None = None - for chunk in llm_with_tools.stream([input_message]): - assert isinstance(chunk, AIMessageChunk) - full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) + full: BaseMessage + if use_v2_stream: + full = llm_with_tools.stream_v2([input_message]).output + else: + aggregated: BaseMessageChunk | None = None + for chunk in llm_with_tools.stream([input_message]): + assert isinstance(chunk, AIMessageChunk) + aggregated = chunk if aggregated is None else aggregated + chunk + assert isinstance(aggregated, AIMessageChunk) + full = aggregated if output_version == "v0": tool_outputs = [ item @@ -796,7 +904,8 @@ def test_mcp_builtin_zdr() -> None: @pytest.mark.default_cassette("test_mcp_builtin_zdr.yaml.gz") @pytest.mark.vcr -def test_mcp_builtin_zdr_v1() -> None: +@pytest.mark.parametrize("use_v2_stream", [False, True]) +def test_mcp_builtin_zdr_v1(use_v2_stream: bool) -> None: llm = ChatOpenAI( model="gpt-5-nano", output_version="v1", @@ -822,12 +931,18 @@ def test_mcp_builtin_zdr_v1() -> None: "spec (modelcontextprotocol/modelcontextprotocol) support?" ), } - full: BaseMessageChunk | None = None - for chunk in llm_with_tools.stream([input_message]): - assert isinstance(chunk, AIMessageChunk) - full = chunk if full is None else full + chunk + full: BaseMessage + if use_v2_stream: + full = llm_with_tools.stream_v2([input_message]).output + else: + aggregated: BaseMessageChunk | None = None + for chunk in llm_with_tools.stream([input_message]): + assert isinstance(chunk, AIMessageChunk) + aggregated = chunk if aggregated is None else aggregated + chunk + assert isinstance(aggregated, AIMessageChunk) + full = aggregated - assert isinstance(full, AIMessageChunk) + assert isinstance(full, AIMessage) assert all(isinstance(block, dict) for block in full.content) approval_message = HumanMessage( @@ -1229,8 +1344,17 @@ def test_compaction(output_version: Literal["responses/v1", "v1"]) -> None: @pytest.mark.default_cassette("test_compaction_streaming.yaml.gz") @pytest.mark.vcr -@pytest.mark.parametrize("output_version", ["responses/v1", "v1"]) -def test_compaction_streaming(output_version: Literal["responses/v1", "v1"]) -> None: +@pytest.mark.parametrize( + ("output_version", "use_v2_stream"), + [ + ("responses/v1", False), + ("v1", False), + ("v1", True), + ], +) +def test_compaction_streaming( + output_version: Literal["responses/v1", "v1"], use_v2_stream: bool +) -> None: """Test the compaction beta feature.""" llm = ChatOpenAI( model="gpt-5.2", @@ -1239,13 +1363,20 @@ def test_compaction_streaming(output_version: Literal["responses/v1", "v1"]) -> streaming=True, ) + def _run(messages: list) -> AIMessage: + if use_v2_stream: + return llm.stream_v2(messages).output + result = llm.invoke(messages) + assert isinstance(result, AIMessage) + return result + input_message = { "role": "user", "content": f"Generate a one-sentence summary of this:\n\n{'a' * 50000}", } messages: list = [input_message] - first_response = llm.invoke(messages) + first_response = _run(messages) messages.append(first_response) second_message = { @@ -1254,7 +1385,7 @@ def test_compaction_streaming(output_version: Literal["responses/v1", "v1"]) -> } messages.append(second_message) - second_response = llm.invoke(messages) + second_response = _run(messages) messages.append(second_response) content_blocks = second_response.content_blocks @@ -1270,7 +1401,7 @@ def test_compaction_streaming(output_version: Literal["responses/v1", "v1"]) -> "content": "What are we talking about?", } messages.append(third_message) - third_response = llm.invoke(messages) + third_response = _run(messages) assert third_response.text @@ -1610,3 +1741,68 @@ def test_client_executed_tool_search() -> None: assert isinstance(messages[4], ToolMessage) assert messages[5].text + + +@pytest.mark.default_cassette("test_reasoning_text_v1_v2_parity.yaml.gz") +@pytest.mark.vcr +def test_reasoning_text_v1_v2_parity() -> None: + """`stream()` and `stream_v2()` must agree on reasoning + text output. + + Exercises the non-tool-call branch of the parity claim: a reasoning + model (`o4-mini` via the Responses API) produces one or more + `reasoning` blocks followed by a `text` block. Both paths replay the + same recorded HTTP response (cassette with `allow_playback_repeats`), + so any remaining divergence is a library issue. + """ + llm = ChatOpenAI( + model="o4-mini", + reasoning={"effort": "low", "summary": "auto"}, + output_version="v1", + ) + prompt = {"role": "user", "content": "What is the capital of France?"} + + v1: AIMessageChunk | None = None + for chunk in llm.stream([prompt]): + assert isinstance(chunk, AIMessageChunk) + v1 = chunk if v1 is None else v1 + chunk + assert isinstance(v1, AIMessageChunk) + + stream = llm.stream_v2([prompt]) + events = list(stream) + assert_valid_event_stream(events) + v2 = stream.output + assert isinstance(v2, AIMessage) + + # No tool calls on either path. + assert v1.tool_calls == v2.tool_calls == [] + assert v1.invalid_tool_calls == v2.invalid_tool_calls == [] + assert v1.additional_kwargs == v2.additional_kwargs + + # Content structure must match: same block sequence, same accumulated + # text and reasoning payloads, same block identifiers. `content_blocks` + # is the v1-shaped projection and is canonical for both paths. + assert v1.content_blocks == v2.content_blocks + assert v1.content == v2.content + # Sanity-check that we actually exercised the reasoning + text path. + block_types = [b["type"] for b in v1.content_blocks] + assert "reasoning" in block_types + assert "text" in block_types + + # Usage: core counts must match; provider detail subdicts are + # dropped by `_to_protocol_usage` because `langchain_protocol.UsageInfo` + # doesn't list them. Tracked as a protocol-repo change. + detail_keys = {"input_token_details", "output_token_details"} + v1_usage = { + k: v for k, v in (v1.usage_metadata or {}).items() if k not in detail_keys + } + v2_usage = { + k: v for k, v in (v2.usage_metadata or {}).items() if k not in detail_keys + } + assert v1_usage == v2_usage + + # Response metadata must match. The Responses API doesn't put + # `finish_reason` in per-chunk metadata, so neither the v1 reduction + # nor the v2 bridge ends up with one. (Protocol 0.0.10 dropped the + # v2 bridge's default `"stop"` synthesis; provider metadata now + # passes through unchanged.) + assert v1.response_metadata == v2.response_metadata diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 481485242eb..88d1b97aa63 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -616,6 +616,28 @@ def test_openai_stream(mock_openai_completion: list) -> None: assert "stream_options" not in call_kwargs[-1] +def test_openai_stream_v2_lifecycle(mock_openai_completion: list) -> None: + """`stream_v2` on chat completions emits a spec-conformant lifecycle.""" + from langchain_tests.utils.stream_lifecycle import assert_valid_event_stream + + llm = ChatOpenAI(model="gpt-4o") + mock_client = MagicMock() + + def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager: + return MockSyncContextManager(mock_openai_completion) + + mock_client.create = mock_create + with patch.object(llm, "client", mock_client): + events = list(llm.stream_v2("你的名字叫什么?只回答名字")) + + assert_valid_event_stream(events) + # At minimum, a text block with the accumulated answer. + finishes = [e for e in events if e["event"] == "content-block-finish"] + assert len(finishes) >= 1 + text_finishes = [f for f in finishes if f["content_block"]["type"] == "text"] + assert len(text_finishes) == 1 + + @pytest.fixture def mock_completion() -> dict: return { diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py index 3679cf22eab..1a4edfca608 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py @@ -1,11 +1,12 @@ from __future__ import annotations import copy -from typing import Any +from typing import Any, cast from unittest.mock import MagicMock, patch import pytest from langchain_core.messages import AIMessageChunk, BaseMessageChunk +from langchain_tests.utils.stream_lifecycle import assert_valid_event_stream from openai.types.responses import ( ResponseCompletedEvent, ResponseContentPartAddedEvent, @@ -762,6 +763,68 @@ def test_responses_stream(output_version: str, expected_content: list[dict]) -> assert dumped == payload["input"][idx] +def test_responses_stream_v2_emits_reasoning_lifecycle() -> None: + """`stream_v2` must emit `content-block-finish` events for reasoning blocks. + + Regression test: the protocol bridge should surface the full lifecycle + (`content-block-start` / `content-block-delta` / `content-block-finish`) + for every reasoning block observed on the wire, not just text blocks. + """ + llm = ChatOpenAI(model="o4-mini", use_responses_api=True, output_version="v1") + mock_client = MagicMock() + + def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager: + return MockSyncContextManager(responses_stream) + + mock_client.responses.create = mock_create + + with patch.object(llm, "root_client", mock_client): + events = list(llm.stream_v2("test")) + + assert_valid_event_stream(events) + + reasoning_starts = [ + e + for e in events + if e["event"] == "content-block-start" + and e["content_block"]["type"] == "reasoning" + ] + reasoning_finishes = [ + e + for e in events + if e["event"] == "content-block-finish" + and e["content_block"]["type"] == "reasoning" + ] + + # The mock stream carries four reasoning summary parts (two per reasoning + # item, across two reasoning items), which surface as four reasoning + # content blocks in `output_version="v1"`. + assert len(reasoning_starts) == 4, ( + f"expected 4 reasoning start events, got {len(reasoning_starts)}" + ) + all_finish_types = [ + e["content_block"]["type"] + for e in events + if e["event"] == "content-block-finish" + ] + assert len(reasoning_finishes) == 4, ( + f"expected 4 reasoning finish events, got {len(reasoning_finishes)}: " + f"all finish events = {all_finish_types}" + ) + + # Finish events must carry the accumulated reasoning text. + reasoning_texts = [ + cast("dict[str, Any]", f["content_block"])["reasoning"] + for f in reasoning_finishes + ] + assert reasoning_texts == [ + "reasoning block one", + "another reasoning block", + "more reasoning", + "still more reasoning", + ] + + def test_responses_stream_with_image_generation_multiple_calls() -> None: """Test that streaming with image_generation tool works across multiple calls. diff --git a/libs/partners/openai/uv.lock b/libs/partners/openai/uv.lock index b82053c509b..60bf696a97b 100644 --- a/libs/partners/openai/uv.lock +++ b/libs/partners/openai/uv.lock @@ -624,10 +624,11 @@ typing = [ [[package]] name = "langchain-core" -version = "1.3.0" +version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -640,6 +641,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -769,6 +771,18 @@ typing = [ { name = "types-tqdm", specifier = ">=4.66.0.5,<5.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/openrouter/uv.lock b/libs/partners/openrouter/uv.lock index d2e4890eb1b..487ea8ced55 100644 --- a/libs/partners/openrouter/uv.lock +++ b/libs/partners/openrouter/uv.lock @@ -350,6 +350,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -362,6 +363,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -454,6 +456,18 @@ test = [ test-integration = [] typing = [{ name = "mypy", specifier = ">=1.19.1,<2.0.0" }] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/perplexity/uv.lock b/libs/partners/perplexity/uv.lock index 789d1203224..5a9e236d23d 100644 --- a/libs/partners/perplexity/uv.lock +++ b/libs/partners/perplexity/uv.lock @@ -434,10 +434,11 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.3.0a2" +version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -450,6 +451,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -565,6 +567,18 @@ typing = [ { name = "types-tqdm", specifier = ">=4.66.0.5,<5.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/partners/qdrant/uv.lock b/libs/partners/qdrant/uv.lock index 1d01fbf3c89..fff7181363d 100644 --- a/libs/partners/qdrant/uv.lock +++ b/libs/partners/qdrant/uv.lock @@ -532,6 +532,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -544,6 +545,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -586,6 +588,18 @@ typing = [ { name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-qdrant" version = "1.1.0" diff --git a/libs/partners/xai/uv.lock b/libs/partners/xai/uv.lock index 96c552cd757..7031485cf3f 100644 --- a/libs/partners/xai/uv.lock +++ b/libs/partners/xai/uv.lock @@ -680,6 +680,7 @@ version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -692,6 +693,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -784,6 +786,18 @@ typing = [ { name = "types-tqdm", specifier = ">=4.66.0.5,<5.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index 3b49ac45765..02d842ca0ad 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -13,6 +13,10 @@ import httpx import pytest from langchain_core.callbacks import BaseCallbackHandler from langchain_core.language_models import BaseChatModel, GenericFakeChatModel +from langchain_core.language_models.chat_model_stream import ( + AsyncChatModelStream, + ChatModelStream, +) from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -35,6 +39,7 @@ from typing_extensions import TypedDict, override from langchain_tests.unit_tests.chat_models import ChatModelTests from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION +from langchain_tests.utils.stream_lifecycle import assert_valid_event_stream if TYPE_CHECKING: from pytest_benchmark.fixture import ( @@ -907,6 +912,69 @@ class ChatModelIntegrationTests(ChatModelTests): f"got {last_chunk.chunk_position!r}" ) + def test_stream_v2(self, model: BaseChatModel) -> None: + """Test that `model.stream_v2(simple_message)` works. + + Exercises the content-block-centric streaming protocol. Passing this + test indicates the model participates in `stream_v2` either natively + (via `_stream_chat_model_events`) or through the compat bridge that + converts `_stream` chunks into protocol events. + + ??? question "Troubleshooting" + + First, debug + `langchain_tests.integration_tests.chat_models.ChatModelIntegrationTests.test_stream` + — `stream_v2` falls back to the same `_stream` path via the compat + bridge when the model does not implement + `_stream_chat_model_events`. If `test_stream` passes but this does + not, inspect the raised lifecycle violation: it identifies the + event index and the rule broken. + """ + stream = model.stream_v2("Hello") + assert isinstance(stream, ChatModelStream) + + events = list(stream) + assert len(events) > 0 + assert_valid_event_stream(events) + + message = stream.output + assert isinstance(message, AIMessage) + assert message.content + assert len(message.content_blocks) == 1 + assert message.content_blocks[0]["type"] == "text" + # `stream_v2` always assembles content as v1 protocol blocks. + assert message.response_metadata.get("output_version") == "v1" + + async def test_astream_v2(self, model: BaseChatModel) -> None: + """Test that `await model.astream_v2(simple_message)` works. + + Async counterpart to `test_stream_v2`. Exercises the + `AsyncChatModelStream` path end-to-end: the background producer task, + replay-buffer-backed event iteration, and the awaitable `output` + projection. + + ??? question "Troubleshooting" + + First, debug + `langchain_tests.integration_tests.chat_models.ChatModelIntegrationTests.test_astream`. + If `test_astream` passes but this does not, inspect the raised + lifecycle violation; it identifies the event index and the rule + broken. + """ + stream = await model.astream_v2("Hello") + assert isinstance(stream, AsyncChatModelStream) + + events = [event async for event in stream] + assert len(events) > 0 + assert_valid_event_stream(events) + + message = await stream.output + assert isinstance(message, AIMessage) + assert message.content + assert len(message.content_blocks) == 1 + assert message.content_blocks[0]["type"] == "text" + assert message.response_metadata.get("output_version") == "v1" + def test_invoke_with_model_override(self, model: BaseChatModel) -> None: """Test that model name can be overridden at invoke time via kwargs. diff --git a/libs/standard-tests/langchain_tests/utils/stream_lifecycle.py b/libs/standard-tests/langchain_tests/utils/stream_lifecycle.py new file mode 100644 index 00000000000..29bb02ae9dc --- /dev/null +++ b/libs/standard-tests/langchain_tests/utils/stream_lifecycle.py @@ -0,0 +1,202 @@ +"""Validator for LangChain content-block protocol event streams. + +Checks that an event stream emitted by a chat model (via `stream_v2`, +or by the compat bridge's `chunks_to_events` / `message_to_events`) +conforms to the protocol lifecycle rules: + +- `message-start` opens and `message-finish` closes the stream. +- Content blocks do not interleave: each block runs + `content-block-start` → optional `content-block-delta`s → + `content-block-finish` before the next block begins. +- Wire indices on content-block events are sequential `uint` values + starting at 0. +- For deltaable block types (`text`, `reasoning`, `tool_call_chunk`, + `server_tool_call_chunk`), accumulated delta content matches the + final payload delivered on `content-block-finish`. + +The validator accepts any iterable of protocol event dicts. It raises +`AssertionError` on the first violation with a descriptive message. +""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Iterable + + +_DELTAABLE_TYPES = frozenset( + {"text", "reasoning", "tool_call_chunk", "server_tool_call_chunk"} +) + + +def assert_valid_event_stream(events: Iterable[Any]) -> None: + """Assert that a stream of protocol events obeys the lifecycle contract. + + Args: + events: Iterable of protocol event dicts (as yielded by + `stream_v2` or `chunks_to_events`). + + Raises: + AssertionError: On the first lifecycle violation found. The + message identifies the event index and the specific rule + that was broken. + """ + event_list = list(events) + if not event_list: + return + + first = event_list[0] + assert first["event"] == "message-start", ( + f"first event must be `message-start`, got {first['event']!r}" + ) + message_start_positions = [ + i for i, e in enumerate(event_list) if e["event"] == "message-start" + ] + assert message_start_positions == [0], ( + f"expected exactly one `message-start` at position 0, " + f"got positions {message_start_positions}" + ) + + message_finish_positions = [ + i for i, e in enumerate(event_list) if e["event"] == "message-finish" + ] + assert len(message_finish_positions) <= 1, ( + f"expected at most one `message-finish`, got {len(message_finish_positions)}" + ) + if message_finish_positions: + assert message_finish_positions[0] == len(event_list) - 1, ( + "`message-finish` must be the final event" + ) + + open_idx: int | None = None + expected_next_idx = 0 + start_events: dict[int, dict[str, Any]] = {} + finish_events: dict[int, dict[str, Any]] = {} + delta_accum: dict[int, dict[str, Any]] = {} + + for i, event in enumerate(event_list): + ev = event["event"] + if ev == "message-start": + assert i == 0, f"duplicate `message-start` at event {i}" + continue + if ev == "message-finish": + assert open_idx is None, ( + f"`message-finish` while block {open_idx} still open (event {i})" + ) + continue + if ev == "error": + continue + if ev == "content-block-start": + idx = event["index"] + assert isinstance(idx, int), ( + f"content-block-start wire index must be an int, " + f"got {idx!r} at event {i}" + ) + assert idx >= 0, ( + f"content-block-start wire index must be non-negative, " + f"got {idx} at event {i}" + ) + assert idx == expected_next_idx, ( + f"expected next wire index {expected_next_idx}, got {idx} at event {i}" + ) + assert open_idx is None, ( + f"content-block-start at idx={idx} while block {open_idx} " + f"still open (event {i}); blocks must not interleave" + ) + open_idx = idx + start_events[idx] = event["content_block"] + delta_accum[idx] = {} + expected_next_idx += 1 + elif ev == "content-block-delta": + idx = event["index"] + assert idx == open_idx, ( + f"content-block-delta at idx={idx} but currently-open block is " + f"{open_idx} (event {i})" + ) + block = event["content_block"] + _accumulate_delta(delta_accum[idx], block) + elif ev == "content-block-finish": + idx = event["index"] + assert idx == open_idx, ( + f"content-block-finish at idx={idx} but currently-open block is " + f"{open_idx} (event {i})" + ) + finish_events[idx] = event["content_block"] + open_idx = None + else: + # Unknown event types are accepted; the CDDL allows extensions. + continue + + assert open_idx is None, ( + f"block {open_idx} still open at end of stream — no content-block-finish" + ) + missing = set(start_events) - set(finish_events) + assert not missing, ( + f"the following block indices have no content-block-finish event: " + f"{sorted(missing)}" + ) + + for idx, finish_block in finish_events.items(): + _assert_delta_matches_finish(idx, delta_accum[idx], finish_block) + + +def _accumulate_delta(accum: dict[str, Any], block: dict[str, Any]) -> None: + """Fold a delta block into the running accumulator for its index.""" + btype = block.get("type") + if btype not in _DELTAABLE_TYPES: + return + if btype == "text": + accum["text"] = accum.get("text", "") + block.get("text", "") + elif btype == "reasoning": + accum["reasoning"] = accum.get("reasoning", "") + block.get("reasoning", "") + else: # tool_call_chunk / server_tool_call_chunk + accum["args"] = accum.get("args", "") + (block.get("args") or "") + if block.get("id") is not None: + accum["id"] = block["id"] + if block.get("name") is not None: + accum["name"] = block["name"] + + +def _assert_delta_matches_finish( + idx: int, + accum: dict[str, Any], + finish_block: dict[str, Any], +) -> None: + """Assert accumulated delta content is reflected in the finish payload.""" + ftype = finish_block.get("type") + if ftype == "text" and "text" in accum: + assert finish_block.get("text", "") == accum["text"], ( + f"block {idx} text accumulation {accum['text']!r} does not match " + f"finish text {finish_block.get('text', '')!r}" + ) + elif ftype == "reasoning" and "reasoning" in accum: + assert finish_block.get("reasoning", "") == accum["reasoning"], ( + f"block {idx} reasoning accumulation mismatch: " + f"accumulated {accum['reasoning']!r}, finish " + f"{finish_block.get('reasoning', '')!r}" + ) + elif ftype == "tool_call" and "args" in accum: + # tool_call_chunk args are concatenated partial-JSON strings that + # parse to a dict on finish. + try: + parsed = json.loads(accum["args"]) if accum["args"] else {} + except json.JSONDecodeError: + # Finish upgrades malformed args to invalid_tool_call, not + # tool_call — so a tool_call finish implies args parsed cleanly. + parsed = None + assert finish_block.get("args") == parsed, ( + f"block {idx} tool_call args mismatch: accumulated parse " + f"{parsed!r}, finish {finish_block.get('args')!r}" + ) + elif ftype == "server_tool_call" and "args" in accum: + try: + parsed = json.loads(accum["args"]) if accum["args"] else {} + except json.JSONDecodeError: + parsed = None + assert finish_block.get("args") == parsed + + +__all__ = ["assert_valid_event_stream"] diff --git a/libs/standard-tests/uv.lock b/libs/standard-tests/uv.lock index da832493151..ddbfc0c38d8 100644 --- a/libs/standard-tests/uv.lock +++ b/libs/standard-tests/uv.lock @@ -324,10 +324,11 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.3.0a2" +version = "1.3.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -340,6 +341,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -382,6 +384,18 @@ typing = [ { name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-tests" version = "1.1.6" diff --git a/libs/text-splitters/uv.lock b/libs/text-splitters/uv.lock index 05e61dd58de..b8b1acce6e7 100644 --- a/libs/text-splitters/uv.lock +++ b/libs/text-splitters/uv.lock @@ -1190,6 +1190,7 @@ version = "1.3.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" }, + { name = "langchain-protocol" }, { name = "langsmith" }, { name = "packaging" }, { name = "pydantic" }, @@ -1202,6 +1203,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" }, + { name = "langchain-protocol", specifier = ">=0.0.10" }, { name = "langsmith", specifier = ">=0.3.45,<1.0.0" }, { name = "packaging", specifier = ">=23.2.0" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, @@ -1244,6 +1246,18 @@ typing = [ { name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" }, ] +[[package]] +name = "langchain-protocol" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c3/0d3911d3274f097040e92133f18a425980cd4085e72b6cd65add1f25327c/langchain_protocol-0.0.10.tar.gz", hash = "sha256:5bc530e0b350d3a15a3ab6889abb8132692a2c8a15eed536bce46624751acaaf", size = 6528, upload-time = "2026-04-23T17:31:34.212Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/11/6c89bc86b5494cfe29ee23420c398406cc147a09b5cf756e323070e358d7/langchain_protocol-0.0.10-py3-none-any.whl", hash = "sha256:040bb2ae966a06ffcd0051a1d1ca7e4926f12e951e83b07440cb80e0e8e12268", size = 6677, upload-time = "2026-04-23T17:31:33.367Z" }, +] + [[package]] name = "langchain-text-splitters" version = "1.1.2"