"""OpenRouter chat models.""" from __future__ import annotations import json import warnings from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from operator import itemgetter from typing import Any, Literal, cast from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models import ( LanguageModelInput, ModelProfile, ModelProfileRegistry, ) from langchain_core.language_models.chat_models import ( BaseChatModel, LangSmithParams, agenerate_from_stream, generate_from_stream, ) from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, ChatMessage, ChatMessageChunk, HumanMessage, HumanMessageChunk, InvalidToolCall, SystemMessage, SystemMessageChunk, ToolCall, ToolMessage, ToolMessageChunk, is_data_content_block, ) from langchain_core.messages.ai import ( InputTokenDetails, OutputTokenDetails, UsageMetadata, ) from langchain_core.messages.block_translators.openai import ( convert_to_openai_data_block, ) from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, PydanticToolsParser, make_invalid_tool_call, parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env from langchain_core.utils.function_calling import ( convert_to_json_schema, convert_to_openai_tool, ) from langchain_core.utils.pydantic import is_basemodel_subclass from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self from langchain_openrouter.data._profiles import _PROFILES _MODEL_PROFILES = cast("ModelProfileRegistry", _PROFILES) # LangChain-internal kwargs that must not be forwarded to the SDK. _INTERNAL_KWARGS = frozenset({"ls_structured_output_format"}) def _get_default_model_profile(model_name: str) -> ModelProfile: default = _MODEL_PROFILES.get(model_name) or {} return default.copy() class ChatOpenRouter(BaseChatModel): """OpenRouter chat model integration. OpenRouter is a unified API that provides access to hundreds of models from multiple providers (OpenAI, Anthropic, Google, Meta, etc.). ???+ info "Setup" Install `langchain-openrouter` and set environment variable `OPENROUTER_API_KEY`. ```bash pip install -U langchain-openrouter ``` ```bash export OPENROUTER_API_KEY="your-api-key" ``` ??? info "Key init args — completion params" | Param | Type | Description | | ----- | ---- | ----------- | | `model` | `str` | Model name, e.g. `'openai/gpt-4o-mini'`. | | `temperature` | `float | None` | Sampling temperature. | | `max_tokens` | `int | None` | Max tokens to generate. | ??? info "Key init args — client params" | Param | Type | Description | | ----- | ---- | ----------- | | `api_key` | `str | None` | OpenRouter API key. | | `base_url` | `str | None` | Base URL for API requests. | | `timeout` | `int | None` | Timeout in milliseconds. | | `app_url` | `str | None` | App URL for attribution. | | `app_title` | `str | None` | App title for attribution. | | `app_categories` | `list[str] | None` | Marketplace attribution categories. | | `session_id` | `str | None` | Group related requests for observability. | | `trace` | `dict[str, Any] | None` | Trace metadata for broadcasts. | | `max_retries` | `int` | Max retries (default `2`). Set to `0` to disable. | ??? info "Instantiate" ```python from langchain_openrouter import ChatOpenRouter model = ChatOpenRouter( model="anthropic/claude-sonnet-4-5", temperature=0, # api_key="...", # openrouter_provider={"order": ["Anthropic"]}, ) ``` See https://openrouter.ai/docs for platform documentation. """ client: Any = Field(default=None, exclude=True) """Underlying SDK client (`openrouter.OpenRouter`).""" openrouter_api_key: SecretStr | None = Field( alias="api_key", default_factory=secret_from_env("OPENROUTER_API_KEY", default=None), ) """OpenRouter API key.""" openrouter_api_base: str | None = Field( default_factory=from_env("OPENROUTER_API_BASE", default=None), alias="base_url", ) """OpenRouter API base URL. Maps to SDK `server_url`.""" app_url: str | None = Field( default_factory=from_env( "OPENROUTER_APP_URL", default="https://docs.langchain.com", ), ) """Application URL for OpenRouter attribution. Maps to `HTTP-Referer` header. Defaults to LangChain docs URL. Set this to your app's URL to get attribution for API usage in the OpenRouter dashboard. See https://openrouter.ai/docs/app-attribution for details. """ app_title: str | None = Field( default_factory=from_env("OPENROUTER_APP_TITLE", default="LangChain"), ) """Application title for OpenRouter attribution. Maps to `X-Title` header. Defaults to `'LangChain'`. Set this to your app's name to get attribution for API usage in the OpenRouter dashboard. See https://openrouter.ai/docs/app-attribution for details. """ app_categories: list[str] | None = Field( default=None, ) """Marketplace categories for OpenRouter attribution. Maps to `X-OpenRouter-Categories` header. Pass a list of lowercase, hyphen-separated category strings (max 30 characters each), e.g. `['cli-agent', 'programming-app']`. Only recognized categories are accepted (unrecognized values are silently dropped by OpenRouter). See https://openrouter.ai/docs/app-attribution for recognized categories. """ request_timeout: int | None = Field(default=None, alias="timeout") """Timeout for requests in milliseconds. Maps to SDK `timeout_ms`.""" max_retries: int = 2 """Maximum number of retries. Each unit adds ~150 seconds to the backoff window via the SDK's `max_elapsed_time` (e.g. `max_retries=2` allows up to ~300 s). Set to `0` to disable retries. """ model_name: str = Field(alias="model") """The name of the model, e.g. `'anthropic/claude-sonnet-4-5'`.""" @property def model(self) -> str: """Same as model_name.""" return self.model_name temperature: float | None = None """Sampling temperature.""" max_tokens: int | None = None """Maximum number of tokens to generate.""" max_completion_tokens: int | None = None """Maximum number of completion tokens to generate.""" top_p: float | None = None """Nucleus sampling parameter.""" frequency_penalty: float | None = None """Frequency penalty for generation.""" presence_penalty: float | None = None """Presence penalty for generation.""" seed: int | None = None """Random seed for reproducibility.""" stop: list[str] | str | None = Field(default=None, alias="stop_sequences") """Default stop sequences.""" n: int = Field(default=1, ge=1) """Number of chat completions to generate for each prompt.""" streaming: bool = False """Whether to stream the results or not.""" stream_usage: bool = True """Whether to include usage metadata in streaming output. If `True`, additional message chunks will be generated during the stream including usage metadata. """ model_kwargs: dict[str, Any] = Field(default_factory=dict) """Any extra model parameters for the OpenRouter API.""" reasoning: dict[str, Any] | None = None """Reasoning settings to pass to OpenRouter. Controls how many tokens the model allocates for internal chain-of-thought reasoning. Accepts an `openrouter.components.OpenResponsesReasoningConfig` or an equivalent dict. Supported keys: - `effort`: Controls reasoning token budget. Values: `'xhigh'`, `'high'`, `'medium'`, `'low'`, `'minimal'`, `'none'`. - `summary`: Controls verbosity of the reasoning summary returned in the response. Values: `'auto'`, `'concise'`, `'detailed'`. Example: `{"effort": "high", "summary": "auto"}` See https://openrouter.ai/docs/guides/best-practices/reasoning-tokens """ openrouter_provider: dict[str, Any] | None = None """Provider preferences to pass to OpenRouter. Example: `{"order": ["Anthropic", "OpenAI"]}` """ route: str | None = None """Route preference for OpenRouter, e.g. `'fallback'`.""" plugins: list[dict[str, Any]] | None = None """Plugins configuration for OpenRouter.""" session_id: str | None = Field( default_factory=from_env("OPENROUTER_SESSION_ID", default=None), ) """Identifier used by OpenRouter to group related requests together. Useful any time multiple requests should share an observability grouping (e.g. a conversation, an agent workflow, a batch job, or a CI run). Equivalent to setting the `x-session-id` HTTP header on the underlying request. OpenRouter rejects values longer than 128 characters. Falls back to the `OPENROUTER_SESSION_ID` environment variable when unset, so callers can group all requests from a process without threading the value through application code. Empty strings are treated as unset. Example: `"conv-2026-04-30-abc"` See https://openrouter.ai/docs/guides/features/broadcast/overview """ trace: dict[str, Any] | None = None """Trace metadata for observability tools (e.g. Langfuse, LangSmith). Forwarded by OpenRouter to configured broadcast destinations. Common keys include `trace_id`, `trace_name`, `span_name`, `generation_name`, and `parent_span_id`; see the OpenRouter broadcast docs for the current full set. Unknown keys are forwarded as custom metadata. No environment-variable fallback — set per-call or on the constructor. Example: `{"trace_id": "abc-123", "span_name": "summarize"}` See https://openrouter.ai/docs/guides/features/broadcast/overview """ model_config = ConfigDict(populate_by_name=True) @model_validator(mode="before") @classmethod def build_extra(cls, values: dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: msg = f"Found {field_name} supplied twice." raise ValueError(msg) if field_name not in all_required_field_names: warnings.warn( f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""", stacklevel=2, ) extra[field_name] = values.pop(field_name) invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) if invalid_model_kwargs: msg = ( f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Instead they were passed in as part of `model_kwargs` parameter." ) raise ValueError(msg) values["model_kwargs"] = extra return values def _build_client(self) -> Any: """Build and return an `openrouter.OpenRouter` SDK client. Returns: An `openrouter.OpenRouter` SDK client instance. """ import openrouter # noqa: PLC0415 from openrouter.utils import ( # noqa: PLC0415 BackoffStrategy, RetryConfig, ) client_kwargs: dict[str, Any] = { "api_key": self.openrouter_api_key.get_secret_value(), # type: ignore[union-attr] } if self.openrouter_api_base: client_kwargs["server_url"] = self.openrouter_api_base extra_headers: dict[str, str] = {} if self.app_url: extra_headers["HTTP-Referer"] = self.app_url if self.app_title: extra_headers["X-Title"] = self.app_title if self.app_categories: extra_headers["X-OpenRouter-Categories"] = ",".join(self.app_categories) if extra_headers: import httpx # noqa: PLC0415 client_kwargs["client"] = httpx.Client( headers=extra_headers, follow_redirects=True ) client_kwargs["async_client"] = httpx.AsyncClient( headers=extra_headers, follow_redirects=True ) if self.request_timeout is not None: client_kwargs["timeout_ms"] = self.request_timeout if self.max_retries > 0: client_kwargs["retry_config"] = RetryConfig( strategy="backoff", backoff=BackoffStrategy( initial_interval=500, max_interval=60000, exponent=1.5, max_elapsed_time=self.max_retries * 150_000, ), retry_connection_errors=True, ) return openrouter.OpenRouter(**client_kwargs) @model_validator(mode="after") def validate_environment(self) -> Self: """Validate configuration and build the SDK client.""" if not (self.openrouter_api_key and self.openrouter_api_key.get_secret_value()): msg = "OPENROUTER_API_KEY must be set." raise ValueError(msg) if self.n > 1 and self.streaming: msg = "n must be 1 when streaming." raise ValueError(msg) if not self.client: try: import openrouter # noqa: PLC0415, F401 self.client = self._build_client() except ImportError as e: msg = ( "Could not import the `openrouter` Python SDK. " "Please install it with: pip install openrouter" ) raise ImportError(msg) from e return self def _resolve_model_profile(self) -> ModelProfile | None: return _get_default_model_profile(self.model_name) or None # # Serializable class method overrides # @property def lc_secrets(self) -> dict[str, str]: """A map of constructor argument names to secret ids.""" return {"openrouter_api_key": "OPENROUTER_API_KEY"} @classmethod def is_lc_serializable(cls) -> bool: """Return whether this model can be serialized by LangChain.""" return True # # BaseChatModel method overrides # @property def _llm_type(self) -> str: """Return type of chat model.""" return "openrouter-chat" @property def _identifying_params(self) -> dict[str, Any]: """Get the identifying parameters.""" return { "model": self.model_name, "temperature": self.temperature, "max_tokens": self.max_tokens, "top_p": self.top_p, "streaming": self.streaming, "reasoning": self.reasoning, "openrouter_provider": self.openrouter_provider, "route": self.route, "model_kwargs": self.model_kwargs, } def _get_ls_params( self, stop: list[str] | None = None, **kwargs: Any, ) -> LangSmithParams: """Get standard params for tracing.""" params = self._get_invocation_params(stop=stop, **kwargs) ls_params = LangSmithParams( ls_provider="openrouter", ls_model_name=params.get("model", self.model_name), ls_model_type="chat", ls_temperature=params.get("temperature", self.temperature), ) if ls_max_tokens := params.get("max_tokens", self.max_tokens): ls_params["ls_max_tokens"] = ls_max_tokens if ls_stop := stop or params.get("stop", None) or self.stop: ls_params["ls_stop"] = ls_stop if isinstance(ls_stop, list) else [ls_stop] return ls_params def _generate( self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} _strip_internal_kwargs(params) sdk_messages = _wrap_messages_for_sdk(message_dicts) response = self.client.chat.send(messages=sdk_messages, **params) return self._create_chat_result(response) async def _agenerate( self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: if self.streaming: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) return await agenerate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} _strip_internal_kwargs(params) sdk_messages = _wrap_messages_for_sdk(message_dicts) response = await self.client.chat.send_async(messages=sdk_messages, **params) return self._create_chat_result(response) def _stream( # noqa: C901, PLR0912 self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} if self.stream_usage: params["stream_options"] = {"include_usage": True} _strip_internal_kwargs(params) sdk_messages = _wrap_messages_for_sdk(message_dicts) default_chunk_class: type[BaseMessageChunk] = AIMessageChunk for chunk in self.client.chat.send(messages=sdk_messages, **params): chunk_dict = chunk.model_dump(by_alias=True) if not chunk_dict.get("choices"): if error := chunk_dict.get("error"): msg = ( f"OpenRouter API returned an error during streaming: " f"{error.get('message', str(error))} " f"(code: {error.get('code', 'unknown')})" ) raise ValueError(msg) # Usage-only chunk (no choices) — emit with usage_metadata if usage := chunk_dict.get("usage"): usage_metadata = _create_usage_metadata(usage) usage_chunk = AIMessageChunk( content="", usage_metadata=usage_metadata ) generation_chunk = ChatGenerationChunk(message=usage_chunk) if run_manager: run_manager.on_llm_new_token( generation_chunk.text, chunk=generation_chunk ) yield generation_chunk continue choice = chunk_dict["choices"][0] message_chunk = _convert_chunk_to_message_chunk( chunk_dict, default_chunk_class ) generation_info: dict[str, Any] = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason # Include response-level metadata on the final chunk response_model = chunk_dict.get("model") generation_info["model_name"] = response_model or self.model_name if system_fingerprint := chunk_dict.get("system_fingerprint"): generation_info["system_fingerprint"] = system_fingerprint if native_finish_reason := choice.get("native_finish_reason"): generation_info["native_finish_reason"] = native_finish_reason if response_id := chunk_dict.get("id"): generation_info["id"] = response_id if created := chunk_dict.get("created"): generation_info["created"] = int(created) if object_ := chunk_dict.get("object"): generation_info["object"] = object_ logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs if generation_info: generation_info["model_provider"] = "openrouter" message_chunk = message_chunk.model_copy( update={ "response_metadata": { **message_chunk.response_metadata, **generation_info, } } ) default_chunk_class = message_chunk.__class__ generation_chunk = ChatGenerationChunk( message=message_chunk, generation_info=generation_info or None ) if run_manager: run_manager.on_llm_new_token( generation_chunk.text, chunk=generation_chunk, logprobs=logprobs, ) yield generation_chunk async def _astream( # noqa: C901, PLR0912 self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} if self.stream_usage: params["stream_options"] = {"include_usage": True} _strip_internal_kwargs(params) sdk_messages = _wrap_messages_for_sdk(message_dicts) default_chunk_class: type[BaseMessageChunk] = AIMessageChunk async for chunk in await self.client.chat.send_async( messages=sdk_messages, **params ): chunk_dict = chunk.model_dump(by_alias=True) if not chunk_dict.get("choices"): if error := chunk_dict.get("error"): msg = ( f"OpenRouter API returned an error during streaming: " f"{error.get('message', str(error))} " f"(code: {error.get('code', 'unknown')})" ) raise ValueError(msg) # Usage-only chunk (no choices) — emit with usage_metadata if usage := chunk_dict.get("usage"): usage_metadata = _create_usage_metadata(usage) usage_chunk = AIMessageChunk( content="", usage_metadata=usage_metadata ) generation_chunk = ChatGenerationChunk(message=usage_chunk) if run_manager: await run_manager.on_llm_new_token( token=generation_chunk.text, chunk=generation_chunk ) yield generation_chunk continue choice = chunk_dict["choices"][0] message_chunk = _convert_chunk_to_message_chunk( chunk_dict, default_chunk_class ) generation_info: dict[str, Any] = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason # Include response-level metadata on the final chunk response_model = chunk_dict.get("model") generation_info["model_name"] = response_model or self.model_name if system_fingerprint := chunk_dict.get("system_fingerprint"): generation_info["system_fingerprint"] = system_fingerprint if native_finish_reason := choice.get("native_finish_reason"): generation_info["native_finish_reason"] = native_finish_reason if response_id := chunk_dict.get("id"): generation_info["id"] = response_id if created := chunk_dict.get("created"): generation_info["created"] = int(created) # UNIX timestamp if object_ := chunk_dict.get("object"): generation_info["object"] = object_ logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs if generation_info: generation_info["model_provider"] = "openrouter" message_chunk = message_chunk.model_copy( update={ "response_metadata": { **message_chunk.response_metadata, **generation_info, } } ) default_chunk_class = message_chunk.__class__ generation_chunk = ChatGenerationChunk( message=message_chunk, generation_info=generation_info or None ) if run_manager: await run_manager.on_llm_new_token( token=generation_chunk.text, chunk=generation_chunk, logprobs=logprobs, ) yield generation_chunk # # Internal methods # @property def _default_params(self) -> dict[str, Any]: # noqa: C901, PLR0912 """Get the default parameters for calling OpenRouter API.""" params: dict[str, Any] = { "model": self.model_name, "stream": self.streaming, **self.model_kwargs, } if self.temperature is not None: params["temperature"] = self.temperature if self.max_tokens is not None: params["max_tokens"] = self.max_tokens if self.max_completion_tokens is not None: params["max_completion_tokens"] = self.max_completion_tokens if self.top_p is not None: params["top_p"] = self.top_p if self.frequency_penalty is not None: params["frequency_penalty"] = self.frequency_penalty if self.presence_penalty is not None: params["presence_penalty"] = self.presence_penalty if self.seed is not None: params["seed"] = self.seed if self.n > 1: params["n"] = self.n if self.stop is not None: params["stop"] = self.stop # OpenRouter-specific params if self.reasoning is not None: params["reasoning"] = self.reasoning if self.openrouter_provider is not None: params["provider"] = self.openrouter_provider if self.route is not None: params["route"] = self.route if self.plugins is not None: params["plugins"] = self.plugins if self.session_id: params["session_id"] = self.session_id if self.trace is not None: params["trace"] = self.trace return params def _create_message_dicts( self, messages: list[BaseMessage], stop: list[str] | None ) -> tuple[list[dict[str, Any]], dict[str, Any]]: params = self._default_params if stop is not None: params["stop"] = stop message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params def _create_chat_result(self, response: Any) -> ChatResult: # noqa: C901, PLR0912 """Create a `ChatResult` from an OpenRouter SDK response.""" if not isinstance(response, dict): response = response.model_dump(by_alias=True) if error := response.get("error"): msg = ( f"OpenRouter API returned an error: " f"{error.get('message', str(error))} " f"(code: {error.get('code', 'unknown')})" ) raise ValueError(msg) generations = [] token_usage = response.get("usage") or {} choices = response.get("choices", []) if not choices: msg = ( "OpenRouter API returned a response with no choices. " "This may indicate a problem with the request or model availability." ) raise ValueError(msg) # Extract top-level response metadata response_model = response.get("model") system_fingerprint = response.get("system_fingerprint") for res in choices: message = _convert_dict_to_message(res["message"]) if token_usage and isinstance(message, AIMessage): message.usage_metadata = _create_usage_metadata(token_usage) # Surface OpenRouter cost data in response_metadata if "cost" in token_usage: message.response_metadata["cost"] = token_usage["cost"] if "cost_details" in token_usage: message.response_metadata["cost_details"] = token_usage[ "cost_details" ] if isinstance(message, AIMessage): if system_fingerprint: message.response_metadata["system_fingerprint"] = system_fingerprint if native_finish_reason := res.get("native_finish_reason"): message.response_metadata["native_finish_reason"] = ( native_finish_reason ) generation_info: dict[str, Any] = { "finish_reason": res.get("finish_reason"), } if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] gen = ChatGeneration( message=message, generation_info=generation_info, ) generations.append(gen) llm_output: dict[str, Any] = { "model_name": response_model or self.model_name, } if response_id := response.get("id"): llm_output["id"] = response_id if created := response.get("created"): llm_output["created"] = int(created) if object_ := response.get("object"): llm_output["object"] = object_ return ChatResult(generations=generations, llm_output=llm_output) def bind_tools( self, tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool], *, tool_choice: dict | str | bool | None = None, strict: bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: """Bind tool-like objects to this chat model. Args: tools: A list of tool definitions to bind to this chat model. Supports any tool definition handled by `langchain_core.utils.function_calling.convert_to_openai_tool`. tool_choice: Which tool to require the model to call. strict: If `True`, model output is guaranteed to exactly match the JSON Schema provided in the tool definition. If `None`, the `strict` argument will not be passed to the model. **kwargs: Any additional parameters. """ formatted_tools = [ convert_to_openai_tool(tool, strict=strict) for tool in tools ] if tool_choice is not None and tool_choice: if tool_choice == "any": tool_choice = "required" if isinstance(tool_choice, str) and ( tool_choice not in ("auto", "none", "required") ): tool_choice = {"type": "function", "function": {"name": tool_choice}} if isinstance(tool_choice, bool): if len(tools) > 1: msg = ( "tool_choice can only be True when there is one tool. Received " f"{len(tools)} tools." ) raise ValueError(msg) tool_name = formatted_tools[0]["function"]["name"] tool_choice = { "type": "function", "function": {"name": tool_name}, } kwargs["tool_choice"] = tool_choice return super().bind(tools=formatted_tools, **kwargs) def with_structured_output( # type: ignore[override] self, schema: dict | type[BaseModel] | None = None, *, method: Literal["function_calling", "json_schema"] = "function_calling", include_raw: bool = False, strict: bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, dict | BaseModel]: """Model wrapper that returns outputs formatted to match the given schema. Args: schema: The output schema as a Pydantic class, TypedDict, JSON Schema, or OpenAI function schema. method: The method for steering model generation. include_raw: If `True` then both the raw model response and the parsed model response will be returned. strict: If `True`, model output is guaranteed to exactly match the JSON Schema provided in the schema definition. If `None`, the `strict` argument will not be passed to the model. **kwargs: Any additional parameters. Returns: A `Runnable` that takes same inputs as a `BaseChatModel`. """ if method == "json_mode": warnings.warn( "Unrecognized structured output method 'json_mode'. " "Defaulting to 'json_schema' method.", stacklevel=2, ) method = "json_schema" is_pydantic_schema = _is_pydantic_class(schema) if method == "function_calling": if schema is None: msg = ( "schema must be specified when method is 'function_calling'. " "Received None." ) raise ValueError(msg) formatted_tool = convert_to_openai_tool(schema) tool_name = formatted_tool["function"]["name"] llm = self.bind_tools( [schema], tool_choice=tool_name, strict=strict, ls_structured_output_format={ "kwargs": {"method": "function_calling", "strict": strict}, "schema": formatted_tool, }, **kwargs, ) if is_pydantic_schema: output_parser: OutputParserLike = PydanticToolsParser( tools=[schema], # type: ignore[list-item] first_tool_only=True, # type: ignore[list-item] ) else: output_parser = JsonOutputKeyToolsParser( key_name=tool_name, first_tool_only=True ) elif method == "json_schema": if schema is None: msg = ( "schema must be specified when method is 'json_schema'. " "Received None." ) raise ValueError(msg) json_schema = convert_to_json_schema(schema) schema_name = json_schema.get("title", "") json_schema_spec: dict[str, Any] = { "name": schema_name, "schema": json_schema, } if strict is not None: json_schema_spec["strict"] = strict response_format = { "type": "json_schema", "json_schema": json_schema_spec, } ls_format_info = { "kwargs": {"method": "json_schema", "strict": strict}, "schema": json_schema, } llm = self.bind( response_format=response_format, ls_structured_output_format=ls_format_info, **kwargs, ) output_parser = ( PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] if is_pydantic_schema else JsonOutputParser() ) else: msg = ( f"Unrecognized method argument. Expected one of 'function_calling' " f"or 'json_schema'. Received: '{method}'" ) raise ValueError(msg) if include_raw: parser_assign = RunnablePassthrough.assign( parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None ) parser_none = RunnablePassthrough.assign(parsed=lambda _: None) parser_with_fallback = parser_assign.with_fallbacks( [parser_none], exception_key="parsing_error" ) return RunnableMap(raw=llm) | parser_with_fallback return llm | output_parser def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and is_basemodel_subclass(obj) def _strip_internal_kwargs(params: dict[str, Any]) -> None: """Remove LangChain-internal keys that the SDK does not accept.""" for key in _INTERNAL_KWARGS: params.pop(key, None) def _has_file_content_blocks(message_dicts: list[dict[str, Any]]) -> bool: """Return `True` if any message dict contains a `file` content block.""" for msg in message_dicts: content = msg.get("content") if isinstance(content, list): for block in content: if isinstance(block, dict) and block.get("type") == "file": return True return False def _wrap_messages_for_sdk( message_dicts: list[dict[str, Any]], ) -> list[dict[str, Any]] | list[Any]: """Wrap message dicts as SDK Pydantic models when file blocks are present. The OpenRouter Python SDK does not include `file` in its `ChatMessageContentItem` discriminated union, so Pydantic validation rejects file content blocks even though the OpenRouter **API** supports them. Using `model_construct` on the SDK's message classes bypasses validation while still producing the correct JSON payload. When no file blocks are detected the original dicts are returned unchanged so the normal (validated) code path is preserved. Args: message_dicts: Message dicts produced by `_convert_message_to_dict`. Returns: The original list when no file blocks are present, or a list of SDK Pydantic model instances otherwise. """ if not _has_file_content_blocks(message_dicts): return message_dicts try: from openrouter import components # noqa: PLC0415 except ImportError: warnings.warn( "Could not import openrouter.components; file content blocks " "will be sent as raw dicts which may cause validation errors.", stacklevel=2, ) return message_dicts role_to_model: dict[str, type[BaseModel]] = { "user": components.ChatUserMessage, "system": components.ChatSystemMessage, "assistant": components.ChatAssistantMessage, "tool": components.ChatToolMessage, "developer": components.ChatDeveloperMessage, } wrapped: list[Any] = [] for msg in message_dicts: model_cls = role_to_model.get(msg.get("role", "")) if model_cls is None: warnings.warn( f"Unknown message role {msg.get('role')!r} encountered during " f"SDK wrapping; passing raw dict to the API.", stacklevel=2, ) wrapped.append(msg) continue wrapped.append(model_cls.model_construct(**msg)) return wrapped # # Type conversion helpers # def _convert_video_block_to_openrouter(block: dict[str, Any]) -> dict[str, Any]: """Convert a LangChain video content block to OpenRouter's `video_url` format. Args: block: A LangChain `VideoContentBlock`. Returns: A dict in OpenRouter's `video_url` format. Raises: ValueError: If no video source is provided. """ if "url" in block: return {"type": "video_url", "video_url": {"url": block["url"]}} if "base64" in block or block.get("source_type") == "base64": base64_data = block["data"] if "source_type" in block else block["base64"] mime_type = block.get("mime_type", "video/mp4") return { "type": "video_url", "video_url": {"url": f"data:{mime_type};base64,{base64_data}"}, } msg = "Video block must have either 'url' or 'base64' data." raise ValueError(msg) def _convert_file_block_to_openrouter(block: dict[str, Any]) -> dict[str, Any]: """Convert a LangChain file content block to OpenRouter's `file` format. OpenRouter accepts files as:: {"type": "file", "file": {"filename": "...", "file_data": "..."}} where `file_data` is either a public URL or a `data:` URI. Args: block: A LangChain file content block. Returns: A dict in OpenRouter's `file` format. Raises: ValueError: If the block contains neither a URL, base64 data, nor a file ID. """ file: dict[str, str] = {} # --- resolve file_data --------------------------------------------------- if "url" in block: file["file_data"] = block["url"] elif block.get("source_type") == "base64" or "base64" in block: base64_data = block["data"] if "source_type" in block else block["base64"] mime_type = block.get("mime_type", "application/octet-stream") file["file_data"] = f"data:{mime_type};base64,{base64_data}" elif block.get("source_type") == "id" or "file_id" in block: msg = "OpenRouter does not support file IDs." raise ValueError(msg) else: msg = "File block must have either 'url' or 'base64' data." raise ValueError(msg) # --- resolve filename ---------------------------------------------------- if filename := block.get("filename"): file["filename"] = filename elif ((extras := block.get("extras")) and "filename" in extras) or ( (extras := block.get("metadata")) and "filename" in extras ): file["filename"] = extras["filename"] return {"type": "file", "file": file} def _format_message_content(content: Any) -> Any: """Format message content for OpenRouter API. Converts LangChain data content blocks to the expected format. Args: content: The message content (string or list of content blocks). Returns: Formatted content suitable for the OpenRouter API. """ if content and isinstance(content, list): formatted: list = [] for block in content: if isinstance(block, dict) and is_data_content_block(block): if block.get("type") == "video": formatted.append(_convert_video_block_to_openrouter(block)) elif block.get("type") == "file": formatted.append(_convert_file_block_to_openrouter(block)) else: formatted.append(convert_to_openai_data_block(block)) else: formatted.append(block) return formatted return content def _merge_reasoning_run(run: list[dict[str, Any]]) -> dict[str, Any]: """Merge a run of consecutive same-`(type, index)` reasoning fragments.""" merged_entry: dict[str, Any] = {} text_parts: list[str] = [] has_text = False for frag in run: for k, v in frag.items(): if k == "text": has_text = True if v: text_parts.append(v) elif v is not None: merged_entry[k] = v if has_text: merged_entry["text"] = "".join(text_parts) return merged_entry def _merge_reasoning_details( details: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Merge fragmented `reasoning_details` from streaming chunk concatenation. During streaming, `AIMessageChunk.__add__` list-concatenates `reasoning_details` in `additional_kwargs`, fragmenting a single entry into many. When serialized back to the API via `_convert_message_to_dict`, these fragments cause `BadRequestResponseError` on multi-turn conversations (the provider rejects the malformed thinking block with `Invalid signature`). Streaming deltas tag each fragment with the `index` of the entry it belongs to in the original (non-streamed) array, so this function groups consecutive entries by `(type, index)` and merges each group into one. Entries without an `index` are preserved as-is, since non-streaming responses can legitimately contain multiple entries. Within a merged group, `text` values are concatenated in order. Other metadata fields (e.g. `format`, `signature`) use last-non-`None`-wins semantics, which preserves stable provider metadata without concatenating repeated strings — Anthropic-style reasoning streams emit a single signature-bearing fragment at the end of the block. A list with zero or one items passes through unchanged. """ if not isinstance(details, list) or len(details) <= 1: return details merged: list[dict[str, Any]] = [] i = 0 while i < len(details): entry = details[i] # Without an index we cannot distinguish streaming fragments from # distinct non-streaming entries, so leave them alone. Same for any # non-dict items that may have slipped in upstream. if not isinstance(entry, dict) or entry.get("index") is None: merged.append(entry) i += 1 continue entry_type = entry.get("type", "") entry_index = entry["index"] run = [entry] i += 1 while i < len(details): nxt = details[i] if ( isinstance(nxt, dict) and nxt.get("type", "") == entry_type and nxt.get("index") == entry_index ): run.append(nxt) i += 1 else: break merged.append(entry if len(run) == 1 else _merge_reasoning_run(run)) return merged def _convert_message_to_dict(message: BaseMessage) -> dict[str, Any]: # noqa: C901, PLR0912 """Convert a LangChain message to an OpenRouter-compatible dict payload. Handles role mapping, multimodal content formatting, tool call serialization, and reasoning content preservation for multi-turn conversations. Args: message: The LangChain message. Returns: A dict suitable for the OpenRouter chat API `messages` parameter. """ message_dict: dict[str, Any] if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} elif isinstance(message, HumanMessage): message_dict = { "role": "user", "content": _format_message_content(message.content), } elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} # Filter out non-text blocks from list content if isinstance(message.content, list): text_blocks = [ block for block in message.content if isinstance(block, dict) and block.get("type") == "text" ] message_dict["content"] = text_blocks or "" if message.tool_calls or message.invalid_tool_calls: message_dict["tool_calls"] = [ _lc_tool_call_to_openrouter_tool_call(tc) for tc in message.tool_calls ] + [ _lc_invalid_tool_call_to_openrouter_tool_call(tc) for tc in message.invalid_tool_calls ] if message_dict["content"] == "" or ( isinstance(message_dict["content"], list) and not message_dict["content"] ): message_dict["content"] = None elif "tool_calls" in message.additional_kwargs: message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] if message_dict["content"] == "" or ( isinstance(message_dict["content"], list) and not message_dict["content"] ): message_dict["content"] = None # Preserve reasoning content for multi-turn conversations (e.g. # tool-calling loops). OpenRouter stores reasoning text in `reasoning` # and structured fragment details in `reasoning_details`; the latter # is merged before serialization to undo streaming fragmentation. if "reasoning_content" in message.additional_kwargs: message_dict["reasoning"] = message.additional_kwargs["reasoning_content"] if "reasoning_details" in message.additional_kwargs: message_dict["reasoning_details"] = _merge_reasoning_details( message.additional_kwargs["reasoning_details"] ) elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolMessage): message_dict = { "role": "tool", "content": message.content, "tool_call_id": message.tool_call_id, } else: msg = f"Got unknown type {message}" raise TypeError(msg) if "name" in message.additional_kwargs: message_dict["name"] = message.additional_kwargs["name"] return message_dict def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: # noqa: C901 """Convert an OpenRouter API response message dict to a LangChain message. Extracts tool calls, reasoning content, and maps roles to the appropriate LangChain message type (`HumanMessage`, `AIMessage`, `SystemMessage`, `ToolMessage`, or `ChatMessage`). Args: _dict: The message dictionary from the API response. Returns: The corresponding LangChain message. """ id_ = _dict.get("id") role = _dict.get("role") if role == "user": return HumanMessage(content=_dict.get("content", "")) if role == "assistant": content = _dict.get("content", "") or "" additional_kwargs: dict = {} if reasoning := _dict.get("reasoning"): additional_kwargs["reasoning_content"] = reasoning if reasoning_details := _dict.get("reasoning_details"): additional_kwargs["reasoning_details"] = reasoning_details tool_calls = [] invalid_tool_calls = [] if raw_tool_calls := _dict.get("tool_calls"): for raw_tool_call in raw_tool_calls: try: tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) except Exception as e: # noqa: BLE001, PERF203 invalid_tool_calls.append( make_invalid_tool_call(raw_tool_call, str(e)) ) return AIMessage( content=content, id=id_, additional_kwargs=additional_kwargs, tool_calls=tool_calls, invalid_tool_calls=invalid_tool_calls, response_metadata={"model_provider": "openrouter"}, ) if role == "system": return SystemMessage(content=_dict.get("content", "")) if role == "tool": additional_kwargs = {} if "name" in _dict: additional_kwargs["name"] = _dict["name"] return ToolMessage( content=_dict.get("content", ""), tool_call_id=_dict.get("tool_call_id"), additional_kwargs=additional_kwargs, ) if role is None: msg = ( f"OpenRouter response message is missing the 'role' field. " f"Message keys: {list(_dict.keys())}" ) raise ValueError(msg) warnings.warn( f"Unrecognized message role '{role}' from OpenRouter. " f"Falling back to ChatMessage.", stacklevel=2, ) return ChatMessage(content=_dict.get("content", ""), role=role) def _convert_chunk_to_message_chunk( # noqa: C901, PLR0911, PLR0912 chunk: Mapping[str, Any], default_class: type[BaseMessageChunk] ) -> BaseMessageChunk: """Convert a streaming chunk dict to a LangChain message chunk. Args: chunk: The streaming chunk dictionary. default_class: Default message chunk class. Returns: The LangChain message chunk. """ choice = chunk["choices"][0] _dict = choice.get("delta", {}) role = cast("str", _dict.get("role")) content = cast("str", _dict.get("content") or "") additional_kwargs: dict = {} tool_call_chunks: list = [] if raw_tool_calls := _dict.get("tool_calls"): for rtc in raw_tool_calls: try: tool_call_chunks.append( tool_call_chunk( name=rtc["function"].get("name"), args=rtc["function"].get("arguments"), id=rtc.get("id"), index=rtc["index"], ) ) except (KeyError, TypeError, AttributeError): # noqa: PERF203 warnings.warn( f"Skipping malformed tool call chunk during streaming: " f"unexpected structure in {rtc!r}.", stacklevel=2, ) if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) if role == "assistant" or default_class == AIMessageChunk: if reasoning := _dict.get("reasoning"): additional_kwargs["reasoning_content"] = reasoning if reasoning_details := _dict.get("reasoning_details"): additional_kwargs["reasoning_details"] = reasoning_details usage_metadata = None response_metadata: dict[str, Any] = {"model_provider": "openrouter"} if usage := chunk.get("usage"): usage_metadata = _create_usage_metadata(usage) # Surface OpenRouter cost data in response_metadata if "cost" in usage: response_metadata["cost"] = usage["cost"] if "cost_details" in usage: response_metadata["cost_details"] = usage["cost_details"] return AIMessageChunk( content=content, additional_kwargs=additional_kwargs, tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] usage_metadata=usage_metadata, # type: ignore[arg-type] response_metadata=response_metadata, ) if role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) if role == "tool" or default_class == ToolMessageChunk: return ToolMessageChunk( content=content, tool_call_id=_dict.get("tool_call_id", "") ) if role: warnings.warn( f"Unrecognized streaming chunk role '{role}' from OpenRouter. " f"Falling back to ChatMessageChunk.", stacklevel=2, ) return ChatMessageChunk(content=content, role=role) if default_class is ChatMessageChunk: return ChatMessageChunk(content=content, role=role or "") return default_class(content=content) # type: ignore[call-arg] def _lc_tool_call_to_openrouter_tool_call(tool_call: ToolCall) -> dict[str, Any]: """Convert a LangChain ``ToolCall`` to an OpenRouter tool call dict. Serializes `args` (a dict) via `json.dumps`. """ return { "type": "function", "id": tool_call["id"], "function": { "name": tool_call["name"], "arguments": json.dumps(tool_call["args"], ensure_ascii=False), }, } def _lc_invalid_tool_call_to_openrouter_tool_call( invalid_tool_call: InvalidToolCall, ) -> dict[str, Any]: """Convert a LangChain `InvalidToolCall` to an OpenRouter tool call dict. Unlike the valid variant, `args` is already a raw string (not a dict) and is passed through as-is. """ return { "type": "function", "id": invalid_tool_call["id"], "function": { "name": invalid_tool_call["name"], "arguments": invalid_tool_call["args"], }, } def _create_usage_metadata(token_usage: dict[str, Any]) -> UsageMetadata: """Create usage metadata from OpenRouter token usage response. OpenRouter may return token counts as floats rather than ints, so all values are explicitly cast to int. Args: token_usage: Token usage dict from the API response. Returns: Usage metadata with input/output token details. """ _input = token_usage.get("prompt_tokens") input_tokens = int( _input if _input is not None else (token_usage.get("input_tokens") or 0) ) _output = token_usage.get("completion_tokens") output_tokens = int( _output if _output is not None else (token_usage.get("output_tokens") or 0) ) _total = token_usage.get("total_tokens") total_tokens = int(_total if _total is not None else input_tokens + output_tokens) input_details_dict = ( token_usage.get("prompt_tokens_details") or token_usage.get("input_tokens_details") or {} ) output_details_dict = ( token_usage.get("completion_tokens_details") or token_usage.get("output_tokens_details") or {} ) cache_read = input_details_dict.get("cached_tokens") cache_creation = input_details_dict.get("cache_write_tokens") input_token_details: dict = { "cache_read": int(cache_read) if cache_read is not None else None, "cache_creation": int(cache_creation) if cache_creation is not None else None, } reasoning_tokens = output_details_dict.get("reasoning_tokens") output_token_details: dict = { "reasoning": int(reasoning_tokens) if reasoning_tokens is not None else None, } usage_metadata: UsageMetadata = { "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens, } filtered_input = {k: v for k, v in input_token_details.items() if v is not None} if filtered_input: usage_metadata["input_token_details"] = InputTokenDetails(**filtered_input) # type: ignore[typeddict-item] filtered_output = {k: v for k, v in output_token_details.items() if v is not None} if filtered_output: usage_metadata["output_token_details"] = OutputTokenDetails(**filtered_output) # type: ignore[typeddict-item] return usage_metadata