diff --git a/libs/partners/openrouter/langchain_openrouter/chat_models.py b/libs/partners/openrouter/langchain_openrouter/chat_models.py index 7a510912766..6ff873f5ce8 100644 --- a/libs/partners/openrouter/langchain_openrouter/chat_models.py +++ b/libs/partners/openrouter/langchain_openrouter/chat_models.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import json import warnings from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence @@ -48,6 +49,7 @@ from langchain_core.messages.ai import ( from langchain_core.messages.block_translators.openai import ( convert_to_openai_data_block, ) +from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( @@ -135,7 +137,10 @@ class ChatOpenRouter(BaseChatModel): """ client: Any = Field(default=None, exclude=True) - """OpenRouter client instance (`openrouter.OpenRouter`).""" + """Underlying SDK client (`openrouter.OpenRouter`). + + Created automatically during validation. + """ model_name: str = Field(alias="model") """The name of the model, e.g. `'anthropic/claude-sonnet-4-5'`.""" @@ -539,6 +544,14 @@ class ChatOpenRouter(BaseChatModel): if not isinstance(response, dict): response = response.model_dump(by_alias=True) + if error := response.get("error"): + msg = ( + f"OpenRouter API returned an error: " + f"{error.get('message', str(error))} " + f"(code: {error.get('code', 'unknown')})" + ) + raise ValueError(msg) + generations = [] token_usage = response.get("usage") or {} @@ -573,20 +586,21 @@ class ChatOpenRouter(BaseChatModel): if token_usage is not None: for k, v in token_usage.items(): if v is None: - overall_token_usage.setdefault(k, v) - elif k not in overall_token_usage: - overall_token_usage[k] = v - elif isinstance(v, dict): - for nested_k, nested_v in v.items(): - if ( - nested_k in overall_token_usage[k] - and nested_v is not None - ): - overall_token_usage[k][nested_k] += nested_v - else: - overall_token_usage[k][nested_k] = nested_v + continue + if k in overall_token_usage: + if isinstance(v, dict): + for nested_k, nested_v in v.items(): + if ( + nested_k in overall_token_usage[k] + and nested_v is not None + ): + overall_token_usage[k][nested_k] += nested_v + else: + overall_token_usage[k][nested_k] = nested_v + else: + overall_token_usage[k] += v else: - overall_token_usage[k] += v + overall_token_usage[k] = v return {"token_usage": overall_token_usage, "model_name": self.model_name} def bind_tools( @@ -889,7 +903,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: # noqa: tool_call_id=_dict.get("tool_call_id"), additional_kwargs=additional_kwargs, ) - return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type] + if role is None: + msg = ( + f"OpenRouter response message is missing the 'role' field. " + f"Message keys: {list(_dict.keys())}" + ) + raise ValueError(msg) + return ChatMessage(content=_dict.get("content", ""), role=role) def _convert_chunk_to_message_chunk( @@ -909,9 +929,19 @@ def _convert_chunk_to_message_chunk( role = cast("str", _dict.get("role")) content = cast("str", _dict.get("content") or "") additional_kwargs: dict = {} + tool_call_chunks: list = [] - if _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = _dict["tool_calls"] + if raw_tool_calls := _dict.get("tool_calls"): + with contextlib.suppress(KeyError): + tool_call_chunks = [ + tool_call_chunk( + name=rtc["function"].get("name"), + args=rtc["function"].get("arguments"), + id=rtc.get("id"), + index=rtc["index"], + ) + for rtc in raw_tool_calls + ] if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) @@ -920,13 +950,13 @@ def _convert_chunk_to_message_chunk( additional_kwargs["reasoning_content"] = reasoning if reasoning_details := _dict.get("reasoning_details"): additional_kwargs["reasoning_details"] = reasoning_details - # Extract usage from chunk if present usage_metadata = None if usage := chunk.get("usage"): usage_metadata = _create_usage_metadata(usage) return AIMessageChunk( content=content, additional_kwargs=additional_kwargs, + tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] usage_metadata=usage_metadata, # type: ignore[arg-type] response_metadata={"model_provider": "openrouter"}, ) diff --git a/libs/partners/openrouter/tests/unit_tests/test_chat_models.py b/libs/partners/openrouter/tests/unit_tests/test_chat_models.py index 4849e45ffbf..37022ca0769 100644 --- a/libs/partners/openrouter/tests/unit_tests/test_chat_models.py +++ b/libs/partners/openrouter/tests/unit_tests/test_chat_models.py @@ -1042,7 +1042,11 @@ class TestStreamingChunks: } message_chunk = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) assert isinstance(message_chunk, AIMessageChunk) - assert "tool_calls" in message_chunk.additional_kwargs + assert len(message_chunk.tool_call_chunks) == 1 + assert message_chunk.tool_call_chunks[0]["name"] == "get_weather" + assert message_chunk.tool_call_chunks[0]["args"] == '{"loc' + assert message_chunk.tool_call_chunks[0]["id"] == "call_1" + assert message_chunk.tool_call_chunks[0]["index"] == 0 def test_chunk_with_user_role(self) -> None: """Test that a chunk with role=user produces HumanMessageChunk."""