From df2ecd944895e230b52ff52cc13f25db32a141e7 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sun, 5 Oct 2025 15:55:55 -0400 Subject: [PATCH] feat(langchain_v1): add llm selection middleware (#33272) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add llm based tool selection middleware. * Note that we might want some form of caching for when the agent is inside an active tool calling loop as the tool selection isn't expected to change during that time. API: ```python class LLMToolSelectorMiddleware(AgentMiddleware): """Uses an LLM to select relevant tools before calling the main model. When an agent has many tools available, this middleware filters them down to only the most relevant ones for the user's query. This reduces token usage and helps the main model focus on the right tools. Examples: Limit to 3 tools: ```python from langchain.agents.middleware import LLMToolSelectorMiddleware middleware = LLMToolSelectorMiddleware(max_tools=3) agent = create_agent( model="openai:gpt-4o", tools=[tool1, tool2, tool3, tool4, tool5], middleware=[middleware], ) ``` Use a smaller model for selection: ```python middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2) ``` """ def __init__( self, *, model: str | BaseChatModel | None = None, system_prompt: str = DEFAULT_SYSTEM_PROMPT, max_tools: int | None = None, always_include: list[str] | None = None, ) -> None: """Initialize the tool selector. Args: model: Model to use for selection. If not provided, uses the agent's main model. Can be a model identifier string or BaseChatModel instance. system_prompt: Instructions for the selection model. max_tools: Maximum number of tools to select. If the model selects more, only the first max_tools will be used. No limit if not specified. always_include: Tool names to always include regardless of selection. These do not count against the max_tools limit. """ ``` ```python """Test script for LLM tool selection middleware.""" from langchain.agents import create_agent from langchain.agents.middleware import LLMToolSelectorMiddleware from langchain_core.tools import tool @tool def get_weather(location: str) -> str: """Get current weather for a location.""" return f"Weather in {location}: 72°F, sunny" @tool def search_web(query: str) -> str: """Search the web for information.""" return f"Search results for: {query}" @tool def calculate(expression: str) -> str: """Perform mathematical calculations.""" return f"Result of {expression}: 42" @tool def send_email(to: str, subject: str) -> str: """Send an email to someone.""" return f"Email sent to {to} with subject: {subject}" @tool def get_stock_price(symbol: str) -> str: """Get current stock price for a symbol.""" return f"Stock price for {symbol}: $150.25" @tool def translate_text(text: str, target_language: str) -> str: """Translate text to another language.""" return f"Translated '{text}' to {target_language}" @tool def set_reminder(task: str, time: str) -> str: """Set a reminder for a task.""" return f"Reminder set: {task} at {time}" @tool def get_news(topic: str) -> str: """Get latest news about a topic.""" return f"Latest news about {topic}" @tool def book_flight(destination: str, date: str) -> str: """Book a flight to a destination.""" return f"Flight booked to {destination} on {date}" @tool def get_restaurant_recommendations(city: str, cuisine: str) -> str: """Get restaurant recommendations.""" return f"Top {cuisine} restaurants in {city}" # Create agent with tool selection middleware middleware = LLMToolSelectorMiddleware( model="openai:gpt-4o-mini", max_tools=3, ) agent = create_agent( model="openai:gpt-4o", tools=[ get_weather, search_web, calculate, send_email, get_stock_price, translate_text, set_reminder, get_news, book_flight, get_restaurant_recommendations, ], middleware=[middleware], ) # Test with a query that should select specific tools response = agent.invoke( {"messages": [{"role": "user", "content": "I need to find restaurants"}]} ) print(response) ``` --- .../langchain/agents/middleware/__init__.py | 2 + .../agents/middleware/tool_selection.py | 293 +++++++++ .../middleware/test_llm_tool_selection.py | 598 ++++++++++++++++++ 3 files changed, 893 insertions(+) create mode 100644 libs/langchain_v1/langchain/agents/middleware/tool_selection.py create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware/test_llm_tool_selection.py diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 011e6fe294d..823bfeaf354 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -6,6 +6,7 @@ from .planning import PlanningMiddleware from .prompt_caching import AnthropicPromptCachingMiddleware from .summarization import SummarizationMiddleware from .tool_call_limit import ToolCallLimitMiddleware +from .tool_selection import LLMToolSelectorMiddleware from .types import ( AgentMiddleware, AgentState, @@ -23,6 +24,7 @@ __all__ = [ # should move to langchain-anthropic if we decide to keep it "AnthropicPromptCachingMiddleware", "HumanInTheLoopMiddleware", + "LLMToolSelectorMiddleware", "ModelRequest", "PIIDetectionError", "PIIMiddleware", diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_selection.py b/libs/langchain_v1/langchain/agents/middleware/tool_selection.py new file mode 100644 index 00000000000..050478f7015 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/tool_selection.py @@ -0,0 +1,293 @@ +"""LLM-based tool selector middleware.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Annotated, Literal, Union + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import HumanMessage +from pydantic import Field, TypeAdapter +from typing_extensions import TypedDict + +from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest, StateT +from langchain.chat_models.base import init_chat_model + +if TYPE_CHECKING: + from langgraph.runtime import Runtime + from langgraph.typing import ContextT + + from langchain.tools import BaseTool + +logger = logging.getLogger(__name__) + +DEFAULT_SYSTEM_PROMPT = ( + "Your goal is to select the most relevant tools for answering the user's query." +) + + +@dataclass +class _SelectionRequest: + """Prepared inputs for tool selection.""" + + available_tools: list[BaseTool] + system_message: str + last_user_message: HumanMessage + model: BaseChatModel + valid_tool_names: list[str] + + +def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter: + """Create a structured output schema for tool selection. + + Args: + tools: Available tools to include in the schema. + + Returns: + TypeAdapter for a schema where each tool name is a Literal with its description. + """ + if not tools: + msg = "Invalid usage: tools must be non-empty" + raise AssertionError(msg) + + # Create a Union of Annotated Literal types for each tool name with description + # Example: Union[Annotated[Literal["tool1"], Field(description="...")], ...] noqa: ERA001 + literals = [ + Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools + ] + selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007 + + description = "Tools to use. Place the most relevant tools first." + + class ToolSelectionResponse(TypedDict): + """Use to select relevant tools.""" + + tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type] + + return TypeAdapter(ToolSelectionResponse) + + +def _render_tool_list(tools: list[BaseTool]) -> str: + """Format tools as markdown list. + + Args: + tools: Tools to format. + + Returns: + Markdown string with each tool on a new line. + """ + return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools) + + +class LLMToolSelectorMiddleware(AgentMiddleware): + """Uses an LLM to select relevant tools before calling the main model. + + When an agent has many tools available, this middleware filters them down + to only the most relevant ones for the user's query. This reduces token usage + and helps the main model focus on the right tools. + + Examples: + Limit to 3 tools: + ```python + from langchain.agents.middleware import LLMToolSelectorMiddleware + + middleware = LLMToolSelectorMiddleware(max_tools=3) + + agent = create_agent( + model="openai:gpt-4o", + tools=[tool1, tool2, tool3, tool4, tool5], + middleware=[middleware], + ) + ``` + + Use a smaller model for selection: + ```python + middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2) + ``` + """ + + def __init__( + self, + *, + model: str | BaseChatModel | None = None, + system_prompt: str = DEFAULT_SYSTEM_PROMPT, + max_tools: int | None = None, + always_include: list[str] | None = None, + ) -> None: + """Initialize the tool selector. + + Args: + model: Model to use for selection. If not provided, uses the agent's main model. + Can be a model identifier string or BaseChatModel instance. + system_prompt: Instructions for the selection model. + max_tools: Maximum number of tools to select. If the model selects more, + only the first max_tools will be used. No limit if not specified. + always_include: Tool names to always include regardless of selection. + These do not count against the max_tools limit. + """ + super().__init__() + self.system_prompt = system_prompt + self.max_tools = max_tools + self.always_include = always_include or [] + + if isinstance(model, (BaseChatModel, type(None))): + self.model: BaseChatModel | None = model + else: + self.model = init_chat_model(model) + + def _prepare_selection_request(self, request: ModelRequest) -> _SelectionRequest | None: + """Prepare inputs for tool selection. + + Returns: + SelectionRequest with prepared inputs, or None if no selection is needed. + """ + # If no tools available, return None + if not request.tools or len(request.tools) == 0: + return None + + # Validate that always_include tools exist + if self.always_include: + available_tool_names = {tool.name for tool in request.tools} + missing_tools = [ + name for name in self.always_include if name not in available_tool_names + ] + if missing_tools: + msg = ( + f"Tools in always_include not found in request: {missing_tools}. " + f"Available tools: {sorted(available_tool_names)}" + ) + raise ValueError(msg) + + # Separate tools that are always included from those available for selection + available_tools = [tool for tool in request.tools if tool.name not in self.always_include] + + # If no tools available for selection, return None + if not available_tools: + return None + + system_message = self.system_prompt + # If there's a max_tools limit, append instructions to the system prompt + if self.max_tools is not None: + system_message += ( + f"\nIMPORTANT: List the tool names in order of relevance, " + f"with the most relevant first. " + f"If you exceed the maximum number of tools, " + f"only the first {self.max_tools} will be used." + ) + + # Get the last user message from the conversation history + last_user_message: HumanMessage + for message in request.messages: + if isinstance(message, HumanMessage): + last_user_message = message + break + else: + msg = "No user message found in request messages" + raise AssertionError(msg) + + model = self.model or request.model + valid_tool_names = [tool.name for tool in available_tools] + + return _SelectionRequest( + available_tools=available_tools, + system_message=system_message, + last_user_message=last_user_message, + model=model, + valid_tool_names=valid_tool_names, + ) + + def _process_selection_response( + self, + response: dict, + available_tools: list[BaseTool], + valid_tool_names: list[str], + request: ModelRequest, + ) -> ModelRequest: + """Process the selection response and return filtered ModelRequest.""" + selected_tool_names: list[str] = [] + invalid_tool_selections = [] + + for tool_name in response["tools"]: + if tool_name not in valid_tool_names: + invalid_tool_selections.append(tool_name) + continue + + # Only add if not already selected and within max_tools limit + if tool_name not in selected_tool_names and ( + self.max_tools is None or len(selected_tool_names) < self.max_tools + ): + selected_tool_names.append(tool_name) + + if invalid_tool_selections: + msg = f"Model selected invalid tools: {invalid_tool_selections}" + raise ValueError(msg) + + # Filter tools based on selection and append always-included tools + selected_tools = [tool for tool in available_tools if tool.name in selected_tool_names] + always_included_tools = [tool for tool in request.tools if tool.name in self.always_include] + selected_tools.extend(always_included_tools) + request.tools = selected_tools + return request + + def modify_model_request( + self, + request: ModelRequest, + state: StateT, # noqa: ARG002 + runtime: Runtime[ContextT], # noqa: ARG002 + ) -> ModelRequest: + """Modify the model request to filter tools based on LLM selection.""" + selection_request = self._prepare_selection_request(request) + if selection_request is None: + return request + + # Create dynamic response model with Literal enum of available tool names + type_adapter = _create_tool_selection_response(selection_request.available_tools) + schema = type_adapter.json_schema() + structured_model = selection_request.model.with_structured_output(schema) + + response = structured_model.invoke( + [ + {"role": "system", "content": selection_request.system_message}, + selection_request.last_user_message, + ] + ) + + # Response should be a dict since we're passing a schema (not a Pydantic model class) + if not isinstance(response, dict): + msg = f"Expected dict response, got {type(response)}" + raise AssertionError(msg) + return self._process_selection_response( + response, selection_request.available_tools, selection_request.valid_tool_names, request + ) + + async def amodify_model_request( + self, + request: ModelRequest, + state: AgentState, # noqa: ARG002 + runtime: Runtime, # noqa: ARG002 + ) -> ModelRequest: + """Modify the model request to filter tools based on LLM selection.""" + selection_request = self._prepare_selection_request(request) + if selection_request is None: + return request + + # Create dynamic response model with Literal enum of available tool names + type_adapter = _create_tool_selection_response(selection_request.available_tools) + schema = type_adapter.json_schema() + structured_model = selection_request.model.with_structured_output(schema) + + response = await structured_model.ainvoke( + [ + {"role": "system", "content": selection_request.system_message}, + selection_request.last_user_message, + ] + ) + + # Response should be a dict since we're passing a schema (not a Pydantic model class) + if not isinstance(response, dict): + msg = f"Expected dict response, got {type(response)}" + raise AssertionError(msg) + return self._process_selection_response( + response, selection_request.available_tools, selection_request.valid_tool_names, request + ) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_llm_tool_selection.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_llm_tool_selection.py new file mode 100644 index 00000000000..6098cf433fc --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_llm_tool_selection.py @@ -0,0 +1,598 @@ +"""Unit tests for LLM tool selection middleware.""" + +import typing +from typing import Union, Any, Literal + +from itertools import cycle +from pydantic import BaseModel + +from langchain.agents import create_agent +from langchain.agents.middleware import AgentState, ModelRequest, modify_model_request +from langchain.agents.middleware import LLMToolSelectorMiddleware +from langchain.messages import AIMessage +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import BaseMessage +from langchain_core.messages import HumanMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.tools import tool + + +@tool +def get_weather(location: str) -> str: + """Get current weather for a location.""" + return f"Weather in {location}: 72°F, sunny" + + +@tool +def search_web(query: str) -> str: + """Search the web for information.""" + return f"Search results for: {query}" + + +@tool +def calculate(expression: str) -> str: + """Perform mathematical calculations.""" + return f"Result of {expression}: 42" + + +@tool +def send_email(to: str, subject: str) -> str: + """Send an email to someone.""" + return f"Email sent to {to}" + + +@tool +def get_stock_price(symbol: str) -> str: + """Get current stock price for a symbol.""" + return f"Stock price for {symbol}: $150.25" + + +class FakeModel(GenericFakeChatModel): + tool_style: Literal["openai", "anthropic"] = "openai" + + def bind_tools( + self, + tools: typing.Sequence[Union[dict[str, Any], type[BaseModel], typing.Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + if len(tools) == 0: + msg = "Must provide at least one tool" + raise ValueError(msg) + + tool_dicts = [] + for tool in tools: + if isinstance(tool, dict): + tool_dicts.append(tool) + continue + if not isinstance(tool, BaseTool): + msg = "Only BaseTool and dict is supported by FakeToolCallingModel.bind_tools" + raise TypeError(msg) + + # NOTE: this is a simplified tool spec for testing purposes only + if self.tool_style == "openai": + tool_dicts.append( + { + "type": "function", + "function": { + "name": tool.name, + }, + } + ) + elif self.tool_style == "anthropic": + tool_dicts.append( + { + "name": tool.name, + } + ) + + return self.bind(tools=tool_dicts) + + +class TestLLMToolSelectorBasic: + """Test basic tool selection functionality.""" + + def test_sync_basic_selection(self) -> None: + """Test synchronous tool selection.""" + # First call: selector picks tools + # Second call: agent uses selected tools + tool_calls = [ + [ + { + "name": "ToolSelectionResponse", + "id": "1", + "args": {"tools": ["get_weather", "calculate"]}, + } + ], + [{"name": "get_weather", "id": "2", "args": {"location": "Paris"}}], + ] + + model_requests = [] + + @modify_model_request + def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest: + """Middleware to select relevant tools based on state/context.""" + # Select a small, relevant subset of tools based on state/context + model_requests.append(request) + return request + + tool_selection_model = FakeModel( + messages=cycle( + [ + AIMessage( + content="", + tool_calls=[ + { + "name": "ToolSelectionResponse", + "id": "1", + "args": {"tools": ["get_weather", "calculate"]}, + } + ], + ), + ] + ) + ) + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[ + {"name": "get_weather", "id": "2", "args": {"location": "Paris"}} + ], + ), + AIMessage(content="The weather in Paris is 72°F and sunny."), + ] + ) + ) + + tool_selector = LLMToolSelectorMiddleware(max_tools=2, model=tool_selection_model) + + agent = create_agent( + model=model, + tools=[get_weather, search_web, calculate, send_email, get_stock_price], + middleware=[tool_selector, trace_model_requests], + ) + + response = agent.invoke({"messages": [HumanMessage("What's the weather in Paris?")]}) + + assert isinstance(response["messages"][-1], AIMessage) + + for request in model_requests: + selected_tool_names = [tool.name for tool in request.tools] if request.tools else [] + assert selected_tool_names == ["get_weather", "calculate"] + + async def test_async_basic_selection(self) -> None: + """Test asynchronous tool selection.""" + tool_selection_model = FakeModel( + messages=cycle( + [ + AIMessage( + content="", + tool_calls=[ + { + "name": "ToolSelectionResponse", + "id": "1", + "args": {"tools": ["search_web"]}, + } + ], + ), + ] + ) + ) + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "search_web", "id": "2", "args": {"query": "Python"}}], + ), + AIMessage(content="Search results found."), + ] + ) + ) + + tool_selector = LLMToolSelectorMiddleware(max_tools=1, model=tool_selection_model) + + agent = create_agent( + model=model, + tools=[get_weather, search_web, calculate], + middleware=[tool_selector], + ) + + response = await agent.ainvoke({"messages": [HumanMessage("Search for Python tutorials")]}) + + assert isinstance(response["messages"][-1], AIMessage) + + +class TestMaxToolsLimiting: + """Test max_tools limiting behavior.""" + + def test_max_tools_limits_selection(self) -> None: + """Test that max_tools limits selection when model selects too many tools.""" + model_requests = [] + + @modify_model_request + def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest: + model_requests.append(request) + return request + + # Selector model tries to select 4 tools + tool_selection_model = FakeModel( + messages=cycle( + [ + AIMessage( + content="", + tool_calls=[ + { + "name": "ToolSelectionResponse", + "id": "1", + "args": { + "tools": [ + "get_weather", + "search_web", + "calculate", + "send_email", + ] + }, + } + ], + ), + ] + ) + ) + + model = FakeModel(messages=iter([AIMessage(content="Done")])) + + # But max_tools=2, so only first 2 should be used + tool_selector = LLMToolSelectorMiddleware(max_tools=2, model=tool_selection_model) + + agent = create_agent( + model=model, + tools=[get_weather, search_web, calculate, send_email], + middleware=[tool_selector, trace_model_requests], + ) + + agent.invoke({"messages": [HumanMessage("test")]}) + + # Verify only 2 tools were passed to the main model + assert len(model_requests) > 0 + for request in model_requests: + assert len(request.tools) == 2 + tool_names = [tool.name for tool in request.tools] + # Should be first 2 from the selection + assert tool_names == ["get_weather", "search_web"] + + def test_no_max_tools_uses_all_selected(self) -> None: + """Test that when max_tools is None, all selected tools are used.""" + model_requests = [] + + @modify_model_request + def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest: + model_requests.append(request) + return request + + tool_selection_model = FakeModel( + messages=cycle( + [ + AIMessage( + content="", + tool_calls=[ + { + "name": "ToolSelectionResponse", + "id": "1", + "args": { + "tools": [ + "get_weather", + "search_web", + "calculate", + "get_stock_price", + ] + }, + } + ], + ), + ] + ) + ) + + model = FakeModel(messages=iter([AIMessage(content="Done")])) + + # No max_tools specified + tool_selector = LLMToolSelectorMiddleware(model=tool_selection_model) + + agent = create_agent( + model=model, + tools=[get_weather, search_web, calculate, send_email, get_stock_price], + middleware=[tool_selector, trace_model_requests], + ) + + agent.invoke({"messages": [HumanMessage("test")]}) + + # All 4 selected tools should be present + assert len(model_requests) > 0 + for request in model_requests: + assert len(request.tools) == 4 + tool_names = [tool.name for tool in request.tools] + assert set(tool_names) == { + "get_weather", + "search_web", + "calculate", + "get_stock_price", + } + + +class TestAlwaysInclude: + """Test always_include functionality.""" + + def test_always_include_tools_present(self) -> None: + """Test that always_include tools are always present in the request.""" + model_requests = [] + + @modify_model_request + def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest: + model_requests.append(request) + return request + + # Selector picks only search_web + tool_selection_model = FakeModel( + messages=cycle( + [ + AIMessage( + content="", + tool_calls=[ + { + "name": "ToolSelectionResponse", + "id": "1", + "args": {"tools": ["search_web"]}, + } + ], + ), + ] + ) + ) + + model = FakeModel(messages=iter([AIMessage(content="Done")])) + + # But send_email is always included + tool_selector = LLMToolSelectorMiddleware( + max_tools=1, always_include=["send_email"], model=tool_selection_model + ) + + agent = create_agent( + model=model, + tools=[get_weather, search_web, send_email], + middleware=[tool_selector, trace_model_requests], + ) + + agent.invoke({"messages": [HumanMessage("test")]}) + + # Both selected and always_include tools should be present + assert len(model_requests) > 0 + for request in model_requests: + tool_names = [tool.name for tool in request.tools] + assert "search_web" in tool_names + assert "send_email" in tool_names + assert len(tool_names) == 2 + + def test_always_include_not_counted_against_max(self) -> None: + """Test that always_include tools don't count against max_tools limit.""" + model_requests = [] + + @modify_model_request + def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest: + model_requests.append(request) + return request + + # Selector picks 2 tools + tool_selection_model = FakeModel( + messages=cycle( + [ + AIMessage( + content="", + tool_calls=[ + { + "name": "ToolSelectionResponse", + "id": "1", + "args": {"tools": ["get_weather", "search_web"]}, + } + ], + ), + ] + ) + ) + + model = FakeModel(messages=iter([AIMessage(content="Done")])) + + # max_tools=2, but we also have 2 always_include tools + tool_selector = LLMToolSelectorMiddleware( + max_tools=2, + always_include=["send_email", "calculate"], + model=tool_selection_model, + ) + + agent = create_agent( + model=model, + tools=[get_weather, search_web, calculate, send_email], + middleware=[tool_selector, trace_model_requests], + ) + + agent.invoke({"messages": [HumanMessage("test")]}) + + # Should have 2 selected + 2 always_include = 4 total + assert len(model_requests) > 0 + for request in model_requests: + assert len(request.tools) == 4 + tool_names = [tool.name for tool in request.tools] + assert "get_weather" in tool_names + assert "search_web" in tool_names + assert "send_email" in tool_names + assert "calculate" in tool_names + + def test_multiple_always_include_tools(self) -> None: + """Test that multiple always_include tools are all present.""" + model_requests = [] + + @modify_model_request + def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest: + model_requests.append(request) + return request + + # Selector picks 1 tool + tool_selection_model = FakeModel( + messages=cycle( + [ + AIMessage( + content="", + tool_calls=[ + { + "name": "ToolSelectionResponse", + "id": "1", + "args": {"tools": ["get_weather"]}, + } + ], + ), + ] + ) + ) + + model = FakeModel(messages=iter([AIMessage(content="Done")])) + + tool_selector = LLMToolSelectorMiddleware( + max_tools=1, + always_include=["send_email", "calculate", "get_stock_price"], + model=tool_selection_model, + ) + + agent = create_agent( + model=model, + tools=[get_weather, search_web, send_email, calculate, get_stock_price], + middleware=[tool_selector, trace_model_requests], + ) + + agent.invoke({"messages": [HumanMessage("test")]}) + + # Should have 1 selected + 3 always_include = 4 total + assert len(model_requests) > 0 + for request in model_requests: + assert len(request.tools) == 4 + tool_names = [tool.name for tool in request.tools] + assert "get_weather" in tool_names + assert "send_email" in tool_names + assert "calculate" in tool_names + assert "get_stock_price" in tool_names + + +class TestDuplicateAndInvalidTools: + """Test handling of duplicate and invalid tool selections.""" + + def test_duplicate_tool_selection_deduplicated(self) -> None: + """Test that duplicate tool selections are deduplicated.""" + model_requests = [] + + @modify_model_request + def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest: + model_requests.append(request) + return request + + # Selector returns duplicates + tool_selection_model = FakeModel( + messages=cycle( + [ + AIMessage( + content="", + tool_calls=[ + { + "name": "ToolSelectionResponse", + "id": "1", + "args": { + "tools": [ + "get_weather", + "get_weather", + "search_web", + "search_web", + ] + }, + } + ], + ), + ] + ) + ) + + model = FakeModel(messages=iter([AIMessage(content="Done")])) + + tool_selector = LLMToolSelectorMiddleware(max_tools=5, model=tool_selection_model) + + agent = create_agent( + model=model, + tools=[get_weather, search_web, calculate], + middleware=[tool_selector, trace_model_requests], + ) + + agent.invoke({"messages": [HumanMessage("test")]}) + + # Duplicates should be removed + assert len(model_requests) > 0 + for request in model_requests: + tool_names = [tool.name for tool in request.tools] + assert tool_names == ["get_weather", "search_web"] + assert len(tool_names) == 2 + + def test_max_tools_with_duplicates(self) -> None: + """Test that max_tools works correctly with duplicate selections.""" + model_requests = [] + + @modify_model_request + def trace_model_requests(request: ModelRequest, state: AgentState, runtime) -> ModelRequest: + model_requests.append(request) + return request + + # Selector returns duplicates but max_tools=2 + tool_selection_model = FakeModel( + messages=cycle( + [ + AIMessage( + content="", + tool_calls=[ + { + "name": "ToolSelectionResponse", + "id": "1", + "args": { + "tools": [ + "get_weather", + "get_weather", + "search_web", + "search_web", + "calculate", + ] + }, + } + ], + ), + ] + ) + ) + + model = FakeModel(messages=iter([AIMessage(content="Done")])) + + tool_selector = LLMToolSelectorMiddleware(max_tools=2, model=tool_selection_model) + + agent = create_agent( + model=model, + tools=[get_weather, search_web, calculate], + middleware=[tool_selector, trace_model_requests], + ) + + agent.invoke({"messages": [HumanMessage("test")]}) + + # Should deduplicate and respect max_tools + assert len(model_requests) > 0 + for request in model_requests: + tool_names = [tool.name for tool in request.tools] + assert len(tool_names) == 2 + assert "get_weather" in tool_names + assert "search_web" in tool_names