mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-08 18:19:21 +00:00
Compare commits
4 Commits
langchain-
...
feat/tool-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
256d1e6af3 | ||
|
|
9d046a1147 | ||
|
|
749fb4f677 | ||
|
|
55bc0eb08e |
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
@@ -16,7 +16,12 @@ from langchain.agents.middleware.types import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.tools.tool_node import ToolCallRequest
|
||||
|
||||
|
||||
class ToolCallLimitState(AgentState):
|
||||
@@ -163,12 +168,12 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
||||
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
# Limit all tool calls globally
|
||||
# Limit all tool calls globally - stop entire agent when exceeded
|
||||
global_limiter = ToolCallLimitMiddleware(thread_limit=20, run_limit=10, exit_behavior="end")
|
||||
|
||||
# Limit a specific tool
|
||||
# Limit a specific tool - block tool execution but let agent continue
|
||||
search_limiter = ToolCallLimitMiddleware(
|
||||
tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end"
|
||||
tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end_tools"
|
||||
)
|
||||
|
||||
# Use both in the same agent
|
||||
@@ -186,7 +191,7 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
||||
tool_name: str | None = None,
|
||||
thread_limit: int | None = None,
|
||||
run_limit: int | None = None,
|
||||
exit_behavior: Literal["end", "error"] = "end",
|
||||
exit_behavior: Literal["end", "end_tools", "error"] = "end",
|
||||
) -> None:
|
||||
"""Initialize the tool call limit middleware.
|
||||
|
||||
@@ -200,6 +205,9 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
||||
exit_behavior: What to do when limits are exceeded.
|
||||
- "end": Jump to the end of the agent execution and
|
||||
inject an artificial AI message indicating that the limit was exceeded.
|
||||
- "end_tools": Allow the model to request tools, but block tool execution
|
||||
when limits are exceeded. The agent receives warning messages and can
|
||||
continue with partial results.
|
||||
- "error": Raise a ToolCallLimitExceededError
|
||||
Defaults to "end".
|
||||
|
||||
@@ -212,8 +220,8 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
||||
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
||||
raise ValueError(msg)
|
||||
|
||||
if exit_behavior not in ("end", "error"):
|
||||
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
|
||||
if exit_behavior not in ("end", "end_tools", "error"):
|
||||
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end', 'end_tools', or 'error'"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.tool_name = tool_name
|
||||
@@ -237,18 +245,80 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
||||
def before_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Check tool call limits before making a model call.
|
||||
|
||||
For `end` and `error` behaviors, this prevents the model from being
|
||||
called if limits are already exceeded. For `end_tools` behavior, this first
|
||||
counts successful tool executions from the previous iteration, then allows
|
||||
the model to run (blocking happens during tool execution instead).
|
||||
|
||||
Args:
|
||||
state: The current agent state containing tool call counts.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
If limits are exceeded and exit_behavior is "end", returns
|
||||
a Command to jump to the end with a limit exceeded message. Otherwise returns None.
|
||||
a Command to jump to the end with a limit exceeded message.
|
||||
For end_tools, returns state updates with updated counts.
|
||||
Otherwise returns None.
|
||||
|
||||
Raises:
|
||||
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
|
||||
is "error".
|
||||
"""
|
||||
# For end_tools behavior, count executions from current run
|
||||
if self.exit_behavior == "end_tools":
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Only look at messages from the current run (after last HumanMessage)
|
||||
run_messages = _get_run_messages(messages)
|
||||
if not run_messages:
|
||||
return None
|
||||
|
||||
# Count successful tool executions in the current run
|
||||
count_key = self.tool_name if self.tool_name else "__all__"
|
||||
successful_executions = 0
|
||||
|
||||
for msg in run_messages:
|
||||
if not isinstance(msg, ToolMessage):
|
||||
continue
|
||||
|
||||
# Check if this is a limit warning (not a successful execution)
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
is_limit_warning = "tool call limits exceeded" in content.lower()
|
||||
|
||||
# Check if this tool matches our filter
|
||||
if self.tool_name is not None and msg.name != self.tool_name:
|
||||
continue
|
||||
|
||||
if not is_limit_warning:
|
||||
successful_executions += 1
|
||||
|
||||
if successful_executions == 0:
|
||||
return None
|
||||
|
||||
# Check if we've already updated to this count
|
||||
current_run_count = state.get("run_tool_call_count", {}).get(count_key, 0)
|
||||
|
||||
# If we've already counted all executions, don't update again
|
||||
if current_run_count >= successful_executions:
|
||||
return None
|
||||
|
||||
# Update counts with the delta
|
||||
thread_counts = state.get("thread_tool_call_count", {}).copy()
|
||||
run_counts = state.get("run_tool_call_count", {}).copy()
|
||||
|
||||
# Calculate how many new executions we haven't counted yet
|
||||
new_executions = successful_executions - current_run_count
|
||||
|
||||
thread_counts[count_key] = thread_counts.get(count_key, 0) + new_executions
|
||||
run_counts[count_key] = successful_executions
|
||||
|
||||
return {
|
||||
"thread_tool_call_count": thread_counts,
|
||||
"run_tool_call_count": run_counts,
|
||||
}
|
||||
|
||||
# Get the count key for this middleware instance
|
||||
count_key = self.tool_name if self.tool_name else "__all__"
|
||||
|
||||
@@ -285,6 +355,10 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
||||
def after_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Increment tool call counts after a model call (when tool calls are made).
|
||||
|
||||
For `end_tools` behavior, counting happens in `before_model` on the next
|
||||
iteration (after tools execute). For other behaviors, this increments the
|
||||
count based on how many tool calls the model made.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
@@ -292,6 +366,10 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
||||
Returns:
|
||||
State updates with incremented tool call counts if tool calls were made.
|
||||
"""
|
||||
# For end_tools, counting happens in before_model (after tools finish)
|
||||
if self.exit_behavior == "end_tools":
|
||||
return None
|
||||
|
||||
# Get the last AIMessage to check for tool calls
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
@@ -331,3 +409,106 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
||||
"thread_tool_call_count": thread_counts,
|
||||
"run_tool_call_count": run_counts,
|
||||
}
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool execution to enforce limits for end_tools behavior.
|
||||
|
||||
For `end_tools` behavior, this method checks if executing this specific
|
||||
tool would exceed the limits. If so, it returns a warning message instead
|
||||
of executing the tool. This allows the agent to continue with partial results.
|
||||
|
||||
The position of the tool call in the model's response is used to determine
|
||||
which tools should execute and which should be blocked, even with parallel
|
||||
tool execution.
|
||||
|
||||
Args:
|
||||
request: The tool call request containing the tool call and state.
|
||||
execute: Function to execute the tool call.
|
||||
|
||||
Returns:
|
||||
ToolMessage with the tool result, or a warning message if limit exceeded.
|
||||
"""
|
||||
# Only intercept for end_tools behavior
|
||||
if self.exit_behavior != "end_tools":
|
||||
return execute(request)
|
||||
|
||||
# Check if this tool matches our filter
|
||||
if self.tool_name is not None and request.tool_call["name"] != self.tool_name:
|
||||
# This tool doesn't match our filter, execute it without counting
|
||||
return execute(request)
|
||||
|
||||
# Get the count key for this middleware instance
|
||||
count_key = self.tool_name if self.tool_name else "__all__"
|
||||
|
||||
# Find the last AI message to get the tool call position
|
||||
messages = request.state.get("messages", [])
|
||||
last_ai_message = None
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, AIMessage):
|
||||
last_ai_message = message
|
||||
break
|
||||
|
||||
if not last_ai_message or not last_ai_message.tool_calls:
|
||||
# No AI message with tool calls found, execute normally
|
||||
return execute(request)
|
||||
|
||||
# Find the position of this tool call in the list
|
||||
# Only count tool calls that match our filter
|
||||
tool_call_position = None
|
||||
for idx, tc in enumerate(last_ai_message.tool_calls):
|
||||
# Match by tool_call_id
|
||||
if tc["id"] == request.tool_call["id"]:
|
||||
# Count how many matching tool calls come before this one
|
||||
matching_before = sum(
|
||||
1
|
||||
for i in range(idx)
|
||||
if self.tool_name is None
|
||||
or last_ai_message.tool_calls[i]["name"] == self.tool_name
|
||||
)
|
||||
tool_call_position = matching_before
|
||||
break
|
||||
|
||||
# Shouldn't happen, but safety check
|
||||
if tool_call_position is None:
|
||||
return execute(request)
|
||||
|
||||
# Get current counts from state
|
||||
thread_counts = request.state.get("thread_tool_call_count", {})
|
||||
run_counts = request.state.get("run_tool_call_count", {})
|
||||
|
||||
current_thread_count = thread_counts.get(count_key, 0)
|
||||
current_run_count = run_counts.get(count_key, 0)
|
||||
|
||||
# Calculate count after this tool executes (based on position)
|
||||
count_after_this_tool = current_thread_count + tool_call_position + 1
|
||||
run_count_after_this_tool = current_run_count + tool_call_position + 1
|
||||
|
||||
# Check if this tool call would exceed limits
|
||||
thread_limit_exceeded = (
|
||||
self.thread_limit is not None and count_after_this_tool > self.thread_limit
|
||||
)
|
||||
run_limit_exceeded = (
|
||||
self.run_limit is not None and run_count_after_this_tool > self.run_limit
|
||||
)
|
||||
|
||||
if thread_limit_exceeded or run_limit_exceeded:
|
||||
# This tool would exceed the limit - return warning message
|
||||
limit_message = _build_tool_limit_exceeded_message(
|
||||
thread_count=current_thread_count + tool_call_position,
|
||||
run_count=current_run_count + tool_call_position,
|
||||
thread_limit=self.thread_limit,
|
||||
run_limit=self.run_limit,
|
||||
tool_name=self.tool_name,
|
||||
)
|
||||
return ToolMessage(
|
||||
content=f"{limit_message} Do not call any more tools.",
|
||||
tool_call_id=request.tool_call["id"],
|
||||
name=request.tool_call["name"],
|
||||
)
|
||||
|
||||
# Within limit - execute the tool
|
||||
return execute(request)
|
||||
|
||||
@@ -452,3 +452,415 @@ def test_exception_error_messages():
|
||||
error_msg = str(exc_info.value)
|
||||
assert "'search' tool call limits exceeded" in error_msg
|
||||
assert "thread limit (2/2)" in error_msg
|
||||
|
||||
|
||||
def test_end_behavior():
|
||||
"""Test end behavior: stops entire agent when limit hit."""
|
||||
|
||||
@tool
|
||||
def tool1(query: str) -> str:
|
||||
"""First tool."""
|
||||
return f"tool1: {query}"
|
||||
|
||||
@tool
|
||||
def tool2(query: str) -> str:
|
||||
"""Second tool."""
|
||||
return f"tool2: {query}"
|
||||
|
||||
@tool
|
||||
def tool3(query: str) -> str:
|
||||
"""Third tool."""
|
||||
return f"tool3: {query}"
|
||||
|
||||
# Model tries to call all 3 tools
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(name="tool1", args={"query": "test"}, id="1"),
|
||||
ToolCall(name="tool2", args={"query": "test"}, id="2"),
|
||||
ToolCall(name="tool3", args={"query": "test"}, id="3"),
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
# Limit to 2 tool calls with end behavior
|
||||
middleware = ToolCallLimitMiddleware(run_limit=2, exit_behavior="end")
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[tool1, tool2, tool3],
|
||||
middleware=[middleware],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Call all tools")]})
|
||||
|
||||
# Agent should hit limit before calling model, jumping to end
|
||||
# The model makes 3 tool calls in first response
|
||||
# after_model increments count to 3
|
||||
# before_model on next iteration sees count >= limit and jumps to end
|
||||
messages = result["messages"]
|
||||
|
||||
# Should have: HumanMessage, AIMessage with tool calls, 3 ToolMessages, AIMessage with limit message
|
||||
ai_messages = [m for m in messages if isinstance(m, AIMessage)]
|
||||
|
||||
# Last message should be the limit exceeded message
|
||||
assert (
|
||||
"run limit" in ai_messages[-1].content
|
||||
or "Tool call limits exceeded" in ai_messages[-1].content
|
||||
)
|
||||
|
||||
|
||||
def test_end_tools_behavior():
|
||||
"""Test end_tools behavior: blocks tool execution but lets agent continue."""
|
||||
|
||||
call_count = {"tool1": 0, "tool2": 0, "tool3": 0}
|
||||
|
||||
@tool
|
||||
def tool1(query: str) -> str:
|
||||
"""First tool."""
|
||||
call_count["tool1"] += 1
|
||||
return f"tool1 result: {query}"
|
||||
|
||||
@tool
|
||||
def tool2(query: str) -> str:
|
||||
"""Second tool."""
|
||||
call_count["tool2"] += 1
|
||||
return f"tool2 result: {query}"
|
||||
|
||||
@tool
|
||||
def tool3(query: str) -> str:
|
||||
"""Third tool."""
|
||||
call_count["tool3"] += 1
|
||||
return f"tool3 result: {query}"
|
||||
|
||||
# Model makes 3 tool calls, then responds with no tool calls
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(name="tool1", args={"query": "test"}, id="1"),
|
||||
ToolCall(name="tool2", args={"query": "test"}, id="2"),
|
||||
ToolCall(name="tool3", args={"query": "test"}, id="3"),
|
||||
],
|
||||
[], # Final response with no tool calls
|
||||
]
|
||||
)
|
||||
|
||||
# Limit to 2 tool calls with end_tools behavior
|
||||
middleware = ToolCallLimitMiddleware(run_limit=2, exit_behavior="end_tools")
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[tool1, tool2, tool3],
|
||||
middleware=[middleware],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Call all tools")]})
|
||||
|
||||
# Verify that only the first 2 tools were actually executed
|
||||
assert call_count["tool1"] == 1, "tool1 should have been called"
|
||||
assert call_count["tool2"] == 1, "tool2 should have been called"
|
||||
assert call_count["tool3"] == 0, "tool3 should have been blocked"
|
||||
|
||||
# Check messages - should have results from tool1 and tool2, warning from tool3
|
||||
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 3
|
||||
|
||||
# First two should have actual results
|
||||
assert "tool1 result" in tool_messages[0].content
|
||||
assert "tool2 result" in tool_messages[1].content
|
||||
|
||||
# Third should be a limit warning
|
||||
assert "Tool call limits exceeded" in tool_messages[2].content
|
||||
assert "Do not call any more tools" in tool_messages[2].content
|
||||
|
||||
|
||||
def test_error_behavior():
|
||||
"""Test error behavior: raises exception when limit hit."""
|
||||
|
||||
@tool
|
||||
def search(query: str) -> str:
|
||||
"""Search for information."""
|
||||
return f"Results for: {query}"
|
||||
|
||||
# Model makes 3 tool calls
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(name="search", args={"query": "test1"}, id="1"),
|
||||
ToolCall(name="search", args={"query": "test2"}, id="2"),
|
||||
ToolCall(name="search", args={"query": "test3"}, id="3"),
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
# Limit to 2 tool calls with error behavior
|
||||
middleware = ToolCallLimitMiddleware(run_limit=2, exit_behavior="error")
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search],
|
||||
middleware=[middleware],
|
||||
)
|
||||
|
||||
# Should raise exception when limit is hit
|
||||
with pytest.raises(ToolCallLimitExceededError) as exc_info:
|
||||
agent.invoke({"messages": [HumanMessage("Search for things")]})
|
||||
|
||||
assert "run limit" in str(exc_info.value)
|
||||
assert exc_info.value.run_limit == 2
|
||||
|
||||
|
||||
def test_end_tools_sequential():
|
||||
"""Test end_tools behavior with sequential tool calls (one at a time)."""
|
||||
|
||||
call_count = {"search": 0}
|
||||
|
||||
@tool
|
||||
def search(query: str) -> str:
|
||||
"""Search for information."""
|
||||
call_count["search"] += 1
|
||||
return f"search result {call_count['search']}: {query}"
|
||||
|
||||
# Model makes tool calls one at a time across multiple iterations
|
||||
# Each iteration: make one tool call, then continue
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[ToolCall(name="search", args={"query": "first"}, id="1")],
|
||||
[ToolCall(name="search", args={"query": "second"}, id="2")],
|
||||
[ToolCall(name="search", args={"query": "third"}, id="3")],
|
||||
[], # Final response with no tool calls
|
||||
]
|
||||
)
|
||||
|
||||
# Limit to 2 tool calls with end_tools behavior
|
||||
middleware = ToolCallLimitMiddleware(run_limit=2, exit_behavior="end_tools")
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search],
|
||||
middleware=[middleware],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Search sequentially")]})
|
||||
|
||||
# First 2 tools should execute, third should be blocked
|
||||
assert call_count["search"] == 2, "Only first 2 searches should execute"
|
||||
|
||||
# Check tool messages
|
||||
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 3
|
||||
|
||||
# First two should have actual results
|
||||
assert "search result 1" in tool_messages[0].content
|
||||
assert "search result 2" in tool_messages[1].content
|
||||
|
||||
# Third should be blocked
|
||||
assert "tool call limits exceeded" in tool_messages[2].content.lower()
|
||||
|
||||
|
||||
def test_end_tools_with_specific_tool():
|
||||
"""Test end_tools behavior with specific tool limiting."""
|
||||
|
||||
call_count = {"search": 0, "calculator": 0}
|
||||
|
||||
@tool
|
||||
def search(query: str) -> str:
|
||||
"""Search for information."""
|
||||
call_count["search"] += 1
|
||||
return f"search result: {query}"
|
||||
|
||||
@tool
|
||||
def calculator(expression: str) -> str:
|
||||
"""Calculate an expression."""
|
||||
call_count["calculator"] += 1
|
||||
return f"calculator result: {expression}"
|
||||
|
||||
# Model makes 3 search calls and 2 calculator calls
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(name="search", args={"query": "test1"}, id="1"),
|
||||
ToolCall(name="search", args={"query": "test2"}, id="2"),
|
||||
ToolCall(name="calculator", args={"expression": "1+1"}, id="3"),
|
||||
ToolCall(name="search", args={"query": "test3"}, id="4"),
|
||||
ToolCall(name="calculator", args={"expression": "2+2"}, id="5"),
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
# Limit only search to 2 calls with end_tools behavior
|
||||
middleware = ToolCallLimitMiddleware(tool_name="search", run_limit=2, exit_behavior="end_tools")
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search, calculator],
|
||||
middleware=[middleware],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Search and calculate")]})
|
||||
|
||||
# First 2 search calls should execute, third should be blocked
|
||||
assert call_count["search"] == 2, "Only first 2 search calls should execute"
|
||||
# All calculator calls should execute (not limited)
|
||||
assert call_count["calculator"] == 2, "All calculator calls should execute"
|
||||
|
||||
# Check tool messages
|
||||
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
|
||||
|
||||
search_messages = [m for m in tool_messages if m.name == "search"]
|
||||
calc_messages = [m for m in tool_messages if m.name == "calculator"]
|
||||
|
||||
# Should have 3 search messages (2 results + 1 warning)
|
||||
assert len(search_messages) == 3
|
||||
assert "search result" in search_messages[0].content
|
||||
assert "search result" in search_messages[1].content
|
||||
assert "tool call limits exceeded" in search_messages[2].content.lower()
|
||||
|
||||
# Should have 2 calculator results
|
||||
assert len(calc_messages) == 2
|
||||
assert "calculator result" in calc_messages[0].content
|
||||
assert "calculator result" in calc_messages[1].content
|
||||
|
||||
|
||||
def test_end_tools_thread_limit():
|
||||
"""Test end_tools behavior with thread limit across multiple runs."""
|
||||
|
||||
call_count = {"search": 0}
|
||||
|
||||
@tool
|
||||
def search(query: str) -> str:
|
||||
"""Search for information."""
|
||||
call_count["search"] += 1
|
||||
return f"search result {call_count['search']}: {query}"
|
||||
|
||||
# First run: 2 search calls
|
||||
# Second run: 2 more search calls (total 4)
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(name="search", args={"query": "test1"}, id="1"),
|
||||
ToolCall(name="search", args={"query": "test2"}, id="2"),
|
||||
],
|
||||
[],
|
||||
[
|
||||
ToolCall(name="search", args={"query": "test3"}, id="3"),
|
||||
ToolCall(name="search", args={"query": "test4"}, id="4"),
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
# Thread limit of 3 with end_tools behavior
|
||||
middleware = ToolCallLimitMiddleware(thread_limit=3, exit_behavior="end_tools")
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search],
|
||||
middleware=[middleware],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test_thread"}}
|
||||
|
||||
# First run: 2 calls, both should execute
|
||||
result1 = agent.invoke({"messages": [HumanMessage("First search")]}, thread_config)
|
||||
assert call_count["search"] == 2
|
||||
|
||||
tool_messages_1 = [m for m in result1["messages"] if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages_1) == 2
|
||||
assert "search result" in tool_messages_1[0].content
|
||||
assert "search result" in tool_messages_1[1].content
|
||||
|
||||
# Second run: would add 2 more (total 4), but limit is 3
|
||||
# So only first call should execute, second should be blocked
|
||||
result2 = agent.invoke({"messages": [HumanMessage("Second search")]}, thread_config)
|
||||
assert call_count["search"] == 3, "Only 1 more search should execute (3 total)"
|
||||
|
||||
tool_messages_2 = [
|
||||
m for m in result2["messages"] if isinstance(m, ToolMessage) and m not in tool_messages_1
|
||||
]
|
||||
|
||||
# Should have 2 new tool messages: 1 result, 1 warning
|
||||
assert len(tool_messages_2) == 2
|
||||
# First is the successful execution
|
||||
assert "search result 3" in tool_messages_2[0].content
|
||||
# Second is the blocked execution
|
||||
assert "Tool call limits exceeded" in tool_messages_2[1].content
|
||||
|
||||
|
||||
def test_comparison_all_three_behaviors():
|
||||
"""
|
||||
Comprehensive test comparing all three exit behaviors.
|
||||
|
||||
This demonstrates the key differences:
|
||||
- end: Stops entire agent execution
|
||||
- end_tools: Blocks individual tool executions, agent continues
|
||||
- error: Raises exception
|
||||
"""
|
||||
|
||||
@tool
|
||||
def search(query: str) -> str:
|
||||
"""Search for information."""
|
||||
return f"search: {query}"
|
||||
|
||||
# Model tries to make 3 tool calls
|
||||
def make_model():
|
||||
return FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
ToolCall(name="search", args={"query": "q1"}, id="1"),
|
||||
ToolCall(name="search", args={"query": "q2"}, id="2"),
|
||||
ToolCall(name="search", args={"query": "q3"}, id="3"),
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
# Test 1: end behavior
|
||||
agent_end = create_agent(
|
||||
model=make_model(),
|
||||
tools=[search],
|
||||
middleware=[ToolCallLimitMiddleware(run_limit=2, exit_behavior="end")],
|
||||
)
|
||||
|
||||
result_end = agent_end.invoke({"messages": [HumanMessage("Search")]})
|
||||
|
||||
# With end: agent completes first iteration (3 tool calls),
|
||||
# then before_model on next iteration sees limit exceeded and jumps to end
|
||||
messages_end = result_end["messages"]
|
||||
ai_messages_end = [m for m in messages_end if isinstance(m, AIMessage)]
|
||||
|
||||
# Should have limit exceeded message
|
||||
assert any("limit" in m.content.lower() for m in ai_messages_end)
|
||||
|
||||
# Test 2: end_tools behavior
|
||||
agent_end_tools = create_agent(
|
||||
model=make_model(),
|
||||
tools=[search],
|
||||
middleware=[ToolCallLimitMiddleware(run_limit=2, exit_behavior="end_tools")],
|
||||
)
|
||||
|
||||
result_end_tools = agent_end_tools.invoke({"messages": [HumanMessage("Search")]})
|
||||
|
||||
# With end_tools: first 2 tools execute, third gets warning, agent continues
|
||||
tool_messages_end_tools = [
|
||||
m for m in result_end_tools["messages"] if isinstance(m, ToolMessage)
|
||||
]
|
||||
assert len(tool_messages_end_tools) == 3
|
||||
assert "search: q1" in tool_messages_end_tools[0].content
|
||||
assert "search: q2" in tool_messages_end_tools[1].content
|
||||
assert "Tool call limits exceeded" in tool_messages_end_tools[2].content
|
||||
|
||||
# Test 3: error behavior
|
||||
agent_error = create_agent(
|
||||
model=make_model(),
|
||||
tools=[search],
|
||||
middleware=[ToolCallLimitMiddleware(run_limit=2, exit_behavior="error")],
|
||||
)
|
||||
|
||||
# With error: should raise exception
|
||||
with pytest.raises(ToolCallLimitExceededError) as exc_info:
|
||||
agent_error.invoke({"messages": [HumanMessage("Search")]})
|
||||
|
||||
assert exc_info.value.run_limit == 2
|
||||
|
||||
2
libs/langchain_v1/uv.lock
generated
2
libs/langchain_v1/uv.lock
generated
@@ -1743,7 +1743,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.0.0"
|
||||
version = "1.0.1"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
|
||||
Reference in New Issue
Block a user