This commit is contained in:
Christian Bromann
2026-04-28 20:18:12 -07:00
parent cb97a91d37
commit 8227fd329f
4 changed files with 38 additions and 37 deletions

View File

@@ -575,7 +575,7 @@ class _ChatModelStreamBase:
# -- Event ingestion (public) ------------------------------------------
def dispatch(self, event: MessagesData) -> None:
def dispatch(self, event: Mapping[str, Any]) -> None:
"""Route a protocol event to the appropriate internal handler.
Public entry point for feeding events into the stream. Called by
@@ -598,9 +598,9 @@ class _ChatModelStreamBase:
# -- Internal push API (called by dispatch) ----------------------------
def _record_event(self, event: MessagesData) -> None:
def _record_event(self, event: Mapping[str, Any]) -> None:
"""Append a raw event to the replay buffer."""
self._events.append(event)
self._events.append(cast("MessagesData", event))
def _push_message_start(self, data: MessageStartData) -> None:
"""Process a `message-start` event."""

View File

@@ -16,7 +16,7 @@ from langchain_core.language_models.chat_model_stream import (
)
if TYPE_CHECKING:
from langchain_protocol.protocol import ContentBlockFinishData, MessagesData
from langchain_protocol.protocol import ContentBlockFinishData
# ---------------------------------------------------------------------------
# Projection unit tests
@@ -244,7 +244,7 @@ class TestAsyncProjection:
"""Concurrent `stream.text` + `await stream.output` both drive the pump."""
stream = AsyncChatModelStream(message_id="m1")
events: list[MessagesData] = [
events: list[dict[str, Any]] = [
{
"event": "message-start",
"role": "ai",
@@ -313,7 +313,7 @@ class TestChatModelStream:
def test_text_deltas_via_pump(self) -> None:
stream = ChatModelStream()
events: list[MessagesData] = [
events: list[dict[str, Any]] = [
{"event": "message-start", "role": "ai"},
{
"event": "content-block-delta",
@@ -363,7 +363,7 @@ class TestChatModelStream:
}
)
stream.dispatch(
{ # type: ignore[arg-type,misc]
{
"event": "content-block-delta",
"index": 0,
"content_block": {

View File

@@ -28,6 +28,11 @@ if TYPE_CHECKING:
)
def _event_metadata(event: Any) -> dict[str, Any]:
"""Return event metadata for protocol versions that type it as extensible."""
return cast("dict[str, Any]", cast("dict[str, Any]", event).get("metadata") or {})
# ---------------------------------------------------------------------------
# Pure helpers
# ---------------------------------------------------------------------------
@@ -134,7 +139,7 @@ def test_chunks_to_events_text_only() -> None:
# No provider finish_reason in fixtures — metadata carries no
# `finish_reason` key (the bridge passes response_metadata through
# unchanged).
assert "finish_reason" not in (finish.get("metadata") or {})
assert "finish_reason" not in _event_metadata(finish)
def test_chunks_to_events_empty_iterator() -> None:
@@ -291,9 +296,7 @@ def test_chunks_to_events_tool_call_multichunk() -> None:
# not synthesize one. It deliberately does not infer `"tool_use"`
# from the presence of a valid tool_call either; terminal reasons
# are provider-specific (see `_build_message_finish`).
assert "finish_reason" not in (
cast("MessageFinishData", events[-1]).get("metadata") or {}
)
assert "finish_reason" not in _event_metadata(events[-1])
def test_chunks_to_events_invalid_tool_call_keeps_stop_reason() -> None:
@@ -323,9 +326,7 @@ def test_chunks_to_events_invalid_tool_call_keeps_stop_reason() -> None:
]
assert len(finish_events) == 1
assert finish_events[0]["content"]["type"] == "invalid_tool_call"
assert "finish_reason" not in (
cast("MessageFinishData", events[-1]).get("metadata") or {}
)
assert "finish_reason" not in _event_metadata(events[-1])
def test_chunks_to_events_anthropic_server_tool_use_routes_through_translator() -> None:
@@ -453,7 +454,7 @@ def test_message_to_events_text_only() -> None:
assert delta_event["delta"] == {"type": "text-delta", "text": "Hello world"}
final = cast("MessageFinishData", events[-1])
assert "finish_reason" not in (final.get("metadata") or {})
assert "finish_reason" not in _event_metadata(final)
def test_message_to_events_empty_content_yields_start_finish_only() -> None:
@@ -512,7 +513,7 @@ def test_message_to_events_tool_call_skips_delta() -> None:
# bridge does not synthesize one and does not second-guess based on
# the presence of a tool_call.
final = cast("MessageFinishData", events[-1])
assert "finish_reason" not in (final.get("metadata") or {})
assert "finish_reason" not in _event_metadata(final)
def test_message_to_events_invalid_tool_calls_surfaced_from_field() -> None:
@@ -557,7 +558,7 @@ def test_message_to_events_preserves_finish_reason_and_metadata() -> None:
# Passthrough: response_metadata lands on `metadata` unchanged,
# including the raw provider `finish_reason`.
final = cast("MessageFinishData", events[-1])
assert final["metadata"] == {
assert _event_metadata(final) == {
"finish_reason": "length",
"model_name": "test-model",
"stop_sequence": "</end>",

View File

@@ -389,14 +389,14 @@ class TestPerBlockAccumulation:
ContentBlockDeltaData(
event="content-block-delta",
index=0,
content_block=TextContentBlock(type="text", text="A"),
delta={"type": "text-delta", "text": "A"},
)
)
stream.dispatch(
ContentBlockFinishData(
event="content-block-finish",
index=0,
content_block=TextContentBlock(type="text", text="A"),
content=TextContentBlock(type="text", text="A"),
)
)
# Block 1: "B"
@@ -404,14 +404,14 @@ class TestPerBlockAccumulation:
ContentBlockDeltaData(
event="content-block-delta",
index=1,
content_block=TextContentBlock(type="text", text="B"),
delta={"type": "text-delta", "text": "B"},
)
)
stream.dispatch(
ContentBlockFinishData(
event="content-block-finish",
index=1,
content_block=TextContentBlock(type="text", text="B"),
content=TextContentBlock(type="text", text="B"),
)
)
stream.dispatch(MessageFinishData(event="message-finish"))
@@ -435,14 +435,14 @@ class TestPerBlockAccumulation:
ContentBlockDeltaData(
event="content-block-delta",
index=0,
content_block=ReasoningContentBlock(type="reasoning", reasoning="one"),
delta={"type": "reasoning-delta", "reasoning": "one"},
)
)
stream.dispatch(
ContentBlockFinishData(
event="content-block-finish",
index=0,
content_block=ReasoningContentBlock(type="reasoning", reasoning="one"),
content=ReasoningContentBlock(type="reasoning", reasoning="one"),
)
)
# Block 1: "two"
@@ -450,14 +450,14 @@ class TestPerBlockAccumulation:
ContentBlockDeltaData(
event="content-block-delta",
index=1,
content_block=ReasoningContentBlock(type="reasoning", reasoning="two"),
delta={"type": "reasoning-delta", "reasoning": "two"},
)
)
stream.dispatch(
ContentBlockFinishData(
event="content-block-finish",
index=1,
content_block=ReasoningContentBlock(type="reasoning", reasoning="two"),
content=ReasoningContentBlock(type="reasoning", reasoning="two"),
)
)
stream.dispatch(MessageFinishData(event="message-finish"))
@@ -482,14 +482,14 @@ class TestPerBlockAccumulation:
ContentBlockDeltaData(
event="content-block-delta",
index=0,
content_block=TextContentBlock(type="text", text="hel"),
delta={"type": "text-delta", "text": "hel"},
)
)
stream.dispatch(
ContentBlockFinishData(
event="content-block-finish",
index=0,
content_block=TextContentBlock(type="text", text="hello"),
content=TextContentBlock(type="text", text="hello"),
)
)
stream.dispatch(MessageFinishData(event="message-finish"))
@@ -520,7 +520,7 @@ class TestPerBlockAccumulation:
ContentBlockDeltaData(
event="content-block-delta",
index=0,
content_block=TextContentBlock(type="text", text="aaa"),
delta={"type": "text-delta", "text": "aaa"},
)
)
# Block 1 streams deltas before block 0 finishes.
@@ -528,7 +528,7 @@ class TestPerBlockAccumulation:
ContentBlockDeltaData(
event="content-block-delta",
index=1,
content_block=TextContentBlock(type="text", text="bb"),
delta={"type": "text-delta", "text": "bb"},
)
)
# Block 0 finishes with authoritative text different from deltas.
@@ -536,14 +536,14 @@ class TestPerBlockAccumulation:
ContentBlockFinishData(
event="content-block-finish",
index=0,
content_block=TextContentBlock(type="text", text="XXX"),
content=TextContentBlock(type="text", text="XXX"),
)
)
stream.dispatch(
ContentBlockFinishData(
event="content-block-finish",
index=1,
content_block=TextContentBlock(type="text", text="bb"),
content=TextContentBlock(type="text", text="bb"),
)
)
stream.dispatch(MessageFinishData(event="message-finish"))
@@ -566,7 +566,7 @@ class TestPerBlockAccumulation:
ContentBlockDeltaData(
event="content-block-delta",
index=0,
content_block=ReasoningContentBlock(type="reasoning", reasoning="thi"),
delta={"type": "reasoning-delta", "reasoning": "thi"},
)
)
stream.dispatch(
@@ -594,14 +594,14 @@ class TestPerBlockAccumulation:
ContentBlockDeltaData(
event="content-block-delta",
index=0,
content_block=TextContentBlock(type="text", text="before"),
delta={"type": "text-delta", "text": "before"},
)
)
stream.dispatch(
ContentBlockFinishData(
event="content-block-finish",
index=0,
content_block=TextContentBlock(type="text", text="before"),
content=TextContentBlock(type="text", text="before"),
)
)
# Block 1: tool_call
@@ -609,7 +609,7 @@ class TestPerBlockAccumulation:
ContentBlockFinishData(
event="content-block-finish",
index=1,
content_block=ToolCall(
content=ToolCall(
type="tool_call",
id="tc1",
name="search",
@@ -622,14 +622,14 @@ class TestPerBlockAccumulation:
ContentBlockDeltaData(
event="content-block-delta",
index=2,
content_block=TextContentBlock(type="text", text="after"),
delta={"type": "text-delta", "text": "after"},
)
)
stream.dispatch(
ContentBlockFinishData(
event="content-block-finish",
index=2,
content_block=TextContentBlock(type="text", text="after"),
content=TextContentBlock(type="text", text="after"),
)
)
stream.dispatch(MessageFinishData(event="message-finish"))