This commit is contained in:
Mason Daugherty
2026-02-13 15:03:21 -05:00
parent 2d22c2c4ef
commit d4f739c5fd
2 changed files with 53 additions and 19 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import contextlib
import json
import warnings
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
@@ -48,6 +49,7 @@ from langchain_core.messages.ai import (
from langchain_core.messages.block_translators.openai import (
convert_to_openai_data_block,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
@@ -135,7 +137,10 @@ class ChatOpenRouter(BaseChatModel):
"""
client: Any = Field(default=None, exclude=True)
"""OpenRouter client instance (`openrouter.OpenRouter`)."""
"""Underlying SDK client (`openrouter.OpenRouter`).
Created automatically during validation.
"""
model_name: str = Field(alias="model")
"""The name of the model, e.g. `'anthropic/claude-sonnet-4-5'`."""
@@ -539,6 +544,14 @@ class ChatOpenRouter(BaseChatModel):
if not isinstance(response, dict):
response = response.model_dump(by_alias=True)
if error := response.get("error"):
msg = (
f"OpenRouter API returned an error: "
f"{error.get('message', str(error))} "
f"(code: {error.get('code', 'unknown')})"
)
raise ValueError(msg)
generations = []
token_usage = response.get("usage") or {}
@@ -573,20 +586,21 @@ class ChatOpenRouter(BaseChatModel):
if token_usage is not None:
for k, v in token_usage.items():
if v is None:
overall_token_usage.setdefault(k, v)
elif k not in overall_token_usage:
overall_token_usage[k] = v
elif isinstance(v, dict):
for nested_k, nested_v in v.items():
if (
nested_k in overall_token_usage[k]
and nested_v is not None
):
overall_token_usage[k][nested_k] += nested_v
else:
overall_token_usage[k][nested_k] = nested_v
continue
if k in overall_token_usage:
if isinstance(v, dict):
for nested_k, nested_v in v.items():
if (
nested_k in overall_token_usage[k]
and nested_v is not None
):
overall_token_usage[k][nested_k] += nested_v
else:
overall_token_usage[k][nested_k] = nested_v
else:
overall_token_usage[k] += v
else:
overall_token_usage[k] += v
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model_name}
def bind_tools(
@@ -889,7 +903,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: # noqa:
tool_call_id=_dict.get("tool_call_id"),
additional_kwargs=additional_kwargs,
)
return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type]
if role is None:
msg = (
f"OpenRouter response message is missing the 'role' field. "
f"Message keys: {list(_dict.keys())}"
)
raise ValueError(msg)
return ChatMessage(content=_dict.get("content", ""), role=role)
def _convert_chunk_to_message_chunk(
@@ -909,9 +929,19 @@ def _convert_chunk_to_message_chunk(
role = cast("str", _dict.get("role"))
content = cast("str", _dict.get("content") or "")
additional_kwargs: dict = {}
tool_call_chunks: list = []
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
if raw_tool_calls := _dict.get("tool_calls"):
with contextlib.suppress(KeyError):
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
for rtc in raw_tool_calls
]
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
@@ -920,13 +950,13 @@ def _convert_chunk_to_message_chunk(
additional_kwargs["reasoning_content"] = reasoning
if reasoning_details := _dict.get("reasoning_details"):
additional_kwargs["reasoning_details"] = reasoning_details
# Extract usage from chunk if present
usage_metadata = None
if usage := chunk.get("usage"):
usage_metadata = _create_usage_metadata(usage)
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "openrouter"},
)

View File

@@ -1042,7 +1042,11 @@ class TestStreamingChunks:
}
message_chunk = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)
assert isinstance(message_chunk, AIMessageChunk)
assert "tool_calls" in message_chunk.additional_kwargs
assert len(message_chunk.tool_call_chunks) == 1
assert message_chunk.tool_call_chunks[0]["name"] == "get_weather"
assert message_chunk.tool_call_chunks[0]["args"] == '{"loc'
assert message_chunk.tool_call_chunks[0]["id"] == "call_1"
assert message_chunk.tool_call_chunks[0]["index"] == 0
def test_chunk_with_user_role(self) -> None:
"""Test that a chunk with role=user produces HumanMessageChunk."""