Compare commits

...

2 Commits

Author SHA1 Message Date
Sydney Runkle
d003296684 async 2026-01-21 13:18:54 -05:00
Sydney Runkle
a6f70fff41 enabling dynamic tools 2026-01-21 09:11:05 -05:00
2 changed files with 74 additions and 71 deletions

View File

@@ -773,27 +773,38 @@ def create_agent(
regular_tools = [t for t in tools if not isinstance(t, dict)]
# Tools that require client-side execution (must be in ToolNode)
available_tools = middleware_tools + regular_tools
static_tools = middleware_tools + regular_tools
# Only create ToolNode if we have client-side tools
# Container to track tools from the last wrap_model_call execution
# This allows middleware to dynamically add tools that ToolNode can execute
current_tools_container: dict[str, list[BaseTool | dict[str, Any]]] = {"tools": []}
async def get_dynamic_tools() -> list[BaseTool]:
"""Return tools from the last model request, or static tools if none set."""
tools_from_request = current_tools_container.get("tools", [])
if tools_from_request:
# Filter to only BaseTool instances (not built-in dict tools)
return [t for t in tools_from_request if isinstance(t, BaseTool)]
# static_tools contains BaseTool or Callable instances (not dicts)
# ToolNode will convert callables to BaseTool internally
return list(static_tools)
# Only create ToolNode if we have static tools
# Dynamic tools from middleware will use this ToolNode via the callable
tool_node = (
ToolNode(
tools=available_tools,
tools=get_dynamic_tools,
wrap_tool_call=wrap_tool_call_wrapper,
awrap_tool_call=awrap_tool_call_wrapper,
)
if available_tools
if static_tools
else None
)
# Default tools for ModelRequest initialization
# Use converted BaseTool instances from ToolNode (not raw callables)
# Include built-ins and converted tools (can be changed dynamically by middleware)
# Include built-ins and static tools (can be changed dynamically by middleware)
# Structured tools are NOT included - they're added dynamically based on response_format
if tool_node:
default_tools = list(tool_node.tools_by_name.values()) + built_in_tools
else:
default_tools = list(built_in_tools)
default_tools: list[BaseTool | dict[str, Any]] = list(static_tools) + list(built_in_tools)
# validate middleware
if len({m.name for m in middleware}) != len(middleware):
@@ -993,41 +1004,8 @@ def create_agent(
initial if auto-detected).
Raises:
ValueError: If middleware returned unknown client-side tool names.
ValueError: If `ToolStrategy` specifies tools not declared upfront.
"""
# Validate ONLY client-side tools that need to exist in tool_node
# Build map of available client-side tools from the ToolNode
# (which has already converted callables)
available_tools_by_name = {}
if tool_node:
available_tools_by_name = tool_node.tools_by_name.copy()
# Check if any requested tools are unknown CLIENT-SIDE tools
unknown_tool_names = []
for t in request.tools:
# Only validate BaseTool instances (skip built-in dict tools)
if isinstance(t, dict):
continue
if isinstance(t, BaseTool) and t.name not in available_tools_by_name:
unknown_tool_names.append(t.name)
if unknown_tool_names:
available_tool_names = sorted(available_tools_by_name.keys())
msg = (
f"Middleware returned unknown tool names: {unknown_tool_names}\n\n"
f"Available client-side tools: {available_tool_names}\n\n"
"To fix this issue:\n"
"1. Ensure the tools are passed to create_agent() via "
"the 'tools' parameter\n"
"2. If using custom middleware with tools, ensure "
"they're registered via middleware.tools attribute\n"
"3. Verify that tool names in ModelRequest.tools match "
"the actual tool.name values\n"
"Note: Built-in provider tools (dict format) can be added dynamically."
)
raise ValueError(msg)
# Determine effective response format (auto-detect if needed)
effective_response_format: ResponseFormat[Any] | None
if isinstance(request.response_format, AutoStrategy):
@@ -1103,6 +1081,9 @@ def create_agent(
This is the core model execution logic wrapped by `wrap_model_call` handlers.
Raises any exceptions that occur during model invocation.
"""
# Store the current tools for ToolNode to use (supports dynamic tools from middleware)
current_tools_container["tools"] = request.tools
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
@@ -1158,6 +1139,9 @@ def create_agent(
Raises any exceptions that occur during model invocation.
"""
# Store the current tools for ToolNode to use (supports dynamic tools from middleware)
current_tools_container["tools"] = request.tools
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
@@ -1327,23 +1311,21 @@ def create_agent(
graph.add_edge(START, entry_node)
# add conditional edges only if tools exist
if tool_node is not None:
# Only include exit_node in destinations if any tool has return_direct=True
# Include exit_node in destinations if any static tool has return_direct=True
# or if there are structured output tools
tools_to_model_destinations = [loop_entry_node]
if (
any(tool.return_direct for tool in tool_node.tools_by_name.values())
or structured_output_tools
):
has_return_direct_tool = any(getattr(tool, "return_direct", False) for tool in static_tools)
if has_return_direct_tool or structured_output_tools:
tools_to_model_destinations.append(exit_node)
graph.add_conditional_edges(
"tools",
RunnableCallable(
_make_tools_to_model_edge(
tool_node=tool_node,
model_destination=loop_entry_node,
structured_output_tools=structured_output_tools,
end_destination=exit_node,
current_tools_container=current_tools_container,
),
trace=False,
),
@@ -1613,21 +1595,27 @@ def _make_model_to_model_edge(
def _make_tools_to_model_edge(
*,
tool_node: ToolNode,
model_destination: str,
structured_output_tools: dict[str, OutputToolBinding[Any]],
end_destination: str,
current_tools_container: dict[str, list[BaseTool | dict[str, Any]]],
) -> Callable[[dict[str, Any]], str | None]:
def tools_to_model(state: dict[str, Any]) -> str | None:
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
# Build tools_by_name from current_tools_container for dynamic tools support
tools_by_name: dict[str, BaseTool] = {}
for t in current_tools_container.get("tools", []):
if isinstance(t, BaseTool):
tools_by_name[t.name] = t
# 1. Exit condition: All executed tools have return_direct=True
# Filter to only client-side tools (provider tools are not in tool_node)
# Filter to only client-side tools (provider tools are not in tools_by_name)
client_side_tool_calls = [
c for c in last_ai_message.tool_calls if c["name"] in tool_node.tools_by_name
c for c in last_ai_message.tool_calls if c["name"] in tools_by_name
]
if client_side_tool_calls and all(
tool_node.tools_by_name[c["name"]].return_direct for c in client_side_tool_calls
tools_by_name[c["name"]].return_direct for c in client_side_tool_calls
):
return end_destination

View File

@@ -123,37 +123,52 @@ def test_middleware_can_modify_tools() -> None:
assert tool_messages[0].name == "tool_a"
def test_unknown_tool_raises_error() -> None:
"""Test that using an unknown tool in ModelRequest raises a clear error."""
def test_dynamic_tool_via_middleware() -> None:
"""Test that middleware can dynamically add tools that weren't pre-registered."""
@tool
def known_tool(value: str) -> str:
"""A known tool."""
return "result"
def static_tool(value: str) -> str:
"""A static tool passed to create_agent."""
return "static result"
@tool
def unknown_tool(value: str) -> str:
"""An unknown tool not passed to create_agent."""
return "unknown"
def dynamic_tool(value: str) -> str:
"""A dynamic tool added by middleware, not pre-registered."""
return f"dynamic result: {value}"
class BadMiddleware(AgentMiddleware):
class DynamicToolMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Add an unknown tool
return handler(request.override(tools=[*request.tools, unknown_tool]))
# Add a dynamic tool that wasn't passed to create_agent
return handler(request.override(tools=[*request.tools, dynamic_tool]))
agent = create_agent(
model=FakeToolCallingModel(),
tools=[known_tool],
system_prompt="You are a helpful assistant.",
middleware=[BadMiddleware()],
# Model will call the dynamic tool
model = FakeToolCallingModel(
tool_calls=[
[{"args": {"value": "test input"}, "id": "1", "name": "dynamic_tool"}],
[],
]
)
with pytest.raises(ValueError, match="Middleware returned unknown tool names"):
agent.invoke({"messages": [HumanMessage("Hello")]})
# Note: dynamic_tool is NOT passed to create_agent, only static_tool is
agent = create_agent(
model=model,
tools=[static_tool],
system_prompt="You are a helpful assistant.",
middleware=[DynamicToolMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Use the dynamic tool")]})
# Verify that the dynamic tool was executed successfully
messages = result["messages"]
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
assert len(tool_messages) == 1
assert tool_messages[0].name == "dynamic_tool"
assert "dynamic result: test input" in tool_messages[0].content
def test_middleware_can_add_and_remove_tools() -> None: