diff --git a/libs/partners/openrouter/langchain_openrouter/chat_models.py b/libs/partners/openrouter/langchain_openrouter/chat_models.py new file mode 100644 index 00000000000..7a510912766 --- /dev/null +++ b/libs/partners/openrouter/langchain_openrouter/chat_models.py @@ -0,0 +1,1012 @@ +"""OpenRouter chat models.""" + +from __future__ import annotations + +import json +import warnings +from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence +from operator import itemgetter +from typing import Any, Literal, cast + +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models import ( + LanguageModelInput, + ModelProfile, + ModelProfileRegistry, +) +from langchain_core.language_models.chat_models import ( + BaseChatModel, + LangSmithParams, + agenerate_from_stream, + generate_from_stream, +) +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + HumanMessage, + HumanMessageChunk, + InvalidToolCall, + SystemMessage, + SystemMessageChunk, + ToolCall, + ToolMessage, + ToolMessageChunk, + is_data_content_block, +) +from langchain_core.messages.ai import ( + InputTokenDetails, + OutputTokenDetails, + UsageMetadata, +) +from langchain_core.messages.block_translators.openai import ( + convert_to_openai_data_block, +) +from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser +from langchain_core.output_parsers.base import OutputParserLike +from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, + make_invalid_tool_call, + parse_tool_call, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough +from langchain_core.tools import BaseTool +from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env +from langchain_core.utils.function_calling import ( + convert_to_json_schema, + convert_to_openai_tool, +) +from langchain_core.utils.pydantic import is_basemodel_subclass +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator +from typing_extensions import Self + +from langchain_openrouter.data._profiles import _PROFILES + +_MODEL_PROFILES = cast("ModelProfileRegistry", _PROFILES) + +# LangChain-internal kwargs that must not be forwarded to the SDK. +_INTERNAL_KWARGS = frozenset({"ls_structured_output_format"}) + + +def _get_default_model_profile(model_name: str) -> ModelProfile: + default = _MODEL_PROFILES.get(model_name) or {} + return default.copy() + + +class ChatOpenRouter(BaseChatModel): + """OpenRouter chat model integration. + + OpenRouter is a unified API that provides access to models from + multiple providers (OpenAI, Anthropic, Google, Meta, etc.). + + ???+ info "Setup" + + Install `langchain-openrouter` and set environment variable + `OPENROUTER_API_KEY`. + + ```bash + pip install -U langchain-openrouter + ``` + + ```bash + export OPENROUTER_API_KEY="your-api-key" + ``` + + ??? info "Key init args — completion params" + + | Param | Type | Description | + | ----- | ---- | ----------- | + | `model` | `str` | Model name, e.g. `'openai/gpt-4o-mini'`. | + | `temperature` | `float | None` | Sampling temperature. | + | `max_tokens` | `int | None` | Max tokens to generate. | + + ??? info "Key init args — client params" + + | Param | Type | Description | + | ----- | ---- | ----------- | + | `api_key` | `str | None` | OpenRouter API key. | + | `base_url` | `str | None` | Base URL for API requests. | + | `timeout` | `int | None` | Timeout in milliseconds. | + | `app_url` | `str | None` | App URL for attribution. | + | `app_title` | `str | None` | App title for attribution. | + + ??? info "Instantiate" + + ```python + from langchain_openrouter import ChatOpenRouter + + model = ChatOpenRouter( + model="anthropic/claude-sonnet-4-5", + temperature=0, + # api_key="...", + # openrouter_provider={"order": ["Anthropic"]}, + ) + ``` + + See https://openrouter.ai/docs for platform documentation. + """ + + client: Any = Field(default=None, exclude=True) + """OpenRouter client instance (`openrouter.OpenRouter`).""" + + model_name: str = Field(alias="model") + """The name of the model, e.g. `'anthropic/claude-sonnet-4-5'`.""" + + temperature: float | None = None + """Sampling temperature.""" + + max_tokens: int | None = None + """Maximum number of tokens to generate.""" + + max_completion_tokens: int | None = None + """Maximum number of completion tokens to generate.""" + + top_p: float | None = None + """Nucleus sampling parameter.""" + + frequency_penalty: float | None = None + """Frequency penalty for generation.""" + + presence_penalty: float | None = None + """Presence penalty for generation.""" + + seed: int | None = None + """Random seed for reproducibility.""" + + stop: list[str] | str | None = Field(default=None, alias="stop_sequences") + """Default stop sequences.""" + + n: int = 1 + """Number of chat completions to generate for each prompt.""" + + streaming: bool = False + """Whether to stream the results or not.""" + + model_kwargs: dict[str, Any] = Field(default_factory=dict) + """Any extra model parameters for the OpenRouter API.""" + + openrouter_api_key: SecretStr | None = Field( + alias="api_key", + default_factory=secret_from_env("OPENROUTER_API_KEY", default=None), + ) + """OpenRouter API key.""" + + openrouter_api_base: str | None = Field( + default_factory=from_env("OPENROUTER_API_BASE", default=None), + alias="base_url", + ) + """OpenRouter API base URL. Maps to SDK `server_url`.""" + + app_url: str | None = Field( + default_factory=from_env("OPENROUTER_APP_URL", default=None), + ) + """Application URL for OpenRouter attribution. Maps to `HTTP-Referer` header.""" + + app_title: str | None = Field( + default_factory=from_env("OPENROUTER_APP_TITLE", default=None), + ) + """Application title for OpenRouter attribution. Maps to `X-Title` header.""" + + request_timeout: int | None = Field(default=None, alias="timeout") + """Timeout for requests in milliseconds. Maps to SDK `timeout_ms`.""" + + max_retries: int = 2 + """Maximum number of retries. + + Controls the retry backoff window via the SDK's `max_elapsed_time`. + + Set to `0` to disable retries. + """ + + openrouter_reasoning: dict[str, Any] | None = None + """Reasoning settings to pass to OpenRouter. + + Example: `{"effort": "high"}` + """ + + openrouter_provider: dict[str, Any] | None = None + """Provider preferences to pass to OpenRouter. + + Example: `{"order": ["Anthropic", "OpenAI"]}` + """ + + openrouter_route: str | None = None + """Route preference for OpenRouter. E.g. `'fallback'`.""" + + openrouter_plugins: list[dict[str, Any]] | None = None + """Plugins configuration for OpenRouter.""" + + model_config = ConfigDict(populate_by_name=True) + + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: dict[str, Any]) -> Any: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + msg = f"Found {field_name} supplied twice." + raise ValueError(msg) + if field_name not in all_required_field_names: + warnings.warn( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""", + stacklevel=2, + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + msg = ( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + raise ValueError(msg) + + values["model_kwargs"] = extra + return values + + @model_validator(mode="after") + def validate_environment(self) -> Self: # noqa: C901 + """Validate configuration and build the SDK client.""" + if not (self.openrouter_api_key and self.openrouter_api_key.get_secret_value()): + msg = "OPENROUTER_API_KEY must be set." + raise ValueError(msg) + if self.n < 1: + msg = "n must be at least 1." + raise ValueError(msg) + if self.n > 1 and self.streaming: + msg = "n must be 1 when streaming." + raise ValueError(msg) + + if not self.client: + try: + import openrouter # noqa: PLC0415 + from openrouter.utils import ( # noqa: PLC0415 + BackoffStrategy, + RetryConfig, + ) + except ImportError as e: + msg = ( + "Could not import the `openrouter` Python SDK. " + "Please install it with: pip install openrouter" + ) + raise ImportError(msg) from e + + client_kwargs: dict[str, Any] = { + "api_key": self.openrouter_api_key.get_secret_value(), + } + if self.openrouter_api_base: + client_kwargs["server_url"] = self.openrouter_api_base + if self.app_url: + client_kwargs["http_referer"] = self.app_url + if self.app_title: + client_kwargs["x_title"] = self.app_title + if self.request_timeout is not None: + client_kwargs["timeout_ms"] = self.request_timeout + if self.max_retries > 0: + client_kwargs["retry_config"] = RetryConfig( + strategy="backoff", + backoff=BackoffStrategy( + initial_interval=500, + max_interval=60000, + exponent=1.5, + max_elapsed_time=self.max_retries * 150_000, + ), + retry_connection_errors=True, + ) + self.client = openrouter.OpenRouter(**client_kwargs) + return self + + @model_validator(mode="after") + def _set_model_profile(self) -> Self: + """Set model profile if not overridden.""" + if self.profile is None: + self.profile = _get_default_model_profile(self.model_name) + return self + + # + # Serializable class method overrides + # + @property + def lc_secrets(self) -> dict[str, str]: + """A map of constructor argument names to secret ids.""" + return {"openrouter_api_key": "OPENROUTER_API_KEY"} + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this model can be serialized by LangChain.""" + return True + + # + # BaseChatModel method overrides + # + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "openrouter-chat" + + def _get_ls_params( + self, + stop: list[str] | None = None, + **kwargs: Any, + ) -> LangSmithParams: + """Get standard params for tracing.""" + params = self._get_invocation_params(stop=stop, **kwargs) + ls_params = LangSmithParams( + ls_provider="openrouter", + ls_model_name=params.get("model", self.model_name), + ls_model_type="chat", + ls_temperature=params.get("temperature", self.temperature), + ) + if ls_max_tokens := params.get("max_tokens", self.max_tokens): + ls_params["ls_max_tokens"] = ls_max_tokens + if ls_stop := stop or params.get("stop", None) or self.stop: + ls_params["ls_stop"] = ls_stop if isinstance(ls_stop, list) else [ls_stop] + return ls_params + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + _strip_internal_kwargs(params) + response = self.client.chat.send(messages=message_dicts, **params) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + _strip_internal_kwargs(params) + response = await self.client.chat.send_async(messages=message_dicts, **params) + return self._create_chat_result(response) + + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + _strip_internal_kwargs(params) + + default_chunk_class: type[BaseMessageChunk] = AIMessageChunk + for chunk in self.client.chat.send(messages=message_dicts, **params): + chunk_dict = chunk.model_dump(by_alias=True) + if not chunk_dict.get("choices"): + continue + choice = chunk_dict["choices"][0] + message_chunk = _convert_chunk_to_message_chunk( + chunk_dict, default_chunk_class + ) + generation_info: dict[str, Any] = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + generation_info["model_name"] = self.model_name + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + + if generation_info: + message_chunk = message_chunk.model_copy( + update={"response_metadata": generation_info} + ) + + default_chunk_class = message_chunk.__class__ + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None + ) + + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, + chunk=generation_chunk, + logprobs=logprobs, + ) + yield generation_chunk + + async def _astream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + _strip_internal_kwargs(params) + + default_chunk_class: type[BaseMessageChunk] = AIMessageChunk + async for chunk in await self.client.chat.send_async( + messages=message_dicts, **params + ): + chunk_dict = chunk.model_dump(by_alias=True) + if not chunk_dict.get("choices"): + continue + choice = chunk_dict["choices"][0] + message_chunk = _convert_chunk_to_message_chunk( + chunk_dict, default_chunk_class + ) + generation_info: dict[str, Any] = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + generation_info["model_name"] = self.model_name + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + + if generation_info: + message_chunk = message_chunk.model_copy( + update={"response_metadata": generation_info} + ) + + default_chunk_class = message_chunk.__class__ + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None + ) + + if run_manager: + await run_manager.on_llm_new_token( + token=generation_chunk.text, + chunk=generation_chunk, + logprobs=logprobs, + ) + yield generation_chunk + + # + # Internal methods + # + @property + def _default_params(self) -> dict[str, Any]: # noqa: C901, PLR0912 + """Get the default parameters for calling OpenRouter API.""" + params: dict[str, Any] = { + "model": self.model_name, + "stream": self.streaming, + **self.model_kwargs, + } + if self.temperature is not None: + params["temperature"] = self.temperature + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.max_completion_tokens is not None: + params["max_completion_tokens"] = self.max_completion_tokens + if self.top_p is not None: + params["top_p"] = self.top_p + if self.frequency_penalty is not None: + params["frequency_penalty"] = self.frequency_penalty + if self.presence_penalty is not None: + params["presence_penalty"] = self.presence_penalty + if self.seed is not None: + params["seed"] = self.seed + if self.n > 1: + params["n"] = self.n + if self.stop is not None: + params["stop"] = self.stop + # OpenRouter-specific params + if self.openrouter_reasoning is not None: + params["reasoning"] = self.openrouter_reasoning + if self.openrouter_provider is not None: + params["provider"] = self.openrouter_provider + if self.openrouter_route is not None: + params["route"] = self.openrouter_route + if self.openrouter_plugins is not None: + params["plugins"] = self.openrouter_plugins + return params + + def _create_message_dicts( + self, messages: list[BaseMessage], stop: list[str] | None + ) -> tuple[list[dict[str, Any]], dict[str, Any]]: + params = self._default_params + if stop is not None: + params["stop"] = stop + message_dicts = [_convert_message_to_dict(m) for m in messages] + return message_dicts, params + + def _create_chat_result(self, response: Any) -> ChatResult: + """Create a `ChatResult` from an OpenRouter SDK response.""" + if not isinstance(response, dict): + response = response.model_dump(by_alias=True) + + generations = [] + token_usage = response.get("usage") or {} + + for res in response.get("choices", []): + message = _convert_dict_to_message(res["message"]) + if token_usage and isinstance(message, AIMessage): + message.usage_metadata = _create_usage_metadata(token_usage) + generation_info: dict[str, Any] = { + "finish_reason": res.get("finish_reason"), + } + if "logprobs" in res: + generation_info["logprobs"] = res["logprobs"] + gen = ChatGeneration( + message=message, + generation_info=generation_info, + ) + generations.append(gen) + + llm_output: dict[str, Any] = { + "token_usage": token_usage, + "model_name": self.model_name, + } + return ChatResult(generations=generations, llm_output=llm_output) + + def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: + """Combine multiple LLM outputs by aggregating token usage.""" + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + continue + token_usage = output.get("token_usage") + if token_usage is not None: + for k, v in token_usage.items(): + if v is None: + overall_token_usage.setdefault(k, v) + elif k not in overall_token_usage: + overall_token_usage[k] = v + elif isinstance(v, dict): + for nested_k, nested_v in v.items(): + if ( + nested_k in overall_token_usage[k] + and nested_v is not None + ): + overall_token_usage[k][nested_k] += nested_v + else: + overall_token_usage[k][nested_k] = nested_v + else: + overall_token_usage[k] += v + return {"token_usage": overall_token_usage, "model_name": self.model_name} + + def bind_tools( + self, + tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool], + *, + tool_choice: dict | str | bool | None = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, AIMessage]: + """Bind tool-like objects to this chat model. + + Args: + tools: A list of tool definitions to bind to this chat model. + + Supports any tool definition handled by + `langchain_core.utils.function_calling.convert_to_openai_tool`. + tool_choice: Which tool to require the model to call. + **kwargs: Any additional parameters. + """ + _ = kwargs.pop("strict", None) + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + if tool_choice is not None and tool_choice: + if tool_choice == "any": + tool_choice = "required" + if isinstance(tool_choice, str) and ( + tool_choice not in ("auto", "none", "required") + ): + tool_choice = {"type": "function", "function": {"name": tool_choice}} + if isinstance(tool_choice, bool): + if len(tools) > 1: + msg = ( + "tool_choice can only be True when there is one tool. Received " + f"{len(tools)} tools." + ) + raise ValueError(msg) + tool_name = formatted_tools[0]["function"]["name"] + tool_choice = { + "type": "function", + "function": {"name": tool_name}, + } + kwargs["tool_choice"] = tool_choice + return super().bind(tools=formatted_tools, **kwargs) + + def with_structured_output( + self, + schema: dict | type[BaseModel] | None = None, + *, + method: Literal[ + "function_calling", "json_mode", "json_schema" + ] = "function_calling", + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, dict | BaseModel]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: The output schema as a Pydantic class, TypedDict, JSON Schema, + or OpenAI function schema. + method: The method for steering model generation. + include_raw: If `True` then both the raw model response and the + parsed model response will be returned. + **kwargs: Any additional parameters. + + Returns: + A `Runnable` that takes same inputs as a `BaseChatModel`. + """ + _ = kwargs.pop("strict", None) + is_pydantic_schema = _is_pydantic_class(schema) + if method == "function_calling": + if schema is None: + msg = ( + "schema must be specified when method is 'function_calling'. " + "Received None." + ) + raise ValueError(msg) + formatted_tool = convert_to_openai_tool(schema) + tool_name = formatted_tool["function"]["name"] + llm = self.bind_tools( + [schema], + tool_choice=tool_name, + ls_structured_output_format={ + "kwargs": {"method": "function_calling"}, + "schema": formatted_tool, + }, + **kwargs, + ) + if is_pydantic_schema: + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], # type: ignore[list-item] + first_tool_only=True, # type: ignore[list-item] + ) + else: + output_parser = JsonOutputKeyToolsParser( + key_name=tool_name, first_tool_only=True + ) + elif method == "json_schema": + if schema is None: + msg = ( + "schema must be specified when method is 'json_schema'. " + "Received None." + ) + raise ValueError(msg) + json_schema = convert_to_json_schema(schema) + schema_name = json_schema.get("title", "") + response_format = { + "type": "json_schema", + "json_schema": {"name": schema_name, "schema": json_schema}, + } + ls_format_info = { + "kwargs": {"method": "json_schema"}, + "schema": json_schema, + } + llm = self.bind( + response_format=response_format, + ls_structured_output_format=ls_format_info, + **kwargs, + ) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + elif method == "json_mode": + llm = self.bind( + response_format={"type": "json_object"}, + ls_structured_output_format={ + "kwargs": {"method": "json_mode"}, + "schema": schema, + }, + **kwargs, + ) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + else: + msg = ( + f"Unrecognized method argument. Expected one of 'function_calling', " + f"'json_schema', or 'json_mode'. Received: '{method}'" + ) + raise ValueError(msg) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + return llm | output_parser + + +def _is_pydantic_class(obj: Any) -> bool: + return isinstance(obj, type) and is_basemodel_subclass(obj) + + +def _strip_internal_kwargs(params: dict[str, Any]) -> None: + """Remove LangChain-internal keys that the SDK does not accept.""" + for key in _INTERNAL_KWARGS: + params.pop(key, None) + + +# +# Type conversion helpers +# +def _format_message_content(content: Any) -> Any: + """Format message content for OpenRouter API. + + Converts LangChain data content blocks to the expected format. + + Args: + content: The message content (string or list of content blocks). + + Returns: + Formatted content suitable for the OpenRouter API. + """ + if content and isinstance(content, list): + formatted: list = [] + for block in content: + if isinstance(block, dict) and is_data_content_block(block): + formatted.append(convert_to_openai_data_block(block)) + else: + formatted.append(block) + return formatted + return content + + +def _convert_message_to_dict(message: BaseMessage) -> dict: # noqa: C901 + """Convert a LangChain message to a dictionary. + + Args: + message: The LangChain message. + + Returns: + The dictionary. + """ + message_dict: dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = { + "role": "user", + "content": _format_message_content(message.content), + } + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + # Filter out non-text blocks from list content + if isinstance(message.content, list): + text_blocks = [ + block + for block in message.content + if isinstance(block, dict) and block.get("type") == "text" + ] + message_dict["content"] = text_blocks or "" + if message.tool_calls or message.invalid_tool_calls: + message_dict["tool_calls"] = [ + _lc_tool_call_to_openrouter_tool_call(tc) for tc in message.tool_calls + ] + [ + _lc_invalid_tool_call_to_openrouter_tool_call(tc) + for tc in message.invalid_tool_calls + ] + if message_dict["content"] == "" or ( + isinstance(message_dict["content"], list) + and not message_dict["content"] + ): + message_dict["content"] = None + elif "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] + if message_dict["content"] == "" or ( + isinstance(message_dict["content"], list) + and not message_dict["content"] + ): + message_dict["content"] = None + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolMessage): + message_dict = { + "role": "tool", + "content": message.content, + "tool_call_id": message.tool_call_id, + } + else: + msg = f"Got unknown type {message}" + raise TypeError(msg) + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: # noqa: C901 + """Convert a dictionary to a LangChain message. + + Args: + _dict: The dictionary. + + Returns: + The LangChain message. + """ + id_ = _dict.get("id") + role = _dict.get("role") + if role == "user": + return HumanMessage(content=_dict.get("content", "")) + if role == "assistant": + content = _dict.get("content", "") or "" + additional_kwargs: dict = {} + if reasoning := _dict.get("reasoning"): + additional_kwargs["reasoning_content"] = reasoning + if reasoning_details := _dict.get("reasoning_details"): + additional_kwargs["reasoning_details"] = reasoning_details + tool_calls = [] + invalid_tool_calls = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) + except Exception as e: # noqa: BLE001, PERF203 + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) + return AIMessage( + content=content, + id=id_, + additional_kwargs=additional_kwargs, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + response_metadata={"model_provider": "openrouter"}, + ) + if role == "system": + return SystemMessage(content=_dict.get("content", "")) + if role == "tool": + additional_kwargs = {} + if "name" in _dict: + additional_kwargs["name"] = _dict["name"] + return ToolMessage( + content=_dict.get("content", ""), + tool_call_id=_dict.get("tool_call_id"), + additional_kwargs=additional_kwargs, + ) + return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type] + + +def _convert_chunk_to_message_chunk( + chunk: Mapping[str, Any], default_class: type[BaseMessageChunk] +) -> BaseMessageChunk: + """Convert a streaming chunk dict to a LangChain message chunk. + + Args: + chunk: The streaming chunk dictionary. + default_class: Default message chunk class. + + Returns: + The LangChain message chunk. + """ + choice = chunk["choices"][0] + _dict = choice.get("delta", {}) + role = cast("str", _dict.get("role")) + content = cast("str", _dict.get("content") or "") + additional_kwargs: dict = {} + + if _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = _dict["tool_calls"] + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + if role == "assistant" or default_class == AIMessageChunk: + if reasoning := _dict.get("reasoning"): + additional_kwargs["reasoning_content"] = reasoning + if reasoning_details := _dict.get("reasoning_details"): + additional_kwargs["reasoning_details"] = reasoning_details + # Extract usage from chunk if present + usage_metadata = None + if usage := chunk.get("usage"): + usage_metadata = _create_usage_metadata(usage) + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + usage_metadata=usage_metadata, # type: ignore[arg-type] + response_metadata={"model_provider": "openrouter"}, + ) + if role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + if role == "tool" or default_class == ToolMessageChunk: + return ToolMessageChunk( + content=content, tool_call_id=_dict.get("tool_call_id", "") + ) + if role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + return default_class(content=content) # type: ignore[call-arg] + + +def _lc_tool_call_to_openrouter_tool_call(tool_call: ToolCall) -> dict: + return { + "type": "function", + "id": tool_call["id"], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["args"], ensure_ascii=False), + }, + } + + +def _lc_invalid_tool_call_to_openrouter_tool_call( + invalid_tool_call: InvalidToolCall, +) -> dict: + return { + "type": "function", + "id": invalid_tool_call["id"], + "function": { + "name": invalid_tool_call["name"], + "arguments": invalid_tool_call["args"], + }, + } + + +def _create_usage_metadata(token_usage: dict) -> UsageMetadata: + """Create usage metadata from OpenRouter token usage response. + + Args: + token_usage: Token usage dict from the API response. + + Returns: + Usage metadata with input/output token details. + """ + input_tokens = int( + token_usage.get("prompt_tokens") or token_usage.get("input_tokens") or 0 + ) + output_tokens = int( + token_usage.get("completion_tokens") or token_usage.get("output_tokens") or 0 + ) + total_tokens = int(token_usage.get("total_tokens") or input_tokens + output_tokens) + + input_details_dict = ( + token_usage.get("prompt_tokens_details") + or token_usage.get("input_tokens_details") + or {} + ) + output_details_dict = ( + token_usage.get("completion_tokens_details") + or token_usage.get("output_tokens_details") + or {} + ) + + input_token_details: dict = { + "cache_read": input_details_dict.get("cached_tokens"), + } + output_token_details: dict = { + "reasoning": output_details_dict.get("reasoning_tokens"), + } + usage_metadata: UsageMetadata = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + } + + if filtered_input := {k: v for k, v in input_token_details.items() if v}: + usage_metadata["input_token_details"] = InputTokenDetails(**filtered_input) # type: ignore[typeddict-item] + if filtered_output := {k: v for k, v in output_token_details.items() if v}: + usage_metadata["output_token_details"] = OutputTokenDetails(**filtered_output) # type: ignore[typeddict-item] + return usage_metadata diff --git a/libs/partners/openrouter/tests/integration_tests/test_chat_models.py b/libs/partners/openrouter/tests/integration_tests/test_chat_models.py new file mode 100644 index 00000000000..22087929b0f --- /dev/null +++ b/libs/partners/openrouter/tests/integration_tests/test_chat_models.py @@ -0,0 +1,69 @@ +"""Integration tests for `ChatOpenRouter` chat model.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessageChunk, BaseMessageChunk +from pydantic import BaseModel, Field + +from langchain_openrouter.chat_models import ChatOpenRouter + + +def test_basic_invoke() -> None: + """Test basic invocation.""" + model = ChatOpenRouter(model="openai/gpt-4o-mini", temperature=0) + response = model.invoke("Say 'hello' and nothing else.") + assert response.content + assert response.response_metadata.get("model_provider") == "openrouter" + + +def test_streaming() -> None: + """Test streaming.""" + model = ChatOpenRouter(model="openai/gpt-4o-mini", temperature=0) + full: BaseMessageChunk | None = None + for chunk in model.stream("Say 'hello' and nothing else."): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.content + + +def test_tool_calling() -> None: + """Test tool calling via OpenRouter.""" + + class GetWeather(BaseModel): + """Get the current weather in a given location.""" + + location: str = Field(description="The city and state") + + model = ChatOpenRouter(model="openai/gpt-4o-mini", temperature=0) + model_with_tools = model.bind_tools([GetWeather]) + response = model_with_tools.invoke("What's the weather in San Francisco?") + assert response.tool_calls + + +def test_structured_output() -> None: + """Test structured output via OpenRouter.""" + + class Joke(BaseModel): + """A joke.""" + + setup: str = Field(description="The setup of the joke") + punchline: str = Field(description="The punchline of the joke") + + model = ChatOpenRouter(model="openai/gpt-4o-mini", temperature=0) + structured = model.with_structured_output(Joke) + result = structured.invoke("Tell me a joke about programming") + assert isinstance(result, Joke) + assert result.setup + assert result.punchline + + +@pytest.mark.xfail(reason="Depends on reasoning model availability on OpenRouter.") +def test_reasoning_content() -> None: + """Test reasoning content from a reasoning model.""" + model = ChatOpenRouter( + model="openai/o3-mini", + openrouter_reasoning={"effort": "low"}, + ) + response = model.invoke("What is 2 + 2?") + assert response.content diff --git a/libs/partners/openrouter/tests/integration_tests/test_standard.py b/libs/partners/openrouter/tests/integration_tests/test_standard.py new file mode 100644 index 00000000000..e7da6c80769 --- /dev/null +++ b/libs/partners/openrouter/tests/integration_tests/test_standard.py @@ -0,0 +1,39 @@ +"""Standard integration tests for `ChatOpenRouter`.""" + +from langchain_tests.integration_tests import ChatModelIntegrationTests + +from langchain_openrouter.chat_models import ChatOpenRouter + +MODEL_NAME = "openai/gpt-4o-mini" + + +class TestChatOpenRouter(ChatModelIntegrationTests): + """Test `ChatOpenRouter` chat model.""" + + @property + def chat_model_class(self) -> type[ChatOpenRouter]: + """Return class of chat model being tested.""" + return ChatOpenRouter + + @property + def chat_model_params(self) -> dict: + """Parameters to create chat model instance for testing.""" + return { + "model": MODEL_NAME, + "temperature": 0, + } + + @property + def supports_json_mode(self) -> bool: + """Whether the chat model supports JSON mode.""" + return True + + @property + def supports_image_inputs(self) -> bool: + """Whether the chat model supports image inputs.""" + return True + + @property + def supports_image_urls(self) -> bool: + """Whether the chat model supports image inputs from URLs.""" + return True diff --git a/libs/partners/openrouter/tests/unit_tests/test_chat_models.py b/libs/partners/openrouter/tests/unit_tests/test_chat_models.py new file mode 100644 index 00000000000..5a2e3dcc165 --- /dev/null +++ b/libs/partners/openrouter/tests/unit_tests/test_chat_models.py @@ -0,0 +1,1309 @@ +"""Unit tests for `ChatOpenRouter` chat model.""" + +from __future__ import annotations + +from typing import Any, Literal +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from langchain_core.load import dumpd, dumps, load +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + ChatMessage, + ChatMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, + ToolMessage, +) +from langchain_core.runnables import RunnableBinding +from pydantic import BaseModel, Field, SecretStr + +from langchain_openrouter.chat_models import ( + ChatOpenRouter, + _convert_chunk_to_message_chunk, + _convert_dict_to_message, + _convert_message_to_dict, + _create_usage_metadata, +) + +MODEL_NAME = "openai/gpt-4o-mini" + + +def _make_model(**kwargs: Any) -> ChatOpenRouter: + """Create a `ChatOpenRouter` with sane defaults for unit tests.""" + defaults: dict[str, Any] = {"model": MODEL_NAME, "api_key": SecretStr("test-key")} + defaults.update(kwargs) + return ChatOpenRouter(**defaults) + + +# --------------------------------------------------------------------------- +# Pydantic schemas used across multiple test classes +# --------------------------------------------------------------------------- + + +class GetWeather(BaseModel): + """Get the current weather in a given location.""" + + location: str = Field(description="The city and state") + + +class GenerateUsername(BaseModel): + """Generate a username from a full name.""" + + name: str = Field(description="The full name") + hair_color: str = Field(description="The hair color") + + +# --------------------------------------------------------------------------- +# Mock helpers for SDK responses +# --------------------------------------------------------------------------- + +_SIMPLE_RESPONSE_DICT: dict[str, Any] = { + "id": "gen-abc123", + "choices": [ + { + "message": {"role": "assistant", "content": "Hello!"}, + "finish_reason": "stop", + "index": 0, + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + "model": MODEL_NAME, + "object": "chat.completion", + "created": 1700000000.0, +} + +_TOOL_RESPONSE_DICT: dict[str, Any] = { + "id": "gen-tool123", + "choices": [ + { + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "GetWeather", + "arguments": '{"location": "San Francisco"}', + }, + } + ], + }, + "finish_reason": "tool_calls", + "index": 0, + } + ], + "usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30}, + "model": MODEL_NAME, + "object": "chat.completion", + "created": 1700000000.0, +} + +_STREAM_CHUNKS: list[dict[str, Any]] = [ + { + "choices": [{"delta": {"role": "assistant", "content": ""}, "index": 0}], + "model": MODEL_NAME, + "object": "chat.completion.chunk", + "created": 1700000000.0, + "id": "gen-stream1", + }, + { + "choices": [{"delta": {"content": "Hello"}, "index": 0}], + "model": MODEL_NAME, + "object": "chat.completion.chunk", + "created": 1700000000.0, + "id": "gen-stream1", + }, + { + "choices": [{"delta": {"content": " world"}, "index": 0}], + "model": MODEL_NAME, + "object": "chat.completion.chunk", + "created": 1700000000.0, + "id": "gen-stream1", + }, + { + "choices": [{"delta": {}, "finish_reason": "stop", "index": 0}], + "usage": {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}, + "model": MODEL_NAME, + "object": "chat.completion.chunk", + "created": 1700000000.0, + "id": "gen-stream1", + }, +] + + +def _make_sdk_response(response_dict: dict[str, Any]) -> MagicMock: + """Build a MagicMock that behaves like an SDK ChatResponse.""" + mock = MagicMock() + mock.model_dump.return_value = response_dict + return mock + + +class _MockSyncStream: + """Synchronous iterator that mimics the SDK EventStream.""" + + def __init__(self, chunks: list[dict[str, Any]]) -> None: + self._chunks = chunks + + def __iter__(self) -> _MockSyncStream: + return self + + def __next__(self) -> MagicMock: + if not self._chunks: + raise StopIteration + chunk = self._chunks.pop(0) + mock = MagicMock() + mock.model_dump.return_value = chunk + return mock + + +class _MockAsyncStream: + """Async iterator that mimics the SDK EventStreamAsync.""" + + def __init__(self, chunks: list[dict[str, Any]]) -> None: + self._chunks = list(chunks) + + def __aiter__(self) -> _MockAsyncStream: + return self + + async def __anext__(self) -> MagicMock: + if not self._chunks: + raise StopAsyncIteration + chunk = self._chunks.pop(0) + mock = MagicMock() + mock.model_dump.return_value = chunk + return mock + + +# =========================================================================== +# Instantiation tests +# =========================================================================== + + +class TestChatOpenRouterInstantiation: + """Tests for `ChatOpenRouter` instantiation.""" + + def test_basic_instantiation(self) -> None: + """Test basic model instantiation with required params.""" + model = _make_model() + assert model.model_name == MODEL_NAME + assert model.openrouter_api_base is None + + def test_api_key_from_field(self) -> None: + """Test that API key is properly set.""" + model = _make_model() + assert model.openrouter_api_key is not None + assert model.openrouter_api_key.get_secret_value() == "test-key" + + def test_api_key_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that API key is read from OPENROUTER_API_KEY env var.""" + monkeypatch.setenv("OPENROUTER_API_KEY", "env-key-123") + model = ChatOpenRouter(model=MODEL_NAME) + assert model.openrouter_api_key is not None + assert model.openrouter_api_key.get_secret_value() == "env-key-123" + + def test_missing_api_key_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that missing API key raises ValueError.""" + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + with pytest.raises(ValueError, match="OPENROUTER_API_KEY must be set"): + ChatOpenRouter(model=MODEL_NAME) + + def test_model_required(self) -> None: + """Test that model name is required.""" + with pytest.raises((ValueError, TypeError)): + ChatOpenRouter(api_key=SecretStr("test-key")) # type: ignore[call-arg] + + def test_secret_masking(self) -> None: + """Test that API key is not exposed in string representation.""" + model = _make_model(api_key=SecretStr("super-secret")) + model_str = str(model) + assert "super-secret" not in model_str + + def test_secret_masking_repr(self) -> None: + """Test that API key is masked in repr too.""" + model = _make_model(api_key=SecretStr("super-secret")) + assert "super-secret" not in repr(model) + + def test_api_key_is_secret_str(self) -> None: + """Test that openrouter_api_key is a SecretStr instance.""" + model = _make_model() + assert isinstance(model.openrouter_api_key, SecretStr) + + def test_llm_type(self) -> None: + """Test _llm_type property.""" + model = _make_model() + assert model._llm_type == "openrouter-chat" + + def test_ls_params(self) -> None: + """Test LangSmith params include openrouter provider.""" + model = _make_model() + ls_params = model._get_ls_params() + assert ls_params["ls_provider"] == "openrouter" + + def test_client_created(self) -> None: + """Test that OpenRouter SDK client is created.""" + model = _make_model() + assert model.client is not None + + def test_client_reused_for_same_params(self) -> None: + """Test that the SDK client is reused when model is re-validated.""" + model = _make_model() + client_1 = model.client + # Re-validate does not replace the existing client + model.validate_environment() # type: ignore[operator] + assert model.client is client_1 + + def test_app_url_passed_to_client(self) -> None: + """Test that app_url is passed as http_referer to the SDK client.""" + with patch("openrouter.OpenRouter") as mock_cls: + mock_cls.return_value = MagicMock() + ChatOpenRouter( + model=MODEL_NAME, + api_key=SecretStr("test-key"), + app_url="https://myapp.com", + ) + call_kwargs = mock_cls.call_args[1] + assert call_kwargs["http_referer"] == "https://myapp.com" + + def test_app_title_passed_to_client(self) -> None: + """Test that app_title is passed as x_title to the SDK client.""" + with patch("openrouter.OpenRouter") as mock_cls: + mock_cls.return_value = MagicMock() + ChatOpenRouter( + model=MODEL_NAME, + api_key=SecretStr("test-key"), + app_title="My App", + ) + call_kwargs = mock_cls.call_args[1] + assert call_kwargs["x_title"] == "My App" + + def test_openrouter_reasoning_in_params(self) -> None: + """Test that openrouter_reasoning is included in default params.""" + model = _make_model(openrouter_reasoning={"effort": "high"}) + params = model._default_params + assert params["reasoning"] == {"effort": "high"} + + def test_openrouter_provider_in_params(self) -> None: + """Test that openrouter_provider is included in default params.""" + model = _make_model(openrouter_provider={"order": ["Anthropic"]}) + params = model._default_params + assert params["provider"] == {"order": ["Anthropic"]} + + def test_openrouter_route_in_params(self) -> None: + """Test that openrouter_route is included in default params.""" + model = _make_model(openrouter_route="fallback") + params = model._default_params + assert params["route"] == "fallback" + + def test_optional_params_excluded_when_none(self) -> None: + """Test that None optional params are not in default params.""" + model = _make_model() + params = model._default_params + assert "temperature" not in params + assert "max_tokens" not in params + assert "top_p" not in params + assert "reasoning" not in params + + def test_temperature_included_when_set(self) -> None: + """Test that temperature is included when explicitly set.""" + model = _make_model(temperature=0.5) + params = model._default_params + assert params["temperature"] == 0.5 + + +# =========================================================================== +# Serialization tests +# =========================================================================== + + +class TestSerialization: + """Tests for serialization round-trips.""" + + def test_is_lc_serializable(self) -> None: + """Test that ChatOpenRouter declares itself as serializable.""" + assert ChatOpenRouter.is_lc_serializable() is True + + def test_dumpd_load_roundtrip(self) -> None: + """Test that dumpd/load round-trip preserves model config.""" + model = _make_model(temperature=0.7, max_tokens=100) + serialized = dumpd(model) + deserialized = load( + serialized, + valid_namespaces=["langchain_openrouter"], + allowed_objects="all", + secrets_from_env=False, + secrets_map={"OPENROUTER_API_KEY": "test-key"}, + ) + assert isinstance(deserialized, ChatOpenRouter) + assert deserialized.model_name == MODEL_NAME + assert deserialized.temperature == 0.7 + assert deserialized.max_tokens == 100 + + def test_dumps_does_not_leak_secrets(self) -> None: + """Test that dumps output does not contain the raw API key.""" + model = _make_model(api_key=SecretStr("super-secret-key")) + serialized = dumps(model) + assert "super-secret-key" not in serialized + + +# =========================================================================== +# Mocked generate / stream tests +# =========================================================================== + + +class TestMockedGenerate: + """Tests for _generate / _agenerate with a mocked SDK client.""" + + def test_invoke_basic(self) -> None: + """Test basic invoke returns an AIMessage via mocked SDK.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send.return_value = _make_sdk_response(_SIMPLE_RESPONSE_DICT) + + result = model.invoke("Hello") + assert isinstance(result, AIMessage) + assert result.content == "Hello!" + model.client.chat.send.assert_called_once() + + def test_invoke_with_tool_response(self) -> None: + """Test invoke that returns tool calls.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send.return_value = _make_sdk_response(_TOOL_RESPONSE_DICT) + + result = model.invoke("What's the weather?") + assert isinstance(result, AIMessage) + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "GetWeather" + + def test_invoke_passes_correct_messages(self) -> None: + """Test that invoke converts messages and passes them to the SDK.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send.return_value = _make_sdk_response(_SIMPLE_RESPONSE_DICT) + + model.invoke([HumanMessage(content="Hi")]) + call_kwargs = model.client.chat.send.call_args[1] + assert call_kwargs["messages"] == [{"role": "user", "content": "Hi"}] + + def test_invoke_strips_internal_kwargs(self) -> None: + """Test that LangChain-internal kwargs are stripped before SDK call.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send.return_value = _make_sdk_response(_SIMPLE_RESPONSE_DICT) + + model._generate( + [HumanMessage(content="Hi")], + ls_structured_output_format={"kwargs": {"method": "function_calling"}}, + ) + call_kwargs = model.client.chat.send.call_args[1] + assert "ls_structured_output_format" not in call_kwargs + + def test_invoke_usage_metadata(self) -> None: + """Test that usage metadata is populated on the response.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send.return_value = _make_sdk_response(_SIMPLE_RESPONSE_DICT) + + result = model.invoke("Hello") + assert isinstance(result, AIMessage) + assert result.usage_metadata is not None + assert result.usage_metadata["input_tokens"] == 10 + assert result.usage_metadata["output_tokens"] == 5 + assert result.usage_metadata["total_tokens"] == 15 + + def test_stream_basic(self) -> None: + """Test streaming returns AIMessageChunks via mocked SDK.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send.return_value = _MockSyncStream( + [dict(c) for c in _STREAM_CHUNKS] + ) + + chunks = list(model.stream("Hello")) + assert len(chunks) > 0 + assert all(isinstance(c, AIMessageChunk) for c in chunks) + # Concatenated content should be "Hello world" + full_content = "".join(c.content for c in chunks if isinstance(c.content, str)) + assert "Hello" in full_content + assert "world" in full_content + + def test_stream_passes_stream_true(self) -> None: + """Test that stream sends stream=True to the SDK.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send.return_value = _MockSyncStream( + [dict(c) for c in _STREAM_CHUNKS] + ) + + list(model.stream("Hello")) + call_kwargs = model.client.chat.send.call_args[1] + assert call_kwargs["stream"] is True + + def test_invoke_with_streaming_flag(self) -> None: + """Test that invoke delegates to stream when streaming=True.""" + model = _make_model(streaming=True) + model.client = MagicMock() + model.client.chat.send.return_value = _MockSyncStream( + [dict(c) for c in _STREAM_CHUNKS] + ) + + result = model.invoke("Hello") + assert isinstance(result, AIMessage) + call_kwargs = model.client.chat.send.call_args[1] + assert call_kwargs["stream"] is True + + async def test_ainvoke_basic(self) -> None: + """Test async invoke returns an AIMessage via mocked SDK.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send_async = AsyncMock( + return_value=_make_sdk_response(_SIMPLE_RESPONSE_DICT) + ) + + result = await model.ainvoke("Hello") + assert isinstance(result, AIMessage) + assert result.content == "Hello!" + model.client.chat.send_async.assert_awaited_once() + + async def test_astream_basic(self) -> None: + """Test async streaming returns AIMessageChunks via mocked SDK.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send_async = AsyncMock( + return_value=_MockAsyncStream(_STREAM_CHUNKS) + ) + + chunks = [c async for c in model.astream("Hello")] + assert len(chunks) > 0 + assert all(isinstance(c, AIMessageChunk) for c in chunks) + + +# =========================================================================== +# Request payload verification +# =========================================================================== + + +class TestRequestPayload: + """Tests verifying the exact dict sent to the SDK.""" + + def test_message_format_in_payload(self) -> None: + """Test that messages are formatted correctly in the SDK call.""" + model = _make_model(temperature=0) + model.client = MagicMock() + model.client.chat.send.return_value = _make_sdk_response(_SIMPLE_RESPONSE_DICT) + + model.invoke( + [ + SystemMessage(content="You are helpful."), + HumanMessage(content="Hi"), + ] + ) + call_kwargs = model.client.chat.send.call_args[1] + assert call_kwargs["messages"] == [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + ] + + def test_model_kwargs_forwarded(self) -> None: + """Test that extra model_kwargs are included in the SDK call.""" + model = _make_model(model_kwargs={"top_k": 50}) + model.client = MagicMock() + model.client.chat.send.return_value = _make_sdk_response(_SIMPLE_RESPONSE_DICT) + + model.invoke("Hi") + call_kwargs = model.client.chat.send.call_args[1] + assert call_kwargs["top_k"] == 50 + + def test_stop_sequences_in_payload(self) -> None: + """Test that stop sequences are passed to the SDK.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send.return_value = _make_sdk_response(_SIMPLE_RESPONSE_DICT) + + model.invoke("Hi", stop=["END"]) + call_kwargs = model.client.chat.send.call_args[1] + assert call_kwargs["stop"] == ["END"] + + def test_tool_format_in_payload(self) -> None: + """Test that tools are formatted in OpenAI-compatible structure.""" + model = _make_model() + model.client = MagicMock() + model.client.chat.send.return_value = _make_sdk_response(_TOOL_RESPONSE_DICT) + + bound = model.bind_tools([GetWeather]) + bound.invoke("What's the weather?") + call_kwargs = model.client.chat.send.call_args[1] + tools = call_kwargs["tools"] + assert len(tools) == 1 + assert tools[0]["type"] == "function" + assert tools[0]["function"]["name"] == "GetWeather" + assert "parameters" in tools[0]["function"] + + def test_openrouter_params_in_payload(self) -> None: + """Test that OpenRouter-specific params appear in the SDK call.""" + model = _make_model( + openrouter_reasoning={"effort": "high"}, + openrouter_provider={"order": ["Anthropic"]}, + openrouter_route="fallback", + ) + model.client = MagicMock() + model.client.chat.send.return_value = _make_sdk_response(_SIMPLE_RESPONSE_DICT) + + model.invoke("Hi") + call_kwargs = model.client.chat.send.call_args[1] + assert call_kwargs["reasoning"] == {"effort": "high"} + assert call_kwargs["provider"] == {"order": ["Anthropic"]} + assert call_kwargs["route"] == "fallback" + + +# =========================================================================== +# bind_tools tests +# =========================================================================== + + +class TestBindTools: + """Tests for the bind_tools public method.""" + + @pytest.mark.parametrize( + "tool_choice", + [ + "auto", + "none", + "required", + "GetWeather", + {"type": "function", "function": {"name": "GetWeather"}}, + None, + ], + ) + def test_bind_tools_tool_choice(self, tool_choice: Any) -> None: + """Test bind_tools accepts various tool_choice values.""" + model = _make_model() + bound = model.bind_tools( + [GetWeather, GenerateUsername], tool_choice=tool_choice + ) + assert isinstance(bound, RunnableBinding) + + def test_bind_tools_bool_true_single_tool(self) -> None: + """Test bind_tools with tool_choice=True and a single tool.""" + model = _make_model() + bound = model.bind_tools([GetWeather], tool_choice=True) + assert isinstance(bound, RunnableBinding) + kwargs = bound.kwargs + assert kwargs["tool_choice"] == { + "type": "function", + "function": {"name": "GetWeather"}, + } + + def test_bind_tools_bool_true_multiple_tools_raises(self) -> None: + """Test bind_tools with tool_choice=True and multiple tools raises.""" + model = _make_model() + with pytest.raises(ValueError, match="tool_choice can only be True"): + model.bind_tools([GetWeather, GenerateUsername], tool_choice=True) + + def test_bind_tools_any_maps_to_required(self) -> None: + """Test that tool_choice='any' is mapped to 'required'.""" + model = _make_model() + bound = model.bind_tools([GetWeather], tool_choice="any") + assert isinstance(bound, RunnableBinding) + assert bound.kwargs["tool_choice"] == "required" + + def test_bind_tools_string_name_becomes_dict(self) -> None: + """Test that a specific tool name string is converted to a dict.""" + model = _make_model() + bound = model.bind_tools([GetWeather], tool_choice="GetWeather") + assert isinstance(bound, RunnableBinding) + assert bound.kwargs["tool_choice"] == { + "type": "function", + "function": {"name": "GetWeather"}, + } + + def test_bind_tools_formats_tools_correctly(self) -> None: + """Test that tools are converted to OpenAI format.""" + model = _make_model() + bound = model.bind_tools([GetWeather]) + assert isinstance(bound, RunnableBinding) + tools = bound.kwargs["tools"] + assert len(tools) == 1 + assert tools[0]["type"] == "function" + assert tools[0]["function"]["name"] == "GetWeather" + + def test_bind_tools_no_choice_omits_key(self) -> None: + """Test that tool_choice=None does not set tool_choice in kwargs.""" + model = _make_model() + bound = model.bind_tools([GetWeather], tool_choice=None) + assert isinstance(bound, RunnableBinding) + assert "tool_choice" not in bound.kwargs + + def test_bind_tools_strict_ignored(self) -> None: + """Test that strict param is accepted but ignored.""" + model = _make_model() + bound = model.bind_tools([GetWeather], strict=True) + assert isinstance(bound, RunnableBinding) + + +# =========================================================================== +# with_structured_output tests +# =========================================================================== + + +class TestWithStructuredOutput: + """Tests for the with_structured_output public method.""" + + @pytest.mark.parametrize("method", ["function_calling", "json_schema", "json_mode"]) + @pytest.mark.parametrize("include_raw", ["yes", "no"]) + def test_with_structured_output_pydantic( + self, + method: Literal["function_calling", "json_mode", "json_schema"], + include_raw: str, + ) -> None: + """Test with_structured_output using a Pydantic schema.""" + model = _make_model() + structured = model.with_structured_output( + GenerateUsername, method=method, include_raw=(include_raw == "yes") + ) + assert structured is not None + + @pytest.mark.parametrize("method", ["function_calling", "json_schema", "json_mode"]) + def test_with_structured_output_dict_schema( + self, + method: Literal["function_calling", "json_mode", "json_schema"], + ) -> None: + """Test with_structured_output using a JSON schema dict.""" + schema = GenerateUsername.model_json_schema() + model = _make_model() + structured = model.with_structured_output(schema, method=method) + assert structured is not None + + def test_with_structured_output_none_schema_function_calling_raises(self) -> None: + """Test that schema=None with function_calling raises ValueError.""" + model = _make_model() + with pytest.raises(ValueError, match="schema must be specified"): + model.with_structured_output(None, method="function_calling") + + def test_with_structured_output_none_schema_json_schema_raises(self) -> None: + """Test that schema=None with json_schema raises ValueError.""" + model = _make_model() + with pytest.raises(ValueError, match="schema must be specified"): + model.with_structured_output(None, method="json_schema") + + def test_with_structured_output_invalid_method_raises(self) -> None: + """Test that an unrecognized method raises ValueError.""" + model = _make_model() + with pytest.raises(ValueError, match="Unrecognized method"): + model.with_structured_output( + GenerateUsername, + method="invalid", # type: ignore[arg-type] + ) + + def test_with_structured_output_json_schema_sets_response_format(self) -> None: + """Test that json_schema method sets response_format correctly.""" + model = _make_model() + structured = model.with_structured_output( + GenerateUsername, method="json_schema" + ) + # The first step in the chain should be the bound model + bound = structured.first # type: ignore[attr-defined] + assert isinstance(bound, RunnableBinding) + rf = bound.kwargs["response_format"] + assert rf["type"] == "json_schema" + assert rf["json_schema"]["name"] == "GenerateUsername" + + def test_with_structured_output_json_mode_sets_response_format(self) -> None: + """Test that json_mode method sets response_format correctly.""" + model = _make_model() + structured = model.with_structured_output(GenerateUsername, method="json_mode") + bound = structured.first # type: ignore[attr-defined] + assert isinstance(bound, RunnableBinding) + rf = bound.kwargs["response_format"] + assert rf["type"] == "json_object" + + def test_with_structured_output_strict_ignored(self) -> None: + """Test that strict param is accepted but ignored.""" + model = _make_model() + structured = model.with_structured_output( + GenerateUsername, method="function_calling", strict=True + ) + assert structured is not None + + +# =========================================================================== +# Message conversion tests +# =========================================================================== + + +class TestMessageConversion: + """Tests for message conversion functions.""" + + def test_human_message_to_dict(self) -> None: + """Test converting HumanMessage to dict.""" + msg = HumanMessage(content="Hello") + result = _convert_message_to_dict(msg) + assert result == {"role": "user", "content": "Hello"} + + def test_system_message_to_dict(self) -> None: + """Test converting SystemMessage to dict.""" + msg = SystemMessage(content="You are helpful.") + result = _convert_message_to_dict(msg) + assert result == {"role": "system", "content": "You are helpful."} + + def test_ai_message_to_dict(self) -> None: + """Test converting AIMessage to dict.""" + msg = AIMessage(content="Hi there!") + result = _convert_message_to_dict(msg) + assert result == {"role": "assistant", "content": "Hi there!"} + + def test_tool_message_to_dict(self) -> None: + """Test converting ToolMessage to dict.""" + msg = ToolMessage(content="result", tool_call_id="call_123") + result = _convert_message_to_dict(msg) + assert result == { + "role": "tool", + "content": "result", + "tool_call_id": "call_123", + } + + def test_chat_message_to_dict(self) -> None: + """Test converting ChatMessage to dict.""" + msg = ChatMessage(content="Hello", role="developer") + result = _convert_message_to_dict(msg) + assert result == {"role": "developer", "content": "Hello"} + + def test_ai_message_with_tool_calls_to_dict(self) -> None: + """Test converting AIMessage with tool calls to dict.""" + msg = AIMessage( + content="", + tool_calls=[ + { + "name": "get_weather", + "args": {"location": "SF"}, + "id": "call_1", + "type": "tool_call", + } + ], + ) + result = _convert_message_to_dict(msg) + assert result["role"] == "assistant" + assert result["content"] is None + assert len(result["tool_calls"]) == 1 + assert result["tool_calls"][0]["function"]["name"] == "get_weather" + + def test_dict_to_ai_message(self) -> None: + """Test converting dict to AIMessage.""" + d = {"role": "assistant", "content": "Hello!"} + msg = _convert_dict_to_message(d) + assert isinstance(msg, AIMessage) + assert msg.content == "Hello!" + + def test_dict_to_ai_message_with_reasoning(self) -> None: + """Test that reasoning is extracted from response dict.""" + d = { + "role": "assistant", + "content": "Answer", + "reasoning": "Let me think...", + } + msg = _convert_dict_to_message(d) + assert isinstance(msg, AIMessage) + assert msg.additional_kwargs["reasoning_content"] == "Let me think..." + + def test_dict_to_ai_message_with_tool_calls(self) -> None: + """Test converting dict with tool calls to AIMessage.""" + d = { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "SF"}', + }, + } + ], + } + msg = _convert_dict_to_message(d) + assert isinstance(msg, AIMessage) + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0]["name"] == "get_weather" + + def test_dict_to_ai_message_with_invalid_tool_calls(self) -> None: + """Test that malformed tool calls produce invalid_tool_calls.""" + d = { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_bad", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "not-valid-json{{{", + }, + } + ], + } + msg = _convert_dict_to_message(d) + assert isinstance(msg, AIMessage) + assert len(msg.invalid_tool_calls) == 1 + assert len(msg.tool_calls) == 0 + assert msg.invalid_tool_calls[0]["name"] == "get_weather" + + def test_dict_to_human_message(self) -> None: + """Test converting dict to HumanMessage.""" + d = {"role": "user", "content": "Hi"} + msg = _convert_dict_to_message(d) + assert isinstance(msg, HumanMessage) + + def test_dict_to_system_message(self) -> None: + """Test converting dict to SystemMessage.""" + d = {"role": "system", "content": "Be helpful"} + msg = _convert_dict_to_message(d) + assert isinstance(msg, SystemMessage) + + def test_dict_to_tool_message(self) -> None: + """Test converting dict with role=tool to ToolMessage.""" + d = { + "role": "tool", + "content": "result data", + "tool_call_id": "call_42", + "name": "get_weather", + } + msg = _convert_dict_to_message(d) + assert isinstance(msg, ToolMessage) + assert msg.content == "result data" + assert msg.tool_call_id == "call_42" + assert msg.additional_kwargs["name"] == "get_weather" + + def test_dict_to_chat_message_unknown_role(self) -> None: + """Test that unrecognized roles fall back to ChatMessage.""" + d = {"role": "developer", "content": "Some content"} + msg = _convert_dict_to_message(d) + assert isinstance(msg, ChatMessage) + assert msg.role == "developer" + assert msg.content == "Some content" + + def test_ai_message_with_list_content_filters_non_text(self) -> None: + """Test that non-text blocks are filtered from AIMessage list content.""" + msg = AIMessage( + content=[ + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": {"url": "http://example.com"}}, + ] + ) + result = _convert_message_to_dict(msg) + assert result["content"] == [{"type": "text", "text": "Hello"}] + + +# =========================================================================== +# _create_chat_result tests +# =========================================================================== + + +class TestCreateChatResult: + """Tests for _create_chat_result.""" + + def test_model_provider_in_response_metadata(self) -> None: + """Test that model_provider is set in response metadata.""" + model = _make_model() + result = model._create_chat_result(_SIMPLE_RESPONSE_DICT) + assert ( + result.generations[0].message.response_metadata.get("model_provider") + == "openrouter" + ) + + def test_reasoning_from_response(self) -> None: + """Test that reasoning content is extracted from response.""" + model = _make_model() + response_dict: dict[str, Any] = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Answer", + "reasoning": "Let me think...", + }, + "finish_reason": "stop", + } + ], + } + result = model._create_chat_result(response_dict) + assert ( + result.generations[0].message.additional_kwargs.get("reasoning_content") + == "Let me think..." + ) + + def test_usage_metadata_created(self) -> None: + """Test that usage metadata is created from token usage.""" + model = _make_model() + result = model._create_chat_result(_SIMPLE_RESPONSE_DICT) + msg = result.generations[0].message + assert isinstance(msg, AIMessage) + usage = msg.usage_metadata + assert usage is not None + assert usage["input_tokens"] == 10 + assert usage["output_tokens"] == 5 + assert usage["total_tokens"] == 15 + + def test_tool_calls_in_response(self) -> None: + """Test that tool calls are extracted from response.""" + model = _make_model() + result = model._create_chat_result(_TOOL_RESPONSE_DICT) + msg = result.generations[0].message + assert isinstance(msg, AIMessage) + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0]["name"] == "GetWeather" + + +# =========================================================================== +# Streaming chunk tests +# =========================================================================== + + +class TestStreamingChunks: + """Tests for streaming chunk conversion.""" + + def test_reasoning_in_streaming_chunk(self) -> None: + """Test that reasoning is extracted from streaming delta.""" + chunk: dict[str, Any] = { + "choices": [ + { + "delta": { + "content": "Main content", + "reasoning": "Streaming reasoning", + }, + }, + ], + } + message_chunk = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + assert isinstance(message_chunk, AIMessageChunk) + assert ( + message_chunk.additional_kwargs.get("reasoning_content") + == "Streaming reasoning" + ) + + def test_model_provider_in_streaming_chunk(self) -> None: + """Test that model_provider is set in streaming chunk metadata.""" + chunk: dict[str, Any] = { + "choices": [ + { + "delta": {"content": "Hello"}, + }, + ], + } + message_chunk = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + assert isinstance(message_chunk, AIMessageChunk) + assert message_chunk.response_metadata.get("model_provider") == "openrouter" + + def test_chunk_without_reasoning(self) -> None: + """Test that chunk without reasoning fields works correctly.""" + chunk: dict[str, Any] = {"choices": [{"delta": {"content": "Hello"}}]} + message_chunk = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + assert isinstance(message_chunk, AIMessageChunk) + assert message_chunk.additional_kwargs.get("reasoning_content") is None + + def test_chunk_with_empty_delta(self) -> None: + """Test that chunk with empty delta works correctly.""" + chunk: dict[str, Any] = {"choices": [{"delta": {}}]} + message_chunk = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + assert isinstance(message_chunk, AIMessageChunk) + assert message_chunk.additional_kwargs.get("reasoning_content") is None + + def test_chunk_with_tool_calls(self) -> None: + """Test that tool calls are extracted from streaming delta.""" + chunk: dict[str, Any] = { + "choices": [ + { + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "call_1", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"loc', + }, + } + ], + }, + }, + ], + } + message_chunk = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + assert isinstance(message_chunk, AIMessageChunk) + assert "tool_calls" in message_chunk.additional_kwargs + + def test_chunk_with_user_role(self) -> None: + """Test that a chunk with role=user produces HumanMessageChunk.""" + chunk: dict[str, Any] = { + "choices": [{"delta": {"role": "user", "content": "test"}}] + } + msg = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + assert isinstance(msg, HumanMessageChunk) + + def test_chunk_with_system_role(self) -> None: + """Test that a chunk with role=system produces SystemMessageChunk.""" + chunk: dict[str, Any] = { + "choices": [{"delta": {"role": "system", "content": "test"}}] + } + # Use ChatMessageChunk default so role dispatch isn't short-circuited + msg = _convert_chunk_to_message_chunk(chunk, ChatMessageChunk) + assert isinstance(msg, SystemMessageChunk) + + def test_chunk_with_unknown_role(self) -> None: + """Test that an unknown role falls back to ChatMessageChunk.""" + chunk: dict[str, Any] = { + "choices": [{"delta": {"role": "developer", "content": "test"}}] + } + msg = _convert_chunk_to_message_chunk(chunk, ChatMessageChunk) + assert isinstance(msg, ChatMessageChunk) + + def test_chunk_with_usage(self) -> None: + """Test that usage metadata is extracted from streaming chunk.""" + chunk: dict[str, Any] = { + "choices": [{"delta": {"content": ""}}], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + message_chunk = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + assert isinstance(message_chunk, AIMessageChunk) + assert message_chunk.usage_metadata is not None + assert message_chunk.usage_metadata["input_tokens"] == 10 + + +# =========================================================================== +# Usage metadata tests +# =========================================================================== + + +class TestUsageMetadata: + """Tests for _create_usage_metadata.""" + + def test_basic_usage(self) -> None: + """Test basic usage metadata creation.""" + usage = _create_usage_metadata( + {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + ) + assert usage["input_tokens"] == 10 + assert usage["output_tokens"] == 5 + assert usage["total_tokens"] == 15 + + def test_float_tokens_cast_to_int(self) -> None: + """Test that float token counts are cast to int.""" + usage = _create_usage_metadata( + {"prompt_tokens": 10.0, "completion_tokens": 5.0, "total_tokens": 15.0} + ) + assert usage["input_tokens"] == 10 + assert isinstance(usage["input_tokens"], int) + + def test_missing_tokens_default_to_zero(self) -> None: + """Test that missing token fields default to zero.""" + usage = _create_usage_metadata({}) + assert usage["input_tokens"] == 0 + assert usage["output_tokens"] == 0 + assert usage["total_tokens"] == 0 + + def test_total_tokens_computed_if_missing(self) -> None: + """Test that total_tokens is computed if not provided.""" + usage = _create_usage_metadata({"prompt_tokens": 10, "completion_tokens": 5}) + assert usage["total_tokens"] == 15 + + def test_token_details(self) -> None: + """Test that token details are extracted.""" + usage = _create_usage_metadata( + { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_tokens_details": {"cached_tokens": 20}, + "completion_tokens_details": {"reasoning_tokens": 10}, + } + ) + assert "input_token_details" in usage + assert "output_token_details" in usage + + def test_alternative_token_key_names(self) -> None: + """Test fallback to input_tokens/output_tokens key names.""" + usage = _create_usage_metadata( + { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + } + ) + assert usage["input_tokens"] == 10 + assert usage["output_tokens"] == 5 + assert usage["total_tokens"] == 15 + + +# =========================================================================== +# _combine_llm_outputs tests +# =========================================================================== + + +class TestCombineLLMOutputs: + """Tests for _combine_llm_outputs.""" + + def test_single_output(self) -> None: + """Test combining a single output.""" + model = _make_model() + result = model._combine_llm_outputs( + [{"token_usage": {"prompt_tokens": 10, "completion_tokens": 5}}] + ) + assert result["token_usage"]["prompt_tokens"] == 10 + assert result["token_usage"]["completion_tokens"] == 5 + + def test_multiple_outputs_accumulated(self) -> None: + """Test that token counts from multiple outputs are accumulated.""" + model = _make_model() + result = model._combine_llm_outputs( + [ + { + "token_usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } + }, + { + "token_usage": { + "prompt_tokens": 20, + "completion_tokens": 10, + "total_tokens": 30, + } + }, + ] + ) + assert result["token_usage"]["prompt_tokens"] == 30 + assert result["token_usage"]["completion_tokens"] == 15 + assert result["token_usage"]["total_tokens"] == 45 + + def test_none_outputs_skipped(self) -> None: + """Test that None outputs are skipped.""" + model = _make_model() + result = model._combine_llm_outputs( + [None, {"token_usage": {"prompt_tokens": 10}}, None] + ) + assert result["token_usage"]["prompt_tokens"] == 10 + + def test_empty_list(self) -> None: + """Test combining an empty list.""" + model = _make_model() + result = model._combine_llm_outputs([]) + assert result["token_usage"] == {} + assert result["model_name"] == MODEL_NAME + + def test_nested_dict_accumulation(self) -> None: + """Test that nested dicts (e.g. token details) are accumulated.""" + model = _make_model() + result = model._combine_llm_outputs( + [ + { + "token_usage": { + "prompt_tokens": 10, + "prompt_tokens_details": {"cached_tokens": 5}, + } + }, + { + "token_usage": { + "prompt_tokens": 20, + "prompt_tokens_details": {"cached_tokens": 3}, + } + }, + ] + ) + assert result["token_usage"]["prompt_tokens"] == 30 + assert result["token_usage"]["prompt_tokens_details"]["cached_tokens"] == 8 + + def test_none_token_usage_skipped(self) -> None: + """Test that outputs with token_usage=None are handled.""" + model = _make_model() + result = model._combine_llm_outputs( + [{"token_usage": None}, {"token_usage": {"prompt_tokens": 5}}] + ) + assert result["token_usage"]["prompt_tokens"] == 5 + + +# =========================================================================== +# Error-path tests +# =========================================================================== + + +class TestErrorPaths: + """Tests for error handling in various code paths.""" + + def test_n_less_than_1_raises(self) -> None: + """Test that n < 1 raises ValueError.""" + with pytest.raises(ValueError, match="n must be at least 1"): + _make_model(n=0) + + def test_n_greater_than_1_with_streaming_raises(self) -> None: + """Test that n > 1 with streaming raises ValueError.""" + with pytest.raises(ValueError, match="n must be 1 when streaming"): + _make_model(n=2, streaming=True) + + def test_n_forwarded_in_params(self) -> None: + """Test that n > 1 is included in _default_params.""" + model = _make_model(n=3) + assert model._default_params["n"] == 3 + + def test_n_default_excluded_from_params(self) -> None: + """Test that n=1 (default) is not in _default_params.""" + model = _make_model() + assert "n" not in model._default_params + + def test_unknown_message_type_raises(self) -> None: + """Test that unknown message types raise TypeError.""" + from langchain_core.messages import FunctionMessage # noqa: PLC0415 + + msg = FunctionMessage(content="result", name="fn") + with pytest.raises(TypeError, match="Got unknown type"): + _convert_message_to_dict(msg) + + def test_duplicate_model_kwargs_raises(self) -> None: + """Test that passing a param in both field and model_kwargs raises.""" + with pytest.raises(ValueError, match="supplied twice"): + _make_model(temperature=0.5, model_kwargs={"temperature": 0.7}) + + def test_known_field_in_model_kwargs_raises(self) -> None: + """Test that a known field passed in model_kwargs raises.""" + with pytest.raises(ValueError, match="should be specified explicitly"): + _make_model(model_kwargs={"model_name": "some-model"}) + + def test_max_retries_zero_disables_retries(self) -> None: + """Test that max_retries=0 does not configure retry.""" + with patch("openrouter.OpenRouter") as mock_cls: + mock_cls.return_value = MagicMock() + ChatOpenRouter( + model=MODEL_NAME, + api_key=SecretStr("test-key"), + max_retries=0, + ) + call_kwargs = mock_cls.call_args[1] + assert "retry_config" not in call_kwargs + + def test_max_retries_scales_elapsed_time(self) -> None: + """Test that max_retries value scales max_elapsed_time.""" + with patch("openrouter.OpenRouter") as mock_cls: + mock_cls.return_value = MagicMock() + ChatOpenRouter( + model=MODEL_NAME, + api_key=SecretStr("test-key"), + max_retries=4, + ) + call_kwargs = mock_cls.call_args[1] + retry_config = call_kwargs["retry_config"] + assert retry_config.backoff.max_elapsed_time == 4 * 150_000 diff --git a/libs/partners/openrouter/tests/unit_tests/test_imports.py b/libs/partners/openrouter/tests/unit_tests/test_imports.py new file mode 100644 index 00000000000..d46ab65340d --- /dev/null +++ b/libs/partners/openrouter/tests/unit_tests/test_imports.py @@ -0,0 +1,12 @@ +"""Test `langchain_openrouter` public API surface.""" + +from langchain_openrouter import __all__ + +EXPECTED_ALL = [ + "ChatOpenRouter", +] + + +def test_all_imports() -> None: + """Verify that __all__ exports match the expected public API.""" + assert sorted(EXPECTED_ALL) == sorted(__all__) diff --git a/libs/partners/openrouter/tests/unit_tests/test_standard.py b/libs/partners/openrouter/tests/unit_tests/test_standard.py new file mode 100644 index 00000000000..19a2858fc92 --- /dev/null +++ b/libs/partners/openrouter/tests/unit_tests/test_standard.py @@ -0,0 +1,39 @@ +"""Standard unit tests for `ChatOpenRouter`.""" + +from langchain_tests.unit_tests import ChatModelUnitTests + +from langchain_openrouter.chat_models import ChatOpenRouter + +MODEL_NAME = "openai/gpt-4o-mini" + + +class TestChatOpenRouterUnit(ChatModelUnitTests): + """Standard unit tests for `ChatOpenRouter` chat model.""" + + @property + def chat_model_class(self) -> type[ChatOpenRouter]: + """Chat model class being tested.""" + return ChatOpenRouter + + @property + def init_from_env_params(self) -> tuple[dict, dict, dict]: + """Parameters to initialize from environment variables.""" + return ( + { + "OPENROUTER_API_KEY": "api_key", + }, + { + "model": MODEL_NAME, + }, + { + "openrouter_api_key": "api_key", + }, + ) + + @property + def chat_model_params(self) -> dict: + """Parameters to create chat model instance for testing.""" + return { + "model": MODEL_NAME, + "api_key": "test-api-key", + }