mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 06:42:37 +00:00
feat(langchain): add ProviderToolSearchMiddleware (#37969)
[Docs](https://github.com/langchain-ai/docs/pull/4355) Adds `ProviderToolSearchMiddleware` to let agents defer selected tools behind OpenAI/Anthropic provider-native tool search while preserving existing `extras={"defer_loading": True}` behavior. The middleware validates searchable tool names, injects the provider search tool only when a tool is deferred, and rejects unsupported providers up front. Made by [Open SWE](https://openswe.vercel.app) --------- Co-authored-by: Alexander Olsen <13665641+aolsenjazz@users.noreply.github.com> Co-authored-by: open-swe[bot] <open-swe@users.noreply.github.com> Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
@@ -12,6 +12,7 @@ from langchain.agents.middleware.model_call_limit import ModelCallLimitMiddlewar
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents.middleware.model_retry import ModelRetryMiddleware
|
||||
from langchain.agents.middleware.pii import PIIDetectionError, PIIMiddleware
|
||||
from langchain.agents.middleware.provider_tool_search import ProviderToolSearchMiddleware
|
||||
from langchain.agents.middleware.shell_tool import (
|
||||
CodexSandboxExecutionPolicy,
|
||||
DockerExecutionPolicy,
|
||||
@@ -65,6 +66,7 @@ __all__ = [
|
||||
"ModelRetryMiddleware",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"ProviderToolSearchMiddleware",
|
||||
"RedactionRule",
|
||||
"Runtime",
|
||||
"ShellToolMiddleware",
|
||||
|
||||
@@ -0,0 +1,303 @@
|
||||
"""Provider-side tool search middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
|
||||
ToolIdentifier: TypeAlias = str | BaseTool
|
||||
"""Tool name or tool instance that can be deferred behind provider tool search."""
|
||||
|
||||
|
||||
class _ServerToolSearchSpec(TypedDict):
|
||||
"""Provider-native tool search tool descriptor sent to the model as a tool."""
|
||||
|
||||
type: str
|
||||
name: NotRequired[str]
|
||||
|
||||
|
||||
# Provider-native tool search descriptors keyed by normalized provider name (see
|
||||
# `_normalize_provider`). This mapping is the single source of truth for which
|
||||
# providers support server-side tool search.
|
||||
#
|
||||
# The identifiers below are version-stamped by the providers and can go stale;
|
||||
# re-verify against provider docs when updating:
|
||||
# - Anthropic: https://docs.langchain.com/oss/python/integrations/chat/anthropic#tool-search
|
||||
# - OpenAI: server-side `tool_search` tool.
|
||||
_SERVER_TOOL_SEARCH_TOOLS: dict[str, _ServerToolSearchSpec] = {
|
||||
"anthropic": {
|
||||
"type": "tool_search_tool_bm25_20251119",
|
||||
"name": "tool_search_tool_bm25",
|
||||
},
|
||||
"openai": {"type": "tool_search"},
|
||||
}
|
||||
|
||||
|
||||
class ProviderToolSearchMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Defer selected tools behind provider-native tool search.
|
||||
|
||||
Instead of sending every tool schema on every turn, this middleware marks
|
||||
selected tools as deferred (via `extras["defer_loading"]`) and injects the
|
||||
provider's server-side tool search tool. The provider then retrieves the
|
||||
full schema of a deferred tool only when the model needs it, which keeps the
|
||||
request payload small when many tools are bound.
|
||||
|
||||
A tool is deferred when its name (or instance) is passed in `searchable_tools`,
|
||||
or when it already carries `extras["defer_loading"] is True`.
|
||||
|
||||
Only providers with server-side tool search are supported (currently
|
||||
Anthropic and OpenAI). The provider is inferred from the bound model.
|
||||
|
||||
!!! warning
|
||||
|
||||
This relies on provider-native tool search and only takes effect for
|
||||
supported providers. If a tool is deferred but the model's provider
|
||||
cannot be identified or does not support tool search, the model call
|
||||
raises `ValueError`. When no tool is deferred, the middleware passes the
|
||||
request through unchanged regardless of provider.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ProviderToolSearchMiddleware
|
||||
|
||||
agent = create_agent(
|
||||
"anthropic:claude-opus-4-8",
|
||||
tools=[get_weather, send_email, lookup_order],
|
||||
middleware=[ProviderToolSearchMiddleware(searchable_tools=["lookup_order"])],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *, searchable_tools: list[ToolIdentifier] | None = None) -> None:
|
||||
"""Initialize provider-side tool search.
|
||||
|
||||
Args:
|
||||
searchable_tools: Tools or tool names to defer behind provider-native
|
||||
tool search.
|
||||
"""
|
||||
super().__init__()
|
||||
self.searchable_tool_names = _to_tool_names(searchable_tools)
|
||||
|
||||
def _prepare_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""Prepare a model request with deferred tools and provider search.
|
||||
|
||||
Validates that every name in `searchable_tools` is bound to the model,
|
||||
then (only when at least one tool is deferred) resolves the model's
|
||||
provider and injects the provider-native tool search tool. Requests with
|
||||
no deferred tools pass through unchanged.
|
||||
|
||||
Args:
|
||||
request: Model request to prepare.
|
||||
|
||||
Returns:
|
||||
The original request when nothing is deferred, otherwise a new
|
||||
request with deferred tools and the provider search tool appended.
|
||||
|
||||
Raises:
|
||||
ValueError: If `searchable_tools` references a tool not bound to the
|
||||
model, or if a tool is deferred but the model's provider cannot
|
||||
be identified or does not support server-side tool search.
|
||||
"""
|
||||
tools = request.tools
|
||||
if self.searchable_tool_names:
|
||||
available = {tool.name for tool in tools if isinstance(tool, BaseTool)}
|
||||
unknown = sorted(self.searchable_tool_names - available)
|
||||
if unknown:
|
||||
msg = (
|
||||
"ProviderToolSearchMiddleware: searchable_tools references "
|
||||
f"tool(s) not bound to the model: {', '.join(unknown)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if not any(_is_deferred_tool(tool, self.searchable_tool_names) for tool in tools):
|
||||
return request
|
||||
|
||||
provider = _get_model_provider(request.model, request.runtime)
|
||||
if provider is None:
|
||||
msg = (
|
||||
"ProviderToolSearchMiddleware could not determine the provider for "
|
||||
f"model {request.model.__class__.__name__!r}; server-side tool search "
|
||||
f"supports: {', '.join(sorted(_SERVER_TOOL_SEARCH_TOOLS))}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if provider not in _SERVER_TOOL_SEARCH_TOOLS:
|
||||
msg = (
|
||||
"ProviderToolSearchMiddleware requires a provider with server-side "
|
||||
f"tool search, but got {provider!r}; supported providers: "
|
||||
f"{', '.join(sorted(_SERVER_TOOL_SEARCH_TOOLS))}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
bound_tools = [_defer_tool_if_needed(tool, self.searchable_tool_names) for tool in tools]
|
||||
return request.override(tools=[*bound_tools, dict(_SERVER_TOOL_SEARCH_TOOLS[provider])])
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Defer tools before invoking the model.
|
||||
|
||||
Args:
|
||||
request: Model request to execute.
|
||||
handler: Callback that executes the model request.
|
||||
|
||||
Returns:
|
||||
The model call result.
|
||||
|
||||
Raises:
|
||||
ValueError: If `searchable_tools` references a tool not bound to the
|
||||
model, or if a tool is deferred but the model's provider cannot
|
||||
be identified or does not support server-side tool search.
|
||||
"""
|
||||
return handler(self._prepare_request(request))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Defer tools before asynchronously invoking the model.
|
||||
|
||||
Args:
|
||||
request: Model request to execute.
|
||||
handler: Callback that executes the model request.
|
||||
|
||||
Returns:
|
||||
The model call result.
|
||||
|
||||
Raises:
|
||||
ValueError: If `searchable_tools` references a tool not bound to the
|
||||
model, or if a tool is deferred but the model's provider cannot
|
||||
be identified or does not support server-side tool search.
|
||||
"""
|
||||
return await handler(self._prepare_request(request))
|
||||
|
||||
|
||||
def _to_tool_names(tools: list[ToolIdentifier] | None) -> set[str]:
|
||||
"""Convert tool identifiers to names."""
|
||||
if tools is None:
|
||||
return set()
|
||||
return {tool if isinstance(tool, str) else tool.name for tool in tools}
|
||||
|
||||
|
||||
def _is_deferred_tool(tool: BaseTool | dict[str, Any], tool_names: set[str]) -> bool:
|
||||
"""Return whether a tool should be deferred.
|
||||
|
||||
Only `BaseTool` instances can be deferred; dict-form tools (e.g. provider
|
||||
tool specs) have no `extras` or name to match and are never deferred.
|
||||
"""
|
||||
if not isinstance(tool, BaseTool):
|
||||
return False
|
||||
extras = tool.extras if isinstance(tool.extras, dict) else {}
|
||||
return extras.get("defer_loading") is True or tool.name in tool_names
|
||||
|
||||
|
||||
def _defer_tool_if_needed(
|
||||
tool: BaseTool | dict[str, Any], tool_names: set[str]
|
||||
) -> BaseTool | dict[str, Any]:
|
||||
"""Return the tool with `defer_loading` set, or unchanged if not deferred.
|
||||
|
||||
Returns the input unchanged when the tool should not be deferred or is not a
|
||||
`BaseTool` (only `BaseTool` instances carry the `extras` that flags deferral).
|
||||
"""
|
||||
if not _is_deferred_tool(tool, tool_names):
|
||||
return tool
|
||||
if not isinstance(tool, BaseTool):
|
||||
return tool
|
||||
extras = {**(tool.extras or {}), "defer_loading": True}
|
||||
return tool.model_copy(update={"extras": extras})
|
||||
|
||||
|
||||
def _get_model_provider(model: BaseChatModel, runtime: Any) -> str | None:
|
||||
"""Infer the normalized provider name for server-side tool search.
|
||||
|
||||
Returns `None` when no provider can be identified, so callers can
|
||||
distinguish a detection failure from a provider that is simply unsupported.
|
||||
"""
|
||||
default_config = getattr(model, "_default_config", None)
|
||||
model_params_fn = getattr(model, "_model_params", None)
|
||||
if callable(model_params_fn):
|
||||
config = getattr(runtime, "config", None)
|
||||
# `_model_params` expects a config mapping (or None); coerce a malformed
|
||||
# non-mapping config to None so it is treated as "no config" rather than
|
||||
# raising deep inside the configurable model.
|
||||
if config is not None and not isinstance(config, Mapping):
|
||||
config = None
|
||||
model_params = model_params_fn(config)
|
||||
if isinstance(model_params, dict):
|
||||
params = (
|
||||
{**default_config, **model_params}
|
||||
if isinstance(default_config, dict)
|
||||
else model_params
|
||||
)
|
||||
if provider := _provider_from_params(params):
|
||||
return provider
|
||||
|
||||
if isinstance(default_config, dict) and (provider := _provider_from_params(default_config)):
|
||||
return provider
|
||||
|
||||
get_ls_params = getattr(model, "_get_ls_params", None)
|
||||
if callable(get_ls_params):
|
||||
ls_params = get_ls_params()
|
||||
if isinstance(ls_params, dict) and isinstance(ls_params.get("ls_provider"), str):
|
||||
return _normalize_provider(ls_params["ls_provider"])
|
||||
|
||||
return _provider_from_class_name(model.__class__.__name__)
|
||||
|
||||
|
||||
def _provider_from_params(params: dict[str, Any]) -> str | None:
|
||||
"""Infer the provider from model parameters, or `None` if absent."""
|
||||
provider = params.get("model_provider")
|
||||
if isinstance(provider, str):
|
||||
return _normalize_provider(provider)
|
||||
model_name = params.get("model")
|
||||
if isinstance(model_name, str):
|
||||
return _provider_from_model_name(model_name)
|
||||
return None
|
||||
|
||||
|
||||
def _provider_from_model_name(model_name: str) -> str | None:
|
||||
"""Infer the provider from a model name, or `None` if unrecognized."""
|
||||
provider, _, rest = model_name.partition(":")
|
||||
if rest:
|
||||
return _normalize_provider(provider)
|
||||
model_lower = model_name.lower()
|
||||
if model_lower.startswith("claude"):
|
||||
return "anthropic"
|
||||
if model_lower.startswith(("gpt-", "o1", "o3", "chatgpt")):
|
||||
return "openai"
|
||||
return None
|
||||
|
||||
|
||||
def _provider_from_class_name(class_name: str) -> str | None:
|
||||
"""Infer the provider from a model class name, or `None` if unrecognized."""
|
||||
if class_name in {"ChatAnthropic", "AnthropicChat"}:
|
||||
return "anthropic"
|
||||
if class_name in {"ChatOpenAI", "OpenAIChat"}:
|
||||
return "openai"
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_provider(provider: str) -> str:
|
||||
"""Normalize a provider identifier by lowercasing and mapping `-` to `_`."""
|
||||
return provider.replace("-", "_").lower()
|
||||
@@ -0,0 +1,432 @@
|
||||
"""Unit tests for provider tool search middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
|
||||
from langchain.agents.middleware import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ProviderToolSearchMiddleware,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
ANTHROPIC_SEARCH_TOOL = {
|
||||
"type": "tool_search_tool_bm25_20251119",
|
||||
"name": "tool_search_tool_bm25",
|
||||
}
|
||||
OPENAI_SEARCH_TOOL = {"type": "tool_search"}
|
||||
|
||||
|
||||
@tool
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get the weather for a city."""
|
||||
return f"Sunny in {city}"
|
||||
|
||||
|
||||
@tool
|
||||
def send_email(to: str) -> str:
|
||||
"""Send an email."""
|
||||
return f"Sent to {to}"
|
||||
|
||||
|
||||
@tool(extras={"defer_loading": True})
|
||||
def lookup_order(order_id: str) -> str:
|
||||
"""Look up an order."""
|
||||
return f"Order {order_id} shipped"
|
||||
|
||||
|
||||
@tool(extras={"category": "billing"})
|
||||
def refund_order(order_id: str) -> str:
|
||||
"""Refund an order."""
|
||||
return f"Refunded {order_id}"
|
||||
|
||||
|
||||
class FakeModel:
|
||||
def __init__(self, provider: str) -> None:
|
||||
self.provider = provider
|
||||
|
||||
def _get_ls_params(self) -> dict[str, str]:
|
||||
return {"ls_provider": self.provider}
|
||||
|
||||
|
||||
class FakeConfigurableModel:
|
||||
def __init__(
|
||||
self,
|
||||
default_config: dict[str, Any] | None = None,
|
||||
model_params: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._default_config = default_config or {}
|
||||
self.model_params = model_params or {}
|
||||
|
||||
def _model_params(self, config: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
return self.model_params if config is None else config.get("configurable", {})
|
||||
|
||||
|
||||
class ChatAnthropic:
|
||||
"""Bare model whose class name is the only provider signal."""
|
||||
|
||||
|
||||
class ChatOpenAI:
|
||||
"""Bare model whose class name is the only provider signal."""
|
||||
|
||||
|
||||
class MysteryModel:
|
||||
"""Model that exposes no provider signal at all."""
|
||||
|
||||
|
||||
class FakeRuntime:
|
||||
def __init__(self, config: Any) -> None:
|
||||
self.config = config
|
||||
|
||||
|
||||
def _request(provider: str, tools: list[BaseTool | dict[str, Any]]) -> ModelRequest:
|
||||
return ModelRequest(
|
||||
model=cast("BaseChatModel", FakeModel(provider)),
|
||||
messages=[HumanMessage("hi")],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
|
||||
def _invoke(middleware: ProviderToolSearchMiddleware, request: ModelRequest) -> ModelRequest:
|
||||
captured_request = None
|
||||
|
||||
def handler(model_request: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured_request
|
||||
captured_request = model_request
|
||||
return ModelResponse(result=[AIMessage("ok")])
|
||||
|
||||
middleware.wrap_model_call(request, handler)
|
||||
assert captured_request is not None
|
||||
return captured_request
|
||||
|
||||
|
||||
def test_passes_through_when_no_tools_are_deferred() -> None:
|
||||
"""Without deferral the request must be returned untouched (no needless copy)."""
|
||||
request = _request("anthropic", [get_weather, send_email])
|
||||
middleware = ProviderToolSearchMiddleware()
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert modified_request is request
|
||||
|
||||
|
||||
def test_defers_tools_named_in_searchable_tools() -> None:
|
||||
"""Only named tools get `defer_loading`; others stay intact and the search tool is added."""
|
||||
request = _request("anthropic", [get_weather, send_email])
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
email_tool = next(
|
||||
tool
|
||||
for tool in modified_request.tools
|
||||
if isinstance(tool, BaseTool) and tool.name == "send_email"
|
||||
)
|
||||
weather_tool = next(
|
||||
tool
|
||||
for tool in modified_request.tools
|
||||
if isinstance(tool, BaseTool) and tool.name == "get_weather"
|
||||
)
|
||||
assert email_tool.extras == {"defer_loading": True}
|
||||
assert weather_tool.extras is None
|
||||
assert ANTHROPIC_SEARCH_TOOL in modified_request.tools
|
||||
|
||||
|
||||
def test_accepts_tool_instances_in_searchable_tools() -> None:
|
||||
"""`searchable_tools` accepts `BaseTool` instances, not just names (the other input form)."""
|
||||
request = _request("openai", [get_weather, send_email])
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=[send_email])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
email_tool = next(
|
||||
tool
|
||||
for tool in modified_request.tools
|
||||
if isinstance(tool, BaseTool) and tool.name == "send_email"
|
||||
)
|
||||
assert email_tool.extras == {"defer_loading": True}
|
||||
assert OPENAI_SEARCH_TOOL in modified_request.tools
|
||||
|
||||
|
||||
def test_honors_tools_pre_marked_with_defer_loading() -> None:
|
||||
"""A pre-marked `defer_loading` tool triggers deferral even with no `searchable_tools`."""
|
||||
request = _request("anthropic", [get_weather, lookup_order])
|
||||
middleware = ProviderToolSearchMiddleware()
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
order_tool = next(
|
||||
tool
|
||||
for tool in modified_request.tools
|
||||
if isinstance(tool, BaseTool) and tool.name == "lookup_order"
|
||||
)
|
||||
assert order_tool.extras == {"defer_loading": True}
|
||||
assert ANTHROPIC_SEARCH_TOOL in modified_request.tools
|
||||
|
||||
|
||||
def test_preserves_existing_extras_when_deferring() -> None:
|
||||
"""Deferral merges into existing `extras` rather than overwriting a user's other keys."""
|
||||
request = _request("anthropic", [refund_order])
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["refund_order"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
refund_tool = next(
|
||||
tool
|
||||
for tool in modified_request.tools
|
||||
if isinstance(tool, BaseTool) and tool.name == "refund_order"
|
||||
)
|
||||
assert refund_tool.extras == {"category": "billing", "defer_loading": True}
|
||||
|
||||
|
||||
def test_passes_dict_tools_through_untouched() -> None:
|
||||
"""Dict-form tools have no name/extras and must pass through without an `AttributeError`."""
|
||||
web_search = {"type": "web_search"}
|
||||
request = _request("anthropic", [get_weather, web_search])
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["get_weather"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert web_search in modified_request.tools
|
||||
assert ANTHROPIC_SEARCH_TOOL in modified_request.tools
|
||||
|
||||
|
||||
def test_raises_when_searchable_tool_is_not_bound() -> None:
|
||||
"""A typo'd `searchable_tools` name fails loudly instead of silently never deferring."""
|
||||
request = _request("anthropic", [get_weather])
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["missing_tool"])
|
||||
|
||||
with pytest.raises(ValueError, match="missing_tool"):
|
||||
_invoke(middleware, request)
|
||||
|
||||
|
||||
def test_unbound_tool_error_is_sorted() -> None:
|
||||
"""The error lists unknown tools in sorted order so the message is deterministic."""
|
||||
request = _request("anthropic", [get_weather])
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["zzz", "aaa"])
|
||||
|
||||
with pytest.raises(ValueError, match="aaa, zzz"):
|
||||
_invoke(middleware, request)
|
||||
|
||||
|
||||
def test_unbound_tool_check_precedes_provider_check() -> None:
|
||||
"""Config errors (unbound tool) surface before provider errors, pinning the check order."""
|
||||
request = _request("mistralai", [get_weather])
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["missing_tool"])
|
||||
|
||||
with pytest.raises(ValueError, match="not bound to the model"):
|
||||
_invoke(middleware, request)
|
||||
|
||||
|
||||
def test_passes_through_unsupported_provider_when_nothing_deferred() -> None:
|
||||
"""An unsupported provider must not raise when no tool is actually deferred."""
|
||||
request = _request("mistralai", [get_weather])
|
||||
middleware = ProviderToolSearchMiddleware()
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert modified_request is request
|
||||
|
||||
|
||||
def test_passes_through_unsupported_provider_with_empty_tools() -> None:
|
||||
"""Empty tool list is a clean no-op and never trips the provider guard."""
|
||||
request = _request("mistralai", [])
|
||||
middleware = ProviderToolSearchMiddleware()
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert modified_request is request
|
||||
|
||||
|
||||
def test_raises_for_unsupported_provider_when_tool_deferred() -> None:
|
||||
"""Deferring against a provider without tool search is a hard error, not a silent drop."""
|
||||
request = _request("mistralai", [send_email])
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
with pytest.raises(ValueError, match="server-side tool search"):
|
||||
_invoke(middleware, request)
|
||||
|
||||
|
||||
def test_raises_when_provider_cannot_be_determined() -> None:
|
||||
"""Detection failure raises a distinct, actionable error rather than a misleading one."""
|
||||
request = ModelRequest(
|
||||
model=cast("BaseChatModel", MysteryModel()),
|
||||
messages=[HumanMessage("hi")],
|
||||
tools=[lookup_order],
|
||||
)
|
||||
middleware = ProviderToolSearchMiddleware()
|
||||
|
||||
with pytest.raises(ValueError, match="could not determine the provider"):
|
||||
_invoke(middleware, request)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_factory", "expected_tool"),
|
||||
[
|
||||
(ChatAnthropic, ANTHROPIC_SEARCH_TOOL),
|
||||
(ChatOpenAI, OPENAI_SEARCH_TOOL),
|
||||
],
|
||||
)
|
||||
def test_detects_provider_from_class_name(
|
||||
model_factory: type, expected_tool: dict[str, str]
|
||||
) -> None:
|
||||
"""Class-name fallback identifies the provider when no params/ls_provider are exposed."""
|
||||
request = ModelRequest(
|
||||
model=cast("BaseChatModel", model_factory()),
|
||||
messages=[HumanMessage("hi")],
|
||||
tools=[send_email],
|
||||
)
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert expected_tool in modified_request.tools
|
||||
|
||||
|
||||
def test_normalizes_provider_casing_and_hyphens() -> None:
|
||||
"""Provider identifiers are normalized so mixed casing still matches the registry."""
|
||||
request = _request("Anthropic", [send_email])
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert ANTHROPIC_SEARCH_TOOL in modified_request.tools
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "expected_tool"),
|
||||
[
|
||||
("claude-sonnet-4-5", ANTHROPIC_SEARCH_TOOL),
|
||||
("gpt-5.4", OPENAI_SEARCH_TOOL),
|
||||
("o3-mini", OPENAI_SEARCH_TOOL),
|
||||
("chatgpt-4o-latest", OPENAI_SEARCH_TOOL),
|
||||
],
|
||||
)
|
||||
def test_detects_provider_from_bare_model_name(
|
||||
model_name: str, expected_tool: dict[str, str]
|
||||
) -> None:
|
||||
"""Bare model names (no `provider:` prefix) are mapped via the name heuristics."""
|
||||
request = ModelRequest(
|
||||
model=cast("BaseChatModel", FakeConfigurableModel(model_params={"model": model_name})),
|
||||
messages=[HumanMessage("hi")],
|
||||
tools=[send_email],
|
||||
)
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert expected_tool in modified_request.tools
|
||||
|
||||
|
||||
def test_raises_for_unrecognized_bare_model_name() -> None:
|
||||
"""An unrecognized bare name yields detection failure, not a wrong-provider guess."""
|
||||
request = ModelRequest(
|
||||
model=cast("BaseChatModel", FakeConfigurableModel(model_params={"model": "llama-3.1"})),
|
||||
messages=[HumanMessage("hi")],
|
||||
tools=[send_email],
|
||||
)
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
with pytest.raises(ValueError, match="could not determine the provider"):
|
||||
_invoke(middleware, request)
|
||||
|
||||
|
||||
def test_detects_provider_from_configurable_model() -> None:
|
||||
"""A configurable model's `_default_config` `provider:model` string resolves the provider."""
|
||||
request = ModelRequest(
|
||||
model=cast("BaseChatModel", FakeConfigurableModel({"model": "openai:gpt-5.4"})),
|
||||
messages=[HumanMessage("hi")],
|
||||
tools=[send_email],
|
||||
)
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert OPENAI_SEARCH_TOOL in modified_request.tools
|
||||
|
||||
|
||||
def test_detects_provider_from_runtime_configurable_model() -> None:
|
||||
"""Provider set only via runtime `configurable` is read through `_model_params`."""
|
||||
request = ModelRequest(
|
||||
model=cast("BaseChatModel", FakeConfigurableModel()),
|
||||
messages=[HumanMessage("hi")],
|
||||
tools=[send_email],
|
||||
runtime=cast("Any", FakeRuntime({"configurable": {"model_provider": "openai"}})),
|
||||
)
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert OPENAI_SEARCH_TOOL in modified_request.tools
|
||||
|
||||
|
||||
def test_runtime_model_override_uses_default_configurable_model_provider() -> None:
|
||||
"""`model_provider` in the merged params wins over a runtime `model` of another provider."""
|
||||
request = ModelRequest(
|
||||
model=cast("BaseChatModel", FakeConfigurableModel({"model_provider": "openai"})),
|
||||
messages=[HumanMessage("hi")],
|
||||
tools=[send_email],
|
||||
runtime=cast("Any", FakeRuntime({"configurable": {"model": "claude-sonnet-4-5"}})),
|
||||
)
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert OPENAI_SEARCH_TOOL in modified_request.tools
|
||||
|
||||
|
||||
def test_runtime_provider_override_uses_runtime_configurable_model_provider() -> None:
|
||||
"""Runtime `model_provider` overrides the default config's, confirming the merge precedence."""
|
||||
request = ModelRequest(
|
||||
model=cast("BaseChatModel", FakeConfigurableModel({"model_provider": "openai"})),
|
||||
messages=[HumanMessage("hi")],
|
||||
tools=[send_email],
|
||||
runtime=cast("Any", FakeRuntime({"configurable": {"model_provider": "anthropic"}})),
|
||||
)
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert ANTHROPIC_SEARCH_TOOL in modified_request.tools
|
||||
|
||||
|
||||
def test_ignores_non_mapping_runtime_config() -> None:
|
||||
"""A malformed (non-mapping) runtime config is treated as absent, not propagated to a crash."""
|
||||
request = ModelRequest(
|
||||
model=cast(
|
||||
"BaseChatModel",
|
||||
FakeConfigurableModel(model_params={"model_provider": "openai"}),
|
||||
),
|
||||
messages=[HumanMessage("hi")],
|
||||
tools=[send_email],
|
||||
runtime=cast("Any", FakeRuntime("not-a-mapping")),
|
||||
)
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
|
||||
modified_request = _invoke(middleware, request)
|
||||
|
||||
assert OPENAI_SEARCH_TOOL in modified_request.tools
|
||||
|
||||
|
||||
async def test_async_wrap_model_call_defers_tools() -> None:
|
||||
"""The async path applies the same deferral as the sync path through shared logic."""
|
||||
request = _request("openai", [send_email])
|
||||
middleware = ProviderToolSearchMiddleware(searchable_tools=["send_email"])
|
||||
captured_request = None
|
||||
|
||||
async def handler(model_request: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured_request
|
||||
captured_request = model_request
|
||||
return ModelResponse(result=[AIMessage("ok")])
|
||||
|
||||
await middleware.awrap_model_call(request, handler)
|
||||
|
||||
assert captured_request is not None
|
||||
assert OPENAI_SEARCH_TOOL in captured_request.tools
|
||||
Reference in New Issue
Block a user