From 92ee772761d7522f40d4c7e2070f58872b727381 Mon Sep 17 00:00:00 2001 From: Alexander Olsen Date: Wed, 10 Jun 2026 14:03:28 -0400 Subject: [PATCH] feat(langchain): add `ProviderToolSearchMiddleware` (#37969) [Docs](https://github.com/langchain-ai/docs/pull/4355) Adds `ProviderToolSearchMiddleware` to let agents defer selected tools behind OpenAI/Anthropic provider-native tool search while preserving existing `extras={"defer_loading": True}` behavior. The middleware validates searchable tool names, injects the provider search tool only when a tool is deferred, and rejects unsupported providers up front. Made by [Open SWE](https://openswe.vercel.app) --------- Co-authored-by: Alexander Olsen <13665641+aolsenjazz@users.noreply.github.com> Co-authored-by: open-swe[bot] Co-authored-by: Mason Daugherty Co-authored-by: Mason Daugherty --- .../langchain/agents/middleware/__init__.py | 2 + .../agents/middleware/provider_tool_search.py | 303 ++++++++++++ .../test_provider_tool_search.py | 432 ++++++++++++++++++ 3 files changed, 737 insertions(+) create mode 100644 libs/langchain_v1/langchain/agents/middleware/provider_tool_search.py create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_provider_tool_search.py diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index d5f5e3a2ea2..b107e8fa043 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -12,6 +12,7 @@ from langchain.agents.middleware.model_call_limit import ModelCallLimitMiddlewar from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware from langchain.agents.middleware.model_retry import ModelRetryMiddleware from langchain.agents.middleware.pii import PIIDetectionError, PIIMiddleware +from langchain.agents.middleware.provider_tool_search import ProviderToolSearchMiddleware from langchain.agents.middleware.shell_tool import ( CodexSandboxExecutionPolicy, DockerExecutionPolicy, @@ -65,6 +66,7 @@ __all__ = [ "ModelRetryMiddleware", "PIIDetectionError", "PIIMiddleware", + "ProviderToolSearchMiddleware", "RedactionRule", "Runtime", "ShellToolMiddleware", diff --git a/libs/langchain_v1/langchain/agents/middleware/provider_tool_search.py b/libs/langchain_v1/langchain/agents/middleware/provider_tool_search.py new file mode 100644 index 00000000000..b6255a53333 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/provider_tool_search.py @@ -0,0 +1,303 @@ +"""Provider-side tool search middleware.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeAlias + +from langchain_core.tools import BaseTool +from typing_extensions import NotRequired, TypedDict + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from langchain_core.language_models.chat_models import BaseChatModel + from langchain_core.messages import AIMessage + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, +) + +ToolIdentifier: TypeAlias = str | BaseTool +"""Tool name or tool instance that can be deferred behind provider tool search.""" + + +class _ServerToolSearchSpec(TypedDict): + """Provider-native tool search tool descriptor sent to the model as a tool.""" + + type: str + name: NotRequired[str] + + +# Provider-native tool search descriptors keyed by normalized provider name (see +# `_normalize_provider`). This mapping is the single source of truth for which +# providers support server-side tool search. +# +# The identifiers below are version-stamped by the providers and can go stale; +# re-verify against provider docs when updating: +# - Anthropic: https://docs.langchain.com/oss/python/integrations/chat/anthropic#tool-search +# - OpenAI: server-side `tool_search` tool. +_SERVER_TOOL_SEARCH_TOOLS: dict[str, _ServerToolSearchSpec] = { + "anthropic": { + "type": "tool_search_tool_bm25_20251119", + "name": "tool_search_tool_bm25", + }, + "openai": {"type": "tool_search"}, +} + + +class ProviderToolSearchMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): + """Defer selected tools behind provider-native tool search. + + Instead of sending every tool schema on every turn, this middleware marks + selected tools as deferred (via `extras["defer_loading"]`) and injects the + provider's server-side tool search tool. The provider then retrieves the + full schema of a deferred tool only when the model needs it, which keeps the + request payload small when many tools are bound. + + A tool is deferred when its name (or instance) is passed in `searchable_tools`, + or when it already carries `extras["defer_loading"] is True`. + + Only providers with server-side tool search are supported (currently + Anthropic and OpenAI). The provider is inferred from the bound model. + + !!! warning + + This relies on provider-native tool search and only takes effect for + supported providers. If a tool is deferred but the model's provider + cannot be identified or does not support tool search, the model call + raises `ValueError`. When no tool is deferred, the middleware passes the + request through unchanged regardless of provider. + + Example: + ```python + from langchain.agents import create_agent + from langchain.agents.middleware import ProviderToolSearchMiddleware + + agent = create_agent( + "anthropic:claude-opus-4-8", + tools=[get_weather, send_email, lookup_order], + middleware=[ProviderToolSearchMiddleware(searchable_tools=["lookup_order"])], + ) + ``` + """ + + def __init__(self, *, searchable_tools: list[ToolIdentifier] | None = None) -> None: + """Initialize provider-side tool search. + + Args: + searchable_tools: Tools or tool names to defer behind provider-native + tool search. + """ + super().__init__() + self.searchable_tool_names = _to_tool_names(searchable_tools) + + def _prepare_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]: + """Prepare a model request with deferred tools and provider search. + + Validates that every name in `searchable_tools` is bound to the model, + then (only when at least one tool is deferred) resolves the model's + provider and injects the provider-native tool search tool. Requests with + no deferred tools pass through unchanged. + + Args: + request: Model request to prepare. + + Returns: + The original request when nothing is deferred, otherwise a new + request with deferred tools and the provider search tool appended. + + Raises: + ValueError: If `searchable_tools` references a tool not bound to the + model, or if a tool is deferred but the model's provider cannot + be identified or does not support server-side tool search. + """ + tools = request.tools + if self.searchable_tool_names: + available = {tool.name for tool in tools if isinstance(tool, BaseTool)} + unknown = sorted(self.searchable_tool_names - available) + if unknown: + msg = ( + "ProviderToolSearchMiddleware: searchable_tools references " + f"tool(s) not bound to the model: {', '.join(unknown)}" + ) + raise ValueError(msg) + + if not any(_is_deferred_tool(tool, self.searchable_tool_names) for tool in tools): + return request + + provider = _get_model_provider(request.model, request.runtime) + if provider is None: + msg = ( + "ProviderToolSearchMiddleware could not determine the provider for " + f"model {request.model.__class__.__name__!r}; server-side tool search " + f"supports: {', '.join(sorted(_SERVER_TOOL_SEARCH_TOOLS))}" + ) + raise ValueError(msg) + if provider not in _SERVER_TOOL_SEARCH_TOOLS: + msg = ( + "ProviderToolSearchMiddleware requires a provider with server-side " + f"tool search, but got {provider!r}; supported providers: " + f"{', '.join(sorted(_SERVER_TOOL_SEARCH_TOOLS))}" + ) + raise ValueError(msg) + + bound_tools = [_defer_tool_if_needed(tool, self.searchable_tool_names) for tool in tools] + return request.override(tools=[*bound_tools, dict(_SERVER_TOOL_SEARCH_TOOLS[provider])]) + + def wrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: + """Defer tools before invoking the model. + + Args: + request: Model request to execute. + handler: Callback that executes the model request. + + Returns: + The model call result. + + Raises: + ValueError: If `searchable_tools` references a tool not bound to the + model, or if a tool is deferred but the model's provider cannot + be identified or does not support server-side tool search. + """ + return handler(self._prepare_request(request)) + + async def awrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> ModelResponse[ResponseT] | AIMessage: + """Defer tools before asynchronously invoking the model. + + Args: + request: Model request to execute. + handler: Callback that executes the model request. + + Returns: + The model call result. + + Raises: + ValueError: If `searchable_tools` references a tool not bound to the + model, or if a tool is deferred but the model's provider cannot + be identified or does not support server-side tool search. + """ + return await handler(self._prepare_request(request)) + + +def _to_tool_names(tools: list[ToolIdentifier] | None) -> set[str]: + """Convert tool identifiers to names.""" + if tools is None: + return set() + return {tool if isinstance(tool, str) else tool.name for tool in tools} + + +def _is_deferred_tool(tool: BaseTool | dict[str, Any], tool_names: set[str]) -> bool: + """Return whether a tool should be deferred. + + Only `BaseTool` instances can be deferred; dict-form tools (e.g. provider + tool specs) have no `extras` or name to match and are never deferred. + """ + if not isinstance(tool, BaseTool): + return False + extras = tool.extras if isinstance(tool.extras, dict) else {} + return extras.get("defer_loading") is True or tool.name in tool_names + + +def _defer_tool_if_needed( + tool: BaseTool | dict[str, Any], tool_names: set[str] +) -> BaseTool | dict[str, Any]: + """Return the tool with `defer_loading` set, or unchanged if not deferred. + + Returns the input unchanged when the tool should not be deferred or is not a + `BaseTool` (only `BaseTool` instances carry the `extras` that flags deferral). + """ + if not _is_deferred_tool(tool, tool_names): + return tool + if not isinstance(tool, BaseTool): + return tool + extras = {**(tool.extras or {}), "defer_loading": True} + return tool.model_copy(update={"extras": extras}) + + +def _get_model_provider(model: BaseChatModel, runtime: Any) -> str | None: + """Infer the normalized provider name for server-side tool search. + + Returns `None` when no provider can be identified, so callers can + distinguish a detection failure from a provider that is simply unsupported. + """ + default_config = getattr(model, "_default_config", None) + model_params_fn = getattr(model, "_model_params", None) + if callable(model_params_fn): + config = getattr(runtime, "config", None) + # `_model_params` expects a config mapping (or None); coerce a malformed + # non-mapping config to None so it is treated as "no config" rather than + # raising deep inside the configurable model. + if config is not None and not isinstance(config, Mapping): + config = None + model_params = model_params_fn(config) + if isinstance(model_params, dict): + params = ( + {**default_config, **model_params} + if isinstance(default_config, dict) + else model_params + ) + if provider := _provider_from_params(params): + return provider + + if isinstance(default_config, dict) and (provider := _provider_from_params(default_config)): + return provider + + get_ls_params = getattr(model, "_get_ls_params", None) + if callable(get_ls_params): + ls_params = get_ls_params() + if isinstance(ls_params, dict) and isinstance(ls_params.get("ls_provider"), str): + return _normalize_provider(ls_params["ls_provider"]) + + return _provider_from_class_name(model.__class__.__name__) + + +def _provider_from_params(params: dict[str, Any]) -> str | None: + """Infer the provider from model parameters, or `None` if absent.""" + provider = params.get("model_provider") + if isinstance(provider, str): + return _normalize_provider(provider) + model_name = params.get("model") + if isinstance(model_name, str): + return _provider_from_model_name(model_name) + return None + + +def _provider_from_model_name(model_name: str) -> str | None: + """Infer the provider from a model name, or `None` if unrecognized.""" + provider, _, rest = model_name.partition(":") + if rest: + return _normalize_provider(provider) + model_lower = model_name.lower() + if model_lower.startswith("claude"): + return "anthropic" + if model_lower.startswith(("gpt-", "o1", "o3", "chatgpt")): + return "openai" + return None + + +def _provider_from_class_name(class_name: str) -> str | None: + """Infer the provider from a model class name, or `None` if unrecognized.""" + if class_name in {"ChatAnthropic", "AnthropicChat"}: + return "anthropic" + if class_name in {"ChatOpenAI", "OpenAIChat"}: + return "openai" + return None + + +def _normalize_provider(provider: str) -> str: + """Normalize a provider identifier by lowercasing and mapping `-` to `_`.""" + return provider.replace("-", "_").lower() diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_provider_tool_search.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_provider_tool_search.py new file mode 100644 index 00000000000..c6939c64e62 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_provider_tool_search.py @@ -0,0 +1,432 @@ +"""Unit tests for provider tool search middleware.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + +import pytest +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import BaseTool, tool + +from langchain.agents.middleware import ( + ModelRequest, + ModelResponse, + ProviderToolSearchMiddleware, +) + +if TYPE_CHECKING: + from langchain_core.language_models.chat_models import BaseChatModel + +ANTHROPIC_SEARCH_TOOL = { + "type": "tool_search_tool_bm25_20251119", + "name": "tool_search_tool_bm25", +} +OPENAI_SEARCH_TOOL = {"type": "tool_search"} + + +@tool +def get_weather(city: str) -> str: + """Get the weather for a city.""" + return f"Sunny in {city}" + + +@tool +def send_email(to: str) -> str: + """Send an email.""" + return f"Sent to {to}" + + +@tool(extras={"defer_loading": True}) +def lookup_order(order_id: str) -> str: + """Look up an order.""" + return f"Order {order_id} shipped" + + +@tool(extras={"category": "billing"}) +def refund_order(order_id: str) -> str: + """Refund an order.""" + return f"Refunded {order_id}" + + +class FakeModel: + def __init__(self, provider: str) -> None: + self.provider = provider + + def _get_ls_params(self) -> dict[str, str]: + return {"ls_provider": self.provider} + + +class FakeConfigurableModel: + def __init__( + self, + default_config: dict[str, Any] | None = None, + model_params: dict[str, Any] | None = None, + ) -> None: + self._default_config = default_config or {} + self.model_params = model_params or {} + + def _model_params(self, config: dict[str, Any] | None = None) -> dict[str, Any]: + return self.model_params if config is None else config.get("configurable", {}) + + +class ChatAnthropic: + """Bare model whose class name is the only provider signal.""" + + +class ChatOpenAI: + """Bare model whose class name is the only provider signal.""" + + +class MysteryModel: + """Model that exposes no provider signal at all.""" + + +class FakeRuntime: + def __init__(self, config: Any) -> None: + self.config = config + + +def _request(provider: str, tools: list[BaseTool | dict[str, Any]]) -> ModelRequest: + return ModelRequest( + model=cast("BaseChatModel", FakeModel(provider)), + messages=[HumanMessage("hi")], + tools=tools, + ) + + +def _invoke(middleware: ProviderToolSearchMiddleware, request: ModelRequest) -> ModelRequest: + captured_request = None + + def handler(model_request: ModelRequest) -> ModelResponse: + nonlocal captured_request + captured_request = model_request + return ModelResponse(result=[AIMessage("ok")]) + + middleware.wrap_model_call(request, handler) + assert captured_request is not None + return captured_request + + +def test_passes_through_when_no_tools_are_deferred() -> None: + """Without deferral the request must be returned untouched (no needless copy).""" + request = _request("anthropic", [get_weather, send_email]) + middleware = ProviderToolSearchMiddleware() + + modified_request = _invoke(middleware, request) + + assert modified_request is request + + +def test_defers_tools_named_in_searchable_tools() -> None: + """Only named tools get `defer_loading`; others stay intact and the search tool is added.""" + request = _request("anthropic", [get_weather, send_email]) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + modified_request = _invoke(middleware, request) + + email_tool = next( + tool + for tool in modified_request.tools + if isinstance(tool, BaseTool) and tool.name == "send_email" + ) + weather_tool = next( + tool + for tool in modified_request.tools + if isinstance(tool, BaseTool) and tool.name == "get_weather" + ) + assert email_tool.extras == {"defer_loading": True} + assert weather_tool.extras is None + assert ANTHROPIC_SEARCH_TOOL in modified_request.tools + + +def test_accepts_tool_instances_in_searchable_tools() -> None: + """`searchable_tools` accepts `BaseTool` instances, not just names (the other input form).""" + request = _request("openai", [get_weather, send_email]) + middleware = ProviderToolSearchMiddleware(searchable_tools=[send_email]) + + modified_request = _invoke(middleware, request) + + email_tool = next( + tool + for tool in modified_request.tools + if isinstance(tool, BaseTool) and tool.name == "send_email" + ) + assert email_tool.extras == {"defer_loading": True} + assert OPENAI_SEARCH_TOOL in modified_request.tools + + +def test_honors_tools_pre_marked_with_defer_loading() -> None: + """A pre-marked `defer_loading` tool triggers deferral even with no `searchable_tools`.""" + request = _request("anthropic", [get_weather, lookup_order]) + middleware = ProviderToolSearchMiddleware() + + modified_request = _invoke(middleware, request) + + order_tool = next( + tool + for tool in modified_request.tools + if isinstance(tool, BaseTool) and tool.name == "lookup_order" + ) + assert order_tool.extras == {"defer_loading": True} + assert ANTHROPIC_SEARCH_TOOL in modified_request.tools + + +def test_preserves_existing_extras_when_deferring() -> None: + """Deferral merges into existing `extras` rather than overwriting a user's other keys.""" + request = _request("anthropic", [refund_order]) + middleware = ProviderToolSearchMiddleware(searchable_tools=["refund_order"]) + + modified_request = _invoke(middleware, request) + + refund_tool = next( + tool + for tool in modified_request.tools + if isinstance(tool, BaseTool) and tool.name == "refund_order" + ) + assert refund_tool.extras == {"category": "billing", "defer_loading": True} + + +def test_passes_dict_tools_through_untouched() -> None: + """Dict-form tools have no name/extras and must pass through without an `AttributeError`.""" + web_search = {"type": "web_search"} + request = _request("anthropic", [get_weather, web_search]) + middleware = ProviderToolSearchMiddleware(searchable_tools=["get_weather"]) + + modified_request = _invoke(middleware, request) + + assert web_search in modified_request.tools + assert ANTHROPIC_SEARCH_TOOL in modified_request.tools + + +def test_raises_when_searchable_tool_is_not_bound() -> None: + """A typo'd `searchable_tools` name fails loudly instead of silently never deferring.""" + request = _request("anthropic", [get_weather]) + middleware = ProviderToolSearchMiddleware(searchable_tools=["missing_tool"]) + + with pytest.raises(ValueError, match="missing_tool"): + _invoke(middleware, request) + + +def test_unbound_tool_error_is_sorted() -> None: + """The error lists unknown tools in sorted order so the message is deterministic.""" + request = _request("anthropic", [get_weather]) + middleware = ProviderToolSearchMiddleware(searchable_tools=["zzz", "aaa"]) + + with pytest.raises(ValueError, match="aaa, zzz"): + _invoke(middleware, request) + + +def test_unbound_tool_check_precedes_provider_check() -> None: + """Config errors (unbound tool) surface before provider errors, pinning the check order.""" + request = _request("mistralai", [get_weather]) + middleware = ProviderToolSearchMiddleware(searchable_tools=["missing_tool"]) + + with pytest.raises(ValueError, match="not bound to the model"): + _invoke(middleware, request) + + +def test_passes_through_unsupported_provider_when_nothing_deferred() -> None: + """An unsupported provider must not raise when no tool is actually deferred.""" + request = _request("mistralai", [get_weather]) + middleware = ProviderToolSearchMiddleware() + + modified_request = _invoke(middleware, request) + + assert modified_request is request + + +def test_passes_through_unsupported_provider_with_empty_tools() -> None: + """Empty tool list is a clean no-op and never trips the provider guard.""" + request = _request("mistralai", []) + middleware = ProviderToolSearchMiddleware() + + modified_request = _invoke(middleware, request) + + assert modified_request is request + + +def test_raises_for_unsupported_provider_when_tool_deferred() -> None: + """Deferring against a provider without tool search is a hard error, not a silent drop.""" + request = _request("mistralai", [send_email]) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + with pytest.raises(ValueError, match="server-side tool search"): + _invoke(middleware, request) + + +def test_raises_when_provider_cannot_be_determined() -> None: + """Detection failure raises a distinct, actionable error rather than a misleading one.""" + request = ModelRequest( + model=cast("BaseChatModel", MysteryModel()), + messages=[HumanMessage("hi")], + tools=[lookup_order], + ) + middleware = ProviderToolSearchMiddleware() + + with pytest.raises(ValueError, match="could not determine the provider"): + _invoke(middleware, request) + + +@pytest.mark.parametrize( + ("model_factory", "expected_tool"), + [ + (ChatAnthropic, ANTHROPIC_SEARCH_TOOL), + (ChatOpenAI, OPENAI_SEARCH_TOOL), + ], +) +def test_detects_provider_from_class_name( + model_factory: type, expected_tool: dict[str, str] +) -> None: + """Class-name fallback identifies the provider when no params/ls_provider are exposed.""" + request = ModelRequest( + model=cast("BaseChatModel", model_factory()), + messages=[HumanMessage("hi")], + tools=[send_email], + ) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + modified_request = _invoke(middleware, request) + + assert expected_tool in modified_request.tools + + +def test_normalizes_provider_casing_and_hyphens() -> None: + """Provider identifiers are normalized so mixed casing still matches the registry.""" + request = _request("Anthropic", [send_email]) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + modified_request = _invoke(middleware, request) + + assert ANTHROPIC_SEARCH_TOOL in modified_request.tools + + +@pytest.mark.parametrize( + ("model_name", "expected_tool"), + [ + ("claude-sonnet-4-5", ANTHROPIC_SEARCH_TOOL), + ("gpt-5.4", OPENAI_SEARCH_TOOL), + ("o3-mini", OPENAI_SEARCH_TOOL), + ("chatgpt-4o-latest", OPENAI_SEARCH_TOOL), + ], +) +def test_detects_provider_from_bare_model_name( + model_name: str, expected_tool: dict[str, str] +) -> None: + """Bare model names (no `provider:` prefix) are mapped via the name heuristics.""" + request = ModelRequest( + model=cast("BaseChatModel", FakeConfigurableModel(model_params={"model": model_name})), + messages=[HumanMessage("hi")], + tools=[send_email], + ) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + modified_request = _invoke(middleware, request) + + assert expected_tool in modified_request.tools + + +def test_raises_for_unrecognized_bare_model_name() -> None: + """An unrecognized bare name yields detection failure, not a wrong-provider guess.""" + request = ModelRequest( + model=cast("BaseChatModel", FakeConfigurableModel(model_params={"model": "llama-3.1"})), + messages=[HumanMessage("hi")], + tools=[send_email], + ) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + with pytest.raises(ValueError, match="could not determine the provider"): + _invoke(middleware, request) + + +def test_detects_provider_from_configurable_model() -> None: + """A configurable model's `_default_config` `provider:model` string resolves the provider.""" + request = ModelRequest( + model=cast("BaseChatModel", FakeConfigurableModel({"model": "openai:gpt-5.4"})), + messages=[HumanMessage("hi")], + tools=[send_email], + ) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + modified_request = _invoke(middleware, request) + + assert OPENAI_SEARCH_TOOL in modified_request.tools + + +def test_detects_provider_from_runtime_configurable_model() -> None: + """Provider set only via runtime `configurable` is read through `_model_params`.""" + request = ModelRequest( + model=cast("BaseChatModel", FakeConfigurableModel()), + messages=[HumanMessage("hi")], + tools=[send_email], + runtime=cast("Any", FakeRuntime({"configurable": {"model_provider": "openai"}})), + ) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + modified_request = _invoke(middleware, request) + + assert OPENAI_SEARCH_TOOL in modified_request.tools + + +def test_runtime_model_override_uses_default_configurable_model_provider() -> None: + """`model_provider` in the merged params wins over a runtime `model` of another provider.""" + request = ModelRequest( + model=cast("BaseChatModel", FakeConfigurableModel({"model_provider": "openai"})), + messages=[HumanMessage("hi")], + tools=[send_email], + runtime=cast("Any", FakeRuntime({"configurable": {"model": "claude-sonnet-4-5"}})), + ) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + modified_request = _invoke(middleware, request) + + assert OPENAI_SEARCH_TOOL in modified_request.tools + + +def test_runtime_provider_override_uses_runtime_configurable_model_provider() -> None: + """Runtime `model_provider` overrides the default config's, confirming the merge precedence.""" + request = ModelRequest( + model=cast("BaseChatModel", FakeConfigurableModel({"model_provider": "openai"})), + messages=[HumanMessage("hi")], + tools=[send_email], + runtime=cast("Any", FakeRuntime({"configurable": {"model_provider": "anthropic"}})), + ) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + modified_request = _invoke(middleware, request) + + assert ANTHROPIC_SEARCH_TOOL in modified_request.tools + + +def test_ignores_non_mapping_runtime_config() -> None: + """A malformed (non-mapping) runtime config is treated as absent, not propagated to a crash.""" + request = ModelRequest( + model=cast( + "BaseChatModel", + FakeConfigurableModel(model_params={"model_provider": "openai"}), + ), + messages=[HumanMessage("hi")], + tools=[send_email], + runtime=cast("Any", FakeRuntime("not-a-mapping")), + ) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + + modified_request = _invoke(middleware, request) + + assert OPENAI_SEARCH_TOOL in modified_request.tools + + +async def test_async_wrap_model_call_defers_tools() -> None: + """The async path applies the same deferral as the sync path through shared logic.""" + request = _request("openai", [send_email]) + middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"]) + captured_request = None + + async def handler(model_request: ModelRequest) -> ModelResponse: + nonlocal captured_request + captured_request = model_request + return ModelResponse(result=[AIMessage("ok")]) + + await middleware.awrap_model_call(request, handler) + + assert captured_request is not None + assert OPENAI_SEARCH_TOOL in captured_request.tools