Compare commits

...

1 Commits

Author SHA1 Message Date
Sydney Runkle
00ff74ac55 initial port for llm tool selector 2025-10-01 18:12:11 -07:00
4 changed files with 552 additions and 2 deletions

View File

@@ -1,6 +1,7 @@
"""Middleware plugins for agents."""
from .human_in_the_loop import HumanInTheLoopMiddleware
from .llm_tool_selector import LLMToolSelectorMiddleware
from .planning import PlanningMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware
@@ -20,6 +21,7 @@ __all__ = [
# should move to langchain-anthropic if we decide to keep it
"AnthropicPromptCachingMiddleware",
"HumanInTheLoopMiddleware",
"LLMToolSelectorMiddleware",
"ModelRequest",
"PlanningMiddleware",
"SummarizationMiddleware",

View File

@@ -0,0 +1,244 @@
"""LLM-based tool selection middleware for agents."""
from __future__ import annotations
import json
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from typing_extensions import NotRequired, TypedDict
from langchain.chat_models import init_chat_model
from .types import AgentMiddleware, AgentState, ModelRequest, Runtime
class ToolSelectionSchema(BaseModel):
"""Schema for tool selection structured output."""
selected_tools: list[str] = Field(description="List of selected tool names")
class LLMToolSelectorConfig(TypedDict):
"""Configuration options for the LLM Tool Selector middleware."""
model: NotRequired[str | BaseChatModel]
"""The language model to use for tool selection
default: the provided model from the agent options."""
system_prompt: NotRequired[str]
"""System prompt for the tool selection model."""
max_tools: NotRequired[int]
"""Maximum number of tools to select."""
include_full_history: NotRequired[bool]
"""Whether to include the full conversation history in the tool selection prompt."""
max_retries: NotRequired[int]
"""Maximum number of retries if the model selects incorrect tools."""
DEFAULT_SYSTEM_PROMPT = (
"Your goal is to select the most relevant tool for answering the user's query."
)
DEFAULT_INCLUDE_FULL_HISTORY = False
DEFAULT_MAX_RETRIES = 3
class LLMToolSelectorMiddleware(AgentMiddleware):
"""Middleware for selecting tools using an LLM-based strategy.
This middleware analyzes the user's query and available tools to select
the most relevant tools for the task, reducing the cognitive load on the
main model and improving response quality.
Args:
model: The language model to use for tool selection
default: the provided model from the agent options.
system_prompt: System prompt for the tool selection model.
max_tools: Maximum number of tools to select.
include_full_history: Whether to include the full conversation
history in the tool selection prompt.
max_retries: Maximum number of retries if the model selects incorrect tools.
Example:
```python
from langchain.agents.middleware.llm_tool_selector import LLMToolSelectorMiddleware
from langchain.agents import create_agent
middleware = LLMToolSelectorMiddleware(
max_tools=3, system_prompt="Select the most relevant tools for the user's query."
)
agent = create_agent(
model="openai:gpt-4o",
tools=[tool1, tool2, tool3, tool4, tool5],
middleware=[middleware],
)
```
"""
def __init__(
self,
*,
model: str | BaseChatModel | None = None,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
max_tools: int | None = None,
include_full_history: bool = DEFAULT_INCLUDE_FULL_HISTORY,
max_retries: int = DEFAULT_MAX_RETRIES,
) -> None:
"""Initialize the LLM Tool Selector middleware.
Args:
model: The language model to use for tool selection (default: the provided model from the agent options).
system_prompt: System prompt for the tool selection model.
max_tools: Maximum number of tools to select.
include_full_history: Whether to include the full conversation history in the tool selection prompt.
max_retries: Maximum number of retries if the model selects incorrect tools.
"""
super().__init__()
self.model = model
self.system_prompt = system_prompt
self.max_tools = max_tools
self.include_full_history = include_full_history
self.max_retries = max_retries
def modify_model_request(
self,
request: ModelRequest,
state: AgentState, # noqa: ARG002
runtime: Runtime,
) -> ModelRequest:
"""Modify the model request to filter tools based on LLM selection.
Args:
request: The original model request.
state: The current agent state.
runtime: The runtime context.
Returns:
The modified model request with filtered tools.
"""
# If no tools available, return request unchanged
if not request.tools or len(request.tools) == 0:
return request
# Extract tool information
tool_info = []
for tool in runtime.tools:
tool_info.append(
{
"name": tool.name,
"description": tool.description,
"tool": tool,
}
)
# Build tool representation for the prompt
tool_representation = "\n".join(
f"- {info['name']}: {info['description']}" for info in tool_info
)
# Build system message
system_message = f"""You are an agent that can use the following tools:
{tool_representation}
{self.system_prompt}"""
if self.include_full_history:
user_messages = [
msg.content for msg in request.messages if isinstance(msg, HumanMessage)
]
system_message += f"\nThe full conversation history is:\n{chr(10).join(user_messages)}"
if self.max_tools is not None:
system_message += f" You can select up to {self.max_tools} tools."
# Get the latest user message
latest_message = request.messages[-1] if request.messages else None
user_content = (
latest_message.content
if isinstance(latest_message, HumanMessage) and isinstance(latest_message.content, str)
else json.dumps(latest_message.content)
if latest_message
else ""
)
# Create tool selection model
tool_selection_model = (
request.model
if self.model is None
else init_chat_model(self.model)
if isinstance(self.model, str)
else self.model
)
valid_tool_names = [info["name"] for info in tool_info]
structured_model = tool_selection_model.with_structured_output(ToolSelectionSchema)
attempts = 0
selected_tool_names: list[str] = valid_tool_names.copy()
while attempts <= self.max_retries:
try:
response = structured_model.invoke(
[
{"role": "system", "content": system_message},
{"role": "user", "content": user_content},
]
)
selected_tool_names = response.selected_tools if response.selected_tools else []
# Validate that selected tools exist
invalid_tools = [
name for name in selected_tool_names if name not in valid_tool_names
]
if len(selected_tool_names) == 0:
system_message += "\n\nNote: You have not selected any tools. Please select at least one tool."
attempts += 1
elif (
len(invalid_tools) == 0
and self.max_tools is not None
and len(selected_tool_names) > self.max_tools
):
system_message += f"\n\nNote: You have selected more tools than the maximum allowed. You can select up to {self.max_tools} tools."
attempts += 1
elif len(invalid_tools) == 0:
# Success
break
elif attempts < self.max_retries:
# Retry with feedback about invalid tools
system_message += (
f"\n\nNote: The following tools are not available: "
f"{', '.join(invalid_tools)}. "
"Please select only from the available tools."
)
attempts += 1
else:
# Filter out invalid tools on final attempt
selected_tool_names = [
name for name in selected_tool_names if name in valid_tool_names
]
break
except Exception:
# Fall back to using all tools
if attempts >= self.max_retries:
return request
attempts += 1
# Filter tools based on selection
selected_tools = [info["name"] for info in tool_info if info["name"] in selected_tool_names]
return ModelRequest(
model=request.model,
system_prompt=request.system_prompt,
messages=request.messages,
tool_choice=request.tool_choice,
tools=selected_tools,
response_format=request.response_format,
model_settings=request.model_settings,
)

View File

@@ -0,0 +1,304 @@
"""Tests for LLM Tool Selector middleware."""
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import BaseModel
from langchain.agents.middleware.llm_tool_selector import (
LLMToolSelectorMiddleware,
ToolSelectionSchema,
)
from langchain.agents.middleware.types import ModelRequest
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.tools import BaseTool
class MockTool(BaseTool):
"""Mock tool for testing."""
name: str = "mock_tool"
description: str = "A mock tool for testing"
def _run(self, *args: Any, **kwargs: Any) -> str:
return "mock result"
async def _arun(self, *args: Any, **kwargs: Any) -> str:
return "mock result"
class MockChatModel(BaseChatModel):
"""Mock chat model for testing."""
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
# This is a placeholder - we'll mock the structured output
pass
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
# This is a placeholder - we'll mock the structured output
pass
@property
def _llm_type(self) -> str:
return "mock"
def test_tool_selection_schema():
"""Test that ToolSelectionSchema works correctly."""
schema = ToolSelectionSchema(selected_tools=["tool1", "tool2"])
assert schema.selected_tools == ["tool1", "tool2"]
def test_middleware_initialization():
"""Test that middleware initializes with correct defaults."""
middleware = LLMToolSelectorMiddleware()
assert middleware.model is None
assert (
middleware.system_prompt
== "Your goal is to select the most relevant tool for answering the user's query."
)
assert middleware.max_tools is None
assert middleware.include_full_history is False
assert middleware.max_retries == 3
def test_middleware_initialization_with_custom_values():
"""Test that middleware initializes with custom values."""
middleware = LLMToolSelectorMiddleware(
model="openai:gpt-4o",
system_prompt="Custom prompt",
max_tools=5,
include_full_history=True,
max_retries=2,
)
assert middleware.model == "openai:gpt-4o"
assert middleware.system_prompt == "Custom prompt"
assert middleware.max_tools == 5
assert middleware.include_full_history is True
assert middleware.max_retries == 2
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
def test_modify_model_request_no_tools(mock_init_chat_model):
"""Test that middleware returns request unchanged when no tools are available."""
middleware = LLMToolSelectorMiddleware()
request = ModelRequest(
model=MockChatModel(),
system_prompt="Test prompt",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=[],
response_format=None,
)
runtime = MagicMock()
runtime.tools = []
result = middleware.modify_model_request(request, {}, runtime)
assert result == request
mock_init_chat_model.assert_not_called()
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
def test_modify_model_request_with_tool_selection(mock_init_chat_model):
"""Test that middleware filters tools based on LLM selection."""
# Create mock tools
tool1 = MockTool(name="tool1", description="First tool")
tool2 = MockTool(name="tool2", description="Second tool")
tool3 = MockTool(name="tool3", description="Third tool")
# Create mock structured model
mock_structured_model = MagicMock()
mock_response = ToolSelectionSchema(selected_tools=["tool1", "tool3"])
mock_structured_model.invoke.return_value = mock_response
# Create mock chat model
mock_chat_model = MagicMock()
mock_chat_model.with_structured_output.return_value = mock_structured_model
middleware = LLMToolSelectorMiddleware(model=mock_chat_model)
request = ModelRequest(
model=MockChatModel(),
system_prompt="Test prompt",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=["tool1", "tool2", "tool3"],
response_format=None,
)
runtime = MagicMock()
runtime.tools = [tool1, tool2, tool3]
result = middleware.modify_model_request(request, {}, runtime)
# Verify that only selected tools are returned
assert result.tools == ["tool1", "tool3"]
assert result.model == request.model
assert result.system_prompt == request.system_prompt
assert result.messages == request.messages
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
def test_modify_model_request_with_string_model(mock_init_chat_model):
"""Test that middleware works with string model specification."""
# Create mock tools
tool1 = MockTool(name="tool1", description="First tool")
# Create mock structured model
mock_structured_model = MagicMock()
mock_response = ToolSelectionSchema(selected_tools=["tool1"])
mock_structured_model.invoke.return_value = mock_response
# Create mock chat model
mock_chat_model = MagicMock()
mock_chat_model.with_structured_output.return_value = mock_structured_model
mock_init_chat_model.return_value = mock_chat_model
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o")
request = ModelRequest(
model=MockChatModel(),
system_prompt="Test prompt",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=["tool1"],
response_format=None,
)
runtime = MagicMock()
runtime.tools = [tool1]
result = middleware.modify_model_request(request, {}, runtime)
# Verify that init_chat_model was called with the string model
mock_init_chat_model.assert_called_once_with("openai:gpt-4o")
# Verify that only selected tools are returned
assert result.tools == ["tool1"]
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
def test_modify_model_request_with_retry_logic(mock_init_chat_model):
"""Test that middleware retries on invalid tool selection."""
# Create mock tools
tool1 = MockTool(name="tool1", description="First tool")
tool2 = MockTool(name="tool2", description="Second tool")
# Create mock structured model that returns invalid tools first, then valid ones
mock_structured_model = MagicMock()
mock_responses = [
ToolSelectionSchema(selected_tools=["invalid_tool"]), # First attempt - invalid
ToolSelectionSchema(selected_tools=["tool1"]), # Second attempt - valid
]
mock_structured_model.invoke.side_effect = mock_responses
# Create mock chat model
mock_chat_model = MagicMock()
mock_chat_model.with_structured_output.return_value = mock_structured_model
middleware = LLMToolSelectorMiddleware(model=mock_chat_model, max_retries=3)
request = ModelRequest(
model=MockChatModel(),
system_prompt="Test prompt",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=["tool1", "tool2"],
response_format=None,
)
runtime = MagicMock()
runtime.tools = [tool1, tool2]
result = middleware.modify_model_request(request, {}, runtime)
# Verify that the model was called twice (initial + retry)
assert mock_structured_model.invoke.call_count == 2
# Verify that only valid tools are returned
assert result.tools == ["tool1"]
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
def test_modify_model_request_with_max_tools_limit(mock_init_chat_model):
"""Test that middleware enforces max_tools limit."""
# Create mock tools
tool1 = MockTool(name="tool1", description="First tool")
tool2 = MockTool(name="tool2", description="Second tool")
tool3 = MockTool(name="tool3", description="Third tool")
# Create mock structured model that returns too many tools first, then correct amount
mock_structured_model = MagicMock()
mock_responses = [
ToolSelectionSchema(selected_tools=["tool1", "tool2", "tool3"]), # Too many
ToolSelectionSchema(selected_tools=["tool1", "tool2"]), # Correct amount
]
mock_structured_model.invoke.side_effect = mock_responses
# Create mock chat model
mock_chat_model = MagicMock()
mock_chat_model.with_structured_output.return_value = mock_structured_model
middleware = LLMToolSelectorMiddleware(model=mock_chat_model, max_tools=2)
request = ModelRequest(
model=MockChatModel(),
system_prompt="Test prompt",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=["tool1", "tool2", "tool3"],
response_format=None,
)
runtime = MagicMock()
runtime.tools = [tool1, tool2, tool3]
result = middleware.modify_model_request(request, {}, runtime)
# Verify that the model was called twice (initial + retry)
assert mock_structured_model.invoke.call_count == 2
# Verify that only the allowed number of tools are returned
assert result.tools == ["tool1", "tool2"]
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
def test_modify_model_request_fallback_on_exception(mock_init_chat_model):
"""Test that middleware falls back to original request on exception."""
# Create mock tools
tool1 = MockTool(name="tool1", description="First tool")
# Create mock structured model that raises an exception
mock_structured_model = MagicMock()
mock_structured_model.invoke.side_effect = Exception("Model error")
# Create mock chat model
mock_chat_model = MagicMock()
mock_chat_model.with_structured_output.return_value = mock_structured_model
middleware = LLMToolSelectorMiddleware(model=mock_chat_model, max_retries=1)
request = ModelRequest(
model=MockChatModel(),
system_prompt="Test prompt",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=["tool1"],
response_format=None,
)
runtime = MagicMock()
runtime.tools = [tool1]
result = middleware.modify_model_request(request, {}, runtime)
# Verify that the original request is returned on exception
assert result == request

View File

@@ -2082,7 +2082,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "0.3.76"
version = "0.3.77"
source = { editable = "../core" }
dependencies = [
{ name = "jsonpatch" },
@@ -2323,7 +2323,7 @@ wheels = [
[[package]]
name = "langchain-tests"
version = "0.3.21"
version = "0.3.22"
source = { editable = "../standard-tests" }
dependencies = [
{ name = "httpx" },