feat(langchain): add ProviderToolSearchMiddleware (#37969)

[Docs](https://github.com/langchain-ai/docs/pull/4355)

Adds `ProviderToolSearchMiddleware` to let agents defer selected tools
behind OpenAI/Anthropic provider-native tool search while preserving
existing `extras={"defer_loading": True}` behavior. The middleware
validates searchable tool names, injects the provider search tool only
when a tool is deferred, and rejects unsupported providers up front.

Made by [Open SWE](https://openswe.vercel.app)

---------

Co-authored-by: Alexander Olsen <13665641+aolsenjazz@users.noreply.github.com>
Co-authored-by: open-swe[bot] <open-swe@users.noreply.github.com>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Alexander Olsen
2026-06-10 14:03:28 -04:00
committed by GitHub
parent 23ce677870
commit 92ee772761
3 changed files with 737 additions and 0 deletions

View File

@@ -12,6 +12,7 @@ from langchain.agents.middleware.model_call_limit import ModelCallLimitMiddlewar
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents.middleware.model_retry import ModelRetryMiddleware
from langchain.agents.middleware.pii import PIIDetectionError, PIIMiddleware
from langchain.agents.middleware.provider_tool_search import ProviderToolSearchMiddleware
from langchain.agents.middleware.shell_tool import (
CodexSandboxExecutionPolicy,
DockerExecutionPolicy,
@@ -65,6 +66,7 @@ __all__ = [
"ModelRetryMiddleware",
"PIIDetectionError",
"PIIMiddleware",
"ProviderToolSearchMiddleware",
"RedactionRule",
"Runtime",
"ShellToolMiddleware",

View File

@@ -0,0 +1,303 @@
"""Provider-side tool search middleware."""
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, TypeAlias
from langchain_core.tools import BaseTool
from typing_extensions import NotRequired, TypedDict
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
ToolIdentifier: TypeAlias = str | BaseTool
"""Tool name or tool instance that can be deferred behind provider tool search."""
class _ServerToolSearchSpec(TypedDict):
"""Provider-native tool search tool descriptor sent to the model as a tool."""
type: str
name: NotRequired[str]
# Provider-native tool search descriptors keyed by normalized provider name (see
# `_normalize_provider`). This mapping is the single source of truth for which
# providers support server-side tool search.
#
# The identifiers below are version-stamped by the providers and can go stale;
# re-verify against provider docs when updating:
# - Anthropic: https://docs.langchain.com/oss/python/integrations/chat/anthropic#tool-search
# - OpenAI: server-side `tool_search` tool.
_SERVER_TOOL_SEARCH_TOOLS: dict[str, _ServerToolSearchSpec] = {
"anthropic": {
"type": "tool_search_tool_bm25_20251119",
"name": "tool_search_tool_bm25",
},
"openai": {"type": "tool_search"},
}
class ProviderToolSearchMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Defer selected tools behind provider-native tool search.
Instead of sending every tool schema on every turn, this middleware marks
selected tools as deferred (via `extras["defer_loading"]`) and injects the
provider's server-side tool search tool. The provider then retrieves the
full schema of a deferred tool only when the model needs it, which keeps the
request payload small when many tools are bound.
A tool is deferred when its name (or instance) is passed in `searchable_tools`,
or when it already carries `extras["defer_loading"] is True`.
Only providers with server-side tool search are supported (currently
Anthropic and OpenAI). The provider is inferred from the bound model.
!!! warning
This relies on provider-native tool search and only takes effect for
supported providers. If a tool is deferred but the model's provider
cannot be identified or does not support tool search, the model call
raises `ValueError`. When no tool is deferred, the middleware passes the
request through unchanged regardless of provider.
Example:
```python
from langchain.agents import create_agent
from langchain.agents.middleware import ProviderToolSearchMiddleware
agent = create_agent(
"anthropic:claude-opus-4-8",
tools=[get_weather, send_email, lookup_order],
middleware=[ProviderToolSearchMiddleware(searchable_tools=["lookup_order"])],
)
```
"""
def __init__(self, *, searchable_tools: list[ToolIdentifier] | None = None) -> None:
"""Initialize provider-side tool search.
Args:
searchable_tools: Tools or tool names to defer behind provider-native
tool search.
"""
super().__init__()
self.searchable_tool_names = _to_tool_names(searchable_tools)
def _prepare_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
"""Prepare a model request with deferred tools and provider search.
Validates that every name in `searchable_tools` is bound to the model,
then (only when at least one tool is deferred) resolves the model's
provider and injects the provider-native tool search tool. Requests with
no deferred tools pass through unchanged.
Args:
request: Model request to prepare.
Returns:
The original request when nothing is deferred, otherwise a new
request with deferred tools and the provider search tool appended.
Raises:
ValueError: If `searchable_tools` references a tool not bound to the
model, or if a tool is deferred but the model's provider cannot
be identified or does not support server-side tool search.
"""
tools = request.tools
if self.searchable_tool_names:
available = {tool.name for tool in tools if isinstance(tool, BaseTool)}
unknown = sorted(self.searchable_tool_names - available)
if unknown:
msg = (
"ProviderToolSearchMiddleware: searchable_tools references "
f"tool(s) not bound to the model: {', '.join(unknown)}"
)
raise ValueError(msg)
if not any(_is_deferred_tool(tool, self.searchable_tool_names) for tool in tools):
return request
provider = _get_model_provider(request.model, request.runtime)
if provider is None:
msg = (
"ProviderToolSearchMiddleware could not determine the provider for "
f"model {request.model.__class__.__name__!r}; server-side tool search "
f"supports: {', '.join(sorted(_SERVER_TOOL_SEARCH_TOOLS))}"
)
raise ValueError(msg)
if provider not in _SERVER_TOOL_SEARCH_TOOLS:
msg = (
"ProviderToolSearchMiddleware requires a provider with server-side "
f"tool search, but got {provider!r}; supported providers: "
f"{', '.join(sorted(_SERVER_TOOL_SEARCH_TOOLS))}"
)
raise ValueError(msg)
bound_tools = [_defer_tool_if_needed(tool, self.searchable_tool_names) for tool in tools]
return request.override(tools=[*bound_tools, dict(_SERVER_TOOL_SEARCH_TOOLS[provider])])
def wrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Defer tools before invoking the model.
Args:
request: Model request to execute.
handler: Callback that executes the model request.
Returns:
The model call result.
Raises:
ValueError: If `searchable_tools` references a tool not bound to the
model, or if a tool is deferred but the model's provider cannot
be identified or does not support server-side tool search.
"""
return handler(self._prepare_request(request))
async def awrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Defer tools before asynchronously invoking the model.
Args:
request: Model request to execute.
handler: Callback that executes the model request.
Returns:
The model call result.
Raises:
ValueError: If `searchable_tools` references a tool not bound to the
model, or if a tool is deferred but the model's provider cannot
be identified or does not support server-side tool search.
"""
return await handler(self._prepare_request(request))
def _to_tool_names(tools: list[ToolIdentifier] | None) -> set[str]:
"""Convert tool identifiers to names."""
if tools is None:
return set()
return {tool if isinstance(tool, str) else tool.name for tool in tools}
def _is_deferred_tool(tool: BaseTool | dict[str, Any], tool_names: set[str]) -> bool:
"""Return whether a tool should be deferred.
Only `BaseTool` instances can be deferred; dict-form tools (e.g. provider
tool specs) have no `extras` or name to match and are never deferred.
"""
if not isinstance(tool, BaseTool):
return False
extras = tool.extras if isinstance(tool.extras, dict) else {}
return extras.get("defer_loading") is True or tool.name in tool_names
def _defer_tool_if_needed(
tool: BaseTool | dict[str, Any], tool_names: set[str]
) -> BaseTool | dict[str, Any]:
"""Return the tool with `defer_loading` set, or unchanged if not deferred.
Returns the input unchanged when the tool should not be deferred or is not a
`BaseTool` (only `BaseTool` instances carry the `extras` that flags deferral).
"""
if not _is_deferred_tool(tool, tool_names):
return tool
if not isinstance(tool, BaseTool):
return tool
extras = {**(tool.extras or {}), "defer_loading": True}
return tool.model_copy(update={"extras": extras})
def _get_model_provider(model: BaseChatModel, runtime: Any) -> str | None:
"""Infer the normalized provider name for server-side tool search.
Returns `None` when no provider can be identified, so callers can
distinguish a detection failure from a provider that is simply unsupported.
"""
default_config = getattr(model, "_default_config", None)
model_params_fn = getattr(model, "_model_params", None)
if callable(model_params_fn):
config = getattr(runtime, "config", None)
# `_model_params` expects a config mapping (or None); coerce a malformed
# non-mapping config to None so it is treated as "no config" rather than
# raising deep inside the configurable model.
if config is not None and not isinstance(config, Mapping):
config = None
model_params = model_params_fn(config)
if isinstance(model_params, dict):
params = (
{**default_config, **model_params}
if isinstance(default_config, dict)
else model_params
)
if provider := _provider_from_params(params):
return provider
if isinstance(default_config, dict) and (provider := _provider_from_params(default_config)):
return provider
get_ls_params = getattr(model, "_get_ls_params", None)
if callable(get_ls_params):
ls_params = get_ls_params()
if isinstance(ls_params, dict) and isinstance(ls_params.get("ls_provider"), str):
return _normalize_provider(ls_params["ls_provider"])
return _provider_from_class_name(model.__class__.__name__)
def _provider_from_params(params: dict[str, Any]) -> str | None:
"""Infer the provider from model parameters, or `None` if absent."""
provider = params.get("model_provider")
if isinstance(provider, str):
return _normalize_provider(provider)
model_name = params.get("model")
if isinstance(model_name, str):
return _provider_from_model_name(model_name)
return None
def _provider_from_model_name(model_name: str) -> str | None:
"""Infer the provider from a model name, or `None` if unrecognized."""
provider, _, rest = model_name.partition(":")
if rest:
return _normalize_provider(provider)
model_lower = model_name.lower()
if model_lower.startswith("claude"):
return "anthropic"
if model_lower.startswith(("gpt-", "o1", "o3", "chatgpt")):
return "openai"
return None
def _provider_from_class_name(class_name: str) -> str | None:
"""Infer the provider from a model class name, or `None` if unrecognized."""
if class_name in {"ChatAnthropic", "AnthropicChat"}:
return "anthropic"
if class_name in {"ChatOpenAI", "OpenAIChat"}:
return "openai"
return None
def _normalize_provider(provider: str) -> str:
"""Normalize a provider identifier by lowercasing and mapping `-` to `_`."""
return provider.replace("-", "_").lower()

View File

@@ -0,0 +1,432 @@
"""Unit tests for provider tool search middleware."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import BaseTool, tool
from langchain.agents.middleware import (
ModelRequest,
ModelResponse,
ProviderToolSearchMiddleware,
)
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
ANTHROPIC_SEARCH_TOOL = {
"type": "tool_search_tool_bm25_20251119",
"name": "tool_search_tool_bm25",
}
OPENAI_SEARCH_TOOL = {"type": "tool_search"}
@tool
def get_weather(city: str) -> str:
"""Get the weather for a city."""
return f"Sunny in {city}"
@tool
def send_email(to: str) -> str:
"""Send an email."""
return f"Sent to {to}"
@tool(extras={"defer_loading": True})
def lookup_order(order_id: str) -> str:
"""Look up an order."""
return f"Order {order_id} shipped"
@tool(extras={"category": "billing"})
def refund_order(order_id: str) -> str:
"""Refund an order."""
return f"Refunded {order_id}"
class FakeModel:
def __init__(self, provider: str) -> None:
self.provider = provider
def _get_ls_params(self) -> dict[str, str]:
return {"ls_provider": self.provider}
class FakeConfigurableModel:
def __init__(
self,
default_config: dict[str, Any] | None = None,
model_params: dict[str, Any] | None = None,
) -> None:
self._default_config = default_config or {}
self.model_params = model_params or {}
def _model_params(self, config: dict[str, Any] | None = None) -> dict[str, Any]:
return self.model_params if config is None else config.get("configurable", {})
class ChatAnthropic:
"""Bare model whose class name is the only provider signal."""
class ChatOpenAI:
"""Bare model whose class name is the only provider signal."""
class MysteryModel:
"""Model that exposes no provider signal at all."""
class FakeRuntime:
def __init__(self, config: Any) -> None:
self.config = config
def _request(provider: str, tools: list[BaseTool | dict[str, Any]]) -> ModelRequest:
return ModelRequest(
model=cast("BaseChatModel", FakeModel(provider)),
messages=[HumanMessage("hi")],
tools=tools,
)
def _invoke(middleware: ProviderToolSearchMiddleware, request: ModelRequest) -> ModelRequest:
captured_request = None
def handler(model_request: ModelRequest) -> ModelResponse:
nonlocal captured_request
captured_request = model_request
return ModelResponse(result=[AIMessage("ok")])
middleware.wrap_model_call(request, handler)
assert captured_request is not None
return captured_request
def test_passes_through_when_no_tools_are_deferred() -> None:
"""Without deferral the request must be returned untouched (no needless copy)."""
request = _request("anthropic", [get_weather, send_email])
middleware = ProviderToolSearchMiddleware()
modified_request = _invoke(middleware, request)
assert modified_request is request
def test_defers_tools_named_in_searchable_tools() -> None:
"""Only named tools get `defer_loading`; others stay intact and the search tool is added."""
request = _request("anthropic", [get_weather, send_email])
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
modified_request = _invoke(middleware, request)
email_tool = next(
tool
for tool in modified_request.tools
if isinstance(tool, BaseTool) and tool.name == "send_email"
)
weather_tool = next(
tool
for tool in modified_request.tools
if isinstance(tool, BaseTool) and tool.name == "get_weather"
)
assert email_tool.extras == {"defer_loading": True}
assert weather_tool.extras is None
assert ANTHROPIC_SEARCH_TOOL in modified_request.tools
def test_accepts_tool_instances_in_searchable_tools() -> None:
"""`searchable_tools` accepts `BaseTool` instances, not just names (the other input form)."""
request = _request("openai", [get_weather, send_email])
middleware = ProviderToolSearchMiddleware(searchable_tools=[send_email])
modified_request = _invoke(middleware, request)
email_tool = next(
tool
for tool in modified_request.tools
if isinstance(tool, BaseTool) and tool.name == "send_email"
)
assert email_tool.extras == {"defer_loading": True}
assert OPENAI_SEARCH_TOOL in modified_request.tools
def test_honors_tools_pre_marked_with_defer_loading() -> None:
"""A pre-marked `defer_loading` tool triggers deferral even with no `searchable_tools`."""
request = _request("anthropic", [get_weather, lookup_order])
middleware = ProviderToolSearchMiddleware()
modified_request = _invoke(middleware, request)
order_tool = next(
tool
for tool in modified_request.tools
if isinstance(tool, BaseTool) and tool.name == "lookup_order"
)
assert order_tool.extras == {"defer_loading": True}
assert ANTHROPIC_SEARCH_TOOL in modified_request.tools
def test_preserves_existing_extras_when_deferring() -> None:
"""Deferral merges into existing `extras` rather than overwriting a user's other keys."""
request = _request("anthropic", [refund_order])
middleware = ProviderToolSearchMiddleware(searchable_tools=["refund_order"])
modified_request = _invoke(middleware, request)
refund_tool = next(
tool
for tool in modified_request.tools
if isinstance(tool, BaseTool) and tool.name == "refund_order"
)
assert refund_tool.extras == {"category": "billing", "defer_loading": True}
def test_passes_dict_tools_through_untouched() -> None:
"""Dict-form tools have no name/extras and must pass through without an `AttributeError`."""
web_search = {"type": "web_search"}
request = _request("anthropic", [get_weather, web_search])
middleware = ProviderToolSearchMiddleware(searchable_tools=["get_weather"])
modified_request = _invoke(middleware, request)
assert web_search in modified_request.tools
assert ANTHROPIC_SEARCH_TOOL in modified_request.tools
def test_raises_when_searchable_tool_is_not_bound() -> None:
"""A typo'd `searchable_tools` name fails loudly instead of silently never deferring."""
request = _request("anthropic", [get_weather])
middleware = ProviderToolSearchMiddleware(searchable_tools=["missing_tool"])
with pytest.raises(ValueError, match="missing_tool"):
_invoke(middleware, request)
def test_unbound_tool_error_is_sorted() -> None:
"""The error lists unknown tools in sorted order so the message is deterministic."""
request = _request("anthropic", [get_weather])
middleware = ProviderToolSearchMiddleware(searchable_tools=["zzz", "aaa"])
with pytest.raises(ValueError, match="aaa, zzz"):
_invoke(middleware, request)
def test_unbound_tool_check_precedes_provider_check() -> None:
"""Config errors (unbound tool) surface before provider errors, pinning the check order."""
request = _request("mistralai", [get_weather])
middleware = ProviderToolSearchMiddleware(searchable_tools=["missing_tool"])
with pytest.raises(ValueError, match="not bound to the model"):
_invoke(middleware, request)
def test_passes_through_unsupported_provider_when_nothing_deferred() -> None:
"""An unsupported provider must not raise when no tool is actually deferred."""
request = _request("mistralai", [get_weather])
middleware = ProviderToolSearchMiddleware()
modified_request = _invoke(middleware, request)
assert modified_request is request
def test_passes_through_unsupported_provider_with_empty_tools() -> None:
"""Empty tool list is a clean no-op and never trips the provider guard."""
request = _request("mistralai", [])
middleware = ProviderToolSearchMiddleware()
modified_request = _invoke(middleware, request)
assert modified_request is request
def test_raises_for_unsupported_provider_when_tool_deferred() -> None:
"""Deferring against a provider without tool search is a hard error, not a silent drop."""
request = _request("mistralai", [send_email])
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
with pytest.raises(ValueError, match="server-side tool search"):
_invoke(middleware, request)
def test_raises_when_provider_cannot_be_determined() -> None:
"""Detection failure raises a distinct, actionable error rather than a misleading one."""
request = ModelRequest(
model=cast("BaseChatModel", MysteryModel()),
messages=[HumanMessage("hi")],
tools=[lookup_order],
)
middleware = ProviderToolSearchMiddleware()
with pytest.raises(ValueError, match="could not determine the provider"):
_invoke(middleware, request)
@pytest.mark.parametrize(
("model_factory", "expected_tool"),
[
(ChatAnthropic, ANTHROPIC_SEARCH_TOOL),
(ChatOpenAI, OPENAI_SEARCH_TOOL),
],
)
def test_detects_provider_from_class_name(
model_factory: type, expected_tool: dict[str, str]
) -> None:
"""Class-name fallback identifies the provider when no params/ls_provider are exposed."""
request = ModelRequest(
model=cast("BaseChatModel", model_factory()),
messages=[HumanMessage("hi")],
tools=[send_email],
)
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
modified_request = _invoke(middleware, request)
assert expected_tool in modified_request.tools
def test_normalizes_provider_casing_and_hyphens() -> None:
"""Provider identifiers are normalized so mixed casing still matches the registry."""
request = _request("Anthropic", [send_email])
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
modified_request = _invoke(middleware, request)
assert ANTHROPIC_SEARCH_TOOL in modified_request.tools
@pytest.mark.parametrize(
("model_name", "expected_tool"),
[
("claude-sonnet-4-5", ANTHROPIC_SEARCH_TOOL),
("gpt-5.4", OPENAI_SEARCH_TOOL),
("o3-mini", OPENAI_SEARCH_TOOL),
("chatgpt-4o-latest", OPENAI_SEARCH_TOOL),
],
)
def test_detects_provider_from_bare_model_name(
model_name: str, expected_tool: dict[str, str]
) -> None:
"""Bare model names (no `provider:` prefix) are mapped via the name heuristics."""
request = ModelRequest(
model=cast("BaseChatModel", FakeConfigurableModel(model_params={"model": model_name})),
messages=[HumanMessage("hi")],
tools=[send_email],
)
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
modified_request = _invoke(middleware, request)
assert expected_tool in modified_request.tools
def test_raises_for_unrecognized_bare_model_name() -> None:
"""An unrecognized bare name yields detection failure, not a wrong-provider guess."""
request = ModelRequest(
model=cast("BaseChatModel", FakeConfigurableModel(model_params={"model": "llama-3.1"})),
messages=[HumanMessage("hi")],
tools=[send_email],
)
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
with pytest.raises(ValueError, match="could not determine the provider"):
_invoke(middleware, request)
def test_detects_provider_from_configurable_model() -> None:
"""A configurable model's `_default_config` `provider:model` string resolves the provider."""
request = ModelRequest(
model=cast("BaseChatModel", FakeConfigurableModel({"model": "openai:gpt-5.4"})),
messages=[HumanMessage("hi")],
tools=[send_email],
)
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
modified_request = _invoke(middleware, request)
assert OPENAI_SEARCH_TOOL in modified_request.tools
def test_detects_provider_from_runtime_configurable_model() -> None:
"""Provider set only via runtime `configurable` is read through `_model_params`."""
request = ModelRequest(
model=cast("BaseChatModel", FakeConfigurableModel()),
messages=[HumanMessage("hi")],
tools=[send_email],
runtime=cast("Any", FakeRuntime({"configurable": {"model_provider": "openai"}})),
)
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
modified_request = _invoke(middleware, request)
assert OPENAI_SEARCH_TOOL in modified_request.tools
def test_runtime_model_override_uses_default_configurable_model_provider() -> None:
"""`model_provider` in the merged params wins over a runtime `model` of another provider."""
request = ModelRequest(
model=cast("BaseChatModel", FakeConfigurableModel({"model_provider": "openai"})),
messages=[HumanMessage("hi")],
tools=[send_email],
runtime=cast("Any", FakeRuntime({"configurable": {"model": "claude-sonnet-4-5"}})),
)
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
modified_request = _invoke(middleware, request)
assert OPENAI_SEARCH_TOOL in modified_request.tools
def test_runtime_provider_override_uses_runtime_configurable_model_provider() -> None:
"""Runtime `model_provider` overrides the default config's, confirming the merge precedence."""
request = ModelRequest(
model=cast("BaseChatModel", FakeConfigurableModel({"model_provider": "openai"})),
messages=[HumanMessage("hi")],
tools=[send_email],
runtime=cast("Any", FakeRuntime({"configurable": {"model_provider": "anthropic"}})),
)
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
modified_request = _invoke(middleware, request)
assert ANTHROPIC_SEARCH_TOOL in modified_request.tools
def test_ignores_non_mapping_runtime_config() -> None:
"""A malformed (non-mapping) runtime config is treated as absent, not propagated to a crash."""
request = ModelRequest(
model=cast(
"BaseChatModel",
FakeConfigurableModel(model_params={"model_provider": "openai"}),
),
messages=[HumanMessage("hi")],
tools=[send_email],
runtime=cast("Any", FakeRuntime("not-a-mapping")),
)
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
modified_request = _invoke(middleware, request)
assert OPENAI_SEARCH_TOOL in modified_request.tools
async def test_async_wrap_model_call_defers_tools() -> None:
"""The async path applies the same deferral as the sync path through shared logic."""
request = _request("openai", [send_email])
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
captured_request = None
async def handler(model_request: ModelRequest) -> ModelResponse:
nonlocal captured_request
captured_request = model_request
return ModelResponse(result=[AIMessage("ok")])
await middleware.awrap_model_call(request, handler)
assert captured_request is not None
assert OPENAI_SEARCH_TOOL in captured_request.tools