mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
chore(langchain): fix types in test_wrap_tool_call (#34600)
This commit is contained in:
committed by
GitHub
parent
5ae53fdfb3
commit
c4babed5c6
@@ -6,9 +6,10 @@ focusing on the handler pattern (not generators).
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, ToolCall, ToolMessage
|
from langchain_core.messages import HumanMessage, ToolCall, ToolMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import BaseTool, tool
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
@@ -41,7 +42,9 @@ def test_wrap_tool_call_basic_passthrough() -> None:
|
|||||||
call_log = []
|
call_log = []
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def passthrough(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("called")
|
call_log.append("called")
|
||||||
return handler(request)
|
return handler(request)
|
||||||
|
|
||||||
@@ -76,7 +79,10 @@ def test_wrap_tool_call_logging() -> None:
|
|||||||
call_log = []
|
call_log = []
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def logging_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def logging_middleware(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
|
assert isinstance(request.tool, BaseTool)
|
||||||
call_log.append(f"before_{request.tool.name}")
|
call_log.append(f"before_{request.tool.name}")
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
call_log.append(f"after_{request.tool.name}")
|
call_log.append(f"after_{request.tool.name}")
|
||||||
@@ -110,7 +116,10 @@ def test_wrap_tool_call_modify_args() -> None:
|
|||||||
"""Test modifying tool arguments with wrap_tool_call decorator."""
|
"""Test modifying tool arguments with wrap_tool_call decorator."""
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def modify_args(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def modify_args(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
|
assert isinstance(request.tool, BaseTool)
|
||||||
# Modify the query argument before execution
|
# Modify the query argument before execution
|
||||||
if request.tool.name == "search":
|
if request.tool.name == "search":
|
||||||
request.tool_call["args"]["query"] = "modified query"
|
request.tool_call["args"]["query"] = "modified query"
|
||||||
@@ -145,7 +154,9 @@ def test_wrap_tool_call_access_state() -> None:
|
|||||||
state_data = []
|
state_data = []
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def access_state(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def access_state(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
# Access state from request
|
# Access state from request
|
||||||
if request.state is not None:
|
if request.state is not None:
|
||||||
messages = request.state.get("messages", [])
|
messages = request.state.get("messages", [])
|
||||||
@@ -181,7 +192,9 @@ def test_wrap_tool_call_access_runtime() -> None:
|
|||||||
runtime_data = []
|
runtime_data = []
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def access_runtime(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def access_runtime(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
# Access runtime from request
|
# Access runtime from request
|
||||||
if request.runtime is not None:
|
if request.runtime is not None:
|
||||||
# Runtime object is available (has context, store, stream_writer, previous)
|
# Runtime object is available (has context, store, stream_writer, previous)
|
||||||
@@ -217,7 +230,9 @@ def test_wrap_tool_call_retry_on_error() -> None:
|
|||||||
attempt_counts = []
|
attempt_counts = []
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def retry_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def retry_middleware(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
last_error = None
|
last_error = None
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
@@ -275,7 +290,9 @@ def test_wrap_tool_call_short_circuit() -> None:
|
|||||||
handler_called = []
|
handler_called = []
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def short_circuit(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def short_circuit(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
# Don't call handler, return custom response directly
|
# Don't call handler, return custom response directly
|
||||||
handler_called.append(False)
|
handler_called.append(False)
|
||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
@@ -314,7 +331,9 @@ def test_wrap_tool_call_response_modification() -> None:
|
|||||||
"""Test modifying tool response with wrap_tool_call decorator."""
|
"""Test modifying tool response with wrap_tool_call decorator."""
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def modify_response(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def modify_response(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
|
|
||||||
# Modify the response
|
# Modify the response
|
||||||
@@ -355,14 +374,18 @@ def test_wrap_tool_call_multiple_middleware_composition() -> None:
|
|||||||
call_log = []
|
call_log = []
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def outer_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def outer_middleware(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("outer_before")
|
call_log.append("outer_before")
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
call_log.append("outer_after")
|
call_log.append("outer_after")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def inner_middleware(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def inner_middleware(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("inner_before")
|
call_log.append("inner_before")
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
call_log.append("inner_after")
|
call_log.append("inner_after")
|
||||||
@@ -399,7 +422,10 @@ def test_wrap_tool_call_multiple_tools() -> None:
|
|||||||
call_log = []
|
call_log = []
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def log_tool_calls(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def log_tool_calls(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
|
assert isinstance(request.tool, BaseTool)
|
||||||
call_log.append(request.tool.name)
|
call_log.append(request.tool.name)
|
||||||
return handler(request)
|
return handler(request)
|
||||||
|
|
||||||
@@ -438,7 +464,9 @@ def test_wrap_tool_call_with_custom_name() -> None:
|
|||||||
"""Test wrap_tool_call decorator with custom middleware name."""
|
"""Test wrap_tool_call decorator with custom middleware name."""
|
||||||
|
|
||||||
@wrap_tool_call(name="CustomToolWrapper")
|
@wrap_tool_call(name="CustomToolWrapper")
|
||||||
def my_wrapper(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def my_wrapper(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
return handler(request)
|
return handler(request)
|
||||||
|
|
||||||
# Verify custom name was applied
|
# Verify custom name was applied
|
||||||
@@ -454,7 +482,9 @@ def test_wrap_tool_call_with_tools_parameter() -> None:
|
|||||||
return f"Extra: {value}"
|
return f"Extra: {value}"
|
||||||
|
|
||||||
@wrap_tool_call(tools=[extra_tool])
|
@wrap_tool_call(tools=[extra_tool])
|
||||||
def wrapper_with_tools(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def wrapper_with_tools(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
return handler(request)
|
return handler(request)
|
||||||
|
|
||||||
# Verify tools were registered
|
# Verify tools were registered
|
||||||
@@ -466,21 +496,27 @@ def test_wrap_tool_call_three_levels_composition() -> None:
|
|||||||
call_log = []
|
call_log = []
|
||||||
|
|
||||||
@wrap_tool_call(name="OuterWrapper")
|
@wrap_tool_call(name="OuterWrapper")
|
||||||
def outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def outer(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("outer_before")
|
call_log.append("outer_before")
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
call_log.append("outer_after")
|
call_log.append("outer_after")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@wrap_tool_call(name="MiddleWrapper")
|
@wrap_tool_call(name="MiddleWrapper")
|
||||||
def middle(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def middle(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("middle_before")
|
call_log.append("middle_before")
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
call_log.append("middle_after")
|
call_log.append("middle_after")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@wrap_tool_call(name="InnerWrapper")
|
@wrap_tool_call(name="InnerWrapper")
|
||||||
def inner(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def inner(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("inner_before")
|
call_log.append("inner_before")
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
call_log.append("inner_after")
|
call_log.append("inner_after")
|
||||||
@@ -524,7 +560,9 @@ def test_wrap_tool_call_outer_intercepts_inner() -> None:
|
|||||||
call_log = []
|
call_log = []
|
||||||
|
|
||||||
@wrap_tool_call(name="InterceptingOuter")
|
@wrap_tool_call(name="InterceptingOuter")
|
||||||
def intercepting_outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def intercepting_outer(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("outer_before")
|
call_log.append("outer_before")
|
||||||
handler(request)
|
handler(request)
|
||||||
call_log.append("outer_after")
|
call_log.append("outer_after")
|
||||||
@@ -537,7 +575,9 @@ def test_wrap_tool_call_outer_intercepts_inner() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@wrap_tool_call(name="InnerWrapper")
|
@wrap_tool_call(name="InnerWrapper")
|
||||||
def inner(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def inner(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("inner_called")
|
call_log.append("inner_called")
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
call_log.append("inner_got_response")
|
call_log.append("inner_got_response")
|
||||||
@@ -580,7 +620,9 @@ def test_wrap_tool_call_inner_short_circuits() -> None:
|
|||||||
call_log = []
|
call_log = []
|
||||||
|
|
||||||
@wrap_tool_call(name="OuterWrapper")
|
@wrap_tool_call(name="OuterWrapper")
|
||||||
def outer(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def outer(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("outer_before")
|
call_log.append("outer_before")
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
call_log.append("outer_after")
|
call_log.append("outer_after")
|
||||||
@@ -595,7 +637,9 @@ def test_wrap_tool_call_inner_short_circuits() -> None:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@wrap_tool_call(name="InnerShortCircuit")
|
@wrap_tool_call(name="InnerShortCircuit")
|
||||||
def inner_short_circuit(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def inner_short_circuit(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("inner_short_circuit")
|
call_log.append("inner_short_circuit")
|
||||||
# Don't call handler, return custom response
|
# Don't call handler, return custom response
|
||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
@@ -636,14 +680,18 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
|
|||||||
call_log = []
|
call_log = []
|
||||||
|
|
||||||
@wrap_tool_call(name="FirstPassthrough")
|
@wrap_tool_call(name="FirstPassthrough")
|
||||||
def first_passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def first_passthrough(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("first_before")
|
call_log.append("first_before")
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
call_log.append("first_after")
|
call_log.append("first_after")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@wrap_tool_call(name="SecondIntercepting")
|
@wrap_tool_call(name="SecondIntercepting")
|
||||||
def second_intercepting(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def second_intercepting(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("second_intercept")
|
call_log.append("second_intercept")
|
||||||
# Call handler but ignore result
|
# Call handler but ignore result
|
||||||
_ = handler(request)
|
_ = handler(request)
|
||||||
@@ -655,7 +703,9 @@ def test_wrap_tool_call_mixed_passthrough_and_intercepting() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@wrap_tool_call(name="ThirdPassthrough")
|
@wrap_tool_call(name="ThirdPassthrough")
|
||||||
def third_passthrough(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def third_passthrough(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
call_log.append("third_called")
|
call_log.append("third_called")
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
call_log.append("third_after")
|
call_log.append("third_after")
|
||||||
@@ -698,7 +748,9 @@ def test_wrap_tool_call_uses_function_name_as_default() -> None:
|
|||||||
"""Test that wrap_tool_call uses function name as default middleware name."""
|
"""Test that wrap_tool_call uses function name as default middleware name."""
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def my_custom_wrapper(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def my_custom_wrapper(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
return handler(request)
|
return handler(request)
|
||||||
|
|
||||||
# Verify that function name is used as middleware class name
|
# Verify that function name is used as middleware class name
|
||||||
@@ -707,11 +759,14 @@ def test_wrap_tool_call_uses_function_name_as_default() -> None:
|
|||||||
|
|
||||||
def test_wrap_tool_call_caching_pattern() -> None:
|
def test_wrap_tool_call_caching_pattern() -> None:
|
||||||
"""Test caching pattern with wrap_tool_call decorator."""
|
"""Test caching pattern with wrap_tool_call decorator."""
|
||||||
cache = {}
|
cache: dict[tuple[str, str], Any] = {}
|
||||||
handler_calls = []
|
handler_calls = []
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def with_cache(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def with_cache(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
|
assert isinstance(request.tool, BaseTool)
|
||||||
# Create cache key from tool name and args
|
# Create cache key from tool name and args
|
||||||
cache_key = (request.tool.name, str(request.tool_call["args"]))
|
cache_key = (request.tool.name, str(request.tool_call["args"]))
|
||||||
|
|
||||||
@@ -765,17 +820,21 @@ def test_wrap_tool_call_monitoring_pattern() -> None:
|
|||||||
metrics = []
|
metrics = []
|
||||||
|
|
||||||
@wrap_tool_call
|
@wrap_tool_call
|
||||||
def monitor_execution(request: ToolCallRequest, handler: Callable) -> ToolMessage | Command:
|
def monitor_execution(
|
||||||
|
request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]]
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
response = handler(request)
|
response = handler(request)
|
||||||
execution_time = time.time() - start_time
|
execution_time = time.time() - start_time
|
||||||
|
|
||||||
|
assert isinstance(request.tool, BaseTool)
|
||||||
|
assert isinstance(response, ToolMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
metrics.append(
|
metrics.append(
|
||||||
{
|
{
|
||||||
"tool": request.tool.name,
|
"tool": request.tool.name,
|
||||||
"execution_time": execution_time,
|
"execution_time": execution_time,
|
||||||
"success": isinstance(response, ToolMessage)
|
"success": not response.content.startswith("Error:"),
|
||||||
and not response.content.startswith("Error:"),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -804,4 +863,5 @@ def test_wrap_tool_call_monitoring_pattern() -> None:
|
|||||||
assert len(metrics) == 1
|
assert len(metrics) == 1
|
||||||
assert metrics[0]["tool"] == "search"
|
assert metrics[0]["tool"] == "search"
|
||||||
assert metrics[0]["success"] is True
|
assert metrics[0]["success"] is True
|
||||||
|
assert isinstance(metrics[0]["execution_time"], float)
|
||||||
assert metrics[0]["execution_time"] >= 0
|
assert metrics[0]["execution_time"] >= 0
|
||||||
|
|||||||
Reference in New Issue
Block a user