fix tests + support builtins

This commit is contained in:
Sydney Runkle
2025-10-03 10:31:22 -04:00
parent 894ffa0be5
commit 6001543093
4 changed files with 37 additions and 50 deletions

View File

@@ -349,11 +349,7 @@ def create_agent( # noqa: PLR0915
# Only create ToolNode if we have tools
tool_node = ToolNode(tools=all_tools) if all_tools else None
# Get the actual tool objects from the tool node (this converts Callables to tools)
if tool_node:
default_tools = list(tool_node.tools_by_name.values()) + builtin_tools
else:
default_tools = builtin_tools
default_tools = regular_tools + builtin_tools + structured_tools + middleware_tools
elif isinstance(tools, ToolNode):
# tools is ToolNode or None
tool_node = tools
@@ -498,26 +494,8 @@ def create_agent( # noqa: PLR0915
def _get_bound_model(request: ModelRequest) -> Runnable:
"""Get the model with appropriate tool bindings."""
# Get actual tool objects from tool names
tools_by_name = {t.name: t for t in default_tools}
unknown_tools = [name for name in request.tools if name not in tools_by_name]
if unknown_tools:
available_tools = sorted(tools_by_name.keys())
msg = (
f"Middleware returned unknown tool names: {unknown_tools}\n\n"
f"Available tools: {available_tools}\n\n"
"To fix this issue:\n"
"1. Ensure the tools are passed to create_agent() via "
"the 'tools' parameter\n"
"2. If using custom middleware with tools, ensure "
"they're registered via middleware.tools attribute\n"
"3. Verify that tool names in ModelRequest.tools match "
"the actual tool.name values"
)
raise ValueError(msg)
requested_tools = [tools_by_name[name] for name in request.tools]
# request.tools contains BaseTool | dict objects
requested_tools = request.tools
if isinstance(response_format, ProviderStrategy):
# Use native structured output
@@ -541,7 +519,7 @@ def create_agent( # noqa: PLR0915
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
tools=[t.name for t in default_tools],
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
@@ -580,7 +558,7 @@ def create_agent( # noqa: PLR0915
"""Async model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
tools=[t.name for t in default_tools],
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],

View File

@@ -62,7 +62,7 @@ class ModelRequest:
system_prompt: str | None
messages: list[AnyMessage] # excluding system prompt
tool_choice: Any | None
tools: list[str]
tools: list[BaseTool | dict[str, Any]]
response_format: ResponseFormat | None
model_settings: dict[str, Any] = field(default_factory=dict)

View File

@@ -10,8 +10,8 @@ from .model import FakeToolCallingModel
from langgraph.runtime import Runtime
def test_model_request_tools_are_strings() -> None:
"""Test that ModelRequest.tools contains tool names as strings, not tool objects."""
def test_model_request_tools_are_objects() -> None:
"""Test that ModelRequest.tools contains tool objects (BaseTool | dict)."""
captured_requests: list[ModelRequest] = []
@tool
@@ -43,16 +43,16 @@ def test_model_request_tools_are_strings() -> None:
# Verify that at least one request was captured
assert len(captured_requests) > 0
# Check that tools in the request are strings (tool names)
# Check that tools in the request are tool objects
request = captured_requests[0]
assert isinstance(request.tools, list)
assert len(request.tools) == 2
assert all(isinstance(tool_name, str) for tool_name in request.tools)
assert set(request.tools) == {"search_tool", "calculator"}
tool_names = {t.name for t in request.tools}
assert tool_names == {"search_tool", "calculator"}
def test_middleware_can_modify_tool_names() -> None:
"""Test that middleware can modify the list of tool names in ModelRequest."""
def test_middleware_can_modify_tools() -> None:
"""Test that middleware can modify the list of tools in ModelRequest."""
@tool
def tool_a(input: str) -> str:
@@ -74,7 +74,7 @@ def test_middleware_can_modify_tool_names() -> None:
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
# Only allow tool_a and tool_b
request.tools = ["tool_a", "tool_b"]
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
return request
# Model will try to call tool_a
@@ -98,31 +98,37 @@ def test_middleware_can_modify_tool_names() -> None:
assert tool_messages[0].name == "tool_a"
def test_unknown_tool_name_raises_error() -> None:
"""Test that using an unknown tool name in ModelRequest raises a clear error."""
def test_middleware_can_add_custom_tools() -> None:
"""Test that middleware can add custom tool objects to ModelRequest."""
@tool
def known_tool(input: str) -> str:
"""A known tool."""
return "result"
class BadMiddleware(AgentMiddleware):
@tool
def custom_tool(input: str) -> str:
"""A custom tool added by middleware."""
return "custom result"
class ToolAddingMiddleware(AgentMiddleware):
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
# Add an unknown tool name
request.tools = ["known_tool", "unknown_tool"]
# Add a custom tool
request.tools = list(request.tools) + [custom_tool]
return request
agent = create_agent(
model=FakeToolCallingModel(),
tools=[known_tool],
system_prompt="You are a helpful assistant.",
middleware=[BadMiddleware()],
middleware=[ToolAddingMiddleware()],
)
with pytest.raises(ValueError, match="Middleware returned unknown tool names"):
agent.invoke({"messages": [HumanMessage("Hello")]})
# Should work fine with custom tools added
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_middleware_can_add_and_remove_tools() -> None:
@@ -149,7 +155,7 @@ def test_middleware_can_add_and_remove_tools() -> None:
) -> ModelRequest:
# Remove admin_tool if not admin
if not state.get("is_admin", False):
request.tools = [name for name in request.tools if name != "admin_tool"]
request.tools = [t for t in request.tools if t.name != "admin_tool"]
return request
model = FakeToolCallingModel()
@@ -224,20 +230,20 @@ def test_tools_preserved_across_multiple_middleware() -> None:
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
modification_order.append(request.tools.copy())
modification_order.append([t.name for t in request.tools])
# Remove tool_c
request.tools = [name for name in request.tools if name != "tool_c"]
request.tools = [t for t in request.tools if t.name != "tool_c"]
return request
class SecondMiddleware(AgentMiddleware):
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
modification_order.append(request.tools.copy())
modification_order.append([t.name for t in request.tools])
# Should not see tool_c here
assert "tool_c" not in request.tools
assert "tool_c" not in [t.name for t in request.tools]
# Remove tool_b
request.tools = [name for name in request.tools if name != "tool_b"]
request.tools = [t for t in request.tools if t.name != "tool_b"]
return request
agent = create_agent(

View File

@@ -7,6 +7,7 @@ from typing import Union
from langchain_core.messages import HumanMessage
from langchain.agents import create_agent
from langchain_core.tools import tool
from langchain.agents.structured_output import (
MultipleStructuredOutputsError,
ProviderStrategy,
@@ -74,12 +75,14 @@ location_json_schema = {
}
@tool
def get_weather() -> str:
"""Get the weather."""
return "The weather is sunny and 75°F."
@tool
def get_location() -> str:
"""Get the current location."""