Compare commits

...

2 Commits

Author SHA1 Message Date
Eugene Yurtsev
a5c882e484 x 2025-10-09 16:07:38 -04:00
Eugene Yurtsev
dcb954d03c x 2025-10-09 15:49:10 -04:00
6 changed files with 239 additions and 223 deletions

View File

@@ -18,7 +18,7 @@ if TYPE_CHECKING:
from collections.abc import Awaitable
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolCall, ToolMessage
from langchain_core.tools import BaseTool
from langgraph._internal._runnable import RunnableCallable
from langgraph.constants import END, START
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
from langgraph.store.base import BaseStore
from langgraph.types import Checkpointer
from langchain.tools.tool_node import ToolCallHandler, ToolCallRequest
from langchain.tools.tool_node import ToolCallRequest, ToolCallWrapper
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
@@ -373,8 +373,8 @@ def _handle_structured_output_error(
def _chain_tool_call_handlers(
handlers: Sequence[ToolCallHandler],
) -> ToolCallHandler | None:
handlers: Sequence[ToolCallWrapper],
) -> ToolCallWrapper | None:
"""Compose handlers into middleware stack (first = outermost).
Args:
@@ -394,19 +394,23 @@ def _chain_tool_call_handlers(
if len(handlers) == 1:
return handlers[0]
def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandler:
def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapper:
"""Compose two handlers where outer wraps inner."""
def composed(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
# Create a callable that invokes inner with the original execute
def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
return inner(req, execute)
# Create a wrapper that calls inner when outer calls its execute callable
# When outer calls this wrapper with a ToolCall, we invoke inner
# which receives the request and the base execute callable
def call_inner_wrapper(_tool_call: ToolCall) -> ToolMessage | Command:
# Outer may have modified the tool_call, but we ignore it here
# Inner receives the original request and base execute
return inner(request, execute)
# Outer can call call_inner multiple times
return outer(request, call_inner)
# Call outer handler with request and the wrapper that invokes inner
return outer(request, call_inner_wrapper)
return composed
@@ -567,7 +571,7 @@ def create_agent( # noqa: PLR0915
# Only create ToolNode if we have client-side tools
tool_node = (
ToolNode(tools=available_tools, on_tool_call=wrap_tool_call_handler)
ToolNode(tools=available_tools, wrap_tool_call=wrap_tool_call_handler)
if available_tools
else None
)

View File

@@ -22,7 +22,7 @@ if TYPE_CHECKING:
from langchain.tools.tool_node import ToolCallRequest
# needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage # noqa: TC002
from langchain_core.messages import AIMessage, AnyMessage, ToolCall, ToolMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.graph.message import add_messages
@@ -264,7 +264,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept tool execution for retries, monitoring, or modification.
@@ -279,22 +279,24 @@ class AgentMiddleware(Generic[StateT, ContextT]):
Returns:
ToolMessage or Command (the final result).
The handler callable can be invoked multiple times for retry logic.
Each call to handler is independent and stateless.
The handler callable can be invoked multiple times for retry logic
with potentially modified tool calls. Each call to handler is independent
and stateless.
Examples:
Modify request before execution:
def wrap_tool_call(self, request, handler):
request.tool_call["args"]["value"] *= 2
return handler(request)
modified_call = request.tool_call.copy()
modified_call["args"]["value"] *= 2
return handler(modified_call)
Retry on error (call handler multiple times):
def wrap_tool_call(self, request, handler):
for attempt in range(3):
try:
result = handler(request)
result = handler(request.tool_call)
if is_valid(result):
return result
except Exception:
@@ -306,7 +308,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def wrap_tool_call(self, request, handler):
for attempt in range(3):
result = handler(request)
result = handler(request.tool_call)
if isinstance(result, ToolMessage) and result.status != "error":
return result
if attempt < 2:
@@ -353,12 +355,13 @@ class _CallableReturningToolResponse(Protocol):
"""Callable for tool call interception with handler callback.
Receives handler callback to execute tool and returns final ToolMessage or Command.
Handler takes ToolCall dict and returns ToolMessage or Command.
"""
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept tool execution via handler callback."""
...
@@ -1267,7 +1270,7 @@ def wrap_tool_call(
```python
@wrap_tool_call
def passthrough(request, handler):
return handler(request)
return handler(request.tool_call)
```
Retry logic:
@@ -1277,7 +1280,7 @@ def wrap_tool_call(
max_retries = 3
for attempt in range(max_retries):
try:
return handler(request)
return handler(request.tool_call)
except Exception:
if attempt == max_retries - 1:
raise
@@ -1287,8 +1290,9 @@ def wrap_tool_call(
```python
@wrap_tool_call
def modify_args(request, handler):
request.tool_call["args"]["value"] *= 2
return handler(request)
modified_call = request.tool_call.copy()
modified_call["args"]["value"] *= 2
return handler(modified_call)
```
Short-circuit with cached result:
@@ -1297,7 +1301,7 @@ def wrap_tool_call(
def with_cache(request, handler):
if cached := get_cache(request):
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
result = handler(request)
result = handler(request.tool_call)
save_cache(request, result)
return result
```
@@ -1309,7 +1313,7 @@ def wrap_tool_call(
def wrapped(
self: AgentMiddleware, # noqa: ARG001
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
return func(request, handler)

View File

@@ -120,21 +120,22 @@ class ToolCallRequest:
runtime: Any
ToolCallHandler = Callable[
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]],
ToolCallWrapper = Callable[
[ToolCallRequest, Callable[[ToolCall], ToolMessage | Command]],
ToolMessage | Command,
]
"""Handler-based tool call interceptor with multi-call support.
"""Wrapper-based tool call interceptor with multi-call support.
Handler receives:
Wrapper receives:
request: ToolCallRequest with tool_call, tool, state, and runtime.
execute: Callable to execute the tool (CAN BE CALLED MULTIPLE TIMES).
Takes a ToolCall dict and returns ToolMessage or Command.
Returns:
ToolMessage or Command (the final result).
The execute callable can be invoked multiple times for retry logic,
with potentially modified requests each time. Each call to execute
with potentially modified tool calls each time. Each call to execute
is independent and stateless.
Note:
@@ -145,21 +146,22 @@ Note:
Examples:
Passthrough (execute once):
def handler(request, execute):
return execute(request)
def wrapper(request, execute):
return execute(request.tool_call)
Modify request before execution:
def handler(request, execute):
request.tool_call["args"]["value"] *= 2
return execute(request)
def wrapper(request, execute):
modified_call = request.tool_call.copy()
modified_call["args"]["value"] *= 2
return execute(modified_call)
Retry on error (execute multiple times):
def handler(request, execute):
def wrapper(request, execute):
for attempt in range(3):
try:
result = execute(request)
result = execute(request.tool_call)
if is_valid(result):
return result
except Exception:
@@ -169,9 +171,9 @@ Examples:
Conditional retry based on response:
def handler(request, execute):
def wrapper(request, execute):
for attempt in range(3):
result = execute(request)
result = execute(request.tool_call)
if isinstance(result, ToolMessage) and result.status != "error":
return result
if attempt < 2:
@@ -180,10 +182,10 @@ Examples:
Cache/short-circuit without calling execute:
def handler(request, execute):
def wrapper(request, execute):
if cached := get_cache(request):
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
result = execute(request)
result = execute(request.tool_call)
save_cache(request, result)
return result
"""
@@ -499,7 +501,7 @@ class ToolNode(RunnableCallable):
| type[Exception]
| tuple[type[Exception], ...] = _default_handle_tool_errors,
messages_key: str = "messages",
on_tool_call: ToolCallHandler | None = None,
wrap_tool_call: ToolCallWrapper | None = None,
) -> None:
"""Initialize ToolNode with tools and configuration.
@@ -509,10 +511,10 @@ class ToolNode(RunnableCallable):
tags: Optional metadata tags.
handle_tool_errors: Error handling configuration.
messages_key: State key containing messages.
on_tool_call: Generator handler to intercept tool execution. Receives
ToolCallRequest, yields requests, messages, or Commands; receives
ToolMessage or Command via .send(). Final result is last value sent to
handler. Enables retries, caching, request modification, and control flow.
wrap_tool_call: Wrapper to intercept tool execution. Receives ToolCallRequest
and execute callable. Execute callable takes ToolCall dict and returns
ToolMessage or Command. Enables retries, caching, request modification,
and control flow by calling execute multiple times with modified tool calls.
"""
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
self._tools_by_name: dict[str, BaseTool] = {}
@@ -520,7 +522,7 @@ class ToolNode(RunnableCallable):
self._tool_to_store_arg: dict[str, str | None] = {}
self._handle_tool_errors = handle_tool_errors
self._messages_key = messages_key
self._on_tool_call = on_tool_call
self._wrap_tool_call = wrap_tool_call
for tool in tools:
if not isinstance(tool, BaseTool):
tool_ = create_tool(cast("type[BaseTool]", tool))
@@ -627,14 +629,16 @@ class ToolNode(RunnableCallable):
def _execute_tool_sync(
self,
request: ToolCallRequest,
call: ToolCall,
tool: BaseTool,
input_type: Literal["list", "dict", "tool_calls"],
config: RunnableConfig,
) -> ToolMessage | Command:
"""Execute tool call with configured error handling.
Args:
request: Tool execution request.
call: Tool call dict with name, args, and id.
tool: BaseTool instance to invoke.
input_type: Input format.
config: Runnable configuration.
@@ -644,8 +648,6 @@ class ToolNode(RunnableCallable):
Raises:
Exception: If tool fails and handle_tool_errors is False.
"""
call = request.tool_call
tool = request.tool
call_args = {**call, "type": "tool_call"}
try:
@@ -698,7 +700,7 @@ class ToolNode(RunnableCallable):
# Process successful response
if isinstance(response, Command):
# Validate Command before returning to handler
return self._validate_tool_command(response, request.tool_call, input_type)
return self._validate_tool_command(response, call, input_type)
if isinstance(response, ToolMessage):
response.content = cast("str | list", msg_content_output(response.content))
return response
@@ -714,7 +716,7 @@ class ToolNode(RunnableCallable):
input: list[AnyMessage] | dict[str, Any] | BaseModel,
runtime: Any,
) -> ToolMessage | Command:
"""Execute single tool call with on_tool_call handler if configured.
"""Execute single tool call with wrap_tool_call wrapper if configured.
Args:
call: Tool call dict.
@@ -742,18 +744,18 @@ class ToolNode(RunnableCallable):
runtime=runtime,
)
if self._on_tool_call is None:
if self._wrap_tool_call is None:
# No handler - execute directly
return self._execute_tool_sync(tool_request, input_type, config)
return self._execute_tool_sync(call, tool, input_type, config)
# Define execute callable that can be called multiple times
def execute(req: ToolCallRequest) -> ToolMessage | Command:
"""Execute tool with given request. Can be called multiple times."""
return self._execute_tool_sync(req, input_type, config)
def execute(tool_call: ToolCall) -> ToolMessage | Command:
"""Execute tool with given tool call. Can be called multiple times."""
return self._execute_tool_sync(tool_call, tool, input_type, config)
# Call handler with request and execute callable
try:
return self._on_tool_call(tool_request, execute)
return self._wrap_tool_call(tool_request, execute)
except Exception as e:
# Handler threw an exception
if not self._handle_tool_errors:
@@ -769,14 +771,16 @@ class ToolNode(RunnableCallable):
async def _execute_tool_async(
self,
request: ToolCallRequest,
call: ToolCall,
tool: BaseTool,
input_type: Literal["list", "dict", "tool_calls"],
config: RunnableConfig,
) -> ToolMessage | Command:
"""Execute tool call asynchronously with configured error handling.
Args:
request: Tool execution request.
call: Tool call dict with name, args, and id.
tool: BaseTool instance to invoke.
input_type: Input format.
config: Runnable configuration.
@@ -786,8 +790,6 @@ class ToolNode(RunnableCallable):
Raises:
Exception: If tool fails and handle_tool_errors is False.
"""
call = request.tool_call
tool = request.tool
call_args = {**call, "type": "tool_call"}
try:
@@ -840,7 +842,7 @@ class ToolNode(RunnableCallable):
# Process successful response
if isinstance(response, Command):
# Validate Command before returning to handler
return self._validate_tool_command(response, request.tool_call, input_type)
return self._validate_tool_command(response, call, input_type)
if isinstance(response, ToolMessage):
response.content = cast("str | list", msg_content_output(response.content))
return response
@@ -856,7 +858,7 @@ class ToolNode(RunnableCallable):
input: list[AnyMessage] | dict[str, Any] | BaseModel,
runtime: Any,
) -> ToolMessage | Command:
"""Execute single tool call asynchronously with on_tool_call handler if configured.
"""Execute single tool call asynchronously with wrap_tool_call wrapper if configured.
Args:
call: Tool call dict.
@@ -884,19 +886,19 @@ class ToolNode(RunnableCallable):
runtime=runtime,
)
if self._on_tool_call is None:
if self._wrap_tool_call is None:
# No handler - execute directly
return await self._execute_tool_async(tool_request, input_type, config)
return await self._execute_tool_async(call, tool, input_type, config)
# Define async execute callable that can be called multiple times
async def execute(req: ToolCallRequest) -> ToolMessage | Command:
"""Execute tool with given request. Can be called multiple times."""
return await self._execute_tool_async(req, input_type, config)
async def execute(tool_call: ToolCall) -> ToolMessage | Command:
"""Execute tool with given tool call. Can be called multiple times."""
return await self._execute_tool_async(tool_call, tool, input_type, config)
# Call handler with request and execute callable
# Note: handler is sync, but execute callable is async
try:
result = self._on_tool_call(tool_request, execute) # type: ignore[arg-type]
result = self._wrap_tool_call(tool_request, execute) # type: ignore[arg-type]
# If result is a coroutine, await it (though handler should be sync)
return await result if hasattr(result, "__await__") else result
except Exception as e:
@@ -972,7 +974,7 @@ class ToolNode(RunnableCallable):
input: The input which may be raw state or ToolCallWithContext.
Returns:
The actual state to pass to on_tool_call handlers.
The actual state to pass to wrap_tool_call wrappers.
"""
if isinstance(input, dict) and input.get("__type") == "tool_call_with_context":
return input["state"]

View File

@@ -44,7 +44,7 @@ def test_wrap_tool_call_basic_passthrough() -> None:
@wrap_tool_call
def passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("called")
return handler(request)
return handler(request.tool_call)
model = FakeToolCallingModel(
tool_calls=[
@@ -79,7 +79,7 @@ def test_wrap_tool_call_logging() -> None:
@wrap_tool_call
def logging_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append(f"before_{request.tool.name}")
response = handler(request)
response = handler(request.tool_call)
call_log.append(f"after_{request.tool.name}")
return response
@@ -115,7 +115,7 @@ def test_wrap_tool_call_modify_args() -> None:
# Modify the query argument before execution
if request.tool.name == "search":
request.tool_call["args"]["query"] = "modified query"
return handler(request)
return handler(request.tool_call)
model = FakeToolCallingModel(
tool_calls=[
@@ -151,7 +151,7 @@ def test_wrap_tool_call_access_state() -> None:
if request.state is not None:
messages = request.state.get("messages", [])
state_data.append(len(messages))
return handler(request)
return handler(request.tool_call)
model = FakeToolCallingModel(
tool_calls=[
@@ -187,7 +187,7 @@ def test_wrap_tool_call_access_runtime() -> None:
if request.runtime is not None:
# Runtime object is available (has context, store, stream_writer, previous)
runtime_data.append(type(request.runtime).__name__)
return handler(request)
return handler(request.tool_call)
model = FakeToolCallingModel(
tool_calls=[
@@ -224,7 +224,7 @@ def test_wrap_tool_call_retry_on_error() -> None:
for attempt in range(max_retries):
attempt_counts.append(attempt)
try:
return handler(request)
return handler(request.tool_call)
except Exception as e:
last_error = e
if attempt == max_retries - 1:
@@ -316,7 +316,7 @@ def test_wrap_tool_call_response_modification() -> None:
@wrap_tool_call
def modify_response(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
response = handler(request)
response = handler(request.tool_call)
# Modify the response
if isinstance(response, ToolMessage):
@@ -359,14 +359,14 @@ def test_wrap_tool_call_multiple_middleware_composition() -> None:
@wrap_tool_call
def outer_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("outer_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("outer_after")
return response
@wrap_tool_call
def inner_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("inner_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("inner_after")
return response
@@ -403,7 +403,7 @@ def test_wrap_tool_call_multiple_tools() -> None:
@wrap_tool_call
def log_tool_calls(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append(request.tool.name)
return handler(request)
return handler(request.tool_call)
model = FakeToolCallingModel(
tool_calls=[
@@ -441,7 +441,7 @@ def test_wrap_tool_call_with_custom_name() -> None:
@wrap_tool_call(name="CustomToolWrapper")
def my_wrapper(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
return handler(request)
return handler(request.tool_call)
# Verify custom name was applied
assert my_wrapper.__class__.__name__ == "CustomToolWrapper"
@@ -457,7 +457,7 @@ def test_wrap_tool_call_with_tools_parameter() -> None:
@wrap_tool_call(tools=[extra_tool])
def wrapper_with_tools(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
return handler(request)
return handler(request.tool_call)
# Verify tools were registered
assert wrapper_with_tools.tools == [extra_tool]
@@ -470,21 +470,21 @@ def test_wrap_tool_call_three_levels_composition() -> None:
@wrap_tool_call(name="OuterWrapper")
def outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("outer_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("outer_after")
return response
@wrap_tool_call(name="MiddleWrapper")
def middle(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("middle_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("middle_after")
return response
@wrap_tool_call(name="InnerWrapper")
def inner(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("inner_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("inner_after")
return response
@@ -528,7 +528,7 @@ def test_wrap_tool_call_outer_intercepts_inner() -> None:
@wrap_tool_call(name="InterceptingOuter")
def intercepting_outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("outer_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("outer_after")
# Return modified message
@@ -541,7 +541,7 @@ def test_wrap_tool_call_outer_intercepts_inner() -> None:
@wrap_tool_call(name="InnerWrapper")
def inner(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("inner_called")
response = handler(request)
response = handler(request.tool_call)
call_log.append("inner_got_response")
return response
@@ -584,7 +584,7 @@ def test_wrap_tool_call_inner_short_circuits() -> None:
@wrap_tool_call(name="OuterWrapper")
def outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("outer_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("outer_after")
# Wrap inner's response
@@ -640,7 +640,7 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
@wrap_tool_call(name="FirstPassthrough")
def first_passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("first_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("first_after")
return response
@@ -648,7 +648,7 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
def second_intercepting(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("second_intercept")
# Call handler but ignore result
_ = handler(request)
_ = handler(request.tool_call)
# Return custom result
return ToolMessage(
content="intercepted_result",
@@ -659,7 +659,7 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
@wrap_tool_call(name="ThirdPassthrough")
def third_passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
call_log.append("third_called")
response = handler(request)
response = handler(request.tool_call)
call_log.append("third_after")
return response
@@ -701,7 +701,7 @@ def test_wrap_tool_call_uses_function_name_as_default() -> None:
@wrap_tool_call
def my_custom_wrapper(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
return handler(request)
return handler(request.tool_call)
# Verify that function name is used as middleware class name
assert my_custom_wrapper.__class__.__name__ == "my_custom_wrapper"
@@ -727,7 +727,7 @@ def test_wrap_tool_call_caching_pattern() -> None:
# Execute tool and cache result
handler_calls.append("executed")
response = handler(request)
response = handler(request.tool_call)
if isinstance(response, ToolMessage):
cache[cache_key] = response.content
@@ -771,7 +771,7 @@ def test_wrap_tool_call_monitoring_pattern() -> None:
import time
start_time = time.time()
response = handler(request)
response = handler(request.tool_call)
execution_time = time.time() - start_time
metrics.append(

View File

@@ -43,10 +43,10 @@ def test_simple_logging_middleware() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append(f"before_{request.tool.name}")
response = handler(request)
response = handler(request.tool_call)
call_log.append(f"after_{request.tool.name}")
return response
@@ -84,13 +84,13 @@ def test_request_modification_middleware() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
# Add prefix to query
if request.tool.name == "search":
original_query = request.tool_call["args"]["query"]
request.tool_call["args"]["query"] = f"modified: {original_query}"
return handler(request)
return handler(request.tool_call)
model = FakeToolCallingModel(
tool_calls=[
@@ -126,9 +126,9 @@ def test_response_inspection_middleware() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
response = handler(request)
response = handler(request.tool_call)
# Record response details
if isinstance(response, ToolMessage):
@@ -176,13 +176,13 @@ def test_conditional_retry_middleware() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
nonlocal call_count
max_retries = 2
for attempt in range(max_retries):
response = handler(request)
response = handler(request.tool_call)
call_count += 1
# Check if we should retry based on content
@@ -235,10 +235,10 @@ def test_multiple_middleware_composition() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("outer_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("outer_after")
return response
@@ -248,10 +248,10 @@ def test_multiple_middleware_composition() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("inner_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("inner_after")
return response
@@ -290,10 +290,10 @@ def test_middleware_with_multiple_tool_calls() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append(request.tool.name)
return handler(request)
return handler(request.tool_call)
model = FakeToolCallingModel(
tool_calls=[
@@ -336,7 +336,7 @@ def test_middleware_access_to_state() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
# Record state - state is now in request.state
state = request.state
@@ -347,7 +347,7 @@ def test_middleware_access_to_state() -> None:
state_seen.append(("list", len(state)))
else:
state_seen.append(("other", type(state).__name__))
return handler(request)
return handler(request.tool_call)
model = FakeToolCallingModel(
tool_calls=[
@@ -416,11 +416,11 @@ def test_generator_composition_immediate_outer_return() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("outer_yield")
# Call handler, receive response from inner
response = handler(request)
response = handler(request.tool_call)
call_log.append("outer_got_response")
# Return modified message
modified = ToolMessage(
@@ -436,10 +436,10 @@ def test_generator_composition_immediate_outer_return() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("inner_called")
response = handler(request)
response = handler(request.tool_call)
call_log.append("inner_got_response")
return response
@@ -480,10 +480,10 @@ def test_generator_composition_short_circuit() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("outer_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("outer_after")
# Modify response from inner
if isinstance(response, ToolMessage):
@@ -501,7 +501,7 @@ def test_generator_composition_short_circuit() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("inner_short_circuit")
# Don't call handler, return custom response directly
@@ -548,11 +548,11 @@ def test_generator_composition_nested_retries() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
for outer_attempt in range(2):
call_log.append(f"outer_{outer_attempt}")
response = handler(request)
response = handler(request.tool_call)
if isinstance(response, ToolMessage) and response.content == "inner_final_failure":
# Inner failed, retry once
@@ -566,11 +566,11 @@ def test_generator_composition_nested_retries() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
for inner_attempt in range(2):
call_log.append(f"inner_{inner_attempt}")
response = handler(request)
response = handler(request.tool_call)
# Check for error in tool result
if isinstance(response, ToolMessage):
@@ -625,10 +625,10 @@ def test_generator_composition_three_levels() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("outer_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("outer_after")
return response
@@ -638,10 +638,10 @@ def test_generator_composition_three_levels() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("middle_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("middle_after")
return response
@@ -651,10 +651,10 @@ def test_generator_composition_three_levels() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("inner_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("inner_after")
return response
@@ -701,9 +701,9 @@ def test_generator_composition_return_value_extraction() -> None:
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
response = handler(request)
response = handler(request.tool_call)
# Return a modified response
if isinstance(response, ToolMessage):
@@ -753,10 +753,10 @@ def test_generator_composition_with_mixed_passthrough_and_intercepting() -> None
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("first_before")
response = handler(request)
response = handler(request.tool_call)
call_log.append("first_after")
return response
@@ -766,11 +766,11 @@ def test_generator_composition_with_mixed_passthrough_and_intercepting() -> None
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("second_intercept")
# Call handler but ignore the result
_ = handler(request)
_ = handler(request.tool_call)
# Return custom result
return ToolMessage(
content="intercepted_result",
@@ -784,10 +784,10 @@ def test_generator_composition_with_mixed_passthrough_and_intercepting() -> None
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
call_log.append("third_called")
response = handler(request)
response = handler(request.tool_call)
call_log.append("third_after")
return response

View File

@@ -39,12 +39,12 @@ def test_passthrough_handler() -> None:
def passthrough_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Simple passthrough handler."""
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=passthrough_handler)
tool_node = ToolNode([add], wrap_tool_call=passthrough_handler)
result = tool_node.invoke(
{
@@ -75,12 +75,12 @@ async def test_passthrough_handler_async() -> None:
def passthrough_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Simple passthrough handler."""
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=passthrough_handler)
tool_node = ToolNode([add], wrap_tool_call=passthrough_handler)
result = await tool_node.ainvoke(
{
@@ -110,16 +110,18 @@ def test_modify_arguments() -> None:
def modify_args_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that doubles the input arguments."""
# Modify the arguments
request.tool_call["args"]["a"] *= 2
request.tool_call["args"]["b"] *= 2
# Modify the arguments by creating a modified tool call
modified_call = request.tool_call.copy()
modified_call["args"] = modified_call["args"].copy()
modified_call["args"]["a"] *= 2
modified_call["args"]["b"] *= 2
return execute(request)
return execute(modified_call)
tool_node = ToolNode([add], on_tool_call=modify_args_handler)
tool_node = ToolNode([add], wrap_tool_call=modify_args_handler)
result = tool_node.invoke(
{
@@ -149,12 +151,12 @@ def test_handler_validation_no_return() -> None:
def handler_with_explicit_none(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that executes and returns result."""
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=handler_with_explicit_none)
tool_node = ToolNode([add], wrap_tool_call=handler_with_explicit_none)
result = tool_node.invoke(
{
@@ -185,13 +187,13 @@ def test_handler_validation_no_yield() -> None:
def bad_handler(
_request: ToolCallRequest,
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
_execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that doesn't call execute - will cause type error."""
# Don't call execute, just return None (invalid)
return None # type: ignore[return-value]
tool_node = ToolNode([add], on_tool_call=bad_handler)
tool_node = ToolNode([add], wrap_tool_call=bad_handler)
# This will return None wrapped in messages
result = tool_node.invoke(
@@ -217,20 +219,22 @@ def test_handler_validation_no_yield() -> None:
def test_handler_with_handle_tool_errors_true() -> None:
"""Test that handle_tool_errors=True works with on_tool_call handler."""
"""Test that handle_tool_errors=True works with wrap_tool_call handler."""
def passthrough_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Simple passthrough handler."""
message = execute(request)
message = execute(request.tool_call)
# When handle_tool_errors=True, errors should be converted to error messages
assert isinstance(message, ToolMessage)
assert message.status == "error"
return message
tool_node = ToolNode([failing_tool], on_tool_call=passthrough_handler, handle_tool_errors=True)
tool_node = ToolNode(
[failing_tool], wrap_tool_call=passthrough_handler, handle_tool_errors=True
)
result = tool_node.invoke(
{
@@ -260,14 +264,14 @@ def test_multiple_tool_calls_with_handler() -> None:
def counting_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that counts calls."""
nonlocal call_count
call_count += 1
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=counting_handler)
tool_node = ToolNode([add], wrap_tool_call=counting_handler)
result = tool_node.invoke(
{
@@ -333,15 +337,17 @@ async def test_handler_with_async_execution() -> None:
def modifying_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that modifies arguments."""
# Add 10 to both arguments
request.tool_call["args"]["a"] += 10
request.tool_call["args"]["b"] += 10
return execute(request)
# Add 10 to both arguments by creating a modified tool call
modified_call = request.tool_call.copy()
modified_call["args"] = modified_call["args"].copy()
modified_call["args"]["a"] += 10
modified_call["args"]["b"] += 10
return execute(modified_call)
tool_node = ToolNode([async_add], on_tool_call=modifying_handler)
tool_node = ToolNode([async_add], wrap_tool_call=modifying_handler)
result = await tool_node.ainvoke(
{
@@ -371,7 +377,7 @@ def test_short_circuit_with_tool_message() -> None:
def short_circuit_handler(
request: ToolCallRequest,
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
_execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that returns cached result without executing tool."""
# Return a ToolMessage directly instead of calling execute
@@ -381,7 +387,7 @@ def test_short_circuit_with_tool_message() -> None:
name=request.tool_call["name"],
)
tool_node = ToolNode([add], on_tool_call=short_circuit_handler)
tool_node = ToolNode([add], wrap_tool_call=short_circuit_handler)
result = tool_node.invoke(
{
@@ -412,7 +418,7 @@ async def test_short_circuit_with_tool_message_async() -> None:
def short_circuit_handler(
request: ToolCallRequest,
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
_execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that returns cached result without executing tool."""
return ToolMessage(
@@ -421,7 +427,7 @@ async def test_short_circuit_with_tool_message_async() -> None:
name=request.tool_call["name"],
)
tool_node = ToolNode([add], on_tool_call=short_circuit_handler)
tool_node = ToolNode([add], wrap_tool_call=short_circuit_handler)
result = await tool_node.ainvoke(
{
@@ -452,7 +458,7 @@ def test_conditional_short_circuit() -> None:
def conditional_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that caches even numbers, executes odd."""
call_count["count"] += 1
@@ -466,9 +472,9 @@ def test_conditional_short_circuit() -> None:
name=request.tool_call["name"],
)
# Odd: execute normally
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=conditional_handler)
tool_node = ToolNode([add], wrap_tool_call=conditional_handler)
# Test with even number (should be cached)
result1 = tool_node.invoke(
@@ -518,7 +524,7 @@ def test_direct_return_tool_message() -> None:
def direct_return_handler(
request: ToolCallRequest,
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
_execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that returns ToolMessage directly."""
# Return ToolMessage directly instead of calling execute
@@ -528,7 +534,7 @@ def test_direct_return_tool_message() -> None:
name=request.tool_call["name"],
)
tool_node = ToolNode([add], on_tool_call=direct_return_handler)
tool_node = ToolNode([add], wrap_tool_call=direct_return_handler)
result = tool_node.invoke(
{
@@ -559,7 +565,7 @@ async def test_direct_return_tool_message_async() -> None:
def direct_return_handler(
request: ToolCallRequest,
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
_execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that returns ToolMessage directly."""
return ToolMessage(
@@ -568,7 +574,7 @@ async def test_direct_return_tool_message_async() -> None:
name=request.tool_call["name"],
)
tool_node = ToolNode([add], on_tool_call=direct_return_handler)
tool_node = ToolNode([add], wrap_tool_call=direct_return_handler)
result = await tool_node.ainvoke(
{
@@ -598,7 +604,7 @@ def test_conditional_direct_return() -> None:
def conditional_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that returns cached or executes based on condition."""
a = request.tool_call["args"]["a"]
@@ -611,9 +617,9 @@ def test_conditional_direct_return() -> None:
name=request.tool_call["name"],
)
# Execute tool normally
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=conditional_handler)
tool_node = ToolNode([add], wrap_tool_call=conditional_handler)
# Test with zero (should return directly)
result1 = tool_node.invoke(
@@ -663,17 +669,17 @@ def test_handler_can_throw_exception() -> None:
def throwing_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that throws an exception after receiving response."""
response = execute(request)
response = execute(request.tool_call)
# Check response and throw if invalid
if isinstance(response, ToolMessage):
msg = "Handler rejected the response"
raise TypeError(msg)
return response
tool_node = ToolNode([add], on_tool_call=throwing_handler, handle_tool_errors=True)
tool_node = ToolNode([add], wrap_tool_call=throwing_handler, handle_tool_errors=True)
result = tool_node.invoke(
{
@@ -705,14 +711,14 @@ def test_handler_throw_without_handle_errors() -> None:
def throwing_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that throws an exception."""
execute(request)
execute(request.tool_call)
msg = "Handler error"
raise ValueError(msg)
tool_node = ToolNode([add], on_tool_call=throwing_handler, handle_tool_errors=False)
tool_node = ToolNode([add], wrap_tool_call=throwing_handler, handle_tool_errors=False)
with pytest.raises(ValueError, match="Handler error"):
tool_node.invoke(
@@ -739,14 +745,14 @@ def test_retry_middleware_with_exception() -> None:
def retry_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that can retry by calling execute multiple times."""
max_retries = 3
for _attempt in range(max_retries):
attempt_count["count"] += 1
response = execute(request)
response = execute(request.tool_call)
# Simulate checking for retriable errors
# In real use case, would check response.status or content
@@ -757,7 +763,7 @@ def test_retry_middleware_with_exception() -> None:
# If we exhausted retries, return last response
return response
tool_node = ToolNode([add], on_tool_call=retry_handler)
tool_node = ToolNode([add], wrap_tool_call=retry_handler)
result = tool_node.invoke(
{
@@ -789,14 +795,14 @@ async def test_async_handler_can_throw_exception() -> None:
def throwing_handler(
_request: ToolCallRequest,
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
_execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that throws an exception before calling execute."""
# Throw exception before executing (to avoid async/await complications)
msg = "Async handler rejected the request"
raise ValueError(msg)
tool_node = ToolNode([add], on_tool_call=throwing_handler, handle_tool_errors=True)
tool_node = ToolNode([add], wrap_tool_call=throwing_handler, handle_tool_errors=True)
result = await tool_node.ainvoke(
{
@@ -831,12 +837,12 @@ def test_handler_cannot_yield_multiple_tool_messages() -> None:
def single_return_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that returns once (as all handlers do)."""
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=single_return_handler)
tool_node = ToolNode([add], wrap_tool_call=single_return_handler)
result = tool_node.invoke(
{
@@ -867,13 +873,13 @@ def test_handler_cannot_yield_request_after_tool_message() -> None:
def single_return_handler(
request: ToolCallRequest,
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
_execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that returns cached result."""
# Return cached result (short-circuit)
return ToolMessage("cached", tool_call_id=request.tool_call["id"], name="add")
tool_node = ToolNode([add], on_tool_call=single_return_handler)
tool_node = ToolNode([add], wrap_tool_call=single_return_handler)
result = tool_node.invoke(
{
@@ -902,13 +908,13 @@ def test_handler_can_short_circuit_with_command() -> None:
def command_handler(
_request: ToolCallRequest,
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
_execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that short-circuits with Command."""
# Short-circuit with Command instead of executing tool
return Command(goto="end")
tool_node = ToolNode([add], on_tool_call=command_handler)
tool_node = ToolNode([add], wrap_tool_call=command_handler)
result = tool_node.invoke(
{
@@ -941,12 +947,12 @@ def test_handler_cannot_yield_multiple_commands() -> None:
def single_command_handler(
_request: ToolCallRequest,
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
_execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that returns Command once."""
return Command(goto="step1")
tool_node = ToolNode([add], on_tool_call=single_command_handler)
tool_node = ToolNode([add], wrap_tool_call=single_command_handler)
result = tool_node.invoke(
{
@@ -979,12 +985,12 @@ def test_handler_cannot_yield_request_after_command() -> None:
def command_handler(
_request: ToolCallRequest,
_execute: Callable[[ToolCallRequest], ToolMessage | Command],
_execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that returns Command."""
return Command(goto="somewhere")
tool_node = ToolNode([add], on_tool_call=command_handler)
tool_node = ToolNode([add], wrap_tool_call=command_handler)
result = tool_node.invoke(
{
@@ -1016,16 +1022,16 @@ def test_tool_returning_command_sent_to_handler() -> None:
def command_inspector_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that inspects Command returned by tool."""
result = execute(request)
result = execute(request.tool_call)
# Should receive Command from tool
if isinstance(result, Command):
received_commands.append(result)
return result
tool_node = ToolNode([command_tool], on_tool_call=command_inspector_handler)
tool_node = ToolNode([command_tool], wrap_tool_call=command_inspector_handler)
result = tool_node.invoke(
{
@@ -1060,16 +1066,16 @@ def test_handler_can_modify_command_from_tool() -> None:
def command_modifier_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that modifies Command returned by tool."""
result = execute(request)
result = execute(request.tool_call)
# Modify the Command
if isinstance(result, Command):
return Command(goto=f"modified_{result.goto}")
return result
tool_node = ToolNode([command_tool], on_tool_call=command_modifier_handler)
tool_node = ToolNode([command_tool], wrap_tool_call=command_modifier_handler)
result = tool_node.invoke(
{
@@ -1101,13 +1107,13 @@ def test_state_extraction_with_dict_input() -> None:
def state_inspector_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that records the state it receives."""
state_seen.append(request.state)
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=state_inspector_handler)
tool_node = ToolNode([add], wrap_tool_call=state_inspector_handler)
input_state = {
"messages": [
@@ -1136,13 +1142,13 @@ def test_state_extraction_with_list_input() -> None:
def state_inspector_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that records the state it receives."""
state_seen.append(request.state)
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=state_inspector_handler)
tool_node = ToolNode([add], wrap_tool_call=state_inspector_handler)
input_state = [
AIMessage(
@@ -1170,13 +1176,13 @@ def test_state_extraction_with_tool_call_with_context() -> None:
def state_inspector_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that records the state it receives."""
state_seen.append(request.state)
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=state_inspector_handler)
tool_node = ToolNode([add], wrap_tool_call=state_inspector_handler)
# Simulate ToolCallWithContext as used by create_agent with Send API
actual_state = {
@@ -1213,13 +1219,13 @@ async def test_state_extraction_with_tool_call_with_context_async() -> None:
def state_inspector_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
execute: Callable[[ToolCall], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that records the state it receives."""
state_seen.append(request.state)
return execute(request)
return execute(request.tool_call)
tool_node = ToolNode([add], on_tool_call=state_inspector_handler)
tool_node = ToolNode([add], wrap_tool_call=state_inspector_handler)
# Simulate ToolCallWithContext as used by create_agent with Send API
actual_state = {