diff --git a/libs/core/langchain_core/language_models/chat_model_stream.py b/libs/core/langchain_core/language_models/chat_model_stream.py index d73ea4cce94..ff88f5989f7 100644 --- a/libs/core/langchain_core/language_models/chat_model_stream.py +++ b/libs/core/langchain_core/language_models/chat_model_stream.py @@ -575,7 +575,7 @@ class _ChatModelStreamBase: # -- Event ingestion (public) ------------------------------------------ - def dispatch(self, event: MessagesData) -> None: + def dispatch(self, event: Mapping[str, Any]) -> None: """Route a protocol event to the appropriate internal handler. Public entry point for feeding events into the stream. Called by @@ -598,9 +598,9 @@ class _ChatModelStreamBase: # -- Internal push API (called by dispatch) ---------------------------- - def _record_event(self, event: MessagesData) -> None: + def _record_event(self, event: Mapping[str, Any]) -> None: """Append a raw event to the replay buffer.""" - self._events.append(event) + self._events.append(cast("MessagesData", event)) def _push_message_start(self, data: MessageStartData) -> None: """Process a `message-start` event.""" diff --git a/libs/core/tests/unit_tests/language_models/test_chat_model_stream.py b/libs/core/tests/unit_tests/language_models/test_chat_model_stream.py index a27a9e9d103..1230e2b23c6 100644 --- a/libs/core/tests/unit_tests/language_models/test_chat_model_stream.py +++ b/libs/core/tests/unit_tests/language_models/test_chat_model_stream.py @@ -16,7 +16,7 @@ from langchain_core.language_models.chat_model_stream import ( ) if TYPE_CHECKING: - from langchain_protocol.protocol import ContentBlockFinishData, MessagesData + from langchain_protocol.protocol import ContentBlockFinishData # --------------------------------------------------------------------------- # Projection unit tests @@ -244,7 +244,7 @@ class TestAsyncProjection: """Concurrent `stream.text` + `await stream.output` both drive the pump.""" stream = AsyncChatModelStream(message_id="m1") - events: list[MessagesData] = [ + events: list[dict[str, Any]] = [ { "event": "message-start", "role": "ai", @@ -313,7 +313,7 @@ class TestChatModelStream: def test_text_deltas_via_pump(self) -> None: stream = ChatModelStream() - events: list[MessagesData] = [ + events: list[dict[str, Any]] = [ {"event": "message-start", "role": "ai"}, { "event": "content-block-delta", @@ -363,7 +363,7 @@ class TestChatModelStream: } ) stream.dispatch( - { # type: ignore[arg-type,misc] + { "event": "content-block-delta", "index": 0, "content_block": { diff --git a/libs/core/tests/unit_tests/language_models/test_compat_bridge.py b/libs/core/tests/unit_tests/language_models/test_compat_bridge.py index d0a9b748c49..130e4d69c15 100644 --- a/libs/core/tests/unit_tests/language_models/test_compat_bridge.py +++ b/libs/core/tests/unit_tests/language_models/test_compat_bridge.py @@ -28,6 +28,11 @@ if TYPE_CHECKING: ) +def _event_metadata(event: Any) -> dict[str, Any]: + """Return event metadata for protocol versions that type it as extensible.""" + return cast("dict[str, Any]", cast("dict[str, Any]", event).get("metadata") or {}) + + # --------------------------------------------------------------------------- # Pure helpers # --------------------------------------------------------------------------- @@ -134,7 +139,7 @@ def test_chunks_to_events_text_only() -> None: # No provider finish_reason in fixtures — metadata carries no # `finish_reason` key (the bridge passes response_metadata through # unchanged). - assert "finish_reason" not in (finish.get("metadata") or {}) + assert "finish_reason" not in _event_metadata(finish) def test_chunks_to_events_empty_iterator() -> None: @@ -291,9 +296,7 @@ def test_chunks_to_events_tool_call_multichunk() -> None: # not synthesize one. It deliberately does not infer `"tool_use"` # from the presence of a valid tool_call either; terminal reasons # are provider-specific (see `_build_message_finish`). - assert "finish_reason" not in ( - cast("MessageFinishData", events[-1]).get("metadata") or {} - ) + assert "finish_reason" not in _event_metadata(events[-1]) def test_chunks_to_events_invalid_tool_call_keeps_stop_reason() -> None: @@ -323,9 +326,7 @@ def test_chunks_to_events_invalid_tool_call_keeps_stop_reason() -> None: ] assert len(finish_events) == 1 assert finish_events[0]["content"]["type"] == "invalid_tool_call" - assert "finish_reason" not in ( - cast("MessageFinishData", events[-1]).get("metadata") or {} - ) + assert "finish_reason" not in _event_metadata(events[-1]) def test_chunks_to_events_anthropic_server_tool_use_routes_through_translator() -> None: @@ -453,7 +454,7 @@ def test_message_to_events_text_only() -> None: assert delta_event["delta"] == {"type": "text-delta", "text": "Hello world"} final = cast("MessageFinishData", events[-1]) - assert "finish_reason" not in (final.get("metadata") or {}) + assert "finish_reason" not in _event_metadata(final) def test_message_to_events_empty_content_yields_start_finish_only() -> None: @@ -512,7 +513,7 @@ def test_message_to_events_tool_call_skips_delta() -> None: # bridge does not synthesize one and does not second-guess based on # the presence of a tool_call. final = cast("MessageFinishData", events[-1]) - assert "finish_reason" not in (final.get("metadata") or {}) + assert "finish_reason" not in _event_metadata(final) def test_message_to_events_invalid_tool_calls_surfaced_from_field() -> None: @@ -557,7 +558,7 @@ def test_message_to_events_preserves_finish_reason_and_metadata() -> None: # Passthrough: response_metadata lands on `metadata` unchanged, # including the raw provider `finish_reason`. final = cast("MessageFinishData", events[-1]) - assert final["metadata"] == { + assert _event_metadata(final) == { "finish_reason": "length", "model_name": "test-model", "stop_sequence": "", diff --git a/libs/core/tests/unit_tests/language_models/test_stream_v2.py b/libs/core/tests/unit_tests/language_models/test_stream_v2.py index e6860ee859c..9e847019d09 100644 --- a/libs/core/tests/unit_tests/language_models/test_stream_v2.py +++ b/libs/core/tests/unit_tests/language_models/test_stream_v2.py @@ -389,14 +389,14 @@ class TestPerBlockAccumulation: ContentBlockDeltaData( event="content-block-delta", index=0, - content_block=TextContentBlock(type="text", text="A"), + delta={"type": "text-delta", "text": "A"}, ) ) stream.dispatch( ContentBlockFinishData( event="content-block-finish", index=0, - content_block=TextContentBlock(type="text", text="A"), + content=TextContentBlock(type="text", text="A"), ) ) # Block 1: "B" @@ -404,14 +404,14 @@ class TestPerBlockAccumulation: ContentBlockDeltaData( event="content-block-delta", index=1, - content_block=TextContentBlock(type="text", text="B"), + delta={"type": "text-delta", "text": "B"}, ) ) stream.dispatch( ContentBlockFinishData( event="content-block-finish", index=1, - content_block=TextContentBlock(type="text", text="B"), + content=TextContentBlock(type="text", text="B"), ) ) stream.dispatch(MessageFinishData(event="message-finish")) @@ -435,14 +435,14 @@ class TestPerBlockAccumulation: ContentBlockDeltaData( event="content-block-delta", index=0, - content_block=ReasoningContentBlock(type="reasoning", reasoning="one"), + delta={"type": "reasoning-delta", "reasoning": "one"}, ) ) stream.dispatch( ContentBlockFinishData( event="content-block-finish", index=0, - content_block=ReasoningContentBlock(type="reasoning", reasoning="one"), + content=ReasoningContentBlock(type="reasoning", reasoning="one"), ) ) # Block 1: "two" @@ -450,14 +450,14 @@ class TestPerBlockAccumulation: ContentBlockDeltaData( event="content-block-delta", index=1, - content_block=ReasoningContentBlock(type="reasoning", reasoning="two"), + delta={"type": "reasoning-delta", "reasoning": "two"}, ) ) stream.dispatch( ContentBlockFinishData( event="content-block-finish", index=1, - content_block=ReasoningContentBlock(type="reasoning", reasoning="two"), + content=ReasoningContentBlock(type="reasoning", reasoning="two"), ) ) stream.dispatch(MessageFinishData(event="message-finish")) @@ -482,14 +482,14 @@ class TestPerBlockAccumulation: ContentBlockDeltaData( event="content-block-delta", index=0, - content_block=TextContentBlock(type="text", text="hel"), + delta={"type": "text-delta", "text": "hel"}, ) ) stream.dispatch( ContentBlockFinishData( event="content-block-finish", index=0, - content_block=TextContentBlock(type="text", text="hello"), + content=TextContentBlock(type="text", text="hello"), ) ) stream.dispatch(MessageFinishData(event="message-finish")) @@ -520,7 +520,7 @@ class TestPerBlockAccumulation: ContentBlockDeltaData( event="content-block-delta", index=0, - content_block=TextContentBlock(type="text", text="aaa"), + delta={"type": "text-delta", "text": "aaa"}, ) ) # Block 1 streams deltas before block 0 finishes. @@ -528,7 +528,7 @@ class TestPerBlockAccumulation: ContentBlockDeltaData( event="content-block-delta", index=1, - content_block=TextContentBlock(type="text", text="bb"), + delta={"type": "text-delta", "text": "bb"}, ) ) # Block 0 finishes with authoritative text different from deltas. @@ -536,14 +536,14 @@ class TestPerBlockAccumulation: ContentBlockFinishData( event="content-block-finish", index=0, - content_block=TextContentBlock(type="text", text="XXX"), + content=TextContentBlock(type="text", text="XXX"), ) ) stream.dispatch( ContentBlockFinishData( event="content-block-finish", index=1, - content_block=TextContentBlock(type="text", text="bb"), + content=TextContentBlock(type="text", text="bb"), ) ) stream.dispatch(MessageFinishData(event="message-finish")) @@ -566,7 +566,7 @@ class TestPerBlockAccumulation: ContentBlockDeltaData( event="content-block-delta", index=0, - content_block=ReasoningContentBlock(type="reasoning", reasoning="thi"), + delta={"type": "reasoning-delta", "reasoning": "thi"}, ) ) stream.dispatch( @@ -594,14 +594,14 @@ class TestPerBlockAccumulation: ContentBlockDeltaData( event="content-block-delta", index=0, - content_block=TextContentBlock(type="text", text="before"), + delta={"type": "text-delta", "text": "before"}, ) ) stream.dispatch( ContentBlockFinishData( event="content-block-finish", index=0, - content_block=TextContentBlock(type="text", text="before"), + content=TextContentBlock(type="text", text="before"), ) ) # Block 1: tool_call @@ -609,7 +609,7 @@ class TestPerBlockAccumulation: ContentBlockFinishData( event="content-block-finish", index=1, - content_block=ToolCall( + content=ToolCall( type="tool_call", id="tc1", name="search", @@ -622,14 +622,14 @@ class TestPerBlockAccumulation: ContentBlockDeltaData( event="content-block-delta", index=2, - content_block=TextContentBlock(type="text", text="after"), + delta={"type": "text-delta", "text": "after"}, ) ) stream.dispatch( ContentBlockFinishData( event="content-block-finish", index=2, - content_block=TextContentBlock(type="text", text="after"), + content=TextContentBlock(type="text", text="after"), ) ) stream.dispatch(MessageFinishData(event="message-finish"))