feat(langchain_v1): add llm selection middleware (#33272)

* Add llm based tool selection middleware.
* Note that we might want some form of caching for when the agent is
inside an active tool calling loop as the tool selection isn't expected
to change during that time.

API:

```python
class LLMToolSelectorMiddleware(AgentMiddleware):
    """Uses an LLM to select relevant tools before calling the main model.

    When an agent has many tools available, this middleware filters them down
    to only the most relevant ones for the user's query. This reduces token usage
    and helps the main model focus on the right tools.

    Examples:
        Limit to 3 tools:
        ```python
        from langchain.agents.middleware import LLMToolSelectorMiddleware

        middleware = LLMToolSelectorMiddleware(max_tools=3)

        agent = create_agent(
            model="openai:gpt-4o",
            tools=[tool1, tool2, tool3, tool4, tool5],
            middleware=[middleware],
        )
        ```

        Use a smaller model for selection:
        ```python
        middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
        ```
    """

    def __init__(
        self,
        *,
        model: str | BaseChatModel | None = None,
        system_prompt: str = DEFAULT_SYSTEM_PROMPT,
        max_tools: int | None = None,
        always_include: list[str] | None = None,
    ) -> None:
        """Initialize the tool selector.

        Args:
            model: Model to use for selection. If not provided, uses the agent's main model.
                Can be a model identifier string or BaseChatModel instance.
            system_prompt: Instructions for the selection model.
            max_tools: Maximum number of tools to select. If the model selects more,
                only the first max_tools will be used. No limit if not specified.
            always_include: Tool names to always include regardless of selection.
                These do not count against the max_tools limit.
        """
```



```python
"""Test script for LLM tool selection middleware."""

from langchain.agents import create_agent
from langchain.agents.middleware import LLMToolSelectorMiddleware
from langchain_core.tools import tool


@tool
def get_weather(location: str) -> str:
    """Get current weather for a location."""
    return f"Weather in {location}: 72°F, sunny"


@tool
def search_web(query: str) -> str:
    """Search the web for information."""
    return f"Search results for: {query}"


@tool
def calculate(expression: str) -> str:
    """Perform mathematical calculations."""
    return f"Result of {expression}: 42"


@tool
def send_email(to: str, subject: str) -> str:
    """Send an email to someone."""
    return f"Email sent to {to} with subject: {subject}"


@tool
def get_stock_price(symbol: str) -> str:
    """Get current stock price for a symbol."""
    return f"Stock price for {symbol}: $150.25"


@tool
def translate_text(text: str, target_language: str) -> str:
    """Translate text to another language."""
    return f"Translated '{text}' to {target_language}"


@tool
def set_reminder(task: str, time: str) -> str:
    """Set a reminder for a task."""
    return f"Reminder set: {task} at {time}"


@tool
def get_news(topic: str) -> str:
    """Get latest news about a topic."""
    return f"Latest news about {topic}"


@tool
def book_flight(destination: str, date: str) -> str:
    """Book a flight to a destination."""
    return f"Flight booked to {destination} on {date}"


@tool
def get_restaurant_recommendations(city: str, cuisine: str) -> str:
    """Get restaurant recommendations."""
    return f"Top {cuisine} restaurants in {city}"


# Create agent with tool selection middleware
middleware = LLMToolSelectorMiddleware(
    model="openai:gpt-4o-mini",
    max_tools=3,
)

agent = create_agent(
    model="openai:gpt-4o",
    tools=[
        get_weather,
        search_web,
        calculate,
        send_email,
        get_stock_price,
        translate_text,
        set_reminder,
        get_news,
        book_flight,
        get_restaurant_recommendations,
    ],
    middleware=[middleware],
)

# Test with a query that should select specific tools
response = agent.invoke(
    {"messages": [{"role": "user", "content": "I need to find restaurants"}]}
)

print(response)
```
This commit is contained in:
Eugene Yurtsev
2025-10-05 15:55:55 -04:00
committed by GitHub
parent bdb7dbbf16
commit df2ecd9448
3 changed files with 893 additions and 0 deletions

View File

@@ -6,6 +6,7 @@ from .planning import PlanningMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware
from .tool_call_limit import ToolCallLimitMiddleware
from .tool_selection import LLMToolSelectorMiddleware
from .types import (
AgentMiddleware,
AgentState,
@@ -23,6 +24,7 @@ __all__ = [
# should move to langchain-anthropic if we decide to keep it
"AnthropicPromptCachingMiddleware",
"HumanInTheLoopMiddleware",
"LLMToolSelectorMiddleware",
"ModelRequest",
"PIIDetectionError",
"PIIMiddleware",

View File

@@ -0,0 +1,293 @@
"""LLM-based tool selector middleware."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated, Literal, Union
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from pydantic import Field, TypeAdapter
from typing_extensions import TypedDict
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest, StateT
from langchain.chat_models.base import init_chat_model
if TYPE_CHECKING:
from langgraph.runtime import Runtime
from langgraph.typing import ContextT
from langchain.tools import BaseTool
logger = logging.getLogger(__name__)
DEFAULT_SYSTEM_PROMPT = (
"Your goal is to select the most relevant tools for answering the user's query."
)
@dataclass
class _SelectionRequest:
"""Prepared inputs for tool selection."""
available_tools: list[BaseTool]
system_message: str
last_user_message: HumanMessage
model: BaseChatModel
valid_tool_names: list[str]
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter:
"""Create a structured output schema for tool selection.
Args:
tools: Available tools to include in the schema.
Returns:
TypeAdapter for a schema where each tool name is a Literal with its description.
"""
if not tools:
msg = "Invalid usage: tools must be non-empty"
raise AssertionError(msg)
# Create a Union of Annotated Literal types for each tool name with description
# Example: Union[Annotated[Literal["tool1"], Field(description="...")], ...] noqa: ERA001
literals = [
Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools
]
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
description = "Tools to use. Place the most relevant tools first."
class ToolSelectionResponse(TypedDict):
"""Use to select relevant tools."""
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
return TypeAdapter(ToolSelectionResponse)
def _render_tool_list(tools: list[BaseTool]) -> str:
"""Format tools as markdown list.
Args:
tools: Tools to format.
Returns:
Markdown string with each tool on a new line.
"""
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
class LLMToolSelectorMiddleware(AgentMiddleware):
"""Uses an LLM to select relevant tools before calling the main model.
When an agent has many tools available, this middleware filters them down
to only the most relevant ones for the user's query. This reduces token usage
and helps the main model focus on the right tools.
Examples:
Limit to 3 tools:
```python
from langchain.agents.middleware import LLMToolSelectorMiddleware
middleware = LLMToolSelectorMiddleware(max_tools=3)
agent = create_agent(
model="openai:gpt-4o",
tools=[tool1, tool2, tool3, tool4, tool5],
middleware=[middleware],
)
```
Use a smaller model for selection:
```python
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
```
"""
def __init__(
self,
*,
model: str | BaseChatModel | None = None,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
max_tools: int | None = None,
always_include: list[str] | None = None,
) -> None:
"""Initialize the tool selector.
Args:
model: Model to use for selection. If not provided, uses the agent's main model.
Can be a model identifier string or BaseChatModel instance.
system_prompt: Instructions for the selection model.
max_tools: Maximum number of tools to select. If the model selects more,
only the first max_tools will be used. No limit if not specified.
always_include: Tool names to always include regardless of selection.
These do not count against the max_tools limit.
"""
super().__init__()
self.system_prompt = system_prompt
self.max_tools = max_tools
self.always_include = always_include or []
if isinstance(model, (BaseChatModel, type(None))):
self.model: BaseChatModel | None = model
else:
self.model = init_chat_model(model)
def _prepare_selection_request(self, request: ModelRequest) -> _SelectionRequest | None:
"""Prepare inputs for tool selection.
Returns:
SelectionRequest with prepared inputs, or None if no selection is needed.
"""
# If no tools available, return None
if not request.tools or len(request.tools) == 0:
return None
# Validate that always_include tools exist
if self.always_include:
available_tool_names = {tool.name for tool in request.tools}
missing_tools = [
name for name in self.always_include if name not in available_tool_names
]
if missing_tools:
msg = (
f"Tools in always_include not found in request: {missing_tools}. "
f"Available tools: {sorted(available_tool_names)}"
)
raise ValueError(msg)
# Separate tools that are always included from those available for selection
available_tools = [tool for tool in request.tools if tool.name not in self.always_include]
# If no tools available for selection, return None
if not available_tools:
return None
system_message = self.system_prompt
# If there's a max_tools limit, append instructions to the system prompt
if self.max_tools is not None:
system_message += (
f"\nIMPORTANT: List the tool names in order of relevance, "
f"with the most relevant first. "
f"If you exceed the maximum number of tools, "
f"only the first {self.max_tools} will be used."
)
# Get the last user message from the conversation history
last_user_message: HumanMessage
for message in request.messages:
if isinstance(message, HumanMessage):
last_user_message = message
break
else:
msg = "No user message found in request messages"
raise AssertionError(msg)
model = self.model or request.model
valid_tool_names = [tool.name for tool in available_tools]
return _SelectionRequest(
available_tools=available_tools,
system_message=system_message,
last_user_message=last_user_message,
model=model,
valid_tool_names=valid_tool_names,
)
def _process_selection_response(
self,
response: dict,
available_tools: list[BaseTool],
valid_tool_names: list[str],
request: ModelRequest,
) -> ModelRequest:
"""Process the selection response and return filtered ModelRequest."""
selected_tool_names: list[str] = []
invalid_tool_selections = []
for tool_name in response["tools"]:
if tool_name not in valid_tool_names:
invalid_tool_selections.append(tool_name)
continue
# Only add if not already selected and within max_tools limit
if tool_name not in selected_tool_names and (
self.max_tools is None or len(selected_tool_names) < self.max_tools
):
selected_tool_names.append(tool_name)
if invalid_tool_selections:
msg = f"Model selected invalid tools: {invalid_tool_selections}"
raise ValueError(msg)
# Filter tools based on selection and append always-included tools
selected_tools = [tool for tool in available_tools if tool.name in selected_tool_names]
always_included_tools = [tool for tool in request.tools if tool.name in self.always_include]
selected_tools.extend(always_included_tools)
request.tools = selected_tools
return request
def modify_model_request(
self,
request: ModelRequest,
state: StateT, # noqa: ARG002
runtime: Runtime[ContextT], # noqa: ARG002
) -> ModelRequest:
"""Modify the model request to filter tools based on LLM selection."""
selection_request = self._prepare_selection_request(request)
if selection_request is None:
return request
# Create dynamic response model with Literal enum of available tool names
type_adapter = _create_tool_selection_response(selection_request.available_tools)
schema = type_adapter.json_schema()
structured_model = selection_request.model.with_structured_output(schema)
response = structured_model.invoke(
[
{"role": "system", "content": selection_request.system_message},
selection_request.last_user_message,
]
)
# Response should be a dict since we're passing a schema (not a Pydantic model class)
if not isinstance(response, dict):
msg = f"Expected dict response, got {type(response)}"
raise AssertionError(msg)
return self._process_selection_response(
response, selection_request.available_tools, selection_request.valid_tool_names, request
)
async def amodify_model_request(
self,
request: ModelRequest,
state: AgentState, # noqa: ARG002
runtime: Runtime, # noqa: ARG002
) -> ModelRequest:
"""Modify the model request to filter tools based on LLM selection."""
selection_request = self._prepare_selection_request(request)
if selection_request is None:
return request
# Create dynamic response model with Literal enum of available tool names
type_adapter = _create_tool_selection_response(selection_request.available_tools)
schema = type_adapter.json_schema()
structured_model = selection_request.model.with_structured_output(schema)
response = await structured_model.ainvoke(
[
{"role": "system", "content": selection_request.system_message},
selection_request.last_user_message,
]
)
# Response should be a dict since we're passing a schema (not a Pydantic model class)
if not isinstance(response, dict):
msg = f"Expected dict response, got {type(response)}"
raise AssertionError(msg)
return self._process_selection_response(
response, selection_request.available_tools, selection_request.valid_tool_names, request
)

View File

@@ -0,0 +1,598 @@
"""Unit tests for LLM tool selection middleware."""
import typing
from typing import Union, Any, Literal
from itertools import cycle
from pydantic import BaseModel
from langchain.agents import create_agent
from langchain.agents.middleware import AgentState, ModelRequest, modify_model_request
from langchain.agents.middleware import LLMToolSelectorMiddleware
from langchain.messages import AIMessage
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.tools import tool
@tool
def get_weather(location: str) -> str:
"""Get current weather for a location."""
return f"Weather in {location}: 72°F, sunny"
@tool
def search_web(query: str) -> str:
"""Search the web for information."""
return f"Search results for: {query}"
@tool
def calculate(expression: str) -> str:
"""Perform mathematical calculations."""
return f"Result of {expression}: 42"
@tool
def send_email(to: str, subject: str) -> str:
"""Send an email to someone."""
return f"Email sent to {to}"
@tool
def get_stock_price(symbol: str) -> str:
"""Get current stock price for a symbol."""
return f"Stock price for {symbol}: $150.25"
class FakeModel(GenericFakeChatModel):
tool_style: Literal["openai", "anthropic"] = "openai"
def bind_tools(
self,
tools: typing.Sequence[Union[dict[str, Any], type[BaseModel], typing.Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
if len(tools) == 0:
msg = "Must provide at least one tool"
raise ValueError(msg)
tool_dicts = []
for tool in tools:
if isinstance(tool, dict):
tool_dicts.append(tool)
continue
if not isinstance(tool, BaseTool):
msg = "Only BaseTool and dict is supported by FakeToolCallingModel.bind_tools"
raise TypeError(msg)
# NOTE: this is a simplified tool spec for testing purposes only
if self.tool_style == "openai":
tool_dicts.append(
{
"type": "function",
"function": {
"name": tool.name,
},
}
)
elif self.tool_style == "anthropic":
tool_dicts.append(
{
"name": tool.name,
}
)
return self.bind(tools=tool_dicts)
class TestLLMToolSelectorBasic:
"""Test basic tool selection functionality."""
def test_sync_basic_selection(self) -> None:
"""Test synchronous tool selection."""
# First call: selector picks tools
# Second call: agent uses selected tools
tool_calls = [
[
{
"name": "ToolSelectionResponse",
"id": "1",
"args": {"tools": ["get_weather", "calculate"]},
}
],
[{"name": "get_weather", "id": "2", "args": {"location": "Paris"}}],
]
model_requests = []
@modify_model_request
def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest:
"""Middleware to select relevant tools based on state/context."""
# Select a small, relevant subset of tools based on state/context
model_requests.append(request)
return request
tool_selection_model = FakeModel(
messages=cycle(
[
AIMessage(
content="",
tool_calls=[
{
"name": "ToolSelectionResponse",
"id": "1",
"args": {"tools": ["get_weather", "calculate"]},
}
],
),
]
)
)
model = FakeModel(
messages=iter(
[
AIMessage(
content="",
tool_calls=[
{"name": "get_weather", "id": "2", "args": {"location": "Paris"}}
],
),
AIMessage(content="The weather in Paris is 72°F and sunny."),
]
)
)
tool_selector = LLMToolSelectorMiddleware(max_tools=2, model=tool_selection_model)
agent = create_agent(
model=model,
tools=[get_weather, search_web, calculate, send_email, get_stock_price],
middleware=[tool_selector, trace_model_requests],
)
response = agent.invoke({"messages": [HumanMessage("What's the weather in Paris?")]})
assert isinstance(response["messages"][-1], AIMessage)
for request in model_requests:
selected_tool_names = [tool.name for tool in request.tools] if request.tools else []
assert selected_tool_names == ["get_weather", "calculate"]
async def test_async_basic_selection(self) -> None:
"""Test asynchronous tool selection."""
tool_selection_model = FakeModel(
messages=cycle(
[
AIMessage(
content="",
tool_calls=[
{
"name": "ToolSelectionResponse",
"id": "1",
"args": {"tools": ["search_web"]},
}
],
),
]
)
)
model = FakeModel(
messages=iter(
[
AIMessage(
content="",
tool_calls=[{"name": "search_web", "id": "2", "args": {"query": "Python"}}],
),
AIMessage(content="Search results found."),
]
)
)
tool_selector = LLMToolSelectorMiddleware(max_tools=1, model=tool_selection_model)
agent = create_agent(
model=model,
tools=[get_weather, search_web, calculate],
middleware=[tool_selector],
)
response = await agent.ainvoke({"messages": [HumanMessage("Search for Python tutorials")]})
assert isinstance(response["messages"][-1], AIMessage)
class TestMaxToolsLimiting:
"""Test max_tools limiting behavior."""
def test_max_tools_limits_selection(self) -> None:
"""Test that max_tools limits selection when model selects too many tools."""
model_requests = []
@modify_model_request
def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest:
model_requests.append(request)
return request
# Selector model tries to select 4 tools
tool_selection_model = FakeModel(
messages=cycle(
[
AIMessage(
content="",
tool_calls=[
{
"name": "ToolSelectionResponse",
"id": "1",
"args": {
"tools": [
"get_weather",
"search_web",
"calculate",
"send_email",
]
},
}
],
),
]
)
)
model = FakeModel(messages=iter([AIMessage(content="Done")]))
# But max_tools=2, so only first 2 should be used
tool_selector = LLMToolSelectorMiddleware(max_tools=2, model=tool_selection_model)
agent = create_agent(
model=model,
tools=[get_weather, search_web, calculate, send_email],
middleware=[tool_selector, trace_model_requests],
)
agent.invoke({"messages": [HumanMessage("test")]})
# Verify only 2 tools were passed to the main model
assert len(model_requests) > 0
for request in model_requests:
assert len(request.tools) == 2
tool_names = [tool.name for tool in request.tools]
# Should be first 2 from the selection
assert tool_names == ["get_weather", "search_web"]
def test_no_max_tools_uses_all_selected(self) -> None:
"""Test that when max_tools is None, all selected tools are used."""
model_requests = []
@modify_model_request
def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest:
model_requests.append(request)
return request
tool_selection_model = FakeModel(
messages=cycle(
[
AIMessage(
content="",
tool_calls=[
{
"name": "ToolSelectionResponse",
"id": "1",
"args": {
"tools": [
"get_weather",
"search_web",
"calculate",
"get_stock_price",
]
},
}
],
),
]
)
)
model = FakeModel(messages=iter([AIMessage(content="Done")]))
# No max_tools specified
tool_selector = LLMToolSelectorMiddleware(model=tool_selection_model)
agent = create_agent(
model=model,
tools=[get_weather, search_web, calculate, send_email, get_stock_price],
middleware=[tool_selector, trace_model_requests],
)
agent.invoke({"messages": [HumanMessage("test")]})
# All 4 selected tools should be present
assert len(model_requests) > 0
for request in model_requests:
assert len(request.tools) == 4
tool_names = [tool.name for tool in request.tools]
assert set(tool_names) == {
"get_weather",
"search_web",
"calculate",
"get_stock_price",
}
class TestAlwaysInclude:
"""Test always_include functionality."""
def test_always_include_tools_present(self) -> None:
"""Test that always_include tools are always present in the request."""
model_requests = []
@modify_model_request
def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest:
model_requests.append(request)
return request
# Selector picks only search_web
tool_selection_model = FakeModel(
messages=cycle(
[
AIMessage(
content="",
tool_calls=[
{
"name": "ToolSelectionResponse",
"id": "1",
"args": {"tools": ["search_web"]},
}
],
),
]
)
)
model = FakeModel(messages=iter([AIMessage(content="Done")]))
# But send_email is always included
tool_selector = LLMToolSelectorMiddleware(
max_tools=1, always_include=["send_email"], model=tool_selection_model
)
agent = create_agent(
model=model,
tools=[get_weather, search_web, send_email],
middleware=[tool_selector, trace_model_requests],
)
agent.invoke({"messages": [HumanMessage("test")]})
# Both selected and always_include tools should be present
assert len(model_requests) > 0
for request in model_requests:
tool_names = [tool.name for tool in request.tools]
assert "search_web" in tool_names
assert "send_email" in tool_names
assert len(tool_names) == 2
def test_always_include_not_counted_against_max(self) -> None:
"""Test that always_include tools don't count against max_tools limit."""
model_requests = []
@modify_model_request
def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest:
model_requests.append(request)
return request
# Selector picks 2 tools
tool_selection_model = FakeModel(
messages=cycle(
[
AIMessage(
content="",
tool_calls=[
{
"name": "ToolSelectionResponse",
"id": "1",
"args": {"tools": ["get_weather", "search_web"]},
}
],
),
]
)
)
model = FakeModel(messages=iter([AIMessage(content="Done")]))
# max_tools=2, but we also have 2 always_include tools
tool_selector = LLMToolSelectorMiddleware(
max_tools=2,
always_include=["send_email", "calculate"],
model=tool_selection_model,
)
agent = create_agent(
model=model,
tools=[get_weather, search_web, calculate, send_email],
middleware=[tool_selector, trace_model_requests],
)
agent.invoke({"messages": [HumanMessage("test")]})
# Should have 2 selected + 2 always_include = 4 total
assert len(model_requests) > 0
for request in model_requests:
assert len(request.tools) == 4
tool_names = [tool.name for tool in request.tools]
assert "get_weather" in tool_names
assert "search_web" in tool_names
assert "send_email" in tool_names
assert "calculate" in tool_names
def test_multiple_always_include_tools(self) -> None:
"""Test that multiple always_include tools are all present."""
model_requests = []
@modify_model_request
def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest:
model_requests.append(request)
return request
# Selector picks 1 tool
tool_selection_model = FakeModel(
messages=cycle(
[
AIMessage(
content="",
tool_calls=[
{
"name": "ToolSelectionResponse",
"id": "1",
"args": {"tools": ["get_weather"]},
}
],
),
]
)
)
model = FakeModel(messages=iter([AIMessage(content="Done")]))
tool_selector = LLMToolSelectorMiddleware(
max_tools=1,
always_include=["send_email", "calculate", "get_stock_price"],
model=tool_selection_model,
)
agent = create_agent(
model=model,
tools=[get_weather, search_web, send_email, calculate, get_stock_price],
middleware=[tool_selector, trace_model_requests],
)
agent.invoke({"messages": [HumanMessage("test")]})
# Should have 1 selected + 3 always_include = 4 total
assert len(model_requests) > 0
for request in model_requests:
assert len(request.tools) == 4
tool_names = [tool.name for tool in request.tools]
assert "get_weather" in tool_names
assert "send_email" in tool_names
assert "calculate" in tool_names
assert "get_stock_price" in tool_names
class TestDuplicateAndInvalidTools:
"""Test handling of duplicate and invalid tool selections."""
def test_duplicate_tool_selection_deduplicated(self) -> None:
"""Test that duplicate tool selections are deduplicated."""
model_requests = []
@modify_model_request
def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest:
model_requests.append(request)
return request
# Selector returns duplicates
tool_selection_model = FakeModel(
messages=cycle(
[
AIMessage(
content="",
tool_calls=[
{
"name": "ToolSelectionResponse",
"id": "1",
"args": {
"tools": [
"get_weather",
"get_weather",
"search_web",
"search_web",
]
},
}
],
),
]
)
)
model = FakeModel(messages=iter([AIMessage(content="Done")]))
tool_selector = LLMToolSelectorMiddleware(max_tools=5, model=tool_selection_model)
agent = create_agent(
model=model,
tools=[get_weather, search_web, calculate],
middleware=[tool_selector, trace_model_requests],
)
agent.invoke({"messages": [HumanMessage("test")]})
# Duplicates should be removed
assert len(model_requests) > 0
for request in model_requests:
tool_names = [tool.name for tool in request.tools]
assert tool_names == ["get_weather", "search_web"]
assert len(tool_names) == 2
def test_max_tools_with_duplicates(self) -> None:
"""Test that max_tools works correctly with duplicate selections."""
model_requests = []
@modify_model_request
def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest:
model_requests.append(request)
return request
# Selector returns duplicates but max_tools=2
tool_selection_model = FakeModel(
messages=cycle(
[
AIMessage(
content="",
tool_calls=[
{
"name": "ToolSelectionResponse",
"id": "1",
"args": {
"tools": [
"get_weather",
"get_weather",
"search_web",
"search_web",
"calculate",
]
},
}
],
),
]
)
)
model = FakeModel(messages=iter([AIMessage(content="Done")]))
tool_selector = LLMToolSelectorMiddleware(max_tools=2, model=tool_selection_model)
agent = create_agent(
model=model,
tools=[get_weather, search_web, calculate],
middleware=[tool_selector, trace_model_requests],
)
agent.invoke({"messages": [HumanMessage("test")]})
# Should deduplicate and respect max_tools
assert len(model_requests) > 0
for request in model_requests:
tool_names = [tool.name for tool in request.tools]
assert len(tool_names) == 2
assert "get_weather" in tool_names
assert "search_web" in tool_names