mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Compare commits
2 Commits
langchain-
...
eugene/add
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
637d755e58 | ||
|
|
ec78645a1f |
@@ -54,6 +54,7 @@ __all__ = [
|
||||
"hook_config",
|
||||
"modify_model_request",
|
||||
"on_model_call",
|
||||
"on_tool_call",
|
||||
]
|
||||
|
||||
JumpTo = Literal["tools", "model", "end"]
|
||||
@@ -348,6 +349,20 @@ class _CallableReturningModelResponseGenerator(Protocol[StateT_contra, ContextT]
|
||||
...
|
||||
|
||||
|
||||
class _CallableReturningToolResponseGenerator(Protocol[StateT_contra, ContextT]):
|
||||
"""Callable returning generator for tool call interception.
|
||||
|
||||
Returns sync generator that works with both sync and async tool execution.
|
||||
Generator receives ToolMessage or Command via .send().
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, request: ToolCallRequest, state: StateT_contra, runtime: Runtime[ContextT]
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
"""Return generator to intercept tool execution."""
|
||||
...
|
||||
|
||||
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@@ -1327,3 +1342,141 @@ def on_model_call(
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def on_tool_call(
|
||||
func: _CallableReturningToolResponseGenerator[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def on_tool_call(
|
||||
func: None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[
|
||||
[_CallableReturningToolResponseGenerator[StateT, ContextT]],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]: ...
|
||||
|
||||
|
||||
def on_tool_call(
|
||||
func: _CallableReturningToolResponseGenerator[StateT, ContextT] | None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[_CallableReturningToolResponseGenerator[StateT, ContextT]],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
):
|
||||
"""Create middleware with on_tool_call hook from a generator function.
|
||||
|
||||
Converts a generator function into middleware that can intercept tool calls,
|
||||
implement retry logic, modify requests, cache results, or handle errors using
|
||||
standard Python exception handling.
|
||||
|
||||
Args:
|
||||
func: Generator function accepting (request, state, runtime) that yields
|
||||
ToolCallRequest, ToolMessage, or Command, receives ToolMessage or Command
|
||||
via .send() on completion, and optionally handles exceptions via .throw().
|
||||
state_schema: Custom state schema. Defaults to AgentState.
|
||||
tools: Additional tools to register with this middleware.
|
||||
name: Middleware class name. Defaults to function name.
|
||||
|
||||
Returns:
|
||||
AgentMiddleware instance if func provided, otherwise a decorator.
|
||||
|
||||
Examples:
|
||||
Basic retry logic:
|
||||
```python
|
||||
@on_tool_call
|
||||
def retry_on_error(
|
||||
request: ToolCallRequest, state: AgentState, runtime: Runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
response = yield request
|
||||
if isinstance(response, ToolMessage) and not response.status == "error":
|
||||
return
|
||||
if attempt == max_retries - 1:
|
||||
return
|
||||
```
|
||||
|
||||
Modify tool arguments:
|
||||
```python
|
||||
@on_tool_call
|
||||
def scale_numeric_args(
|
||||
request: ToolCallRequest, state: AgentState, runtime: Runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
# Double numeric arguments
|
||||
for key, value in request.tool_call["args"].items():
|
||||
if isinstance(value, (int, float)):
|
||||
request.tool_call["args"][key] = value * 2
|
||||
yield request
|
||||
```
|
||||
|
||||
Cache tool results:
|
||||
```python
|
||||
@on_tool_call
|
||||
def cache_results(
|
||||
request: ToolCallRequest, state: AgentState, runtime: Runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
cache_key = f"{request.tool_call['name']}:{request.tool_call['args']}"
|
||||
if cached := get_cache(cache_key):
|
||||
yield cached # Short-circuit with cached result
|
||||
else:
|
||||
result = yield request
|
||||
save_cache(cache_key, result)
|
||||
```
|
||||
|
||||
Error handling with fallback:
|
||||
```python
|
||||
@on_tool_call
|
||||
def handle_errors(
|
||||
request: ToolCallRequest, state: AgentState, runtime: Runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
try:
|
||||
yield request
|
||||
except Exception as e:
|
||||
# Return error message instead of propagating
|
||||
yield ToolMessage(
|
||||
content=f"Tool execution failed: {e}",
|
||||
tool_call_id=request.tool_call["id"],
|
||||
status="error",
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableReturningToolResponseGenerator[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
def wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ToolCallRequest,
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
return func(request, state, runtime)
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "OnToolCallMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"on_tool_call": wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
@@ -0,0 +1,587 @@
|
||||
"""Unit tests for the @on_tool_call decorator."""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
on_tool_call,
|
||||
)
|
||||
from langchain.tools.tool_node import ToolCallRequest
|
||||
from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel
|
||||
|
||||
|
||||
@tool
|
||||
def search_tool(query: str) -> str:
|
||||
"""Search for information."""
|
||||
return f"Results for: {query}"
|
||||
|
||||
|
||||
@tool
|
||||
def calculator_tool(expression: str) -> str:
|
||||
"""Calculate an expression."""
|
||||
return f"Result: {expression}"
|
||||
|
||||
|
||||
@tool
|
||||
def failing_tool(input: str) -> str:
|
||||
"""Tool that always fails."""
|
||||
msg = f"Failed: {input}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class TestOnToolCallDecorator:
|
||||
"""Test the @on_tool_call decorator for creating middleware."""
|
||||
|
||||
def test_basic_decorator_usage(self) -> None:
|
||||
"""Test basic decorator usage without parameters."""
|
||||
|
||||
@on_tool_call
|
||||
def passthrough_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
response = yield request
|
||||
|
||||
# Should return an AgentMiddleware instance
|
||||
assert isinstance(passthrough_middleware, AgentMiddleware)
|
||||
|
||||
# Should work in agent
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "test"},
|
||||
id="call_1",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
[], # Empty to signal agent should stop
|
||||
],
|
||||
tool_style="openai",
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search_tool],
|
||||
middleware=[passthrough_middleware],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Search for test")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
# Should have human message, AI message with tool call, tool message, and final AI message
|
||||
assert len(result["messages"]) >= 3
|
||||
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
assert "Results for: test" in tool_messages[0].content
|
||||
|
||||
def test_decorator_with_custom_name(self) -> None:
|
||||
"""Test decorator with custom middleware name."""
|
||||
|
||||
@on_tool_call(name="CustomToolMiddleware")
|
||||
def my_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
response = yield request
|
||||
|
||||
assert isinstance(my_middleware, AgentMiddleware)
|
||||
assert my_middleware.__class__.__name__ == "CustomToolMiddleware"
|
||||
|
||||
def test_decorator_logging(self) -> None:
|
||||
"""Test decorator for logging tool calls."""
|
||||
call_log = []
|
||||
|
||||
@on_tool_call
|
||||
def logging_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
call_log.append(f"before_{request.tool.name}")
|
||||
response = yield request
|
||||
call_log.append(f"after_{request.tool.name}")
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "test"},
|
||||
id="call_1",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
[], # Empty to signal agent should stop
|
||||
],
|
||||
tool_style="openai",
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search_tool],
|
||||
middleware=[logging_middleware],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
agent.invoke(
|
||||
{"messages": [HumanMessage("Search")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
assert call_log == ["before_search_tool", "after_search_tool"]
|
||||
|
||||
def test_decorator_modifying_args(self) -> None:
|
||||
"""Test decorator modifying tool arguments."""
|
||||
|
||||
@on_tool_call
|
||||
def modify_args_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
# Modify the query argument
|
||||
request.tool_call["args"]["query"] = "modified query"
|
||||
response = yield request
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "original"},
|
||||
id="call_1",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
[], # Empty to signal agent should stop
|
||||
],
|
||||
tool_style="openai",
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search_tool],
|
||||
middleware=[modify_args_middleware],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Search")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
# Tool should have been called with modified args
|
||||
tool_message = result["messages"][2]
|
||||
assert "Results for: modified query" in tool_message.content
|
||||
|
||||
def test_decorator_response_inspection(self) -> None:
|
||||
"""Test decorator inspecting tool responses."""
|
||||
inspected_values = []
|
||||
|
||||
@on_tool_call
|
||||
def inspect_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
response = yield request
|
||||
inspected_values.append(response.content)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "test"},
|
||||
id="call_1",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
[], # Empty to signal agent should stop
|
||||
],
|
||||
tool_style="openai",
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search_tool],
|
||||
middleware=[inspect_middleware],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Test")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
# Should have inspected the response
|
||||
assert len(inspected_values) == 1
|
||||
assert "Results for: test" in inspected_values[0]
|
||||
|
||||
def test_decorator_with_state_access(self) -> None:
|
||||
"""Test decorator accessing agent state."""
|
||||
state_values = []
|
||||
|
||||
@on_tool_call
|
||||
def log_state(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
state_values.append(len(state.get("messages", [])))
|
||||
response = yield request
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "test"},
|
||||
id="call_1",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
[], # Empty to signal agent should stop
|
||||
],
|
||||
tool_style="openai",
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search_tool],
|
||||
middleware=[log_state],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
agent.invoke(
|
||||
{"messages": [HumanMessage("Search")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
# State should have been logged
|
||||
assert len(state_values) == 1
|
||||
assert state_values[0] == 2 # Human message + AI message with tool call
|
||||
|
||||
def test_multiple_decorated_middleware(self) -> None:
|
||||
"""Test composition of multiple decorated middleware."""
|
||||
execution_order = []
|
||||
|
||||
@on_tool_call
|
||||
def outer_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
execution_order.append("outer-before")
|
||||
response = yield request
|
||||
execution_order.append("outer-after")
|
||||
|
||||
@on_tool_call
|
||||
def inner_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
execution_order.append("inner-before")
|
||||
response = yield request
|
||||
execution_order.append("inner-after")
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "test"},
|
||||
id="call_1",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
[], # Empty to signal agent should stop
|
||||
],
|
||||
tool_style="openai",
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search_tool],
|
||||
middleware=[outer_middleware, inner_middleware],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
agent.invoke(
|
||||
{"messages": [HumanMessage("Search")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
assert execution_order == [
|
||||
"outer-before",
|
||||
"inner-before",
|
||||
"inner-after",
|
||||
"outer-after",
|
||||
]
|
||||
|
||||
def test_decorator_with_custom_state_schema(self) -> None:
|
||||
"""Test decorator with custom state schema."""
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
class CustomState(TypedDict):
|
||||
messages: list
|
||||
custom_field: str
|
||||
|
||||
@on_tool_call(state_schema=CustomState)
|
||||
def middleware_with_schema(
|
||||
request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
response = yield request
|
||||
|
||||
assert isinstance(middleware_with_schema, AgentMiddleware)
|
||||
assert middleware_with_schema.state_schema == CustomState
|
||||
|
||||
def test_decorator_with_tools_parameter(self) -> None:
|
||||
"""Test decorator with tools parameter."""
|
||||
|
||||
@tool
|
||||
def test_tool(query: str) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {query}"
|
||||
|
||||
@on_tool_call(tools=[test_tool])
|
||||
def middleware_with_tools(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
response = yield request
|
||||
|
||||
assert isinstance(middleware_with_tools, AgentMiddleware)
|
||||
assert len(middleware_with_tools.tools) == 1
|
||||
assert middleware_with_tools.tools[0].name == "test_tool"
|
||||
|
||||
def test_decorator_parentheses_optional(self) -> None:
|
||||
"""Test that decorator works both with and without parentheses."""
|
||||
|
||||
# Without parentheses
|
||||
@on_tool_call
|
||||
def middleware_no_parens(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
response = yield request
|
||||
|
||||
# With parentheses
|
||||
@on_tool_call()
|
||||
def middleware_with_parens(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
response = yield request
|
||||
|
||||
assert isinstance(middleware_no_parens, AgentMiddleware)
|
||||
assert isinstance(middleware_with_parens, AgentMiddleware)
|
||||
|
||||
def test_decorator_preserves_function_name(self) -> None:
|
||||
"""Test that decorator uses function name for class name."""
|
||||
|
||||
@on_tool_call
|
||||
def my_custom_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
response = yield request
|
||||
|
||||
assert my_custom_middleware.__class__.__name__ == "my_custom_middleware"
|
||||
|
||||
def test_decorator_mixed_with_class_middleware(self) -> None:
|
||||
"""Test decorated middleware mixed with class-based middleware."""
|
||||
execution_order = []
|
||||
|
||||
@on_tool_call
|
||||
def decorated_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
execution_order.append("decorated-before")
|
||||
response = yield request
|
||||
execution_order.append("decorated-after")
|
||||
|
||||
class ClassMiddleware(AgentMiddleware):
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
execution_order.append("class-before")
|
||||
response = yield request
|
||||
execution_order.append("class-after")
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "test"},
|
||||
id="call_1",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
[], # Empty to signal agent should stop
|
||||
],
|
||||
tool_style="openai",
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search_tool],
|
||||
middleware=[decorated_middleware, ClassMiddleware()],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
agent.invoke(
|
||||
{"messages": [HumanMessage("Search")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
assert execution_order == [
|
||||
"decorated-before",
|
||||
"class-before",
|
||||
"class-after",
|
||||
"decorated-after",
|
||||
]
|
||||
|
||||
def test_decorator_short_circuit_with_cached_result(self) -> None:
|
||||
"""Test decorator short-circuiting with cached result."""
|
||||
cache = {}
|
||||
|
||||
@on_tool_call
|
||||
def caching_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
cache_key = f"{request.tool_call['name']}:{request.tool_call['args']}"
|
||||
if cache_key in cache:
|
||||
# Short-circuit with cached result
|
||||
yield cache[cache_key]
|
||||
else:
|
||||
# Execute tool and cache result
|
||||
response = yield request
|
||||
cache[cache_key] = response
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "test"},
|
||||
id="call_1",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "test"},
|
||||
id="call_2",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
[], # Empty to signal agent should stop
|
||||
],
|
||||
tool_style="openai",
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search_tool],
|
||||
middleware=[caching_middleware],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
# First call - should execute tool
|
||||
result1 = agent.invoke(
|
||||
{"messages": [HumanMessage("Search")]},
|
||||
{"configurable": {"thread_id": "test1"}},
|
||||
)
|
||||
|
||||
# Cache should be populated
|
||||
assert len(cache) == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = agent.invoke(
|
||||
{"messages": [HumanMessage("Search again")]},
|
||||
{"configurable": {"thread_id": "test2"}},
|
||||
)
|
||||
|
||||
# Both results should have tool messages with same content
|
||||
assert "Results for: test" in result1["messages"][2].content
|
||||
assert "Results for: test" in result2["messages"][2].content
|
||||
|
||||
async def test_decorator_with_async_agent(self) -> None:
|
||||
"""Test that decorated middleware works with async agent invocation."""
|
||||
call_log = []
|
||||
|
||||
@on_tool_call
|
||||
def logging_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
call_log.append("before")
|
||||
response = yield request
|
||||
call_log.append("after")
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "test"},
|
||||
id="call_1",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
[], # Empty to signal agent should stop
|
||||
],
|
||||
tool_style="openai",
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search_tool],
|
||||
middleware=[logging_middleware],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
result = await agent.ainvoke(
|
||||
{"messages": [HumanMessage("Search")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
assert call_log == ["before", "after"]
|
||||
assert "Results for: test" in result["messages"][2].content
|
||||
|
||||
def test_decorator_multiple_tools_called(self) -> None:
|
||||
"""Test decorator handling multiple tool calls in one turn."""
|
||||
call_log = []
|
||||
|
||||
@on_tool_call
|
||||
def logging_middleware(
|
||||
request: ToolCallRequest, state: AgentState, runtime
|
||||
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
|
||||
call_log.append(request.tool.name)
|
||||
response = yield request
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(
|
||||
name="search_tool",
|
||||
args={"query": "test"},
|
||||
id="call_1",
|
||||
type="tool_call",
|
||||
),
|
||||
ToolCall(
|
||||
name="calculator_tool",
|
||||
args={"expression": "2+2"},
|
||||
id="call_2",
|
||||
type="tool_call",
|
||||
),
|
||||
],
|
||||
[], # Empty to signal agent should stop
|
||||
],
|
||||
tool_style="openai",
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search_tool, calculator_tool],
|
||||
middleware=[logging_middleware],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Search and calculate")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
# Both tools should have been called
|
||||
assert "search_tool" in call_log
|
||||
assert "calculator_tool" in call_log
|
||||
assert len(result["messages"]) >= 4 # human, ai, tool1, tool2, (final ai)
|
||||
4
libs/langchain_v1/uv.lock
generated
4
libs/langchain_v1/uv.lock
generated
@@ -1742,7 +1742,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.0.0a7"
|
||||
version = "1.0.0a8"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -1907,7 +1907,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-openai"
|
||||
version = "1.0.0a3"
|
||||
version = "1.0.0a4"
|
||||
source = { editable = "../partners/openai" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
|
||||
Reference in New Issue
Block a user