mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
Compare commits
1 Commits
langchain-
...
sr/port-bi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00ff74ac55 |
@@ -1,6 +1,7 @@
|
||||
"""Middleware plugins for agents."""
|
||||
|
||||
from .human_in_the_loop import HumanInTheLoopMiddleware
|
||||
from .llm_tool_selector import LLMToolSelectorMiddleware
|
||||
from .planning import PlanningMiddleware
|
||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from .summarization import SummarizationMiddleware
|
||||
@@ -20,6 +21,7 @@ __all__ = [
|
||||
# should move to langchain-anthropic if we decide to keep it
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"LLMToolSelectorMiddleware",
|
||||
"ModelRequest",
|
||||
"PlanningMiddleware",
|
||||
"SummarizationMiddleware",
|
||||
|
||||
@@ -0,0 +1,244 @@
|
||||
"""LLM-based tool selection middleware for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
from .types import AgentMiddleware, AgentState, ModelRequest, Runtime
|
||||
|
||||
|
||||
class ToolSelectionSchema(BaseModel):
|
||||
"""Schema for tool selection structured output."""
|
||||
|
||||
selected_tools: list[str] = Field(description="List of selected tool names")
|
||||
|
||||
|
||||
class LLMToolSelectorConfig(TypedDict):
|
||||
"""Configuration options for the LLM Tool Selector middleware."""
|
||||
|
||||
model: NotRequired[str | BaseChatModel]
|
||||
"""The language model to use for tool selection
|
||||
|
||||
default: the provided model from the agent options."""
|
||||
|
||||
system_prompt: NotRequired[str]
|
||||
"""System prompt for the tool selection model."""
|
||||
|
||||
max_tools: NotRequired[int]
|
||||
"""Maximum number of tools to select."""
|
||||
|
||||
include_full_history: NotRequired[bool]
|
||||
"""Whether to include the full conversation history in the tool selection prompt."""
|
||||
|
||||
max_retries: NotRequired[int]
|
||||
"""Maximum number of retries if the model selects incorrect tools."""
|
||||
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"Your goal is to select the most relevant tool for answering the user's query."
|
||||
)
|
||||
DEFAULT_INCLUDE_FULL_HISTORY = False
|
||||
DEFAULT_MAX_RETRIES = 3
|
||||
|
||||
|
||||
class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
"""Middleware for selecting tools using an LLM-based strategy.
|
||||
|
||||
This middleware analyzes the user's query and available tools to select
|
||||
the most relevant tools for the task, reducing the cognitive load on the
|
||||
main model and improving response quality.
|
||||
|
||||
Args:
|
||||
model: The language model to use for tool selection
|
||||
default: the provided model from the agent options.
|
||||
system_prompt: System prompt for the tool selection model.
|
||||
max_tools: Maximum number of tools to select.
|
||||
include_full_history: Whether to include the full conversation
|
||||
history in the tool selection prompt.
|
||||
max_retries: Maximum number of retries if the model selects incorrect tools.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents.middleware.llm_tool_selector import LLMToolSelectorMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
middleware = LLMToolSelectorMiddleware(
|
||||
max_tools=3, system_prompt="Select the most relevant tools for the user's query."
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[tool1, tool2, tool3, tool4, tool5],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str | BaseChatModel | None = None,
|
||||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||||
max_tools: int | None = None,
|
||||
include_full_history: bool = DEFAULT_INCLUDE_FULL_HISTORY,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
) -> None:
|
||||
"""Initialize the LLM Tool Selector middleware.
|
||||
|
||||
Args:
|
||||
model: The language model to use for tool selection (default: the provided model from the agent options).
|
||||
system_prompt: System prompt for the tool selection model.
|
||||
max_tools: Maximum number of tools to select.
|
||||
include_full_history: Whether to include the full conversation history in the tool selection prompt.
|
||||
max_retries: Maximum number of retries if the model selects incorrect tools.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
self.max_tools = max_tools
|
||||
self.include_full_history = include_full_history
|
||||
self.max_retries = max_retries
|
||||
|
||||
def modify_model_request(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
state: AgentState, # noqa: ARG002
|
||||
runtime: Runtime,
|
||||
) -> ModelRequest:
|
||||
"""Modify the model request to filter tools based on LLM selection.
|
||||
|
||||
Args:
|
||||
request: The original model request.
|
||||
state: The current agent state.
|
||||
runtime: The runtime context.
|
||||
|
||||
Returns:
|
||||
The modified model request with filtered tools.
|
||||
"""
|
||||
# If no tools available, return request unchanged
|
||||
if not request.tools or len(request.tools) == 0:
|
||||
return request
|
||||
|
||||
# Extract tool information
|
||||
tool_info = []
|
||||
for tool in runtime.tools:
|
||||
tool_info.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"tool": tool,
|
||||
}
|
||||
)
|
||||
|
||||
# Build tool representation for the prompt
|
||||
tool_representation = "\n".join(
|
||||
f"- {info['name']}: {info['description']}" for info in tool_info
|
||||
)
|
||||
|
||||
# Build system message
|
||||
system_message = f"""You are an agent that can use the following tools:
|
||||
{tool_representation}
|
||||
{self.system_prompt}"""
|
||||
|
||||
if self.include_full_history:
|
||||
user_messages = [
|
||||
msg.content for msg in request.messages if isinstance(msg, HumanMessage)
|
||||
]
|
||||
system_message += f"\nThe full conversation history is:\n{chr(10).join(user_messages)}"
|
||||
|
||||
if self.max_tools is not None:
|
||||
system_message += f" You can select up to {self.max_tools} tools."
|
||||
|
||||
# Get the latest user message
|
||||
latest_message = request.messages[-1] if request.messages else None
|
||||
user_content = (
|
||||
latest_message.content
|
||||
if isinstance(latest_message, HumanMessage) and isinstance(latest_message.content, str)
|
||||
else json.dumps(latest_message.content)
|
||||
if latest_message
|
||||
else ""
|
||||
)
|
||||
|
||||
# Create tool selection model
|
||||
tool_selection_model = (
|
||||
request.model
|
||||
if self.model is None
|
||||
else init_chat_model(self.model)
|
||||
if isinstance(self.model, str)
|
||||
else self.model
|
||||
)
|
||||
|
||||
valid_tool_names = [info["name"] for info in tool_info]
|
||||
structured_model = tool_selection_model.with_structured_output(ToolSelectionSchema)
|
||||
|
||||
attempts = 0
|
||||
selected_tool_names: list[str] = valid_tool_names.copy()
|
||||
|
||||
while attempts <= self.max_retries:
|
||||
try:
|
||||
response = structured_model.invoke(
|
||||
[
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
)
|
||||
|
||||
selected_tool_names = response.selected_tools if response.selected_tools else []
|
||||
|
||||
# Validate that selected tools exist
|
||||
invalid_tools = [
|
||||
name for name in selected_tool_names if name not in valid_tool_names
|
||||
]
|
||||
|
||||
if len(selected_tool_names) == 0:
|
||||
system_message += "\n\nNote: You have not selected any tools. Please select at least one tool."
|
||||
attempts += 1
|
||||
elif (
|
||||
len(invalid_tools) == 0
|
||||
and self.max_tools is not None
|
||||
and len(selected_tool_names) > self.max_tools
|
||||
):
|
||||
system_message += f"\n\nNote: You have selected more tools than the maximum allowed. You can select up to {self.max_tools} tools."
|
||||
attempts += 1
|
||||
elif len(invalid_tools) == 0:
|
||||
# Success
|
||||
break
|
||||
elif attempts < self.max_retries:
|
||||
# Retry with feedback about invalid tools
|
||||
system_message += (
|
||||
f"\n\nNote: The following tools are not available: "
|
||||
f"{', '.join(invalid_tools)}. "
|
||||
"Please select only from the available tools."
|
||||
)
|
||||
attempts += 1
|
||||
else:
|
||||
# Filter out invalid tools on final attempt
|
||||
selected_tool_names = [
|
||||
name for name in selected_tool_names if name in valid_tool_names
|
||||
]
|
||||
break
|
||||
except Exception:
|
||||
# Fall back to using all tools
|
||||
if attempts >= self.max_retries:
|
||||
return request
|
||||
attempts += 1
|
||||
|
||||
# Filter tools based on selection
|
||||
selected_tools = [info["name"] for info in tool_info if info["name"] in selected_tool_names]
|
||||
|
||||
return ModelRequest(
|
||||
model=request.model,
|
||||
system_prompt=request.system_prompt,
|
||||
messages=request.messages,
|
||||
tool_choice=request.tool_choice,
|
||||
tools=selected_tools,
|
||||
response_format=request.response_format,
|
||||
model_settings=request.model_settings,
|
||||
)
|
||||
@@ -0,0 +1,304 @@
|
||||
"""Tests for LLM Tool Selector middleware."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents.middleware.llm_tool_selector import (
|
||||
LLMToolSelectorMiddleware,
|
||||
ToolSelectionSchema,
|
||||
)
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
class MockTool(BaseTool):
|
||||
"""Mock tool for testing."""
|
||||
|
||||
name: str = "mock_tool"
|
||||
description: str = "A mock tool for testing"
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
return "mock result"
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
return "mock result"
|
||||
|
||||
|
||||
class MockChatModel(BaseChatModel):
|
||||
"""Mock chat model for testing."""
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
# This is a placeholder - we'll mock the structured output
|
||||
pass
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
# This is a placeholder - we'll mock the structured output
|
||||
pass
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
|
||||
def test_tool_selection_schema():
|
||||
"""Test that ToolSelectionSchema works correctly."""
|
||||
schema = ToolSelectionSchema(selected_tools=["tool1", "tool2"])
|
||||
assert schema.selected_tools == ["tool1", "tool2"]
|
||||
|
||||
|
||||
def test_middleware_initialization():
|
||||
"""Test that middleware initializes with correct defaults."""
|
||||
middleware = LLMToolSelectorMiddleware()
|
||||
|
||||
assert middleware.model is None
|
||||
assert (
|
||||
middleware.system_prompt
|
||||
== "Your goal is to select the most relevant tool for answering the user's query."
|
||||
)
|
||||
assert middleware.max_tools is None
|
||||
assert middleware.include_full_history is False
|
||||
assert middleware.max_retries == 3
|
||||
|
||||
|
||||
def test_middleware_initialization_with_custom_values():
|
||||
"""Test that middleware initializes with custom values."""
|
||||
middleware = LLMToolSelectorMiddleware(
|
||||
model="openai:gpt-4o",
|
||||
system_prompt="Custom prompt",
|
||||
max_tools=5,
|
||||
include_full_history=True,
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
assert middleware.model == "openai:gpt-4o"
|
||||
assert middleware.system_prompt == "Custom prompt"
|
||||
assert middleware.max_tools == 5
|
||||
assert middleware.include_full_history is True
|
||||
assert middleware.max_retries == 2
|
||||
|
||||
|
||||
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
|
||||
def test_modify_model_request_no_tools(mock_init_chat_model):
|
||||
"""Test that middleware returns request unchanged when no tools are available."""
|
||||
middleware = LLMToolSelectorMiddleware()
|
||||
|
||||
request = ModelRequest(
|
||||
model=MockChatModel(),
|
||||
system_prompt="Test prompt",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.tools = []
|
||||
|
||||
result = middleware.modify_model_request(request, {}, runtime)
|
||||
|
||||
assert result == request
|
||||
mock_init_chat_model.assert_not_called()
|
||||
|
||||
|
||||
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
|
||||
def test_modify_model_request_with_tool_selection(mock_init_chat_model):
|
||||
"""Test that middleware filters tools based on LLM selection."""
|
||||
# Create mock tools
|
||||
tool1 = MockTool(name="tool1", description="First tool")
|
||||
tool2 = MockTool(name="tool2", description="Second tool")
|
||||
tool3 = MockTool(name="tool3", description="Third tool")
|
||||
|
||||
# Create mock structured model
|
||||
mock_structured_model = MagicMock()
|
||||
mock_response = ToolSelectionSchema(selected_tools=["tool1", "tool3"])
|
||||
mock_structured_model.invoke.return_value = mock_response
|
||||
|
||||
# Create mock chat model
|
||||
mock_chat_model = MagicMock()
|
||||
mock_chat_model.with_structured_output.return_value = mock_structured_model
|
||||
|
||||
middleware = LLMToolSelectorMiddleware(model=mock_chat_model)
|
||||
|
||||
request = ModelRequest(
|
||||
model=MockChatModel(),
|
||||
system_prompt="Test prompt",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=["tool1", "tool2", "tool3"],
|
||||
response_format=None,
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.tools = [tool1, tool2, tool3]
|
||||
|
||||
result = middleware.modify_model_request(request, {}, runtime)
|
||||
|
||||
# Verify that only selected tools are returned
|
||||
assert result.tools == ["tool1", "tool3"]
|
||||
assert result.model == request.model
|
||||
assert result.system_prompt == request.system_prompt
|
||||
assert result.messages == request.messages
|
||||
|
||||
|
||||
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
|
||||
def test_modify_model_request_with_string_model(mock_init_chat_model):
|
||||
"""Test that middleware works with string model specification."""
|
||||
# Create mock tools
|
||||
tool1 = MockTool(name="tool1", description="First tool")
|
||||
|
||||
# Create mock structured model
|
||||
mock_structured_model = MagicMock()
|
||||
mock_response = ToolSelectionSchema(selected_tools=["tool1"])
|
||||
mock_structured_model.invoke.return_value = mock_response
|
||||
|
||||
# Create mock chat model
|
||||
mock_chat_model = MagicMock()
|
||||
mock_chat_model.with_structured_output.return_value = mock_structured_model
|
||||
|
||||
mock_init_chat_model.return_value = mock_chat_model
|
||||
|
||||
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o")
|
||||
|
||||
request = ModelRequest(
|
||||
model=MockChatModel(),
|
||||
system_prompt="Test prompt",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=["tool1"],
|
||||
response_format=None,
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.tools = [tool1]
|
||||
|
||||
result = middleware.modify_model_request(request, {}, runtime)
|
||||
|
||||
# Verify that init_chat_model was called with the string model
|
||||
mock_init_chat_model.assert_called_once_with("openai:gpt-4o")
|
||||
|
||||
# Verify that only selected tools are returned
|
||||
assert result.tools == ["tool1"]
|
||||
|
||||
|
||||
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
|
||||
def test_modify_model_request_with_retry_logic(mock_init_chat_model):
|
||||
"""Test that middleware retries on invalid tool selection."""
|
||||
# Create mock tools
|
||||
tool1 = MockTool(name="tool1", description="First tool")
|
||||
tool2 = MockTool(name="tool2", description="Second tool")
|
||||
|
||||
# Create mock structured model that returns invalid tools first, then valid ones
|
||||
mock_structured_model = MagicMock()
|
||||
mock_responses = [
|
||||
ToolSelectionSchema(selected_tools=["invalid_tool"]), # First attempt - invalid
|
||||
ToolSelectionSchema(selected_tools=["tool1"]), # Second attempt - valid
|
||||
]
|
||||
mock_structured_model.invoke.side_effect = mock_responses
|
||||
|
||||
# Create mock chat model
|
||||
mock_chat_model = MagicMock()
|
||||
mock_chat_model.with_structured_output.return_value = mock_structured_model
|
||||
|
||||
middleware = LLMToolSelectorMiddleware(model=mock_chat_model, max_retries=3)
|
||||
|
||||
request = ModelRequest(
|
||||
model=MockChatModel(),
|
||||
system_prompt="Test prompt",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=["tool1", "tool2"],
|
||||
response_format=None,
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.tools = [tool1, tool2]
|
||||
|
||||
result = middleware.modify_model_request(request, {}, runtime)
|
||||
|
||||
# Verify that the model was called twice (initial + retry)
|
||||
assert mock_structured_model.invoke.call_count == 2
|
||||
|
||||
# Verify that only valid tools are returned
|
||||
assert result.tools == ["tool1"]
|
||||
|
||||
|
||||
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
|
||||
def test_modify_model_request_with_max_tools_limit(mock_init_chat_model):
|
||||
"""Test that middleware enforces max_tools limit."""
|
||||
# Create mock tools
|
||||
tool1 = MockTool(name="tool1", description="First tool")
|
||||
tool2 = MockTool(name="tool2", description="Second tool")
|
||||
tool3 = MockTool(name="tool3", description="Third tool")
|
||||
|
||||
# Create mock structured model that returns too many tools first, then correct amount
|
||||
mock_structured_model = MagicMock()
|
||||
mock_responses = [
|
||||
ToolSelectionSchema(selected_tools=["tool1", "tool2", "tool3"]), # Too many
|
||||
ToolSelectionSchema(selected_tools=["tool1", "tool2"]), # Correct amount
|
||||
]
|
||||
mock_structured_model.invoke.side_effect = mock_responses
|
||||
|
||||
# Create mock chat model
|
||||
mock_chat_model = MagicMock()
|
||||
mock_chat_model.with_structured_output.return_value = mock_structured_model
|
||||
|
||||
middleware = LLMToolSelectorMiddleware(model=mock_chat_model, max_tools=2)
|
||||
|
||||
request = ModelRequest(
|
||||
model=MockChatModel(),
|
||||
system_prompt="Test prompt",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=["tool1", "tool2", "tool3"],
|
||||
response_format=None,
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.tools = [tool1, tool2, tool3]
|
||||
|
||||
result = middleware.modify_model_request(request, {}, runtime)
|
||||
|
||||
# Verify that the model was called twice (initial + retry)
|
||||
assert mock_structured_model.invoke.call_count == 2
|
||||
|
||||
# Verify that only the allowed number of tools are returned
|
||||
assert result.tools == ["tool1", "tool2"]
|
||||
|
||||
|
||||
@patch("langchain.agents.middleware.llm_tool_selector.init_chat_model")
|
||||
def test_modify_model_request_fallback_on_exception(mock_init_chat_model):
|
||||
"""Test that middleware falls back to original request on exception."""
|
||||
# Create mock tools
|
||||
tool1 = MockTool(name="tool1", description="First tool")
|
||||
|
||||
# Create mock structured model that raises an exception
|
||||
mock_structured_model = MagicMock()
|
||||
mock_structured_model.invoke.side_effect = Exception("Model error")
|
||||
|
||||
# Create mock chat model
|
||||
mock_chat_model = MagicMock()
|
||||
mock_chat_model.with_structured_output.return_value = mock_structured_model
|
||||
|
||||
middleware = LLMToolSelectorMiddleware(model=mock_chat_model, max_retries=1)
|
||||
|
||||
request = ModelRequest(
|
||||
model=MockChatModel(),
|
||||
system_prompt="Test prompt",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=["tool1"],
|
||||
response_format=None,
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.tools = [tool1]
|
||||
|
||||
result = middleware.modify_model_request(request, {}, runtime)
|
||||
|
||||
# Verify that the original request is returned on exception
|
||||
assert result == request
|
||||
4
libs/langchain_v1/uv.lock
generated
4
libs/langchain_v1/uv.lock
generated
@@ -2082,7 +2082,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.3.76"
|
||||
version = "0.3.77"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -2323,7 +2323,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-tests"
|
||||
version = "0.3.21"
|
||||
version = "0.3.22"
|
||||
source = { editable = "../standard-tests" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
|
||||
Reference in New Issue
Block a user