chore(langchain): fix types in test_wrap_tool_call (#34600)

This commit is contained in:
Christophe Bornet
2026-01-05 20:38:31 +01:00
committed by GitHub
parent 5ae53fdfb3
commit c4babed5c6

View File

@@ -6,9 +6,10 @@ focusing on the handler pattern (not generators).
import time import time
from collections.abc import Callable from collections.abc import Callable
from typing import Any
from langchain_core.messages import HumanMessage, ToolCall, ToolMessage from langchain_core.messages import HumanMessage, ToolCall, ToolMessage
from langchain_core.tools import tool from langchain_core.tools import BaseTool, tool
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from langgraph.types import Command from langgraph.types import Command
@@ -41,7 +42,9 @@ def test_wrap_tool_call_basic_passthrough() -> None:
call_log = [] call_log = []
@wrap_tool_call @wrap_tool_call
def passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def passthrough(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("called") call_log.append("called")
return handler(request) return handler(request)
@@ -76,7 +79,10 @@ def test_wrap_tool_call_logging() -> None:
call_log = [] call_log = []
@wrap_tool_call @wrap_tool_call
def logging_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def logging_middleware(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
assert isinstance(request.tool, BaseTool)
call_log.append(f"before_{request.tool.name}") call_log.append(f"before_{request.tool.name}")
response = handler(request) response = handler(request)
call_log.append(f"after_{request.tool.name}") call_log.append(f"after_{request.tool.name}")
@@ -110,7 +116,10 @@ def test_wrap_tool_call_modify_args() -> None:
"""Test modifying tool arguments with wrap_tool_call decorator.""" """Test modifying tool arguments with wrap_tool_call decorator."""
@wrap_tool_call @wrap_tool_call
def modify_args(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def modify_args(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
assert isinstance(request.tool, BaseTool)
# Modify the query argument before execution # Modify the query argument before execution
if request.tool.name == "search": if request.tool.name == "search":
request.tool_call["args"]["query"] = "modified query" request.tool_call["args"]["query"] = "modified query"
@@ -145,7 +154,9 @@ def test_wrap_tool_call_access_state() -> None:
state_data = [] state_data = []
@wrap_tool_call @wrap_tool_call
def access_state(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def access_state(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
# Access state from request # Access state from request
if request.state is not None: if request.state is not None:
messages = request.state.get("messages", []) messages = request.state.get("messages", [])
@@ -181,7 +192,9 @@ def test_wrap_tool_call_access_runtime() -> None:
runtime_data = [] runtime_data = []
@wrap_tool_call @wrap_tool_call
def access_runtime(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def access_runtime(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
# Access runtime from request # Access runtime from request
if request.runtime is not None: if request.runtime is not None:
# Runtime object is available (has context, store, stream_writer, previous) # Runtime object is available (has context, store, stream_writer, previous)
@@ -217,7 +230,9 @@ def test_wrap_tool_call_retry_on_error() -> None:
attempt_counts = [] attempt_counts = []
@wrap_tool_call @wrap_tool_call
def retry_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def retry_middleware(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
max_retries = 3 max_retries = 3
last_error = None last_error = None
for attempt in range(max_retries): for attempt in range(max_retries):
@@ -275,7 +290,9 @@ def test_wrap_tool_call_short_circuit() -> None:
handler_called = [] handler_called = []
@wrap_tool_call @wrap_tool_call
def short_circuit(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def short_circuit(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
# Don't call handler, return custom response directly # Don't call handler, return custom response directly
handler_called.append(False) handler_called.append(False)
return ToolMessage( return ToolMessage(
@@ -314,7 +331,9 @@ def test_wrap_tool_call_response_modification() -> None:
"""Test modifying tool response with wrap_tool_call decorator.""" """Test modifying tool response with wrap_tool_call decorator."""
@wrap_tool_call @wrap_tool_call
def modify_response(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def modify_response(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
response = handler(request) response = handler(request)
# Modify the response # Modify the response
@@ -355,14 +374,18 @@ def test_wrap_tool_call_multiple_middleware_composition() -> None:
call_log = [] call_log = []
@wrap_tool_call @wrap_tool_call
def outer_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def outer_middleware(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("outer_before") call_log.append("outer_before")
response = handler(request) response = handler(request)
call_log.append("outer_after") call_log.append("outer_after")
return response return response
@wrap_tool_call @wrap_tool_call
def inner_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def inner_middleware(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("inner_before") call_log.append("inner_before")
response = handler(request) response = handler(request)
call_log.append("inner_after") call_log.append("inner_after")
@@ -399,7 +422,10 @@ def test_wrap_tool_call_multiple_tools() -> None:
call_log = [] call_log = []
@wrap_tool_call @wrap_tool_call
def log_tool_calls(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def log_tool_calls(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
assert isinstance(request.tool, BaseTool)
call_log.append(request.tool.name) call_log.append(request.tool.name)
return handler(request) return handler(request)
@@ -438,7 +464,9 @@ def test_wrap_tool_call_with_custom_name() -> None:
"""Test wrap_tool_call decorator with custom middleware name.""" """Test wrap_tool_call decorator with custom middleware name."""
@wrap_tool_call(name="CustomToolWrapper") @wrap_tool_call(name="CustomToolWrapper")
def my_wrapper(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def my_wrapper(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
return handler(request) return handler(request)
# Verify custom name was applied # Verify custom name was applied
@@ -454,7 +482,9 @@ def test_wrap_tool_call_with_tools_parameter() -> None:
return f"Extra: {value}" return f"Extra: {value}"
@wrap_tool_call(tools=[extra_tool]) @wrap_tool_call(tools=[extra_tool])
def wrapper_with_tools(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def wrapper_with_tools(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
return handler(request) return handler(request)
# Verify tools were registered # Verify tools were registered
@@ -466,21 +496,27 @@ def test_wrap_tool_call_three_levels_composition() -> None:
call_log = [] call_log = []
@wrap_tool_call(name="OuterWrapper") @wrap_tool_call(name="OuterWrapper")
def outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def outer(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("outer_before") call_log.append("outer_before")
response = handler(request) response = handler(request)
call_log.append("outer_after") call_log.append("outer_after")
return response return response
@wrap_tool_call(name="MiddleWrapper") @wrap_tool_call(name="MiddleWrapper")
def middle(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def middle(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("middle_before") call_log.append("middle_before")
response = handler(request) response = handler(request)
call_log.append("middle_after") call_log.append("middle_after")
return response return response
@wrap_tool_call(name="InnerWrapper") @wrap_tool_call(name="InnerWrapper")
def inner(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def inner(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("inner_before") call_log.append("inner_before")
response = handler(request) response = handler(request)
call_log.append("inner_after") call_log.append("inner_after")
@@ -524,7 +560,9 @@ def test_wrap_tool_call_outer_intercepts_inner() -> None:
call_log = [] call_log = []
@wrap_tool_call(name="InterceptingOuter") @wrap_tool_call(name="InterceptingOuter")
def intercepting_outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def intercepting_outer(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("outer_before") call_log.append("outer_before")
handler(request) handler(request)
call_log.append("outer_after") call_log.append("outer_after")
@@ -537,7 +575,9 @@ def test_wrap_tool_call_outer_intercepts_inner() -> None:
) )
@wrap_tool_call(name="InnerWrapper") @wrap_tool_call(name="InnerWrapper")
def inner(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def inner(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("inner_called") call_log.append("inner_called")
response = handler(request) response = handler(request)
call_log.append("inner_got_response") call_log.append("inner_got_response")
@@ -580,7 +620,9 @@ def test_wrap_tool_call_inner_short_circuits() -> None:
call_log = [] call_log = []
@wrap_tool_call(name="OuterWrapper") @wrap_tool_call(name="OuterWrapper")
def outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def outer(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("outer_before") call_log.append("outer_before")
response = handler(request) response = handler(request)
call_log.append("outer_after") call_log.append("outer_after")
@@ -595,7 +637,9 @@ def test_wrap_tool_call_inner_short_circuits() -> None:
return response return response
@wrap_tool_call(name="InnerShortCircuit") @wrap_tool_call(name="InnerShortCircuit")
def inner_short_circuit(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def inner_short_circuit(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("inner_short_circuit") call_log.append("inner_short_circuit")
# Don't call handler, return custom response # Don't call handler, return custom response
return ToolMessage( return ToolMessage(
@@ -636,14 +680,18 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
call_log = [] call_log = []
@wrap_tool_call(name="FirstPassthrough") @wrap_tool_call(name="FirstPassthrough")
def first_passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def first_passthrough(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("first_before") call_log.append("first_before")
response = handler(request) response = handler(request)
call_log.append("first_after") call_log.append("first_after")
return response return response
@wrap_tool_call(name="SecondIntercepting") @wrap_tool_call(name="SecondIntercepting")
def second_intercepting(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def second_intercepting(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("second_intercept") call_log.append("second_intercept")
# Call handler but ignore result # Call handler but ignore result
_ = handler(request) _ = handler(request)
@@ -655,7 +703,9 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
) )
@wrap_tool_call(name="ThirdPassthrough") @wrap_tool_call(name="ThirdPassthrough")
def third_passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def third_passthrough(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("third_called") call_log.append("third_called")
response = handler(request) response = handler(request)
call_log.append("third_after") call_log.append("third_after")
@@ -698,7 +748,9 @@ def test_wrap_tool_call_uses_function_name_as_default() -> None:
"""Test that wrap_tool_call uses function name as default middleware name.""" """Test that wrap_tool_call uses function name as default middleware name."""
@wrap_tool_call @wrap_tool_call
def my_custom_wrapper(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def my_custom_wrapper(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
return handler(request) return handler(request)
# Verify that function name is used as middleware class name # Verify that function name is used as middleware class name
@@ -707,11 +759,14 @@ def test_wrap_tool_call_uses_function_name_as_default() -> None:
def test_wrap_tool_call_caching_pattern() -> None: def test_wrap_tool_call_caching_pattern() -> None:
"""Test caching pattern with wrap_tool_call decorator.""" """Test caching pattern with wrap_tool_call decorator."""
cache = {} cache: dict[tuple[str, str], Any] = {}
handler_calls = [] handler_calls = []
@wrap_tool_call @wrap_tool_call
def with_cache(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def with_cache(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
assert isinstance(request.tool, BaseTool)
# Create cache key from tool name and args # Create cache key from tool name and args
cache_key = (request.tool.name, str(request.tool_call["args"])) cache_key = (request.tool.name, str(request.tool_call["args"]))
@@ -765,17 +820,21 @@ def test_wrap_tool_call_monitoring_pattern() -> None:
metrics = [] metrics = []
@wrap_tool_call @wrap_tool_call
def monitor_execution(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command: def monitor_execution(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
start_time = time.time() start_time = time.time()
response = handler(request) response = handler(request)
execution_time = time.time() - start_time execution_time = time.time() - start_time
assert isinstance(request.tool, BaseTool)
assert isinstance(response, ToolMessage)
assert isinstance(response.content, str)
metrics.append( metrics.append(
{ {
"tool": request.tool.name, "tool": request.tool.name,
"execution_time": execution_time, "execution_time": execution_time,
"success": isinstance(response, ToolMessage) "success": not response.content.startswith("Error:"),
and not response.content.startswith("Error:"),
} }
) )
@@ -804,4 +863,5 @@ def test_wrap_tool_call_monitoring_pattern() -> None:
assert len(metrics) == 1 assert len(metrics) == 1
assert metrics[0]["tool"] == "search" assert metrics[0]["tool"] == "search"
assert metrics[0]["success"] is True assert metrics[0]["success"] is True
assert isinstance(metrics[0]["execution_time"], float)
assert metrics[0]["execution_time"] >= 0 assert metrics[0]["execution_time"] >= 0