mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-23 20:23:59 +00:00
fix tests + support builtins
This commit is contained in:
@@ -349,11 +349,7 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Only create ToolNode if we have tools
|
||||
tool_node = ToolNode(tools=all_tools) if all_tools else None
|
||||
# Get the actual tool objects from the tool node (this converts Callables to tools)
|
||||
if tool_node:
|
||||
default_tools = list(tool_node.tools_by_name.values()) + builtin_tools
|
||||
else:
|
||||
default_tools = builtin_tools
|
||||
default_tools = regular_tools + builtin_tools + structured_tools + middleware_tools
|
||||
elif isinstance(tools, ToolNode):
|
||||
# tools is ToolNode or None
|
||||
tool_node = tools
|
||||
@@ -498,26 +494,8 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> Runnable:
|
||||
"""Get the model with appropriate tool bindings."""
|
||||
# Get actual tool objects from tool names
|
||||
tools_by_name = {t.name: t for t in default_tools}
|
||||
|
||||
unknown_tools = [name for name in request.tools if name not in tools_by_name]
|
||||
if unknown_tools:
|
||||
available_tools = sorted(tools_by_name.keys())
|
||||
msg = (
|
||||
f"Middleware returned unknown tool names: {unknown_tools}\n\n"
|
||||
f"Available tools: {available_tools}\n\n"
|
||||
"To fix this issue:\n"
|
||||
"1. Ensure the tools are passed to create_agent() via "
|
||||
"the 'tools' parameter\n"
|
||||
"2. If using custom middleware with tools, ensure "
|
||||
"they're registered via middleware.tools attribute\n"
|
||||
"3. Verify that tool names in ModelRequest.tools match "
|
||||
"the actual tool.name values"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
requested_tools = [tools_by_name[name] for name in request.tools]
|
||||
# request.tools contains BaseTool | dict objects
|
||||
requested_tools = request.tools
|
||||
|
||||
if isinstance(response_format, ProviderStrategy):
|
||||
# Use native structured output
|
||||
@@ -541,7 +519,7 @@ def create_agent( # noqa: PLR0915
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=[t.name for t in default_tools],
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
@@ -580,7 +558,7 @@ def create_agent( # noqa: PLR0915
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=[t.name for t in default_tools],
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
|
||||
@@ -62,7 +62,7 @@ class ModelRequest:
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage] # excluding system prompt
|
||||
tool_choice: Any | None
|
||||
tools: list[str]
|
||||
tools: list[BaseTool | dict[str, Any]]
|
||||
response_format: ResponseFormat | None
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from .model import FakeToolCallingModel
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def test_model_request_tools_are_strings() -> None:
|
||||
"""Test that ModelRequest.tools contains tool names as strings, not tool objects."""
|
||||
def test_model_request_tools_are_objects() -> None:
|
||||
"""Test that ModelRequest.tools contains tool objects (BaseTool | dict)."""
|
||||
captured_requests: list[ModelRequest] = []
|
||||
|
||||
@tool
|
||||
@@ -43,16 +43,16 @@ def test_model_request_tools_are_strings() -> None:
|
||||
# Verify that at least one request was captured
|
||||
assert len(captured_requests) > 0
|
||||
|
||||
# Check that tools in the request are strings (tool names)
|
||||
# Check that tools in the request are tool objects
|
||||
request = captured_requests[0]
|
||||
assert isinstance(request.tools, list)
|
||||
assert len(request.tools) == 2
|
||||
assert all(isinstance(tool_name, str) for tool_name in request.tools)
|
||||
assert set(request.tools) == {"search_tool", "calculator"}
|
||||
tool_names = {t.name for t in request.tools}
|
||||
assert tool_names == {"search_tool", "calculator"}
|
||||
|
||||
|
||||
def test_middleware_can_modify_tool_names() -> None:
|
||||
"""Test that middleware can modify the list of tool names in ModelRequest."""
|
||||
def test_middleware_can_modify_tools() -> None:
|
||||
"""Test that middleware can modify the list of tools in ModelRequest."""
|
||||
|
||||
@tool
|
||||
def tool_a(input: str) -> str:
|
||||
@@ -74,7 +74,7 @@ def test_middleware_can_modify_tool_names() -> None:
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
# Only allow tool_a and tool_b
|
||||
request.tools = ["tool_a", "tool_b"]
|
||||
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
|
||||
return request
|
||||
|
||||
# Model will try to call tool_a
|
||||
@@ -98,31 +98,37 @@ def test_middleware_can_modify_tool_names() -> None:
|
||||
assert tool_messages[0].name == "tool_a"
|
||||
|
||||
|
||||
def test_unknown_tool_name_raises_error() -> None:
|
||||
"""Test that using an unknown tool name in ModelRequest raises a clear error."""
|
||||
def test_middleware_can_add_custom_tools() -> None:
|
||||
"""Test that middleware can add custom tool objects to ModelRequest."""
|
||||
|
||||
@tool
|
||||
def known_tool(input: str) -> str:
|
||||
"""A known tool."""
|
||||
return "result"
|
||||
|
||||
class BadMiddleware(AgentMiddleware):
|
||||
@tool
|
||||
def custom_tool(input: str) -> str:
|
||||
"""A custom tool added by middleware."""
|
||||
return "custom result"
|
||||
|
||||
class ToolAddingMiddleware(AgentMiddleware):
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
# Add an unknown tool name
|
||||
request.tools = ["known_tool", "unknown_tool"]
|
||||
# Add a custom tool
|
||||
request.tools = list(request.tools) + [custom_tool]
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[known_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[BadMiddleware()],
|
||||
middleware=[ToolAddingMiddleware()],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Middleware returned unknown tool names"):
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
# Should work fine with custom tools added
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
|
||||
def test_middleware_can_add_and_remove_tools() -> None:
|
||||
@@ -149,7 +155,7 @@ def test_middleware_can_add_and_remove_tools() -> None:
|
||||
) -> ModelRequest:
|
||||
# Remove admin_tool if not admin
|
||||
if not state.get("is_admin", False):
|
||||
request.tools = [name for name in request.tools if name != "admin_tool"]
|
||||
request.tools = [t for t in request.tools if t.name != "admin_tool"]
|
||||
return request
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
@@ -224,20 +230,20 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
modification_order.append(request.tools.copy())
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
# Remove tool_c
|
||||
request.tools = [name for name in request.tools if name != "tool_c"]
|
||||
request.tools = [t for t in request.tools if t.name != "tool_c"]
|
||||
return request
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
modification_order.append(request.tools.copy())
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
# Should not see tool_c here
|
||||
assert "tool_c" not in request.tools
|
||||
assert "tool_c" not in [t.name for t in request.tools]
|
||||
# Remove tool_b
|
||||
request.tools = [name for name in request.tools if name != "tool_b"]
|
||||
request.tools = [t for t in request.tools if t.name != "tool_b"]
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Union
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.tools import tool
|
||||
from langchain.agents.structured_output import (
|
||||
MultipleStructuredOutputsError,
|
||||
ProviderStrategy,
|
||||
@@ -74,12 +75,14 @@ location_json_schema = {
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def get_weather() -> str:
|
||||
"""Get the weather."""
|
||||
|
||||
return "The weather is sunny and 75°F."
|
||||
|
||||
|
||||
@tool
|
||||
def get_location() -> str:
|
||||
"""Get the current location."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user