Compare commits

...

4 Commits

Author SHA1 Message Date
Liam Bush
256d1e6af3 fix: resolve mypy type error for ToolMessage content
- Handle case where msg.content could be list or string
- Convert to string before calling .lower()
- Mypy now passes cleanly
2025-10-22 14:01:25 -07:00
Liam Bush
9d046a1147 fix: resolve linting and formatting issues
- Move Callable and Command imports to TYPE_CHECKING block
- Fix line length violations (max 100 characters)
- Run ruff format to fix formatting
- All 17 tests passing
2025-10-22 13:55:29 -07:00
Liam Bush
749fb4f677 refactor: clean up code style and add sequential execution test
- Simplify verbose comments in tool_call_limit.py
- Use .lower() for case-insensitive string comparison
- Add test_end_tools_sequential() to explicitly test sequential tool calls
- All 17 tests passing (10 existing + 7 new)
2025-10-22 13:49:31 -07:00
Liam Bush
55bc0eb08e feat(agents): add end_tools exit behavior to ToolCallLimitMiddleware
Add new `end_tools` exit behavior that allows agents to continue with partial tool results when limits are exceeded, instead of stopping entirely.

## Changes

- Add `end_tools` exit behavior to ToolCallLimitMiddleware
- Implement `wrap_tool_call` hook to block tool execution when limits exceeded
- Update counting logic to work for both parallel and sequential execution
- Add comprehensive tests for all three exit behaviors (end, end_tools, error)

## Behavior

- `end`: Stops entire agent execution (unchanged)
- `end_tools` (new): Blocks tool execution but lets agent continue with partial results
- `error`: Raises exception (unchanged)

## Tests

- All 16 tests passing
- Added 6 new tests covering parallel execution, sequential execution, specific tool limiting, and thread limits
- Works correctly for both parallel (batch) and sequential (one-at-a-time) tool calls
2025-10-22 13:40:43 -07:00
3 changed files with 602 additions and 9 deletions

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Annotated, Any, Literal
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
from langgraph.channels.untracked_value import UntrackedValue
from typing_extensions import NotRequired
@@ -16,7 +16,12 @@ from langchain.agents.middleware.types import (
)
if TYPE_CHECKING:
from collections.abc import Callable
from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.tools.tool_node import ToolCallRequest
class ToolCallLimitState(AgentState):
@@ -163,12 +168,12 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
from langchain.agents import create_agent
# Limit all tool calls globally
# Limit all tool calls globally - stop entire agent when exceeded
global_limiter = ToolCallLimitMiddleware(thread_limit=20, run_limit=10, exit_behavior="end")
# Limit a specific tool
# Limit a specific tool - block tool execution but let agent continue
search_limiter = ToolCallLimitMiddleware(
tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end"
tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end_tools"
)
# Use both in the same agent
@@ -186,7 +191,7 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
tool_name: str | None = None,
thread_limit: int | None = None,
run_limit: int | None = None,
exit_behavior: Literal["end", "error"] = "end",
exit_behavior: Literal["end", "end_tools", "error"] = "end",
) -> None:
"""Initialize the tool call limit middleware.
@@ -200,6 +205,9 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
exit_behavior: What to do when limits are exceeded.
- "end": Jump to the end of the agent execution and
inject an artificial AI message indicating that the limit was exceeded.
- "end_tools": Allow the model to request tools, but block tool execution
when limits are exceeded. The agent receives warning messages and can
continue with partial results.
- "error": Raise a ToolCallLimitExceededError
Defaults to "end".
@@ -212,8 +220,8 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
msg = "At least one limit must be specified (thread_limit or run_limit)"
raise ValueError(msg)
if exit_behavior not in ("end", "error"):
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
if exit_behavior not in ("end", "end_tools", "error"):
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end', 'end_tools', or 'error'"
raise ValueError(msg)
self.tool_name = tool_name
@@ -237,18 +245,80 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
def before_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Check tool call limits before making a model call.
For `end` and `error` behaviors, this prevents the model from being
called if limits are already exceeded. For `end_tools` behavior, this first
counts successful tool executions from the previous iteration, then allows
the model to run (blocking happens during tool execution instead).
Args:
state: The current agent state containing tool call counts.
runtime: The langgraph runtime.
Returns:
If limits are exceeded and exit_behavior is "end", returns
a Command to jump to the end with a limit exceeded message. Otherwise returns None.
a Command to jump to the end with a limit exceeded message.
For end_tools, returns state updates with updated counts.
Otherwise returns None.
Raises:
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
is "error".
"""
# For end_tools behavior, count executions from current run
if self.exit_behavior == "end_tools":
messages = state.get("messages", [])
if not messages:
return None
# Only look at messages from the current run (after last HumanMessage)
run_messages = _get_run_messages(messages)
if not run_messages:
return None
# Count successful tool executions in the current run
count_key = self.tool_name if self.tool_name else "__all__"
successful_executions = 0
for msg in run_messages:
if not isinstance(msg, ToolMessage):
continue
# Check if this is a limit warning (not a successful execution)
content = msg.content if isinstance(msg.content, str) else str(msg.content)
is_limit_warning = "tool call limits exceeded" in content.lower()
# Check if this tool matches our filter
if self.tool_name is not None and msg.name != self.tool_name:
continue
if not is_limit_warning:
successful_executions += 1
if successful_executions == 0:
return None
# Check if we've already updated to this count
current_run_count = state.get("run_tool_call_count", {}).get(count_key, 0)
# If we've already counted all executions, don't update again
if current_run_count >= successful_executions:
return None
# Update counts with the delta
thread_counts = state.get("thread_tool_call_count", {}).copy()
run_counts = state.get("run_tool_call_count", {}).copy()
# Calculate how many new executions we haven't counted yet
new_executions = successful_executions - current_run_count
thread_counts[count_key] = thread_counts.get(count_key, 0) + new_executions
run_counts[count_key] = successful_executions
return {
"thread_tool_call_count": thread_counts,
"run_tool_call_count": run_counts,
}
# Get the count key for this middleware instance
count_key = self.tool_name if self.tool_name else "__all__"
@@ -285,6 +355,10 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
def after_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Increment tool call counts after a model call (when tool calls are made).
For `end_tools` behavior, counting happens in `before_model` on the next
iteration (after tools execute). For other behaviors, this increments the
count based on how many tool calls the model made.
Args:
state: The current agent state.
runtime: The langgraph runtime.
@@ -292,6 +366,10 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
Returns:
State updates with incremented tool call counts if tool calls were made.
"""
# For end_tools, counting happens in before_model (after tools finish)
if self.exit_behavior == "end_tools":
return None
# Get the last AIMessage to check for tool calls
messages = state.get("messages", [])
if not messages:
@@ -331,3 +409,106 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
"thread_tool_call_count": thread_counts,
"run_tool_call_count": run_counts,
}
def wrap_tool_call(
self,
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept tool execution to enforce limits for end_tools behavior.
For `end_tools` behavior, this method checks if executing this specific
tool would exceed the limits. If so, it returns a warning message instead
of executing the tool. This allows the agent to continue with partial results.
The position of the tool call in the model's response is used to determine
which tools should execute and which should be blocked, even with parallel
tool execution.
Args:
request: The tool call request containing the tool call and state.
execute: Function to execute the tool call.
Returns:
ToolMessage with the tool result, or a warning message if limit exceeded.
"""
# Only intercept for end_tools behavior
if self.exit_behavior != "end_tools":
return execute(request)
# Check if this tool matches our filter
if self.tool_name is not None and request.tool_call["name"] != self.tool_name:
# This tool doesn't match our filter, execute it without counting
return execute(request)
# Get the count key for this middleware instance
count_key = self.tool_name if self.tool_name else "__all__"
# Find the last AI message to get the tool call position
messages = request.state.get("messages", [])
last_ai_message = None
for message in reversed(messages):
if isinstance(message, AIMessage):
last_ai_message = message
break
if not last_ai_message or not last_ai_message.tool_calls:
# No AI message with tool calls found, execute normally
return execute(request)
# Find the position of this tool call in the list
# Only count tool calls that match our filter
tool_call_position = None
for idx, tc in enumerate(last_ai_message.tool_calls):
# Match by tool_call_id
if tc["id"] == request.tool_call["id"]:
# Count how many matching tool calls come before this one
matching_before = sum(
1
for i in range(idx)
if self.tool_name is None
or last_ai_message.tool_calls[i]["name"] == self.tool_name
)
tool_call_position = matching_before
break
# Shouldn't happen, but safety check
if tool_call_position is None:
return execute(request)
# Get current counts from state
thread_counts = request.state.get("thread_tool_call_count", {})
run_counts = request.state.get("run_tool_call_count", {})
current_thread_count = thread_counts.get(count_key, 0)
current_run_count = run_counts.get(count_key, 0)
# Calculate count after this tool executes (based on position)
count_after_this_tool = current_thread_count + tool_call_position + 1
run_count_after_this_tool = current_run_count + tool_call_position + 1
# Check if this tool call would exceed limits
thread_limit_exceeded = (
self.thread_limit is not None and count_after_this_tool > self.thread_limit
)
run_limit_exceeded = (
self.run_limit is not None and run_count_after_this_tool > self.run_limit
)
if thread_limit_exceeded or run_limit_exceeded:
# This tool would exceed the limit - return warning message
limit_message = _build_tool_limit_exceeded_message(
thread_count=current_thread_count + tool_call_position,
run_count=current_run_count + tool_call_position,
thread_limit=self.thread_limit,
run_limit=self.run_limit,
tool_name=self.tool_name,
)
return ToolMessage(
content=f"{limit_message} Do not call any more tools.",
tool_call_id=request.tool_call["id"],
name=request.tool_call["name"],
)
# Within limit - execute the tool
return execute(request)

View File

@@ -452,3 +452,415 @@ def test_exception_error_messages():
error_msg = str(exc_info.value)
assert "'search' tool call limits exceeded" in error_msg
assert "thread limit (2/2)" in error_msg
def test_end_behavior():
"""Test end behavior: stops entire agent when limit hit."""
@tool
def tool1(query: str) -> str:
"""First tool."""
return f"tool1: {query}"
@tool
def tool2(query: str) -> str:
"""Second tool."""
return f"tool2: {query}"
@tool
def tool3(query: str) -> str:
"""Third tool."""
return f"tool3: {query}"
# Model tries to call all 3 tools
model = FakeToolCallingModel(
tool_calls=[
[
ToolCall(name="tool1", args={"query": "test"}, id="1"),
ToolCall(name="tool2", args={"query": "test"}, id="2"),
ToolCall(name="tool3", args={"query": "test"}, id="3"),
]
]
)
# Limit to 2 tool calls with end behavior
middleware = ToolCallLimitMiddleware(run_limit=2, exit_behavior="end")
agent = create_agent(
model=model,
tools=[tool1, tool2, tool3],
middleware=[middleware],
)
result = agent.invoke({"messages": [HumanMessage("Call all tools")]})
# Agent should hit limit before calling model, jumping to end
# The model makes 3 tool calls in first response
# after_model increments count to 3
# before_model on next iteration sees count >= limit and jumps to end
messages = result["messages"]
# Should have: HumanMessage, AIMessage with tool calls, 3 ToolMessages, AIMessage with limit message
ai_messages = [m for m in messages if isinstance(m, AIMessage)]
# Last message should be the limit exceeded message
assert (
"run limit" in ai_messages[-1].content
or "Tool call limits exceeded" in ai_messages[-1].content
)
def test_end_tools_behavior():
"""Test end_tools behavior: blocks tool execution but lets agent continue."""
call_count = {"tool1": 0, "tool2": 0, "tool3": 0}
@tool
def tool1(query: str) -> str:
"""First tool."""
call_count["tool1"] += 1
return f"tool1 result: {query}"
@tool
def tool2(query: str) -> str:
"""Second tool."""
call_count["tool2"] += 1
return f"tool2 result: {query}"
@tool
def tool3(query: str) -> str:
"""Third tool."""
call_count["tool3"] += 1
return f"tool3 result: {query}"
# Model makes 3 tool calls, then responds with no tool calls
model = FakeToolCallingModel(
tool_calls=[
[
ToolCall(name="tool1", args={"query": "test"}, id="1"),
ToolCall(name="tool2", args={"query": "test"}, id="2"),
ToolCall(name="tool3", args={"query": "test"}, id="3"),
],
[], # Final response with no tool calls
]
)
# Limit to 2 tool calls with end_tools behavior
middleware = ToolCallLimitMiddleware(run_limit=2, exit_behavior="end_tools")
agent = create_agent(
model=model,
tools=[tool1, tool2, tool3],
middleware=[middleware],
)
result = agent.invoke({"messages": [HumanMessage("Call all tools")]})
# Verify that only the first 2 tools were actually executed
assert call_count["tool1"] == 1, "tool1 should have been called"
assert call_count["tool2"] == 1, "tool2 should have been called"
assert call_count["tool3"] == 0, "tool3 should have been blocked"
# Check messages - should have results from tool1 and tool2, warning from tool3
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
assert len(tool_messages) == 3
# First two should have actual results
assert "tool1 result" in tool_messages[0].content
assert "tool2 result" in tool_messages[1].content
# Third should be a limit warning
assert "Tool call limits exceeded" in tool_messages[2].content
assert "Do not call any more tools" in tool_messages[2].content
def test_error_behavior():
"""Test error behavior: raises exception when limit hit."""
@tool
def search(query: str) -> str:
"""Search for information."""
return f"Results for: {query}"
# Model makes 3 tool calls
model = FakeToolCallingModel(
tool_calls=[
[
ToolCall(name="search", args={"query": "test1"}, id="1"),
ToolCall(name="search", args={"query": "test2"}, id="2"),
ToolCall(name="search", args={"query": "test3"}, id="3"),
]
]
)
# Limit to 2 tool calls with error behavior
middleware = ToolCallLimitMiddleware(run_limit=2, exit_behavior="error")
agent = create_agent(
model=model,
tools=[search],
middleware=[middleware],
)
# Should raise exception when limit is hit
with pytest.raises(ToolCallLimitExceededError) as exc_info:
agent.invoke({"messages": [HumanMessage("Search for things")]})
assert "run limit" in str(exc_info.value)
assert exc_info.value.run_limit == 2
def test_end_tools_sequential():
"""Test end_tools behavior with sequential tool calls (one at a time)."""
call_count = {"search": 0}
@tool
def search(query: str) -> str:
"""Search for information."""
call_count["search"] += 1
return f"search result {call_count['search']}: {query}"
# Model makes tool calls one at a time across multiple iterations
# Each iteration: make one tool call, then continue
model = FakeToolCallingModel(
tool_calls=[
[ToolCall(name="search", args={"query": "first"}, id="1")],
[ToolCall(name="search", args={"query": "second"}, id="2")],
[ToolCall(name="search", args={"query": "third"}, id="3")],
[], # Final response with no tool calls
]
)
# Limit to 2 tool calls with end_tools behavior
middleware = ToolCallLimitMiddleware(run_limit=2, exit_behavior="end_tools")
agent = create_agent(
model=model,
tools=[search],
middleware=[middleware],
)
result = agent.invoke({"messages": [HumanMessage("Search sequentially")]})
# First 2 tools should execute, third should be blocked
assert call_count["search"] == 2, "Only first 2 searches should execute"
# Check tool messages
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
assert len(tool_messages) == 3
# First two should have actual results
assert "search result 1" in tool_messages[0].content
assert "search result 2" in tool_messages[1].content
# Third should be blocked
assert "tool call limits exceeded" in tool_messages[2].content.lower()
def test_end_tools_with_specific_tool():
"""Test end_tools behavior with specific tool limiting."""
call_count = {"search": 0, "calculator": 0}
@tool
def search(query: str) -> str:
"""Search for information."""
call_count["search"] += 1
return f"search result: {query}"
@tool
def calculator(expression: str) -> str:
"""Calculate an expression."""
call_count["calculator"] += 1
return f"calculator result: {expression}"
# Model makes 3 search calls and 2 calculator calls
model = FakeToolCallingModel(
tool_calls=[
[
ToolCall(name="search", args={"query": "test1"}, id="1"),
ToolCall(name="search", args={"query": "test2"}, id="2"),
ToolCall(name="calculator", args={"expression": "1+1"}, id="3"),
ToolCall(name="search", args={"query": "test3"}, id="4"),
ToolCall(name="calculator", args={"expression": "2+2"}, id="5"),
],
[],
]
)
# Limit only search to 2 calls with end_tools behavior
middleware = ToolCallLimitMiddleware(tool_name="search", run_limit=2, exit_behavior="end_tools")
agent = create_agent(
model=model,
tools=[search, calculator],
middleware=[middleware],
)
result = agent.invoke({"messages": [HumanMessage("Search and calculate")]})
# First 2 search calls should execute, third should be blocked
assert call_count["search"] == 2, "Only first 2 search calls should execute"
# All calculator calls should execute (not limited)
assert call_count["calculator"] == 2, "All calculator calls should execute"
# Check tool messages
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
search_messages = [m for m in tool_messages if m.name == "search"]
calc_messages = [m for m in tool_messages if m.name == "calculator"]
# Should have 3 search messages (2 results + 1 warning)
assert len(search_messages) == 3
assert "search result" in search_messages[0].content
assert "search result" in search_messages[1].content
assert "tool call limits exceeded" in search_messages[2].content.lower()
# Should have 2 calculator results
assert len(calc_messages) == 2
assert "calculator result" in calc_messages[0].content
assert "calculator result" in calc_messages[1].content
def test_end_tools_thread_limit():
"""Test end_tools behavior with thread limit across multiple runs."""
call_count = {"search": 0}
@tool
def search(query: str) -> str:
"""Search for information."""
call_count["search"] += 1
return f"search result {call_count['search']}: {query}"
# First run: 2 search calls
# Second run: 2 more search calls (total 4)
model = FakeToolCallingModel(
tool_calls=[
[
ToolCall(name="search", args={"query": "test1"}, id="1"),
ToolCall(name="search", args={"query": "test2"}, id="2"),
],
[],
[
ToolCall(name="search", args={"query": "test3"}, id="3"),
ToolCall(name="search", args={"query": "test4"}, id="4"),
],
[],
]
)
# Thread limit of 3 with end_tools behavior
middleware = ToolCallLimitMiddleware(thread_limit=3, exit_behavior="end_tools")
agent = create_agent(
model=model,
tools=[search],
middleware=[middleware],
checkpointer=InMemorySaver(),
)
thread_config = {"configurable": {"thread_id": "test_thread"}}
# First run: 2 calls, both should execute
result1 = agent.invoke({"messages": [HumanMessage("First search")]}, thread_config)
assert call_count["search"] == 2
tool_messages_1 = [m for m in result1["messages"] if isinstance(m, ToolMessage)]
assert len(tool_messages_1) == 2
assert "search result" in tool_messages_1[0].content
assert "search result" in tool_messages_1[1].content
# Second run: would add 2 more (total 4), but limit is 3
# So only first call should execute, second should be blocked
result2 = agent.invoke({"messages": [HumanMessage("Second search")]}, thread_config)
assert call_count["search"] == 3, "Only 1 more search should execute (3 total)"
tool_messages_2 = [
m for m in result2["messages"] if isinstance(m, ToolMessage) and m not in tool_messages_1
]
# Should have 2 new tool messages: 1 result, 1 warning
assert len(tool_messages_2) == 2
# First is the successful execution
assert "search result 3" in tool_messages_2[0].content
# Second is the blocked execution
assert "Tool call limits exceeded" in tool_messages_2[1].content
def test_comparison_all_three_behaviors():
"""
Comprehensive test comparing all three exit behaviors.
This demonstrates the key differences:
- end: Stops entire agent execution
- end_tools: Blocks individual tool executions, agent continues
- error: Raises exception
"""
@tool
def search(query: str) -> str:
"""Search for information."""
return f"search: {query}"
# Model tries to make 3 tool calls
def make_model():
return FakeToolCallingModel(
tool_calls=[
[
ToolCall(name="search", args={"query": "q1"}, id="1"),
ToolCall(name="search", args={"query": "q2"}, id="2"),
ToolCall(name="search", args={"query": "q3"}, id="3"),
],
[],
]
)
# Test 1: end behavior
agent_end = create_agent(
model=make_model(),
tools=[search],
middleware=[ToolCallLimitMiddleware(run_limit=2, exit_behavior="end")],
)
result_end = agent_end.invoke({"messages": [HumanMessage("Search")]})
# With end: agent completes first iteration (3 tool calls),
# then before_model on next iteration sees limit exceeded and jumps to end
messages_end = result_end["messages"]
ai_messages_end = [m for m in messages_end if isinstance(m, AIMessage)]
# Should have limit exceeded message
assert any("limit" in m.content.lower() for m in ai_messages_end)
# Test 2: end_tools behavior
agent_end_tools = create_agent(
model=make_model(),
tools=[search],
middleware=[ToolCallLimitMiddleware(run_limit=2, exit_behavior="end_tools")],
)
result_end_tools = agent_end_tools.invoke({"messages": [HumanMessage("Search")]})
# With end_tools: first 2 tools execute, third gets warning, agent continues
tool_messages_end_tools = [
m for m in result_end_tools["messages"] if isinstance(m, ToolMessage)
]
assert len(tool_messages_end_tools) == 3
assert "search: q1" in tool_messages_end_tools[0].content
assert "search: q2" in tool_messages_end_tools[1].content
assert "Tool call limits exceeded" in tool_messages_end_tools[2].content
# Test 3: error behavior
agent_error = create_agent(
model=make_model(),
tools=[search],
middleware=[ToolCallLimitMiddleware(run_limit=2, exit_behavior="error")],
)
# With error: should raise exception
with pytest.raises(ToolCallLimitExceededError) as exc_info:
agent_error.invoke({"messages": [HumanMessage("Search")]})
assert exc_info.value.run_limit == 2

View File

@@ -1743,7 +1743,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "1.0.0"
version = "1.0.1"
source = { editable = "../core" }
dependencies = [
{ name = "jsonpatch" },