mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-05 00:30:18 +00:00
Compare commits
2 Commits
langchain-
...
wrap_tool_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5c882e484 | ||
|
|
dcb954d03c |
@@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolCall, ToolMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
from langgraph.constants import END, START
|
||||
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from langchain.tools.tool_node import ToolCallHandler, ToolCallRequest
|
||||
from langchain.tools.tool_node import ToolCallRequest, ToolCallWrapper
|
||||
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
@@ -373,8 +373,8 @@ def _handle_structured_output_error(
|
||||
|
||||
|
||||
def _chain_tool_call_handlers(
|
||||
handlers: Sequence[ToolCallHandler],
|
||||
) -> ToolCallHandler | None:
|
||||
handlers: Sequence[ToolCallWrapper],
|
||||
) -> ToolCallWrapper | None:
|
||||
"""Compose handlers into middleware stack (first = outermost).
|
||||
|
||||
Args:
|
||||
@@ -394,19 +394,23 @@ def _chain_tool_call_handlers(
|
||||
if len(handlers) == 1:
|
||||
return handlers[0]
|
||||
|
||||
def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandler:
|
||||
def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapper:
|
||||
"""Compose two handlers where outer wraps inner."""
|
||||
|
||||
def composed(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
# Create a callable that invokes inner with the original execute
|
||||
def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
|
||||
return inner(req, execute)
|
||||
# Create a wrapper that calls inner when outer calls its execute callable
|
||||
# When outer calls this wrapper with a ToolCall, we invoke inner
|
||||
# which receives the request and the base execute callable
|
||||
def call_inner_wrapper(_tool_call: ToolCall) -> ToolMessage | Command:
|
||||
# Outer may have modified the tool_call, but we ignore it here
|
||||
# Inner receives the original request and base execute
|
||||
return inner(request, execute)
|
||||
|
||||
# Outer can call call_inner multiple times
|
||||
return outer(request, call_inner)
|
||||
# Call outer handler with request and the wrapper that invokes inner
|
||||
return outer(request, call_inner_wrapper)
|
||||
|
||||
return composed
|
||||
|
||||
@@ -567,7 +571,7 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Only create ToolNode if we have client-side tools
|
||||
tool_node = (
|
||||
ToolNode(tools=available_tools, on_tool_call=wrap_tool_call_handler)
|
||||
ToolNode(tools=available_tools, wrap_tool_call=wrap_tool_call_handler)
|
||||
if available_tools
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
||||
from langchain.tools.tool_node import ToolCallRequest
|
||||
|
||||
# needed as top level import for pydantic schema generation on AgentState
|
||||
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage # noqa: TC002
|
||||
from langchain_core.messages import AIMessage, AnyMessage, ToolCall, ToolMessage # noqa: TC002
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from langgraph.graph.message import add_messages
|
||||
@@ -264,7 +264,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool execution for retries, monitoring, or modification.
|
||||
|
||||
@@ -279,22 +279,24 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
Returns:
|
||||
ToolMessage or Command (the final result).
|
||||
|
||||
The handler callable can be invoked multiple times for retry logic.
|
||||
Each call to handler is independent and stateless.
|
||||
The handler callable can be invoked multiple times for retry logic
|
||||
with potentially modified tool calls. Each call to handler is independent
|
||||
and stateless.
|
||||
|
||||
Examples:
|
||||
Modify request before execution:
|
||||
|
||||
def wrap_tool_call(self, request, handler):
|
||||
request.tool_call["args"]["value"] *= 2
|
||||
return handler(request)
|
||||
modified_call = request.tool_call.copy()
|
||||
modified_call["args"]["value"] *= 2
|
||||
return handler(modified_call)
|
||||
|
||||
Retry on error (call handler multiple times):
|
||||
|
||||
def wrap_tool_call(self, request, handler):
|
||||
for attempt in range(3):
|
||||
try:
|
||||
result = handler(request)
|
||||
result = handler(request.tool_call)
|
||||
if is_valid(result):
|
||||
return result
|
||||
except Exception:
|
||||
@@ -306,7 +308,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
|
||||
def wrap_tool_call(self, request, handler):
|
||||
for attempt in range(3):
|
||||
result = handler(request)
|
||||
result = handler(request.tool_call)
|
||||
if isinstance(result, ToolMessage) and result.status != "error":
|
||||
return result
|
||||
if attempt < 2:
|
||||
@@ -353,12 +355,13 @@ class _CallableReturningToolResponse(Protocol):
|
||||
"""Callable for tool call interception with handler callback.
|
||||
|
||||
Receives handler callback to execute tool and returns final ToolMessage or Command.
|
||||
Handler takes ToolCall dict and returns ToolMessage or Command.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool execution via handler callback."""
|
||||
...
|
||||
@@ -1267,7 +1270,7 @@ def wrap_tool_call(
|
||||
```python
|
||||
@wrap_tool_call
|
||||
def passthrough(request, handler):
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
```
|
||||
|
||||
Retry logic:
|
||||
@@ -1277,7 +1280,7 @@ def wrap_tool_call(
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
except Exception:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
@@ -1287,8 +1290,9 @@ def wrap_tool_call(
|
||||
```python
|
||||
@wrap_tool_call
|
||||
def modify_args(request, handler):
|
||||
request.tool_call["args"]["value"] *= 2
|
||||
return handler(request)
|
||||
modified_call = request.tool_call.copy()
|
||||
modified_call["args"]["value"] *= 2
|
||||
return handler(modified_call)
|
||||
```
|
||||
|
||||
Short-circuit with cached result:
|
||||
@@ -1297,7 +1301,7 @@ def wrap_tool_call(
|
||||
def with_cache(request, handler):
|
||||
if cached := get_cache(request):
|
||||
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
||||
result = handler(request)
|
||||
result = handler(request.tool_call)
|
||||
save_cache(request, result)
|
||||
return result
|
||||
```
|
||||
@@ -1309,7 +1313,7 @@ def wrap_tool_call(
|
||||
def wrapped(
|
||||
self: AgentMiddleware, # noqa: ARG001
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
return func(request, handler)
|
||||
|
||||
|
||||
@@ -120,21 +120,22 @@ class ToolCallRequest:
|
||||
runtime: Any
|
||||
|
||||
|
||||
ToolCallHandler = Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]],
|
||||
ToolCallWrapper = Callable[
|
||||
[ToolCallRequest, Callable[[ToolCall], ToolMessage | Command]],
|
||||
ToolMessage | Command,
|
||||
]
|
||||
"""Handler-based tool call interceptor with multi-call support.
|
||||
"""Wrapper-based tool call interceptor with multi-call support.
|
||||
|
||||
Handler receives:
|
||||
Wrapper receives:
|
||||
request: ToolCallRequest with tool_call, tool, state, and runtime.
|
||||
execute: Callable to execute the tool (CAN BE CALLED MULTIPLE TIMES).
|
||||
Takes a ToolCall dict and returns ToolMessage or Command.
|
||||
|
||||
Returns:
|
||||
ToolMessage or Command (the final result).
|
||||
|
||||
The execute callable can be invoked multiple times for retry logic,
|
||||
with potentially modified requests each time. Each call to execute
|
||||
with potentially modified tool calls each time. Each call to execute
|
||||
is independent and stateless.
|
||||
|
||||
Note:
|
||||
@@ -145,21 +146,22 @@ Note:
|
||||
Examples:
|
||||
Passthrough (execute once):
|
||||
|
||||
def handler(request, execute):
|
||||
return execute(request)
|
||||
def wrapper(request, execute):
|
||||
return execute(request.tool_call)
|
||||
|
||||
Modify request before execution:
|
||||
|
||||
def handler(request, execute):
|
||||
request.tool_call["args"]["value"] *= 2
|
||||
return execute(request)
|
||||
def wrapper(request, execute):
|
||||
modified_call = request.tool_call.copy()
|
||||
modified_call["args"]["value"] *= 2
|
||||
return execute(modified_call)
|
||||
|
||||
Retry on error (execute multiple times):
|
||||
|
||||
def handler(request, execute):
|
||||
def wrapper(request, execute):
|
||||
for attempt in range(3):
|
||||
try:
|
||||
result = execute(request)
|
||||
result = execute(request.tool_call)
|
||||
if is_valid(result):
|
||||
return result
|
||||
except Exception:
|
||||
@@ -169,9 +171,9 @@ Examples:
|
||||
|
||||
Conditional retry based on response:
|
||||
|
||||
def handler(request, execute):
|
||||
def wrapper(request, execute):
|
||||
for attempt in range(3):
|
||||
result = execute(request)
|
||||
result = execute(request.tool_call)
|
||||
if isinstance(result, ToolMessage) and result.status != "error":
|
||||
return result
|
||||
if attempt < 2:
|
||||
@@ -180,10 +182,10 @@ Examples:
|
||||
|
||||
Cache/short-circuit without calling execute:
|
||||
|
||||
def handler(request, execute):
|
||||
def wrapper(request, execute):
|
||||
if cached := get_cache(request):
|
||||
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
||||
result = execute(request)
|
||||
result = execute(request.tool_call)
|
||||
save_cache(request, result)
|
||||
return result
|
||||
"""
|
||||
@@ -499,7 +501,7 @@ class ToolNode(RunnableCallable):
|
||||
| type[Exception]
|
||||
| tuple[type[Exception], ...] = _default_handle_tool_errors,
|
||||
messages_key: str = "messages",
|
||||
on_tool_call: ToolCallHandler | None = None,
|
||||
wrap_tool_call: ToolCallWrapper | None = None,
|
||||
) -> None:
|
||||
"""Initialize ToolNode with tools and configuration.
|
||||
|
||||
@@ -509,10 +511,10 @@ class ToolNode(RunnableCallable):
|
||||
tags: Optional metadata tags.
|
||||
handle_tool_errors: Error handling configuration.
|
||||
messages_key: State key containing messages.
|
||||
on_tool_call: Generator handler to intercept tool execution. Receives
|
||||
ToolCallRequest, yields requests, messages, or Commands; receives
|
||||
ToolMessage or Command via .send(). Final result is last value sent to
|
||||
handler. Enables retries, caching, request modification, and control flow.
|
||||
wrap_tool_call: Wrapper to intercept tool execution. Receives ToolCallRequest
|
||||
and execute callable. Execute callable takes ToolCall dict and returns
|
||||
ToolMessage or Command. Enables retries, caching, request modification,
|
||||
and control flow by calling execute multiple times with modified tool calls.
|
||||
"""
|
||||
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
|
||||
self._tools_by_name: dict[str, BaseTool] = {}
|
||||
@@ -520,7 +522,7 @@ class ToolNode(RunnableCallable):
|
||||
self._tool_to_store_arg: dict[str, str | None] = {}
|
||||
self._handle_tool_errors = handle_tool_errors
|
||||
self._messages_key = messages_key
|
||||
self._on_tool_call = on_tool_call
|
||||
self._wrap_tool_call = wrap_tool_call
|
||||
for tool in tools:
|
||||
if not isinstance(tool, BaseTool):
|
||||
tool_ = create_tool(cast("type[BaseTool]", tool))
|
||||
@@ -627,14 +629,16 @@ class ToolNode(RunnableCallable):
|
||||
|
||||
def _execute_tool_sync(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
call: ToolCall,
|
||||
tool: BaseTool,
|
||||
input_type: Literal["list", "dict", "tool_calls"],
|
||||
config: RunnableConfig,
|
||||
) -> ToolMessage | Command:
|
||||
"""Execute tool call with configured error handling.
|
||||
|
||||
Args:
|
||||
request: Tool execution request.
|
||||
call: Tool call dict with name, args, and id.
|
||||
tool: BaseTool instance to invoke.
|
||||
input_type: Input format.
|
||||
config: Runnable configuration.
|
||||
|
||||
@@ -644,8 +648,6 @@ class ToolNode(RunnableCallable):
|
||||
Raises:
|
||||
Exception: If tool fails and handle_tool_errors is False.
|
||||
"""
|
||||
call = request.tool_call
|
||||
tool = request.tool
|
||||
call_args = {**call, "type": "tool_call"}
|
||||
|
||||
try:
|
||||
@@ -698,7 +700,7 @@ class ToolNode(RunnableCallable):
|
||||
# Process successful response
|
||||
if isinstance(response, Command):
|
||||
# Validate Command before returning to handler
|
||||
return self._validate_tool_command(response, request.tool_call, input_type)
|
||||
return self._validate_tool_command(response, call, input_type)
|
||||
if isinstance(response, ToolMessage):
|
||||
response.content = cast("str | list", msg_content_output(response.content))
|
||||
return response
|
||||
@@ -714,7 +716,7 @@ class ToolNode(RunnableCallable):
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
runtime: Any,
|
||||
) -> ToolMessage | Command:
|
||||
"""Execute single tool call with on_tool_call handler if configured.
|
||||
"""Execute single tool call with wrap_tool_call wrapper if configured.
|
||||
|
||||
Args:
|
||||
call: Tool call dict.
|
||||
@@ -742,18 +744,18 @@ class ToolNode(RunnableCallable):
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
if self._on_tool_call is None:
|
||||
if self._wrap_tool_call is None:
|
||||
# No handler - execute directly
|
||||
return self._execute_tool_sync(tool_request, input_type, config)
|
||||
return self._execute_tool_sync(call, tool, input_type, config)
|
||||
|
||||
# Define execute callable that can be called multiple times
|
||||
def execute(req: ToolCallRequest) -> ToolMessage | Command:
|
||||
"""Execute tool with given request. Can be called multiple times."""
|
||||
return self._execute_tool_sync(req, input_type, config)
|
||||
def execute(tool_call: ToolCall) -> ToolMessage | Command:
|
||||
"""Execute tool with given tool call. Can be called multiple times."""
|
||||
return self._execute_tool_sync(tool_call, tool, input_type, config)
|
||||
|
||||
# Call handler with request and execute callable
|
||||
try:
|
||||
return self._on_tool_call(tool_request, execute)
|
||||
return self._wrap_tool_call(tool_request, execute)
|
||||
except Exception as e:
|
||||
# Handler threw an exception
|
||||
if not self._handle_tool_errors:
|
||||
@@ -769,14 +771,16 @@ class ToolNode(RunnableCallable):
|
||||
|
||||
async def _execute_tool_async(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
call: ToolCall,
|
||||
tool: BaseTool,
|
||||
input_type: Literal["list", "dict", "tool_calls"],
|
||||
config: RunnableConfig,
|
||||
) -> ToolMessage | Command:
|
||||
"""Execute tool call asynchronously with configured error handling.
|
||||
|
||||
Args:
|
||||
request: Tool execution request.
|
||||
call: Tool call dict with name, args, and id.
|
||||
tool: BaseTool instance to invoke.
|
||||
input_type: Input format.
|
||||
config: Runnable configuration.
|
||||
|
||||
@@ -786,8 +790,6 @@ class ToolNode(RunnableCallable):
|
||||
Raises:
|
||||
Exception: If tool fails and handle_tool_errors is False.
|
||||
"""
|
||||
call = request.tool_call
|
||||
tool = request.tool
|
||||
call_args = {**call, "type": "tool_call"}
|
||||
|
||||
try:
|
||||
@@ -840,7 +842,7 @@ class ToolNode(RunnableCallable):
|
||||
# Process successful response
|
||||
if isinstance(response, Command):
|
||||
# Validate Command before returning to handler
|
||||
return self._validate_tool_command(response, request.tool_call, input_type)
|
||||
return self._validate_tool_command(response, call, input_type)
|
||||
if isinstance(response, ToolMessage):
|
||||
response.content = cast("str | list", msg_content_output(response.content))
|
||||
return response
|
||||
@@ -856,7 +858,7 @@ class ToolNode(RunnableCallable):
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
runtime: Any,
|
||||
) -> ToolMessage | Command:
|
||||
"""Execute single tool call asynchronously with on_tool_call handler if configured.
|
||||
"""Execute single tool call asynchronously with wrap_tool_call wrapper if configured.
|
||||
|
||||
Args:
|
||||
call: Tool call dict.
|
||||
@@ -884,19 +886,19 @@ class ToolNode(RunnableCallable):
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
if self._on_tool_call is None:
|
||||
if self._wrap_tool_call is None:
|
||||
# No handler - execute directly
|
||||
return await self._execute_tool_async(tool_request, input_type, config)
|
||||
return await self._execute_tool_async(call, tool, input_type, config)
|
||||
|
||||
# Define async execute callable that can be called multiple times
|
||||
async def execute(req: ToolCallRequest) -> ToolMessage | Command:
|
||||
"""Execute tool with given request. Can be called multiple times."""
|
||||
return await self._execute_tool_async(req, input_type, config)
|
||||
async def execute(tool_call: ToolCall) -> ToolMessage | Command:
|
||||
"""Execute tool with given tool call. Can be called multiple times."""
|
||||
return await self._execute_tool_async(tool_call, tool, input_type, config)
|
||||
|
||||
# Call handler with request and execute callable
|
||||
# Note: handler is sync, but execute callable is async
|
||||
try:
|
||||
result = self._on_tool_call(tool_request, execute) # type: ignore[arg-type]
|
||||
result = self._wrap_tool_call(tool_request, execute) # type: ignore[arg-type]
|
||||
# If result is a coroutine, await it (though handler should be sync)
|
||||
return await result if hasattr(result, "__await__") else result
|
||||
except Exception as e:
|
||||
@@ -972,7 +974,7 @@ class ToolNode(RunnableCallable):
|
||||
input: The input which may be raw state or ToolCallWithContext.
|
||||
|
||||
Returns:
|
||||
The actual state to pass to on_tool_call handlers.
|
||||
The actual state to pass to wrap_tool_call wrappers.
|
||||
"""
|
||||
if isinstance(input, dict) and input.get("__type") == "tool_call_with_context":
|
||||
return input["state"]
|
||||
|
||||
@@ -44,7 +44,7 @@ def test_wrap_tool_call_basic_passthrough() -> None:
|
||||
@wrap_tool_call
|
||||
def passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("called")
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
@@ -79,7 +79,7 @@ def test_wrap_tool_call_logging() -> None:
|
||||
@wrap_tool_call
|
||||
def logging_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append(f"before_{request.tool.name}")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append(f"after_{request.tool.name}")
|
||||
return response
|
||||
|
||||
@@ -115,7 +115,7 @@ def test_wrap_tool_call_modify_args() -> None:
|
||||
# Modify the query argument before execution
|
||||
if request.tool.name == "search":
|
||||
request.tool_call["args"]["query"] = "modified query"
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
@@ -151,7 +151,7 @@ def test_wrap_tool_call_access_state() -> None:
|
||||
if request.state is not None:
|
||||
messages = request.state.get("messages", [])
|
||||
state_data.append(len(messages))
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
@@ -187,7 +187,7 @@ def test_wrap_tool_call_access_runtime() -> None:
|
||||
if request.runtime is not None:
|
||||
# Runtime object is available (has context, store, stream_writer, previous)
|
||||
runtime_data.append(type(request.runtime).__name__)
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
@@ -224,7 +224,7 @@ def test_wrap_tool_call_retry_on_error() -> None:
|
||||
for attempt in range(max_retries):
|
||||
attempt_counts.append(attempt)
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt == max_retries - 1:
|
||||
@@ -316,7 +316,7 @@ def test_wrap_tool_call_response_modification() -> None:
|
||||
|
||||
@wrap_tool_call
|
||||
def modify_response(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
|
||||
# Modify the response
|
||||
if isinstance(response, ToolMessage):
|
||||
@@ -359,14 +359,14 @@ def test_wrap_tool_call_multiple_middleware_composition() -> None:
|
||||
@wrap_tool_call
|
||||
def outer_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("outer_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("outer_after")
|
||||
return response
|
||||
|
||||
@wrap_tool_call
|
||||
def inner_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("inner_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("inner_after")
|
||||
return response
|
||||
|
||||
@@ -403,7 +403,7 @@ def test_wrap_tool_call_multiple_tools() -> None:
|
||||
@wrap_tool_call
|
||||
def log_tool_calls(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append(request.tool.name)
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
@@ -441,7 +441,7 @@ def test_wrap_tool_call_with_custom_name() -> None:
|
||||
|
||||
@wrap_tool_call(name="CustomToolWrapper")
|
||||
def my_wrapper(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
# Verify custom name was applied
|
||||
assert my_wrapper.__class__.__name__ == "CustomToolWrapper"
|
||||
@@ -457,7 +457,7 @@ def test_wrap_tool_call_with_tools_parameter() -> None:
|
||||
|
||||
@wrap_tool_call(tools=[extra_tool])
|
||||
def wrapper_with_tools(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
# Verify tools were registered
|
||||
assert wrapper_with_tools.tools == [extra_tool]
|
||||
@@ -470,21 +470,21 @@ def test_wrap_tool_call_three_levels_composition() -> None:
|
||||
@wrap_tool_call(name="OuterWrapper")
|
||||
def outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("outer_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("outer_after")
|
||||
return response
|
||||
|
||||
@wrap_tool_call(name="MiddleWrapper")
|
||||
def middle(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("middle_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("middle_after")
|
||||
return response
|
||||
|
||||
@wrap_tool_call(name="InnerWrapper")
|
||||
def inner(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("inner_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("inner_after")
|
||||
return response
|
||||
|
||||
@@ -528,7 +528,7 @@ def test_wrap_tool_call_outer_intercepts_inner() -> None:
|
||||
@wrap_tool_call(name="InterceptingOuter")
|
||||
def intercepting_outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("outer_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("outer_after")
|
||||
|
||||
# Return modified message
|
||||
@@ -541,7 +541,7 @@ def test_wrap_tool_call_outer_intercepts_inner() -> None:
|
||||
@wrap_tool_call(name="InnerWrapper")
|
||||
def inner(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("inner_called")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("inner_got_response")
|
||||
return response
|
||||
|
||||
@@ -584,7 +584,7 @@ def test_wrap_tool_call_inner_short_circuits() -> None:
|
||||
@wrap_tool_call(name="OuterWrapper")
|
||||
def outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("outer_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("outer_after")
|
||||
|
||||
# Wrap inner's response
|
||||
@@ -640,7 +640,7 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
|
||||
@wrap_tool_call(name="FirstPassthrough")
|
||||
def first_passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("first_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("first_after")
|
||||
return response
|
||||
|
||||
@@ -648,7 +648,7 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
|
||||
def second_intercepting(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("second_intercept")
|
||||
# Call handler but ignore result
|
||||
_ = handler(request)
|
||||
_ = handler(request.tool_call)
|
||||
# Return custom result
|
||||
return ToolMessage(
|
||||
content="intercepted_result",
|
||||
@@ -659,7 +659,7 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
|
||||
@wrap_tool_call(name="ThirdPassthrough")
|
||||
def third_passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
call_log.append("third_called")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("third_after")
|
||||
return response
|
||||
|
||||
@@ -701,7 +701,7 @@ def test_wrap_tool_call_uses_function_name_as_default() -> None:
|
||||
|
||||
@wrap_tool_call
|
||||
def my_custom_wrapper(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
# Verify that function name is used as middleware class name
|
||||
assert my_custom_wrapper.__class__.__name__ == "my_custom_wrapper"
|
||||
@@ -727,7 +727,7 @@ def test_wrap_tool_call_caching_pattern() -> None:
|
||||
|
||||
# Execute tool and cache result
|
||||
handler_calls.append("executed")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
|
||||
if isinstance(response, ToolMessage):
|
||||
cache[cache_key] = response.content
|
||||
@@ -771,7 +771,7 @@ def test_wrap_tool_call_monitoring_pattern() -> None:
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
metrics.append(
|
||||
|
||||
@@ -43,10 +43,10 @@ def test_simple_logging_middleware() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append(f"before_{request.tool.name}")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append(f"after_{request.tool.name}")
|
||||
return response
|
||||
|
||||
@@ -84,13 +84,13 @@ def test_request_modification_middleware() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
# Add prefix to query
|
||||
if request.tool.name == "search":
|
||||
original_query = request.tool_call["args"]["query"]
|
||||
request.tool_call["args"]["query"] = f"modified: {original_query}"
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
@@ -126,9 +126,9 @@ def test_response_inspection_middleware() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
|
||||
# Record response details
|
||||
if isinstance(response, ToolMessage):
|
||||
@@ -176,13 +176,13 @@ def test_conditional_retry_middleware() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
nonlocal call_count
|
||||
max_retries = 2
|
||||
|
||||
for attempt in range(max_retries):
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_count += 1
|
||||
|
||||
# Check if we should retry based on content
|
||||
@@ -235,10 +235,10 @@ def test_multiple_middleware_composition() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("outer_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("outer_after")
|
||||
return response
|
||||
|
||||
@@ -248,10 +248,10 @@ def test_multiple_middleware_composition() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("inner_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("inner_after")
|
||||
return response
|
||||
|
||||
@@ -290,10 +290,10 @@ def test_middleware_with_multiple_tool_calls() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append(request.tool.name)
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
@@ -336,7 +336,7 @@ def test_middleware_access_to_state() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
# Record state - state is now in request.state
|
||||
state = request.state
|
||||
@@ -347,7 +347,7 @@ def test_middleware_access_to_state() -> None:
|
||||
state_seen.append(("list", len(state)))
|
||||
else:
|
||||
state_seen.append(("other", type(state).__name__))
|
||||
return handler(request)
|
||||
return handler(request.tool_call)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
@@ -416,11 +416,11 @@ def test_generator_composition_immediate_outer_return() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("outer_yield")
|
||||
# Call handler, receive response from inner
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("outer_got_response")
|
||||
# Return modified message
|
||||
modified = ToolMessage(
|
||||
@@ -436,10 +436,10 @@ def test_generator_composition_immediate_outer_return() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("inner_called")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("inner_got_response")
|
||||
return response
|
||||
|
||||
@@ -480,10 +480,10 @@ def test_generator_composition_short_circuit() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("outer_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("outer_after")
|
||||
# Modify response from inner
|
||||
if isinstance(response, ToolMessage):
|
||||
@@ -501,7 +501,7 @@ def test_generator_composition_short_circuit() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("inner_short_circuit")
|
||||
# Don't call handler, return custom response directly
|
||||
@@ -548,11 +548,11 @@ def test_generator_composition_nested_retries() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
for outer_attempt in range(2):
|
||||
call_log.append(f"outer_{outer_attempt}")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
|
||||
if isinstance(response, ToolMessage) and response.content == "inner_final_failure":
|
||||
# Inner failed, retry once
|
||||
@@ -566,11 +566,11 @@ def test_generator_composition_nested_retries() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
for inner_attempt in range(2):
|
||||
call_log.append(f"inner_{inner_attempt}")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
|
||||
# Check for error in tool result
|
||||
if isinstance(response, ToolMessage):
|
||||
@@ -625,10 +625,10 @@ def test_generator_composition_three_levels() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("outer_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("outer_after")
|
||||
return response
|
||||
|
||||
@@ -638,10 +638,10 @@ def test_generator_composition_three_levels() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("middle_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("middle_after")
|
||||
return response
|
||||
|
||||
@@ -651,10 +651,10 @@ def test_generator_composition_three_levels() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("inner_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("inner_after")
|
||||
return response
|
||||
|
||||
@@ -701,9 +701,9 @@ def test_generator_composition_return_value_extraction() -> None:
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
|
||||
# Return a modified response
|
||||
if isinstance(response, ToolMessage):
|
||||
@@ -753,10 +753,10 @@ def test_generator_composition_with_mixed_passthrough_and_intercepting() -> None
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("first_before")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("first_after")
|
||||
return response
|
||||
|
||||
@@ -766,11 +766,11 @@ def test_generator_composition_with_mixed_passthrough_and_intercepting() -> None
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("second_intercept")
|
||||
# Call handler but ignore the result
|
||||
_ = handler(request)
|
||||
_ = handler(request.tool_call)
|
||||
# Return custom result
|
||||
return ToolMessage(
|
||||
content="intercepted_result",
|
||||
@@ -784,10 +784,10 @@ def test_generator_composition_with_mixed_passthrough_and_intercepting() -> None
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
call_log.append("third_called")
|
||||
response = handler(request)
|
||||
response = handler(request.tool_call)
|
||||
call_log.append("third_after")
|
||||
return response
|
||||
|
||||
|
||||
@@ -39,12 +39,12 @@ def test_passthrough_handler() -> None:
|
||||
|
||||
def passthrough_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Simple passthrough handler."""
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=passthrough_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=passthrough_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -75,12 +75,12 @@ async def test_passthrough_handler_async() -> None:
|
||||
|
||||
def passthrough_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Simple passthrough handler."""
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=passthrough_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=passthrough_handler)
|
||||
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
@@ -110,16 +110,18 @@ def test_modify_arguments() -> None:
|
||||
|
||||
def modify_args_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that doubles the input arguments."""
|
||||
# Modify the arguments
|
||||
request.tool_call["args"]["a"] *= 2
|
||||
request.tool_call["args"]["b"] *= 2
|
||||
# Modify the arguments by creating a modified tool call
|
||||
modified_call = request.tool_call.copy()
|
||||
modified_call["args"] = modified_call["args"].copy()
|
||||
modified_call["args"]["a"] *= 2
|
||||
modified_call["args"]["b"] *= 2
|
||||
|
||||
return execute(request)
|
||||
return execute(modified_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=modify_args_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=modify_args_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -149,12 +151,12 @@ def test_handler_validation_no_return() -> None:
|
||||
|
||||
def handler_with_explicit_none(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that executes and returns result."""
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=handler_with_explicit_none)
|
||||
tool_node = ToolNode([add], wrap_tool_call=handler_with_explicit_none)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -185,13 +187,13 @@ def test_handler_validation_no_yield() -> None:
|
||||
|
||||
def bad_handler(
|
||||
_request: ToolCallRequest,
|
||||
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
_execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that doesn't call execute - will cause type error."""
|
||||
# Don't call execute, just return None (invalid)
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=bad_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=bad_handler)
|
||||
|
||||
# This will return None wrapped in messages
|
||||
result = tool_node.invoke(
|
||||
@@ -217,20 +219,22 @@ def test_handler_validation_no_yield() -> None:
|
||||
|
||||
|
||||
def test_handler_with_handle_tool_errors_true() -> None:
|
||||
"""Test that handle_tool_errors=True works with on_tool_call handler."""
|
||||
"""Test that handle_tool_errors=True works with wrap_tool_call handler."""
|
||||
|
||||
def passthrough_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Simple passthrough handler."""
|
||||
message = execute(request)
|
||||
message = execute(request.tool_call)
|
||||
# When handle_tool_errors=True, errors should be converted to error messages
|
||||
assert isinstance(message, ToolMessage)
|
||||
assert message.status == "error"
|
||||
return message
|
||||
|
||||
tool_node = ToolNode([failing_tool], on_tool_call=passthrough_handler, handle_tool_errors=True)
|
||||
tool_node = ToolNode(
|
||||
[failing_tool], wrap_tool_call=passthrough_handler, handle_tool_errors=True
|
||||
)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -260,14 +264,14 @@ def test_multiple_tool_calls_with_handler() -> None:
|
||||
|
||||
def counting_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that counts calls."""
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=counting_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=counting_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -333,15 +337,17 @@ async def test_handler_with_async_execution() -> None:
|
||||
|
||||
def modifying_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that modifies arguments."""
|
||||
# Add 10 to both arguments
|
||||
request.tool_call["args"]["a"] += 10
|
||||
request.tool_call["args"]["b"] += 10
|
||||
return execute(request)
|
||||
# Add 10 to both arguments by creating a modified tool call
|
||||
modified_call = request.tool_call.copy()
|
||||
modified_call["args"] = modified_call["args"].copy()
|
||||
modified_call["args"]["a"] += 10
|
||||
modified_call["args"]["b"] += 10
|
||||
return execute(modified_call)
|
||||
|
||||
tool_node = ToolNode([async_add], on_tool_call=modifying_handler)
|
||||
tool_node = ToolNode([async_add], wrap_tool_call=modifying_handler)
|
||||
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
@@ -371,7 +377,7 @@ def test_short_circuit_with_tool_message() -> None:
|
||||
|
||||
def short_circuit_handler(
|
||||
request: ToolCallRequest,
|
||||
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
_execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that returns cached result without executing tool."""
|
||||
# Return a ToolMessage directly instead of calling execute
|
||||
@@ -381,7 +387,7 @@ def test_short_circuit_with_tool_message() -> None:
|
||||
name=request.tool_call["name"],
|
||||
)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=short_circuit_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=short_circuit_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -412,7 +418,7 @@ async def test_short_circuit_with_tool_message_async() -> None:
|
||||
|
||||
def short_circuit_handler(
|
||||
request: ToolCallRequest,
|
||||
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
_execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that returns cached result without executing tool."""
|
||||
return ToolMessage(
|
||||
@@ -421,7 +427,7 @@ async def test_short_circuit_with_tool_message_async() -> None:
|
||||
name=request.tool_call["name"],
|
||||
)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=short_circuit_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=short_circuit_handler)
|
||||
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
@@ -452,7 +458,7 @@ def test_conditional_short_circuit() -> None:
|
||||
|
||||
def conditional_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that caches even numbers, executes odd."""
|
||||
call_count["count"] += 1
|
||||
@@ -466,9 +472,9 @@ def test_conditional_short_circuit() -> None:
|
||||
name=request.tool_call["name"],
|
||||
)
|
||||
# Odd: execute normally
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=conditional_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=conditional_handler)
|
||||
|
||||
# Test with even number (should be cached)
|
||||
result1 = tool_node.invoke(
|
||||
@@ -518,7 +524,7 @@ def test_direct_return_tool_message() -> None:
|
||||
|
||||
def direct_return_handler(
|
||||
request: ToolCallRequest,
|
||||
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
_execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that returns ToolMessage directly."""
|
||||
# Return ToolMessage directly instead of calling execute
|
||||
@@ -528,7 +534,7 @@ def test_direct_return_tool_message() -> None:
|
||||
name=request.tool_call["name"],
|
||||
)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=direct_return_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=direct_return_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -559,7 +565,7 @@ async def test_direct_return_tool_message_async() -> None:
|
||||
|
||||
def direct_return_handler(
|
||||
request: ToolCallRequest,
|
||||
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
_execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that returns ToolMessage directly."""
|
||||
return ToolMessage(
|
||||
@@ -568,7 +574,7 @@ async def test_direct_return_tool_message_async() -> None:
|
||||
name=request.tool_call["name"],
|
||||
)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=direct_return_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=direct_return_handler)
|
||||
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
@@ -598,7 +604,7 @@ def test_conditional_direct_return() -> None:
|
||||
|
||||
def conditional_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that returns cached or executes based on condition."""
|
||||
a = request.tool_call["args"]["a"]
|
||||
@@ -611,9 +617,9 @@ def test_conditional_direct_return() -> None:
|
||||
name=request.tool_call["name"],
|
||||
)
|
||||
# Execute tool normally
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=conditional_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=conditional_handler)
|
||||
|
||||
# Test with zero (should return directly)
|
||||
result1 = tool_node.invoke(
|
||||
@@ -663,17 +669,17 @@ def test_handler_can_throw_exception() -> None:
|
||||
|
||||
def throwing_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that throws an exception after receiving response."""
|
||||
response = execute(request)
|
||||
response = execute(request.tool_call)
|
||||
# Check response and throw if invalid
|
||||
if isinstance(response, ToolMessage):
|
||||
msg = "Handler rejected the response"
|
||||
raise TypeError(msg)
|
||||
return response
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=throwing_handler, handle_tool_errors=True)
|
||||
tool_node = ToolNode([add], wrap_tool_call=throwing_handler, handle_tool_errors=True)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -705,14 +711,14 @@ def test_handler_throw_without_handle_errors() -> None:
|
||||
|
||||
def throwing_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that throws an exception."""
|
||||
execute(request)
|
||||
execute(request.tool_call)
|
||||
msg = "Handler error"
|
||||
raise ValueError(msg)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=throwing_handler, handle_tool_errors=False)
|
||||
tool_node = ToolNode([add], wrap_tool_call=throwing_handler, handle_tool_errors=False)
|
||||
|
||||
with pytest.raises(ValueError, match="Handler error"):
|
||||
tool_node.invoke(
|
||||
@@ -739,14 +745,14 @@ def test_retry_middleware_with_exception() -> None:
|
||||
|
||||
def retry_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that can retry by calling execute multiple times."""
|
||||
max_retries = 3
|
||||
|
||||
for _attempt in range(max_retries):
|
||||
attempt_count["count"] += 1
|
||||
response = execute(request)
|
||||
response = execute(request.tool_call)
|
||||
|
||||
# Simulate checking for retriable errors
|
||||
# In real use case, would check response.status or content
|
||||
@@ -757,7 +763,7 @@ def test_retry_middleware_with_exception() -> None:
|
||||
# If we exhausted retries, return last response
|
||||
return response
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=retry_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=retry_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -789,14 +795,14 @@ async def test_async_handler_can_throw_exception() -> None:
|
||||
|
||||
def throwing_handler(
|
||||
_request: ToolCallRequest,
|
||||
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
_execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that throws an exception before calling execute."""
|
||||
# Throw exception before executing (to avoid async/await complications)
|
||||
msg = "Async handler rejected the request"
|
||||
raise ValueError(msg)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=throwing_handler, handle_tool_errors=True)
|
||||
tool_node = ToolNode([add], wrap_tool_call=throwing_handler, handle_tool_errors=True)
|
||||
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
@@ -831,12 +837,12 @@ def test_handler_cannot_yield_multiple_tool_messages() -> None:
|
||||
|
||||
def single_return_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that returns once (as all handlers do)."""
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=single_return_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=single_return_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -867,13 +873,13 @@ def test_handler_cannot_yield_request_after_tool_message() -> None:
|
||||
|
||||
def single_return_handler(
|
||||
request: ToolCallRequest,
|
||||
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
_execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that returns cached result."""
|
||||
# Return cached result (short-circuit)
|
||||
return ToolMessage("cached", tool_call_id=request.tool_call["id"], name="add")
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=single_return_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=single_return_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -902,13 +908,13 @@ def test_handler_can_short_circuit_with_command() -> None:
|
||||
|
||||
def command_handler(
|
||||
_request: ToolCallRequest,
|
||||
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
_execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that short-circuits with Command."""
|
||||
# Short-circuit with Command instead of executing tool
|
||||
return Command(goto="end")
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=command_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=command_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -941,12 +947,12 @@ def test_handler_cannot_yield_multiple_commands() -> None:
|
||||
|
||||
def single_command_handler(
|
||||
_request: ToolCallRequest,
|
||||
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
_execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that returns Command once."""
|
||||
return Command(goto="step1")
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=single_command_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=single_command_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -979,12 +985,12 @@ def test_handler_cannot_yield_request_after_command() -> None:
|
||||
|
||||
def command_handler(
|
||||
_request: ToolCallRequest,
|
||||
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
_execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that returns Command."""
|
||||
return Command(goto="somewhere")
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=command_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=command_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -1016,16 +1022,16 @@ def test_tool_returning_command_sent_to_handler() -> None:
|
||||
|
||||
def command_inspector_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that inspects Command returned by tool."""
|
||||
result = execute(request)
|
||||
result = execute(request.tool_call)
|
||||
# Should receive Command from tool
|
||||
if isinstance(result, Command):
|
||||
received_commands.append(result)
|
||||
return result
|
||||
|
||||
tool_node = ToolNode([command_tool], on_tool_call=command_inspector_handler)
|
||||
tool_node = ToolNode([command_tool], wrap_tool_call=command_inspector_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -1060,16 +1066,16 @@ def test_handler_can_modify_command_from_tool() -> None:
|
||||
|
||||
def command_modifier_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that modifies Command returned by tool."""
|
||||
result = execute(request)
|
||||
result = execute(request.tool_call)
|
||||
# Modify the Command
|
||||
if isinstance(result, Command):
|
||||
return Command(goto=f"modified_{result.goto}")
|
||||
return result
|
||||
|
||||
tool_node = ToolNode([command_tool], on_tool_call=command_modifier_handler)
|
||||
tool_node = ToolNode([command_tool], wrap_tool_call=command_modifier_handler)
|
||||
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
@@ -1101,13 +1107,13 @@ def test_state_extraction_with_dict_input() -> None:
|
||||
|
||||
def state_inspector_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that records the state it receives."""
|
||||
state_seen.append(request.state)
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=state_inspector_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=state_inspector_handler)
|
||||
|
||||
input_state = {
|
||||
"messages": [
|
||||
@@ -1136,13 +1142,13 @@ def test_state_extraction_with_list_input() -> None:
|
||||
|
||||
def state_inspector_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that records the state it receives."""
|
||||
state_seen.append(request.state)
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=state_inspector_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=state_inspector_handler)
|
||||
|
||||
input_state = [
|
||||
AIMessage(
|
||||
@@ -1170,13 +1176,13 @@ def test_state_extraction_with_tool_call_with_context() -> None:
|
||||
|
||||
def state_inspector_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that records the state it receives."""
|
||||
state_seen.append(request.state)
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=state_inspector_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=state_inspector_handler)
|
||||
|
||||
# Simulate ToolCallWithContext as used by create_agent with Send API
|
||||
actual_state = {
|
||||
@@ -1213,13 +1219,13 @@ async def test_state_extraction_with_tool_call_with_context_async() -> None:
|
||||
|
||||
def state_inspector_handler(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
execute: Callable[[ToolCall], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Handler that records the state it receives."""
|
||||
state_seen.append(request.state)
|
||||
return execute(request)
|
||||
return execute(request.tool_call)
|
||||
|
||||
tool_node = ToolNode([add], on_tool_call=state_inspector_handler)
|
||||
tool_node = ToolNode([add], wrap_tool_call=state_inspector_handler)
|
||||
|
||||
# Simulate ToolCallWithContext as used by create_agent with Send API
|
||||
actual_state = {
|
||||
|
||||
Reference in New Issue
Block a user