mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
488 lines
18 KiB
Python
488 lines
18 KiB
Python
"""Tool call limit middleware for agents."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
|
|
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
from langgraph.typing import ContextT
|
|
from typing_extensions import NotRequired, override
|
|
|
|
from langchain.agents.middleware.types import (
|
|
AgentMiddleware,
|
|
AgentState,
|
|
PrivateStateAttr,
|
|
ResponseT,
|
|
hook_config,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from langgraph.runtime import Runtime
|
|
|
|
ExitBehavior = Literal["continue", "error", "end"]
|
|
"""How to handle execution when tool call limits are exceeded.
|
|
|
|
- `'continue'`: Block exceeded tools with error messages, let other tools continue
|
|
(default)
|
|
- `'error'`: Raise a `ToolCallLimitExceededError` exception
|
|
- `'end'`: Stop execution immediately, injecting a `ToolMessage` and an `AIMessage` for
|
|
the single tool call that exceeded the limit. Raises `NotImplementedError` if there
|
|
are other pending tool calls (due to parallel tool calling).
|
|
"""
|
|
|
|
|
|
class ToolCallLimitState(AgentState[ResponseT]):
|
|
"""State schema for `ToolCallLimitMiddleware`.
|
|
|
|
Extends `AgentState` with tool call tracking fields.
|
|
|
|
The count fields are dictionaries mapping tool names to execution counts. This
|
|
allows multiple middleware instances to track different tools independently. The
|
|
special key `'__all__'` is used for tracking all tool calls globally.
|
|
|
|
Type Parameters:
|
|
ResponseT: The type of the structured response. Defaults to `Any`.
|
|
"""
|
|
|
|
thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
|
|
run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
|
|
|
|
|
|
def _build_tool_message_content(tool_name: str | None) -> str:
|
|
"""Build the error message content for `ToolMessage` when limit is exceeded.
|
|
|
|
This message is sent to the model, so it should not reference thread/run concepts
|
|
that the model has no notion of.
|
|
|
|
Args:
|
|
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
|
|
|
|
Returns:
|
|
A concise message instructing the model not to call the tool again.
|
|
"""
|
|
# Always instruct the model not to call again, regardless of which limit was hit
|
|
if tool_name:
|
|
return f"Tool call limit exceeded. Do not call '{tool_name}' again."
|
|
return "Tool call limit exceeded. Do not make additional tool calls."
|
|
|
|
|
|
def _build_final_ai_message_content(
|
|
thread_count: int,
|
|
run_count: int,
|
|
thread_limit: int | None,
|
|
run_limit: int | None,
|
|
tool_name: str | None,
|
|
) -> str:
|
|
"""Build the final AI message content for `'end'` behavior.
|
|
|
|
This message is displayed to the user, so it should include detailed information
|
|
about which limits were exceeded.
|
|
|
|
Args:
|
|
thread_count: Current thread tool call count.
|
|
run_count: Current run tool call count.
|
|
thread_limit: Thread tool call limit (if set).
|
|
run_limit: Run tool call limit (if set).
|
|
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
|
|
|
|
Returns:
|
|
A formatted message describing which limits were exceeded.
|
|
"""
|
|
tool_desc = f"'{tool_name}' tool" if tool_name else "Tool"
|
|
exceeded_limits = []
|
|
|
|
if thread_limit is not None and thread_count > thread_limit:
|
|
exceeded_limits.append(f"thread limit exceeded ({thread_count}/{thread_limit} calls)")
|
|
if run_limit is not None and run_count > run_limit:
|
|
exceeded_limits.append(f"run limit exceeded ({run_count}/{run_limit} calls)")
|
|
|
|
limits_text = " and ".join(exceeded_limits)
|
|
return f"{tool_desc} call limit reached: {limits_text}."
|
|
|
|
|
|
class ToolCallLimitExceededError(Exception):
|
|
"""Exception raised when tool call limits are exceeded.
|
|
|
|
This exception is raised when the configured exit behavior is `'error'` and either
|
|
the thread or run tool call limit has been exceeded.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
thread_count: int,
|
|
run_count: int,
|
|
thread_limit: int | None,
|
|
run_limit: int | None,
|
|
tool_name: str | None = None,
|
|
) -> None:
|
|
"""Initialize the exception with call count information.
|
|
|
|
Args:
|
|
thread_count: Current thread tool call count.
|
|
run_count: Current run tool call count.
|
|
thread_limit: Thread tool call limit (if set).
|
|
run_limit: Run tool call limit (if set).
|
|
tool_name: Tool name being limited (if specific tool), or None for all tools.
|
|
"""
|
|
self.thread_count = thread_count
|
|
self.run_count = run_count
|
|
self.thread_limit = thread_limit
|
|
self.run_limit = run_limit
|
|
self.tool_name = tool_name
|
|
|
|
msg = _build_final_ai_message_content(
|
|
thread_count, run_count, thread_limit, run_limit, tool_name
|
|
)
|
|
super().__init__(msg)
|
|
|
|
|
|
class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState[ResponseT], ContextT, ResponseT]):
|
|
"""Track tool call counts and enforces limits during agent execution.
|
|
|
|
This middleware monitors the number of tool calls made and can terminate or
|
|
restrict execution when limits are exceeded. It supports both thread-level
|
|
(persistent across runs) and run-level (per invocation) call counting.
|
|
|
|
Configuration:
|
|
- `exit_behavior`: How to handle when limits are exceeded
|
|
- `'continue'`: Block exceeded tools, let execution continue (default)
|
|
- `'error'`: Raise an exception
|
|
- `'end'`: Stop immediately with a `ToolMessage` + AI message for the single
|
|
tool call that exceeded the limit (raises `NotImplementedError` if there
|
|
are other pending tool calls (due to parallel tool calling).
|
|
|
|
Examples:
|
|
!!! example "Continue execution with blocked tools (default)"
|
|
|
|
```python
|
|
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
|
from langchain.agents import create_agent
|
|
|
|
# Block exceeded tools but let other tools and model continue
|
|
limiter = ToolCallLimitMiddleware(
|
|
thread_limit=20,
|
|
run_limit=10,
|
|
exit_behavior="continue", # default
|
|
)
|
|
|
|
agent = create_agent("openai:gpt-5.5", middleware=[limiter])
|
|
```
|
|
|
|
!!! example "Stop immediately when limit exceeded"
|
|
|
|
```python
|
|
# End execution immediately with an AI message
|
|
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
|
|
|
|
agent = create_agent("openai:gpt-5.5", middleware=[limiter])
|
|
```
|
|
|
|
!!! example "Raise exception on limit"
|
|
|
|
```python
|
|
# Strict limit with exception handling
|
|
limiter = ToolCallLimitMiddleware(
|
|
tool_name="search", thread_limit=5, exit_behavior="error"
|
|
)
|
|
|
|
agent = create_agent("openai:gpt-5.5", middleware=[limiter])
|
|
|
|
try:
|
|
result = await agent.invoke({"messages": [HumanMessage("Task")]})
|
|
except ToolCallLimitExceededError as e:
|
|
print(f"Search limit exceeded: {e}")
|
|
```
|
|
|
|
"""
|
|
|
|
state_schema = ToolCallLimitState # type: ignore[assignment]
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
tool_name: str | None = None,
|
|
thread_limit: int | None = None,
|
|
run_limit: int | None = None,
|
|
exit_behavior: ExitBehavior = "continue",
|
|
) -> None:
|
|
"""Initialize the tool call limit middleware.
|
|
|
|
Args:
|
|
tool_name: Name of the specific tool to limit. If `None`, limits apply
|
|
to all tools.
|
|
thread_limit: Maximum number of tool calls allowed per thread.
|
|
`None` means no limit.
|
|
run_limit: Maximum number of tool calls allowed per run.
|
|
`None` means no limit.
|
|
exit_behavior: How to handle when limits are exceeded.
|
|
|
|
- `'continue'`: Block exceeded tools with error messages, let other
|
|
tools continue. Model decides when to end.
|
|
- `'error'`: Raise a `ToolCallLimitExceededError` exception
|
|
- `'end'`: Stop execution immediately with a `ToolMessage` + AI message
|
|
for the single tool call that exceeded the limit. Raises
|
|
`NotImplementedError` if there are multiple parallel tool
|
|
calls to other tools or multiple pending tool calls.
|
|
|
|
Raises:
|
|
ValueError: If both limits are `None`, if `exit_behavior` is invalid,
|
|
or if `run_limit` exceeds `thread_limit`.
|
|
"""
|
|
super().__init__()
|
|
|
|
if thread_limit is None and run_limit is None:
|
|
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
|
raise ValueError(msg)
|
|
|
|
valid_behaviors = ("continue", "error", "end")
|
|
if exit_behavior not in valid_behaviors:
|
|
msg = f"Invalid exit_behavior: {exit_behavior!r}. Must be one of {valid_behaviors}"
|
|
raise ValueError(msg)
|
|
|
|
if thread_limit is not None and run_limit is not None and run_limit > thread_limit:
|
|
msg = (
|
|
f"run_limit ({run_limit}) cannot exceed thread_limit ({thread_limit}). "
|
|
"The run limit should be less than or equal to the thread limit."
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
self.tool_name = tool_name
|
|
self.thread_limit = thread_limit
|
|
self.run_limit = run_limit
|
|
self.exit_behavior = exit_behavior
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""The name of the middleware instance.
|
|
|
|
Includes the tool name if specified to allow multiple instances
|
|
of this middleware with different tool names.
|
|
"""
|
|
base_name = self.__class__.__name__
|
|
if self.tool_name:
|
|
return f"{base_name}[{self.tool_name}]"
|
|
return base_name
|
|
|
|
def _would_exceed_limit(self, thread_count: int, run_count: int) -> bool:
|
|
"""Check if incrementing the counts would exceed any configured limit.
|
|
|
|
Args:
|
|
thread_count: Current thread call count.
|
|
run_count: Current run call count.
|
|
|
|
Returns:
|
|
True if either limit would be exceeded by one more call.
|
|
"""
|
|
return (self.thread_limit is not None and thread_count + 1 > self.thread_limit) or (
|
|
self.run_limit is not None and run_count + 1 > self.run_limit
|
|
)
|
|
|
|
def _matches_tool_filter(self, tool_call: ToolCall) -> bool:
|
|
"""Check if a tool call matches this middleware's tool filter.
|
|
|
|
Args:
|
|
tool_call: The tool call to check.
|
|
|
|
Returns:
|
|
True if this middleware should track this tool call.
|
|
"""
|
|
return self.tool_name is None or tool_call["name"] == self.tool_name
|
|
|
|
def _separate_tool_calls(
|
|
self, tool_calls: list[ToolCall], thread_count: int, run_count: int
|
|
) -> tuple[list[ToolCall], list[ToolCall], int, int]:
|
|
"""Separate tool calls into allowed and blocked based on limits.
|
|
|
|
Args:
|
|
tool_calls: List of tool calls to evaluate.
|
|
thread_count: Current thread call count.
|
|
run_count: Current run call count.
|
|
|
|
Returns:
|
|
Tuple of `(allowed_calls, blocked_calls, final_thread_count, final_run_count)`.
|
|
"""
|
|
allowed_calls: list[ToolCall] = []
|
|
blocked_calls: list[ToolCall] = []
|
|
temp_thread_count = thread_count
|
|
temp_run_count = run_count
|
|
|
|
for tool_call in tool_calls:
|
|
if not self._matches_tool_filter(tool_call):
|
|
continue
|
|
|
|
if self._would_exceed_limit(temp_thread_count, temp_run_count):
|
|
blocked_calls.append(tool_call)
|
|
else:
|
|
allowed_calls.append(tool_call)
|
|
temp_thread_count += 1
|
|
temp_run_count += 1
|
|
|
|
return allowed_calls, blocked_calls, temp_thread_count, temp_run_count
|
|
|
|
@hook_config(can_jump_to=["end"])
|
|
@override
|
|
def after_model(
|
|
self,
|
|
state: ToolCallLimitState[ResponseT],
|
|
runtime: Runtime[ContextT],
|
|
) -> dict[str, Any] | None:
|
|
"""Increment tool call counts after a model call and check limits.
|
|
|
|
Args:
|
|
state: The current agent state.
|
|
runtime: The langgraph runtime.
|
|
|
|
Returns:
|
|
State updates with incremented tool call counts. If limits are exceeded
|
|
and exit_behavior is `'end'`, also includes a jump to end with a
|
|
`ToolMessage` and AI message for the single exceeded tool call.
|
|
|
|
Raises:
|
|
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
|
is `'error'`.
|
|
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
|
and there are multiple tool calls.
|
|
"""
|
|
# Get the last AIMessage to check for tool calls
|
|
messages = state.get("messages", [])
|
|
if not messages:
|
|
return None
|
|
|
|
# Find the last AIMessage
|
|
last_ai_message = None
|
|
for message in reversed(messages):
|
|
if isinstance(message, AIMessage):
|
|
last_ai_message = message
|
|
break
|
|
|
|
if not last_ai_message or not last_ai_message.tool_calls:
|
|
return None
|
|
|
|
# Get the count key for this middleware instance
|
|
count_key = self.tool_name or "__all__"
|
|
|
|
# Get current counts
|
|
thread_counts = state.get("thread_tool_call_count", {}).copy()
|
|
run_counts = state.get("run_tool_call_count", {}).copy()
|
|
current_thread_count = thread_counts.get(count_key, 0)
|
|
current_run_count = run_counts.get(count_key, 0)
|
|
|
|
# Separate tool calls into allowed and blocked
|
|
allowed_calls, blocked_calls, new_thread_count, new_run_count = self._separate_tool_calls(
|
|
last_ai_message.tool_calls, current_thread_count, current_run_count
|
|
)
|
|
|
|
# Update counts to include only allowed calls for thread count
|
|
# (blocked calls don't count towards thread-level tracking)
|
|
# But run count includes blocked calls since they were attempted in this run
|
|
thread_counts[count_key] = new_thread_count
|
|
run_counts[count_key] = new_run_count + len(blocked_calls)
|
|
|
|
# If no tool calls are blocked, just update counts
|
|
if not blocked_calls:
|
|
if allowed_calls:
|
|
return {
|
|
"thread_tool_call_count": thread_counts,
|
|
"run_tool_call_count": run_counts,
|
|
}
|
|
return None
|
|
|
|
# Get final counts for building messages
|
|
final_thread_count = thread_counts[count_key]
|
|
final_run_count = run_counts[count_key]
|
|
|
|
# Handle different exit behaviors
|
|
if self.exit_behavior == "error":
|
|
# Use hypothetical thread count to show which limit was exceeded
|
|
hypothetical_thread_count = final_thread_count + len(blocked_calls)
|
|
raise ToolCallLimitExceededError(
|
|
thread_count=hypothetical_thread_count,
|
|
run_count=final_run_count,
|
|
thread_limit=self.thread_limit,
|
|
run_limit=self.run_limit,
|
|
tool_name=self.tool_name,
|
|
)
|
|
|
|
# Build tool message content (sent to model - no thread/run details)
|
|
tool_msg_content = _build_tool_message_content(self.tool_name)
|
|
|
|
# Inject artificial error ToolMessages for blocked tool calls
|
|
artificial_messages: list[ToolMessage | AIMessage] = [
|
|
ToolMessage(
|
|
content=tool_msg_content,
|
|
tool_call_id=tool_call["id"],
|
|
name=tool_call.get("name"),
|
|
status="error",
|
|
)
|
|
for tool_call in blocked_calls
|
|
]
|
|
|
|
if self.exit_behavior == "end":
|
|
# Check if there are tool calls to other tools that would continue executing
|
|
other_tools = [
|
|
tc
|
|
for tc in last_ai_message.tool_calls
|
|
if self.tool_name is not None and tc["name"] != self.tool_name
|
|
]
|
|
|
|
if other_tools:
|
|
tool_names = ", ".join({tc["name"] for tc in other_tools})
|
|
msg = (
|
|
f"Cannot end execution with other tool calls pending. "
|
|
f"Found calls to: {tool_names}. Use 'continue' or 'error' behavior instead."
|
|
)
|
|
raise NotImplementedError(msg)
|
|
|
|
# Build final AI message content (displayed to user - includes thread/run details)
|
|
# Use hypothetical thread count (what it would have been if call wasn't blocked)
|
|
# to show which limit was actually exceeded
|
|
hypothetical_thread_count = final_thread_count + len(blocked_calls)
|
|
final_msg_content = _build_final_ai_message_content(
|
|
hypothetical_thread_count,
|
|
final_run_count,
|
|
self.thread_limit,
|
|
self.run_limit,
|
|
self.tool_name,
|
|
)
|
|
artificial_messages.append(AIMessage(content=final_msg_content))
|
|
|
|
return {
|
|
"thread_tool_call_count": thread_counts,
|
|
"run_tool_call_count": run_counts,
|
|
"jump_to": "end",
|
|
"messages": artificial_messages,
|
|
}
|
|
|
|
# For exit_behavior="continue", return error messages to block exceeded tools
|
|
return {
|
|
"thread_tool_call_count": thread_counts,
|
|
"run_tool_call_count": run_counts,
|
|
"messages": artificial_messages,
|
|
}
|
|
|
|
@hook_config(can_jump_to=["end"])
|
|
async def aafter_model(
|
|
self,
|
|
state: ToolCallLimitState[ResponseT],
|
|
runtime: Runtime[ContextT],
|
|
) -> dict[str, Any] | None:
|
|
"""Async increment tool call counts after a model call and check limits.
|
|
|
|
Args:
|
|
state: The current agent state.
|
|
runtime: The langgraph runtime.
|
|
|
|
Returns:
|
|
State updates with incremented tool call counts. If limits are exceeded
|
|
and exit_behavior is `'end'`, also includes a jump to end with a
|
|
`ToolMessage` and AI message for the single exceeded tool call.
|
|
|
|
Raises:
|
|
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
|
is `'error'`.
|
|
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
|
and there are multiple tool calls.
|
|
"""
|
|
return self.after_model(state, runtime)
|