mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-19 20:04:11 +00:00
Compare commits
17 Commits
jacob/patc
...
nh/content
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c11d57a86d | ||
|
|
06dcfaa596 | ||
|
|
bf64733f74 | ||
|
|
a8ce29ab8c | ||
|
|
cee4dd3852 | ||
|
|
84e0365438 | ||
|
|
2c449ca1f5 | ||
|
|
63ca3f2831 | ||
|
|
0efc5d538e | ||
|
|
416d55b3d6 | ||
|
|
204c6af2f1 | ||
|
|
6aef1fd4fe | ||
|
|
8773cb8c4e | ||
|
|
a1d331e8f0 | ||
|
|
352a725d5c | ||
|
|
6b203f082d | ||
|
|
14442f4d10 |
@@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_protocol.protocol import MessagesData
|
||||
from tenacity import RetryCallState
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -124,6 +125,43 @@ class LLMManagerMixin:
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def on_stream_event(
|
||||
self,
|
||||
event: MessagesData,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on each protocol event produced by `stream_v2` / `astream_v2`.
|
||||
|
||||
Fires once per `MessagesData` event — `message-start`, per-block
|
||||
`content-block-start` / `content-block-delta` /
|
||||
`content-block-finish`, and `message-finish`. Analogous to
|
||||
`on_llm_new_token` in v1 streaming, but at event granularity rather
|
||||
than chunk: a single chunk can map to multiple events (e.g. a
|
||||
`content-block-start` plus its first `content-block-delta`), and
|
||||
lifecycle boundaries are explicit.
|
||||
|
||||
Fires uniformly whether the provider emits events natively via
|
||||
`_stream_chat_model_events` or goes through the chunk-to-event
|
||||
compat bridge. Observers see the same event stream regardless of
|
||||
how the underlying model produces output.
|
||||
|
||||
Not fired from v1 `stream()` / `astream()`; for those, keep using
|
||||
`on_llm_new_token`. Purely additive — `on_chat_model_start`,
|
||||
`on_llm_end`, and `on_llm_error` still fire around a v2 call as
|
||||
they do around a v1 call.
|
||||
|
||||
Args:
|
||||
event: The protocol event.
|
||||
run_id: The ID of the current run.
|
||||
parent_run_id: The ID of the parent run.
|
||||
tags: The tags.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
|
||||
class ChainManagerMixin:
|
||||
"""Mixin for chain callbacks."""
|
||||
@@ -288,10 +326,10 @@ class CallbackManagerMixin:
|
||||
!!! note
|
||||
|
||||
When overriding this method, the signature **must** include the two
|
||||
required positional arguments ``serialized`` and ``messages``. Avoid
|
||||
using ``*args`` in your override — doing so causes an ``IndexError``
|
||||
in the fallback path when the callback system converts ``messages``
|
||||
to prompt strings for ``on_llm_start``. Always declare the
|
||||
required positional arguments `serialized` and `messages`. Avoid
|
||||
using `*args` in your override — doing so causes an `IndexError`
|
||||
in the fallback path when the callback system converts `messages`
|
||||
to prompt strings for `on_llm_start`. Always declare the
|
||||
signature explicitly:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -557,10 +595,10 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
!!! note
|
||||
|
||||
When overriding this method, the signature **must** include the two
|
||||
required positional arguments ``serialized`` and ``messages``. Avoid
|
||||
using ``*args`` in your override — doing so causes an ``IndexError``
|
||||
in the fallback path when the callback system converts ``messages``
|
||||
to prompt strings for ``on_llm_start``. Always declare the
|
||||
required positional arguments `serialized` and `messages`. Avoid
|
||||
using `*args` in your override — doing so causes an `IndexError`
|
||||
in the fallback path when the callback system converts `messages`
|
||||
to prompt strings for `on_llm_start`. Always declare the
|
||||
signature explicitly:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -652,6 +690,31 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
the error occurred.
|
||||
"""
|
||||
|
||||
async def on_stream_event(
|
||||
self,
|
||||
event: MessagesData,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on each protocol event produced by `astream_v2`.
|
||||
|
||||
See :meth:`LLMManagerMixin.on_stream_event` for the full contract.
|
||||
Fires once per `MessagesData` event at event granularity, uniformly
|
||||
across native and compat-bridge providers, and is purely additive
|
||||
to the existing `on_chat_model_start` / `on_llm_end` /
|
||||
`on_llm_error` callbacks.
|
||||
|
||||
Args:
|
||||
event: The protocol event.
|
||||
run_id: The ID of the current run.
|
||||
parent_run_id: The ID of the parent run.
|
||||
tags: The tags.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
|
||||
@@ -35,6 +35,7 @@ if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_protocol.protocol import MessagesData
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
@@ -747,6 +748,26 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_stream_event(self, event: MessagesData, **kwargs: Any) -> None:
|
||||
"""Run on each protocol event from `stream_v2`.
|
||||
|
||||
Args:
|
||||
event: The protocol event.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_stream_event",
|
||||
"ignore_llm",
|
||||
event,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
"""Async callback manager for LLM run."""
|
||||
@@ -849,6 +870,26 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_stream_event(self, event: MessagesData, **kwargs: Any) -> None:
|
||||
"""Run on each protocol event from `astream_v2`.
|
||||
|
||||
Args:
|
||||
event: The protocol event.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_stream_event",
|
||||
"ignore_llm",
|
||||
event,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
"""Callback manager for chain run."""
|
||||
|
||||
618
libs/core/langchain_core/language_models/_compat_bridge.py
Normal file
618
libs/core/langchain_core/language_models/_compat_bridge.py
Normal file
@@ -0,0 +1,618 @@
|
||||
"""Compat bridge: convert `AIMessageChunk` streams to protocol events.
|
||||
|
||||
The bridge trusts :meth:`AIMessageChunk.content_blocks` as the single
|
||||
protocol view of any chunk. That property runs the three-tier lookup
|
||||
(`output_version == "v1"` short-circuit, registered translator, or
|
||||
best-effort parsing) and returns a `list[ContentBlock]` for every
|
||||
well-formed message — whether the provider is a registered partner, an
|
||||
unregistered community model, or not tagged at all.
|
||||
|
||||
Per-chunk `content_blocks` output is a **delta slice**, not accumulated
|
||||
state: providers in this ecosystem emit SSE-style chunks that each carry
|
||||
their own increment. The bridge therefore forwards each slice straight
|
||||
through as a `content-block-delta` event, and accumulates per-index
|
||||
state only so the final `content-block-finish` event can report a
|
||||
finalized block (e.g. `tool_call_chunk` args parsed to a dict).
|
||||
|
||||
Lifecycle::
|
||||
|
||||
message-start
|
||||
-> content-block-start (first time each index is observed)
|
||||
-> content-block-delta* (per chunk, carrying the slice)
|
||||
-> content-block-finish (finalized block)
|
||||
-> message-finish
|
||||
|
||||
Public API:
|
||||
|
||||
- :func:`chunks_to_events` / :func:`achunks_to_events` — for live streams
|
||||
where chunks arrive over time.
|
||||
- :func:`message_to_events` / :func:`amessage_to_events` — for replaying a
|
||||
finalized :class:`AIMessage` (cache hit, checkpoint restore, graph-node
|
||||
return value) as a synthetic event lifecycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from langchain_protocol.protocol import (
|
||||
ContentBlock,
|
||||
ContentBlockDeltaData,
|
||||
ContentBlockFinishData,
|
||||
ContentBlockStartData,
|
||||
FinalizedContentBlock,
|
||||
FinishReason,
|
||||
InvalidToolCallBlock,
|
||||
MessageFinishData,
|
||||
MessageMetadata,
|
||||
MessagesData,
|
||||
MessageStartData,
|
||||
ReasoningBlock,
|
||||
ServerToolCallBlock,
|
||||
ServerToolCallChunkBlock,
|
||||
TextBlock,
|
||||
ToolCallBlock,
|
||||
ToolCallChunkBlock,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
|
||||
CompatBlock = dict[str, Any]
|
||||
"""Internal working type for a content block.
|
||||
|
||||
The bridge works with plain dicts internally because two separate but
|
||||
structurally similar `ContentBlock` Unions exist — one in
|
||||
:mod:`langchain_core.messages.content` (returned by
|
||||
`msg.content_blocks`), one in :mod:`langchain_protocol.protocol` (the
|
||||
wire/event shape). They are not mypy-compatible despite being
|
||||
near-isomorphic. Passing through `dict[str, Any]` launders between
|
||||
them. See :func:`_to_protocol_block` for the single seam where the
|
||||
laundering cast lives.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type laundering between core and protocol `ContentBlock` unions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _to_protocol_block(block: CompatBlock) -> ContentBlock:
|
||||
"""Narrow an internal working dict to a protocol `ContentBlock`.
|
||||
|
||||
Single seam between the two `ContentBlock` type systems:
|
||||
:mod:`langchain_core.messages.content` (what `msg.content_blocks`
|
||||
returns) and :mod:`langchain_protocol.protocol` (what event payloads
|
||||
require). The two Unions overlap structurally but are nominally
|
||||
distinct to mypy, so we launder through `dict[str, Any]`. When the
|
||||
Unions are unified, this helper and its finalized counterpart can be
|
||||
deleted.
|
||||
"""
|
||||
return cast("ContentBlock", block)
|
||||
|
||||
|
||||
def _to_finalized_block(block: CompatBlock) -> FinalizedContentBlock:
|
||||
"""Counterpart of :func:`_to_protocol_block` for finalized blocks."""
|
||||
return cast("FinalizedContentBlock", block)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block iteration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _iter_protocol_blocks(msg: BaseMessage) -> list[tuple[int, CompatBlock]]:
|
||||
"""Read per-chunk protocol blocks from `msg.content_blocks`.
|
||||
|
||||
Returns `(index, block)` pairs. Block indices come from each
|
||||
block's `index` field when present, falling back to positional.
|
||||
|
||||
For finalized :class:`AIMessage`, also surfaces `invalid_tool_calls`
|
||||
— which `AIMessage.content_blocks` currently omits from its return
|
||||
value even though they are a defined protocol block type.
|
||||
"""
|
||||
try:
|
||||
raw = msg.content_blocks
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
result: list[tuple[int, CompatBlock]] = []
|
||||
for i, block in enumerate(raw):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
raw_idx = block.get("index", i)
|
||||
idx = raw_idx if isinstance(raw_idx, int) else i
|
||||
result.append((idx, dict(block)))
|
||||
|
||||
if not isinstance(msg, AIMessageChunk):
|
||||
# Finalized AIMessage: pull invalid_tool_calls from the dedicated
|
||||
# field — AIMessage.content_blocks does not currently include them.
|
||||
for itc in getattr(msg, "invalid_tool_calls", None) or []:
|
||||
itc_block: CompatBlock = {"type": "invalid_tool_call"}
|
||||
for key in ("id", "name", "args", "error"):
|
||||
if itc.get(key) is not None:
|
||||
itc_block[key] = itc[key]
|
||||
result.append((len(result), itc_block))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-block helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _start_skeleton(block: CompatBlock) -> ContentBlock:
|
||||
"""Empty-content placeholder for the `content-block-start` event.
|
||||
|
||||
Deltaable block types (text, reasoning, the `_chunk` tool variants)
|
||||
get an empty payload so the lifecycle's "start" signal is distinct
|
||||
from the first incremental delta. Self-contained or already-finalized
|
||||
block types pass through unchanged — their `start` event is also
|
||||
their only content-bearing event.
|
||||
"""
|
||||
btype = block.get("type", "text")
|
||||
if btype == "text":
|
||||
return TextBlock(type="text", text="")
|
||||
if btype == "reasoning":
|
||||
return ReasoningBlock(type="reasoning", reasoning="")
|
||||
if btype == "tool_call_chunk":
|
||||
skel = ToolCallChunkBlock(type="tool_call_chunk", args="")
|
||||
if block.get("id") is not None:
|
||||
skel["id"] = block["id"]
|
||||
if block.get("name") is not None:
|
||||
skel["name"] = block["name"]
|
||||
return skel
|
||||
if btype == "server_tool_call_chunk":
|
||||
s_skel = ServerToolCallChunkBlock(
|
||||
type="server_tool_call_chunk",
|
||||
args="",
|
||||
)
|
||||
if block.get("id") is not None:
|
||||
s_skel["id"] = block["id"]
|
||||
if block.get("name") is not None:
|
||||
s_skel["name"] = block["name"]
|
||||
return s_skel
|
||||
return _to_protocol_block(block)
|
||||
|
||||
|
||||
def _should_emit_delta(block: CompatBlock) -> bool:
|
||||
"""Whether a per-chunk block carries content worth a delta event.
|
||||
|
||||
Deltaable types emit only when they have fresh content. Self-contained
|
||||
/ already-finalized types skip the delta entirely — the `finish`
|
||||
event carries them.
|
||||
"""
|
||||
btype = block.get("type")
|
||||
if btype == "text":
|
||||
return bool(block.get("text"))
|
||||
if btype == "reasoning":
|
||||
return bool(block.get("reasoning"))
|
||||
if btype in ("tool_call_chunk", "server_tool_call_chunk"):
|
||||
return bool(
|
||||
block.get("args") or block.get("id") or block.get("name"),
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _accumulate(state: CompatBlock | None, delta: CompatBlock) -> CompatBlock:
|
||||
"""Merge a per-chunk delta slice into accumulated per-index state.
|
||||
|
||||
Used only for the finalization pass — live delta events are emitted
|
||||
directly from the per-chunk block, without round-tripping through
|
||||
accumulated state.
|
||||
"""
|
||||
if state is None:
|
||||
return dict(delta)
|
||||
btype = state.get("type")
|
||||
dtype = delta.get("type")
|
||||
if btype == "text" and dtype == "text":
|
||||
state["text"] = state.get("text", "") + delta.get("text", "")
|
||||
elif btype == "reasoning" and dtype == "reasoning":
|
||||
state["reasoning"] = state.get("reasoning", "") + delta.get("reasoning", "")
|
||||
elif btype in ("tool_call_chunk", "server_tool_call_chunk") and dtype == btype:
|
||||
state["args"] = state.get("args", "") + (delta.get("args") or "")
|
||||
if delta.get("id") is not None:
|
||||
state["id"] = delta["id"]
|
||||
if delta.get("name") is not None:
|
||||
state["name"] = delta["name"]
|
||||
else:
|
||||
# Self-contained or already-finalized types: replace wholesale.
|
||||
state.clear()
|
||||
state.update(delta)
|
||||
return state
|
||||
|
||||
|
||||
def _finalize_block(block: CompatBlock) -> FinalizedContentBlock:
|
||||
"""Promote chunk variants to their finalized form.
|
||||
|
||||
`tool_call_chunk` becomes `tool_call` — or `invalid_tool_call`
|
||||
if the accumulated `args` don't parse as JSON.
|
||||
`server_tool_call_chunk` becomes `server_tool_call` under the same
|
||||
rule. Everything else passes through: text/reasoning blocks carry
|
||||
their accumulated snapshot, and self-contained types are already in
|
||||
their terminal shape.
|
||||
"""
|
||||
btype = block.get("type")
|
||||
if btype in ("tool_call_chunk", "server_tool_call_chunk"):
|
||||
raw = block.get("args") or "{}"
|
||||
try:
|
||||
parsed = json.loads(raw) if raw else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
invalid = InvalidToolCallBlock(
|
||||
type="invalid_tool_call",
|
||||
args=raw,
|
||||
error="Failed to parse tool call arguments as JSON",
|
||||
)
|
||||
if block.get("id") is not None:
|
||||
invalid["id"] = block["id"]
|
||||
if block.get("name") is not None:
|
||||
invalid["name"] = block["name"]
|
||||
return invalid
|
||||
if btype == "tool_call_chunk":
|
||||
return ToolCallBlock(
|
||||
type="tool_call",
|
||||
id=block.get("id", ""),
|
||||
name=block.get("name", ""),
|
||||
args=parsed,
|
||||
)
|
||||
return ServerToolCallBlock(
|
||||
type="server_tool_call",
|
||||
id=block.get("id", ""),
|
||||
name=block.get("name", ""),
|
||||
args=parsed,
|
||||
)
|
||||
return _to_finalized_block(block)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metadata, usage, finish-reason
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_start_metadata(response_metadata: dict[str, Any]) -> MessageMetadata:
|
||||
"""Pull provider/model hints for the `message-start` event."""
|
||||
metadata: MessageMetadata = {}
|
||||
if "model_provider" in response_metadata:
|
||||
metadata["provider"] = response_metadata["model_provider"]
|
||||
if "model_name" in response_metadata:
|
||||
metadata["model"] = response_metadata["model_name"]
|
||||
return metadata
|
||||
|
||||
|
||||
def _normalize_finish_reason(value: Any) -> FinishReason:
|
||||
"""Map provider-specific stop reasons to protocol finish reasons."""
|
||||
if value == "length":
|
||||
return "length"
|
||||
if value == "content_filter":
|
||||
return "content_filter"
|
||||
if value in ("tool_use", "tool_calls"):
|
||||
return "tool_use"
|
||||
return "stop"
|
||||
|
||||
|
||||
def _accumulate_usage(
|
||||
current: dict[str, Any] | None, delta: Any
|
||||
) -> dict[str, Any] | None:
|
||||
"""Sum usage counts and merge detail dicts across chunks."""
|
||||
if not isinstance(delta, dict):
|
||||
return current
|
||||
if current is None:
|
||||
return dict(delta)
|
||||
for key in ("input_tokens", "output_tokens", "total_tokens", "cached_tokens"):
|
||||
if key in delta:
|
||||
current[key] = current.get(key, 0) + delta[key]
|
||||
for detail_key in ("input_token_details", "output_token_details"):
|
||||
if detail_key in delta and isinstance(delta[detail_key], dict):
|
||||
if detail_key not in current:
|
||||
current[detail_key] = {}
|
||||
current[detail_key].update(delta[detail_key])
|
||||
return current
|
||||
|
||||
|
||||
def _to_protocol_usage(usage: dict[str, Any] | None) -> UsageInfo | None:
|
||||
"""Convert accumulated usage to the protocol's `UsageInfo` shape."""
|
||||
if usage is None:
|
||||
return None
|
||||
result: UsageInfo = {}
|
||||
for key in ("input_tokens", "output_tokens", "total_tokens", "cached_tokens"):
|
||||
if key in usage:
|
||||
result[key] = usage[key]
|
||||
return result or None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event builders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_message_start(
|
||||
msg: BaseMessage,
|
||||
message_id: str | None,
|
||||
) -> MessageStartData:
|
||||
start_data = MessageStartData(event="message-start", role="ai")
|
||||
resolved_id = message_id if message_id is not None else getattr(msg, "id", None)
|
||||
if resolved_id:
|
||||
start_data["message_id"] = resolved_id
|
||||
start_metadata = _extract_start_metadata(msg.response_metadata or {})
|
||||
if start_metadata:
|
||||
start_data["metadata"] = start_metadata
|
||||
return start_data
|
||||
|
||||
|
||||
def _build_message_finish(
|
||||
*,
|
||||
finish_reason: FinishReason,
|
||||
has_valid_tool_call: bool,
|
||||
usage: dict[str, Any] | None,
|
||||
response_metadata: dict[str, Any] | None,
|
||||
) -> MessageFinishData:
|
||||
# Infer tool_use only from finalized (parsed) tool_calls. An
|
||||
# invalid_tool_call means parsing failed — the model didn't
|
||||
# successfully request a tool, so leave finish_reason alone.
|
||||
if finish_reason == "stop" and has_valid_tool_call:
|
||||
finish_reason = "tool_use"
|
||||
finish_data = MessageFinishData(event="message-finish", reason=finish_reason)
|
||||
usage_info = _to_protocol_usage(usage)
|
||||
if usage_info is not None:
|
||||
finish_data["usage"] = usage_info
|
||||
if response_metadata:
|
||||
metadata = {
|
||||
k: v
|
||||
for k, v in response_metadata.items()
|
||||
if k not in ("finish_reason", "stop_reason")
|
||||
}
|
||||
if metadata:
|
||||
finish_data["metadata"] = metadata
|
||||
return finish_data
|
||||
|
||||
|
||||
def _finish_all_blocks(
|
||||
state: dict[int, CompatBlock],
|
||||
) -> tuple[list[MessagesData], bool]:
|
||||
"""Emit `content-block-finish` events for every open block.
|
||||
|
||||
Returns the event list plus a flag indicating whether any finalized
|
||||
block was a valid `tool_call` (used for finish-reason inference).
|
||||
"""
|
||||
events: list[MessagesData] = []
|
||||
has_valid_tool_call = False
|
||||
for idx in sorted(state):
|
||||
finalized = _finalize_block(state[idx])
|
||||
if finalized.get("type") == "tool_call":
|
||||
has_valid_tool_call = True
|
||||
events.append(
|
||||
ContentBlockFinishData(
|
||||
event="content-block-finish",
|
||||
index=idx,
|
||||
content_block=finalized,
|
||||
)
|
||||
)
|
||||
return events, has_valid_tool_call
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main generators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def chunks_to_events(
|
||||
chunks: Iterator[ChatGenerationChunk],
|
||||
*,
|
||||
message_id: str | None = None,
|
||||
) -> Iterator[MessagesData]:
|
||||
"""Convert a stream of `ChatGenerationChunk` to protocol events.
|
||||
|
||||
Args:
|
||||
chunks: Iterator of `ChatGenerationChunk` from `_stream()`.
|
||||
message_id: Optional stable message ID.
|
||||
|
||||
Yields:
|
||||
`MessagesData` lifecycle events.
|
||||
"""
|
||||
started = False
|
||||
state: dict[int, CompatBlock] = {}
|
||||
first_seen: set[int] = set()
|
||||
usage: dict[str, Any] | None = None
|
||||
response_metadata: dict[str, Any] = {}
|
||||
finish_reason: FinishReason = "stop"
|
||||
|
||||
for chunk in chunks:
|
||||
msg = chunk.message
|
||||
if not isinstance(msg, AIMessageChunk):
|
||||
continue
|
||||
|
||||
if msg.response_metadata:
|
||||
response_metadata.update(msg.response_metadata)
|
||||
|
||||
if not started:
|
||||
started = True
|
||||
yield _build_message_start(msg, message_id)
|
||||
|
||||
for idx, block in _iter_protocol_blocks(msg):
|
||||
if idx not in first_seen:
|
||||
first_seen.add(idx)
|
||||
yield ContentBlockStartData(
|
||||
event="content-block-start",
|
||||
index=idx,
|
||||
content_block=_start_skeleton(block),
|
||||
)
|
||||
if _should_emit_delta(block):
|
||||
yield ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=idx,
|
||||
content_block=_to_protocol_block(block),
|
||||
)
|
||||
state[idx] = _accumulate(state.get(idx), block)
|
||||
|
||||
if msg.usage_metadata:
|
||||
usage = _accumulate_usage(usage, msg.usage_metadata)
|
||||
|
||||
rm = msg.response_metadata or {}
|
||||
raw_reason = rm.get("finish_reason") or rm.get("stop_reason")
|
||||
if raw_reason:
|
||||
finish_reason = _normalize_finish_reason(raw_reason)
|
||||
|
||||
if not started:
|
||||
return
|
||||
|
||||
finish_events, has_valid_tool_call = _finish_all_blocks(state)
|
||||
yield from finish_events
|
||||
yield _build_message_finish(
|
||||
finish_reason=finish_reason,
|
||||
has_valid_tool_call=has_valid_tool_call,
|
||||
usage=usage,
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
|
||||
|
||||
async def achunks_to_events(
|
||||
chunks: AsyncIterator[ChatGenerationChunk],
|
||||
*,
|
||||
message_id: str | None = None,
|
||||
) -> AsyncIterator[MessagesData]:
|
||||
"""Async variant of :func:`chunks_to_events`."""
|
||||
started = False
|
||||
state: dict[int, CompatBlock] = {}
|
||||
first_seen: set[int] = set()
|
||||
usage: dict[str, Any] | None = None
|
||||
response_metadata: dict[str, Any] = {}
|
||||
finish_reason: FinishReason = "stop"
|
||||
|
||||
async for chunk in chunks:
|
||||
msg = chunk.message
|
||||
if not isinstance(msg, AIMessageChunk):
|
||||
continue
|
||||
|
||||
if msg.response_metadata:
|
||||
response_metadata.update(msg.response_metadata)
|
||||
|
||||
if not started:
|
||||
started = True
|
||||
yield _build_message_start(msg, message_id)
|
||||
|
||||
for idx, block in _iter_protocol_blocks(msg):
|
||||
if idx not in first_seen:
|
||||
first_seen.add(idx)
|
||||
yield ContentBlockStartData(
|
||||
event="content-block-start",
|
||||
index=idx,
|
||||
content_block=_start_skeleton(block),
|
||||
)
|
||||
if _should_emit_delta(block):
|
||||
yield ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=idx,
|
||||
content_block=_to_protocol_block(block),
|
||||
)
|
||||
state[idx] = _accumulate(state.get(idx), block)
|
||||
|
||||
if msg.usage_metadata:
|
||||
usage = _accumulate_usage(usage, msg.usage_metadata)
|
||||
|
||||
rm = msg.response_metadata or {}
|
||||
raw_reason = rm.get("finish_reason") or rm.get("stop_reason")
|
||||
if raw_reason:
|
||||
finish_reason = _normalize_finish_reason(raw_reason)
|
||||
|
||||
if not started:
|
||||
return
|
||||
|
||||
finish_events, has_valid_tool_call = _finish_all_blocks(state)
|
||||
for event in finish_events:
|
||||
yield event
|
||||
yield _build_message_finish(
|
||||
finish_reason=finish_reason,
|
||||
has_valid_tool_call=has_valid_tool_call,
|
||||
usage=usage,
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
|
||||
|
||||
def message_to_events(
|
||||
msg: BaseMessage,
|
||||
*,
|
||||
message_id: str | None = None,
|
||||
) -> Iterator[MessagesData]:
|
||||
"""Replay a finalized message as a synthetic event lifecycle.
|
||||
|
||||
For a message returned whole (from a graph node, checkpoint, or
|
||||
cache), produce the same `message-start` / per-block /
|
||||
`message-finish` event stream a live call would produce. Consumers
|
||||
downstream see a uniform event shape regardless of source.
|
||||
|
||||
Text and reasoning blocks emit a single `content-block-delta` with
|
||||
the full accumulated content. Already-finalized blocks (tool_call,
|
||||
server_tool_call, image, etc.) skip the delta and rely on the
|
||||
`content-block-finish` event alone.
|
||||
|
||||
Args:
|
||||
msg: The finalized message — typically an `AIMessage`.
|
||||
message_id: Optional stable message ID; falls back to `msg.id`.
|
||||
|
||||
Yields:
|
||||
`MessagesData` lifecycle events.
|
||||
"""
|
||||
response_metadata = msg.response_metadata or {}
|
||||
yield _build_message_start(msg, message_id)
|
||||
|
||||
has_valid_tool_call = False
|
||||
for idx, block in _iter_protocol_blocks(msg):
|
||||
yield ContentBlockStartData(
|
||||
event="content-block-start",
|
||||
index=idx,
|
||||
content_block=_start_skeleton(block),
|
||||
)
|
||||
if _should_emit_delta(block):
|
||||
yield ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=idx,
|
||||
content_block=_to_protocol_block(block),
|
||||
)
|
||||
finalized = _finalize_block(block)
|
||||
if finalized.get("type") == "tool_call":
|
||||
has_valid_tool_call = True
|
||||
yield ContentBlockFinishData(
|
||||
event="content-block-finish",
|
||||
index=idx,
|
||||
content_block=finalized,
|
||||
)
|
||||
|
||||
raw_reason = response_metadata.get("finish_reason") or response_metadata.get(
|
||||
"stop_reason"
|
||||
)
|
||||
finish_reason: FinishReason = (
|
||||
_normalize_finish_reason(raw_reason) if raw_reason else "stop"
|
||||
)
|
||||
yield _build_message_finish(
|
||||
finish_reason=finish_reason,
|
||||
has_valid_tool_call=has_valid_tool_call,
|
||||
usage=getattr(msg, "usage_metadata", None),
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
|
||||
|
||||
async def amessage_to_events(
|
||||
msg: BaseMessage,
|
||||
*,
|
||||
message_id: str | None = None,
|
||||
) -> AsyncIterator[MessagesData]:
|
||||
"""Async variant of :func:`message_to_events`."""
|
||||
for event in message_to_events(msg, message_id=message_id):
|
||||
yield event
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CompatBlock",
|
||||
"achunks_to_events",
|
||||
"amessage_to_events",
|
||||
"chunks_to_events",
|
||||
"message_to_events",
|
||||
]
|
||||
1011
libs/core/langchain_core/language_models/chat_model_stream.py
Normal file
1011
libs/core/langchain_core/language_models/chat_model_stream.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -24,6 +24,10 @@ from langchain_core.callbacks import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models._compat_bridge import (
|
||||
achunks_to_events,
|
||||
chunks_to_events,
|
||||
)
|
||||
from langchain_core.language_models._utils import (
|
||||
_normalize_messages,
|
||||
_update_message_content_to_blocks,
|
||||
@@ -33,6 +37,10 @@ from langchain_core.language_models.base import (
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.language_models.chat_model_stream import (
|
||||
AsyncChatModelStream,
|
||||
ChatModelStream,
|
||||
)
|
||||
from langchain_core.language_models.model_profile import (
|
||||
ModelProfile,
|
||||
_warn_unknown_profile_keys,
|
||||
@@ -68,7 +76,10 @@ from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPro
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
from langchain_core.tracers._streaming import (
|
||||
_StreamingCallbackHandler,
|
||||
_V2StreamingCallbackHandler,
|
||||
)
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_json_schema,
|
||||
convert_to_openai_tool,
|
||||
@@ -80,6 +91,8 @@ if TYPE_CHECKING:
|
||||
import builtins
|
||||
import uuid
|
||||
|
||||
from langchain_protocol.protocol import MessagesData
|
||||
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -528,6 +541,143 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
handlers = run_manager.handlers if run_manager else []
|
||||
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)
|
||||
|
||||
def _should_stream_v2(
|
||||
self,
|
||||
*,
|
||||
async_api: bool,
|
||||
run_manager: CallbackManagerForLLMRun
|
||||
| AsyncCallbackManagerForLLMRun
|
||||
| None = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Determine whether an invoke should route through the v2 event path.
|
||||
|
||||
Runs alongside `_should_stream` inside `_generate_with_cache` /
|
||||
`_agenerate_with_cache` — after the run manager is open — and
|
||||
wins over the v1 streaming branch when a handler has declared
|
||||
itself a `_V2StreamingCallbackHandler`.
|
||||
|
||||
Args:
|
||||
async_api: Whether the caller is on the async path.
|
||||
run_manager: The active LLM run manager.
|
||||
**kwargs: Call kwargs; inspected for `disable_streaming`
|
||||
semantics and an explicit `stream=False` override.
|
||||
|
||||
Returns:
|
||||
`True` if any attached handler inherits
|
||||
`_V2StreamingCallbackHandler` and the model can drive the v2
|
||||
event generator (natively or via the `_stream` compat
|
||||
bridge).
|
||||
"""
|
||||
# v2 fallback bridges through `_stream` / `_astream`, so streaming
|
||||
# must be implemented for the requested flavor.
|
||||
sync_not_implemented = type(self)._stream == BaseChatModel._stream # noqa: SLF001
|
||||
async_not_implemented = type(self)._astream == BaseChatModel._astream # noqa: SLF001
|
||||
native_sync = getattr(type(self), "_stream_chat_model_events", None) is not None
|
||||
native_async = (
|
||||
getattr(type(self), "_astream_chat_model_events", None) is not None
|
||||
)
|
||||
if not async_api and not (native_sync or not sync_not_implemented):
|
||||
return False
|
||||
if async_api and not (
|
||||
native_async
|
||||
or native_sync
|
||||
or not async_not_implemented
|
||||
or not sync_not_implemented
|
||||
):
|
||||
return False
|
||||
|
||||
if self.disable_streaming is True:
|
||||
return False
|
||||
if self.disable_streaming == "tool_calling" and kwargs.get("tools"):
|
||||
return False
|
||||
if "stream" in kwargs and not kwargs["stream"]:
|
||||
return False
|
||||
|
||||
handlers = run_manager.handlers if run_manager else []
|
||||
return any(isinstance(h, _V2StreamingCallbackHandler) for h in handlers)
|
||||
|
||||
def _iter_v2_events(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
run_manager: CallbackManagerForLLMRun,
|
||||
stream: ChatModelStream,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[MessagesData]:
|
||||
"""Drive the v2 event generator with per-event dispatch.
|
||||
|
||||
Shared between `stream_v2`'s pump and the invoke-time v2 branch
|
||||
in `_generate_with_cache`. Picks the native
|
||||
`_stream_chat_model_events` hook when the subclass provides one,
|
||||
else bridges `_stream` chunks via `chunks_to_events`. Each event
|
||||
is dispatched into `stream` and fired as `on_stream_event` on
|
||||
the run manager. Run-lifecycle callbacks
|
||||
(`on_chat_model_start` / `on_llm_end` / `on_llm_error`) and
|
||||
rate-limiter acquisition are the caller's responsibility.
|
||||
|
||||
Args:
|
||||
messages: Normalized input messages.
|
||||
run_manager: Active LLM run manager; receives
|
||||
`on_stream_event` per event.
|
||||
stream: Accumulator owned by the caller; receives each
|
||||
event via `stream.dispatch`.
|
||||
stop: Optional stop sequences.
|
||||
**kwargs: Forwarded to the event producer.
|
||||
|
||||
Yields:
|
||||
Each protocol event produced by the model.
|
||||
"""
|
||||
native = cast(
|
||||
"Callable[..., Iterator[MessagesData]] | None",
|
||||
getattr(self, "_stream_chat_model_events", None),
|
||||
)
|
||||
if native is not None:
|
||||
event_iter: Iterator[MessagesData] = native(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
event_iter = chunks_to_events(
|
||||
self._stream(messages, stop=stop, run_manager=run_manager, **kwargs),
|
||||
message_id=stream.message_id,
|
||||
)
|
||||
for event in event_iter:
|
||||
stream.dispatch(event)
|
||||
run_manager.on_stream_event(event)
|
||||
yield event
|
||||
|
||||
async def _aiter_v2_events(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForLLMRun,
|
||||
stream: AsyncChatModelStream,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[MessagesData]:
|
||||
"""Async counterpart to :meth:`_iter_v2_events`.
|
||||
|
||||
See :meth:`_iter_v2_events` for the shared contract.
|
||||
"""
|
||||
native = cast(
|
||||
"Callable[..., AsyncIterator[MessagesData]] | None",
|
||||
getattr(self, "_astream_chat_model_events", None),
|
||||
)
|
||||
if native is not None:
|
||||
event_iter: AsyncIterator[MessagesData] = native(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
event_iter = achunks_to_events(
|
||||
self._astream(messages, stop=stop, run_manager=run_manager, **kwargs),
|
||||
message_id=stream.message_id,
|
||||
)
|
||||
async for event in event_iter:
|
||||
stream.dispatch(event)
|
||||
await run_manager.on_stream_event(event)
|
||||
yield event
|
||||
|
||||
@override
|
||||
def stream(
|
||||
self,
|
||||
@@ -784,6 +934,198 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
|
||||
# --- stream_v2 / astream_v2 ---
|
||||
|
||||
def stream_v2(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatModelStream:
|
||||
"""Stream content-block lifecycle events for a single model call.
|
||||
|
||||
Returns a :class:`ChatModelStream` with typed projections
|
||||
(`.text`, `.reasoning`, `.tool_calls`, `.usage`,
|
||||
`.output`).
|
||||
|
||||
.. warning::
|
||||
This API is experimental and may change.
|
||||
|
||||
Args:
|
||||
input: The model input.
|
||||
config: Optional runnable config.
|
||||
stop: Optional list of stop words.
|
||||
**kwargs: Additional keyword arguments passed to the model.
|
||||
|
||||
Returns:
|
||||
A :class:`ChatModelStream` with typed projections.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input).to_messages()
|
||||
input_messages = _normalize_messages(messages)
|
||||
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params_with_defaults(stop=stop, **kwargs),
|
||||
}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
[_format_for_tracing(messages)],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
run_id = "-".join((LC_ID_PREFIX, str(run_manager.run_id)))
|
||||
stream = ChatModelStream(message_id=run_id)
|
||||
|
||||
event_iter_ref = iter(
|
||||
self._iter_v2_events(
|
||||
input_messages,
|
||||
run_manager=run_manager,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
rate_limiter_acquired = self.rate_limiter is None
|
||||
|
||||
def pump_one() -> bool:
|
||||
nonlocal rate_limiter_acquired
|
||||
if not rate_limiter_acquired:
|
||||
assert self.rate_limiter is not None # noqa: S101
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
rate_limiter_acquired = True
|
||||
try:
|
||||
next(event_iter_ref)
|
||||
except StopIteration:
|
||||
return False
|
||||
except BaseException as exc:
|
||||
stream.fail(exc)
|
||||
run_manager.on_llm_error(
|
||||
exc,
|
||||
response=LLMResult(generations=[]),
|
||||
)
|
||||
return False
|
||||
if stream.done and stream.output_message is not None:
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(
|
||||
generations=[
|
||||
[ChatGeneration(message=stream.output_message)],
|
||||
],
|
||||
),
|
||||
)
|
||||
return True
|
||||
|
||||
stream.bind_pump(pump_one)
|
||||
return stream
|
||||
|
||||
async def astream_v2(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncChatModelStream:
|
||||
"""Async variant of :meth:`stream_v2`.
|
||||
|
||||
Returns an :class:`AsyncChatModelStream` whose projections are
|
||||
async-iterable and awaitable.
|
||||
|
||||
.. warning::
|
||||
This API is experimental and may change.
|
||||
|
||||
Args:
|
||||
input: The model input.
|
||||
config: Optional runnable config.
|
||||
stop: Optional list of stop words.
|
||||
**kwargs: Additional keyword arguments passed to the model.
|
||||
|
||||
Returns:
|
||||
An :class:`AsyncChatModelStream` with typed projections.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input).to_messages()
|
||||
input_messages = _normalize_messages(messages)
|
||||
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params_with_defaults(stop=stop, **kwargs),
|
||||
}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
[_format_for_tracing(messages)],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
run_id = "-".join((LC_ID_PREFIX, str(run_manager.run_id)))
|
||||
stream = AsyncChatModelStream(message_id=run_id)
|
||||
|
||||
async def _produce() -> None:
|
||||
try:
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.aacquire(blocking=True)
|
||||
|
||||
async for _event in self._aiter_v2_events(
|
||||
input_messages,
|
||||
run_manager=run_manager,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
if stream.done and stream.output_message is not None:
|
||||
await run_manager.on_llm_end(
|
||||
LLMResult(
|
||||
generations=[
|
||||
[ChatGeneration(message=stream.output_message)],
|
||||
],
|
||||
),
|
||||
)
|
||||
except asyncio.CancelledError as exc:
|
||||
stream.fail(exc)
|
||||
raise
|
||||
except BaseException as exc:
|
||||
stream.fail(exc)
|
||||
await run_manager.on_llm_error(
|
||||
exc,
|
||||
response=LLMResult(generations=[]),
|
||||
)
|
||||
|
||||
stream._producer_task = asyncio.get_running_loop().create_task(_produce()) # noqa: SLF001
|
||||
return stream
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
def _combine_llm_outputs(self, _llm_outputs: list[dict | None], /) -> dict:
|
||||
@@ -1237,9 +1579,39 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
# v2 streaming: preferred over v1 when any attached handler opts in via
|
||||
# `_V2StreamingCallbackHandler`. Drives the protocol event generator
|
||||
# (native or `_stream` compat bridge) through the shared helper so
|
||||
# `on_stream_event` fires per event, then returns a normal `ChatResult`
|
||||
# so caching / `on_llm_end` stay on the existing generate path.
|
||||
if self._should_stream_v2(
|
||||
async_api=False,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
stream_accum = ChatModelStream(
|
||||
message_id=(
|
||||
f"{LC_ID_PREFIX}-{run_manager.run_id}" if run_manager else None
|
||||
)
|
||||
)
|
||||
assert run_manager is not None # noqa: S101
|
||||
for _event in self._iter_v2_events(
|
||||
messages,
|
||||
run_manager=run_manager,
|
||||
stream=stream_accum,
|
||||
stop=stop,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
if stream_accum.output_message is None:
|
||||
msg = "v2 stream finished without producing a message"
|
||||
raise RuntimeError(msg)
|
||||
result = ChatResult(
|
||||
generations=[ChatGeneration(message=stream_accum.output_message)]
|
||||
)
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _stream not implemented
|
||||
if self._should_stream(
|
||||
elif self._should_stream(
|
||||
async_api=False,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
@@ -1363,9 +1735,35 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.aacquire(blocking=True)
|
||||
|
||||
# v2 streaming: see sync counterpart in `_generate_with_cache`.
|
||||
if self._should_stream_v2(
|
||||
async_api=True,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
stream_accum = AsyncChatModelStream(
|
||||
message_id=(
|
||||
f"{LC_ID_PREFIX}-{run_manager.run_id}" if run_manager else None
|
||||
)
|
||||
)
|
||||
assert run_manager is not None # noqa: S101
|
||||
async for _event in self._aiter_v2_events(
|
||||
messages,
|
||||
run_manager=run_manager,
|
||||
stream=stream_accum,
|
||||
stop=stop,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
if stream_accum.output_message is None:
|
||||
msg = "v2 stream finished without producing a message"
|
||||
raise RuntimeError(msg)
|
||||
result = ChatResult(
|
||||
generations=[ChatGeneration(message=stream_accum.output_message)]
|
||||
)
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _astream not implemented
|
||||
if self._should_stream(
|
||||
elif self._should_stream(
|
||||
async_api=True,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
|
||||
@@ -5889,6 +5889,41 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[
|
||||
):
|
||||
yield item
|
||||
|
||||
def stream_v2(
|
||||
self,
|
||||
input: Input,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Any:
|
||||
"""Forward `stream_v2` to the bound runnable with bound kwargs merged.
|
||||
|
||||
Chat-model-specific: the bound runnable must implement `stream_v2`
|
||||
(see `BaseChatModel`). Without this override, `__getattr__` would
|
||||
forward the call but drop `self.kwargs` — losing tools bound via
|
||||
`bind_tools`, `stop` sequences, etc.
|
||||
"""
|
||||
return self.bound.stream_v2( # type: ignore[attr-defined]
|
||||
input,
|
||||
self._merge_configs(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
async def astream_v2(
|
||||
self,
|
||||
input: Input,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Any:
|
||||
"""Forward `astream_v2` to the bound runnable with bound kwargs merged.
|
||||
|
||||
Async variant of `stream_v2`. See that method for the full rationale.
|
||||
"""
|
||||
return await self.bound.astream_v2( # type: ignore[attr-defined]
|
||||
input,
|
||||
self._merge_configs(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@override
|
||||
async def astream_events(
|
||||
self,
|
||||
|
||||
@@ -28,6 +28,25 @@ class _StreamingCallbackHandler(typing.Protocol[T]):
|
||||
"""Used for internal astream_log and astream events implementations."""
|
||||
|
||||
|
||||
# THIS IS USED IN LANGGRAPH.
|
||||
class _V2StreamingCallbackHandler:
|
||||
"""Marker base class for handlers that consume `on_stream_event` (v2).
|
||||
|
||||
A handler inheriting from this class signals that it wants content-
|
||||
block lifecycle events from `stream_v2` / `astream_v2` rather than
|
||||
the v1 `on_llm_new_token` chunks. `BaseChatModel.invoke` uses
|
||||
`isinstance(handler, _V2StreamingCallbackHandler)` to decide whether
|
||||
to route an invoke through the v2 event generator.
|
||||
|
||||
Implemented as a concrete marker class (not a `Protocol`) so opt-in
|
||||
is explicit via inheritance. An empty `runtime_checkable` Protocol
|
||||
would match every object and misroute every call. The event
|
||||
delivery contract itself lives on
|
||||
`BaseCallbackHandler.on_stream_event`.
|
||||
"""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"_StreamingCallbackHandler",
|
||||
"_V2StreamingCallbackHandler",
|
||||
]
|
||||
|
||||
@@ -32,6 +32,7 @@ dependencies = [
|
||||
"packaging>=23.2.0",
|
||||
"pydantic>=2.7.4,<3.0.0",
|
||||
"uuid-utils>=0.12.0,<1.0",
|
||||
"langchain-protocol>=0.0.8",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -0,0 +1,934 @@
|
||||
"""Tests for ChatModelStream, AsyncChatModelStream, and projections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.language_models.chat_model_stream import (
|
||||
AsyncChatModelStream,
|
||||
AsyncProjection,
|
||||
ChatModelStream,
|
||||
SyncProjection,
|
||||
SyncTextProjection,
|
||||
dispatch_event,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_protocol.protocol import ContentBlockFinishData, MessagesData
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Projection unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncProjection:
|
||||
"""Test SyncProjection push/pull mechanics."""
|
||||
|
||||
def test_push_and_iterate(self) -> None:
|
||||
proj = SyncProjection()
|
||||
proj.push("a")
|
||||
proj.push("b")
|
||||
proj.complete(["a", "b"])
|
||||
assert list(proj) == ["a", "b"]
|
||||
|
||||
def test_get_returns_final_value(self) -> None:
|
||||
proj = SyncProjection()
|
||||
proj.push("x")
|
||||
proj.complete("final")
|
||||
assert proj.get() == "final"
|
||||
|
||||
def test_request_more_pulls(self) -> None:
|
||||
proj = SyncProjection()
|
||||
calls = iter(["a", "b", None])
|
||||
|
||||
def pump() -> bool:
|
||||
val = next(calls)
|
||||
if val is None:
|
||||
proj.complete("ab")
|
||||
return True
|
||||
proj.push(val)
|
||||
return True
|
||||
|
||||
proj._request_more = pump
|
||||
assert list(proj) == ["a", "b"]
|
||||
assert proj.get() == "ab"
|
||||
|
||||
def test_error_propagation(self) -> None:
|
||||
proj = SyncProjection()
|
||||
proj.push("partial")
|
||||
proj.fail(ValueError("boom"))
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
list(proj)
|
||||
|
||||
def test_error_on_get(self) -> None:
|
||||
proj = SyncProjection()
|
||||
proj.fail(ValueError("boom"))
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
proj.get()
|
||||
|
||||
def test_multi_cursor_replay(self) -> None:
|
||||
proj = SyncProjection()
|
||||
proj.push("a")
|
||||
proj.push("b")
|
||||
proj.complete(None)
|
||||
assert list(proj) == ["a", "b"]
|
||||
assert list(proj) == ["a", "b"] # Second iteration replays
|
||||
|
||||
def test_empty_projection(self) -> None:
|
||||
proj = SyncProjection()
|
||||
proj.complete([])
|
||||
assert list(proj) == []
|
||||
assert proj.get() == []
|
||||
|
||||
|
||||
class TestSyncTextProjection:
|
||||
"""Test SyncTextProjection string convenience methods."""
|
||||
|
||||
def test_str_drains(self) -> None:
|
||||
proj = SyncTextProjection()
|
||||
proj.push("Hello")
|
||||
proj.push(" world")
|
||||
proj.complete("Hello world")
|
||||
assert str(proj) == "Hello world"
|
||||
|
||||
def test_str_with_pump(self) -> None:
|
||||
proj = SyncTextProjection()
|
||||
done = False
|
||||
|
||||
def pump() -> bool:
|
||||
nonlocal done
|
||||
if not done:
|
||||
proj.push("Hi")
|
||||
proj.complete("Hi")
|
||||
done = True
|
||||
return True
|
||||
return False
|
||||
|
||||
proj._request_more = pump
|
||||
assert str(proj) == "Hi"
|
||||
|
||||
def test_bool_nonempty(self) -> None:
|
||||
proj = SyncTextProjection()
|
||||
assert not proj
|
||||
proj.push("x")
|
||||
assert proj
|
||||
|
||||
def test_repr(self) -> None:
|
||||
proj = SyncTextProjection()
|
||||
proj.push("hello")
|
||||
assert repr(proj) == "'hello'"
|
||||
proj.complete("hello")
|
||||
assert repr(proj) == "'hello'"
|
||||
|
||||
|
||||
class TestAsyncProjection:
|
||||
"""Test AsyncProjection async iteration and awaiting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_await_final_value(self) -> None:
|
||||
proj = AsyncProjection()
|
||||
proj.push("a")
|
||||
proj.complete("final")
|
||||
assert await proj == "final"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_iter(self) -> None:
|
||||
proj = AsyncProjection()
|
||||
|
||||
async def produce() -> None:
|
||||
await asyncio.sleep(0)
|
||||
proj.push("x")
|
||||
await asyncio.sleep(0)
|
||||
proj.push("y")
|
||||
await asyncio.sleep(0)
|
||||
proj.complete("xy")
|
||||
|
||||
asyncio.get_running_loop().create_task(produce())
|
||||
deltas = [d async for d in proj]
|
||||
assert deltas == ["x", "y"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_on_await(self) -> None:
|
||||
proj = AsyncProjection()
|
||||
proj.fail(ValueError("async boom"))
|
||||
with pytest.raises(ValueError, match="async boom"):
|
||||
await proj
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_on_iter(self) -> None:
|
||||
proj = AsyncProjection()
|
||||
proj.push("partial")
|
||||
proj.fail(ValueError("mid-stream"))
|
||||
with pytest.raises(ValueError, match="mid-stream"):
|
||||
async for _ in proj:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arequest_more_drives_iteration(self) -> None:
|
||||
"""Cursor drives the async pump when the buffer is empty."""
|
||||
proj = AsyncProjection()
|
||||
deltas = iter(["a", "b", "c"])
|
||||
|
||||
async def pump() -> bool:
|
||||
try:
|
||||
proj.push(next(deltas))
|
||||
except StopIteration:
|
||||
proj.complete("abc")
|
||||
return False
|
||||
return True
|
||||
|
||||
proj.set_arequest_more(pump)
|
||||
collected = [d async for d in proj]
|
||||
assert collected == ["a", "b", "c"]
|
||||
assert await proj == "abc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arequest_more_drives_await(self) -> None:
|
||||
"""`await projection` drives the pump too, not just iteration."""
|
||||
proj = AsyncProjection()
|
||||
steps = iter([("push", "x"), ("push", "y"), ("complete", "xy")])
|
||||
|
||||
async def pump() -> bool:
|
||||
try:
|
||||
action, value = next(steps)
|
||||
except StopIteration:
|
||||
return False
|
||||
if action == "push":
|
||||
proj.push(value)
|
||||
else:
|
||||
proj.complete(value)
|
||||
return True
|
||||
|
||||
proj.set_arequest_more(pump)
|
||||
assert await proj == "xy"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arequest_more_stops_when_pump_exhausts(self) -> None:
|
||||
"""Pump returning False without completing ends iteration cleanly."""
|
||||
proj = AsyncProjection()
|
||||
pushed = [False]
|
||||
|
||||
async def pump() -> bool:
|
||||
if not pushed[0]:
|
||||
proj.push("only")
|
||||
pushed[0] = True
|
||||
return True
|
||||
return False
|
||||
|
||||
proj.set_arequest_more(pump)
|
||||
collected = [d async for d in proj]
|
||||
assert collected == ["only"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_model_stream_set_arequest_more_fans_out(self) -> None:
|
||||
"""`set_arequest_more` wires every projection on AsyncChatModelStream."""
|
||||
stream = AsyncChatModelStream(message_id="m1")
|
||||
|
||||
async def pump() -> bool:
|
||||
return False
|
||||
|
||||
stream.set_arequest_more(pump)
|
||||
for proj in (
|
||||
stream._text_proj,
|
||||
stream._reasoning_proj,
|
||||
stream._tool_calls_proj,
|
||||
stream._usage_proj,
|
||||
stream._output_proj,
|
||||
stream._events_proj,
|
||||
):
|
||||
assert cast("AsyncProjection", proj)._arequest_more is pump
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_text_and_output_share_pump(self) -> None:
|
||||
"""Concurrent `stream.text` + `await stream.output` both drive the pump."""
|
||||
stream = AsyncChatModelStream(message_id="m1")
|
||||
|
||||
events: list[MessagesData] = [
|
||||
{
|
||||
"event": "message-start",
|
||||
"message_id": "m1",
|
||||
"metadata": {"provider": "test", "model": "fake"},
|
||||
},
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "hello "},
|
||||
},
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "world"},
|
||||
},
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "hello world"},
|
||||
},
|
||||
{"event": "message-finish", "reason": "stop"},
|
||||
]
|
||||
cursor = iter(events)
|
||||
pump_lock = asyncio.Lock()
|
||||
|
||||
async def pump() -> bool:
|
||||
async with pump_lock:
|
||||
try:
|
||||
evt = next(cursor)
|
||||
except StopIteration:
|
||||
return False
|
||||
stream.dispatch(evt)
|
||||
return True
|
||||
|
||||
stream.set_arequest_more(pump)
|
||||
|
||||
async def drain_text() -> str:
|
||||
buf = [delta async for delta in stream.text]
|
||||
return "".join(buf)
|
||||
|
||||
text, message = await asyncio.gather(drain_text(), stream.output)
|
||||
assert text == "hello world"
|
||||
assert message.content == "hello world"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatModelStream unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChatModelStream:
|
||||
"""Test sync ChatModelStream with dispatch_event."""
|
||||
|
||||
def test_text_projection_cached(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
assert stream.text is stream.text
|
||||
|
||||
def test_reasoning_projection_cached(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
assert stream.reasoning is stream.reasoning
|
||||
|
||||
def test_tool_calls_projection_cached(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
assert stream.tool_calls is stream.tool_calls
|
||||
|
||||
def test_text_deltas_via_pump(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
events: list[MessagesData] = [
|
||||
{"event": "message-start", "role": "ai"},
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "Hi"},
|
||||
},
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": " there"},
|
||||
},
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "Hi there"},
|
||||
},
|
||||
{"event": "message-finish", "reason": "stop"},
|
||||
]
|
||||
idx = 0
|
||||
|
||||
def pump() -> bool:
|
||||
nonlocal idx
|
||||
if idx >= len(events):
|
||||
return False
|
||||
dispatch_event(events[idx], stream)
|
||||
idx += 1
|
||||
return True
|
||||
|
||||
stream.bind_pump(pump)
|
||||
assert list(stream.text) == ["Hi", " there"]
|
||||
assert str(stream.text) == "Hi there"
|
||||
|
||||
def test_tool_call_chunk_streaming(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "tool_call_chunk",
|
||||
"id": "tc1",
|
||||
"name": "search",
|
||||
"args": '{"q":',
|
||||
"index": 0,
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "tool_call_chunk",
|
||||
"args": '"test"}',
|
||||
"index": 0,
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "tool_call",
|
||||
"id": "tc1",
|
||||
"name": "search",
|
||||
"args": {"q": "test"},
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "tool_use"}, stream)
|
||||
|
||||
# Check chunk deltas were pushed
|
||||
chunks = list(stream.tool_calls)
|
||||
assert len(chunks) == 2 # two chunk deltas
|
||||
assert chunks[0]["type"] == "tool_call_chunk"
|
||||
assert chunks[0]["name"] == "search"
|
||||
|
||||
# Check finalized tool calls
|
||||
finalized = stream.tool_calls.get()
|
||||
assert len(finalized) == 1
|
||||
assert finalized[0]["name"] == "search"
|
||||
assert finalized[0]["args"] == {"q": "test"}
|
||||
|
||||
def test_multi_tool_parallel(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
# Tool 1 starts
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "tool_call_chunk",
|
||||
"id": "t1",
|
||||
"name": "foo",
|
||||
"args": '{"a":',
|
||||
"index": 0,
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
# Tool 2 starts
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 1,
|
||||
"content_block": {
|
||||
"type": "tool_call_chunk",
|
||||
"id": "t2",
|
||||
"name": "bar",
|
||||
"args": '{"b":',
|
||||
"index": 1,
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
# Tool 1 finishes
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "tool_call",
|
||||
"id": "t1",
|
||||
"name": "foo",
|
||||
"args": {"a": 1},
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
# Tool 2 finishes
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 1,
|
||||
"content_block": {
|
||||
"type": "tool_call",
|
||||
"id": "t2",
|
||||
"name": "bar",
|
||||
"args": {"b": 2},
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "tool_use"}, stream)
|
||||
|
||||
finalized = stream.tool_calls.get()
|
||||
assert len(finalized) == 2
|
||||
assert finalized[0]["name"] == "foo"
|
||||
assert finalized[1]["name"] == "bar"
|
||||
|
||||
def test_output_assembles_aimessage(self) -> None:
|
||||
stream = ChatModelStream(message_id="msg-1")
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "message-start",
|
||||
"role": "ai",
|
||||
"metadata": {"provider": "anthropic", "model": "claude-4"},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "Hello"},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "Hello"},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "message-finish",
|
||||
"reason": "stop",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
|
||||
msg = stream.output
|
||||
assert msg.content == "Hello"
|
||||
assert msg.id == "msg-1"
|
||||
assert msg.response_metadata["finish_reason"] == "stop"
|
||||
assert msg.response_metadata["model_provider"] == "anthropic"
|
||||
assert msg.usage_metadata is not None
|
||||
assert msg.usage_metadata["input_tokens"] == 10
|
||||
|
||||
def test_error_propagates_to_projections(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "partial"},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
stream.fail(RuntimeError("connection lost"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="connection lost"):
|
||||
str(stream.text)
|
||||
|
||||
with pytest.raises(RuntimeError, match="connection lost"):
|
||||
stream.tool_calls.get()
|
||||
|
||||
def test_raw_event_iteration(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "hi"},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
events = list(stream)
|
||||
assert len(events) == 3
|
||||
assert events[0]["event"] == "message-start"
|
||||
assert events[2]["event"] == "message-finish"
|
||||
|
||||
def test_raw_event_multi_cursor(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
assert list(stream) == list(stream) # Replay
|
||||
|
||||
def test_invalid_tool_call_preserved_on_finish(self) -> None:
|
||||
"""An `invalid_tool_call` finish lands on `invalid_tool_calls`."""
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "invalid_tool_call",
|
||||
"id": "call_1",
|
||||
"name": "search",
|
||||
"args": '{"q": ', # malformed
|
||||
"error": "Failed to parse tool call arguments as JSON",
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
msg = stream.output
|
||||
assert msg.tool_calls == []
|
||||
assert len(msg.invalid_tool_calls) == 1
|
||||
assert msg.invalid_tool_calls[0]["name"] == "search"
|
||||
assert msg.invalid_tool_calls[0]["args"] == '{"q": '
|
||||
assert msg.invalid_tool_calls[0]["error"] == (
|
||||
"Failed to parse tool call arguments as JSON"
|
||||
)
|
||||
|
||||
def test_invalid_tool_call_survives_sweep(self) -> None:
|
||||
"""Regression: finish deletes stale chunk, sweep cannot revive it."""
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
# Stream a tool_call_chunk with malformed JSON args
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "tool_call_chunk",
|
||||
"id": "call_1",
|
||||
"name": "search",
|
||||
"args": '{"q": ',
|
||||
"index": 0,
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
# Finish event declares the call invalid
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "invalid_tool_call",
|
||||
"id": "call_1",
|
||||
"name": "search",
|
||||
"args": '{"q": ',
|
||||
"error": "Failed to parse tool call arguments as JSON",
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
msg = stream.output
|
||||
# The sweep must NOT have revived the chunk as an empty-args tool_call.
|
||||
assert msg.tool_calls == []
|
||||
assert len(msg.invalid_tool_calls) == 1
|
||||
|
||||
def test_output_content_uses_protocol_tool_call_shape(self) -> None:
|
||||
"""`.output.content` must emit `type: tool_call`, not legacy tool_use."""
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "Let me search."},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "Let me search."},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 1,
|
||||
"content_block": {
|
||||
"type": "tool_call",
|
||||
"id": "call_1",
|
||||
"name": "search",
|
||||
"args": {"q": "weather"},
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "tool_use"}, stream)
|
||||
|
||||
msg = stream.output
|
||||
assert isinstance(msg.content, list)
|
||||
content = cast("list[dict[str, Any]]", msg.content)
|
||||
types = [b.get("type") for b in content]
|
||||
assert types == ["text", "tool_call"]
|
||||
tool_block = content[1]
|
||||
assert tool_block["name"] == "search"
|
||||
assert tool_block["args"] == {"q": "weather"}
|
||||
# Legacy shape fields must be absent
|
||||
assert "input" not in tool_block
|
||||
assert tool_block.get("type") != "tool_use"
|
||||
|
||||
def test_server_tool_call_finish_lands_in_output_content(self) -> None:
|
||||
"""Server-executed tool call finish events flow into .output.content."""
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "server_tool_call",
|
||||
"id": "srv_1",
|
||||
"name": "web_search",
|
||||
"args": {"q": "weather"},
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event(
|
||||
cast(
|
||||
"ContentBlockFinishData",
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 1,
|
||||
"content_block": {
|
||||
"type": "server_tool_result",
|
||||
"tool_call_id": "srv_1",
|
||||
"status": "success",
|
||||
"output": "62F, clear",
|
||||
},
|
||||
},
|
||||
),
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
msg = stream.output
|
||||
assert isinstance(msg.content, list)
|
||||
content = cast("list[dict[str, Any]]", msg.content)
|
||||
types = [b.get("type") for b in content]
|
||||
assert types == ["server_tool_call", "server_tool_result"]
|
||||
# Regular tool_calls projection must NOT include server-executed ones
|
||||
assert msg.tool_calls == []
|
||||
|
||||
def test_server_tool_call_chunk_sweep(self) -> None:
|
||||
"""Unfinished server_tool_call_chunks get swept to server_tool_call."""
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "server_tool_call_chunk",
|
||||
"id": "srv_1",
|
||||
"name": "web_search",
|
||||
"args": '{"q":',
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "server_tool_call_chunk",
|
||||
"args": ' "weather"}',
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
msg = stream.output
|
||||
assert isinstance(msg.content, list)
|
||||
content = cast("list[dict[str, Any]]", msg.content)
|
||||
assert content[0]["type"] == "server_tool_call"
|
||||
assert content[0]["args"] == {"q": "weather"}
|
||||
assert content[0]["name"] == "web_search"
|
||||
|
||||
def test_image_block_pass_through(self) -> None:
|
||||
"""An image block finished via the event stream reaches .output.content."""
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "image",
|
||||
"url": "https://example.com/cat.png",
|
||||
"mime_type": "image/png",
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
msg = stream.output
|
||||
assert isinstance(msg.content, list)
|
||||
assert msg.content[0] == {
|
||||
"type": "image",
|
||||
"url": "https://example.com/cat.png",
|
||||
"mime_type": "image/png",
|
||||
}
|
||||
|
||||
def test_sweep_of_unfinished_malformed_chunk_produces_invalid_tool_call(
|
||||
self,
|
||||
) -> None:
|
||||
"""Unfinished chunk with malformed JSON sweeps to invalid_tool_call."""
|
||||
stream = ChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "tool_call_chunk",
|
||||
"id": "call_1",
|
||||
"name": "search",
|
||||
"args": '{"q": ', # malformed, never completed
|
||||
"index": 0,
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
msg = stream.output
|
||||
assert msg.tool_calls == []
|
||||
assert len(msg.invalid_tool_calls) == 1
|
||||
itc = msg.invalid_tool_calls[0]
|
||||
assert itc["name"] == "search"
|
||||
assert itc["args"] == '{"q": '
|
||||
assert "Failed to parse" in (itc["error"] or "")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AsyncChatModelStream unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAsyncChatModelStream:
|
||||
"""Test async ChatModelStream."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_await_output(self) -> None:
|
||||
stream = AsyncChatModelStream(message_id="m1")
|
||||
|
||||
async def produce() -> None:
|
||||
await asyncio.sleep(0)
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "Hi"},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
asyncio.get_running_loop().create_task(produce())
|
||||
msg = await stream
|
||||
assert msg.content == "Hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_text_deltas(self) -> None:
|
||||
stream = AsyncChatModelStream()
|
||||
|
||||
async def produce() -> None:
|
||||
await asyncio.sleep(0)
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
await asyncio.sleep(0)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "a"},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": "b"},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
asyncio.get_running_loop().create_task(produce())
|
||||
deltas = [d async for d in stream.text]
|
||||
assert deltas == ["a", "b"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_await_tool_calls(self) -> None:
|
||||
stream = AsyncChatModelStream()
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-delta",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "tool_call_chunk",
|
||||
"id": "tc1",
|
||||
"name": "search",
|
||||
"args": '{"q":"hi"}',
|
||||
"index": 0,
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event(
|
||||
{
|
||||
"event": "content-block-finish",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "tool_call",
|
||||
"id": "tc1",
|
||||
"name": "search",
|
||||
"args": {"q": "hi"},
|
||||
},
|
||||
},
|
||||
stream,
|
||||
)
|
||||
dispatch_event({"event": "message-finish", "reason": "tool_use"}, stream)
|
||||
|
||||
result = await stream.tool_calls
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "search"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_raw_event_iteration(self) -> None:
|
||||
stream = AsyncChatModelStream()
|
||||
|
||||
async def produce() -> None:
|
||||
await asyncio.sleep(0)
|
||||
dispatch_event({"event": "message-start", "role": "ai"}, stream)
|
||||
await asyncio.sleep(0)
|
||||
dispatch_event({"event": "message-finish", "reason": "stop"}, stream)
|
||||
|
||||
asyncio.get_running_loop().create_task(produce())
|
||||
events = [e async for e in stream]
|
||||
assert len(events) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_propagation(self) -> None:
|
||||
stream = AsyncChatModelStream()
|
||||
stream.fail(RuntimeError("async fail"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="async fail"):
|
||||
await stream.text
|
||||
with pytest.raises(RuntimeError, match="async fail"):
|
||||
await stream
|
||||
@@ -0,0 +1,342 @@
|
||||
"""Tests for BaseChatModel.stream_v2() / astream_v2()."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.language_models.chat_model_stream import (
|
||||
AsyncChatModelStream,
|
||||
ChatModelStream,
|
||||
)
|
||||
from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_protocol.protocol import MessagesData
|
||||
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class TestStreamV2Sync:
|
||||
"""Test BaseChatModel.stream_v2() with FakeListChatModel."""
|
||||
|
||||
def test_stream_text(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hello world!"])
|
||||
stream = model.stream_v2("test")
|
||||
|
||||
assert isinstance(stream, ChatModelStream)
|
||||
deltas = list(stream.text)
|
||||
assert "".join(deltas) == "Hello world!"
|
||||
assert stream.done
|
||||
|
||||
def test_stream_output(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hello!"])
|
||||
stream = model.stream_v2("test")
|
||||
|
||||
msg = stream.output
|
||||
assert msg.content == "Hello!"
|
||||
assert msg.id is not None
|
||||
|
||||
def test_stream_usage_none_for_fake(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hi"])
|
||||
stream = model.stream_v2("test")
|
||||
# Drain
|
||||
for _ in stream.text:
|
||||
pass
|
||||
assert stream.usage is None
|
||||
|
||||
def test_stream_raw_events(self) -> None:
|
||||
model = FakeListChatModel(responses=["ab"])
|
||||
stream = model.stream_v2("test")
|
||||
|
||||
events = list(stream)
|
||||
event_types = [e.get("event") for e in events]
|
||||
assert event_types[0] == "message-start"
|
||||
assert event_types[-1] == "message-finish"
|
||||
assert "content-block-delta" in event_types
|
||||
|
||||
|
||||
class TestAstreamV2:
|
||||
"""Test BaseChatModel.astream_v2() with FakeListChatModel."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astream_text_await(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hello!"])
|
||||
stream = await model.astream_v2("test")
|
||||
|
||||
assert isinstance(stream, AsyncChatModelStream)
|
||||
full = await stream.text
|
||||
assert full == "Hello!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astream_text_deltas(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hi"])
|
||||
stream = await model.astream_v2("test")
|
||||
|
||||
deltas = [d async for d in stream.text]
|
||||
assert "".join(deltas) == "Hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astream_await_output(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hey"])
|
||||
stream = await model.astream_v2("test")
|
||||
|
||||
msg = await stream
|
||||
assert msg.content == "Hey"
|
||||
|
||||
|
||||
class _RecordingHandler(BaseCallbackHandler):
|
||||
"""Sync callback handler that records lifecycle hook invocations."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: list[str] = []
|
||||
self.stream_events: list[MessagesData] = []
|
||||
self.last_llm_end_response: LLMResult | None = None
|
||||
|
||||
def on_chat_model_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
del args, kwargs
|
||||
self.events.append("on_chat_model_start")
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
del kwargs
|
||||
self.events.append("on_llm_end")
|
||||
self.last_llm_end_response = response
|
||||
|
||||
def on_llm_error(self, *args: Any, **kwargs: Any) -> None:
|
||||
del args, kwargs
|
||||
self.events.append("on_llm_error")
|
||||
|
||||
def on_stream_event(self, event: MessagesData, **kwargs: Any) -> None:
|
||||
del kwargs
|
||||
self.stream_events.append(event)
|
||||
|
||||
|
||||
class _AsyncRecordingHandler(AsyncCallbackHandler):
|
||||
"""Async callback handler that records lifecycle hook invocations."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: list[str] = []
|
||||
self.stream_events: list[MessagesData] = []
|
||||
self.last_llm_end_response: LLMResult | None = None
|
||||
|
||||
async def on_chat_model_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
del args, kwargs
|
||||
self.events.append("on_chat_model_start")
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
del kwargs
|
||||
self.events.append("on_llm_end")
|
||||
self.last_llm_end_response = response
|
||||
|
||||
async def on_llm_error(self, *args: Any, **kwargs: Any) -> None:
|
||||
del args, kwargs
|
||||
self.events.append("on_llm_error")
|
||||
|
||||
async def on_stream_event(self, event: MessagesData, **kwargs: Any) -> None:
|
||||
del kwargs
|
||||
self.stream_events.append(event)
|
||||
|
||||
|
||||
class TestCallbacks:
|
||||
"""Verify stream_v2 fires on_llm_end / on_llm_error callbacks."""
|
||||
|
||||
def test_on_llm_end_fires_after_drain(self) -> None:
|
||||
handler = _RecordingHandler()
|
||||
model = FakeListChatModel(responses=["done"], callbacks=[handler])
|
||||
stream = model.stream_v2("test")
|
||||
for _ in stream.text:
|
||||
pass
|
||||
_ = stream.output
|
||||
|
||||
assert "on_chat_model_start" in handler.events
|
||||
assert "on_llm_end" in handler.events
|
||||
assert handler.events.index("on_llm_end") > handler.events.index(
|
||||
"on_chat_model_start"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_llm_end_fires_async(self) -> None:
|
||||
handler = _AsyncRecordingHandler()
|
||||
model = FakeListChatModel(responses=["done"], callbacks=[handler])
|
||||
stream = await model.astream_v2("test")
|
||||
_ = await stream
|
||||
|
||||
assert "on_chat_model_start" in handler.events
|
||||
assert "on_llm_end" in handler.events
|
||||
|
||||
def test_on_llm_end_receives_assembled_message(self) -> None:
|
||||
"""The LLMResult passed to on_llm_end must carry the final message.
|
||||
|
||||
Without this, LangSmith traces would see an empty generations list.
|
||||
"""
|
||||
handler = _RecordingHandler()
|
||||
model = FakeListChatModel(responses=["hello"], callbacks=[handler])
|
||||
stream = model.stream_v2("test")
|
||||
_ = stream.output
|
||||
|
||||
response = handler.last_llm_end_response
|
||||
assert response is not None
|
||||
assert response.generations
|
||||
gen = response.generations[0][0]
|
||||
assert isinstance(gen, ChatGeneration)
|
||||
assert gen.message.content == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_llm_end_receives_assembled_message_async(self) -> None:
|
||||
handler = _AsyncRecordingHandler()
|
||||
model = FakeListChatModel(responses=["hello"], callbacks=[handler])
|
||||
stream = await model.astream_v2("test")
|
||||
_ = await stream
|
||||
|
||||
response = handler.last_llm_end_response
|
||||
assert response is not None
|
||||
assert response.generations
|
||||
gen = response.generations[0][0]
|
||||
assert isinstance(gen, ChatGeneration)
|
||||
assert gen.message.content == "hello"
|
||||
|
||||
|
||||
class TestOnStreamEvent:
|
||||
"""`on_stream_event` must fire once per protocol event from stream_v2."""
|
||||
|
||||
def test_on_stream_event_fires_for_every_event_sync(self) -> None:
|
||||
handler = _RecordingHandler()
|
||||
model = FakeListChatModel(responses=["Hi"], callbacks=[handler])
|
||||
stream = model.stream_v2("test")
|
||||
_ = stream.output
|
||||
|
||||
# Every event the stream sees should also reach the observer.
|
||||
assert len(handler.stream_events) == len(list(stream))
|
||||
event_types = [e["event"] for e in handler.stream_events]
|
||||
assert event_types[0] == "message-start"
|
||||
assert event_types[-1] == "message-finish"
|
||||
assert "content-block-delta" in event_types
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_stream_event_fires_for_every_event_async(self) -> None:
|
||||
handler = _AsyncRecordingHandler()
|
||||
model = FakeListChatModel(responses=["Hi"], callbacks=[handler])
|
||||
stream = await model.astream_v2("test")
|
||||
_ = await stream
|
||||
|
||||
event_types = [e["event"] for e in handler.stream_events]
|
||||
assert event_types[0] == "message-start"
|
||||
assert event_types[-1] == "message-finish"
|
||||
assert "content-block-delta" in event_types
|
||||
|
||||
def test_on_stream_event_ordering_relative_to_lifecycle(self) -> None:
|
||||
"""Stream events must all fire between on_chat_model_start and on_llm_end."""
|
||||
handler = _RecordingHandler()
|
||||
model = FakeListChatModel(responses=["Hi"], callbacks=[handler])
|
||||
stream = model.stream_v2("test")
|
||||
_ = stream.output
|
||||
|
||||
# on_stream_event doesn't show up in `events` (different list), but
|
||||
# on_chat_model_start and on_llm_end bracket the run.
|
||||
assert handler.events[0] == "on_chat_model_start"
|
||||
assert handler.events[-1] == "on_llm_end"
|
||||
# And we did see stream events during that bracket.
|
||||
assert handler.stream_events
|
||||
|
||||
|
||||
class TestCancellation:
|
||||
"""Cancellation of `astream_v2` must propagate, not be swallowed."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astream_v2_cancellation_propagates(self) -> None:
|
||||
"""Cancelling the producer task must raise CancelledError.
|
||||
|
||||
Regression test: the producer's `except BaseException` previously
|
||||
swallowed `asyncio.CancelledError`, converting it into an
|
||||
`on_llm_error` + `stream._fail` pair that never propagated.
|
||||
"""
|
||||
model = FakeListChatModel(responses=["abcdefghij"], sleep=0.05)
|
||||
stream = await model.astream_v2("test")
|
||||
task = stream._producer_task
|
||||
assert task is not None
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
task.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
assert isinstance(stream._error, asyncio.CancelledError)
|
||||
|
||||
|
||||
class _KwargRecordingModel(FakeListChatModel):
|
||||
"""Fake model that records kwargs passed to `_stream` / `_astream`."""
|
||||
|
||||
received_kwargs: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: Any,
|
||||
stop: Any = None,
|
||||
run_manager: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.received_kwargs.append({"stop": stop, **kwargs})
|
||||
return super()._stream(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: Any,
|
||||
stop: Any = None,
|
||||
run_manager: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.received_kwargs.append({"stop": stop, **kwargs})
|
||||
async for chunk in super()._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
class TestRunnableBindingForwarding:
|
||||
"""`RunnableBinding.stream_v2` must merge bound kwargs into the call.
|
||||
|
||||
Without the explicit override on `RunnableBinding`, `__getattr__`
|
||||
forwards the call but drops `self.kwargs` — so tools bound via
|
||||
`bind_tools`, stop sequences bound via `bind`, etc. would be silently
|
||||
ignored.
|
||||
"""
|
||||
|
||||
def test_bound_kwargs_reach_stream_v2(self) -> None:
|
||||
model = _KwargRecordingModel(responses=["hi"])
|
||||
model.received_kwargs = []
|
||||
bound = model.bind(my_marker="sentinel-42")
|
||||
|
||||
stream = bound.stream_v2("test") # type: ignore[attr-defined]
|
||||
for _ in stream.text:
|
||||
pass
|
||||
|
||||
assert len(model.received_kwargs) == 1
|
||||
assert model.received_kwargs[0].get("my_marker") == "sentinel-42"
|
||||
|
||||
def test_call_kwargs_override_bound_kwargs(self) -> None:
|
||||
model = _KwargRecordingModel(responses=["hi"])
|
||||
model.received_kwargs = []
|
||||
bound = model.bind(my_marker="from-bind")
|
||||
|
||||
stream = bound.stream_v2("test", my_marker="from-call") # type: ignore[attr-defined]
|
||||
for _ in stream.text:
|
||||
pass
|
||||
|
||||
assert model.received_kwargs[0].get("my_marker") == "from-call"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bound_kwargs_reach_astream_v2(self) -> None:
|
||||
model = _KwargRecordingModel(responses=["hi"])
|
||||
model.received_kwargs = []
|
||||
bound = model.bind(my_marker="sentinel-async")
|
||||
|
||||
stream = await bound.astream_v2("test") # type: ignore[attr-defined]
|
||||
_ = await stream
|
||||
|
||||
assert len(model.received_kwargs) == 1
|
||||
assert model.received_kwargs[0].get("my_marker") == "sentinel-async"
|
||||
496
libs/core/tests/unit_tests/language_models/test_compat_bridge.py
Normal file
496
libs/core/tests/unit_tests/language_models/test_compat_bridge.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""Tests for the compat bridge (chunk-to-event conversion)."""
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.language_models._compat_bridge import (
|
||||
CompatBlock,
|
||||
_finalize_block,
|
||||
_normalize_finish_reason,
|
||||
_to_protocol_usage,
|
||||
amessage_to_events,
|
||||
chunks_to_events,
|
||||
message_to_events,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_protocol.protocol import (
|
||||
ContentBlockDeltaData,
|
||||
InvalidToolCallBlock,
|
||||
MessageFinishData,
|
||||
MessageStartData,
|
||||
ReasoningBlock,
|
||||
ServerToolCallBlock,
|
||||
TextBlock,
|
||||
ToolCallBlock,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_finalize_block_text_passes_through() -> None:
|
||||
block: CompatBlock = {"type": "text", "text": "hello"}
|
||||
result = _finalize_block(block)
|
||||
text_result = cast("TextBlock", result)
|
||||
assert text_result["type"] == "text"
|
||||
assert text_result["text"] == "hello"
|
||||
|
||||
|
||||
def test_finalize_block_tool_call_chunk_valid_json() -> None:
|
||||
block: CompatBlock = {
|
||||
"type": "tool_call_chunk",
|
||||
"args": '{"query": "test"}',
|
||||
"id": "tc1",
|
||||
"name": "search",
|
||||
}
|
||||
result = _finalize_block(block)
|
||||
tool_call = cast("ToolCallBlock", result)
|
||||
assert tool_call["type"] == "tool_call"
|
||||
assert tool_call["id"] == "tc1"
|
||||
assert tool_call["name"] == "search"
|
||||
assert tool_call["args"] == {"query": "test"}
|
||||
|
||||
|
||||
def test_finalize_block_tool_call_chunk_invalid_json() -> None:
|
||||
block: CompatBlock = {
|
||||
"type": "tool_call_chunk",
|
||||
"args": "not json",
|
||||
"id": "tc1",
|
||||
"name": "search",
|
||||
}
|
||||
result = _finalize_block(block)
|
||||
invalid = cast("InvalidToolCallBlock", result)
|
||||
assert invalid["type"] == "invalid_tool_call"
|
||||
assert invalid.get("error") is not None
|
||||
|
||||
|
||||
def test_finalize_block_server_tool_call_chunk_valid_json() -> None:
|
||||
block: CompatBlock = {
|
||||
"type": "server_tool_call_chunk",
|
||||
"args": '{"q": "weather"}',
|
||||
"id": "srv_1",
|
||||
"name": "web_search",
|
||||
}
|
||||
result = _finalize_block(block)
|
||||
server_result = cast("ServerToolCallBlock", result)
|
||||
assert server_result["type"] == "server_tool_call"
|
||||
assert server_result["id"] == "srv_1"
|
||||
assert server_result["name"] == "web_search"
|
||||
assert server_result["args"] == {"q": "weather"}
|
||||
|
||||
|
||||
def test_finalize_block_server_tool_call_chunk_invalid_json() -> None:
|
||||
block: CompatBlock = {
|
||||
"type": "server_tool_call_chunk",
|
||||
"args": "not json",
|
||||
"id": "srv_1",
|
||||
"name": "web_search",
|
||||
}
|
||||
result = _finalize_block(block)
|
||||
invalid = cast("InvalidToolCallBlock", result)
|
||||
assert invalid["type"] == "invalid_tool_call"
|
||||
assert invalid.get("error") is not None
|
||||
|
||||
|
||||
def test_normalize_finish_reason() -> None:
|
||||
assert _normalize_finish_reason("stop") == "stop"
|
||||
assert _normalize_finish_reason("end_turn") == "stop"
|
||||
assert _normalize_finish_reason("length") == "length"
|
||||
assert _normalize_finish_reason("tool_use") == "tool_use"
|
||||
assert _normalize_finish_reason("tool_calls") == "tool_use"
|
||||
assert _normalize_finish_reason("content_filter") == "content_filter"
|
||||
assert _normalize_finish_reason(None) == "stop"
|
||||
|
||||
|
||||
def test_to_protocol_usage_present() -> None:
|
||||
usage = {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}
|
||||
result = _to_protocol_usage(usage)
|
||||
assert result is not None
|
||||
assert result["input_tokens"] == 10
|
||||
assert result["output_tokens"] == 20
|
||||
|
||||
|
||||
def test_to_protocol_usage_none() -> None:
|
||||
assert _to_protocol_usage(None) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chunks_to_events: streaming lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_chunks_to_events_text_only() -> None:
|
||||
"""Multi-chunk text stream produces a clean lifecycle."""
|
||||
chunks = [
|
||||
ChatGenerationChunk(message=AIMessageChunk(content="Hello", id="msg-1")),
|
||||
ChatGenerationChunk(message=AIMessageChunk(content=" world", id="msg-1")),
|
||||
]
|
||||
|
||||
events = list(chunks_to_events(iter(chunks), message_id="msg-1"))
|
||||
event_types = [e["event"] for e in events]
|
||||
|
||||
assert event_types[0] == "message-start"
|
||||
assert "content-block-start" in event_types
|
||||
assert event_types.count("content-block-delta") == 2
|
||||
assert "content-block-finish" in event_types
|
||||
assert event_types[-1] == "message-finish"
|
||||
|
||||
finish = cast("MessageFinishData", events[-1])
|
||||
assert finish["reason"] == "stop"
|
||||
|
||||
|
||||
def test_chunks_to_events_empty_iterator() -> None:
|
||||
"""No chunks means no events."""
|
||||
assert list(chunks_to_events(iter([]))) == []
|
||||
|
||||
|
||||
def test_chunks_to_events_tool_call_multichunk() -> None:
|
||||
"""Partial tool-call args across chunks finalize to a single tool_call."""
|
||||
chunks = [
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
id="msg-1",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"index": 0,
|
||||
"id": "tc1",
|
||||
"name": "search",
|
||||
"args": '{"q":',
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
id="msg-1",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"name": None,
|
||||
"args": ' "test"}',
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
events = list(chunks_to_events(iter(chunks), message_id="msg-1"))
|
||||
event_types = [e["event"] for e in events]
|
||||
|
||||
assert event_types[0] == "message-start"
|
||||
assert "content-block-start" in event_types
|
||||
assert "content-block-finish" in event_types
|
||||
assert event_types[-1] == "message-finish"
|
||||
|
||||
# Exactly one block finalized, args parsed to a dict.
|
||||
finish_events = [e for e in events if e["event"] == "content-block-finish"]
|
||||
assert len(finish_events) == 1
|
||||
finalized = cast("ToolCallBlock", finish_events[0]["content_block"])
|
||||
assert finalized["type"] == "tool_call"
|
||||
assert finalized["args"] == {"q": "test"}
|
||||
|
||||
# Valid tool_call at finish => finish_reason flips to tool_use.
|
||||
assert cast("MessageFinishData", events[-1])["reason"] == "tool_use"
|
||||
|
||||
|
||||
def test_chunks_to_events_invalid_tool_call_keeps_stop_reason() -> None:
|
||||
"""Malformed tool-args become invalid_tool_call; finish_reason stays `stop`."""
|
||||
chunks = [
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
id="msg-bad",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"index": 0,
|
||||
"id": "tc1",
|
||||
"name": "search",
|
||||
"args": "{oops",
|
||||
"type": "tool_call_chunk",
|
||||
},
|
||||
],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
events = list(chunks_to_events(iter(chunks), message_id="msg-bad"))
|
||||
|
||||
finish_events = [e for e in events if e["event"] == "content-block-finish"]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0]["content_block"]["type"] == "invalid_tool_call"
|
||||
assert cast("MessageFinishData", events[-1])["reason"] == "stop"
|
||||
|
||||
|
||||
def test_chunks_to_events_anthropic_server_tool_use_routes_through_translator() -> None:
|
||||
"""`server_tool_use` shape + anthropic provider tag becomes `server_tool_call`."""
|
||||
chunks = [
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=[
|
||||
{"type": "text", "text": "Let me search. "},
|
||||
{
|
||||
"type": "server_tool_use",
|
||||
"id": "srvtoolu_01",
|
||||
"name": "web_search",
|
||||
"input": {"query": "weather"},
|
||||
},
|
||||
],
|
||||
response_metadata={"model_provider": "anthropic"},
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
events = list(chunks_to_events(iter(chunks)))
|
||||
finish_blocks = [
|
||||
e["content_block"] for e in events if e["event"] == "content-block-finish"
|
||||
]
|
||||
block_types = [b.get("type") for b in finish_blocks]
|
||||
assert "server_tool_call" in block_types
|
||||
assert "text" in block_types
|
||||
|
||||
|
||||
def test_chunks_to_events_unregistered_provider_falls_back() -> None:
|
||||
"""Unknown provider tag doesn't crash; best-effort parsing surfaces text."""
|
||||
chunks = [
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="Hello",
|
||||
response_metadata={"model_provider": "totally-made-up-provider"},
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
events = list(chunks_to_events(iter(chunks)))
|
||||
finish_events = [e for e in events if e["event"] == "content-block-finish"]
|
||||
assert [e["content_block"]["type"] for e in finish_events] == ["text"]
|
||||
|
||||
|
||||
def test_chunks_to_events_no_provider_text_plus_tool_call() -> None:
|
||||
"""Without a provider tag, text + tool_call_chunks both come through.
|
||||
|
||||
This is the case the old legacy path silently dropped the tool call
|
||||
because it re-mined tool_call_chunks on top of the positional index
|
||||
already used by the text block. Trusting content_blocks keeps them
|
||||
on distinct indices.
|
||||
"""
|
||||
chunks = [
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="Hello",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"index": 1,
|
||||
"id": "t1",
|
||||
"name": "search",
|
||||
"args": '{"q": "x"}',
|
||||
"type": "tool_call_chunk",
|
||||
},
|
||||
],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
events = list(chunks_to_events(iter(chunks)))
|
||||
finish_blocks = [
|
||||
e["content_block"] for e in events if e["event"] == "content-block-finish"
|
||||
]
|
||||
types = [b.get("type") for b in finish_blocks]
|
||||
assert "text" in types
|
||||
assert "tool_call" in types
|
||||
|
||||
|
||||
def test_chunks_to_events_reasoning_in_additional_kwargs() -> None:
|
||||
"""Reasoning packed into additional_kwargs surfaces as a reasoning block."""
|
||||
chunks = [
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=[{"type": "text", "text": "2+2=4"}],
|
||||
additional_kwargs={"reasoning_content": "Adding two and two..."},
|
||||
response_metadata={"model_provider": "unknown-open-model"},
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
events = list(chunks_to_events(iter(chunks)))
|
||||
finish_blocks = [
|
||||
e["content_block"] for e in events if e["event"] == "content-block-finish"
|
||||
]
|
||||
types = [b.get("type") for b in finish_blocks]
|
||||
assert "reasoning" in types
|
||||
assert "text" in types
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# message_to_events: finalized-message replay
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_message_to_events_text_only() -> None:
|
||||
msg = AIMessage(content="Hello world", id="msg-1")
|
||||
events = list(message_to_events(msg))
|
||||
|
||||
event_types = [e["event"] for e in events]
|
||||
assert event_types == [
|
||||
"message-start",
|
||||
"content-block-start",
|
||||
"content-block-delta",
|
||||
"content-block-finish",
|
||||
"message-finish",
|
||||
]
|
||||
start = cast("MessageStartData", events[0])
|
||||
assert start["message_id"] == "msg-1"
|
||||
|
||||
delta_event = cast("ContentBlockDeltaData", events[2])
|
||||
delta = cast("TextBlock", delta_event["content_block"])
|
||||
assert delta["text"] == "Hello world"
|
||||
|
||||
final = cast("MessageFinishData", events[-1])
|
||||
assert final["reason"] == "stop"
|
||||
|
||||
|
||||
def test_message_to_events_empty_content_yields_start_finish_only() -> None:
|
||||
msg = AIMessage(content="", id="msg-empty")
|
||||
events = list(message_to_events(msg))
|
||||
event_types = [e["event"] for e in events]
|
||||
assert event_types == ["message-start", "message-finish"]
|
||||
|
||||
|
||||
def test_message_to_events_reasoning_text_order() -> None:
|
||||
msg = AIMessage(
|
||||
content=[
|
||||
{"type": "reasoning", "reasoning": "think hard"},
|
||||
{"type": "text", "text": "the answer"},
|
||||
],
|
||||
id="msg-2",
|
||||
)
|
||||
events = list(message_to_events(msg))
|
||||
|
||||
starts = [e for e in events if e["event"] == "content-block-start"]
|
||||
finishes = [e for e in events if e["event"] == "content-block-finish"]
|
||||
assert [s["content_block"]["type"] for s in starts] == ["reasoning", "text"]
|
||||
assert [f["content_block"]["type"] for f in finishes] == ["reasoning", "text"]
|
||||
|
||||
deltas = [e for e in events if e["event"] == "content-block-delta"]
|
||||
assert len(deltas) == 2
|
||||
assert cast("ReasoningBlock", deltas[0]["content_block"])["reasoning"] == (
|
||||
"think hard"
|
||||
)
|
||||
assert cast("TextBlock", deltas[1]["content_block"])["text"] == "the answer"
|
||||
|
||||
|
||||
def test_message_to_events_tool_call_skips_delta_and_infers_tool_use() -> None:
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
id="msg-3",
|
||||
tool_calls=[
|
||||
{"id": "tc1", "name": "search", "args": {"q": "hi"}, "type": "tool_call"},
|
||||
],
|
||||
)
|
||||
events = list(message_to_events(msg))
|
||||
|
||||
# Finalized tool_call blocks carry no useful incremental text,
|
||||
# so no content-block-delta is emitted.
|
||||
deltas = [e for e in events if e["event"] == "content-block-delta"]
|
||||
assert deltas == []
|
||||
|
||||
finishes = [e for e in events if e["event"] == "content-block-finish"]
|
||||
assert len(finishes) == 1
|
||||
tc = cast("ToolCallBlock", finishes[0]["content_block"])
|
||||
assert tc["type"] == "tool_call"
|
||||
assert tc["args"] == {"q": "hi"}
|
||||
|
||||
final = cast("MessageFinishData", events[-1])
|
||||
assert final["reason"] == "tool_use"
|
||||
|
||||
|
||||
def test_message_to_events_invalid_tool_calls_surfaced_from_field() -> None:
|
||||
"""`invalid_tool_calls` on AIMessage surface as protocol blocks.
|
||||
|
||||
`AIMessage.content_blocks` does not currently include
|
||||
`invalid_tool_calls`, so the bridge merges them in explicitly.
|
||||
"""
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"type": "invalid_tool_call",
|
||||
"id": "call_1",
|
||||
"name": "search",
|
||||
"args": '{"q":',
|
||||
"error": "bad json",
|
||||
}
|
||||
],
|
||||
)
|
||||
events = list(message_to_events(msg))
|
||||
finishes = [e for e in events if e["event"] == "content-block-finish"]
|
||||
types = [f["content_block"]["type"] for f in finishes]
|
||||
assert "invalid_tool_call" in types
|
||||
|
||||
|
||||
def test_message_to_events_preserves_finish_reason_and_metadata() -> None:
|
||||
msg = AIMessage(
|
||||
content="done",
|
||||
id="msg-4",
|
||||
response_metadata={
|
||||
"finish_reason": "length",
|
||||
"model_name": "test-model",
|
||||
"stop_sequence": "</end>",
|
||||
},
|
||||
)
|
||||
events = list(message_to_events(msg))
|
||||
|
||||
start = cast("MessageStartData", events[0])
|
||||
assert start["metadata"] == {"model": "test-model"}
|
||||
|
||||
final = cast("MessageFinishData", events[-1])
|
||||
assert final["reason"] == "length"
|
||||
# finish_reason stripped from metadata; stop_sequence preserved
|
||||
assert final["metadata"] == {"model_name": "test-model", "stop_sequence": "</end>"}
|
||||
|
||||
|
||||
def test_message_to_events_propagates_usage() -> None:
|
||||
msg = AIMessage(
|
||||
content="hi",
|
||||
id="msg-5",
|
||||
usage_metadata={"input_tokens": 10, "output_tokens": 2, "total_tokens": 12},
|
||||
)
|
||||
events = list(message_to_events(msg))
|
||||
|
||||
final = cast("MessageFinishData", events[-1])
|
||||
assert final["usage"] == {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 2,
|
||||
"total_tokens": 12,
|
||||
}
|
||||
|
||||
|
||||
def test_message_to_events_message_id_override() -> None:
|
||||
msg = AIMessage(content="x", id="msg-orig")
|
||||
events = list(message_to_events(msg, message_id="msg-override"))
|
||||
start = cast("MessageStartData", events[0])
|
||||
assert start["message_id"] == "msg-override"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_amessage_to_events_matches_sync() -> None:
|
||||
msg = AIMessage(
|
||||
content=[
|
||||
{"type": "reasoning", "reasoning": "why"},
|
||||
{"type": "text", "text": "because"},
|
||||
],
|
||||
id="msg-async",
|
||||
)
|
||||
sync_events = list(message_to_events(msg))
|
||||
async_events = [e async for e in amessage_to_events(msg)]
|
||||
assert async_events == sync_events
|
||||
371
libs/core/tests/unit_tests/language_models/test_stream_v2.py
Normal file
371
libs/core/tests/unit_tests/language_models/test_stream_v2.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Tests for stream_v2 / astream_v2 and ChatModelStream."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
from langchain_protocol.protocol import (
|
||||
ContentBlockDeltaData,
|
||||
ContentBlockFinishData,
|
||||
MessageFinishData,
|
||||
ReasoningBlock,
|
||||
TextBlock,
|
||||
ToolCallBlock,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
from langchain_core.language_models.chat_model_stream import (
|
||||
AsyncChatModelStream,
|
||||
ChatModelStream,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class _MalformedToolCallModel(BaseChatModel):
|
||||
"""Fake model that emits a tool_call_chunk with malformed JSON args."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "malformed-tool-call-fake"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
del messages, stop, run_manager, kwargs
|
||||
raise NotImplementedError
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
del messages, stop, run_manager, kwargs
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": "search",
|
||||
"args": '{"q": ', # malformed JSON
|
||||
"id": "call_1",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class _AnthropicStyleServerToolModel(BaseChatModel):
|
||||
"""Fake model that streams Anthropic-native server_tool_use shapes.
|
||||
|
||||
Exercises Phase E: the bridge should call `content_blocks` (which
|
||||
invokes the Anthropic translator) to convert `server_tool_use` into
|
||||
protocol `server_tool_call` blocks instead of silently dropping them.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "anthropic-style-fake"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
del messages, stop, run_manager, kwargs
|
||||
raise NotImplementedError
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
del messages, stop, run_manager, kwargs
|
||||
# Single chunk carrying a complete server_tool_use block — what
|
||||
# Anthropic typically emits once input_json_delta finishes.
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "server_tool_use",
|
||||
"id": "srvtoolu_01",
|
||||
"name": "web_search",
|
||||
"input": {"query": "weather today"},
|
||||
},
|
||||
{"type": "text", "text": "Based on the search..."},
|
||||
],
|
||||
response_metadata={"model_provider": "anthropic"},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestChatModelStream:
|
||||
"""Test the sync ChatModelStream object."""
|
||||
|
||||
def test_push_text_delta(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
stream._push_content_block_delta(
|
||||
ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=0,
|
||||
content_block=TextBlock(type="text", text="Hello"),
|
||||
)
|
||||
)
|
||||
assert stream._text_acc == "Hello"
|
||||
|
||||
def test_push_reasoning_delta(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
stream._push_content_block_delta(
|
||||
ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=0,
|
||||
content_block=ReasoningBlock(type="reasoning", reasoning="think"),
|
||||
)
|
||||
)
|
||||
assert stream._reasoning_acc == "think"
|
||||
|
||||
def test_push_content_block_finish_tool_call(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
stream._push_content_block_finish(
|
||||
ContentBlockFinishData(
|
||||
event="content-block-finish",
|
||||
index=0,
|
||||
content_block=ToolCallBlock(
|
||||
type="tool_call",
|
||||
id="tc1",
|
||||
name="search",
|
||||
args={"q": "test"},
|
||||
),
|
||||
)
|
||||
)
|
||||
assert len(stream._tool_calls_acc) == 1
|
||||
assert stream._tool_calls_acc[0]["name"] == "search"
|
||||
|
||||
def test_finish(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
assert not stream.done
|
||||
usage = UsageInfo(input_tokens=10, output_tokens=5, total_tokens=15)
|
||||
stream._finish(
|
||||
MessageFinishData(event="message-finish", reason="stop", usage=usage)
|
||||
)
|
||||
assert stream.done
|
||||
assert stream._usage_value == usage
|
||||
|
||||
def test_fail(self) -> None:
|
||||
stream = ChatModelStream()
|
||||
stream.fail(RuntimeError("test"))
|
||||
assert stream.done
|
||||
|
||||
def test_pump_driven_text(self) -> None:
|
||||
"""Test text projection with pump binding."""
|
||||
stream = ChatModelStream()
|
||||
deltas: list[ContentBlockDeltaData] = [
|
||||
ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=0,
|
||||
content_block=TextBlock(type="text", text="Hi"),
|
||||
),
|
||||
ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=0,
|
||||
content_block=TextBlock(type="text", text=" there"),
|
||||
),
|
||||
]
|
||||
finish = MessageFinishData(event="message-finish", reason="stop")
|
||||
idx = 0
|
||||
|
||||
def pump_one() -> bool:
|
||||
nonlocal idx
|
||||
if idx < len(deltas):
|
||||
stream._push_content_block_delta(deltas[idx])
|
||||
idx += 1
|
||||
return True
|
||||
if idx == len(deltas):
|
||||
stream._finish(finish)
|
||||
idx += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
stream.bind_pump(pump_one)
|
||||
|
||||
text_deltas = list(stream.text)
|
||||
assert text_deltas == ["Hi", " there"]
|
||||
assert stream.done
|
||||
|
||||
|
||||
class TestAsyncChatModelStream:
|
||||
"""Test the async ChatModelStream object."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_await(self) -> None:
|
||||
stream = AsyncChatModelStream()
|
||||
stream._push_content_block_delta(
|
||||
ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=0,
|
||||
content_block=TextBlock(type="text", text="Hello"),
|
||||
)
|
||||
)
|
||||
stream._push_content_block_delta(
|
||||
ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=0,
|
||||
content_block=TextBlock(type="text", text=" world"),
|
||||
)
|
||||
)
|
||||
stream._finish(MessageFinishData(event="message-finish", reason="stop"))
|
||||
|
||||
full = await stream.text
|
||||
assert full == "Hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_async_iter(self) -> None:
|
||||
stream = AsyncChatModelStream()
|
||||
|
||||
async def produce() -> None:
|
||||
await asyncio.sleep(0)
|
||||
stream._push_content_block_delta(
|
||||
ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=0,
|
||||
content_block=TextBlock(type="text", text="a"),
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
stream._push_content_block_delta(
|
||||
ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=0,
|
||||
content_block=TextBlock(type="text", text="b"),
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
stream._finish(MessageFinishData(event="message-finish", reason="stop"))
|
||||
|
||||
asyncio.get_running_loop().create_task(produce())
|
||||
|
||||
deltas = [d async for d in stream.text]
|
||||
assert deltas == ["a", "b"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_calls_await(self) -> None:
|
||||
stream = AsyncChatModelStream()
|
||||
stream._push_content_block_finish(
|
||||
ContentBlockFinishData(
|
||||
event="content-block-finish",
|
||||
index=0,
|
||||
content_block=ToolCallBlock(
|
||||
type="tool_call",
|
||||
id="tc1",
|
||||
name="search",
|
||||
args={"q": "test"},
|
||||
),
|
||||
)
|
||||
)
|
||||
stream._finish(MessageFinishData(event="message-finish", reason="tool_use"))
|
||||
|
||||
tool_calls = await stream.tool_calls
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0]["name"] == "search"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_propagation(self) -> None:
|
||||
stream = AsyncChatModelStream()
|
||||
stream.fail(RuntimeError("boom"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await stream.text
|
||||
|
||||
|
||||
class TestStreamV2:
|
||||
"""Test BaseChatModel.stream_v2() with FakeListChatModel."""
|
||||
|
||||
def test_stream_v2_text(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hello world!"])
|
||||
stream = model.stream_v2("test")
|
||||
|
||||
assert isinstance(stream, ChatModelStream)
|
||||
deltas = list(stream.text)
|
||||
assert "".join(deltas) == "Hello world!"
|
||||
assert stream.done
|
||||
|
||||
def test_stream_v2_usage(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hi"])
|
||||
stream = model.stream_v2("test")
|
||||
|
||||
# Drain stream
|
||||
for _ in stream.text:
|
||||
pass
|
||||
# FakeListChatModel doesn't emit usage, so it should be None
|
||||
assert stream.usage is None
|
||||
assert stream.done
|
||||
|
||||
def test_stream_v2_malformed_tool_args_produce_invalid_tool_call(self) -> None:
|
||||
"""End-to-end: malformed tool-call JSON becomes invalid_tool_calls."""
|
||||
model = _MalformedToolCallModel()
|
||||
stream = model.stream_v2("test")
|
||||
msg = stream.output
|
||||
|
||||
assert msg.tool_calls == []
|
||||
assert len(msg.invalid_tool_calls) == 1
|
||||
itc = msg.invalid_tool_calls[0]
|
||||
assert itc["name"] == "search"
|
||||
assert itc["args"] == '{"q": '
|
||||
assert itc["id"] == "call_1"
|
||||
|
||||
def test_stream_v2_translates_anthropic_server_tool_use_to_protocol(self) -> None:
|
||||
"""Phase E end-to-end: server_tool_use becomes server_tool_call in output."""
|
||||
model = _AnthropicStyleServerToolModel()
|
||||
stream = model.stream_v2("weather?")
|
||||
msg = stream.output
|
||||
|
||||
assert isinstance(msg.content, list)
|
||||
types = [b.get("type") for b in msg.content if isinstance(b, dict)]
|
||||
# The server tool call must appear in the output content.
|
||||
assert "server_tool_call" in types
|
||||
# Text block should also be present.
|
||||
assert "text" in types
|
||||
# Regular tool_calls should NOT include the server-executed call.
|
||||
assert msg.tool_calls == []
|
||||
|
||||
|
||||
class TestAstreamV2:
|
||||
"""Test BaseChatModel.astream_v2() with FakeListChatModel."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astream_v2_text(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hello!"])
|
||||
stream = await model.astream_v2("test")
|
||||
|
||||
assert isinstance(stream, AsyncChatModelStream)
|
||||
full = await stream.text
|
||||
assert full == "Hello!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astream_v2_deltas(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hi"])
|
||||
stream = await model.astream_v2("test")
|
||||
|
||||
deltas = [d async for d in stream.text]
|
||||
assert "".join(deltas) == "Hi"
|
||||
368
libs/core/tests/unit_tests/language_models/test_v1_parity.py
Normal file
368
libs/core/tests/unit_tests/language_models/test_v1_parity.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""V1 parity tests: stream_v2() output must match model.stream() output.
|
||||
|
||||
These are the acceptance criteria for streaming v2 — if any test fails,
|
||||
v2 has a regression vs v1.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class _ScriptedChunkModel(BaseChatModel):
|
||||
"""Fake chat model that streams a fixed, pre-built sequence of chunks.
|
||||
|
||||
Lets us write parity tests that exercise tool calls, reasoning,
|
||||
usage metadata, and response metadata — shapes `FakeListChatModel`
|
||||
cannot produce.
|
||||
"""
|
||||
|
||||
scripted_chunks: list[AIMessageChunk]
|
||||
raise_after: bool = False
|
||||
"""If True, raise `_FakeStreamError` after yielding all scripted chunks."""
|
||||
|
||||
@property
|
||||
@override
|
||||
def _llm_type(self) -> str:
|
||||
return "scripted-chunk-fake"
|
||||
|
||||
def _merged(self) -> AIMessageChunk:
|
||||
merged = self.scripted_chunks[0]
|
||||
for c in self.scripted_chunks[1:]:
|
||||
merged = merged + c
|
||||
return merged
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
merged = self._merged()
|
||||
final = AIMessage(
|
||||
content=merged.content,
|
||||
id=merged.id,
|
||||
tool_calls=merged.tool_calls,
|
||||
usage_metadata=merged.usage_metadata,
|
||||
response_metadata=merged.response_metadata,
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=final)])
|
||||
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
for chunk in self.scripted_chunks:
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
if self.raise_after:
|
||||
msg = "scripted failure"
|
||||
raise _FakeStreamError(msg)
|
||||
|
||||
@override
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
for chunk in self.scripted_chunks:
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
if self.raise_after:
|
||||
msg = "scripted failure"
|
||||
raise _FakeStreamError(msg)
|
||||
|
||||
|
||||
class _FakeStreamError(RuntimeError):
|
||||
"""Marker exception raised by `_ScriptedChunkModel` during streaming."""
|
||||
|
||||
|
||||
def _collect_v1_message(model: BaseChatModel, input_text: str) -> AIMessage:
|
||||
"""Run model.stream() and merge chunks into a single AIMessage."""
|
||||
chunks: list[AIMessageChunk] = [
|
||||
chunk for chunk in model.stream(input_text) if isinstance(chunk, AIMessageChunk)
|
||||
]
|
||||
if not chunks:
|
||||
msg = "No chunks produced"
|
||||
raise RuntimeError(msg)
|
||||
merged = chunks[0]
|
||||
for c in chunks[1:]:
|
||||
merged = merged + c
|
||||
return AIMessage(
|
||||
content=merged.content,
|
||||
id=merged.id,
|
||||
tool_calls=merged.tool_calls,
|
||||
usage_metadata=merged.usage_metadata,
|
||||
response_metadata=merged.response_metadata,
|
||||
)
|
||||
|
||||
|
||||
def _collect_v2_message(model: BaseChatModel, input_text: str) -> AIMessage:
|
||||
"""Run model.stream_v2() and get .output."""
|
||||
stream = model.stream_v2(input_text)
|
||||
return stream.output
|
||||
|
||||
|
||||
class TestV1ParityBasic:
|
||||
"""Smoke-level parity using the simple text-only fake."""
|
||||
|
||||
def test_text_only_content_matches(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hello world!"])
|
||||
v1 = _collect_v1_message(model, "test")
|
||||
model.i = 0
|
||||
v2 = _collect_v2_message(model, "test")
|
||||
|
||||
assert v1.content == v2.content
|
||||
|
||||
def test_message_id_present(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hi"])
|
||||
v1 = _collect_v1_message(model, "test")
|
||||
model.i = 0
|
||||
v2 = _collect_v2_message(model, "test")
|
||||
|
||||
assert v1.id is not None
|
||||
assert v2.id is not None
|
||||
|
||||
def test_empty_response(self) -> None:
|
||||
model = FakeListChatModel(responses=[""])
|
||||
stream = model.stream_v2("test")
|
||||
msg = stream.output
|
||||
assert msg.content == ""
|
||||
|
||||
def test_multi_character_response(self) -> None:
|
||||
text = "The quick brown fox"
|
||||
model = FakeListChatModel(responses=[text])
|
||||
v2 = _collect_v2_message(model, "test")
|
||||
assert v2.content == text
|
||||
|
||||
def test_text_deltas_reconstruct_content(self) -> None:
|
||||
model = FakeListChatModel(responses=["Hello!"])
|
||||
stream = model.stream_v2("test")
|
||||
|
||||
deltas = list(stream.text)
|
||||
assert "".join(deltas) == stream.output.content
|
||||
|
||||
|
||||
class TestV1ParityToolCalls:
|
||||
"""Tool-call parity — the most load-bearing v1 shape."""
|
||||
|
||||
@staticmethod
|
||||
def _make_model() -> _ScriptedChunkModel:
|
||||
chunks = [
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
id="run-tool-1",
|
||||
tool_call_chunks=[
|
||||
{"index": 0, "id": "call_1", "name": "get_weather", "args": ""},
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
id="run-tool-1",
|
||||
tool_call_chunks=[
|
||||
{"index": 0, "id": None, "name": None, "args": '{"city": "'},
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
id="run-tool-1",
|
||||
tool_call_chunks=[
|
||||
{"index": 0, "id": None, "name": None, "args": 'Paris"}'},
|
||||
],
|
||||
response_metadata={"finish_reason": "tool_use"},
|
||||
),
|
||||
]
|
||||
return _ScriptedChunkModel(scripted_chunks=chunks)
|
||||
|
||||
def test_tool_calls_match(self) -> None:
|
||||
model = self._make_model()
|
||||
v1 = _collect_v1_message(model, "weather?")
|
||||
v2 = _collect_v2_message(self._make_model(), "weather?")
|
||||
|
||||
assert len(v1.tool_calls) == 1
|
||||
assert len(v2.tool_calls) == 1
|
||||
assert v1.tool_calls[0]["id"] == v2.tool_calls[0]["id"] == "call_1"
|
||||
assert v1.tool_calls[0]["name"] == v2.tool_calls[0]["name"] == "get_weather"
|
||||
assert v1.tool_calls[0]["args"] == v2.tool_calls[0]["args"] == {"city": "Paris"}
|
||||
|
||||
def test_tool_calls_via_projection(self) -> None:
|
||||
model = self._make_model()
|
||||
stream = model.stream_v2("weather?")
|
||||
finalized = stream.tool_calls.get()
|
||||
assert len(finalized) == 1
|
||||
assert finalized[0]["name"] == "get_weather"
|
||||
assert finalized[0]["args"] == {"city": "Paris"}
|
||||
|
||||
def test_finish_reason_tool_use(self) -> None:
|
||||
model = self._make_model()
|
||||
v2 = _collect_v2_message(model, "weather?")
|
||||
assert v2.response_metadata.get("finish_reason") == "tool_use"
|
||||
|
||||
|
||||
class TestV1ParityUsage:
|
||||
"""Usage metadata parity."""
|
||||
|
||||
@staticmethod
|
||||
def _make_model() -> _ScriptedChunkModel:
|
||||
chunks = [
|
||||
AIMessageChunk(content="Hi", id="run-usage-1"),
|
||||
AIMessageChunk(
|
||||
content=" there",
|
||||
id="run-usage-1",
|
||||
usage_metadata={
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
response_metadata={"finish_reason": "stop"},
|
||||
),
|
||||
]
|
||||
return _ScriptedChunkModel(scripted_chunks=chunks)
|
||||
|
||||
def test_usage_metadata_present(self) -> None:
|
||||
v1 = _collect_v1_message(self._make_model(), "hello")
|
||||
v2 = _collect_v2_message(self._make_model(), "hello")
|
||||
|
||||
assert v1.usage_metadata is not None
|
||||
assert v2.usage_metadata is not None
|
||||
assert v1.usage_metadata["input_tokens"] == v2.usage_metadata["input_tokens"]
|
||||
assert v1.usage_metadata["output_tokens"] == v2.usage_metadata["output_tokens"]
|
||||
assert v1.usage_metadata["total_tokens"] == v2.usage_metadata["total_tokens"]
|
||||
|
||||
def test_usage_projection_matches(self) -> None:
|
||||
stream = self._make_model().stream_v2("hello")
|
||||
# Drain so usage is available
|
||||
for _ in stream.text:
|
||||
pass
|
||||
assert stream.usage is not None
|
||||
assert stream.usage["input_tokens"] == 10
|
||||
assert stream.usage["output_tokens"] == 5
|
||||
|
||||
|
||||
class TestV1ParityResponseMetadata:
|
||||
"""Response metadata preservation (fix 5b)."""
|
||||
|
||||
@staticmethod
|
||||
def _make_model() -> _ScriptedChunkModel:
|
||||
chunks = [
|
||||
AIMessageChunk(
|
||||
content="ok",
|
||||
id="run-meta-1",
|
||||
response_metadata={
|
||||
"finish_reason": "stop",
|
||||
"model_provider": "fake-provider",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
),
|
||||
]
|
||||
return _ScriptedChunkModel(scripted_chunks=chunks)
|
||||
|
||||
def test_finish_reason_preserved(self) -> None:
|
||||
v2 = _collect_v2_message(self._make_model(), "hi")
|
||||
assert v2.response_metadata.get("finish_reason") == "stop"
|
||||
|
||||
def test_provider_metadata_preserved(self) -> None:
|
||||
"""Non-finish-reason keys should survive the round-trip."""
|
||||
v2 = _collect_v2_message(self._make_model(), "hi")
|
||||
# stop_sequence came from response_metadata on chunks; the bridge
|
||||
# should carry it through via MessageFinishData.metadata.
|
||||
assert "stop_sequence" in v2.response_metadata
|
||||
|
||||
|
||||
class TestV1ParityReasoning:
|
||||
"""Reasoning content parity — order must be preserved."""
|
||||
|
||||
@staticmethod
|
||||
def _make_model() -> _ScriptedChunkModel:
|
||||
chunks = [
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{"type": "reasoning", "reasoning": "Let me think. ", "index": 0},
|
||||
],
|
||||
id="run-reason-1",
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{"type": "reasoning", "reasoning": "Done.", "index": 0},
|
||||
],
|
||||
id="run-reason-1",
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{"type": "text", "text": "The answer is 42.", "index": 1},
|
||||
],
|
||||
id="run-reason-1",
|
||||
response_metadata={"finish_reason": "stop"},
|
||||
),
|
||||
]
|
||||
return _ScriptedChunkModel(scripted_chunks=chunks)
|
||||
|
||||
def test_reasoning_text_order(self) -> None:
|
||||
"""Reasoning block should come before text block in .output.content."""
|
||||
v2 = _collect_v2_message(self._make_model(), "think")
|
||||
assert isinstance(v2.content, list)
|
||||
types_in_order = [b.get("type") for b in v2.content if isinstance(b, dict)]
|
||||
assert types_in_order == ["reasoning", "text"]
|
||||
|
||||
def test_reasoning_projection(self) -> None:
|
||||
stream = self._make_model().stream_v2("think")
|
||||
full_reasoning = str(stream.reasoning)
|
||||
assert full_reasoning == "Let me think. Done."
|
||||
|
||||
|
||||
class TestV1ParityError:
|
||||
"""Errors during streaming must propagate on both paths."""
|
||||
|
||||
def test_error_propagates_sync(self) -> None:
|
||||
chunks = [
|
||||
AIMessageChunk(content="partial", id="run-err-1"),
|
||||
]
|
||||
model = _ScriptedChunkModel(scripted_chunks=chunks, raise_after=True)
|
||||
|
||||
stream = model.stream_v2("boom")
|
||||
# Drain first; error may surface here or at .output access.
|
||||
try:
|
||||
list(stream.text)
|
||||
except _FakeStreamError:
|
||||
return # Error surfaced during iteration — pass
|
||||
with pytest.raises(_FakeStreamError):
|
||||
_ = stream.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_propagates_async(self) -> None:
|
||||
chunks = [
|
||||
AIMessageChunk(content="partial", id="run-err-2"),
|
||||
]
|
||||
model = _ScriptedChunkModel(scripted_chunks=chunks, raise_after=True)
|
||||
|
||||
stream = await model.astream_v2("boom")
|
||||
try:
|
||||
async for _ in stream.text:
|
||||
pass
|
||||
except _FakeStreamError:
|
||||
return
|
||||
with pytest.raises(_FakeStreamError):
|
||||
_ = await stream
|
||||
14
libs/core/uv.lock
generated
14
libs/core/uv.lock
generated
@@ -999,6 +999,7 @@ version = "1.3.0a2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
{ name = "langchain-protocol" },
|
||||
{ name = "langsmith" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pydantic" },
|
||||
@@ -1045,6 +1046,7 @@ typing = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "jsonpatch", specifier = ">=1.33.0,<2.0.0" },
|
||||
{ name = "langchain-protocol", specifier = ">=0.0.8" },
|
||||
{ name = "langsmith", specifier = ">=0.3.45,<1.0.0" },
|
||||
{ name = "packaging", specifier = ">=23.2.0" },
|
||||
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
|
||||
@@ -1087,6 +1089,18 @@ typing = [
|
||||
{ name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-protocol"
|
||||
version = "0.0.8"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/40/0b/34d23ad37c4ef14f96cf6990b619e2e7c4f9e58c7f1089f044f963af3b32/langchain_protocol-0.0.8.tar.gz", hash = "sha256:28fc94f3278cf0da6b9b2e8cc4cd40cafc9e79b6f2de8dc2d06879327af0762c", size = 6357, upload-time = "2026-04-16T20:01:38.218Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a7/02/2bd9075e6f7fb75155b1e8208535ca78be0f4f16c03994295c74c01cbc04/langchain_protocol-0.0.8-py3-none-any.whl", hash = "sha256:39c7b28f1f7a98317ca5353d2ddb111cbbab9d295d15246ffd34449417c0b614", size = 6559, upload-time = "2026-04-16T20:01:37.364Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-tests"
|
||||
version = "1.1.6"
|
||||
|
||||
Reference in New Issue
Block a user