chore(langchain): fix types in test_wrap_tool_call (#34600)

This commit is contained in:
Christophe Bornet
2026-01-05 20:38:31 +01:00
committed by GitHub
parent 5ae53fdfb3
commit c4babed5c6

View File

@@ -6,9 +6,10 @@ focusing on the handler pattern (not generators).
import time
from collections.abc import Callable
from typing import Any
from langchain_core.messages import HumanMessage, ToolCall, ToolMessage
from langchain_core.tools import tool
from langchain_core.tools import BaseTool, tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.types import Command
@@ -41,7 +42,9 @@ def test_wrap_tool_call_basic_passthrough() -> None:
call_log = []
@wrap_tool_call
def passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def passthrough(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("called")
return handler(request)
@@ -76,7 +79,10 @@ def test_wrap_tool_call_logging() -> None:
call_log = []
@wrap_tool_call
def logging_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def logging_middleware(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
assert isinstance(request.tool, BaseTool)
call_log.append(f"before_{request.tool.name}")
response = handler(request)
call_log.append(f"after_{request.tool.name}")
@@ -110,7 +116,10 @@ def test_wrap_tool_call_modify_args() -> None:
"""Test modifying tool arguments with wrap_tool_call decorator."""
@wrap_tool_call
def modify_args(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def modify_args(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
assert isinstance(request.tool, BaseTool)
# Modify the query argument before execution
if request.tool.name == "search":
request.tool_call["args"]["query"] = "modified query"
@@ -145,7 +154,9 @@ def test_wrap_tool_call_access_state() -> None:
state_data = []
@wrap_tool_call
def access_state(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def access_state(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
# Access state from request
if request.state is not None:
messages = request.state.get("messages", [])
@@ -181,7 +192,9 @@ def test_wrap_tool_call_access_runtime() -> None:
runtime_data = []
@wrap_tool_call
def access_runtime(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def access_runtime(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
# Access runtime from request
if request.runtime is not None:
# Runtime object is available (has context, store, stream_writer, previous)
@@ -217,7 +230,9 @@ def test_wrap_tool_call_retry_on_error() -> None:
attempt_counts = []
@wrap_tool_call
def retry_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def retry_middleware(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
max_retries = 3
last_error = None
for attempt in range(max_retries):
@@ -275,7 +290,9 @@ def test_wrap_tool_call_short_circuit() -> None:
handler_called = []
@wrap_tool_call
def short_circuit(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def short_circuit(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
# Don't call handler, return custom response directly
handler_called.append(False)
return ToolMessage(
@@ -314,7 +331,9 @@ def test_wrap_tool_call_response_modification() -> None:
"""Test modifying tool response with wrap_tool_call decorator."""
@wrap_tool_call
def modify_response(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def modify_response(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
response = handler(request)
# Modify the response
@@ -355,14 +374,18 @@ def test_wrap_tool_call_multiple_middleware_composition() -> None:
call_log = []
@wrap_tool_call
def outer_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def outer_middleware(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("outer_before")
response = handler(request)
call_log.append("outer_after")
return response
@wrap_tool_call
def inner_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def inner_middleware(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("inner_before")
response = handler(request)
call_log.append("inner_after")
@@ -399,7 +422,10 @@ def test_wrap_tool_call_multiple_tools() -> None:
call_log = []
@wrap_tool_call
def log_tool_calls(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def log_tool_calls(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
assert isinstance(request.tool, BaseTool)
call_log.append(request.tool.name)
return handler(request)
@@ -438,7 +464,9 @@ def test_wrap_tool_call_with_custom_name() -> None:
"""Test wrap_tool_call decorator with custom middleware name."""
@wrap_tool_call(name="CustomToolWrapper")
def my_wrapper(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def my_wrapper(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
return handler(request)
# Verify custom name was applied
@@ -454,7 +482,9 @@ def test_wrap_tool_call_with_tools_parameter() -> None:
return f"Extra: {value}"
@wrap_tool_call(tools=[extra_tool])
def wrapper_with_tools(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def wrapper_with_tools(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
return handler(request)
# Verify tools were registered
@@ -466,21 +496,27 @@ def test_wrap_tool_call_three_levels_composition() -> None:
call_log = []
@wrap_tool_call(name="OuterWrapper")
def outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def outer(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("outer_before")
response = handler(request)
call_log.append("outer_after")
return response
@wrap_tool_call(name="MiddleWrapper")
def middle(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def middle(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("middle_before")
response = handler(request)
call_log.append("middle_after")
return response
@wrap_tool_call(name="InnerWrapper")
def inner(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def inner(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("inner_before")
response = handler(request)
call_log.append("inner_after")
@@ -524,7 +560,9 @@ def test_wrap_tool_call_outer_intercepts_inner() -> None:
call_log = []
@wrap_tool_call(name="InterceptingOuter")
def intercepting_outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def intercepting_outer(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("outer_before")
handler(request)
call_log.append("outer_after")
@@ -537,7 +575,9 @@ def test_wrap_tool_call_outer_intercepts_inner() -> None:
)
@wrap_tool_call(name="InnerWrapper")
def inner(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def inner(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("inner_called")
response = handler(request)
call_log.append("inner_got_response")
@@ -580,7 +620,9 @@ def test_wrap_tool_call_inner_short_circuits() -> None:
call_log = []
@wrap_tool_call(name="OuterWrapper")
def outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def outer(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("outer_before")
response = handler(request)
call_log.append("outer_after")
@@ -595,7 +637,9 @@ def test_wrap_tool_call_inner_short_circuits() -> None:
return response
@wrap_tool_call(name="InnerShortCircuit")
def inner_short_circuit(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def inner_short_circuit(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("inner_short_circuit")
# Don't call handler, return custom response
return ToolMessage(
@@ -636,14 +680,18 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
call_log = []
@wrap_tool_call(name="FirstPassthrough")
def first_passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def first_passthrough(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("first_before")
response = handler(request)
call_log.append("first_after")
return response
@wrap_tool_call(name="SecondIntercepting")
def second_intercepting(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def second_intercepting(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("second_intercept")
# Call handler but ignore result
_ = handler(request)
@@ -655,7 +703,9 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
)
@wrap_tool_call(name="ThirdPassthrough")
def third_passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def third_passthrough(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
call_log.append("third_called")
response = handler(request)
call_log.append("third_after")
@@ -698,7 +748,9 @@ def test_wrap_tool_call_uses_function_name_as_default() -> None:
"""Test that wrap_tool_call uses function name as default middleware name."""
@wrap_tool_call
def my_custom_wrapper(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def my_custom_wrapper(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
return handler(request)
# Verify that function name is used as middleware class name
@@ -707,11 +759,14 @@ def test_wrap_tool_call_uses_function_name_as_default() -> None:
def test_wrap_tool_call_caching_pattern() -> None:
"""Test caching pattern with wrap_tool_call decorator."""
cache = {}
cache: dict[tuple[str, str], Any] = {}
handler_calls = []
@wrap_tool_call
def with_cache(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def with_cache(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
assert isinstance(request.tool, BaseTool)
# Create cache key from tool name and args
cache_key = (request.tool.name, str(request.tool_call["args"]))
@@ -765,17 +820,21 @@ def test_wrap_tool_call_monitoring_pattern() -> None:
metrics = []
@wrap_tool_call
def monitor_execution(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
def monitor_execution(
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
) -> ToolMessage | Command[Any]:
start_time = time.time()
response = handler(request)
execution_time = time.time() - start_time
assert isinstance(request.tool, BaseTool)
assert isinstance(response, ToolMessage)
assert isinstance(response.content, str)
metrics.append(
{
"tool": request.tool.name,
"execution_time": execution_time,
"success": isinstance(response, ToolMessage)
and not response.content.startswith("Error:"),
"success": not response.content.startswith("Error:"),
}
)
@@ -804,4 +863,5 @@ def test_wrap_tool_call_monitoring_pattern() -> None:
assert len(metrics) == 1
assert metrics[0]["tool"] == "search"
assert metrics[0]["success"] is True
assert isinstance(metrics[0]["execution_time"], float)
assert metrics[0]["execution_time"] >= 0