diff --git a/libs/partners/openrouter/langchain_openrouter/chat_models.py b/libs/partners/openrouter/langchain_openrouter/chat_models.py index 2eca53b1a0f..3fe32a2812d 100644 --- a/libs/partners/openrouter/langchain_openrouter/chat_models.py +++ b/libs/partners/openrouter/langchain_openrouter/chat_models.py @@ -83,6 +83,24 @@ def _get_default_model_profile(model_name: str) -> ModelProfile: return default.copy() +def _create_stream_generation_info( + chunk_dict: dict[str, Any], choice: dict[str, Any], model_name: str +) -> dict[str, Any]: + generation_info = {"finish_reason": choice["finish_reason"]} + generation_info["model_name"] = chunk_dict.get("model") or model_name + if system_fingerprint := chunk_dict.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + if native_finish_reason := choice.get("native_finish_reason"): + generation_info["native_finish_reason"] = native_finish_reason + if response_id := chunk_dict.get("id"): + generation_info["id"] = response_id + if created := chunk_dict.get("created"): + generation_info["created"] = int(created) + if object_ := chunk_dict.get("object"): + generation_info["object"] = object_ + return generation_info + + class ChatOpenRouter(BaseChatModel): """OpenRouter chat model integration. @@ -534,7 +552,7 @@ class ChatOpenRouter(BaseChatModel): response = await self.client.chat.send_async(messages=message_dicts, **params) return self._create_chat_result(response) - def _stream( # noqa: C901, PLR0912 + def _stream( # noqa: C901 self, messages: list[BaseMessage], stop: list[str] | None = None, @@ -548,6 +566,7 @@ class ChatOpenRouter(BaseChatModel): _strip_internal_kwargs(params) default_chunk_class: type[BaseMessageChunk] = AIMessageChunk + terminal_generation_info: dict[str, Any] = {} for chunk in self.client.chat.send(messages=message_dicts, **params): chunk_dict = chunk.model_dump(by_alias=True) if not chunk_dict.get("choices"): @@ -576,21 +595,18 @@ class ChatOpenRouter(BaseChatModel): chunk_dict, default_chunk_class ) generation_info: dict[str, Any] = {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - # Include response-level metadata on the final chunk - response_model = chunk_dict.get("model") - generation_info["model_name"] = response_model or self.model_name - if system_fingerprint := chunk_dict.get("system_fingerprint"): - generation_info["system_fingerprint"] = system_fingerprint - if native_finish_reason := choice.get("native_finish_reason"): - generation_info["native_finish_reason"] = native_finish_reason - if response_id := chunk_dict.get("id"): - generation_info["id"] = response_id - if created := chunk_dict.get("created"): - generation_info["created"] = int(created) - if object_ := chunk_dict.get("object"): - generation_info["object"] = object_ + if choice.get("finish_reason"): + candidate_generation_info = _create_stream_generation_info( + chunk_dict, choice, self.model_name + ) + generation_info.update( + { + key: value + for key, value in candidate_generation_info.items() + if key not in terminal_generation_info + } + ) + terminal_generation_info.update(generation_info) logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs @@ -619,7 +635,7 @@ class ChatOpenRouter(BaseChatModel): ) yield generation_chunk - async def _astream( # noqa: C901, PLR0912 + async def _astream( # noqa: C901 self, messages: list[BaseMessage], stop: list[str] | None = None, @@ -633,6 +649,7 @@ class ChatOpenRouter(BaseChatModel): _strip_internal_kwargs(params) default_chunk_class: type[BaseMessageChunk] = AIMessageChunk + terminal_generation_info: dict[str, Any] = {} async for chunk in await self.client.chat.send_async( messages=message_dicts, **params ): @@ -663,21 +680,18 @@ class ChatOpenRouter(BaseChatModel): chunk_dict, default_chunk_class ) generation_info: dict[str, Any] = {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - # Include response-level metadata on the final chunk - response_model = chunk_dict.get("model") - generation_info["model_name"] = response_model or self.model_name - if system_fingerprint := chunk_dict.get("system_fingerprint"): - generation_info["system_fingerprint"] = system_fingerprint - if native_finish_reason := choice.get("native_finish_reason"): - generation_info["native_finish_reason"] = native_finish_reason - if response_id := chunk_dict.get("id"): - generation_info["id"] = response_id - if created := chunk_dict.get("created"): - generation_info["created"] = int(created) # UNIX timestamp - if object_ := chunk_dict.get("object"): - generation_info["object"] = object_ + if choice.get("finish_reason"): + candidate_generation_info = _create_stream_generation_info( + chunk_dict, choice, self.model_name + ) + generation_info.update( + { + key: value + for key, value in candidate_generation_info.items() + if key not in terminal_generation_info + } + ) + terminal_generation_info.update(generation_info) logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs diff --git a/libs/partners/openrouter/tests/unit_tests/test_chat_models.py b/libs/partners/openrouter/tests/unit_tests/test_chat_models.py index 037dcf26416..218a7b2e5ed 100644 --- a/libs/partners/openrouter/tests/unit_tests/test_chat_models.py +++ b/libs/partners/openrouter/tests/unit_tests/test_chat_models.py @@ -144,6 +144,39 @@ _STREAM_CHUNKS: list[dict[str, Any]] = [ }, ] +_DUPLICATE_FINISH_STREAM_CHUNKS: list[dict[str, Any]] = [ + { + "choices": [{"delta": {"role": "assistant", "content": "Hello"}, "index": 0}], + "model": MODEL_NAME, + "object": "chat.completion.chunk", + "created": 1700000000.0, + "id": "gen-stream1", + }, + { + "choices": [{"delta": {}, "finish_reason": "stop", "index": 0}], + "model": MODEL_NAME, + "object": "chat.completion.chunk", + "created": 1700000000.0, + "id": "gen-stream1", + }, + { + "choices": [ + { + "delta": {}, + "finish_reason": "stop", + "native_finish_reason": "end_turn", + "index": 0, + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}, + "model": MODEL_NAME, + "object": "chat.completion.chunk", + "created": 1700000000.0, + "id": "gen-stream1", + "system_fingerprint": "fp_duplicate", + }, +] + def _make_sdk_response(response_dict: dict[str, Any]) -> MagicMock: """Build a MagicMock that behaves like an SDK ChatResponse.""" @@ -152,11 +185,34 @@ def _make_sdk_response(response_dict: dict[str, Any]) -> MagicMock: return mock +def _assert_duplicate_finish_result(result: Any) -> None: + generation = result.generations[0][0] + assert generation.text == "Hello" + assert generation.generation_info == { + "finish_reason": "stop", + "model_name": MODEL_NAME, + "id": "gen-stream1", + "created": 1700000000, + "object": "chat.completion.chunk", + "model_provider": "openrouter", + "system_fingerprint": "fp_duplicate", + "native_finish_reason": "end_turn", + } + assert generation.message.response_metadata == generation.generation_info + assert generation.message.usage_metadata == { + "input_tokens": 5, + "output_tokens": 2, + "total_tokens": 7, + } + + class _MockSyncStream: """Synchronous iterator that mimics the SDK EventStream.""" def __init__(self, chunks: list[dict[str, Any]]) -> None: - self._chunks = chunks + # Copy so `__next__`'s `pop(0)` never drains a caller-supplied list + # (e.g. a shared module-level fixture), mirroring `_MockAsyncStream`. + self._chunks = list(chunks) def __iter__(self) -> _MockSyncStream: return self @@ -3336,6 +3392,83 @@ class TestStreamUsage: assert usage["output_tokens"] == 5 assert usage["total_tokens"] == 15 + @pytest.mark.parametrize("stream_usage", [True, False]) + def test_generate_duplicate_finish_chunks_deduplicates_generation_info( + self, stream_usage: Literal[True, False] + ) -> None: + """Test duplicate finish chunks do not concatenate metadata.""" + model = _make_model(streaming=True, stream_usage=stream_usage) + model.client = MagicMock() + model.client.chat.send.return_value = _MockSyncStream( + _DUPLICATE_FINISH_STREAM_CHUNKS + ) + + result = model.generate([[HumanMessage(content="Hello")]]) + + _assert_duplicate_finish_result(result) + + @pytest.mark.parametrize("stream_usage", [True, False]) + async def test_agenerate_duplicate_finish_chunks_deduplicates_generation_info( + self, stream_usage: Literal[True, False] + ) -> None: + """Test async duplicate finish chunks do not concatenate metadata.""" + model = _make_model(streaming=True, stream_usage=stream_usage) + model.client = MagicMock() + model.client.chat.send_async = AsyncMock( + return_value=_MockAsyncStream(_DUPLICATE_FINISH_STREAM_CHUNKS) + ) + + result = await model.agenerate([[HumanMessage(content="Hello")]]) + + _assert_duplicate_finish_result(result) + + def test_stream_differing_finish_reasons_keeps_first(self) -> None: + """First finish reason wins across repeated terminal chunks. + + OpenRouter can emit several chunks bearing a `finish_reason`. Later + terminal chunks should fill only missing metadata fields, so a differing + later reason must not surface in any emitted chunk. + """ + finish_chunk: dict[str, Any] = { + "choices": [{"delta": {}, "finish_reason": "stop", "index": 0}], + "model": MODEL_NAME, + "object": "chat.completion.chunk", + "created": 1700000000.0, + "id": "gen-stream1", + } + chunks: list[dict[str, Any]] = [ + { + "choices": [ + {"delta": {"role": "assistant", "content": "Hi"}, "index": 0} + ], + "model": MODEL_NAME, + "object": "chat.completion.chunk", + "created": 1700000000.0, + "id": "gen-stream1", + }, + finish_chunk, + { + **finish_chunk, + "choices": [{"delta": {}, "finish_reason": "length", "index": 0}], + }, + { + **finish_chunk, + "choices": [{"delta": {}, "finish_reason": "stop", "index": 0}], + }, + ] + model = _make_model() + model.client = MagicMock() + model.client.chat.send.return_value = _MockSyncStream(chunks) + + emitted = list(model.stream("Hi")) + + finish_reasons = [ + c.response_metadata["finish_reason"] + for c in emitted + if "finish_reason" in c.response_metadata + ] + assert finish_reasons == ["stop"] + async def test_astream_options_passed_by_default(self) -> None: """Test that async stream sends stream_options by default.""" model = _make_model()