mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-25 01:16:55 +00:00
chore(langchain_v1): use list[str] for modifyModelRequest (#33166)
Update model request to return tools by name. This will decrease the odds of misusing the API. We'll need to extend the type for built-in tools later.
This commit is contained in:
@@ -59,7 +59,7 @@ class ModelRequest:
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage] # excluding system prompt
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool]
|
||||
tools: list[str]
|
||||
response_format: ResponseFormat | None
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -328,21 +328,42 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> Runnable:
|
||||
"""Get the model with appropriate tool bindings."""
|
||||
# Get actual tool objects from tool names
|
||||
tools_by_name = {t.name: t for t in default_tools}
|
||||
|
||||
unknown_tools = [name for name in request.tools if name not in tools_by_name]
|
||||
if unknown_tools:
|
||||
available_tools = sorted(tools_by_name.keys())
|
||||
msg = (
|
||||
f"Middleware returned unknown tool names: {unknown_tools}\n\n"
|
||||
f"Available tools: {available_tools}\n\n"
|
||||
"To fix this issue:\n"
|
||||
"1. Ensure the tools are passed to create_agent() via "
|
||||
"the 'tools' parameter\n"
|
||||
"2. If using custom middleware with tools, ensure "
|
||||
"they're registered via middleware.tools attribute\n"
|
||||
"3. Verify that tool names in ModelRequest.tools match "
|
||||
"the actual tool.name values"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
requested_tools = [tools_by_name[name] for name in request.tools]
|
||||
|
||||
if isinstance(response_format, ProviderStrategy):
|
||||
# Use native structured output
|
||||
kwargs = response_format.to_model_kwargs()
|
||||
return request.model.bind_tools(
|
||||
request.tools, strict=True, **kwargs, **request.model_settings
|
||||
requested_tools, strict=True, **kwargs, **request.model_settings
|
||||
)
|
||||
if isinstance(response_format, ToolStrategy):
|
||||
tool_choice = "any" if structured_output_tools else request.tool_choice
|
||||
return request.model.bind_tools(
|
||||
request.tools, tool_choice=tool_choice, **request.model_settings
|
||||
requested_tools, tool_choice=tool_choice, **request.model_settings
|
||||
)
|
||||
# Standard model binding
|
||||
if request.tools:
|
||||
if requested_tools:
|
||||
return request.model.bind_tools(
|
||||
request.tools, tool_choice=request.tool_choice, **request.model_settings
|
||||
requested_tools, tool_choice=request.tool_choice, **request.model_settings
|
||||
)
|
||||
return request.model.bind(**request.model_settings)
|
||||
|
||||
@@ -357,7 +378,7 @@ def create_agent( # noqa: PLR0915
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
tools=[t.name for t in default_tools],
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
@@ -385,7 +406,7 @@ def create_agent( # noqa: PLR0915
|
||||
# Start with the base model request
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
tools=[t.name for t in default_tools],
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
|
||||
@@ -0,0 +1,283 @@
|
||||
"""Test Middleware handling of tools in agents."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_model_request_tools_are_strings() -> None:
|
||||
"""Test that ModelRequest.tools contains tool names as strings, not tool objects."""
|
||||
captured_requests: list[ModelRequest] = []
|
||||
|
||||
@tool
|
||||
def search_tool(query: str) -> str:
|
||||
"""Search for information."""
|
||||
return f"Results for: {query}"
|
||||
|
||||
@tool
|
||||
def calculator(expression: str) -> str:
|
||||
"""Calculate a mathematical expression."""
|
||||
return f"Result: {expression}"
|
||||
|
||||
class RequestCapturingMiddleware(AgentMiddleware):
|
||||
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
captured_requests.append(request)
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[search_tool, calculator],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[RequestCapturingMiddleware()],
|
||||
).compile()
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
# Verify that at least one request was captured
|
||||
assert len(captured_requests) > 0
|
||||
|
||||
# Check that tools in the request are strings (tool names)
|
||||
request = captured_requests[0]
|
||||
assert isinstance(request.tools, list)
|
||||
assert len(request.tools) == 2
|
||||
assert all(isinstance(tool_name, str) for tool_name in request.tools)
|
||||
assert set(request.tools) == {"search_tool", "calculator"}
|
||||
|
||||
|
||||
def test_middleware_can_modify_tool_names() -> None:
|
||||
"""Test that middleware can modify the list of tool names in ModelRequest."""
|
||||
|
||||
@tool
|
||||
def tool_a(input: str) -> str:
|
||||
"""Tool A."""
|
||||
return "A"
|
||||
|
||||
@tool
|
||||
def tool_b(input: str) -> str:
|
||||
"""Tool B."""
|
||||
return "B"
|
||||
|
||||
@tool
|
||||
def tool_c(input: str) -> str:
|
||||
"""Tool C."""
|
||||
return "C"
|
||||
|
||||
class ToolFilteringMiddleware(AgentMiddleware):
|
||||
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
# Only allow tool_a and tool_b
|
||||
request.tools = ["tool_a", "tool_b"]
|
||||
return request
|
||||
|
||||
# Model will try to call tool_a
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "tool_a"}], []]
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[tool_a, tool_b, tool_c],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ToolFilteringMiddleware()],
|
||||
).compile()
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Use tool_a")]})
|
||||
|
||||
# Verify that the tool was executed successfully
|
||||
messages = result["messages"]
|
||||
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0].name == "tool_a"
|
||||
|
||||
|
||||
def test_unknown_tool_name_raises_error() -> None:
|
||||
"""Test that using an unknown tool name in ModelRequest raises a clear error."""
|
||||
|
||||
@tool
|
||||
def known_tool(input: str) -> str:
|
||||
"""A known tool."""
|
||||
return "result"
|
||||
|
||||
class BadMiddleware(AgentMiddleware):
|
||||
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
# Add an unknown tool name
|
||||
request.tools = ["known_tool", "unknown_tool"]
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[known_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[BadMiddleware()],
|
||||
).compile()
|
||||
|
||||
with pytest.raises(ValueError, match="Middleware returned unknown tool names"):
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
|
||||
def test_middleware_can_add_and_remove_tools() -> None:
|
||||
"""Test that middleware can dynamically add/remove tools based on state."""
|
||||
|
||||
@tool
|
||||
def search(query: str) -> str:
|
||||
"""Search for information."""
|
||||
return f"Search results for: {query}"
|
||||
|
||||
@tool
|
||||
def admin_tool(command: str) -> str:
|
||||
"""Admin-only tool."""
|
||||
return f"Admin: {command}"
|
||||
|
||||
class AdminState(AgentState):
|
||||
is_admin: bool
|
||||
|
||||
class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
|
||||
state_schema = AdminState
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state: AdminState) -> ModelRequest:
|
||||
# Remove admin_tool if not admin
|
||||
if not state.get("is_admin", False):
|
||||
request.tools = [name for name in request.tools if name != "admin_tool"]
|
||||
return request
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search, admin_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ConditionalToolMiddleware()],
|
||||
).compile()
|
||||
|
||||
# Test non-admin user - should not have access to admin_tool
|
||||
# We can't directly inspect the bound model, but we can verify the agent runs
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": False})
|
||||
assert "messages" in result
|
||||
|
||||
# Test admin user - should have access to all tools
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": True})
|
||||
assert "messages" in result
|
||||
|
||||
|
||||
def test_empty_tools_list_is_valid() -> None:
|
||||
"""Test that middleware can set tools to an empty list."""
|
||||
|
||||
@tool
|
||||
def some_tool(input: str) -> str:
|
||||
"""Some tool."""
|
||||
return "result"
|
||||
|
||||
class NoToolsMiddleware(AgentMiddleware):
|
||||
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
# Remove all tools
|
||||
request.tools = []
|
||||
return request
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[some_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoToolsMiddleware()],
|
||||
).compile()
|
||||
|
||||
# Should run without error even with no tools
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
|
||||
def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
"""Test that tool modifications by one middleware are visible to the next."""
|
||||
modification_order: list[list[str]] = []
|
||||
|
||||
@tool
|
||||
def tool_a(input: str) -> str:
|
||||
"""Tool A."""
|
||||
return "A"
|
||||
|
||||
@tool
|
||||
def tool_b(input: str) -> str:
|
||||
"""Tool B."""
|
||||
return "B"
|
||||
|
||||
@tool
|
||||
def tool_c(input: str) -> str:
|
||||
"""Tool C."""
|
||||
return "C"
|
||||
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
modification_order.append(request.tools.copy())
|
||||
# Remove tool_c
|
||||
request.tools = [name for name in request.tools if name != "tool_c"]
|
||||
return request
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
modification_order.append(request.tools.copy())
|
||||
# Should not see tool_c here
|
||||
assert "tool_c" not in request.tools
|
||||
# Remove tool_b
|
||||
request.tools = [name for name in request.tools if name != "tool_b"]
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[tool_a, tool_b, tool_c],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[FirstMiddleware(), SecondMiddleware()],
|
||||
).compile()
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
# Verify the modification sequence
|
||||
assert len(modification_order) == 2
|
||||
# First middleware sees all three tools
|
||||
assert set(modification_order[0]) == {"tool_a", "tool_b", "tool_c"}
|
||||
# Second middleware sees tool_c removed
|
||||
assert set(modification_order[1]) == {"tool_a", "tool_b"}
|
||||
|
||||
|
||||
def test_middleware_with_additional_tools() -> None:
|
||||
"""Test middleware that provides additional tools via tools attribute."""
|
||||
|
||||
@tool
|
||||
def base_tool(input: str) -> str:
|
||||
"""Base tool."""
|
||||
return "base"
|
||||
|
||||
@tool
|
||||
def middleware_tool(input: str) -> str:
|
||||
"""Tool provided by middleware."""
|
||||
return "middleware"
|
||||
|
||||
class ToolProvidingMiddleware(AgentMiddleware):
|
||||
tools = [middleware_tool]
|
||||
|
||||
# Model calls the middleware-provided tool
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"args": {"input": "test"}, "id": "1", "name": "middleware_tool"}],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[base_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ToolProvidingMiddleware()],
|
||||
).compile()
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Use middleware tool")]})
|
||||
|
||||
# Verify that the middleware tool was executed
|
||||
messages = result["messages"]
|
||||
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0].name == "middleware_tool"
|
||||
assert "middleware" in tool_messages[0].content.lower()
|
||||
Reference in New Issue
Block a user