mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-08 18:19:21 +00:00
Compare commits
12 Commits
langchain=
...
sr/refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f6a29c44d | ||
|
|
ec1bd054f1 | ||
|
|
f0139330f7 | ||
|
|
984e1e984f | ||
|
|
e529445b29 | ||
|
|
5b972fb0e9 | ||
|
|
b8a20329d8 | ||
|
|
54e507f331 | ||
|
|
6001543093 | ||
|
|
894ffa0be5 | ||
|
|
56c93fa82f | ||
|
|
3743d596c7 |
@@ -1,6 +1,7 @@
|
||||
"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""
|
||||
|
||||
from langchain.agents.react_agent import AgentState, create_agent
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
|
||||
__all__ = [
|
||||
"AgentState",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Middleware agent implementation."""
|
||||
"""Agent factory implementation."""
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
@@ -9,9 +9,10 @@ from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMe
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph.state import StateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Send
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.types import Checkpointer, Send
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||
|
||||
@@ -167,17 +168,142 @@ def _handle_structured_output_error(
|
||||
|
||||
|
||||
def create_agent( # noqa: PLR0915
|
||||
*,
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
|
||||
*,
|
||||
system_prompt: str | None = None,
|
||||
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
|
||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
) -> StateGraph[
|
||||
checkpointer: Checkpointer | None = None,
|
||||
store: BaseStore | None = None,
|
||||
interrupt_before: list[str] | None = None,
|
||||
interrupt_after: list[str] | None = None,
|
||||
debug: bool = False,
|
||||
cache: Any = None,
|
||||
name: str | None = None,
|
||||
) -> CompiledStateGraph[
|
||||
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
||||
]:
|
||||
"""Create a middleware agent graph."""
|
||||
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
|
||||
|
||||
For more details on using `create_agent`,
|
||||
visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.
|
||||
|
||||
Args:
|
||||
model: The language model for the agent. Supports static model
|
||||
selection with string identifier (e.g., `"openai:gpt-4"`) or
|
||||
chat model instance (e.g., `ChatOpenAI()`).
|
||||
|
||||
tools: A list of tools or a ToolNode instance.
|
||||
If an empty list is provided, the agent will consist of a single LLM node
|
||||
without tool calling.
|
||||
|
||||
system_prompt: An optional system prompt for the LLM. This is converted to a
|
||||
SystemMessage and added to the beginning of the list of messages in
|
||||
state["messages"].
|
||||
|
||||
middleware: An optional sequence of AgentMiddleware instances for customizing
|
||||
agent behavior. Middleware can:
|
||||
|
||||
- **Intercept execution** with before_model/after_model hooks
|
||||
- **Modify model requests** with modify_model_request
|
||||
- **Add custom state** via state_schema
|
||||
- **Control flow** with jump_to (using @hook_config decorator)
|
||||
|
||||
See AgentMiddleware documentation for implementation details.
|
||||
|
||||
response_format: An optional configuration for structured responses.
|
||||
|
||||
If provided, the agent will handle structured output via tool calls
|
||||
during the normal conversation flow.
|
||||
When the model calls a structured output tool, the response will be captured
|
||||
and returned in the 'structured_response' state key.
|
||||
If not provided, `structured_response` will not be present in the output state.
|
||||
|
||||
Can be one of:
|
||||
|
||||
- **Pydantic model class**: Automatically converted to ToolStrategy or
|
||||
ProviderStrategy based on model capabilities
|
||||
- **ToolStrategy**: Uses tool calling for structured output with error handling
|
||||
- **ProviderStrategy**: Uses provider-native structured output (OpenAI, etc.)
|
||||
|
||||
!!! important
|
||||
`response_format` requires the model to support tool calling
|
||||
|
||||
!!! note
|
||||
Structured responses are handled directly in the model call node via
|
||||
tool calls, eliminating the need for separate structured response nodes.
|
||||
|
||||
context_schema: An optional schema for runtime context.
|
||||
|
||||
checkpointer: An optional checkpoint saver object. This is used for persisting
|
||||
the state of the graph (e.g., as chat memory) for a single thread
|
||||
(e.g., a single conversation).
|
||||
|
||||
store: An optional store object. This is used for persisting data
|
||||
across multiple threads (e.g., multiple conversations / users).
|
||||
|
||||
interrupt_before: An optional list of node names to interrupt before.
|
||||
Should be one of the following: "model_request", "tools".
|
||||
This is useful if you want to add a user confirmation or other interrupt
|
||||
before taking an action.
|
||||
|
||||
interrupt_after: An optional list of node names to interrupt after.
|
||||
Should be one of the following: "model_request", "tools".
|
||||
This is useful if you want to return directly or run additional processing on an output.
|
||||
|
||||
debug: A flag indicating whether to enable debug mode.
|
||||
|
||||
cache: An optional cache object for caching LLM responses.
|
||||
|
||||
name: An optional name for the compiled graph.
|
||||
|
||||
Returns:
|
||||
A compiled LangGraph agent that can be used for chat interactions.
|
||||
|
||||
The "model_request" node calls the language model with the messages list
|
||||
(after applying the system prompt if provided).
|
||||
If the resulting AIMessage contains `tool_calls`,
|
||||
the graph will then call the ["tools"][langgraph.prebuilt.tool_node.ToolNode].
|
||||
The "tools" node executes the tools (1 tool per `tool_call`)
|
||||
and adds the responses to the messages list as `ToolMessage` objects.
|
||||
The model_request node then calls the language model again.
|
||||
The process repeats until no more `tool_calls` are present in the response.
|
||||
The agent then returns the full list of messages as a dictionary containing the key "messages".
|
||||
|
||||
``` mermaid
|
||||
sequenceDiagram
|
||||
participant U as User
|
||||
participant A as LLM
|
||||
participant T as Tools
|
||||
U->>A: Initial input
|
||||
Note over A: System Prompt + LLM
|
||||
loop while tool_calls present
|
||||
A->>T: Execute tools
|
||||
T-->>A: ToolMessage for each tool_calls
|
||||
end
|
||||
A->>U: Return final state
|
||||
```
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
|
||||
def check_weather(location: str) -> str:
|
||||
'''Return the weather forecast for the specified location.'''
|
||||
return f"It's always sunny in {location}"
|
||||
|
||||
graph = create_agent(
|
||||
model="anthropic:claude-3-7-sonnet-latest",
|
||||
tools=[check_weather],
|
||||
system_prompt="You are a helpful assistant",
|
||||
)
|
||||
inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
|
||||
for chunk in graph.stream(inputs, stream_mode="updates"):
|
||||
print(chunk)
|
||||
```
|
||||
"""
|
||||
# init chat model
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
@@ -368,26 +494,8 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> Runnable:
|
||||
"""Get the model with appropriate tool bindings."""
|
||||
# Get actual tool objects from tool names
|
||||
tools_by_name = {t.name: t for t in default_tools}
|
||||
|
||||
unknown_tools = [name for name in request.tools if name not in tools_by_name]
|
||||
if unknown_tools:
|
||||
available_tools = sorted(tools_by_name.keys())
|
||||
msg = (
|
||||
f"Middleware returned unknown tool names: {unknown_tools}\n\n"
|
||||
f"Available tools: {available_tools}\n\n"
|
||||
"To fix this issue:\n"
|
||||
"1. Ensure the tools are passed to create_agent() via "
|
||||
"the 'tools' parameter\n"
|
||||
"2. If using custom middleware with tools, ensure "
|
||||
"they're registered via middleware.tools attribute\n"
|
||||
"3. Verify that tool names in ModelRequest.tools match "
|
||||
"the actual tool.name values"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
requested_tools = [tools_by_name[name] for name in request.tools]
|
||||
# request.tools contains BaseTool | dict objects
|
||||
requested_tools = request.tools
|
||||
|
||||
if isinstance(response_format, ProviderStrategy):
|
||||
# Use native structured output
|
||||
@@ -411,7 +519,7 @@ def create_agent( # noqa: PLR0915
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=[t.name for t in default_tools],
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
@@ -438,17 +546,19 @@ def create_agent( # noqa: PLR0915
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
output = model_.invoke(messages)
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(output),
|
||||
}
|
||||
result = _handle_model_output(output)
|
||||
|
||||
# Always add call counts
|
||||
result["thread_model_call_count"] = state.get("thread_model_call_count", 0) + 1
|
||||
result["run_model_call_count"] = state.get("run_model_call_count", 0) + 1
|
||||
|
||||
return result
|
||||
|
||||
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=[t.name for t in default_tools],
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
@@ -466,11 +576,13 @@ def create_agent( # noqa: PLR0915
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
output = await model_.ainvoke(messages)
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(output),
|
||||
}
|
||||
result = _handle_model_output(output)
|
||||
|
||||
# Always add call counts
|
||||
result["thread_model_call_count"] = state.get("thread_model_call_count", 0) + 1
|
||||
result["run_model_call_count"] = state.get("run_model_call_count", 0) + 1
|
||||
|
||||
return result
|
||||
|
||||
# Use sync or async based on model capabilities
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
@@ -595,7 +707,15 @@ def create_agent( # noqa: PLR0915
|
||||
can_jump_to=_get_can_jump_to(m1, "after_model"),
|
||||
)
|
||||
|
||||
return graph
|
||||
return graph.compile(
|
||||
checkpointer=checkpointer,
|
||||
store=store,
|
||||
interrupt_before=interrupt_before,
|
||||
interrupt_after=interrupt_after,
|
||||
debug=debug,
|
||||
cache=cache,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
||||
@@ -626,8 +746,8 @@ def _fetch_last_ai_and_tool_messages(
|
||||
|
||||
def _make_model_to_tools_edge(
|
||||
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
|
||||
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
||||
def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
|
||||
) -> Callable[[dict[str, Any], Runtime], str | list[Send] | None]:
|
||||
def model_to_tools(state: dict[str, Any], runtime: Runtime) -> str | list[Send] | None:
|
||||
# 1. if there's an explicit jump_to in the state, use it
|
||||
if jump_to := state.get("jump_to"):
|
||||
return _resolve_jump(jump_to, first_node)
|
||||
@@ -646,14 +766,19 @@ def _make_model_to_tools_edge(
|
||||
if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
|
||||
]
|
||||
|
||||
# 3. if there are pending tool calls, jump to the tool node
|
||||
# 3. if there are pending (non-structured) tool calls, jump to the tool node
|
||||
if pending_tool_calls:
|
||||
pending_tool_calls = [
|
||||
tool_node.inject_tool_args(call, state, None) for call in pending_tool_calls
|
||||
tool_node.inject_tool_args(call, state, runtime.store)
|
||||
for call in pending_tool_calls
|
||||
]
|
||||
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
||||
|
||||
# 4. AIMessage has tool calls, but there are no pending tool calls
|
||||
# 4. if we have a structured response and no pending tool calls, we're done
|
||||
if "structured_response" in state:
|
||||
return END
|
||||
|
||||
# 5. AIMessage has tool calls, but there are no pending tool calls
|
||||
# which suggests the injection of artificial tool messages. jump to the first node
|
||||
return first_node
|
||||
|
||||
@@ -666,7 +791,8 @@ def _make_tools_to_model_edge(
|
||||
def tools_to_model(state: dict[str, Any]) -> str | None:
|
||||
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||
|
||||
if all(
|
||||
# Check if any tool call has return_direct=True
|
||||
if any(
|
||||
tool_node.tools_by_name[c["name"]].return_direct
|
||||
for c in last_ai_message.tool_calls
|
||||
if c["name"] in tool_node.tools_by_name
|
||||
@@ -692,7 +818,6 @@ def _add_middleware_edge(
|
||||
|
||||
Args:
|
||||
graph: The graph to add the edge to.
|
||||
method: The method to call for the middleware node.
|
||||
name: The name of the middleware node.
|
||||
default_destination: The default destination for the edge.
|
||||
model_destination: The destination for the edge to the model.
|
||||
@@ -716,3 +841,6 @@ def _add_middleware_edge(
|
||||
|
||||
else:
|
||||
graph.add_edge(name, default_destination)
|
||||
|
||||
|
||||
__all__ = ["create_agent"]
|
||||
@@ -62,7 +62,7 @@ class ModelRequest:
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage] # excluding system prompt
|
||||
tool_choice: Any | None
|
||||
tools: list[str]
|
||||
tools: list[BaseTool | dict[str, Any]]
|
||||
response_format: ResponseFormat | None
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -85,14 +85,14 @@ langchain-text-splitters = { path = "../text-splitters", editable = true }
|
||||
langchain-openai = { path = "../partners/openai", editable = true }
|
||||
|
||||
[tool.ruff]
|
||||
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
|
||||
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py", "tests/unit_tests/tools/test_tool_node.py"]
|
||||
line-length = 100
|
||||
|
||||
[tool.mypy]
|
||||
strict = true
|
||||
ignore_missing_imports = true
|
||||
enable_error_code = "deprecated"
|
||||
exclude = ["tests/unit_tests/agents/*", "tests/integration_tests/agents/*"]
|
||||
exclude = ["tests/unit_tests/agents/*", "tests/integration_tests/agents/*", "tests/unit_tests/tools/test_tool_node.py"]
|
||||
|
||||
# TODO: activate for 'strict' checking
|
||||
disallow_any_generics = false
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Middleware tests for agents."""
|
||||
@@ -0,0 +1,230 @@
|
||||
"""Tests for ModelCallLimitMiddleware."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
|
||||
from langchain.agents.middleware.call_tracking import (
|
||||
ModelCallLimitMiddleware,
|
||||
ModelCallLimitExceededError,
|
||||
)
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from ..model import FakeToolCallingModel
|
||||
|
||||
|
||||
@tool
|
||||
def simple_tool(input: str) -> str:
|
||||
"""A simple tool"""
|
||||
return input
|
||||
|
||||
|
||||
def test_middleware_unit_functionality():
|
||||
"""Test that the middleware works as expected in isolation."""
|
||||
# Test with end behavior
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1)
|
||||
|
||||
# Mock runtime (not used in current implementation)
|
||||
runtime = None
|
||||
|
||||
# Test when limits are not exceeded
|
||||
state = {"thread_model_call_count": 0, "run_model_call_count": 0}
|
||||
result = middleware.before_model(state, runtime)
|
||||
assert result is None
|
||||
|
||||
# Test when thread limit is exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
||||
result = middleware.before_model(state, runtime)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "end"
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert "thread limit (2/2)" in result["messages"][0].content
|
||||
|
||||
# Test when run limit is exceeded
|
||||
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
|
||||
result = middleware.before_model(state, runtime)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "end"
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert "run limit (1/1)" in result["messages"][0].content
|
||||
|
||||
# Test with error behavior
|
||||
middleware_exception = ModelCallLimitMiddleware(
|
||||
thread_limit=2, run_limit=1, exit_behavior="error"
|
||||
)
|
||||
|
||||
# Test exception when thread limit exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware_exception.before_model(state, runtime)
|
||||
|
||||
assert "thread limit (2/2)" in str(exc_info.value)
|
||||
|
||||
# Test exception when run limit exceeded
|
||||
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware_exception.before_model(state, runtime)
|
||||
|
||||
assert "run limit (1/1)" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_thread_limit_with_create_agent():
|
||||
"""Test that thread limits work correctly with create_agent."""
|
||||
# Create a model that will make 2 calls
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"name": "simple_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
[], # No tool calls on second call
|
||||
]
|
||||
)
|
||||
|
||||
# Set thread limit to 1 (should be exceeded after 1 call)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[simple_tool],
|
||||
middleware=[ModelCallLimitMiddleware(thread_limit=1)],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
# This should hit the thread limit after the first model call
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
|
||||
)
|
||||
|
||||
assert "messages" in result
|
||||
# The agent should have made 1 model call then jumped to end with limit exceeded message
|
||||
# So we should have: Human + AI + Tool + Limit exceeded AI message
|
||||
assert len(result["messages"]) == 4 # Human + AI + Tool + Limit AI
|
||||
assert isinstance(result["messages"][0], HumanMessage)
|
||||
assert isinstance(result["messages"][1], AIMessage)
|
||||
assert isinstance(result["messages"][2], ToolMessage)
|
||||
# Limit exceeded message
|
||||
assert "thread limit" in result["messages"][3].content
|
||||
|
||||
# Second invocation should also hit the limit
|
||||
result2 = agent.invoke(
|
||||
{"messages": [HumanMessage("Hello again")]}, {"configurable": {"thread_id": "thread1"}}
|
||||
)
|
||||
assert len(result2["messages"]) == 2 # Human + Limit exceeded AI
|
||||
assert isinstance(result2["messages"][0], HumanMessage)
|
||||
# Limit exceeded message
|
||||
assert "thread limit" in result2["messages"][1].content
|
||||
|
||||
|
||||
def test_run_limit_with_create_agent():
|
||||
"""Test that run limits work correctly with create_agent."""
|
||||
# Create a model that will make 2 calls
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"name": "simple_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
[], # No tool calls on second call
|
||||
]
|
||||
)
|
||||
|
||||
# Set run limit to 1 (should be exceeded after 1 call)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[simple_tool],
|
||||
middleware=[ModelCallLimitMiddleware(run_limit=1)],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
# This should hit the run limit after the first model call
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
|
||||
)
|
||||
|
||||
assert "messages" in result
|
||||
# The agent should have made 1 model call then jumped to end with limit exceeded message
|
||||
# So we should have: Human + AI + Tool + Limit exceeded AI message
|
||||
assert len(result["messages"]) == 4 # Human + AI + Tool + Limit AI
|
||||
assert isinstance(result["messages"][0], HumanMessage)
|
||||
assert isinstance(result["messages"][1], AIMessage)
|
||||
assert isinstance(result["messages"][2], ToolMessage)
|
||||
# Limit exceeded message
|
||||
assert "run limit" in result["messages"][3].content
|
||||
|
||||
|
||||
def test_middleware_initialization_validation():
|
||||
"""Test that middleware initialization validates parameters correctly."""
|
||||
# Test that at least one limit must be specified
|
||||
with pytest.raises(ValueError, match="At least one limit must be specified"):
|
||||
ModelCallLimitMiddleware()
|
||||
|
||||
# Test invalid exit behavior
|
||||
with pytest.raises(ValueError, match="Invalid exit_behavior"):
|
||||
ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid")
|
||||
|
||||
# Test valid initialization
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3)
|
||||
assert middleware.thread_limit == 5
|
||||
assert middleware.run_limit == 3
|
||||
assert middleware.exit_behavior == "end"
|
||||
|
||||
# Test with only thread limit
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=5)
|
||||
assert middleware.thread_limit == 5
|
||||
assert middleware.run_limit is None
|
||||
|
||||
# Test with only run limit
|
||||
middleware = ModelCallLimitMiddleware(run_limit=3)
|
||||
assert middleware.thread_limit is None
|
||||
assert middleware.run_limit == 3
|
||||
|
||||
|
||||
def test_exception_error_message():
|
||||
"""Test that the exception provides clear error messages."""
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error")
|
||||
|
||||
# Test thread limit exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware.before_model(state, None)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Model call limits exceeded" in error_msg
|
||||
assert "thread limit (2/2)" in error_msg
|
||||
|
||||
# Test run limit exceeded
|
||||
state = {"thread_model_call_count": 0, "run_model_call_count": 1}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware.before_model(state, None)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Model call limits exceeded" in error_msg
|
||||
assert "run limit (1/1)" in error_msg
|
||||
|
||||
# Test both limits exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 1}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware.before_model(state, None)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Model call limits exceeded" in error_msg
|
||||
assert "thread limit (2/2)" in error_msg
|
||||
assert "run limit (1/1)" in error_msg
|
||||
|
||||
|
||||
def test_run_limit_resets_between_invocations() -> None:
|
||||
"""Test that run limits reset between invocations."""
|
||||
# First: No tool calls per invocation, so model does not increment call counts internally
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=3, run_limit=1, exit_behavior="error")
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[], [], [], []]
|
||||
) # No tool calls, so only model call per run
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver())
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test_thread"}}
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config)
|
||||
agent.invoke({"messages": [HumanMessage("Hello again")]}, thread_config)
|
||||
agent.invoke({"messages": [HumanMessage("Hello third")]}, thread_config)
|
||||
|
||||
# Fourth run: should raise, thread_model_call_count == 3 (limit)
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
agent.invoke({"messages": [HumanMessage("Hello fourth")]}, thread_config)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "thread limit (3/3)" in error_msg
|
||||
@@ -0,0 +1,346 @@
|
||||
"""Tests for middleware decorators: before_model, after_model, and modify_model_request."""
|
||||
|
||||
import pytest
|
||||
from typing import Any
|
||||
from typing_extensions import NotRequired
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
before_model,
|
||||
after_model,
|
||||
modify_model_request,
|
||||
hook_config,
|
||||
)
|
||||
from langchain.agents.factory import create_agent, _get_can_jump_to
|
||||
|
||||
from ..model import FakeToolCallingModel
|
||||
|
||||
|
||||
class CustomState(AgentState):
|
||||
"""Custom state schema for testing."""
|
||||
|
||||
custom_field: NotRequired[str]
|
||||
|
||||
|
||||
@tool
|
||||
def test_tool(input: str) -> str:
|
||||
"""A test tool for middleware testing."""
|
||||
return f"Tool result: {input}"
|
||||
|
||||
|
||||
def test_before_model_decorator() -> None:
|
||||
"""Test before_model decorator with all configuration options."""
|
||||
|
||||
@before_model(
|
||||
state_schema=CustomState, tools=[test_tool], can_jump_to=["end"], name="CustomBeforeModel"
|
||||
)
|
||||
def custom_before_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"jump_to": "end"}
|
||||
|
||||
assert isinstance(custom_before_model, AgentMiddleware)
|
||||
assert custom_before_model.state_schema == CustomState
|
||||
assert custom_before_model.tools == [test_tool]
|
||||
assert getattr(custom_before_model.__class__.before_model, "__can_jump_to__", []) == ["end"]
|
||||
|
||||
|
||||
def test_after_model_decorator() -> None:
|
||||
"""Test after_model decorator with all configuration options."""
|
||||
|
||||
@after_model(
|
||||
state_schema=CustomState, tools=[test_tool], can_jump_to=["end"], name="CustomAfterModel"
|
||||
)
|
||||
def custom_after_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"jump_to": "end"}
|
||||
|
||||
assert isinstance(custom_after_model, AgentMiddleware)
|
||||
assert custom_after_model.state_schema == CustomState
|
||||
assert custom_after_model.tools == [test_tool]
|
||||
assert getattr(custom_after_model.__class__.after_model, "__can_jump_to__", []) == ["end"]
|
||||
|
||||
|
||||
def test_modify_model_request_decorator() -> None:
|
||||
"""Test modify_model_request decorator with all configuration options."""
|
||||
|
||||
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="CustomModifyRequest")
|
||||
def custom_modify_request(
|
||||
request: ModelRequest, state: CustomState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
return request
|
||||
|
||||
assert isinstance(custom_modify_request, AgentMiddleware)
|
||||
assert custom_modify_request.state_schema == CustomState
|
||||
assert custom_modify_request.tools == [test_tool]
|
||||
|
||||
|
||||
def test_decorator_with_minimal_config() -> None:
|
||||
"""Test decorators with minimal configuration."""
|
||||
|
||||
@before_model
|
||||
def minimal_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@after_model
|
||||
def minimal_after_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@modify_model_request
|
||||
def minimal_modify_request(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
return request
|
||||
|
||||
assert isinstance(minimal_before_model, AgentMiddleware)
|
||||
assert isinstance(minimal_after_model, AgentMiddleware)
|
||||
assert isinstance(minimal_modify_request, AgentMiddleware)
|
||||
|
||||
|
||||
def test_decorator_with_jump_to() -> None:
|
||||
"""Test decorators with can_jump_to configuration."""
|
||||
|
||||
@before_model(can_jump_to=["end", "custom_node"])
|
||||
def jump_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"jump_to": "end"}
|
||||
|
||||
@after_model(can_jump_to=["end", "custom_node"])
|
||||
def jump_after_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"jump_to": "end"}
|
||||
|
||||
assert getattr(jump_before_model.__class__.before_model, "__can_jump_to__", []) == [
|
||||
"end",
|
||||
"custom_node",
|
||||
]
|
||||
assert getattr(jump_after_model.__class__.after_model, "__can_jump_to__", []) == [
|
||||
"end",
|
||||
"custom_node",
|
||||
]
|
||||
|
||||
|
||||
def test_decorator_with_tools() -> None:
|
||||
"""Test decorators with tools configuration."""
|
||||
|
||||
@tool
|
||||
def tool1(input: str) -> str:
|
||||
"""Tool 1."""
|
||||
return "result1"
|
||||
|
||||
@tool
|
||||
def tool2(input: str) -> str:
|
||||
"""Tool 2."""
|
||||
return "result2"
|
||||
|
||||
@before_model(tools=[tool1, tool2])
|
||||
def tools_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
assert tools_before_model.tools == [tool1, tool2]
|
||||
|
||||
|
||||
def test_decorator_with_custom_state() -> None:
|
||||
"""Test decorators with custom state schema."""
|
||||
|
||||
class MyCustomState(AgentState):
|
||||
"""Custom state for testing."""
|
||||
|
||||
my_field: NotRequired[int]
|
||||
|
||||
@before_model(state_schema=MyCustomState)
|
||||
def custom_state_before_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
assert custom_state_before_model.state_schema == MyCustomState
|
||||
|
||||
|
||||
def test_decorator_with_name() -> None:
|
||||
"""Test decorators with custom name."""
|
||||
|
||||
@before_model(name="MyCustomMiddleware")
|
||||
def named_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
assert named_before_model.__class__.__name__ == "MyCustomMiddleware"
|
||||
|
||||
|
||||
def test_decorator_combined_config() -> None:
|
||||
"""Test decorators with combined configuration options."""
|
||||
|
||||
@tool
|
||||
def combined_tool(input: str) -> str:
|
||||
"""Combined tool."""
|
||||
return "combined"
|
||||
|
||||
class CombinedState(AgentState):
|
||||
"""Combined state."""
|
||||
|
||||
combined_field: NotRequired[str]
|
||||
|
||||
@before_model(
|
||||
state_schema=CombinedState,
|
||||
tools=[combined_tool],
|
||||
can_jump_to=["end", "custom"],
|
||||
name="CombinedMiddleware",
|
||||
)
|
||||
def combined_before_model(state: CombinedState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"jump_to": "end"}
|
||||
|
||||
assert isinstance(combined_before_model, AgentMiddleware)
|
||||
assert combined_before_model.state_schema == CombinedState
|
||||
assert combined_before_model.tools == [combined_tool]
|
||||
assert getattr(combined_before_model.__class__.before_model, "__can_jump_to__", []) == [
|
||||
"end",
|
||||
"custom",
|
||||
]
|
||||
assert combined_before_model.__class__.__name__ == "CombinedMiddleware"
|
||||
|
||||
|
||||
def test_get_can_jump_to() -> None:
|
||||
"""Test _get_can_jump_to utility function."""
|
||||
|
||||
@before_model(can_jump_to=["end", "custom"])
|
||||
def test_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
can_jump_to = _get_can_jump_to(test_middleware)
|
||||
assert can_jump_to == ["end", "custom"]
|
||||
|
||||
# Test with no can_jump_to
|
||||
@before_model
|
||||
def no_jump_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
can_jump_to = _get_can_jump_to(no_jump_middleware)
|
||||
assert can_jump_to == []
|
||||
|
||||
|
||||
def test_decorator_functionality() -> None:
|
||||
"""Test that decorated functions work correctly in agents."""
|
||||
|
||||
@before_model
|
||||
def test_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"custom_field": "test_value"}
|
||||
|
||||
@after_model
|
||||
def test_after_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"processed": True}
|
||||
|
||||
@modify_model_request
|
||||
def test_modify_request(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
# Add custom header to model settings
|
||||
request.model_settings["custom_header"] = "test"
|
||||
return request
|
||||
|
||||
# Create agent with decorated middleware
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[test_tool],
|
||||
middleware=[test_before_model, test_after_model, test_modify_request],
|
||||
)
|
||||
|
||||
# Test that agent can be invoked
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
assert "custom_field" in result
|
||||
assert result["custom_field"] == "test_value"
|
||||
assert "processed" in result
|
||||
assert result["processed"] is True
|
||||
|
||||
|
||||
def test_decorator_with_jump_functionality() -> None:
|
||||
"""Test decorators with jump functionality."""
|
||||
|
||||
@before_model(can_jump_to=["end"])
|
||||
def jump_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"jump_to": "end"}
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[test_tool],
|
||||
middleware=[jump_middleware],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
# Should have jumped to end, so minimal messages
|
||||
assert len(result["messages"]) <= 2
|
||||
|
||||
|
||||
def test_decorator_with_tools_in_agent() -> None:
|
||||
"""Test that decorators with tools work correctly in agents."""
|
||||
|
||||
@tool
|
||||
def decorator_tool(input: str) -> str:
|
||||
"""Tool from decorator."""
|
||||
return f"Decorator tool result: {input}"
|
||||
|
||||
@before_model(tools=[decorator_tool])
|
||||
def tools_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[test_tool],
|
||||
middleware=[tools_middleware],
|
||||
)
|
||||
|
||||
# The agent should have access to both tools
|
||||
result = agent.invoke({"messages": [HumanMessage("Use both tools")]})
|
||||
assert "messages" in result
|
||||
|
||||
|
||||
def test_decorator_error_handling() -> None:
|
||||
"""Test decorator error handling."""
|
||||
|
||||
@before_model
|
||||
def error_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
raise ValueError("Test error")
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[test_tool],
|
||||
middleware=[error_middleware],
|
||||
)
|
||||
|
||||
# Should handle the error gracefully
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
|
||||
def test_decorator_with_hook_config() -> None:
|
||||
"""Test decorators with hook_config."""
|
||||
|
||||
@hook_config
|
||||
def my_hook_config():
|
||||
return {"custom_config": "test"}
|
||||
|
||||
@before_model
|
||||
def config_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
# Test that hook_config can be applied
|
||||
assert my_hook_config() == {"custom_config": "test"}
|
||||
|
||||
|
||||
def test_decorator_snapshot_compatibility(snapshot: SnapshotAssertion) -> None:
|
||||
"""Test that decorators produce consistent snapshots."""
|
||||
|
||||
@before_model(name="SnapshotMiddleware")
|
||||
def snapshot_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"snapshot_test": True}
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[test_tool],
|
||||
middleware=[snapshot_middleware],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Snapshot test")]})
|
||||
assert result == snapshot
|
||||
@@ -0,0 +1,514 @@
|
||||
"""Tests for HumanInTheLoopMiddleware."""
|
||||
|
||||
import pytest
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain.agents.middleware.human_in_the_loop import (
|
||||
ActionRequest,
|
||||
HumanInTheLoopMiddleware,
|
||||
)
|
||||
|
||||
from ..model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_initialization() -> None:
|
||||
"""Test HumanInTheLoopMiddleware initialization."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
},
|
||||
description_prefix="Custom prefix",
|
||||
)
|
||||
|
||||
assert middleware.interrupt_on == {
|
||||
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
}
|
||||
assert middleware.description_prefix == "Custom prefix"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
||||
"""Test HumanInTheLoopMiddleware when no interrupts are needed."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
# Test with no messages
|
||||
state: dict[str, Any] = {"messages": []}
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is None
|
||||
|
||||
# Test with message but no tool calls
|
||||
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi there")]}
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is None
|
||||
|
||||
# Test with tool calls that don't require interrupts
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "other_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_single_tool_accept() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool accept response."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_accept(requests):
|
||||
return [{"type": "accept", "args": None}]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0] == ai_message
|
||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||
|
||||
state["messages"].append(
|
||||
ToolMessage(content="Tool message", name="test_tool", tool_call_id="1")
|
||||
)
|
||||
state["messages"].append(AIMessage(content="test_tool called with result: Tool message"))
|
||||
|
||||
result = middleware.after_model(state, None)
|
||||
# No interrupts needed
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_single_tool_edit() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool edit response."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_edit(requests):
|
||||
return [
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="test_tool",
|
||||
args={"input": "edited"},
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||
assert result["messages"][0].tool_calls[0]["id"] == "1" # ID should be preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_single_tool_response() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool response with custom message."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_response(requests):
|
||||
return [{"type": "response", "args": "Custom response message"}]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 2
|
||||
assert isinstance(result["messages"][0], AIMessage)
|
||||
assert isinstance(result["messages"][1], ToolMessage)
|
||||
assert result["messages"][1].content == "Custom response message"
|
||||
assert result["messages"][1].name == "test_tool"
|
||||
assert result["messages"][1].tool_call_id == "1"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||
"get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you with weather",
|
||||
tool_calls=[
|
||||
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
|
||||
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||
|
||||
def mock_mixed_responses(requests):
|
||||
return [
|
||||
{"type": "accept", "args": None},
|
||||
{"type": "response", "args": "User rejected this tool call"},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert (
|
||||
len(result["messages"]) == 2
|
||||
) # AI message with accepted tool call + tool message for rejected
|
||||
|
||||
# First message should be the AI message with both tool calls
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert len(updated_ai_message.tool_calls) == 2 # Both tool calls remain
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "get_forecast" # Accepted
|
||||
assert updated_ai_message.tool_calls[1]["name"] == "get_temperature" # Got response
|
||||
|
||||
# Second message should be the tool message for the rejected tool call
|
||||
tool_message = result["messages"][1]
|
||||
assert isinstance(tool_message, ToolMessage)
|
||||
assert tool_message.content == "User rejected this tool call"
|
||||
assert tool_message.name == "get_temperature"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with multiple tools and edit responses."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||
"get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you with weather",
|
||||
tool_calls=[
|
||||
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
|
||||
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||
|
||||
def mock_edit_responses(requests):
|
||||
return [
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="get_forecast",
|
||||
args={"location": "New York"},
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="get_temperature",
|
||||
args={"location": "New York"},
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert updated_ai_message.tool_calls[0]["args"] == {"location": "New York"}
|
||||
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
|
||||
assert updated_ai_message.tool_calls[1]["args"] == {"location": "New York"}
|
||||
assert updated_ai_message.tool_calls[1]["id"] == "2" # ID preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with edit action that includes modified args."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_edit_with_args(requests):
|
||||
return [
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="test_tool",
|
||||
args={"input": "modified"},
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
side_effect=mock_edit_with_args,
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
# Should have modified args
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert updated_ai_message.tool_calls[0]["args"] == {"input": "modified"}
|
||||
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_unknown_response_type() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with unknown response type."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_unknown(requests):
|
||||
return [{"type": "unknown", "args": None}]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Unexpected human response: {'type': 'unknown', 'args': None}. Response action 'unknown' is not allowed for tool 'test_tool'. Expected one of \['accept', 'edit', 'response'\] based on the tool's configuration.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_disallowed_action() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with action not allowed by tool config."""
|
||||
# edit is not allowed by tool config
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": {"allow_respond": True, "allow_edit": False, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_disallowed_action(requests):
|
||||
return [
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="test_tool",
|
||||
args={"input": "modified"},
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
side_effect=mock_disallowed_action,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Unexpected human response: {'type': 'edit', 'args': {'action': 'test_tool', 'args': {'input': 'modified'}}}. Response action 'edit' is not allowed for tool 'test_tool'. Expected one of \['accept', 'response'\] based on the tool's configuration.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with mix of auto-approved and interrupt tools."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"interrupt_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[
|
||||
{"name": "auto_tool", "args": {"input": "auto"}, "id": "1"},
|
||||
{"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_accept(requests):
|
||||
return [{"type": "accept", "args": None}]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
# Should have both tools: auto-approved first, then interrupt tool
|
||||
assert len(updated_ai_message.tool_calls) == 2
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "auto_tool"
|
||||
assert updated_ai_message.tool_calls[1]["name"] == "interrupt_tool"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
|
||||
"""Test that interrupt requests are structured correctly."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
},
|
||||
description_prefix="Custom prefix",
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
captured_requests = []
|
||||
|
||||
def mock_capture_requests(requests):
|
||||
captured_requests.extend(requests)
|
||||
return [{"type": "accept", "args": None}]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
assert len(captured_requests) == 1
|
||||
request = captured_requests[0]
|
||||
|
||||
assert "action_request" in request
|
||||
assert "config" in request
|
||||
assert "description" in request
|
||||
|
||||
assert request["action_request"]["action"] == "test_tool"
|
||||
assert request["action_request"]["args"] == {"input": "test", "location": "SF"}
|
||||
expected_config = {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
assert request["config"] == expected_config
|
||||
assert "Custom prefix" in request["description"]
|
||||
assert "Tool: test_tool" in request["description"]
|
||||
assert "Args: {'input': 'test', 'location': 'SF'}" in request["description"]
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_boolean_configs() -> None:
|
||||
"""Test HITL middleware with boolean tool configs."""
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test accept
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[{"type": "accept", "args": None}],
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||
|
||||
# Test edit
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="test_tool",
|
||||
args={"input": "edited"},
|
||||
),
|
||||
}
|
||||
],
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": False})
|
||||
|
||||
result = middleware.after_model(state, None)
|
||||
# No interruption should occur
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
|
||||
"""Test that sequence mismatch in resume raises an error."""
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test with too few responses
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[], # No responses for 1 tool call
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Number of human responses \(0\) does not match number of hanging tool calls \(1\)\.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
# Test with too many responses
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[
|
||||
{"type": "accept", "args": None},
|
||||
{"type": "accept", "args": None},
|
||||
], # 2 responses for 1 tool call
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Number of human responses \(2\) does not match number of hanging tool calls \(1\)\.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
@@ -0,0 +1,232 @@
|
||||
"""Tests for PlanningMiddleware."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain.agents.middleware.planning import (
|
||||
PlanningMiddleware,
|
||||
PlanningState,
|
||||
WRITE_TODOS_SYSTEM_PROMPT,
|
||||
write_todos,
|
||||
WRITE_TODOS_TOOL_DESCRIPTION,
|
||||
)
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
from langchain.agents.factory import create_agent
|
||||
|
||||
from ..model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_planning_middleware_initialization() -> None:
|
||||
"""Test that PlanningMiddleware initializes correctly."""
|
||||
middleware = PlanningMiddleware()
|
||||
assert middleware.state_schema == PlanningState
|
||||
assert len(middleware.tools) == 1
|
||||
assert middleware.tools[0].name == "write_todos"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"original_prompt,expected_prompt_prefix",
|
||||
[
|
||||
("Original prompt", "Original prompt\n\n## `write_todos`"),
|
||||
(None, "## `write_todos`"),
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_modify_model_request(original_prompt, expected_prompt_prefix) -> None:
|
||||
"""Test that modify_model_request handles system prompts correctly."""
|
||||
middleware = PlanningMiddleware()
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=original_prompt,
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
modified_request = middleware.modify_model_request(request, state, None)
|
||||
assert modified_request.system_prompt.startswith(expected_prompt_prefix)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"todos,expected_message",
|
||||
[
|
||||
([], "Updated todo list to []"),
|
||||
(
|
||||
[{"content": "Task 1", "status": "pending"}],
|
||||
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}]",
|
||||
),
|
||||
(
|
||||
[
|
||||
{"content": "Task 1", "status": "pending"},
|
||||
{"content": "Task 2", "status": "in_progress"},
|
||||
],
|
||||
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}]",
|
||||
),
|
||||
(
|
||||
[
|
||||
{"content": "Task 1", "status": "pending"},
|
||||
{"content": "Task 2", "status": "in_progress"},
|
||||
{"content": "Task 3", "status": "completed"},
|
||||
],
|
||||
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}, {'content': 'Task 3', 'status': 'completed'}]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_write_todos_tool_execution(todos, expected_message) -> None:
|
||||
"""Test that the write_todos tool executes correctly."""
|
||||
tool_call = {
|
||||
"args": {"todos": todos},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
result = write_todos.invoke(tool_call)
|
||||
assert result.update["todos"] == todos
|
||||
assert result.update["messages"][0].content == expected_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_todos",
|
||||
[
|
||||
[{"content": "Task 1", "status": "invalid_status"}],
|
||||
[{"status": "pending"}],
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -> None:
|
||||
"""Test that the write_todos tool rejects invalid input."""
|
||||
tool_call = {
|
||||
"args": {"todos": invalid_todos},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
with pytest.raises(Exception):
|
||||
write_todos.invoke(tool_call)
|
||||
|
||||
|
||||
def test_planning_middleware_agent_creation_with_middleware() -> None:
|
||||
"""Test that an agent can be created with the planning middleware."""
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Task 1", "status": "pending"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Task 1", "status": "in_progress"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Task 1", "status": "completed"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
middleware = PlanningMiddleware()
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["todos"] == [{"content": "Task 1", "status": "completed"}]
|
||||
|
||||
# human message (1)
|
||||
# ai message (2) - initial todo
|
||||
# tool message (3)
|
||||
# ai message (4) - updated todo
|
||||
# tool message (5)
|
||||
# ai message (6) - complete todo
|
||||
# tool message (7)
|
||||
# ai message (8) - no tool calls
|
||||
assert len(result["messages"]) == 8
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with custom system prompt."""
|
||||
custom_system_prompt = "Custom todo system prompt for testing"
|
||||
middleware = PlanningMiddleware(system_prompt=custom_system_prompt)
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
modified_request = middleware.modify_model_request(request, state, None)
|
||||
assert modified_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
|
||||
|
||||
|
||||
def test_planning_middleware_custom_tool_description() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with custom tool description."""
|
||||
custom_tool_description = "Custom tool description for testing"
|
||||
middleware = PlanningMiddleware(tool_description=custom_tool_description)
|
||||
|
||||
assert len(middleware.tools) == 1
|
||||
tool = middleware.tools[0]
|
||||
assert tool.description == custom_tool_description
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt_and_tool_description() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with both custom prompts."""
|
||||
custom_system_prompt = "Custom system prompt"
|
||||
custom_tool_description = "Custom tool description"
|
||||
middleware = PlanningMiddleware(
|
||||
system_prompt=custom_system_prompt,
|
||||
tool_description=custom_tool_description,
|
||||
)
|
||||
|
||||
# Verify system prompt
|
||||
model = FakeToolCallingModel()
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
modified_request = middleware.modify_model_request(request, state, None)
|
||||
assert modified_request.system_prompt == custom_system_prompt
|
||||
|
||||
# Verify tool description
|
||||
assert len(middleware.tools) == 1
|
||||
tool = middleware.tools[0]
|
||||
assert tool.description == custom_tool_description
|
||||
|
||||
|
||||
def test_planning_middleware_default_prompts() -> None:
|
||||
"""Test that PlanningMiddleware uses default prompts when none provided."""
|
||||
middleware = PlanningMiddleware()
|
||||
|
||||
# Verify default system prompt
|
||||
assert middleware.system_prompt == WRITE_TODOS_SYSTEM_PROMPT
|
||||
|
||||
# Verify default tool description
|
||||
assert middleware.tool_description == WRITE_TODOS_TOOL_DESCRIPTION
|
||||
assert len(middleware.tools) == 1
|
||||
tool = middleware.tools[0]
|
||||
assert tool.description == WRITE_TODOS_TOOL_DESCRIPTION
|
||||
@@ -0,0 +1,109 @@
|
||||
"""Tests for AnthropicPromptCachingMiddleware."""
|
||||
|
||||
import warnings
|
||||
from types import ModuleType
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
|
||||
from ..model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "1h"
|
||||
assert middleware.min_messages_to_cache == 5
|
||||
|
||||
# Test with default values
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
assert middleware.modify_model_request(fake_request, {}, None).model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"}
|
||||
}
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
|
||||
):
|
||||
middleware.modify_model_request(fake_request, {}, None)
|
||||
|
||||
langchain_anthropic = ModuleType("langchain_anthropic")
|
||||
|
||||
class MockChatAnthropic:
|
||||
pass
|
||||
|
||||
langchain_anthropic.ChatAnthropic = MockChatAnthropic
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
|
||||
):
|
||||
middleware.modify_model_request(fake_request, {}, None)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = middleware.modify_model_request(fake_request, {}, None)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
|
||||
in str(w[-1].message)
|
||||
)
|
||||
assert result == fake_request
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
result = middleware.modify_model_request(fake_request, {}, None)
|
||||
assert result is fake_request
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
|
||||
in str(w[-1].message)
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
|
||||
result = middleware.modify_model_request(fake_request, {}, None)
|
||||
assert result is fake_request
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
|
||||
result = middleware.modify_model_request(fake_request, {}, None)
|
||||
assert result is fake_request
|
||||
@@ -0,0 +1,230 @@
|
||||
"""Tests for SummarizationMiddleware."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
|
||||
from ..model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_summarization_middleware_initialization() -> None:
|
||||
"""Test SummarizationMiddleware initialization."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model,
|
||||
max_tokens_before_summary=1000,
|
||||
messages_to_keep=10,
|
||||
summary_prompt="Custom prompt: {messages}",
|
||||
summary_prefix="Custom prefix:",
|
||||
)
|
||||
|
||||
assert middleware.model == model
|
||||
assert middleware.max_tokens_before_summary == 1000
|
||||
assert middleware.messages_to_keep == 10
|
||||
assert middleware.summary_prompt == "Custom prompt: {messages}"
|
||||
assert middleware.summary_prefix == "Custom prefix:"
|
||||
|
||||
# Test with string model
|
||||
with patch(
|
||||
"langchain.agents.middleware.summarization.init_chat_model",
|
||||
return_value=FakeToolCallingModel(),
|
||||
):
|
||||
middleware = SummarizationMiddleware(model="fake-model")
|
||||
assert isinstance(middleware.model, FakeToolCallingModel)
|
||||
|
||||
|
||||
def test_summarization_middleware_no_summarization_cases() -> None:
|
||||
"""Test SummarizationMiddleware when summarization is not needed or disabled."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(model=model, max_tokens_before_summary=1000)
|
||||
|
||||
# Test when summarization is disabled
|
||||
middleware_disabled = SummarizationMiddleware(model=model, max_tokens_before_summary=None)
|
||||
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi")]}
|
||||
result = middleware_disabled.before_model(state, None)
|
||||
assert result is None
|
||||
|
||||
# Test when token count is below threshold
|
||||
def mock_token_counter(messages):
|
||||
return 500 # Below threshold
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
result = middleware.before_model(state, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_summarization_middleware_helper_methods() -> None:
|
||||
"""Test SummarizationMiddleware helper methods."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(model=model, max_tokens_before_summary=1000)
|
||||
|
||||
# Test message ID assignment
|
||||
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
|
||||
middleware._ensure_message_ids(messages)
|
||||
for msg in messages:
|
||||
assert msg.id is not None
|
||||
|
||||
# Test message partitioning
|
||||
messages = [
|
||||
HumanMessage(content="1"),
|
||||
HumanMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
HumanMessage(content="4"),
|
||||
HumanMessage(content="5"),
|
||||
]
|
||||
to_summarize, preserved = middleware._partition_messages(messages, 2)
|
||||
assert len(to_summarize) == 2
|
||||
assert len(preserved) == 3
|
||||
assert to_summarize == messages[:2]
|
||||
assert preserved == messages[2:]
|
||||
|
||||
# Test summary message building
|
||||
summary = "This is a test summary"
|
||||
new_messages = middleware._build_new_messages(summary)
|
||||
assert len(new_messages) == 1
|
||||
assert isinstance(new_messages[0], HumanMessage)
|
||||
assert "Here is a summary of the conversation to date:" in new_messages[0].content
|
||||
assert summary in new_messages[0].content
|
||||
|
||||
# Test tool call detection
|
||||
ai_message_no_tools = AIMessage(content="Hello")
|
||||
assert not middleware._has_tool_calls(ai_message_no_tools)
|
||||
|
||||
ai_message_with_tools = AIMessage(
|
||||
content="Hello", tool_calls=[{"name": "test", "args": {}, "id": "1"}]
|
||||
)
|
||||
assert middleware._has_tool_calls(ai_message_with_tools)
|
||||
|
||||
human_message = HumanMessage(content="Hello")
|
||||
assert not middleware._has_tool_calls(human_message)
|
||||
|
||||
|
||||
def test_summarization_middleware_tool_call_safety() -> None:
|
||||
"""Test SummarizationMiddleware tool call safety logic."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, max_tokens_before_summary=1000, messages_to_keep=3
|
||||
)
|
||||
|
||||
# Test safe cutoff point detection with tool calls
|
||||
messages = [
|
||||
HumanMessage(content="1"),
|
||||
AIMessage(content="2", tool_calls=[{"name": "test", "args": {}, "id": "1"}]),
|
||||
ToolMessage(content="3", tool_call_id="1"),
|
||||
HumanMessage(content="4"),
|
||||
]
|
||||
|
||||
# Safe cutoff (doesn't separate AI/Tool pair)
|
||||
is_safe = middleware._is_safe_cutoff_point(messages, 0)
|
||||
assert is_safe is True
|
||||
|
||||
# Unsafe cutoff (separates AI/Tool pair)
|
||||
is_safe = middleware._is_safe_cutoff_point(messages, 2)
|
||||
assert is_safe is False
|
||||
|
||||
# Test tool call ID extraction
|
||||
ids = middleware._extract_tool_call_ids(messages[1])
|
||||
assert ids == {"1"}
|
||||
|
||||
|
||||
def test_summarization_middleware_summary_creation() -> None:
|
||||
"""Test SummarizationMiddleware summary creation."""
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
middleware = SummarizationMiddleware(model=MockModel(), max_tokens_before_summary=1000)
|
||||
|
||||
# Test normal summary creation
|
||||
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
|
||||
summary = middleware._create_summary(messages)
|
||||
assert summary == "Generated summary"
|
||||
|
||||
# Test empty messages
|
||||
summary = middleware._create_summary([])
|
||||
assert summary == "No previous conversation history."
|
||||
|
||||
# Test error handling
|
||||
class ErrorModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
raise Exception("Model error")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
middleware_error = SummarizationMiddleware(model=ErrorModel(), max_tokens_before_summary=1000)
|
||||
summary = middleware_error._create_summary(messages)
|
||||
assert "Error generating summary: Model error" in summary
|
||||
|
||||
|
||||
def test_summarization_middleware_full_workflow() -> None:
|
||||
"""Test SummarizationMiddleware complete summarization workflow."""
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockModel(), max_tokens_before_summary=1000, messages_to_keep=2
|
||||
)
|
||||
|
||||
# Mock high token count to trigger summarization
|
||||
def mock_token_counter(messages):
|
||||
return 1500 # Above threshold
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="1"),
|
||||
HumanMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
HumanMessage(content="4"),
|
||||
HumanMessage(content="5"),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) > 0
|
||||
|
||||
# Should have RemoveMessage for cleanup
|
||||
assert isinstance(result["messages"][0], RemoveMessage)
|
||||
assert result["messages"][0].id == REMOVE_ALL_MESSAGES
|
||||
|
||||
# Should have summary message
|
||||
summary_message = None
|
||||
for msg in result["messages"]:
|
||||
if isinstance(msg, HumanMessage) and "summary of the conversation" in msg.content:
|
||||
summary_message = msg
|
||||
break
|
||||
|
||||
assert summary_message is not None
|
||||
assert "Generated summary" in summary_message.content
|
||||
@@ -1,17 +1,18 @@
|
||||
"""Test Middleware handling of tools in agents."""
|
||||
"""Tests for middleware handling of tools in agents."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from .model import FakeToolCallingModel
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from ..model import FakeToolCallingModel
|
||||
|
||||
def test_model_request_tools_are_strings() -> None:
|
||||
"""Test that ModelRequest.tools contains tool names as strings, not tool objects."""
|
||||
|
||||
def test_model_request_tools_are_objects() -> None:
|
||||
"""Test that ModelRequest.tools contains tool objects (BaseTool | dict)."""
|
||||
captured_requests: list[ModelRequest] = []
|
||||
|
||||
@tool
|
||||
@@ -36,23 +37,23 @@ def test_model_request_tools_are_strings() -> None:
|
||||
tools=[search_tool, calculator],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[RequestCapturingMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
# Verify that at least one request was captured
|
||||
assert len(captured_requests) > 0
|
||||
|
||||
# Check that tools in the request are strings (tool names)
|
||||
# Check that tools in the request are tool objects
|
||||
request = captured_requests[0]
|
||||
assert isinstance(request.tools, list)
|
||||
assert len(request.tools) == 2
|
||||
assert all(isinstance(tool_name, str) for tool_name in request.tools)
|
||||
assert set(request.tools) == {"search_tool", "calculator"}
|
||||
tool_names = {t.name for t in request.tools}
|
||||
assert tool_names == {"search_tool", "calculator"}
|
||||
|
||||
|
||||
def test_middleware_can_modify_tool_names() -> None:
|
||||
"""Test that middleware can modify the list of tool names in ModelRequest."""
|
||||
def test_middleware_can_modify_tools() -> None:
|
||||
"""Test that middleware can modify the list of tools in ModelRequest."""
|
||||
|
||||
@tool
|
||||
def tool_a(input: str) -> str:
|
||||
@@ -74,7 +75,7 @@ def test_middleware_can_modify_tool_names() -> None:
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
# Only allow tool_a and tool_b
|
||||
request.tools = ["tool_a", "tool_b"]
|
||||
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
|
||||
return request
|
||||
|
||||
# Model will try to call tool_a
|
||||
@@ -87,7 +88,7 @@ def test_middleware_can_modify_tool_names() -> None:
|
||||
tools=[tool_a, tool_b, tool_c],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ToolFilteringMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Use tool_a")]})
|
||||
|
||||
@@ -98,88 +99,114 @@ def test_middleware_can_modify_tool_names() -> None:
|
||||
assert tool_messages[0].name == "tool_a"
|
||||
|
||||
|
||||
def test_unknown_tool_name_raises_error() -> None:
|
||||
"""Test that using an unknown tool name in ModelRequest raises a clear error."""
|
||||
def test_middleware_can_add_custom_tools() -> None:
|
||||
"""Test that middleware can add custom tool objects to ModelRequest."""
|
||||
|
||||
@tool
|
||||
def known_tool(input: str) -> str:
|
||||
"""A known tool."""
|
||||
return "result"
|
||||
def original_tool(input: str) -> str:
|
||||
"""Original tool."""
|
||||
return "original"
|
||||
|
||||
class BadMiddleware(AgentMiddleware):
|
||||
@tool
|
||||
def middleware_tool(input: str) -> str:
|
||||
"""Middleware-added tool."""
|
||||
return "middleware"
|
||||
|
||||
class ToolAddingMiddleware(AgentMiddleware):
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
# Add an unknown tool name
|
||||
request.tools = ["known_tool", "unknown_tool"]
|
||||
# Add middleware tool to the request
|
||||
request.tools = request.tools + [middleware_tool]
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[known_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[BadMiddleware()],
|
||||
).compile()
|
||||
|
||||
with pytest.raises(ValueError, match="Middleware returned unknown tool names"):
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
|
||||
def test_middleware_can_add_and_remove_tools() -> None:
|
||||
"""Test that middleware can dynamically add/remove tools based on state."""
|
||||
|
||||
@tool
|
||||
def search(query: str) -> str:
|
||||
"""Search for information."""
|
||||
return f"Search results for: {query}"
|
||||
|
||||
@tool
|
||||
def admin_tool(command: str) -> str:
|
||||
"""Admin-only tool."""
|
||||
return f"Admin: {command}"
|
||||
|
||||
class AdminState(AgentState):
|
||||
is_admin: bool
|
||||
|
||||
class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
|
||||
state_schema = AdminState
|
||||
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AdminState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
# Remove admin_tool if not admin
|
||||
if not state.get("is_admin", False):
|
||||
request.tools = [name for name in request.tools if name != "admin_tool"]
|
||||
return request
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
# Model will try to call middleware_tool
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "middleware_tool"}], []]
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[search, admin_tool],
|
||||
tools=[original_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ToolAddingMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Use middleware_tool")]})
|
||||
|
||||
# Verify that the middleware tool was executed successfully
|
||||
messages = result["messages"]
|
||||
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0].name == "middleware_tool"
|
||||
assert tool_messages[0].content == "middleware"
|
||||
|
||||
|
||||
def test_middleware_can_conditionally_add_tools() -> None:
|
||||
"""Test that middleware can conditionally add tools based on state."""
|
||||
|
||||
@tool
|
||||
def admin_tool(input: str) -> str:
|
||||
"""Admin-only tool."""
|
||||
return "admin_result"
|
||||
|
||||
class AdminState(AgentState):
|
||||
is_admin: bool = False
|
||||
|
||||
class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AdminState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
# Only add admin tool if user is admin
|
||||
if state.get("is_admin", False):
|
||||
request.tools = request.tools + [admin_tool]
|
||||
return request
|
||||
|
||||
# Model will try to call admin_tool when admin
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "admin_tool"}], []]
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ConditionalToolMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
# Test non-admin user - should not have access to admin_tool
|
||||
# We can't directly inspect the bound model, but we can verify the agent runs
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": False})
|
||||
assert "messages" in result
|
||||
# Test with admin user
|
||||
result = agent.invoke({"messages": [HumanMessage("Use admin tool")], "is_admin": True})
|
||||
messages = result["messages"]
|
||||
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0].name == "admin_tool"
|
||||
|
||||
# Test admin user - should have access to all tools
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": True})
|
||||
assert "messages" in result
|
||||
# Test with non-admin user (should not have admin tool available)
|
||||
model_no_admin = FakeToolCallingModel(tool_calls=[[], []]) # No tool calls
|
||||
agent_no_admin = create_agent(
|
||||
model=model_no_admin,
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ConditionalToolMiddleware()],
|
||||
)
|
||||
|
||||
result_no_admin = agent_no_admin.invoke(
|
||||
{"messages": [HumanMessage("Hello")], "is_admin": False}
|
||||
)
|
||||
messages_no_admin = result_no_admin["messages"]
|
||||
tool_messages_no_admin = [m for m in messages_no_admin if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages_no_admin) == 0
|
||||
|
||||
|
||||
def test_empty_tools_list_is_valid() -> None:
|
||||
"""Test that middleware can set tools to an empty list."""
|
||||
"""Test that middleware can set an empty tools list."""
|
||||
|
||||
@tool
|
||||
def some_tool(input: str) -> str:
|
||||
"""Some tool."""
|
||||
return "result"
|
||||
def test_tool(input: str) -> str:
|
||||
"""Test tool."""
|
||||
return "test"
|
||||
|
||||
class NoToolsMiddleware(AgentMiddleware):
|
||||
class EmptyToolsMiddleware(AgentMiddleware):
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
@@ -187,17 +214,18 @@ def test_empty_tools_list_is_valid() -> None:
|
||||
request.tools = []
|
||||
return request
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
# Model should not make any tool calls
|
||||
model = FakeToolCallingModel(tool_calls=[[], []]) # No tool calls
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[some_tool],
|
||||
tools=[test_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoToolsMiddleware()],
|
||||
).compile()
|
||||
middleware=[EmptyToolsMiddleware()],
|
||||
)
|
||||
|
||||
# Should run without error even with no tools
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert "messages" in result
|
||||
|
||||
|
||||
@@ -224,20 +252,20 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
modification_order.append(request.tools.copy())
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
# Remove tool_c
|
||||
request.tools = [name for name in request.tools if name != "tool_c"]
|
||||
request.tools = [t for t in request.tools if t.name != "tool_c"]
|
||||
return request
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
modification_order.append(request.tools.copy())
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
# Should not see tool_c here
|
||||
assert "tool_c" not in request.tools
|
||||
assert "tool_c" not in [t.name for t in request.tools]
|
||||
# Remove tool_b
|
||||
request.tools = [name for name in request.tools if name != "tool_b"]
|
||||
request.tools = [t for t in request.tools if t.name != "tool_b"]
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
@@ -245,7 +273,7 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
tools=[tool_a, tool_b, tool_c],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[FirstMiddleware(), SecondMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
@@ -260,37 +288,38 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
def test_middleware_with_additional_tools() -> None:
|
||||
"""Test middleware that provides additional tools via tools attribute."""
|
||||
|
||||
@tool
|
||||
def base_tool(input: str) -> str:
|
||||
"""Base tool."""
|
||||
return "base"
|
||||
|
||||
@tool
|
||||
def middleware_tool(input: str) -> str:
|
||||
"""Tool provided by middleware."""
|
||||
return "middleware"
|
||||
return f"Middleware tool result: {input}"
|
||||
|
||||
class ToolProvidingMiddleware(AgentMiddleware):
|
||||
tools = [middleware_tool]
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tools = [middleware_tool]
|
||||
|
||||
# Model calls the middleware-provided tool
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
# Add middleware tools to the request
|
||||
request.tools = request.tools + self.tools
|
||||
return request
|
||||
|
||||
# Model will try to call middleware_tool
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"args": {"input": "test"}, "id": "1", "name": "middleware_tool"}],
|
||||
[],
|
||||
]
|
||||
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "middleware_tool"}], []]
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[base_tool],
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ToolProvidingMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Use middleware tool")]})
|
||||
|
||||
# Verify that the middleware tool was executed
|
||||
# Verify that the middleware tool was executed successfully
|
||||
messages = result["messages"]
|
||||
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
@@ -108,4 +108,4 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
|
||||
}
|
||||
)
|
||||
|
||||
return self.bind(tools=tool_dicts)
|
||||
return self.bind(tools=tool_dicts, **kwargs)
|
||||
|
||||
381
libs/langchain_v1/tests/unit_tests/agents/test_basic_agents.py
Normal file
381
libs/langchain_v1/tests/unit_tests/agents/test_basic_agents.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Tests for basic agent functionality."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from typing import Any
|
||||
from langchain_core.tools import tool as dec_tool, InjectedToolCallId, ToolException
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from langgraph.types import Command
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
from langchain.tools import (
|
||||
InjectedState,
|
||||
InjectedStore,
|
||||
ToolNode,
|
||||
)
|
||||
from langchain.tools.tool_node import (
|
||||
_get_state_args,
|
||||
_infer_handled_types,
|
||||
)
|
||||
|
||||
from .any_str import AnyStr
|
||||
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
|
||||
def test_no_prompt(sync_checkpointer: BaseCheckpointSaver) -> None:
|
||||
"""Test agent with no system prompt."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[],
|
||||
checkpointer=sync_checkpointer,
|
||||
)
|
||||
inputs = [HumanMessage("hi?")]
|
||||
thread = {"configurable": {"thread_id": "123"}}
|
||||
response = agent.invoke({"messages": inputs}, thread, debug=True)
|
||||
expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
|
||||
assert response == expected_response
|
||||
|
||||
saved = sync_checkpointer.get_tuple(thread)
|
||||
assert saved is not None
|
||||
checkpoint_values = saved.checkpoint["channel_values"]
|
||||
assert checkpoint_values["messages"] == [
|
||||
_AnyIdHumanMessage(content="hi?"),
|
||||
AIMessage(content="hi?", id="0"),
|
||||
]
|
||||
assert checkpoint_values["thread_model_call_count"] == 1
|
||||
assert saved.metadata == {
|
||||
"parents": {},
|
||||
"source": "loop",
|
||||
"step": 1,
|
||||
}
|
||||
assert saved.pending_writes == []
|
||||
|
||||
|
||||
async def test_no_prompt_async(async_checkpointer: BaseCheckpointSaver) -> None:
|
||||
"""Test agent with no system prompt (async)."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(model, [], checkpointer=async_checkpointer)
|
||||
inputs = [HumanMessage("hi?")]
|
||||
thread = {"configurable": {"thread_id": "123"}}
|
||||
response = await agent.ainvoke({"messages": inputs}, thread, debug=True)
|
||||
expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
|
||||
assert response == expected_response
|
||||
|
||||
saved = await async_checkpointer.aget_tuple(thread)
|
||||
assert saved is not None
|
||||
checkpoint_values = saved.checkpoint["channel_values"]
|
||||
assert checkpoint_values["messages"] == [
|
||||
_AnyIdHumanMessage(content="hi?"),
|
||||
AIMessage(content="hi?", id="0"),
|
||||
]
|
||||
assert checkpoint_values["thread_model_call_count"] == 1
|
||||
assert saved.metadata == {
|
||||
"parents": {},
|
||||
"source": "loop",
|
||||
"step": 1,
|
||||
}
|
||||
assert saved.pending_writes == []
|
||||
|
||||
|
||||
def test_system_message_prompt() -> None:
|
||||
"""Test agent with system message prompt."""
|
||||
system_prompt = "Foo"
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(model, [], system_prompt=system_prompt)
|
||||
inputs = [HumanMessage("hi?")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
|
||||
assert response == expected_response
|
||||
|
||||
|
||||
def test_system_message_prompt_with_tools() -> None:
|
||||
"""Test agent with system message prompt and tools."""
|
||||
system_prompt = "You are a helpful assistant."
|
||||
|
||||
@dec_tool
|
||||
def search_tool(query: str) -> str:
|
||||
"""Search for information."""
|
||||
return f"Results for: {query}"
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"query": "test"}, "id": "1", "name": "search_tool"}], []]
|
||||
)
|
||||
|
||||
agent = create_agent(model, [search_tool], system_prompt=system_prompt)
|
||||
inputs = [HumanMessage("Search for something")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
messages = response["messages"]
|
||||
assert len(messages) >= 2 # Human message + AI message
|
||||
assert isinstance(messages[0], HumanMessage)
|
||||
assert isinstance(messages[1], AIMessage)
|
||||
|
||||
|
||||
def test_agent_with_tools() -> None:
|
||||
"""Test agent with tools."""
|
||||
|
||||
@dec_tool
|
||||
def calculator(expression: str) -> str:
|
||||
"""Calculate a mathematical expression."""
|
||||
return f"Result: {expression}"
|
||||
|
||||
@dec_tool
|
||||
def search_tool(query: str) -> str:
|
||||
"""Search for information."""
|
||||
return f"Results for: {query}"
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"expression": "2+2"}, "id": "1", "name": "calculator"}], []]
|
||||
)
|
||||
|
||||
agent = create_agent(model, [calculator, search_tool])
|
||||
inputs = [HumanMessage("Calculate 2+2")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
messages = response["messages"]
|
||||
assert len(messages) >= 3 # Human + AI + Tool message
|
||||
assert isinstance(messages[0], HumanMessage)
|
||||
assert isinstance(messages[1], AIMessage)
|
||||
assert isinstance(messages[2], ToolMessage)
|
||||
|
||||
|
||||
def test_agent_with_structured_output() -> None:
|
||||
"""Test agent with structured output."""
|
||||
|
||||
class WeatherResponse(BaseModel):
|
||||
temperature: float = Field(description="Temperature in Fahrenheit")
|
||||
condition: str = Field(description="Weather condition")
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[],
|
||||
response_format=ToolStrategy(schema=WeatherResponse),
|
||||
)
|
||||
inputs = [HumanMessage("What's the weather like?")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
messages = response["messages"]
|
||||
assert len(messages) >= 2
|
||||
assert isinstance(messages[0], HumanMessage)
|
||||
assert isinstance(messages[1], AIMessage)
|
||||
|
||||
|
||||
def test_agent_with_injected_tools() -> None:
|
||||
"""Test agent with tools that use injected state."""
|
||||
|
||||
@dec_tool
|
||||
def state_tool(input: str, state: Annotated[dict, InjectedState]) -> str:
|
||||
"""Tool that uses injected state."""
|
||||
return f"State: {state.get('custom_field', 'none')}"
|
||||
|
||||
@dec_tool
|
||||
def store_tool(input: str, store: Annotated[BaseStore, InjectedStore]) -> str:
|
||||
"""Tool that uses injected store."""
|
||||
return f"Store: {type(store).__name__}"
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "state_tool"}], []]
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[state_tool, store_tool],
|
||||
store=InMemoryStore(),
|
||||
)
|
||||
inputs = [HumanMessage("Use state tool")]
|
||||
response = agent.invoke({"messages": inputs, "custom_field": "test_value"}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
messages = response["messages"]
|
||||
assert len(messages) >= 3 # Human + AI + Tool message
|
||||
assert isinstance(messages[2], ToolMessage)
|
||||
assert "test_value" in messages[2].content
|
||||
|
||||
|
||||
def test_agent_with_tool_exception() -> None:
|
||||
"""Test agent handling tool exceptions."""
|
||||
|
||||
@dec_tool
|
||||
def error_tool(input: str) -> str:
|
||||
"""Tool that raises an exception."""
|
||||
raise ToolException("Tool error occurred")
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "error_tool"}], []]
|
||||
)
|
||||
|
||||
agent = create_agent(model, [error_tool])
|
||||
inputs = [HumanMessage("Use error tool")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
messages = response["messages"]
|
||||
assert len(messages) >= 3 # Human + AI + Tool message
|
||||
assert isinstance(messages[2], ToolMessage)
|
||||
assert "Tool error occurred" in messages[2].content
|
||||
|
||||
|
||||
def test_agent_with_multiple_tool_calls() -> None:
|
||||
"""Test agent with multiple tool calls in one response."""
|
||||
|
||||
@dec_tool
|
||||
def tool1(input: str) -> str:
|
||||
"""First tool."""
|
||||
return f"Tool1: {input}"
|
||||
|
||||
@dec_tool
|
||||
def tool2(input: str) -> str:
|
||||
"""Second tool."""
|
||||
return f"Tool2: {input}"
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{"args": {"input": "test1"}, "id": "1", "name": "tool1"},
|
||||
{"args": {"input": "test2"}, "id": "2", "name": "tool2"},
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
agent = create_agent(model, [tool1, tool2])
|
||||
inputs = [HumanMessage("Use both tools")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
messages = response["messages"]
|
||||
assert len(messages) >= 4 # Human + AI + 2 Tool messages
|
||||
assert isinstance(messages[1], AIMessage)
|
||||
assert len(messages[1].tool_calls) == 2
|
||||
assert isinstance(messages[2], ToolMessage)
|
||||
assert isinstance(messages[3], ToolMessage)
|
||||
|
||||
|
||||
def test_agent_with_custom_middleware() -> None:
|
||||
"""Test agent with custom middleware."""
|
||||
|
||||
class CustomMiddleware(AgentMiddleware[AgentState]):
|
||||
def before_model(self, state: AgentState, runtime) -> dict[str, Any]:
|
||||
return {"custom_field": "middleware_value"}
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(model, [], middleware=[CustomMiddleware()])
|
||||
inputs = [HumanMessage("Hello")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
assert "custom_field" in response
|
||||
assert response["custom_field"] == "middleware_value"
|
||||
|
||||
|
||||
def test_agent_with_checkpointer() -> None:
|
||||
"""Test agent with checkpointer for state persistence."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(model, [], checkpointer=InMemoryStore())
|
||||
inputs = [HumanMessage("Hello")]
|
||||
thread = {"configurable": {"thread_id": "test_thread"}}
|
||||
|
||||
# First invocation
|
||||
response1 = agent.invoke({"messages": inputs}, thread, debug=True)
|
||||
assert "messages" in response1
|
||||
|
||||
# Second invocation in same thread
|
||||
inputs2 = [HumanMessage("Hello again")]
|
||||
response2 = agent.invoke({"messages": inputs2}, thread, debug=True)
|
||||
assert "messages" in response2
|
||||
|
||||
# Should have conversation history
|
||||
messages = response2["messages"]
|
||||
assert len(messages) >= 2 # Should have previous messages
|
||||
|
||||
|
||||
def test_agent_with_store() -> None:
|
||||
"""Test agent with store for persistent data."""
|
||||
|
||||
@dec_tool
|
||||
def store_tool(input: str, store: Annotated[BaseStore, InjectedStore]) -> str:
|
||||
"""Tool that uses store."""
|
||||
store.put("test_key", "test_value")
|
||||
return "Stored value"
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "store_tool"}], []]
|
||||
)
|
||||
|
||||
store = InMemoryStore()
|
||||
agent = create_agent(model, [store_tool], store=store)
|
||||
inputs = [HumanMessage("Use store tool")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
# Verify store was used
|
||||
stored_value = store.get("test_key")
|
||||
assert stored_value == "test_value"
|
||||
|
||||
|
||||
def test_agent_debug_mode() -> None:
|
||||
"""Test agent in debug mode."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(model, [])
|
||||
inputs = [HumanMessage("Hello")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
# Debug mode should provide additional information
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_agent_with_empty_tools() -> None:
|
||||
"""Test agent with empty tools list."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(model, [])
|
||||
inputs = [HumanMessage("Hello")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
messages = response["messages"]
|
||||
assert len(messages) >= 2 # Human + AI message
|
||||
assert isinstance(messages[0], HumanMessage)
|
||||
assert isinstance(messages[1], AIMessage)
|
||||
|
||||
|
||||
def test_agent_with_none_system_prompt() -> None:
|
||||
"""Test agent with None system prompt."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(model, [], system_prompt=None)
|
||||
inputs = [HumanMessage("Hello")]
|
||||
response = agent.invoke({"messages": inputs}, debug=True)
|
||||
|
||||
assert "messages" in response
|
||||
messages = response["messages"]
|
||||
assert len(messages) >= 2
|
||||
assert isinstance(messages[0], HumanMessage)
|
||||
assert isinstance(messages[1], AIMessage)
|
||||
356
libs/langchain_v1/tests/unit_tests/agents/test_edge_cases.py
Normal file
356
libs/langchain_v1/tests/unit_tests/agents/test_edge_cases.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""Tests for edge cases and error conditions in agents."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.exceptions import LangChainException
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def test_agent_with_invalid_model_string(self) -> None:
|
||||
"""Test agent creation with invalid model string."""
|
||||
with pytest.raises(Exception): # Should raise some kind of error
|
||||
create_agent("invalid:model", [])
|
||||
|
||||
def test_agent_with_none_tools(self) -> None:
|
||||
"""Test agent with None tools."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, None)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_empty_tools_list(self) -> None:
|
||||
"""Test agent with empty tools list."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_duplicate_middleware(self) -> None:
|
||||
"""Test agent with duplicate middleware instances."""
|
||||
|
||||
class TestMiddleware(AgentMiddleware):
|
||||
pass
|
||||
|
||||
middleware1 = TestMiddleware()
|
||||
middleware2 = TestMiddleware() # Same class, different instance
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
# Should raise an error about duplicate middleware
|
||||
with pytest.raises(AssertionError, match="Please remove duplicate middleware instances"):
|
||||
create_agent(model, [], middleware=[middleware1, middleware2])
|
||||
|
||||
def test_agent_with_middleware_error(self) -> None:
|
||||
"""Test agent with middleware that raises an error."""
|
||||
|
||||
class ErrorMiddleware(AgentMiddleware):
|
||||
def before_model(self, state: AgentState, runtime) -> dict[str, Any]:
|
||||
raise ValueError("Middleware error")
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [], middleware=[ErrorMiddleware()])
|
||||
|
||||
with pytest.raises(ValueError, match="Middleware error"):
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
def test_agent_with_invalid_structured_output(self) -> None:
|
||||
"""Test agent with invalid structured output configuration."""
|
||||
|
||||
class InvalidSchema:
|
||||
pass # Not a valid Pydantic model
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
# Should handle invalid schema gracefully
|
||||
with pytest.raises(Exception):
|
||||
create_agent(model, [], response_format=InvalidSchema)
|
||||
|
||||
def test_agent_with_tool_exception(self) -> None:
|
||||
"""Test agent handling tool exceptions."""
|
||||
|
||||
@tool
|
||||
def error_tool(input: str) -> str:
|
||||
"""Tool that always raises an exception."""
|
||||
raise Exception("Tool execution failed")
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "error_tool"}], []]
|
||||
)
|
||||
|
||||
agent = create_agent(model, [error_tool])
|
||||
result = agent.invoke({"messages": [HumanMessage("Use error tool")]})
|
||||
|
||||
# Should handle tool exception gracefully
|
||||
assert "messages" in result
|
||||
messages = result["messages"]
|
||||
assert len(messages) >= 3 # Human + AI + Tool message
|
||||
assert isinstance(messages[2], ToolMessage)
|
||||
assert "Tool execution failed" in messages[2].content
|
||||
|
||||
def test_agent_with_malformed_tool_call(self) -> None:
|
||||
"""Test agent with malformed tool call from model."""
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[{"invalid": "tool_call"}], []] # Malformed tool call
|
||||
)
|
||||
|
||||
agent = create_agent(model, [])
|
||||
|
||||
# Should handle malformed tool calls gracefully
|
||||
with pytest.raises(Exception):
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
def test_agent_with_empty_messages(self) -> None:
|
||||
"""Test agent with empty messages list."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [])
|
||||
|
||||
result = agent.invoke({"messages": []})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_none_messages(self) -> None:
|
||||
"""Test agent with None messages."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [])
|
||||
|
||||
with pytest.raises(Exception):
|
||||
agent.invoke({"messages": None})
|
||||
|
||||
def test_agent_with_invalid_state(self) -> None:
|
||||
"""Test agent with invalid state structure."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [])
|
||||
|
||||
with pytest.raises(Exception):
|
||||
agent.invoke({"invalid_key": "value"})
|
||||
|
||||
def test_agent_with_large_message_history(self) -> None:
|
||||
"""Test agent with large message history."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [])
|
||||
|
||||
# Create a large message history
|
||||
messages = [HumanMessage(f"Message {i}") for i in range(100)]
|
||||
|
||||
result = agent.invoke({"messages": messages})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_special_characters(self) -> None:
|
||||
"""Test agent with special characters in messages."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [])
|
||||
|
||||
special_message = HumanMessage("Hello! @#$%^&*()_+-=[]{}|;':\",./<>?")
|
||||
result = agent.invoke({"messages": [special_message]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_unicode_messages(self) -> None:
|
||||
"""Test agent with unicode messages."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [])
|
||||
|
||||
unicode_message = HumanMessage("Hello 世界! 🌍")
|
||||
result = agent.invoke({"messages": [unicode_message]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_very_long_message(self) -> None:
|
||||
"""Test agent with very long message."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [])
|
||||
|
||||
long_message = HumanMessage("A" * 10000) # 10k character message
|
||||
result = agent.invoke({"messages": [long_message]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_multiple_system_prompts(self) -> None:
|
||||
"""Test agent behavior with multiple system prompts."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [], system_prompt="First prompt")
|
||||
|
||||
# Add system message to input
|
||||
messages = [HumanMessage("System: Second prompt"), HumanMessage("Hello")]
|
||||
|
||||
result = agent.invoke({"messages": messages})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_tool_node_instead_of_list(self) -> None:
|
||||
"""Test agent with ToolNode instead of tools list."""
|
||||
from langchain.tools import ToolNode
|
||||
|
||||
@tool
|
||||
def test_tool(input: str) -> str:
|
||||
"""Test tool."""
|
||||
return f"Result: {input}"
|
||||
|
||||
tool_node = ToolNode([test_tool])
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, tool_node)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_interrupt_before(self) -> None:
|
||||
"""Test agent with interrupt_before configuration."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [], interrupt_before=["model_request"])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_interrupt_after(self) -> None:
|
||||
"""Test agent with interrupt_after configuration."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [], interrupt_after=["model_request"])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_debug_mode(self) -> None:
|
||||
"""Test agent with debug mode enabled."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [], debug=True)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_cache(self) -> None:
|
||||
"""Test agent with cache configuration."""
|
||||
model = FakeToolCallingModel()
|
||||
cache = {} # Simple dict cache
|
||||
agent = create_agent(model, [], cache=cache)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_custom_name(self) -> None:
|
||||
"""Test agent with custom name."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [], name="CustomAgent")
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_context_schema(self) -> None:
|
||||
"""Test agent with custom context schema."""
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
class CustomContext(TypedDict):
|
||||
user_id: str
|
||||
session_id: str
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [], context_schema=CustomContext)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_checkpointer_and_store(self) -> None:
|
||||
"""Test agent with both checkpointer and store."""
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
checkpointer = InMemorySaver()
|
||||
store = InMemoryStore()
|
||||
|
||||
agent = create_agent(model, [], checkpointer=checkpointer, store=store)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_middleware_modifying_request(self) -> None:
|
||||
"""Test agent with middleware that modifies model request."""
|
||||
|
||||
class RequestModifyingMiddleware(AgentMiddleware):
|
||||
def modify_model_request(self, request, state, runtime):
|
||||
request.model_settings["custom_setting"] = "modified"
|
||||
return request
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [], middleware=[RequestModifyingMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_middleware_jumping_to_end(self) -> None:
|
||||
"""Test agent with middleware that jumps to end."""
|
||||
|
||||
class JumpToEndMiddleware(AgentMiddleware):
|
||||
def before_model(self, state: AgentState, runtime):
|
||||
return {"jump_to": "end"}
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [], middleware=[JumpToEndMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
# Should have minimal messages due to jumping to end
|
||||
assert len(result["messages"]) <= 2
|
||||
|
||||
def test_agent_with_structured_output_error(self) -> None:
|
||||
"""Test agent with structured output validation error."""
|
||||
|
||||
class WeatherResponse(BaseModel):
|
||||
temperature: float = Field(description="Temperature in Fahrenheit")
|
||||
condition: str = Field(description="Weather condition")
|
||||
|
||||
# Mock model that returns invalid structured output
|
||||
class InvalidStructuredModel(FakeToolCallingModel):
|
||||
def invoke(self, messages, **kwargs):
|
||||
# Return a message that doesn't match the schema
|
||||
return AIMessage(content="Invalid response")
|
||||
|
||||
model = InvalidStructuredModel()
|
||||
agent = create_agent(model, [], response_format=ToolStrategy(schema=WeatherResponse))
|
||||
|
||||
# Should handle structured output errors gracefully
|
||||
result = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_concurrent_invocations(self) -> None:
|
||||
"""Test agent with concurrent invocations."""
|
||||
import asyncio
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [])
|
||||
|
||||
async def run_concurrent():
|
||||
tasks = [agent.ainvoke({"messages": [HumanMessage(f"Message {i}")]}) for i in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
# Run concurrent invocations
|
||||
results = asyncio.run(run_concurrent())
|
||||
|
||||
# All should succeed
|
||||
assert len(results) == 5
|
||||
for result in results:
|
||||
assert "messages" in result
|
||||
|
||||
def test_agent_with_streaming(self) -> None:
|
||||
"""Test agent streaming functionality."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [])
|
||||
|
||||
# Test streaming
|
||||
stream = agent.stream({"messages": [HumanMessage("Hello")]})
|
||||
chunks = list(stream)
|
||||
|
||||
# Should have at least one chunk
|
||||
assert len(chunks) > 0
|
||||
# Each chunk should be a dict
|
||||
for chunk in chunks:
|
||||
assert isinstance(chunk, dict)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,574 +0,0 @@
|
||||
"""Consolidated tests for middleware decorators: before_model, after_model, and modify_model_request."""
|
||||
|
||||
import pytest
|
||||
from typing import Any
|
||||
from typing_extensions import NotRequired
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
before_model,
|
||||
after_model,
|
||||
modify_model_request,
|
||||
hook_config,
|
||||
)
|
||||
from langchain.agents.middleware_agent import create_agent, _get_can_jump_to
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
class CustomState(AgentState):
|
||||
"""Custom state schema for testing."""
|
||||
|
||||
custom_field: NotRequired[str]
|
||||
|
||||
|
||||
@tool
|
||||
def test_tool(input: str) -> str:
|
||||
"""A test tool for middleware testing."""
|
||||
return f"Tool result: {input}"
|
||||
|
||||
|
||||
def test_before_model_decorator() -> None:
|
||||
"""Test before_model decorator with all configuration options."""
|
||||
|
||||
@before_model(
|
||||
state_schema=CustomState, tools=[test_tool], can_jump_to=["end"], name="CustomBeforeModel"
|
||||
)
|
||||
def custom_before_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"jump_to": "end"}
|
||||
|
||||
assert isinstance(custom_before_model, AgentMiddleware)
|
||||
assert custom_before_model.state_schema == CustomState
|
||||
assert custom_before_model.tools == [test_tool]
|
||||
assert getattr(custom_before_model.__class__.before_model, "__can_jump_to__", []) == ["end"]
|
||||
assert custom_before_model.__class__.__name__ == "CustomBeforeModel"
|
||||
|
||||
result = custom_before_model.before_model({"messages": [HumanMessage("Hello")]}, None)
|
||||
assert result == {"jump_to": "end"}
|
||||
|
||||
|
||||
def test_after_model_decorator() -> None:
|
||||
"""Test after_model decorator with all configuration options."""
|
||||
|
||||
@after_model(
|
||||
state_schema=CustomState,
|
||||
tools=[test_tool],
|
||||
can_jump_to=["model", "end"],
|
||||
name="CustomAfterModel",
|
||||
)
|
||||
def custom_after_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"jump_to": "model"}
|
||||
|
||||
# Verify all options were applied
|
||||
assert isinstance(custom_after_model, AgentMiddleware)
|
||||
assert custom_after_model.state_schema == CustomState
|
||||
assert custom_after_model.tools == [test_tool]
|
||||
assert getattr(custom_after_model.__class__.after_model, "__can_jump_to__", []) == [
|
||||
"model",
|
||||
"end",
|
||||
]
|
||||
assert custom_after_model.__class__.__name__ == "CustomAfterModel"
|
||||
|
||||
# Verify it works
|
||||
result = custom_after_model.after_model(
|
||||
{"messages": [HumanMessage("Hello"), AIMessage("Hi!")]}, None
|
||||
)
|
||||
assert result == {"jump_to": "model"}
|
||||
|
||||
|
||||
def test_modify_model_request_decorator() -> None:
|
||||
"""Test modify_model_request decorator with all configuration options."""
|
||||
|
||||
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="CustomModifyRequest")
|
||||
def custom_modify_request(
|
||||
request: ModelRequest, state: CustomState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
request.system_prompt = "Modified"
|
||||
return request
|
||||
|
||||
# Verify all options were applied
|
||||
assert isinstance(custom_modify_request, AgentMiddleware)
|
||||
assert custom_modify_request.state_schema == CustomState
|
||||
assert custom_modify_request.tools == [test_tool]
|
||||
assert custom_modify_request.__class__.__name__ == "CustomModifyRequest"
|
||||
|
||||
# Verify it works
|
||||
original_request = ModelRequest(
|
||||
model="test-model",
|
||||
system_prompt="Original",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
)
|
||||
result = custom_modify_request.modify_model_request(
|
||||
original_request, {"messages": [HumanMessage("Hello")]}, None
|
||||
)
|
||||
assert result.system_prompt == "Modified"
|
||||
|
||||
|
||||
def test_all_decorators_integration() -> None:
|
||||
"""Test all three decorators working together in an agent."""
|
||||
call_order = []
|
||||
|
||||
@before_model
|
||||
def track_before(state: AgentState, runtime: Runtime) -> None:
|
||||
call_order.append("before")
|
||||
return None
|
||||
|
||||
@modify_model_request
|
||||
def track_modify(request: ModelRequest, state: AgentState, runtime: Runtime) -> ModelRequest:
|
||||
call_order.append("modify")
|
||||
return request
|
||||
|
||||
@after_model
|
||||
def track_after(state: AgentState, runtime: Runtime) -> None:
|
||||
call_order.append("after")
|
||||
return None
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(), middleware=[track_before, track_modify, track_after]
|
||||
)
|
||||
agent = agent.compile()
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert call_order == ["before", "modify", "after"]
|
||||
|
||||
|
||||
def test_decorators_use_function_names_as_default() -> None:
|
||||
"""Test that decorators use function names as default middleware names."""
|
||||
|
||||
@before_model
|
||||
def my_before_hook(state: AgentState, runtime: Runtime) -> None:
|
||||
return None
|
||||
|
||||
@modify_model_request
|
||||
def my_modify_hook(request: ModelRequest, state: AgentState, runtime: Runtime) -> ModelRequest:
|
||||
return request
|
||||
|
||||
@after_model
|
||||
def my_after_hook(state: AgentState, runtime: Runtime) -> None:
|
||||
return None
|
||||
|
||||
# Verify that function names are used as middleware class names
|
||||
assert my_before_hook.__class__.__name__ == "my_before_hook"
|
||||
assert my_modify_hook.__class__.__name__ == "my_modify_hook"
|
||||
assert my_after_hook.__class__.__name__ == "my_after_hook"
|
||||
|
||||
|
||||
def test_hook_config_decorator_on_class_method() -> None:
|
||||
"""Test hook_config decorator on AgentMiddleware class methods."""
|
||||
|
||||
class JumpMiddleware(AgentMiddleware):
|
||||
@hook_config(can_jump_to=["end", "model"])
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
if len(state["messages"]) > 5:
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["tools"])
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
return {"jump_to": "tools"}
|
||||
|
||||
# Verify can_jump_to metadata is preserved
|
||||
assert getattr(JumpMiddleware.before_model, "__can_jump_to__", []) == ["end", "model"]
|
||||
assert getattr(JumpMiddleware.after_model, "__can_jump_to__", []) == ["tools"]
|
||||
|
||||
|
||||
def test_can_jump_to_with_before_model_decorator() -> None:
|
||||
"""Test can_jump_to parameter used with before_model decorator."""
|
||||
|
||||
@before_model(can_jump_to=["end"])
|
||||
def conditional_before(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
if len(state["messages"]) > 3:
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
||||
# Verify middleware was created and has can_jump_to metadata
|
||||
assert isinstance(conditional_before, AgentMiddleware)
|
||||
assert getattr(conditional_before.__class__.before_model, "__can_jump_to__", []) == ["end"]
|
||||
|
||||
|
||||
def test_can_jump_to_with_after_model_decorator() -> None:
|
||||
"""Test can_jump_to parameter used with after_model decorator."""
|
||||
|
||||
@after_model(can_jump_to=["model", "end"])
|
||||
def conditional_after(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
if state["messages"][-1].content == "retry":
|
||||
return {"jump_to": "model"}
|
||||
return None
|
||||
|
||||
# Verify middleware was created and has can_jump_to metadata
|
||||
assert isinstance(conditional_after, AgentMiddleware)
|
||||
assert getattr(conditional_after.__class__.after_model, "__can_jump_to__", []) == [
|
||||
"model",
|
||||
"end",
|
||||
]
|
||||
|
||||
|
||||
def test_can_jump_to_integration() -> None:
|
||||
"""Test can_jump_to parameter in a full agent."""
|
||||
calls = []
|
||||
|
||||
@before_model(can_jump_to=["end"])
|
||||
def early_exit(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
calls.append("early_exit")
|
||||
if state["messages"][0].content == "exit":
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[early_exit])
|
||||
agent = agent.compile()
|
||||
|
||||
# Test with early exit
|
||||
result = agent.invoke({"messages": [HumanMessage("exit")]})
|
||||
assert calls == ["early_exit"]
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
# Test without early exit
|
||||
calls.clear()
|
||||
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
||||
assert calls == ["early_exit"]
|
||||
assert len(result["messages"]) > 1
|
||||
|
||||
|
||||
# Async Decorator Tests
|
||||
|
||||
|
||||
def test_async_before_model_decorator() -> None:
|
||||
"""Test before_model decorator with async function."""
|
||||
|
||||
@before_model(state_schema=CustomState, tools=[test_tool], name="AsyncBeforeModel")
|
||||
async def async_before_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"custom_field": "async_value"}
|
||||
|
||||
assert isinstance(async_before_model, AgentMiddleware)
|
||||
assert async_before_model.state_schema == CustomState
|
||||
assert async_before_model.tools == [test_tool]
|
||||
assert async_before_model.__class__.__name__ == "AsyncBeforeModel"
|
||||
|
||||
|
||||
def test_async_after_model_decorator() -> None:
|
||||
"""Test after_model decorator with async function."""
|
||||
|
||||
@after_model(state_schema=CustomState, tools=[test_tool], name="AsyncAfterModel")
|
||||
async def async_after_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"custom_field": "async_value"}
|
||||
|
||||
assert isinstance(async_after_model, AgentMiddleware)
|
||||
assert async_after_model.state_schema == CustomState
|
||||
assert async_after_model.tools == [test_tool]
|
||||
assert async_after_model.__class__.__name__ == "AsyncAfterModel"
|
||||
|
||||
|
||||
def test_async_modify_model_request_decorator() -> None:
|
||||
"""Test modify_model_request decorator with async function."""
|
||||
|
||||
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="AsyncModifyRequest")
|
||||
async def async_modify_request(
|
||||
request: ModelRequest, state: CustomState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
request.system_prompt = "Modified async"
|
||||
return request
|
||||
|
||||
assert isinstance(async_modify_request, AgentMiddleware)
|
||||
assert async_modify_request.state_schema == CustomState
|
||||
assert async_modify_request.tools == [test_tool]
|
||||
assert async_modify_request.__class__.__name__ == "AsyncModifyRequest"
|
||||
|
||||
|
||||
def test_mixed_sync_async_decorators() -> None:
|
||||
"""Test decorators with both sync and async functions."""
|
||||
|
||||
@before_model(name="MixedBeforeModel")
|
||||
def sync_before(state: AgentState, runtime: Runtime) -> None:
|
||||
return None
|
||||
|
||||
@before_model(name="MixedBeforeModel")
|
||||
async def async_before(state: AgentState, runtime: Runtime) -> None:
|
||||
return None
|
||||
|
||||
@modify_model_request(name="MixedModifyRequest")
|
||||
def sync_modify(request: ModelRequest, state: AgentState, runtime: Runtime) -> ModelRequest:
|
||||
return request
|
||||
|
||||
@modify_model_request(name="MixedModifyRequest")
|
||||
async def async_modify(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
return request
|
||||
|
||||
# Both should create valid middleware instances
|
||||
assert isinstance(sync_before, AgentMiddleware)
|
||||
assert isinstance(async_before, AgentMiddleware)
|
||||
assert isinstance(sync_modify, AgentMiddleware)
|
||||
assert isinstance(async_modify, AgentMiddleware)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_decorators_integration() -> None:
|
||||
"""Test async decorators working together in an agent."""
|
||||
call_order = []
|
||||
|
||||
@before_model
|
||||
async def track_async_before(state: AgentState, runtime: Runtime) -> None:
|
||||
call_order.append("async_before")
|
||||
return None
|
||||
|
||||
@modify_model_request
|
||||
async def track_async_modify(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
call_order.append("async_modify")
|
||||
return request
|
||||
|
||||
@after_model
|
||||
async def track_async_after(state: AgentState, runtime: Runtime) -> None:
|
||||
call_order.append("async_after")
|
||||
return None
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[track_async_before, track_async_modify, track_async_after],
|
||||
)
|
||||
agent = agent.compile()
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert call_order == ["async_before", "async_modify", "async_after"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_sync_async_decorators_integration() -> None:
|
||||
"""Test mixed sync/async decorators working together in an agent."""
|
||||
call_order = []
|
||||
|
||||
@before_model
|
||||
def track_sync_before(state: AgentState, runtime: Runtime) -> None:
|
||||
call_order.append("sync_before")
|
||||
return None
|
||||
|
||||
@before_model
|
||||
async def track_async_before(state: AgentState, runtime: Runtime) -> None:
|
||||
call_order.append("async_before")
|
||||
return None
|
||||
|
||||
@modify_model_request
|
||||
def track_sync_modify(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
call_order.append("sync_modify")
|
||||
return request
|
||||
|
||||
@modify_model_request
|
||||
async def track_async_modify(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
call_order.append("async_modify")
|
||||
return request
|
||||
|
||||
@after_model
|
||||
async def track_async_after(state: AgentState, runtime: Runtime) -> None:
|
||||
call_order.append("async_after")
|
||||
return None
|
||||
|
||||
@after_model
|
||||
def track_sync_after(state: AgentState, runtime: Runtime) -> None:
|
||||
call_order.append("sync_after")
|
||||
return None
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[
|
||||
track_sync_before,
|
||||
track_async_before,
|
||||
track_sync_modify,
|
||||
track_async_modify,
|
||||
track_async_after,
|
||||
track_sync_after,
|
||||
],
|
||||
)
|
||||
agent = agent.compile()
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert call_order == [
|
||||
"sync_before",
|
||||
"async_before",
|
||||
"sync_modify",
|
||||
"async_modify",
|
||||
"sync_after",
|
||||
"async_after",
|
||||
]
|
||||
|
||||
|
||||
def test_async_before_model_preserves_can_jump_to() -> None:
|
||||
"""Test that can_jump_to metadata is preserved for async before_model functions."""
|
||||
|
||||
@before_model(can_jump_to=["end"])
|
||||
async def async_conditional_before(
|
||||
state: AgentState, runtime: Runtime
|
||||
) -> dict[str, Any] | None:
|
||||
if len(state["messages"]) > 3:
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
||||
# Verify middleware was created and has can_jump_to metadata
|
||||
assert isinstance(async_conditional_before, AgentMiddleware)
|
||||
assert getattr(async_conditional_before.__class__.abefore_model, "__can_jump_to__", []) == [
|
||||
"end"
|
||||
]
|
||||
|
||||
|
||||
def test_async_after_model_preserves_can_jump_to() -> None:
|
||||
"""Test that can_jump_to metadata is preserved for async after_model functions."""
|
||||
|
||||
@after_model(can_jump_to=["model", "end"])
|
||||
async def async_conditional_after(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
if state["messages"][-1].content == "retry":
|
||||
return {"jump_to": "model"}
|
||||
return None
|
||||
|
||||
# Verify middleware was created and has can_jump_to metadata
|
||||
assert isinstance(async_conditional_after, AgentMiddleware)
|
||||
assert getattr(async_conditional_after.__class__.aafter_model, "__can_jump_to__", []) == [
|
||||
"model",
|
||||
"end",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_can_jump_to_integration() -> None:
|
||||
"""Test can_jump_to parameter in a full agent with async middleware."""
|
||||
calls = []
|
||||
|
||||
@before_model(can_jump_to=["end"])
|
||||
async def async_early_exit(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
calls.append("async_early_exit")
|
||||
if state["messages"][0].content == "exit":
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[async_early_exit])
|
||||
agent = agent.compile()
|
||||
|
||||
# Test with early exit
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("exit")]})
|
||||
assert calls == ["async_early_exit"]
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
# Test without early exit
|
||||
calls.clear()
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
|
||||
assert calls == ["async_early_exit"]
|
||||
assert len(result["messages"]) > 1
|
||||
|
||||
|
||||
def test_get_can_jump_to_no_false_positives() -> None:
|
||||
"""Test that _get_can_jump_to doesn't return false positives for base class methods."""
|
||||
|
||||
# Middleware with no overridden methods should return empty list
|
||||
class EmptyMiddleware(AgentMiddleware):
|
||||
pass
|
||||
|
||||
empty_middleware = EmptyMiddleware()
|
||||
empty_middleware.tools = []
|
||||
|
||||
# Should not return any jump destinations for base class methods
|
||||
assert _get_can_jump_to(empty_middleware, "before_model") == []
|
||||
assert _get_can_jump_to(empty_middleware, "after_model") == []
|
||||
|
||||
|
||||
def test_get_can_jump_to_only_overridden_methods() -> None:
|
||||
"""Test that _get_can_jump_to only checks overridden methods."""
|
||||
|
||||
# Middleware with only sync method overridden
|
||||
class SyncOnlyMiddleware(AgentMiddleware):
|
||||
@hook_config(can_jump_to=["end"])
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
sync_middleware = SyncOnlyMiddleware()
|
||||
sync_middleware.tools = []
|
||||
|
||||
# Should return can_jump_to from overridden sync method
|
||||
assert _get_can_jump_to(sync_middleware, "before_model") == ["end"]
|
||||
|
||||
# Middleware with only async method overridden
|
||||
class AsyncOnlyMiddleware(AgentMiddleware):
|
||||
@hook_config(can_jump_to=["model"])
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
async_middleware = AsyncOnlyMiddleware()
|
||||
async_middleware.tools = []
|
||||
|
||||
# Should return can_jump_to from overridden async method
|
||||
assert _get_can_jump_to(async_middleware, "after_model") == ["model"]
|
||||
|
||||
|
||||
def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAssertion) -> None:
|
||||
"""Test that async middleware with can_jump_to creates correct graph structure with conditional edges."""
|
||||
|
||||
# Test 1: Async before_model with can_jump_to
|
||||
@before_model(can_jump_to=["end"])
|
||||
async def async_before_with_jump(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
if len(state["messages"]) > 5:
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
||||
agent_async_before = create_agent(
|
||||
model=FakeToolCallingModel(), middleware=[async_before_with_jump]
|
||||
)
|
||||
|
||||
assert agent_async_before.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
# Test 2: Async after_model with can_jump_to
|
||||
@after_model(can_jump_to=["model", "end"])
|
||||
async def async_after_with_jump(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
if state["messages"][-1].content == "retry":
|
||||
return {"jump_to": "model"}
|
||||
return None
|
||||
|
||||
agent_async_after = create_agent(
|
||||
model=FakeToolCallingModel(), middleware=[async_after_with_jump]
|
||||
)
|
||||
|
||||
assert agent_async_after.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
# Test 3: Multiple async middleware with can_jump_to
|
||||
@before_model(can_jump_to=["end"])
|
||||
async def async_before_early_exit(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
@after_model(can_jump_to=["model"])
|
||||
async def async_after_retry(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
agent_multiple_async = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[async_before_early_exit, async_after_retry],
|
||||
)
|
||||
|
||||
assert agent_multiple_async.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
# Test 4: Mixed sync and async middleware with can_jump_to
|
||||
@before_model(can_jump_to=["end"])
|
||||
def sync_before_with_jump(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
@after_model(can_jump_to=["model", "end"])
|
||||
async def async_after_with_jumps(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
agent_mixed = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[sync_before_with_jump, async_after_with_jumps],
|
||||
)
|
||||
|
||||
assert agent_mixed.compile().get_graph().draw_mermaid() == snapshot
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,58 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
|
||||
def tool() -> None:
|
||||
"""Testing tool."""
|
||||
|
||||
|
||||
def pre_model_hook() -> None:
|
||||
"""Pre-model hook."""
|
||||
|
||||
|
||||
def post_model_hook() -> None:
|
||||
"""Post-model hook."""
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
"""Response format for the agent."""
|
||||
|
||||
result: str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [[], [tool]])
|
||||
@pytest.mark.parametrize("pre_model_hook", [None, pre_model_hook])
|
||||
@pytest.mark.parametrize("post_model_hook", [None, post_model_hook])
|
||||
def test_react_agent_graph_structure(
|
||||
snapshot: SnapshotAssertion,
|
||||
tools: list[Callable],
|
||||
pre_model_hook: Union[Callable, None],
|
||||
post_model_hook: Union[Callable, None],
|
||||
) -> None:
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=tools,
|
||||
pre_model_hook=pre_model_hook,
|
||||
post_model_hook=post_model_hook,
|
||||
)
|
||||
try:
|
||||
assert agent.get_graph().draw_mermaid(with_styles=False) == snapshot
|
||||
except Exception as e:
|
||||
msg = (
|
||||
"The graph structure has changed. Please update the snapshot."
|
||||
"Configuration used:\n"
|
||||
f"tools: {tools}, "
|
||||
f"pre_model_hook: {pre_model_hook}, "
|
||||
f"post_model_hook: {post_model_hook}, "
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
@@ -7,6 +7,7 @@ from typing import Union
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.tools import tool
|
||||
from langchain.agents.structured_output import (
|
||||
MultipleStructuredOutputsError,
|
||||
ProviderStrategy,
|
||||
@@ -74,12 +75,14 @@ location_json_schema = {
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def get_weather() -> str:
|
||||
"""Get the weather."""
|
||||
|
||||
return "The weather is sunny and 75°F."
|
||||
|
||||
|
||||
@tool
|
||||
def get_location() -> str:
|
||||
"""Get the current location."""
|
||||
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip this test since langgraph.prebuilt.responses is not available
|
||||
pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
|
||||
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
except ImportError:
|
||||
skip_openai_integration_tests = True
|
||||
else:
|
||||
skip_openai_integration_tests = False
|
||||
|
||||
AGENT_PROMPT = "You are an HR assistant."
|
||||
|
||||
|
||||
class ToolCalls(BaseSchema):
|
||||
get_employee_role: int
|
||||
get_employee_department: int
|
||||
|
||||
|
||||
class AssertionByInvocation(BaseSchema):
|
||||
prompt: str
|
||||
tools_with_expected_calls: ToolCalls
|
||||
expected_last_message: str
|
||||
expected_structured_response: Optional[Dict[str, Any]]
|
||||
llm_request_count: int
|
||||
|
||||
|
||||
class TestCase(BaseSchema):
|
||||
name: str
|
||||
response_format: Union[Dict[str, Any], List[Dict[str, Any]]]
|
||||
assertions_by_invocation: List[AssertionByInvocation]
|
||||
|
||||
|
||||
class Employee(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
department: str
|
||||
|
||||
|
||||
EMPLOYEES: list[Employee] = [
|
||||
Employee(name="Sabine", role="Developer", department="IT"),
|
||||
Employee(name="Henrik", role="Product Manager", department="IT"),
|
||||
Employee(name="Jessica", role="HR", department="People"),
|
||||
]
|
||||
|
||||
TEST_CASES = load_spec("responses", as_model=TestCase)
|
||||
|
||||
|
||||
def _make_tool(fn, *, name: str, description: str):
|
||||
mock = MagicMock(side_effect=lambda *, name: fn(name=name))
|
||||
InputModel = create_model(f"{name}_input", name=(str, ...))
|
||||
|
||||
@tool(name, description=description, args_schema=InputModel)
|
||||
def _wrapped(name: str):
|
||||
return mock(name=name)
|
||||
|
||||
return {"tool": _wrapped, "mock": mock}
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI integration tests are disabled.")
|
||||
@pytest.mark.parametrize("case", TEST_CASES, ids=[c.name for c in TEST_CASES])
|
||||
def test_responses_integration_matrix(case: TestCase) -> None:
|
||||
if case.name == "asking for information that does not fit into the response format":
|
||||
pytest.xfail(
|
||||
"currently failing due to undefined behavior when model cannot conform to any of the structured response formats."
|
||||
)
|
||||
|
||||
def get_employee_role(*, name: str) -> Optional[str]:
|
||||
for e in EMPLOYEES:
|
||||
if e.name == name:
|
||||
return e.role
|
||||
return None
|
||||
|
||||
def get_employee_department(*, name: str) -> Optional[str]:
|
||||
for e in EMPLOYEES:
|
||||
if e.name == name:
|
||||
return e.department
|
||||
return None
|
||||
|
||||
role_tool = _make_tool(
|
||||
get_employee_role,
|
||||
name="get_employee_role",
|
||||
description="Get the employee role by name",
|
||||
)
|
||||
dept_tool = _make_tool(
|
||||
get_employee_department,
|
||||
name="get_employee_department",
|
||||
description="Get the employee department by name",
|
||||
)
|
||||
|
||||
response_format_spec = case.response_format
|
||||
if isinstance(response_format_spec, dict):
|
||||
response_format_spec = [response_format_spec]
|
||||
# Unwrap nested schema objects
|
||||
response_format_spec = [item.get("schema", item) for item in response_format_spec]
|
||||
if len(response_format_spec) == 1:
|
||||
tool_output = ToolStrategy(response_format_spec[0])
|
||||
else:
|
||||
tool_output = ToolStrategy({"oneOf": response_format_spec})
|
||||
|
||||
llm_request_count = 0
|
||||
|
||||
for assertion in case.assertions_by_invocation:
|
||||
|
||||
def on_request(request: httpx.Request) -> None:
|
||||
nonlocal llm_request_count
|
||||
llm_request_count += 1
|
||||
|
||||
http_client = httpx.Client(
|
||||
event_hooks={"request": [on_request]},
|
||||
)
|
||||
|
||||
model = ChatOpenAI(
|
||||
model="gpt-4o",
|
||||
temperature=0,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=[role_tool["tool"], dept_tool["tool"]],
|
||||
prompt=AGENT_PROMPT,
|
||||
response_format=tool_output,
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(assertion.prompt)]})
|
||||
|
||||
# Count tool calls
|
||||
assert role_tool["mock"].call_count == assertion.tools_with_expected_calls.get_employee_role
|
||||
assert (
|
||||
dept_tool["mock"].call_count
|
||||
== assertion.tools_with_expected_calls.get_employee_department
|
||||
)
|
||||
|
||||
# Count LLM calls
|
||||
assert llm_request_count == assertion.llm_request_count
|
||||
|
||||
# Check last message content
|
||||
last_message = result["messages"][-1]
|
||||
assert last_message.content == assertion.expected_last_message
|
||||
|
||||
# Check structured response
|
||||
structured_response_json = result["structured_response"]
|
||||
assert structured_response_json == assertion.expected_structured_response
|
||||
@@ -1,107 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip this test since langgraph.prebuilt.responses is not available
|
||||
pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
|
||||
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
except ImportError:
|
||||
skip_openai_integration_tests = True
|
||||
else:
|
||||
skip_openai_integration_tests = False
|
||||
|
||||
AGENT_PROMPT = """
|
||||
You are a strict polling bot.
|
||||
|
||||
- Only use the "poll_job" tool until it returns { status: "succeeded" }.
|
||||
- If status is "pending", call the tool again. Do not produce a final answer.
|
||||
- When it is "succeeded", return exactly: "Attempts: <number>" with no extra text.
|
||||
"""
|
||||
|
||||
|
||||
class TestCase(BaseSchema):
|
||||
name: str
|
||||
return_direct: bool
|
||||
response_format: Optional[Dict[str, Any]]
|
||||
expected_tool_calls: int
|
||||
expected_last_message: str
|
||||
expected_structured_response: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
TEST_CASES = load_spec("return_direct", as_model=TestCase)
|
||||
|
||||
|
||||
def _make_tool(return_direct: bool):
|
||||
attempts = 0
|
||||
|
||||
def _side_effect():
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
return {
|
||||
"status": "succeeded" if attempts >= 10 else "pending",
|
||||
"attempts": attempts,
|
||||
}
|
||||
|
||||
mock = MagicMock(side_effect=_side_effect)
|
||||
|
||||
@tool(
|
||||
"pollJob",
|
||||
description=(
|
||||
"Check the status of a long-running job. "
|
||||
"Returns { status: 'pending' | 'succeeded', attempts: number }."
|
||||
),
|
||||
return_direct=return_direct,
|
||||
)
|
||||
def _wrapped():
|
||||
return mock()
|
||||
|
||||
return {"tool": _wrapped, "mock": mock}
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI integration tests are disabled.")
|
||||
@pytest.mark.parametrize("case", TEST_CASES, ids=[c.name for c in TEST_CASES])
|
||||
def test_return_direct_integration_matrix(case: TestCase) -> None:
|
||||
poll_tool = _make_tool(case.return_direct)
|
||||
|
||||
model = ChatOpenAI(
|
||||
model="gpt-4o",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
if case.response_format:
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=[poll_tool["tool"]],
|
||||
prompt=AGENT_PROMPT,
|
||||
response_format=ToolStrategy(case.response_format),
|
||||
)
|
||||
else:
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=[poll_tool["tool"]],
|
||||
prompt=AGENT_PROMPT,
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [
|
||||
HumanMessage("Poll the job until it's done and tell me how many attempts it took.")
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Count tool calls
|
||||
assert poll_tool["mock"].call_count == case.expected_tool_calls
|
||||
|
||||
# Check last message content
|
||||
last_message = result["messages"][-1]
|
||||
assert last_message.content == case.expected_last_message
|
||||
|
||||
# Check structured response
|
||||
if case.expected_structured_response is not None:
|
||||
structured_response_json = result["structured_response"]
|
||||
assert structured_response_json == case.expected_structured_response
|
||||
else:
|
||||
assert "structured_response" not in result
|
||||
328
libs/langchain_v1/tests/unit_tests/agents/test_specifications.py
Normal file
328
libs/langchain_v1/tests/unit_tests/agents/test_specifications.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Combined tests for agent specifications."""
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip these tests since langgraph.prebuilt.responses is not available
|
||||
pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
|
||||
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
except ImportError:
|
||||
skip_openai_integration_tests = True
|
||||
else:
|
||||
skip_openai_integration_tests = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Import specification loading utilities
|
||||
try:
|
||||
from .specifications.responses import load_spec
|
||||
from .specifications.return_direct import load_spec as load_return_direct_spec
|
||||
except ImportError:
|
||||
# Fallback if specifications are not available
|
||||
def load_spec(name, as_model=None):
|
||||
return []
|
||||
|
||||
def load_return_direct_spec(name, as_model=None):
|
||||
return []
|
||||
|
||||
|
||||
# Test data models for responses specification
|
||||
class ToolCalls(BaseModel):
|
||||
get_employee_role: int
|
||||
get_employee_department: int
|
||||
|
||||
|
||||
class AssertionByInvocation(BaseModel):
|
||||
prompt: str
|
||||
tools_with_expected_calls: ToolCalls
|
||||
expected_last_message: str
|
||||
expected_structured_response: Optional[Dict[str, Any]]
|
||||
llm_request_count: int
|
||||
|
||||
|
||||
class TestCase(BaseModel):
|
||||
name: str
|
||||
response_format: Union[Dict[str, Any], List[Dict[str, Any]]]
|
||||
assertions_by_invocation: List[AssertionByInvocation]
|
||||
|
||||
|
||||
class Employee(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
department: str
|
||||
|
||||
|
||||
EMPLOYEES: list[Employee] = [
|
||||
Employee(name="Sabine", role="Developer", department="IT"),
|
||||
Employee(name="Henrik", role="Product Manager", department="IT"),
|
||||
Employee(name="Jessica", role="HR", department="People"),
|
||||
]
|
||||
|
||||
TEST_CASES = load_spec("responses", as_model=TestCase)
|
||||
|
||||
|
||||
# Test data models for return_direct specification
|
||||
class ReturnDirectTestCase(BaseModel):
|
||||
name: str
|
||||
return_direct: bool
|
||||
response_format: Optional[Dict[str, Any]]
|
||||
expected_tool_calls: int
|
||||
expected_last_message: str
|
||||
expected_structured_response: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
RETURN_DIRECT_TEST_CASES = load_return_direct_spec("return_direct", as_model=ReturnDirectTestCase)
|
||||
|
||||
|
||||
# Test tools
|
||||
@tool
|
||||
def get_employee_role(employee_name: str) -> str:
|
||||
"""Get the role of an employee."""
|
||||
for emp in EMPLOYEES:
|
||||
if emp.name == employee_name:
|
||||
return emp.role
|
||||
return "Employee not found"
|
||||
|
||||
|
||||
@tool
|
||||
def get_employee_department(employee_name: str) -> str:
|
||||
"""Get the department of an employee."""
|
||||
for emp in EMPLOYEES:
|
||||
if emp.name == employee_name:
|
||||
return emp.department
|
||||
return "Employee not found"
|
||||
|
||||
|
||||
@tool
|
||||
def poll_job() -> Dict[str, Any]:
|
||||
"""Poll a job status."""
|
||||
# This will be mocked in tests
|
||||
return {"status": "pending", "attempts": 1}
|
||||
|
||||
|
||||
# Responses specification tests
|
||||
class TestResponsesSpecification:
|
||||
"""Test responses specification functionality."""
|
||||
|
||||
def test_responses_specification_loading(self) -> None:
|
||||
"""Test that responses specification can be loaded."""
|
||||
assert isinstance(TEST_CASES, list)
|
||||
# If specifications are available, we should have test cases
|
||||
# If not, the list will be empty due to the fallback
|
||||
|
||||
@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI not available")
|
||||
def test_responses_specification_with_openai(self) -> None:
|
||||
"""Test responses specification with OpenAI model."""
|
||||
if not TEST_CASES:
|
||||
pytest.skip("No test cases available")
|
||||
|
||||
# This would run the actual specification tests if available
|
||||
# For now, just verify the structure
|
||||
for test_case in TEST_CASES:
|
||||
assert hasattr(test_case, "name")
|
||||
assert hasattr(test_case, "response_format")
|
||||
assert hasattr(test_case, "assertions_by_invocation")
|
||||
|
||||
|
||||
# Return direct specification tests
|
||||
class TestReturnDirectSpecification:
|
||||
"""Test return direct specification functionality."""
|
||||
|
||||
def test_return_direct_specification_loading(self) -> None:
|
||||
"""Test that return direct specification can be loaded."""
|
||||
assert isinstance(RETURN_DIRECT_TEST_CASES, list)
|
||||
|
||||
@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI not available")
|
||||
def test_return_direct_specification_with_openai(self) -> None:
|
||||
"""Test return direct specification with OpenAI model."""
|
||||
if not RETURN_DIRECT_TEST_CASES:
|
||||
pytest.skip("No test cases available")
|
||||
|
||||
# This would run the actual specification tests if available
|
||||
# For now, just verify the structure
|
||||
for test_case in RETURN_DIRECT_TEST_CASES:
|
||||
assert hasattr(test_case, "name")
|
||||
assert hasattr(test_case, "return_direct")
|
||||
assert hasattr(test_case, "expected_tool_calls")
|
||||
|
||||
|
||||
# Tool strategy tests
|
||||
class TestToolStrategy:
|
||||
"""Test ToolStrategy functionality."""
|
||||
|
||||
def test_tool_strategy_basic_creation(self) -> None:
|
||||
"""Test basic ToolStrategy creation."""
|
||||
strategy = ToolStrategy(schema=Employee)
|
||||
assert strategy.schema == Employee
|
||||
assert strategy.tool_message_content is None
|
||||
assert len(strategy.schema_specs) == 1
|
||||
|
||||
def test_tool_strategy_with_tool_message_content(self) -> None:
|
||||
"""Test ToolStrategy with tool message content."""
|
||||
strategy = ToolStrategy(schema=Employee, tool_message_content="custom message")
|
||||
assert strategy.schema == Employee
|
||||
assert strategy.tool_message_content == "custom message"
|
||||
|
||||
def test_tool_strategy_with_union_schema(self) -> None:
|
||||
"""Test ToolStrategy with Union schema."""
|
||||
|
||||
class CustomModel(BaseModel):
|
||||
value: float
|
||||
description: str
|
||||
|
||||
strategy = ToolStrategy(schema=Union[Employee, CustomModel])
|
||||
assert len(strategy.schema_specs) == 2
|
||||
assert strategy.schema_specs[0].schema == Employee
|
||||
assert strategy.schema_specs[1].schema == CustomModel
|
||||
|
||||
|
||||
# Agent with specifications tests
|
||||
class TestAgentWithSpecifications:
|
||||
"""Test agents with various specifications."""
|
||||
|
||||
def test_agent_with_employee_schema(self) -> None:
|
||||
"""Test agent with employee schema."""
|
||||
|
||||
# Mock model for testing
|
||||
class MockModel:
|
||||
def invoke(self, messages, **kwargs):
|
||||
return HumanMessage(content="Mock response")
|
||||
|
||||
agent = create_agent(
|
||||
model=MockModel(),
|
||||
tools=[get_employee_role, get_employee_department],
|
||||
response_format=ToolStrategy(schema=Employee),
|
||||
)
|
||||
|
||||
# Test that agent can be created
|
||||
assert agent is not None
|
||||
|
||||
def test_agent_with_polling_tool(self) -> None:
|
||||
"""Test agent with polling tool."""
|
||||
# Mock the polling tool
|
||||
mock_poll = MagicMock()
|
||||
mock_poll.side_effect = [
|
||||
{"status": "pending", "attempts": 1},
|
||||
{"status": "pending", "attempts": 2},
|
||||
{"status": "succeeded", "attempts": 3},
|
||||
]
|
||||
|
||||
@tool
|
||||
def mock_poll_job() -> Dict[str, Any]:
|
||||
"""Mock polling tool."""
|
||||
return mock_poll()
|
||||
|
||||
class MockModel:
|
||||
def invoke(self, messages, **kwargs):
|
||||
return HumanMessage(content="Mock response")
|
||||
|
||||
agent = create_agent(
|
||||
model=MockModel(),
|
||||
tools=[mock_poll_job],
|
||||
)
|
||||
|
||||
# Test that agent can be created
|
||||
assert agent is not None
|
||||
|
||||
def test_agent_with_return_direct_tool(self) -> None:
|
||||
"""Test agent with return_direct tool."""
|
||||
|
||||
@tool
|
||||
def return_direct_tool(input: str) -> str:
|
||||
"""Tool that returns directly."""
|
||||
return f"Direct result: {input}"
|
||||
|
||||
class MockModel:
|
||||
def invoke(self, messages, **kwargs):
|
||||
return HumanMessage(content="Mock response")
|
||||
|
||||
agent = create_agent(
|
||||
model=MockModel(),
|
||||
tools=[return_direct_tool],
|
||||
)
|
||||
|
||||
# Test that agent can be created
|
||||
assert agent is not None
|
||||
|
||||
|
||||
# Specification validation tests
|
||||
class TestSpecificationValidation:
|
||||
"""Test specification validation."""
|
||||
|
||||
def test_employee_schema_validation(self) -> None:
|
||||
"""Test employee schema validation."""
|
||||
# Valid employee
|
||||
emp = Employee(name="Test", role="Developer", department="IT")
|
||||
assert emp.name == "Test"
|
||||
assert emp.role == "Developer"
|
||||
assert emp.department == "IT"
|
||||
|
||||
# Invalid employee (missing required fields)
|
||||
with pytest.raises(Exception):
|
||||
Employee(name="Test") # Missing role and department
|
||||
|
||||
def test_tool_calls_schema_validation(self) -> None:
|
||||
"""Test tool calls schema validation."""
|
||||
tool_calls = ToolCalls(get_employee_role=1, get_employee_department=2)
|
||||
assert tool_calls.get_employee_role == 1
|
||||
assert tool_calls.get_employee_department == 2
|
||||
|
||||
def test_assertion_schema_validation(self) -> None:
|
||||
"""Test assertion schema validation."""
|
||||
tool_calls = ToolCalls(get_employee_role=1, get_employee_department=2)
|
||||
assertion = AssertionByInvocation(
|
||||
prompt="Test prompt",
|
||||
tools_with_expected_calls=tool_calls,
|
||||
expected_last_message="Expected message",
|
||||
expected_structured_response={"key": "value"},
|
||||
llm_request_count=1,
|
||||
)
|
||||
assert assertion.prompt == "Test prompt"
|
||||
assert assertion.llm_request_count == 1
|
||||
|
||||
|
||||
# Integration tests (when specifications are available)
|
||||
class TestSpecificationIntegration:
|
||||
"""Test specification integration."""
|
||||
|
||||
def test_specification_file_loading(self) -> None:
|
||||
"""Test that specification files can be loaded."""
|
||||
# This test verifies that the specification loading mechanism works
|
||||
# even if the actual specification files are not available
|
||||
try:
|
||||
responses_spec = load_spec("responses")
|
||||
return_direct_spec = load_return_direct_spec("return_direct")
|
||||
assert isinstance(responses_spec, list)
|
||||
assert isinstance(return_direct_spec, list)
|
||||
except Exception:
|
||||
# If specifications are not available, that's okay for this test
|
||||
pass
|
||||
|
||||
def test_specification_with_mock_data(self) -> None:
|
||||
"""Test specifications with mock data."""
|
||||
# Create mock test cases
|
||||
mock_tool_calls = ToolCalls(get_employee_role=1, get_employee_department=1)
|
||||
mock_assertion = AssertionByInvocation(
|
||||
prompt="Test prompt",
|
||||
tools_with_expected_calls=mock_tool_calls,
|
||||
expected_last_message="Expected message",
|
||||
expected_structured_response=None,
|
||||
llm_request_count=1,
|
||||
)
|
||||
mock_test_case = TestCase(
|
||||
name="Mock test",
|
||||
response_format={"type": "json_object"},
|
||||
assertions_by_invocation=[mock_assertion],
|
||||
)
|
||||
|
||||
# Verify mock data structure
|
||||
assert mock_test_case.name == "Mock test"
|
||||
assert len(mock_test_case.assertions_by_invocation) == 1
|
||||
assert mock_test_case.assertions_by_invocation[0].prompt == "Test prompt"
|
||||
@@ -7,7 +7,6 @@ from typing import (
|
||||
Any,
|
||||
NoReturn,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import pytest
|
||||
@@ -33,14 +32,13 @@ from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.tools import (
|
||||
ToolNode,
|
||||
InjectedState,
|
||||
InjectedStore,
|
||||
ToolNode,
|
||||
)
|
||||
from langchain.tools.tool_node import TOOL_CALL_ERROR_TEMPLATE, ToolInvocationError, tools_condition
|
||||
|
||||
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||
from .model import FakeToolCallingModel
|
||||
from tests.unit_tests.agents.messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
@@ -1374,7 +1372,7 @@ def test_tool_node_inject_store() -> None:
|
||||
def test_tool_node_ensure_utf8() -> None:
|
||||
@dec_tool
|
||||
def get_day_list(days: list[str]) -> list[str]:
|
||||
"""choose days"""
|
||||
"""Choose days"""
|
||||
return days
|
||||
|
||||
data = ["星期一", "水曜日", "목요일", "Friday"]
|
||||
Reference in New Issue
Block a user