Compare commits

...

12 Commits

Author SHA1 Message Date
Sydney Runkle
9f6a29c44d minimal changes 2025-10-03 12:11:05 -04:00
Sydney Runkle
ec1bd054f1 tool node fixes 2025-10-03 12:09:19 -04:00
Sydney Runkle
f0139330f7 refactoring tests 2025-10-03 11:59:22 -04:00
Sydney Runkle
984e1e984f Merge branch 'master' into sr/create-agent-api 2025-10-03 11:15:26 -04:00
Sydney Runkle
e529445b29 no printing 2025-10-03 11:12:57 -04:00
Sydney Runkle
5b972fb0e9 fixes 2025-10-03 11:11:52 -04:00
Sydney Runkle
b8a20329d8 fixing up tests 2025-10-03 10:39:48 -04:00
Sydney Runkle
54e507f331 more fixes 2025-10-03 10:32:22 -04:00
Sydney Runkle
6001543093 fix tests + support builtins 2025-10-03 10:31:22 -04:00
Sydney Runkle
894ffa0be5 minor fixes 2025-10-03 10:04:47 -04:00
Sydney Runkle
56c93fa82f adding back some tests 2025-10-03 09:31:42 -04:00
Sydney Runkle
3743d596c7 beginnings of a refactor 2025-10-03 09:06:47 -04:00
27 changed files with 3043 additions and 6047 deletions

View File

@@ -1,6 +1,7 @@
"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""
from langchain.agents.react_agent import AgentState, create_agent
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentState
__all__ = [
"AgentState",

View File

@@ -1,4 +1,4 @@
"""Middleware agent implementation."""
"""Agent factory implementation."""
import itertools
from collections.abc import Callable, Sequence
@@ -9,9 +9,10 @@ from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMe
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langgraph.constants import END, START
from langgraph.graph.state import StateGraph
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.runtime import Runtime
from langgraph.types import Send
from langgraph.store.base import BaseStore
from langgraph.types import Checkpointer, Send
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
@@ -167,17 +168,142 @@ def _handle_structured_output_error(
def create_agent( # noqa: PLR0915
*,
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
*,
system_prompt: str | None = None,
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
context_schema: type[ContextT] | None = None,
) -> StateGraph[
checkpointer: Checkpointer | None = None,
store: BaseStore | None = None,
interrupt_before: list[str] | None = None,
interrupt_after: list[str] | None = None,
debug: bool = False,
cache: Any = None,
name: str | None = None,
) -> CompiledStateGraph[
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
]:
"""Create a middleware agent graph."""
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
For more details on using `create_agent`,
visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.
Args:
model: The language model for the agent. Supports static model
selection with string identifier (e.g., `"openai:gpt-4"`) or
chat model instance (e.g., `ChatOpenAI()`).
tools: A list of tools or a ToolNode instance.
If an empty list is provided, the agent will consist of a single LLM node
without tool calling.
system_prompt: An optional system prompt for the LLM. This is converted to a
SystemMessage and added to the beginning of the list of messages in
state["messages"].
middleware: An optional sequence of AgentMiddleware instances for customizing
agent behavior. Middleware can:
- **Intercept execution** with before_model/after_model hooks
- **Modify model requests** with modify_model_request
- **Add custom state** via state_schema
- **Control flow** with jump_to (using @hook_config decorator)
See AgentMiddleware documentation for implementation details.
response_format: An optional configuration for structured responses.
If provided, the agent will handle structured output via tool calls
during the normal conversation flow.
When the model calls a structured output tool, the response will be captured
and returned in the 'structured_response' state key.
If not provided, `structured_response` will not be present in the output state.
Can be one of:
- **Pydantic model class**: Automatically converted to ToolStrategy or
ProviderStrategy based on model capabilities
- **ToolStrategy**: Uses tool calling for structured output with error handling
- **ProviderStrategy**: Uses provider-native structured output (OpenAI, etc.)
!!! important
`response_format` requires the model to support tool calling
!!! note
Structured responses are handled directly in the model call node via
tool calls, eliminating the need for separate structured response nodes.
context_schema: An optional schema for runtime context.
checkpointer: An optional checkpoint saver object. This is used for persisting
the state of the graph (e.g., as chat memory) for a single thread
(e.g., a single conversation).
store: An optional store object. This is used for persisting data
across multiple threads (e.g., multiple conversations / users).
interrupt_before: An optional list of node names to interrupt before.
Should be one of the following: "model_request", "tools".
This is useful if you want to add a user confirmation or other interrupt
before taking an action.
interrupt_after: An optional list of node names to interrupt after.
Should be one of the following: "model_request", "tools".
This is useful if you want to return directly or run additional processing on an output.
debug: A flag indicating whether to enable debug mode.
cache: An optional cache object for caching LLM responses.
name: An optional name for the compiled graph.
Returns:
A compiled LangGraph agent that can be used for chat interactions.
The "model_request" node calls the language model with the messages list
(after applying the system prompt if provided).
If the resulting AIMessage contains `tool_calls`,
the graph will then call the ["tools"][langgraph.prebuilt.tool_node.ToolNode].
The "tools" node executes the tools (1 tool per `tool_call`)
and adds the responses to the messages list as `ToolMessage` objects.
The model_request node then calls the language model again.
The process repeats until no more `tool_calls` are present in the response.
The agent then returns the full list of messages as a dictionary containing the key "messages".
``` mermaid
sequenceDiagram
participant U as User
participant A as LLM
participant T as Tools
U->>A: Initial input
Note over A: System Prompt + LLM
loop while tool_calls present
A->>T: Execute tools
T-->>A: ToolMessage for each tool_calls
end
A->>U: Return final state
```
Example:
```python
from langchain.agents import create_agent
def check_weather(location: str) -> str:
'''Return the weather forecast for the specified location.'''
return f"It's always sunny in {location}"
graph = create_agent(
model="anthropic:claude-3-7-sonnet-latest",
tools=[check_weather],
system_prompt="You are a helpful assistant",
)
inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
for chunk in graph.stream(inputs, stream_mode="updates"):
print(chunk)
```
"""
# init chat model
if isinstance(model, str):
model = init_chat_model(model)
@@ -368,26 +494,8 @@ def create_agent( # noqa: PLR0915
def _get_bound_model(request: ModelRequest) -> Runnable:
"""Get the model with appropriate tool bindings."""
# Get actual tool objects from tool names
tools_by_name = {t.name: t for t in default_tools}
unknown_tools = [name for name in request.tools if name not in tools_by_name]
if unknown_tools:
available_tools = sorted(tools_by_name.keys())
msg = (
f"Middleware returned unknown tool names: {unknown_tools}\n\n"
f"Available tools: {available_tools}\n\n"
"To fix this issue:\n"
"1. Ensure the tools are passed to create_agent() via "
"the 'tools' parameter\n"
"2. If using custom middleware with tools, ensure "
"they're registered via middleware.tools attribute\n"
"3. Verify that tool names in ModelRequest.tools match "
"the actual tool.name values"
)
raise ValueError(msg)
requested_tools = [tools_by_name[name] for name in request.tools]
# request.tools contains BaseTool | dict objects
requested_tools = request.tools
if isinstance(response_format, ProviderStrategy):
# Use native structured output
@@ -411,7 +519,7 @@ def create_agent( # noqa: PLR0915
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
tools=[t.name for t in default_tools],
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
@@ -438,17 +546,19 @@ def create_agent( # noqa: PLR0915
messages = [SystemMessage(request.system_prompt), *messages]
output = model_.invoke(messages)
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(output),
}
result = _handle_model_output(output)
# Always add call counts
result["thread_model_call_count"] = state.get("thread_model_call_count", 0) + 1
result["run_model_call_count"] = state.get("run_model_call_count", 0) + 1
return result
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
tools=[t.name for t in default_tools],
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
@@ -466,11 +576,13 @@ def create_agent( # noqa: PLR0915
messages = [SystemMessage(request.system_prompt), *messages]
output = await model_.ainvoke(messages)
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(output),
}
result = _handle_model_output(output)
# Always add call counts
result["thread_model_call_count"] = state.get("thread_model_call_count", 0) + 1
result["run_model_call_count"] = state.get("run_model_call_count", 0) + 1
return result
# Use sync or async based on model capabilities
from langgraph._internal._runnable import RunnableCallable
@@ -595,7 +707,15 @@ def create_agent( # noqa: PLR0915
can_jump_to=_get_can_jump_to(m1, "after_model"),
)
return graph
return graph.compile(
checkpointer=checkpointer,
store=store,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
cache=cache,
name=name,
)
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
@@ -626,8 +746,8 @@ def _fetch_last_ai_and_tool_messages(
def _make_model_to_tools_edge(
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
) -> Callable[[dict[str, Any], Runtime], str | list[Send] | None]:
def model_to_tools(state: dict[str, Any], runtime: Runtime) -> str | list[Send] | None:
# 1. if there's an explicit jump_to in the state, use it
if jump_to := state.get("jump_to"):
return _resolve_jump(jump_to, first_node)
@@ -646,14 +766,19 @@ def _make_model_to_tools_edge(
if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
]
# 3. if there are pending tool calls, jump to the tool node
# 3. if there are pending (non-structured) tool calls, jump to the tool node
if pending_tool_calls:
pending_tool_calls = [
tool_node.inject_tool_args(call, state, None) for call in pending_tool_calls
tool_node.inject_tool_args(call, state, runtime.store)
for call in pending_tool_calls
]
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
# 4. AIMessage has tool calls, but there are no pending tool calls
# 4. if we have a structured response and no pending tool calls, we're done
if "structured_response" in state:
return END
# 5. AIMessage has tool calls, but there are no pending tool calls
# which suggests the injection of artificial tool messages. jump to the first node
return first_node
@@ -666,7 +791,8 @@ def _make_tools_to_model_edge(
def tools_to_model(state: dict[str, Any]) -> str | None:
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
if all(
# Check if any tool call has return_direct=True
if any(
tool_node.tools_by_name[c["name"]].return_direct
for c in last_ai_message.tool_calls
if c["name"] in tool_node.tools_by_name
@@ -692,7 +818,6 @@ def _add_middleware_edge(
Args:
graph: The graph to add the edge to.
method: The method to call for the middleware node.
name: The name of the middleware node.
default_destination: The default destination for the edge.
model_destination: The destination for the edge to the model.
@@ -716,3 +841,6 @@ def _add_middleware_edge(
else:
graph.add_edge(name, default_destination)
__all__ = ["create_agent"]

View File

@@ -62,7 +62,7 @@ class ModelRequest:
system_prompt: str | None
messages: list[AnyMessage] # excluding system prompt
tool_choice: Any | None
tools: list[str]
tools: list[BaseTool | dict[str, Any]]
response_format: ResponseFormat | None
model_settings: dict[str, Any] = field(default_factory=dict)

File diff suppressed because it is too large Load Diff

View File

@@ -85,14 +85,14 @@ langchain-text-splitters = { path = "../text-splitters", editable = true }
langchain-openai = { path = "../partners/openai", editable = true }
[tool.ruff]
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py", "tests/unit_tests/tools/test_tool_node.py"]
line-length = 100
[tool.mypy]
strict = true
ignore_missing_imports = true
enable_error_code = "deprecated"
exclude = ["tests/unit_tests/agents/*", "tests/integration_tests/agents/*"]
exclude = ["tests/unit_tests/agents/*", "tests/integration_tests/agents/*", "tests/unit_tests/tools/test_tool_node.py"]
# TODO: activate for 'strict' checking
disallow_any_generics = false

View File

@@ -0,0 +1 @@
"""Middleware tests for agents."""

View File

@@ -0,0 +1,230 @@
"""Tests for ModelCallLimitMiddleware."""
import pytest
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain.agents.middleware.call_tracking import (
ModelCallLimitMiddleware,
ModelCallLimitExceededError,
)
from langchain.agents.factory import create_agent
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from ..model import FakeToolCallingModel
@tool
def simple_tool(input: str) -> str:
"""A simple tool"""
return input
def test_middleware_unit_functionality():
"""Test that the middleware works as expected in isolation."""
# Test with end behavior
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1)
# Mock runtime (not used in current implementation)
runtime = None
# Test when limits are not exceeded
state = {"thread_model_call_count": 0, "run_model_call_count": 0}
result = middleware.before_model(state, runtime)
assert result is None
# Test when thread limit is exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
result = middleware.before_model(state, runtime)
assert result is not None
assert result["jump_to"] == "end"
assert "messages" in result
assert len(result["messages"]) == 1
assert "thread limit (2/2)" in result["messages"][0].content
# Test when run limit is exceeded
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
result = middleware.before_model(state, runtime)
assert result is not None
assert result["jump_to"] == "end"
assert "messages" in result
assert len(result["messages"]) == 1
assert "run limit (1/1)" in result["messages"][0].content
# Test with error behavior
middleware_exception = ModelCallLimitMiddleware(
thread_limit=2, run_limit=1, exit_behavior="error"
)
# Test exception when thread limit exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware_exception.before_model(state, runtime)
assert "thread limit (2/2)" in str(exc_info.value)
# Test exception when run limit exceeded
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware_exception.before_model(state, runtime)
assert "run limit (1/1)" in str(exc_info.value)
def test_thread_limit_with_create_agent():
"""Test that thread limits work correctly with create_agent."""
# Create a model that will make 2 calls
model = FakeToolCallingModel(
tool_calls=[
[{"name": "simple_tool", "args": {"input": "test"}, "id": "1"}],
[], # No tool calls on second call
]
)
# Set thread limit to 1 (should be exceeded after 1 call)
agent = create_agent(
model=model,
tools=[simple_tool],
middleware=[ModelCallLimitMiddleware(thread_limit=1)],
checkpointer=InMemorySaver(),
)
# This should hit the thread limit after the first model call
result = agent.invoke(
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
)
assert "messages" in result
# The agent should have made 1 model call then jumped to end with limit exceeded message
# So we should have: Human + AI + Tool + Limit exceeded AI message
assert len(result["messages"]) == 4 # Human + AI + Tool + Limit AI
assert isinstance(result["messages"][0], HumanMessage)
assert isinstance(result["messages"][1], AIMessage)
assert isinstance(result["messages"][2], ToolMessage)
# Limit exceeded message
assert "thread limit" in result["messages"][3].content
# Second invocation should also hit the limit
result2 = agent.invoke(
{"messages": [HumanMessage("Hello again")]}, {"configurable": {"thread_id": "thread1"}}
)
assert len(result2["messages"]) == 2 # Human + Limit exceeded AI
assert isinstance(result2["messages"][0], HumanMessage)
# Limit exceeded message
assert "thread limit" in result2["messages"][1].content
def test_run_limit_with_create_agent():
"""Test that run limits work correctly with create_agent."""
# Create a model that will make 2 calls
model = FakeToolCallingModel(
tool_calls=[
[{"name": "simple_tool", "args": {"input": "test"}, "id": "1"}],
[], # No tool calls on second call
]
)
# Set run limit to 1 (should be exceeded after 1 call)
agent = create_agent(
model=model,
tools=[simple_tool],
middleware=[ModelCallLimitMiddleware(run_limit=1)],
checkpointer=InMemorySaver(),
)
# This should hit the run limit after the first model call
result = agent.invoke(
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
)
assert "messages" in result
# The agent should have made 1 model call then jumped to end with limit exceeded message
# So we should have: Human + AI + Tool + Limit exceeded AI message
assert len(result["messages"]) == 4 # Human + AI + Tool + Limit AI
assert isinstance(result["messages"][0], HumanMessage)
assert isinstance(result["messages"][1], AIMessage)
assert isinstance(result["messages"][2], ToolMessage)
# Limit exceeded message
assert "run limit" in result["messages"][3].content
def test_middleware_initialization_validation():
"""Test that middleware initialization validates parameters correctly."""
# Test that at least one limit must be specified
with pytest.raises(ValueError, match="At least one limit must be specified"):
ModelCallLimitMiddleware()
# Test invalid exit behavior
with pytest.raises(ValueError, match="Invalid exit_behavior"):
ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid")
# Test valid initialization
middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3)
assert middleware.thread_limit == 5
assert middleware.run_limit == 3
assert middleware.exit_behavior == "end"
# Test with only thread limit
middleware = ModelCallLimitMiddleware(thread_limit=5)
assert middleware.thread_limit == 5
assert middleware.run_limit is None
# Test with only run limit
middleware = ModelCallLimitMiddleware(run_limit=3)
assert middleware.thread_limit is None
assert middleware.run_limit == 3
def test_exception_error_message():
"""Test that the exception provides clear error messages."""
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error")
# Test thread limit exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None)
error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg
assert "thread limit (2/2)" in error_msg
# Test run limit exceeded
state = {"thread_model_call_count": 0, "run_model_call_count": 1}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None)
error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg
assert "run limit (1/1)" in error_msg
# Test both limits exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 1}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None)
error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg
assert "thread limit (2/2)" in error_msg
assert "run limit (1/1)" in error_msg
def test_run_limit_resets_between_invocations() -> None:
"""Test that run limits reset between invocations."""
# First: No tool calls per invocation, so model does not increment call counts internally
middleware = ModelCallLimitMiddleware(thread_limit=3, run_limit=1, exit_behavior="error")
model = FakeToolCallingModel(
tool_calls=[[], [], [], []]
) # No tool calls, so only model call per run
agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver())
thread_config = {"configurable": {"thread_id": "test_thread"}}
agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config)
agent.invoke({"messages": [HumanMessage("Hello again")]}, thread_config)
agent.invoke({"messages": [HumanMessage("Hello third")]}, thread_config)
# Fourth run: should raise, thread_model_call_count == 3 (limit)
with pytest.raises(ModelCallLimitExceededError) as exc_info:
agent.invoke({"messages": [HumanMessage("Hello fourth")]}, thread_config)
error_msg = str(exc_info.value)
assert "thread limit (3/3)" in error_msg

View File

@@ -0,0 +1,346 @@
"""Tests for middleware decorators: before_model, after_model, and modify_model_request."""
import pytest
from typing import Any
from typing_extensions import NotRequired
from syrupy.assertion import SnapshotAssertion
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.tools import tool
from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelRequest,
before_model,
after_model,
modify_model_request,
hook_config,
)
from langchain.agents.factory import create_agent, _get_can_jump_to
from ..model import FakeToolCallingModel
class CustomState(AgentState):
"""Custom state schema for testing."""
custom_field: NotRequired[str]
@tool
def test_tool(input: str) -> str:
"""A test tool for middleware testing."""
return f"Tool result: {input}"
def test_before_model_decorator() -> None:
"""Test before_model decorator with all configuration options."""
@before_model(
state_schema=CustomState, tools=[test_tool], can_jump_to=["end"], name="CustomBeforeModel"
)
def custom_before_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
return {"jump_to": "end"}
assert isinstance(custom_before_model, AgentMiddleware)
assert custom_before_model.state_schema == CustomState
assert custom_before_model.tools == [test_tool]
assert getattr(custom_before_model.__class__.before_model, "__can_jump_to__", []) == ["end"]
def test_after_model_decorator() -> None:
"""Test after_model decorator with all configuration options."""
@after_model(
state_schema=CustomState, tools=[test_tool], can_jump_to=["end"], name="CustomAfterModel"
)
def custom_after_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
return {"jump_to": "end"}
assert isinstance(custom_after_model, AgentMiddleware)
assert custom_after_model.state_schema == CustomState
assert custom_after_model.tools == [test_tool]
assert getattr(custom_after_model.__class__.after_model, "__can_jump_to__", []) == ["end"]
def test_modify_model_request_decorator() -> None:
"""Test modify_model_request decorator with all configuration options."""
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="CustomModifyRequest")
def custom_modify_request(
request: ModelRequest, state: CustomState, runtime: Runtime
) -> ModelRequest:
return request
assert isinstance(custom_modify_request, AgentMiddleware)
assert custom_modify_request.state_schema == CustomState
assert custom_modify_request.tools == [test_tool]
def test_decorator_with_minimal_config() -> None:
"""Test decorators with minimal configuration."""
@before_model
def minimal_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {}
@after_model
def minimal_after_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {}
@modify_model_request
def minimal_modify_request(
request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
return request
assert isinstance(minimal_before_model, AgentMiddleware)
assert isinstance(minimal_after_model, AgentMiddleware)
assert isinstance(minimal_modify_request, AgentMiddleware)
def test_decorator_with_jump_to() -> None:
"""Test decorators with can_jump_to configuration."""
@before_model(can_jump_to=["end", "custom_node"])
def jump_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {"jump_to": "end"}
@after_model(can_jump_to=["end", "custom_node"])
def jump_after_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {"jump_to": "end"}
assert getattr(jump_before_model.__class__.before_model, "__can_jump_to__", []) == [
"end",
"custom_node",
]
assert getattr(jump_after_model.__class__.after_model, "__can_jump_to__", []) == [
"end",
"custom_node",
]
def test_decorator_with_tools() -> None:
"""Test decorators with tools configuration."""
@tool
def tool1(input: str) -> str:
"""Tool 1."""
return "result1"
@tool
def tool2(input: str) -> str:
"""Tool 2."""
return "result2"
@before_model(tools=[tool1, tool2])
def tools_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {}
assert tools_before_model.tools == [tool1, tool2]
def test_decorator_with_custom_state() -> None:
"""Test decorators with custom state schema."""
class MyCustomState(AgentState):
"""Custom state for testing."""
my_field: NotRequired[int]
@before_model(state_schema=MyCustomState)
def custom_state_before_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
return {}
assert custom_state_before_model.state_schema == MyCustomState
def test_decorator_with_name() -> None:
"""Test decorators with custom name."""
@before_model(name="MyCustomMiddleware")
def named_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {}
assert named_before_model.__class__.__name__ == "MyCustomMiddleware"
def test_decorator_combined_config() -> None:
"""Test decorators with combined configuration options."""
@tool
def combined_tool(input: str) -> str:
"""Combined tool."""
return "combined"
class CombinedState(AgentState):
"""Combined state."""
combined_field: NotRequired[str]
@before_model(
state_schema=CombinedState,
tools=[combined_tool],
can_jump_to=["end", "custom"],
name="CombinedMiddleware",
)
def combined_before_model(state: CombinedState, runtime: Runtime) -> dict[str, Any]:
return {"jump_to": "end"}
assert isinstance(combined_before_model, AgentMiddleware)
assert combined_before_model.state_schema == CombinedState
assert combined_before_model.tools == [combined_tool]
assert getattr(combined_before_model.__class__.before_model, "__can_jump_to__", []) == [
"end",
"custom",
]
assert combined_before_model.__class__.__name__ == "CombinedMiddleware"
def test_get_can_jump_to() -> None:
"""Test _get_can_jump_to utility function."""
@before_model(can_jump_to=["end", "custom"])
def test_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {}
can_jump_to = _get_can_jump_to(test_middleware)
assert can_jump_to == ["end", "custom"]
# Test with no can_jump_to
@before_model
def no_jump_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {}
can_jump_to = _get_can_jump_to(no_jump_middleware)
assert can_jump_to == []
def test_decorator_functionality() -> None:
"""Test that decorated functions work correctly in agents."""
@before_model
def test_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {"custom_field": "test_value"}
@after_model
def test_after_model(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {"processed": True}
@modify_model_request
def test_modify_request(
request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
# Add custom header to model settings
request.model_settings["custom_header"] = "test"
return request
# Create agent with decorated middleware
agent = create_agent(
model=FakeToolCallingModel(),
tools=[test_tool],
middleware=[test_before_model, test_after_model, test_modify_request],
)
# Test that agent can be invoked
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
assert "custom_field" in result
assert result["custom_field"] == "test_value"
assert "processed" in result
assert result["processed"] is True
def test_decorator_with_jump_functionality() -> None:
"""Test decorators with jump functionality."""
@before_model(can_jump_to=["end"])
def jump_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {"jump_to": "end"}
agent = create_agent(
model=FakeToolCallingModel(),
tools=[test_tool],
middleware=[jump_middleware],
)
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
# Should have jumped to end, so minimal messages
assert len(result["messages"]) <= 2
def test_decorator_with_tools_in_agent() -> None:
"""Test that decorators with tools work correctly in agents."""
@tool
def decorator_tool(input: str) -> str:
"""Tool from decorator."""
return f"Decorator tool result: {input}"
@before_model(tools=[decorator_tool])
def tools_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {}
agent = create_agent(
model=FakeToolCallingModel(),
tools=[test_tool],
middleware=[tools_middleware],
)
# The agent should have access to both tools
result = agent.invoke({"messages": [HumanMessage("Use both tools")]})
assert "messages" in result
def test_decorator_error_handling() -> None:
"""Test decorator error handling."""
@before_model
def error_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
raise ValueError("Test error")
agent = create_agent(
model=FakeToolCallingModel(),
tools=[test_tool],
middleware=[error_middleware],
)
# Should handle the error gracefully
with pytest.raises(ValueError, match="Test error"):
agent.invoke({"messages": [HumanMessage("Hello")]})
def test_decorator_with_hook_config() -> None:
"""Test decorators with hook_config."""
@hook_config
def my_hook_config():
return {"custom_config": "test"}
@before_model
def config_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {}
# Test that hook_config can be applied
assert my_hook_config() == {"custom_config": "test"}
def test_decorator_snapshot_compatibility(snapshot: SnapshotAssertion) -> None:
"""Test that decorators produce consistent snapshots."""
@before_model(name="SnapshotMiddleware")
def snapshot_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {"snapshot_test": True}
agent = create_agent(
model=FakeToolCallingModel(),
tools=[test_tool],
middleware=[snapshot_middleware],
)
result = agent.invoke({"messages": [HumanMessage("Snapshot test")]})
assert result == snapshot

View File

@@ -0,0 +1,514 @@
"""Tests for HumanInTheLoopMiddleware."""
import pytest
from typing import Any
from unittest.mock import patch
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain.agents.middleware.human_in_the_loop import (
ActionRequest,
HumanInTheLoopMiddleware,
)
from ..model import FakeToolCallingModel
def test_human_in_the_loop_middleware_initialization() -> None:
"""Test HumanInTheLoopMiddleware initialization."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
},
description_prefix="Custom prefix",
)
assert middleware.interrupt_on == {
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
}
assert middleware.description_prefix == "Custom prefix"
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
"""Test HumanInTheLoopMiddleware when no interrupts are needed."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
}
)
# Test with no messages
state: dict[str, Any] = {"messages": []}
result = middleware.after_model(state, None)
assert result is None
# Test with message but no tool calls
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi there")]}
result = middleware.after_model(state, None)
assert result is None
# Test with tool calls that don't require interrupts
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "other_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
result = middleware.after_model(state, None)
assert result is None
def test_human_in_the_loop_middleware_single_tool_accept() -> None:
"""Test HumanInTheLoopMiddleware with single tool accept response."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_accept(requests):
return [{"type": "accept", "args": None}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0] == ai_message
assert result["messages"][0].tool_calls == ai_message.tool_calls
state["messages"].append(
ToolMessage(content="Tool message", name="test_tool", tool_call_id="1")
)
state["messages"].append(AIMessage(content="test_tool called with result: Tool message"))
result = middleware.after_model(state, None)
# No interrupts needed
assert result is None
def test_human_in_the_loop_middleware_single_tool_edit() -> None:
"""Test HumanInTheLoopMiddleware with single tool edit response."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_edit(requests):
return [
{
"type": "edit",
"args": ActionRequest(
action="test_tool",
args={"input": "edited"},
),
}
]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
assert result["messages"][0].tool_calls[0]["id"] == "1" # ID should be preserved
def test_human_in_the_loop_middleware_single_tool_response() -> None:
"""Test HumanInTheLoopMiddleware with single tool response with custom message."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_response(requests):
return [{"type": "response", "args": "Custom response message"}]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 2
assert isinstance(result["messages"][0], AIMessage)
assert isinstance(result["messages"][1], ToolMessage)
assert result["messages"][1].content == "Custom response message"
assert result["messages"][1].name == "test_tool"
assert result["messages"][1].tool_call_id == "1"
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
"get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
}
)
ai_message = AIMessage(
content="I'll help you with weather",
tool_calls=[
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
def mock_mixed_responses(requests):
return [
{"type": "accept", "args": None},
{"type": "response", "args": "User rejected this tool call"},
]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert (
len(result["messages"]) == 2
) # AI message with accepted tool call + tool message for rejected
# First message should be the AI message with both tool calls
updated_ai_message = result["messages"][0]
assert len(updated_ai_message.tool_calls) == 2 # Both tool calls remain
assert updated_ai_message.tool_calls[0]["name"] == "get_forecast" # Accepted
assert updated_ai_message.tool_calls[1]["name"] == "get_temperature" # Got response
# Second message should be the tool message for the rejected tool call
tool_message = result["messages"][1]
assert isinstance(tool_message, ToolMessage)
assert tool_message.content == "User rejected this tool call"
assert tool_message.name == "get_temperature"
def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
"""Test HumanInTheLoopMiddleware with multiple tools and edit responses."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
"get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
}
)
ai_message = AIMessage(
content="I'll help you with weather",
tool_calls=[
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
def mock_edit_responses(requests):
return [
{
"type": "edit",
"args": ActionRequest(
action="get_forecast",
args={"location": "New York"},
),
},
{
"type": "edit",
"args": ActionRequest(
action="get_temperature",
args={"location": "New York"},
),
},
]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
updated_ai_message = result["messages"][0]
assert updated_ai_message.tool_calls[0]["args"] == {"location": "New York"}
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
assert updated_ai_message.tool_calls[1]["args"] == {"location": "New York"}
assert updated_ai_message.tool_calls[1]["id"] == "2" # ID preserved
def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
"""Test HumanInTheLoopMiddleware with edit action that includes modified args."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_edit_with_args(requests):
return [
{
"type": "edit",
"args": ActionRequest(
action="test_tool",
args={"input": "modified"},
),
}
]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_edit_with_args,
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
# Should have modified args
updated_ai_message = result["messages"][0]
assert updated_ai_message.tool_calls[0]["args"] == {"input": "modified"}
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
def test_human_in_the_loop_middleware_unknown_response_type() -> None:
"""Test HumanInTheLoopMiddleware with unknown response type."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_unknown(requests):
return [{"type": "unknown", "args": None}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
with pytest.raises(
ValueError,
match=r"Unexpected human response: {'type': 'unknown', 'args': None}. Response action 'unknown' is not allowed for tool 'test_tool'. Expected one of \['accept', 'edit', 'response'\] based on the tool's configuration.",
):
middleware.after_model(state, None)
def test_human_in_the_loop_middleware_disallowed_action() -> None:
"""Test HumanInTheLoopMiddleware with action not allowed by tool config."""
# edit is not allowed by tool config
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": {"allow_respond": True, "allow_edit": False, "allow_accept": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_disallowed_action(requests):
return [
{
"type": "edit",
"args": ActionRequest(
action="test_tool",
args={"input": "modified"},
),
}
]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_disallowed_action,
):
with pytest.raises(
ValueError,
match=r"Unexpected human response: {'type': 'edit', 'args': {'action': 'test_tool', 'args': {'input': 'modified'}}}. Response action 'edit' is not allowed for tool 'test_tool'. Expected one of \['accept', 'response'\] based on the tool's configuration.",
):
middleware.after_model(state, None)
def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None:
"""Test HumanInTheLoopMiddleware with mix of auto-approved and interrupt tools."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"interrupt_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[
{"name": "auto_tool", "args": {"input": "auto"}, "id": "1"},
{"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_accept(requests):
return [{"type": "accept", "args": None}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
updated_ai_message = result["messages"][0]
# Should have both tools: auto-approved first, then interrupt tool
assert len(updated_ai_message.tool_calls) == 2
assert updated_ai_message.tool_calls[0]["name"] == "auto_tool"
assert updated_ai_message.tool_calls[1]["name"] == "interrupt_tool"
def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
"""Test that interrupt requests are structured correctly."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
},
description_prefix="Custom prefix",
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
captured_requests = []
def mock_capture_requests(requests):
captured_requests.extend(requests)
return [{"type": "accept", "args": None}]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
):
middleware.after_model(state, None)
assert len(captured_requests) == 1
request = captured_requests[0]
assert "action_request" in request
assert "config" in request
assert "description" in request
assert request["action_request"]["action"] == "test_tool"
assert request["action_request"]["args"] == {"input": "test", "location": "SF"}
expected_config = {"allow_accept": True, "allow_edit": True, "allow_respond": True}
assert request["config"] == expected_config
assert "Custom prefix" in request["description"]
assert "Tool: test_tool" in request["description"]
assert "Args: {'input': 'test', 'location': 'SF'}" in request["description"]
def test_human_in_the_loop_middleware_boolean_configs() -> None:
"""Test HITL middleware with boolean tool configs."""
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test accept
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "accept", "args": None}],
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls == ai_message.tool_calls
# Test edit
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[
{
"type": "edit",
"args": ActionRequest(
action="test_tool",
args={"input": "edited"},
),
}
],
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": False})
result = middleware.after_model(state, None)
# No interruption should occur
assert result is None
def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
"""Test that sequence mismatch in resume raises an error."""
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test with too few responses
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[], # No responses for 1 tool call
):
with pytest.raises(
ValueError,
match=r"Number of human responses \(0\) does not match number of hanging tool calls \(1\)\.",
):
middleware.after_model(state, None)
# Test with too many responses
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[
{"type": "accept", "args": None},
{"type": "accept", "args": None},
], # 2 responses for 1 tool call
):
with pytest.raises(
ValueError,
match=r"Number of human responses \(2\) does not match number of hanging tool calls \(1\)\.",
):
middleware.after_model(state, None)

View File

@@ -0,0 +1,232 @@
"""Tests for PlanningMiddleware."""
import pytest
from langchain_core.messages import HumanMessage
from langchain.agents.middleware.planning import (
PlanningMiddleware,
PlanningState,
WRITE_TODOS_SYSTEM_PROMPT,
write_todos,
WRITE_TODOS_TOOL_DESCRIPTION,
)
from langchain.agents.middleware.types import ModelRequest
from langchain.agents.factory import create_agent
from ..model import FakeToolCallingModel
def test_planning_middleware_initialization() -> None:
"""Test that PlanningMiddleware initializes correctly."""
middleware = PlanningMiddleware()
assert middleware.state_schema == PlanningState
assert len(middleware.tools) == 1
assert middleware.tools[0].name == "write_todos"
@pytest.mark.parametrize(
"original_prompt,expected_prompt_prefix",
[
("Original prompt", "Original prompt\n\n## `write_todos`"),
(None, "## `write_todos`"),
],
)
def test_planning_middleware_modify_model_request(original_prompt, expected_prompt_prefix) -> None:
"""Test that modify_model_request handles system prompts correctly."""
middleware = PlanningMiddleware()
model = FakeToolCallingModel()
request = ModelRequest(
model=model,
system_prompt=original_prompt,
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
modified_request = middleware.modify_model_request(request, state, None)
assert modified_request.system_prompt.startswith(expected_prompt_prefix)
@pytest.mark.parametrize(
"todos,expected_message",
[
([], "Updated todo list to []"),
(
[{"content": "Task 1", "status": "pending"}],
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}]",
),
(
[
{"content": "Task 1", "status": "pending"},
{"content": "Task 2", "status": "in_progress"},
],
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}]",
),
(
[
{"content": "Task 1", "status": "pending"},
{"content": "Task 2", "status": "in_progress"},
{"content": "Task 3", "status": "completed"},
],
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}, {'content': 'Task 3', 'status': 'completed'}]",
),
],
)
def test_planning_middleware_write_todos_tool_execution(todos, expected_message) -> None:
"""Test that the write_todos tool executes correctly."""
tool_call = {
"args": {"todos": todos},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
result = write_todos.invoke(tool_call)
assert result.update["todos"] == todos
assert result.update["messages"][0].content == expected_message
@pytest.mark.parametrize(
"invalid_todos",
[
[{"content": "Task 1", "status": "invalid_status"}],
[{"status": "pending"}],
],
)
def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -> None:
"""Test that the write_todos tool rejects invalid input."""
tool_call = {
"args": {"todos": invalid_todos},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
with pytest.raises(Exception):
write_todos.invoke(tool_call)
def test_planning_middleware_agent_creation_with_middleware() -> None:
"""Test that an agent can be created with the planning middleware."""
model = FakeToolCallingModel(
tool_calls=[
[
{
"args": {"todos": [{"content": "Task 1", "status": "pending"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[
{
"args": {"todos": [{"content": "Task 1", "status": "in_progress"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[
{
"args": {"todos": [{"content": "Task 1", "status": "completed"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[],
]
)
middleware = PlanningMiddleware()
agent = create_agent(model=model, middleware=[middleware])
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert result["todos"] == [{"content": "Task 1", "status": "completed"}]
# human message (1)
# ai message (2) - initial todo
# tool message (3)
# ai message (4) - updated todo
# tool message (5)
# ai message (6) - complete todo
# tool message (7)
# ai message (8) - no tool calls
assert len(result["messages"]) == 8
def test_planning_middleware_custom_system_prompt() -> None:
"""Test that PlanningMiddleware can be initialized with custom system prompt."""
custom_system_prompt = "Custom todo system prompt for testing"
middleware = PlanningMiddleware(system_prompt=custom_system_prompt)
model = FakeToolCallingModel()
request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
modified_request = middleware.modify_model_request(request, state, None)
assert modified_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
def test_planning_middleware_custom_tool_description() -> None:
"""Test that PlanningMiddleware can be initialized with custom tool description."""
custom_tool_description = "Custom tool description for testing"
middleware = PlanningMiddleware(tool_description=custom_tool_description)
assert len(middleware.tools) == 1
tool = middleware.tools[0]
assert tool.description == custom_tool_description
def test_planning_middleware_custom_system_prompt_and_tool_description() -> None:
"""Test that PlanningMiddleware can be initialized with both custom prompts."""
custom_system_prompt = "Custom system prompt"
custom_tool_description = "Custom tool description"
middleware = PlanningMiddleware(
system_prompt=custom_system_prompt,
tool_description=custom_tool_description,
)
# Verify system prompt
model = FakeToolCallingModel()
request = ModelRequest(
model=model,
system_prompt=None,
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
modified_request = middleware.modify_model_request(request, state, None)
assert modified_request.system_prompt == custom_system_prompt
# Verify tool description
assert len(middleware.tools) == 1
tool = middleware.tools[0]
assert tool.description == custom_tool_description
def test_planning_middleware_default_prompts() -> None:
"""Test that PlanningMiddleware uses default prompts when none provided."""
middleware = PlanningMiddleware()
# Verify default system prompt
assert middleware.system_prompt == WRITE_TODOS_SYSTEM_PROMPT
# Verify default tool description
assert middleware.tool_description == WRITE_TODOS_TOOL_DESCRIPTION
assert len(middleware.tools) == 1
tool = middleware.tools[0]
assert tool.description == WRITE_TODOS_TOOL_DESCRIPTION

View File

@@ -0,0 +1,109 @@
"""Tests for AnthropicPromptCachingMiddleware."""
import warnings
from types import ModuleType
from unittest.mock import patch
import pytest
from langchain_core.messages import HumanMessage
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.types import ModelRequest
from ..model import FakeToolCallingModel
def test_anthropic_prompt_caching_middleware_initialization() -> None:
"""Test AnthropicPromptCachingMiddleware initialization."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
assert middleware.type == "ephemeral"
assert middleware.ttl == "1h"
assert middleware.min_messages_to_cache == 5
# Test with default values
middleware = AnthropicPromptCachingMiddleware()
assert middleware.type == "ephemeral"
assert middleware.ttl == "5m"
assert middleware.min_messages_to_cache == 0
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
assert middleware.modify_model_request(fake_request, {}, None).model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
with pytest.raises(
ValueError,
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
):
middleware.modify_model_request(fake_request, {}, None)
langchain_anthropic = ModuleType("langchain_anthropic")
class MockChatAnthropic:
pass
langchain_anthropic.ChatAnthropic = MockChatAnthropic
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
with pytest.raises(
ValueError,
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
):
middleware.modify_model_request(fake_request, {}, None)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
with warnings.catch_warnings(record=True) as w:
result = middleware.modify_model_request(fake_request, {}, None)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
in str(w[-1].message)
)
assert result == fake_request
with warnings.catch_warnings(record=True) as w:
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
result = middleware.modify_model_request(fake_request, {}, None)
assert result is fake_request
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
in str(w[-1].message)
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.modify_model_request(fake_request, {}, None)
assert result is fake_request
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
result = middleware.modify_model_request(fake_request, {}, None)
assert result is fake_request

View File

@@ -0,0 +1,230 @@
"""Tests for SummarizationMiddleware."""
from unittest.mock import patch
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langchain.agents.middleware.summarization import SummarizationMiddleware
from ..model import FakeToolCallingModel
def test_summarization_middleware_initialization() -> None:
"""Test SummarizationMiddleware initialization."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model,
max_tokens_before_summary=1000,
messages_to_keep=10,
summary_prompt="Custom prompt: {messages}",
summary_prefix="Custom prefix:",
)
assert middleware.model == model
assert middleware.max_tokens_before_summary == 1000
assert middleware.messages_to_keep == 10
assert middleware.summary_prompt == "Custom prompt: {messages}"
assert middleware.summary_prefix == "Custom prefix:"
# Test with string model
with patch(
"langchain.agents.middleware.summarization.init_chat_model",
return_value=FakeToolCallingModel(),
):
middleware = SummarizationMiddleware(model="fake-model")
assert isinstance(middleware.model, FakeToolCallingModel)
def test_summarization_middleware_no_summarization_cases() -> None:
"""Test SummarizationMiddleware when summarization is not needed or disabled."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(model=model, max_tokens_before_summary=1000)
# Test when summarization is disabled
middleware_disabled = SummarizationMiddleware(model=model, max_tokens_before_summary=None)
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi")]}
result = middleware_disabled.before_model(state, None)
assert result is None
# Test when token count is below threshold
def mock_token_counter(messages):
return 500 # Below threshold
middleware.token_counter = mock_token_counter
result = middleware.before_model(state, None)
assert result is None
def test_summarization_middleware_helper_methods() -> None:
"""Test SummarizationMiddleware helper methods."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(model=model, max_tokens_before_summary=1000)
# Test message ID assignment
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
middleware._ensure_message_ids(messages)
for msg in messages:
assert msg.id is not None
# Test message partitioning
messages = [
HumanMessage(content="1"),
HumanMessage(content="2"),
HumanMessage(content="3"),
HumanMessage(content="4"),
HumanMessage(content="5"),
]
to_summarize, preserved = middleware._partition_messages(messages, 2)
assert len(to_summarize) == 2
assert len(preserved) == 3
assert to_summarize == messages[:2]
assert preserved == messages[2:]
# Test summary message building
summary = "This is a test summary"
new_messages = middleware._build_new_messages(summary)
assert len(new_messages) == 1
assert isinstance(new_messages[0], HumanMessage)
assert "Here is a summary of the conversation to date:" in new_messages[0].content
assert summary in new_messages[0].content
# Test tool call detection
ai_message_no_tools = AIMessage(content="Hello")
assert not middleware._has_tool_calls(ai_message_no_tools)
ai_message_with_tools = AIMessage(
content="Hello", tool_calls=[{"name": "test", "args": {}, "id": "1"}]
)
assert middleware._has_tool_calls(ai_message_with_tools)
human_message = HumanMessage(content="Hello")
assert not middleware._has_tool_calls(human_message)
def test_summarization_middleware_tool_call_safety() -> None:
"""Test SummarizationMiddleware tool call safety logic."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, max_tokens_before_summary=1000, messages_to_keep=3
)
# Test safe cutoff point detection with tool calls
messages = [
HumanMessage(content="1"),
AIMessage(content="2", tool_calls=[{"name": "test", "args": {}, "id": "1"}]),
ToolMessage(content="3", tool_call_id="1"),
HumanMessage(content="4"),
]
# Safe cutoff (doesn't separate AI/Tool pair)
is_safe = middleware._is_safe_cutoff_point(messages, 0)
assert is_safe is True
# Unsafe cutoff (separates AI/Tool pair)
is_safe = middleware._is_safe_cutoff_point(messages, 2)
assert is_safe is False
# Test tool call ID extraction
ids = middleware._extract_tool_call_ids(messages[1])
assert ids == {"1"}
def test_summarization_middleware_summary_creation() -> None:
"""Test SummarizationMiddleware summary creation."""
class MockModel(BaseChatModel):
def invoke(self, prompt):
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
middleware = SummarizationMiddleware(model=MockModel(), max_tokens_before_summary=1000)
# Test normal summary creation
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
summary = middleware._create_summary(messages)
assert summary == "Generated summary"
# Test empty messages
summary = middleware._create_summary([])
assert summary == "No previous conversation history."
# Test error handling
class ErrorModel(BaseChatModel):
def invoke(self, prompt):
raise Exception("Model error")
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
middleware_error = SummarizationMiddleware(model=ErrorModel(), max_tokens_before_summary=1000)
summary = middleware_error._create_summary(messages)
assert "Error generating summary: Model error" in summary
def test_summarization_middleware_full_workflow() -> None:
"""Test SummarizationMiddleware complete summarization workflow."""
class MockModel(BaseChatModel):
def invoke(self, prompt):
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
middleware = SummarizationMiddleware(
model=MockModel(), max_tokens_before_summary=1000, messages_to_keep=2
)
# Mock high token count to trigger summarization
def mock_token_counter(messages):
return 1500 # Above threshold
middleware.token_counter = mock_token_counter
messages = [
HumanMessage(content="1"),
HumanMessage(content="2"),
HumanMessage(content="3"),
HumanMessage(content="4"),
HumanMessage(content="5"),
]
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) > 0
# Should have RemoveMessage for cleanup
assert isinstance(result["messages"][0], RemoveMessage)
assert result["messages"][0].id == REMOVE_ALL_MESSAGES
# Should have summary message
summary_message = None
for msg in result["messages"]:
if isinstance(msg, HumanMessage) and "summary of the conversation" in msg.content:
summary_message = msg
break
assert summary_message is not None
assert "Generated summary" in summary_message.content

View File

@@ -1,17 +1,18 @@
"""Test Middleware handling of tools in agents."""
"""Tests for middleware handling of tools in agents."""
import pytest
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
from langchain.agents.middleware_agent import create_agent
from langchain.agents.factory import create_agent
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.tools import tool
from .model import FakeToolCallingModel
from langgraph.runtime import Runtime
from ..model import FakeToolCallingModel
def test_model_request_tools_are_strings() -> None:
"""Test that ModelRequest.tools contains tool names as strings, not tool objects."""
def test_model_request_tools_are_objects() -> None:
"""Test that ModelRequest.tools contains tool objects (BaseTool | dict)."""
captured_requests: list[ModelRequest] = []
@tool
@@ -36,23 +37,23 @@ def test_model_request_tools_are_strings() -> None:
tools=[search_tool, calculator],
system_prompt="You are a helpful assistant.",
middleware=[RequestCapturingMiddleware()],
).compile()
)
agent.invoke({"messages": [HumanMessage("Hello")]})
# Verify that at least one request was captured
assert len(captured_requests) > 0
# Check that tools in the request are strings (tool names)
# Check that tools in the request are tool objects
request = captured_requests[0]
assert isinstance(request.tools, list)
assert len(request.tools) == 2
assert all(isinstance(tool_name, str) for tool_name in request.tools)
assert set(request.tools) == {"search_tool", "calculator"}
tool_names = {t.name for t in request.tools}
assert tool_names == {"search_tool", "calculator"}
def test_middleware_can_modify_tool_names() -> None:
"""Test that middleware can modify the list of tool names in ModelRequest."""
def test_middleware_can_modify_tools() -> None:
"""Test that middleware can modify the list of tools in ModelRequest."""
@tool
def tool_a(input: str) -> str:
@@ -74,7 +75,7 @@ def test_middleware_can_modify_tool_names() -> None:
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
# Only allow tool_a and tool_b
request.tools = ["tool_a", "tool_b"]
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
return request
# Model will try to call tool_a
@@ -87,7 +88,7 @@ def test_middleware_can_modify_tool_names() -> None:
tools=[tool_a, tool_b, tool_c],
system_prompt="You are a helpful assistant.",
middleware=[ToolFilteringMiddleware()],
).compile()
)
result = agent.invoke({"messages": [HumanMessage("Use tool_a")]})
@@ -98,88 +99,114 @@ def test_middleware_can_modify_tool_names() -> None:
assert tool_messages[0].name == "tool_a"
def test_unknown_tool_name_raises_error() -> None:
"""Test that using an unknown tool name in ModelRequest raises a clear error."""
def test_middleware_can_add_custom_tools() -> None:
"""Test that middleware can add custom tool objects to ModelRequest."""
@tool
def known_tool(input: str) -> str:
"""A known tool."""
return "result"
def original_tool(input: str) -> str:
"""Original tool."""
return "original"
class BadMiddleware(AgentMiddleware):
@tool
def middleware_tool(input: str) -> str:
"""Middleware-added tool."""
return "middleware"
class ToolAddingMiddleware(AgentMiddleware):
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
# Add an unknown tool name
request.tools = ["known_tool", "unknown_tool"]
# Add middleware tool to the request
request.tools = request.tools + [middleware_tool]
return request
agent = create_agent(
model=FakeToolCallingModel(),
tools=[known_tool],
system_prompt="You are a helpful assistant.",
middleware=[BadMiddleware()],
).compile()
with pytest.raises(ValueError, match="Middleware returned unknown tool names"):
agent.invoke({"messages": [HumanMessage("Hello")]})
def test_middleware_can_add_and_remove_tools() -> None:
"""Test that middleware can dynamically add/remove tools based on state."""
@tool
def search(query: str) -> str:
"""Search for information."""
return f"Search results for: {query}"
@tool
def admin_tool(command: str) -> str:
"""Admin-only tool."""
return f"Admin: {command}"
class AdminState(AgentState):
is_admin: bool
class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
state_schema = AdminState
def modify_model_request(
self, request: ModelRequest, state: AdminState, runtime: Runtime
) -> ModelRequest:
# Remove admin_tool if not admin
if not state.get("is_admin", False):
request.tools = [name for name in request.tools if name != "admin_tool"]
return request
model = FakeToolCallingModel()
# Model will try to call middleware_tool
model = FakeToolCallingModel(
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "middleware_tool"}], []]
)
agent = create_agent(
model=model,
tools=[search, admin_tool],
tools=[original_tool],
system_prompt="You are a helpful assistant.",
middleware=[ToolAddingMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Use middleware_tool")]})
# Verify that the middleware tool was executed successfully
messages = result["messages"]
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
assert len(tool_messages) == 1
assert tool_messages[0].name == "middleware_tool"
assert tool_messages[0].content == "middleware"
def test_middleware_can_conditionally_add_tools() -> None:
"""Test that middleware can conditionally add tools based on state."""
@tool
def admin_tool(input: str) -> str:
"""Admin-only tool."""
return "admin_result"
class AdminState(AgentState):
is_admin: bool = False
class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
def modify_model_request(
self, request: ModelRequest, state: AdminState, runtime: Runtime
) -> ModelRequest:
# Only add admin tool if user is admin
if state.get("is_admin", False):
request.tools = request.tools + [admin_tool]
return request
# Model will try to call admin_tool when admin
model = FakeToolCallingModel(
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "admin_tool"}], []]
)
agent = create_agent(
model=model,
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[ConditionalToolMiddleware()],
).compile()
)
# Test non-admin user - should not have access to admin_tool
# We can't directly inspect the bound model, but we can verify the agent runs
result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": False})
assert "messages" in result
# Test with admin user
result = agent.invoke({"messages": [HumanMessage("Use admin tool")], "is_admin": True})
messages = result["messages"]
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
assert len(tool_messages) == 1
assert tool_messages[0].name == "admin_tool"
# Test admin user - should have access to all tools
result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": True})
assert "messages" in result
# Test with non-admin user (should not have admin tool available)
model_no_admin = FakeToolCallingModel(tool_calls=[[], []]) # No tool calls
agent_no_admin = create_agent(
model=model_no_admin,
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[ConditionalToolMiddleware()],
)
result_no_admin = agent_no_admin.invoke(
{"messages": [HumanMessage("Hello")], "is_admin": False}
)
messages_no_admin = result_no_admin["messages"]
tool_messages_no_admin = [m for m in messages_no_admin if isinstance(m, ToolMessage)]
assert len(tool_messages_no_admin) == 0
def test_empty_tools_list_is_valid() -> None:
"""Test that middleware can set tools to an empty list."""
"""Test that middleware can set an empty tools list."""
@tool
def some_tool(input: str) -> str:
"""Some tool."""
return "result"
def test_tool(input: str) -> str:
"""Test tool."""
return "test"
class NoToolsMiddleware(AgentMiddleware):
class EmptyToolsMiddleware(AgentMiddleware):
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
@@ -187,17 +214,18 @@ def test_empty_tools_list_is_valid() -> None:
request.tools = []
return request
model = FakeToolCallingModel()
# Model should not make any tool calls
model = FakeToolCallingModel(tool_calls=[[], []]) # No tool calls
agent = create_agent(
model=model,
tools=[some_tool],
tools=[test_tool],
system_prompt="You are a helpful assistant.",
middleware=[NoToolsMiddleware()],
).compile()
middleware=[EmptyToolsMiddleware()],
)
# Should run without error even with no tools
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
@@ -224,20 +252,20 @@ def test_tools_preserved_across_multiple_middleware() -> None:
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
modification_order.append(request.tools.copy())
modification_order.append([t.name for t in request.tools])
# Remove tool_c
request.tools = [name for name in request.tools if name != "tool_c"]
request.tools = [t for t in request.tools if t.name != "tool_c"]
return request
class SecondMiddleware(AgentMiddleware):
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
modification_order.append(request.tools.copy())
modification_order.append([t.name for t in request.tools])
# Should not see tool_c here
assert "tool_c" not in request.tools
assert "tool_c" not in [t.name for t in request.tools]
# Remove tool_b
request.tools = [name for name in request.tools if name != "tool_b"]
request.tools = [t for t in request.tools if t.name != "tool_b"]
return request
agent = create_agent(
@@ -245,7 +273,7 @@ def test_tools_preserved_across_multiple_middleware() -> None:
tools=[tool_a, tool_b, tool_c],
system_prompt="You are a helpful assistant.",
middleware=[FirstMiddleware(), SecondMiddleware()],
).compile()
)
agent.invoke({"messages": [HumanMessage("Hello")]})
@@ -260,37 +288,38 @@ def test_tools_preserved_across_multiple_middleware() -> None:
def test_middleware_with_additional_tools() -> None:
"""Test middleware that provides additional tools via tools attribute."""
@tool
def base_tool(input: str) -> str:
"""Base tool."""
return "base"
@tool
def middleware_tool(input: str) -> str:
"""Tool provided by middleware."""
return "middleware"
return f"Middleware tool result: {input}"
class ToolProvidingMiddleware(AgentMiddleware):
tools = [middleware_tool]
def __init__(self):
super().__init__()
self.tools = [middleware_tool]
# Model calls the middleware-provided tool
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
# Add middleware tools to the request
request.tools = request.tools + self.tools
return request
# Model will try to call middleware_tool
model = FakeToolCallingModel(
tool_calls=[
[{"args": {"input": "test"}, "id": "1", "name": "middleware_tool"}],
[],
]
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "middleware_tool"}], []]
)
agent = create_agent(
model=model,
tools=[base_tool],
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[ToolProvidingMiddleware()],
).compile()
)
result = agent.invoke({"messages": [HumanMessage("Use middleware tool")]})
# Verify that the middleware tool was executed
# Verify that the middleware tool was executed successfully
messages = result["messages"]
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
assert len(tool_messages) == 1

View File

@@ -108,4 +108,4 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
}
)
return self.bind(tools=tool_dicts)
return self.bind(tools=tool_dicts, **kwargs)

View File

@@ -0,0 +1,381 @@
"""Tests for basic agent functionality."""
import pytest
from langchain_core.messages import (
AIMessage,
HumanMessage,
ToolCall,
ToolMessage,
)
from typing import Any
from langchain_core.tools import tool as dec_tool, InjectedToolCallId, ToolException
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
from langgraph.types import Command
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from langchain.agents import create_agent
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.agents.structured_output import ToolStrategy
from langchain.tools import (
InjectedState,
InjectedStore,
ToolNode,
)
from langchain.tools.tool_node import (
_get_state_args,
_infer_handled_types,
)
from .any_str import AnyStr
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
from .model import FakeToolCallingModel
pytestmark = pytest.mark.anyio
def test_no_prompt(sync_checkpointer: BaseCheckpointSaver) -> None:
"""Test agent with no system prompt."""
model = FakeToolCallingModel()
agent = create_agent(
model,
[],
checkpointer=sync_checkpointer,
)
inputs = [HumanMessage("hi?")]
thread = {"configurable": {"thread_id": "123"}}
response = agent.invoke({"messages": inputs}, thread, debug=True)
expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
assert response == expected_response
saved = sync_checkpointer.get_tuple(thread)
assert saved is not None
checkpoint_values = saved.checkpoint["channel_values"]
assert checkpoint_values["messages"] == [
_AnyIdHumanMessage(content="hi?"),
AIMessage(content="hi?", id="0"),
]
assert checkpoint_values["thread_model_call_count"] == 1
assert saved.metadata == {
"parents": {},
"source": "loop",
"step": 1,
}
assert saved.pending_writes == []
async def test_no_prompt_async(async_checkpointer: BaseCheckpointSaver) -> None:
"""Test agent with no system prompt (async)."""
model = FakeToolCallingModel()
agent = create_agent(model, [], checkpointer=async_checkpointer)
inputs = [HumanMessage("hi?")]
thread = {"configurable": {"thread_id": "123"}}
response = await agent.ainvoke({"messages": inputs}, thread, debug=True)
expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
assert response == expected_response
saved = await async_checkpointer.aget_tuple(thread)
assert saved is not None
checkpoint_values = saved.checkpoint["channel_values"]
assert checkpoint_values["messages"] == [
_AnyIdHumanMessage(content="hi?"),
AIMessage(content="hi?", id="0"),
]
assert checkpoint_values["thread_model_call_count"] == 1
assert saved.metadata == {
"parents": {},
"source": "loop",
"step": 1,
}
assert saved.pending_writes == []
def test_system_message_prompt() -> None:
"""Test agent with system message prompt."""
system_prompt = "Foo"
model = FakeToolCallingModel()
agent = create_agent(model, [], system_prompt=system_prompt)
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs}, debug=True)
expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
assert response == expected_response
def test_system_message_prompt_with_tools() -> None:
"""Test agent with system message prompt and tools."""
system_prompt = "You are a helpful assistant."
@dec_tool
def search_tool(query: str) -> str:
"""Search for information."""
return f"Results for: {query}"
model = FakeToolCallingModel(
tool_calls=[[{"args": {"query": "test"}, "id": "1", "name": "search_tool"}], []]
)
agent = create_agent(model, [search_tool], system_prompt=system_prompt)
inputs = [HumanMessage("Search for something")]
response = agent.invoke({"messages": inputs}, debug=True)
assert "messages" in response
messages = response["messages"]
assert len(messages) >= 2 # Human message + AI message
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)
def test_agent_with_tools() -> None:
"""Test agent with tools."""
@dec_tool
def calculator(expression: str) -> str:
"""Calculate a mathematical expression."""
return f"Result: {expression}"
@dec_tool
def search_tool(query: str) -> str:
"""Search for information."""
return f"Results for: {query}"
model = FakeToolCallingModel(
tool_calls=[[{"args": {"expression": "2+2"}, "id": "1", "name": "calculator"}], []]
)
agent = create_agent(model, [calculator, search_tool])
inputs = [HumanMessage("Calculate 2+2")]
response = agent.invoke({"messages": inputs}, debug=True)
assert "messages" in response
messages = response["messages"]
assert len(messages) >= 3 # Human + AI + Tool message
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)
assert isinstance(messages[2], ToolMessage)
def test_agent_with_structured_output() -> None:
"""Test agent with structured output."""
class WeatherResponse(BaseModel):
temperature: float = Field(description="Temperature in Fahrenheit")
condition: str = Field(description="Weather condition")
model = FakeToolCallingModel()
agent = create_agent(
model,
[],
response_format=ToolStrategy(schema=WeatherResponse),
)
inputs = [HumanMessage("What's the weather like?")]
response = agent.invoke({"messages": inputs}, debug=True)
assert "messages" in response
messages = response["messages"]
assert len(messages) >= 2
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)
def test_agent_with_injected_tools() -> None:
"""Test agent with tools that use injected state."""
@dec_tool
def state_tool(input: str, state: Annotated[dict, InjectedState]) -> str:
"""Tool that uses injected state."""
return f"State: {state.get('custom_field', 'none')}"
@dec_tool
def store_tool(input: str, store: Annotated[BaseStore, InjectedStore]) -> str:
"""Tool that uses injected store."""
return f"Store: {type(store).__name__}"
model = FakeToolCallingModel(
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "state_tool"}], []]
)
agent = create_agent(
model,
[state_tool, store_tool],
store=InMemoryStore(),
)
inputs = [HumanMessage("Use state tool")]
response = agent.invoke({"messages": inputs, "custom_field": "test_value"}, debug=True)
assert "messages" in response
messages = response["messages"]
assert len(messages) >= 3 # Human + AI + Tool message
assert isinstance(messages[2], ToolMessage)
assert "test_value" in messages[2].content
def test_agent_with_tool_exception() -> None:
"""Test agent handling tool exceptions."""
@dec_tool
def error_tool(input: str) -> str:
"""Tool that raises an exception."""
raise ToolException("Tool error occurred")
model = FakeToolCallingModel(
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "error_tool"}], []]
)
agent = create_agent(model, [error_tool])
inputs = [HumanMessage("Use error tool")]
response = agent.invoke({"messages": inputs}, debug=True)
assert "messages" in response
messages = response["messages"]
assert len(messages) >= 3 # Human + AI + Tool message
assert isinstance(messages[2], ToolMessage)
assert "Tool error occurred" in messages[2].content
def test_agent_with_multiple_tool_calls() -> None:
"""Test agent with multiple tool calls in one response."""
@dec_tool
def tool1(input: str) -> str:
"""First tool."""
return f"Tool1: {input}"
@dec_tool
def tool2(input: str) -> str:
"""Second tool."""
return f"Tool2: {input}"
model = FakeToolCallingModel(
tool_calls=[
[
{"args": {"input": "test1"}, "id": "1", "name": "tool1"},
{"args": {"input": "test2"}, "id": "2", "name": "tool2"},
],
[],
]
)
agent = create_agent(model, [tool1, tool2])
inputs = [HumanMessage("Use both tools")]
response = agent.invoke({"messages": inputs}, debug=True)
assert "messages" in response
messages = response["messages"]
assert len(messages) >= 4 # Human + AI + 2 Tool messages
assert isinstance(messages[1], AIMessage)
assert len(messages[1].tool_calls) == 2
assert isinstance(messages[2], ToolMessage)
assert isinstance(messages[3], ToolMessage)
def test_agent_with_custom_middleware() -> None:
"""Test agent with custom middleware."""
class CustomMiddleware(AgentMiddleware[AgentState]):
def before_model(self, state: AgentState, runtime) -> dict[str, Any]:
return {"custom_field": "middleware_value"}
model = FakeToolCallingModel()
agent = create_agent(model, [], middleware=[CustomMiddleware()])
inputs = [HumanMessage("Hello")]
response = agent.invoke({"messages": inputs}, debug=True)
assert "messages" in response
assert "custom_field" in response
assert response["custom_field"] == "middleware_value"
def test_agent_with_checkpointer() -> None:
"""Test agent with checkpointer for state persistence."""
model = FakeToolCallingModel()
agent = create_agent(model, [], checkpointer=InMemoryStore())
inputs = [HumanMessage("Hello")]
thread = {"configurable": {"thread_id": "test_thread"}}
# First invocation
response1 = agent.invoke({"messages": inputs}, thread, debug=True)
assert "messages" in response1
# Second invocation in same thread
inputs2 = [HumanMessage("Hello again")]
response2 = agent.invoke({"messages": inputs2}, thread, debug=True)
assert "messages" in response2
# Should have conversation history
messages = response2["messages"]
assert len(messages) >= 2 # Should have previous messages
def test_agent_with_store() -> None:
"""Test agent with store for persistent data."""
@dec_tool
def store_tool(input: str, store: Annotated[BaseStore, InjectedStore]) -> str:
"""Tool that uses store."""
store.put("test_key", "test_value")
return "Stored value"
model = FakeToolCallingModel(
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "store_tool"}], []]
)
store = InMemoryStore()
agent = create_agent(model, [store_tool], store=store)
inputs = [HumanMessage("Use store tool")]
response = agent.invoke({"messages": inputs}, debug=True)
assert "messages" in response
# Verify store was used
stored_value = store.get("test_key")
assert stored_value == "test_value"
def test_agent_debug_mode() -> None:
"""Test agent in debug mode."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
inputs = [HumanMessage("Hello")]
response = agent.invoke({"messages": inputs}, debug=True)
assert "messages" in response
# Debug mode should provide additional information
assert isinstance(response, dict)
def test_agent_with_empty_tools() -> None:
"""Test agent with empty tools list."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
inputs = [HumanMessage("Hello")]
response = agent.invoke({"messages": inputs}, debug=True)
assert "messages" in response
messages = response["messages"]
assert len(messages) >= 2 # Human + AI message
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)
def test_agent_with_none_system_prompt() -> None:
"""Test agent with None system prompt."""
model = FakeToolCallingModel()
agent = create_agent(model, [], system_prompt=None)
inputs = [HumanMessage("Hello")]
response = agent.invoke({"messages": inputs}, debug=True)
assert "messages" in response
messages = response["messages"]
assert len(messages) >= 2
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)

View File

@@ -0,0 +1,356 @@
"""Tests for edge cases and error conditions in agents."""
import pytest
from unittest.mock import Mock, patch
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_core.tools import tool
from langchain_core.exceptions import LangChainException
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents import create_agent
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.agents.structured_output import ToolStrategy
from pydantic import BaseModel, Field
from .model import FakeToolCallingModel
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_agent_with_invalid_model_string(self) -> None:
"""Test agent creation with invalid model string."""
with pytest.raises(Exception): # Should raise some kind of error
create_agent("invalid:model", [])
def test_agent_with_none_tools(self) -> None:
"""Test agent with None tools."""
model = FakeToolCallingModel()
agent = create_agent(model, None)
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_empty_tools_list(self) -> None:
"""Test agent with empty tools list."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_duplicate_middleware(self) -> None:
"""Test agent with duplicate middleware instances."""
class TestMiddleware(AgentMiddleware):
pass
middleware1 = TestMiddleware()
middleware2 = TestMiddleware() # Same class, different instance
model = FakeToolCallingModel()
# Should raise an error about duplicate middleware
with pytest.raises(AssertionError, match="Please remove duplicate middleware instances"):
create_agent(model, [], middleware=[middleware1, middleware2])
def test_agent_with_middleware_error(self) -> None:
"""Test agent with middleware that raises an error."""
class ErrorMiddleware(AgentMiddleware):
def before_model(self, state: AgentState, runtime) -> dict[str, Any]:
raise ValueError("Middleware error")
model = FakeToolCallingModel()
agent = create_agent(model, [], middleware=[ErrorMiddleware()])
with pytest.raises(ValueError, match="Middleware error"):
agent.invoke({"messages": [HumanMessage("Hello")]})
def test_agent_with_invalid_structured_output(self) -> None:
"""Test agent with invalid structured output configuration."""
class InvalidSchema:
pass # Not a valid Pydantic model
model = FakeToolCallingModel()
# Should handle invalid schema gracefully
with pytest.raises(Exception):
create_agent(model, [], response_format=InvalidSchema)
def test_agent_with_tool_exception(self) -> None:
"""Test agent handling tool exceptions."""
@tool
def error_tool(input: str) -> str:
"""Tool that always raises an exception."""
raise Exception("Tool execution failed")
model = FakeToolCallingModel(
tool_calls=[[{"args": {"input": "test"}, "id": "1", "name": "error_tool"}], []]
)
agent = create_agent(model, [error_tool])
result = agent.invoke({"messages": [HumanMessage("Use error tool")]})
# Should handle tool exception gracefully
assert "messages" in result
messages = result["messages"]
assert len(messages) >= 3 # Human + AI + Tool message
assert isinstance(messages[2], ToolMessage)
assert "Tool execution failed" in messages[2].content
def test_agent_with_malformed_tool_call(self) -> None:
"""Test agent with malformed tool call from model."""
model = FakeToolCallingModel(
tool_calls=[[{"invalid": "tool_call"}], []] # Malformed tool call
)
agent = create_agent(model, [])
# Should handle malformed tool calls gracefully
with pytest.raises(Exception):
agent.invoke({"messages": [HumanMessage("Hello")]})
def test_agent_with_empty_messages(self) -> None:
"""Test agent with empty messages list."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
result = agent.invoke({"messages": []})
assert "messages" in result
def test_agent_with_none_messages(self) -> None:
"""Test agent with None messages."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
with pytest.raises(Exception):
agent.invoke({"messages": None})
def test_agent_with_invalid_state(self) -> None:
"""Test agent with invalid state structure."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
with pytest.raises(Exception):
agent.invoke({"invalid_key": "value"})
def test_agent_with_large_message_history(self) -> None:
"""Test agent with large message history."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
# Create a large message history
messages = [HumanMessage(f"Message {i}") for i in range(100)]
result = agent.invoke({"messages": messages})
assert "messages" in result
def test_agent_with_special_characters(self) -> None:
"""Test agent with special characters in messages."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
special_message = HumanMessage("Hello! @#$%^&*()_+-=[]{}|;':\",./<>?")
result = agent.invoke({"messages": [special_message]})
assert "messages" in result
def test_agent_with_unicode_messages(self) -> None:
"""Test agent with unicode messages."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
unicode_message = HumanMessage("Hello 世界! 🌍")
result = agent.invoke({"messages": [unicode_message]})
assert "messages" in result
def test_agent_with_very_long_message(self) -> None:
"""Test agent with very long message."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
long_message = HumanMessage("A" * 10000) # 10k character message
result = agent.invoke({"messages": [long_message]})
assert "messages" in result
def test_agent_with_multiple_system_prompts(self) -> None:
"""Test agent behavior with multiple system prompts."""
model = FakeToolCallingModel()
agent = create_agent(model, [], system_prompt="First prompt")
# Add system message to input
messages = [HumanMessage("System: Second prompt"), HumanMessage("Hello")]
result = agent.invoke({"messages": messages})
assert "messages" in result
def test_agent_with_tool_node_instead_of_list(self) -> None:
"""Test agent with ToolNode instead of tools list."""
from langchain.tools import ToolNode
@tool
def test_tool(input: str) -> str:
"""Test tool."""
return f"Result: {input}"
tool_node = ToolNode([test_tool])
model = FakeToolCallingModel()
agent = create_agent(model, tool_node)
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_interrupt_before(self) -> None:
"""Test agent with interrupt_before configuration."""
model = FakeToolCallingModel()
agent = create_agent(model, [], interrupt_before=["model_request"])
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_interrupt_after(self) -> None:
"""Test agent with interrupt_after configuration."""
model = FakeToolCallingModel()
agent = create_agent(model, [], interrupt_after=["model_request"])
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_debug_mode(self) -> None:
"""Test agent with debug mode enabled."""
model = FakeToolCallingModel()
agent = create_agent(model, [], debug=True)
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_cache(self) -> None:
"""Test agent with cache configuration."""
model = FakeToolCallingModel()
cache = {} # Simple dict cache
agent = create_agent(model, [], cache=cache)
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_custom_name(self) -> None:
"""Test agent with custom name."""
model = FakeToolCallingModel()
agent = create_agent(model, [], name="CustomAgent")
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_context_schema(self) -> None:
"""Test agent with custom context schema."""
from typing_extensions import TypedDict
class CustomContext(TypedDict):
user_id: str
session_id: str
model = FakeToolCallingModel()
agent = create_agent(model, [], context_schema=CustomContext)
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_checkpointer_and_store(self) -> None:
"""Test agent with both checkpointer and store."""
from langgraph.store.memory import InMemoryStore
model = FakeToolCallingModel()
checkpointer = InMemorySaver()
store = InMemoryStore()
agent = create_agent(model, [], checkpointer=checkpointer, store=store)
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_middleware_modifying_request(self) -> None:
"""Test agent with middleware that modifies model request."""
class RequestModifyingMiddleware(AgentMiddleware):
def modify_model_request(self, request, state, runtime):
request.model_settings["custom_setting"] = "modified"
return request
model = FakeToolCallingModel()
agent = create_agent(model, [], middleware=[RequestModifyingMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_agent_with_middleware_jumping_to_end(self) -> None:
"""Test agent with middleware that jumps to end."""
class JumpToEndMiddleware(AgentMiddleware):
def before_model(self, state: AgentState, runtime):
return {"jump_to": "end"}
model = FakeToolCallingModel()
agent = create_agent(model, [], middleware=[JumpToEndMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
# Should have minimal messages due to jumping to end
assert len(result["messages"]) <= 2
def test_agent_with_structured_output_error(self) -> None:
"""Test agent with structured output validation error."""
class WeatherResponse(BaseModel):
temperature: float = Field(description="Temperature in Fahrenheit")
condition: str = Field(description="Weather condition")
# Mock model that returns invalid structured output
class InvalidStructuredModel(FakeToolCallingModel):
def invoke(self, messages, **kwargs):
# Return a message that doesn't match the schema
return AIMessage(content="Invalid response")
model = InvalidStructuredModel()
agent = create_agent(model, [], response_format=ToolStrategy(schema=WeatherResponse))
# Should handle structured output errors gracefully
result = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
assert "messages" in result
def test_agent_with_concurrent_invocations(self) -> None:
"""Test agent with concurrent invocations."""
import asyncio
model = FakeToolCallingModel()
agent = create_agent(model, [])
async def run_concurrent():
tasks = [agent.ainvoke({"messages": [HumanMessage(f"Message {i}")]}) for i in range(5)]
results = await asyncio.gather(*tasks)
return results
# Run concurrent invocations
results = asyncio.run(run_concurrent())
# All should succeed
assert len(results) == 5
for result in results:
assert "messages" in result
def test_agent_with_streaming(self) -> None:
"""Test agent streaming functionality."""
model = FakeToolCallingModel()
agent = create_agent(model, [])
# Test streaming
stream = agent.stream({"messages": [HumanMessage("Hello")]})
chunks = list(stream)
# Should have at least one chunk
assert len(chunks) > 0
# Each chunk should be a dict
for chunk in chunks:
assert isinstance(chunk, dict)

View File

@@ -1,574 +0,0 @@
"""Consolidated tests for middleware decorators: before_model, after_model, and modify_model_request."""
import pytest
from typing import Any
from typing_extensions import NotRequired
from syrupy.assertion import SnapshotAssertion
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.tools import tool
from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelRequest,
before_model,
after_model,
modify_model_request,
hook_config,
)
from langchain.agents.middleware_agent import create_agent, _get_can_jump_to
from .model import FakeToolCallingModel
class CustomState(AgentState):
"""Custom state schema for testing."""
custom_field: NotRequired[str]
@tool
def test_tool(input: str) -> str:
"""A test tool for middleware testing."""
return f"Tool result: {input}"
def test_before_model_decorator() -> None:
"""Test before_model decorator with all configuration options."""
@before_model(
state_schema=CustomState, tools=[test_tool], can_jump_to=["end"], name="CustomBeforeModel"
)
def custom_before_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
return {"jump_to": "end"}
assert isinstance(custom_before_model, AgentMiddleware)
assert custom_before_model.state_schema == CustomState
assert custom_before_model.tools == [test_tool]
assert getattr(custom_before_model.__class__.before_model, "__can_jump_to__", []) == ["end"]
assert custom_before_model.__class__.__name__ == "CustomBeforeModel"
result = custom_before_model.before_model({"messages": [HumanMessage("Hello")]}, None)
assert result == {"jump_to": "end"}
def test_after_model_decorator() -> None:
"""Test after_model decorator with all configuration options."""
@after_model(
state_schema=CustomState,
tools=[test_tool],
can_jump_to=["model", "end"],
name="CustomAfterModel",
)
def custom_after_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
return {"jump_to": "model"}
# Verify all options were applied
assert isinstance(custom_after_model, AgentMiddleware)
assert custom_after_model.state_schema == CustomState
assert custom_after_model.tools == [test_tool]
assert getattr(custom_after_model.__class__.after_model, "__can_jump_to__", []) == [
"model",
"end",
]
assert custom_after_model.__class__.__name__ == "CustomAfterModel"
# Verify it works
result = custom_after_model.after_model(
{"messages": [HumanMessage("Hello"), AIMessage("Hi!")]}, None
)
assert result == {"jump_to": "model"}
def test_modify_model_request_decorator() -> None:
"""Test modify_model_request decorator with all configuration options."""
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="CustomModifyRequest")
def custom_modify_request(
request: ModelRequest, state: CustomState, runtime: Runtime
) -> ModelRequest:
request.system_prompt = "Modified"
return request
# Verify all options were applied
assert isinstance(custom_modify_request, AgentMiddleware)
assert custom_modify_request.state_schema == CustomState
assert custom_modify_request.tools == [test_tool]
assert custom_modify_request.__class__.__name__ == "CustomModifyRequest"
# Verify it works
original_request = ModelRequest(
model="test-model",
system_prompt="Original",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=[],
response_format=None,
)
result = custom_modify_request.modify_model_request(
original_request, {"messages": [HumanMessage("Hello")]}, None
)
assert result.system_prompt == "Modified"
def test_all_decorators_integration() -> None:
"""Test all three decorators working together in an agent."""
call_order = []
@before_model
def track_before(state: AgentState, runtime: Runtime) -> None:
call_order.append("before")
return None
@modify_model_request
def track_modify(request: ModelRequest, state: AgentState, runtime: Runtime) -> ModelRequest:
call_order.append("modify")
return request
@after_model
def track_after(state: AgentState, runtime: Runtime) -> None:
call_order.append("after")
return None
agent = create_agent(
model=FakeToolCallingModel(), middleware=[track_before, track_modify, track_after]
)
agent = agent.compile()
agent.invoke({"messages": [HumanMessage("Hello")]})
assert call_order == ["before", "modify", "after"]
def test_decorators_use_function_names_as_default() -> None:
"""Test that decorators use function names as default middleware names."""
@before_model
def my_before_hook(state: AgentState, runtime: Runtime) -> None:
return None
@modify_model_request
def my_modify_hook(request: ModelRequest, state: AgentState, runtime: Runtime) -> ModelRequest:
return request
@after_model
def my_after_hook(state: AgentState, runtime: Runtime) -> None:
return None
# Verify that function names are used as middleware class names
assert my_before_hook.__class__.__name__ == "my_before_hook"
assert my_modify_hook.__class__.__name__ == "my_modify_hook"
assert my_after_hook.__class__.__name__ == "my_after_hook"
def test_hook_config_decorator_on_class_method() -> None:
"""Test hook_config decorator on AgentMiddleware class methods."""
class JumpMiddleware(AgentMiddleware):
@hook_config(can_jump_to=["end", "model"])
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if len(state["messages"]) > 5:
return {"jump_to": "end"}
return None
@hook_config(can_jump_to=["tools"])
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
return {"jump_to": "tools"}
# Verify can_jump_to metadata is preserved
assert getattr(JumpMiddleware.before_model, "__can_jump_to__", []) == ["end", "model"]
assert getattr(JumpMiddleware.after_model, "__can_jump_to__", []) == ["tools"]
def test_can_jump_to_with_before_model_decorator() -> None:
"""Test can_jump_to parameter used with before_model decorator."""
@before_model(can_jump_to=["end"])
def conditional_before(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if len(state["messages"]) > 3:
return {"jump_to": "end"}
return None
# Verify middleware was created and has can_jump_to metadata
assert isinstance(conditional_before, AgentMiddleware)
assert getattr(conditional_before.__class__.before_model, "__can_jump_to__", []) == ["end"]
def test_can_jump_to_with_after_model_decorator() -> None:
"""Test can_jump_to parameter used with after_model decorator."""
@after_model(can_jump_to=["model", "end"])
def conditional_after(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if state["messages"][-1].content == "retry":
return {"jump_to": "model"}
return None
# Verify middleware was created and has can_jump_to metadata
assert isinstance(conditional_after, AgentMiddleware)
assert getattr(conditional_after.__class__.after_model, "__can_jump_to__", []) == [
"model",
"end",
]
def test_can_jump_to_integration() -> None:
"""Test can_jump_to parameter in a full agent."""
calls = []
@before_model(can_jump_to=["end"])
def early_exit(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
calls.append("early_exit")
if state["messages"][0].content == "exit":
return {"jump_to": "end"}
return None
agent = create_agent(model=FakeToolCallingModel(), middleware=[early_exit])
agent = agent.compile()
# Test with early exit
result = agent.invoke({"messages": [HumanMessage("exit")]})
assert calls == ["early_exit"]
assert len(result["messages"]) == 1
# Test without early exit
calls.clear()
result = agent.invoke({"messages": [HumanMessage("hello")]})
assert calls == ["early_exit"]
assert len(result["messages"]) > 1
# Async Decorator Tests
def test_async_before_model_decorator() -> None:
"""Test before_model decorator with async function."""
@before_model(state_schema=CustomState, tools=[test_tool], name="AsyncBeforeModel")
async def async_before_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
return {"custom_field": "async_value"}
assert isinstance(async_before_model, AgentMiddleware)
assert async_before_model.state_schema == CustomState
assert async_before_model.tools == [test_tool]
assert async_before_model.__class__.__name__ == "AsyncBeforeModel"
def test_async_after_model_decorator() -> None:
"""Test after_model decorator with async function."""
@after_model(state_schema=CustomState, tools=[test_tool], name="AsyncAfterModel")
async def async_after_model(state: CustomState, runtime: Runtime) -> dict[str, Any]:
return {"custom_field": "async_value"}
assert isinstance(async_after_model, AgentMiddleware)
assert async_after_model.state_schema == CustomState
assert async_after_model.tools == [test_tool]
assert async_after_model.__class__.__name__ == "AsyncAfterModel"
def test_async_modify_model_request_decorator() -> None:
"""Test modify_model_request decorator with async function."""
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="AsyncModifyRequest")
async def async_modify_request(
request: ModelRequest, state: CustomState, runtime: Runtime
) -> ModelRequest:
request.system_prompt = "Modified async"
return request
assert isinstance(async_modify_request, AgentMiddleware)
assert async_modify_request.state_schema == CustomState
assert async_modify_request.tools == [test_tool]
assert async_modify_request.__class__.__name__ == "AsyncModifyRequest"
def test_mixed_sync_async_decorators() -> None:
"""Test decorators with both sync and async functions."""
@before_model(name="MixedBeforeModel")
def sync_before(state: AgentState, runtime: Runtime) -> None:
return None
@before_model(name="MixedBeforeModel")
async def async_before(state: AgentState, runtime: Runtime) -> None:
return None
@modify_model_request(name="MixedModifyRequest")
def sync_modify(request: ModelRequest, state: AgentState, runtime: Runtime) -> ModelRequest:
return request
@modify_model_request(name="MixedModifyRequest")
async def async_modify(
request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
return request
# Both should create valid middleware instances
assert isinstance(sync_before, AgentMiddleware)
assert isinstance(async_before, AgentMiddleware)
assert isinstance(sync_modify, AgentMiddleware)
assert isinstance(async_modify, AgentMiddleware)
@pytest.mark.asyncio
async def test_async_decorators_integration() -> None:
"""Test async decorators working together in an agent."""
call_order = []
@before_model
async def track_async_before(state: AgentState, runtime: Runtime) -> None:
call_order.append("async_before")
return None
@modify_model_request
async def track_async_modify(
request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
call_order.append("async_modify")
return request
@after_model
async def track_async_after(state: AgentState, runtime: Runtime) -> None:
call_order.append("async_after")
return None
agent = create_agent(
model=FakeToolCallingModel(),
middleware=[track_async_before, track_async_modify, track_async_after],
)
agent = agent.compile()
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
assert call_order == ["async_before", "async_modify", "async_after"]
@pytest.mark.asyncio
async def test_mixed_sync_async_decorators_integration() -> None:
"""Test mixed sync/async decorators working together in an agent."""
call_order = []
@before_model
def track_sync_before(state: AgentState, runtime: Runtime) -> None:
call_order.append("sync_before")
return None
@before_model
async def track_async_before(state: AgentState, runtime: Runtime) -> None:
call_order.append("async_before")
return None
@modify_model_request
def track_sync_modify(
request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
call_order.append("sync_modify")
return request
@modify_model_request
async def track_async_modify(
request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
call_order.append("async_modify")
return request
@after_model
async def track_async_after(state: AgentState, runtime: Runtime) -> None:
call_order.append("async_after")
return None
@after_model
def track_sync_after(state: AgentState, runtime: Runtime) -> None:
call_order.append("sync_after")
return None
agent = create_agent(
model=FakeToolCallingModel(),
middleware=[
track_sync_before,
track_async_before,
track_sync_modify,
track_async_modify,
track_async_after,
track_sync_after,
],
)
agent = agent.compile()
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
assert call_order == [
"sync_before",
"async_before",
"sync_modify",
"async_modify",
"sync_after",
"async_after",
]
def test_async_before_model_preserves_can_jump_to() -> None:
"""Test that can_jump_to metadata is preserved for async before_model functions."""
@before_model(can_jump_to=["end"])
async def async_conditional_before(
state: AgentState, runtime: Runtime
) -> dict[str, Any] | None:
if len(state["messages"]) > 3:
return {"jump_to": "end"}
return None
# Verify middleware was created and has can_jump_to metadata
assert isinstance(async_conditional_before, AgentMiddleware)
assert getattr(async_conditional_before.__class__.abefore_model, "__can_jump_to__", []) == [
"end"
]
def test_async_after_model_preserves_can_jump_to() -> None:
"""Test that can_jump_to metadata is preserved for async after_model functions."""
@after_model(can_jump_to=["model", "end"])
async def async_conditional_after(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if state["messages"][-1].content == "retry":
return {"jump_to": "model"}
return None
# Verify middleware was created and has can_jump_to metadata
assert isinstance(async_conditional_after, AgentMiddleware)
assert getattr(async_conditional_after.__class__.aafter_model, "__can_jump_to__", []) == [
"model",
"end",
]
@pytest.mark.asyncio
async def test_async_can_jump_to_integration() -> None:
"""Test can_jump_to parameter in a full agent with async middleware."""
calls = []
@before_model(can_jump_to=["end"])
async def async_early_exit(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
calls.append("async_early_exit")
if state["messages"][0].content == "exit":
return {"jump_to": "end"}
return None
agent = create_agent(model=FakeToolCallingModel(), middleware=[async_early_exit])
agent = agent.compile()
# Test with early exit
result = await agent.ainvoke({"messages": [HumanMessage("exit")]})
assert calls == ["async_early_exit"]
assert len(result["messages"]) == 1
# Test without early exit
calls.clear()
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
assert calls == ["async_early_exit"]
assert len(result["messages"]) > 1
def test_get_can_jump_to_no_false_positives() -> None:
"""Test that _get_can_jump_to doesn't return false positives for base class methods."""
# Middleware with no overridden methods should return empty list
class EmptyMiddleware(AgentMiddleware):
pass
empty_middleware = EmptyMiddleware()
empty_middleware.tools = []
# Should not return any jump destinations for base class methods
assert _get_can_jump_to(empty_middleware, "before_model") == []
assert _get_can_jump_to(empty_middleware, "after_model") == []
def test_get_can_jump_to_only_overridden_methods() -> None:
"""Test that _get_can_jump_to only checks overridden methods."""
# Middleware with only sync method overridden
class SyncOnlyMiddleware(AgentMiddleware):
@hook_config(can_jump_to=["end"])
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
return None
sync_middleware = SyncOnlyMiddleware()
sync_middleware.tools = []
# Should return can_jump_to from overridden sync method
assert _get_can_jump_to(sync_middleware, "before_model") == ["end"]
# Middleware with only async method overridden
class AsyncOnlyMiddleware(AgentMiddleware):
@hook_config(can_jump_to=["model"])
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
return None
async_middleware = AsyncOnlyMiddleware()
async_middleware.tools = []
# Should return can_jump_to from overridden async method
assert _get_can_jump_to(async_middleware, "after_model") == ["model"]
def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAssertion) -> None:
"""Test that async middleware with can_jump_to creates correct graph structure with conditional edges."""
# Test 1: Async before_model with can_jump_to
@before_model(can_jump_to=["end"])
async def async_before_with_jump(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if len(state["messages"]) > 5:
return {"jump_to": "end"}
return None
agent_async_before = create_agent(
model=FakeToolCallingModel(), middleware=[async_before_with_jump]
)
assert agent_async_before.compile().get_graph().draw_mermaid() == snapshot
# Test 2: Async after_model with can_jump_to
@after_model(can_jump_to=["model", "end"])
async def async_after_with_jump(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if state["messages"][-1].content == "retry":
return {"jump_to": "model"}
return None
agent_async_after = create_agent(
model=FakeToolCallingModel(), middleware=[async_after_with_jump]
)
assert agent_async_after.compile().get_graph().draw_mermaid() == snapshot
# Test 3: Multiple async middleware with can_jump_to
@before_model(can_jump_to=["end"])
async def async_before_early_exit(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
return None
@after_model(can_jump_to=["model"])
async def async_after_retry(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
return None
agent_multiple_async = create_agent(
model=FakeToolCallingModel(),
middleware=[async_before_early_exit, async_after_retry],
)
assert agent_multiple_async.compile().get_graph().draw_mermaid() == snapshot
# Test 4: Mixed sync and async middleware with can_jump_to
@before_model(can_jump_to=["end"])
def sync_before_with_jump(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
return None
@after_model(can_jump_to=["model", "end"])
async def async_after_with_jumps(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
return None
agent_mixed = create_agent(
model=FakeToolCallingModel(),
middleware=[sync_before_with_jump, async_after_with_jumps],
)
assert agent_mixed.compile().get_graph().draw_mermaid() == snapshot

File diff suppressed because it is too large Load Diff

View File

@@ -1,58 +0,0 @@
from collections.abc import Callable
from typing import Union
import pytest
from pydantic import BaseModel
from syrupy.assertion import SnapshotAssertion
from langchain.agents import create_agent
from .model import FakeToolCallingModel
model = FakeToolCallingModel()
def tool() -> None:
"""Testing tool."""
def pre_model_hook() -> None:
"""Pre-model hook."""
def post_model_hook() -> None:
"""Post-model hook."""
class ResponseFormat(BaseModel):
"""Response format for the agent."""
result: str
@pytest.mark.parametrize("tools", [[], [tool]])
@pytest.mark.parametrize("pre_model_hook", [None, pre_model_hook])
@pytest.mark.parametrize("post_model_hook", [None, post_model_hook])
def test_react_agent_graph_structure(
snapshot: SnapshotAssertion,
tools: list[Callable],
pre_model_hook: Union[Callable, None],
post_model_hook: Union[Callable, None],
) -> None:
agent = create_agent(
model,
tools=tools,
pre_model_hook=pre_model_hook,
post_model_hook=post_model_hook,
)
try:
assert agent.get_graph().draw_mermaid(with_styles=False) == snapshot
except Exception as e:
msg = (
"The graph structure has changed. Please update the snapshot."
"Configuration used:\n"
f"tools: {tools}, "
f"pre_model_hook: {pre_model_hook}, "
f"post_model_hook: {post_model_hook}, "
)
raise ValueError(msg) from e

View File

@@ -7,6 +7,7 @@ from typing import Union
from langchain_core.messages import HumanMessage
from langchain.agents import create_agent
from langchain_core.tools import tool
from langchain.agents.structured_output import (
MultipleStructuredOutputsError,
ProviderStrategy,
@@ -74,12 +75,14 @@ location_json_schema = {
}
@tool
def get_weather() -> str:
"""Get the weather."""
return "The weather is sunny and 75°F."
@tool
def get_location() -> str:
"""Get the current location."""

View File

@@ -1,147 +0,0 @@
from __future__ import annotations
import pytest
# Skip this test since langgraph.prebuilt.responses is not available
pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
try:
from langchain_openai import ChatOpenAI
except ImportError:
skip_openai_integration_tests = True
else:
skip_openai_integration_tests = False
AGENT_PROMPT = "You are an HR assistant."
class ToolCalls(BaseSchema):
get_employee_role: int
get_employee_department: int
class AssertionByInvocation(BaseSchema):
prompt: str
tools_with_expected_calls: ToolCalls
expected_last_message: str
expected_structured_response: Optional[Dict[str, Any]]
llm_request_count: int
class TestCase(BaseSchema):
name: str
response_format: Union[Dict[str, Any], List[Dict[str, Any]]]
assertions_by_invocation: List[AssertionByInvocation]
class Employee(BaseModel):
name: str
role: str
department: str
EMPLOYEES: list[Employee] = [
Employee(name="Sabine", role="Developer", department="IT"),
Employee(name="Henrik", role="Product Manager", department="IT"),
Employee(name="Jessica", role="HR", department="People"),
]
TEST_CASES = load_spec("responses", as_model=TestCase)
def _make_tool(fn, *, name: str, description: str):
mock = MagicMock(side_effect=lambda *, name: fn(name=name))
InputModel = create_model(f"{name}_input", name=(str, ...))
@tool(name, description=description, args_schema=InputModel)
def _wrapped(name: str):
return mock(name=name)
return {"tool": _wrapped, "mock": mock}
@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI integration tests are disabled.")
@pytest.mark.parametrize("case", TEST_CASES, ids=[c.name for c in TEST_CASES])
def test_responses_integration_matrix(case: TestCase) -> None:
if case.name == "asking for information that does not fit into the response format":
pytest.xfail(
"currently failing due to undefined behavior when model cannot conform to any of the structured response formats."
)
def get_employee_role(*, name: str) -> Optional[str]:
for e in EMPLOYEES:
if e.name == name:
return e.role
return None
def get_employee_department(*, name: str) -> Optional[str]:
for e in EMPLOYEES:
if e.name == name:
return e.department
return None
role_tool = _make_tool(
get_employee_role,
name="get_employee_role",
description="Get the employee role by name",
)
dept_tool = _make_tool(
get_employee_department,
name="get_employee_department",
description="Get the employee department by name",
)
response_format_spec = case.response_format
if isinstance(response_format_spec, dict):
response_format_spec = [response_format_spec]
# Unwrap nested schema objects
response_format_spec = [item.get("schema", item) for item in response_format_spec]
if len(response_format_spec) == 1:
tool_output = ToolStrategy(response_format_spec[0])
else:
tool_output = ToolStrategy({"oneOf": response_format_spec})
llm_request_count = 0
for assertion in case.assertions_by_invocation:
def on_request(request: httpx.Request) -> None:
nonlocal llm_request_count
llm_request_count += 1
http_client = httpx.Client(
event_hooks={"request": [on_request]},
)
model = ChatOpenAI(
model="gpt-4o",
temperature=0,
http_client=http_client,
)
agent = create_agent(
model,
tools=[role_tool["tool"], dept_tool["tool"]],
prompt=AGENT_PROMPT,
response_format=tool_output,
)
result = agent.invoke({"messages": [HumanMessage(assertion.prompt)]})
# Count tool calls
assert role_tool["mock"].call_count == assertion.tools_with_expected_calls.get_employee_role
assert (
dept_tool["mock"].call_count
== assertion.tools_with_expected_calls.get_employee_department
)
# Count LLM calls
assert llm_request_count == assertion.llm_request_count
# Check last message content
last_message = result["messages"][-1]
assert last_message.content == assertion.expected_last_message
# Check structured response
structured_response_json = result["structured_response"]
assert structured_response_json == assertion.expected_structured_response

View File

@@ -1,107 +0,0 @@
from __future__ import annotations
import pytest
# Skip this test since langgraph.prebuilt.responses is not available
pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
try:
from langchain_openai import ChatOpenAI
except ImportError:
skip_openai_integration_tests = True
else:
skip_openai_integration_tests = False
AGENT_PROMPT = """
You are a strict polling bot.
- Only use the "poll_job" tool until it returns { status: "succeeded" }.
- If status is "pending", call the tool again. Do not produce a final answer.
- When it is "succeeded", return exactly: "Attempts: <number>" with no extra text.
"""
class TestCase(BaseSchema):
name: str
return_direct: bool
response_format: Optional[Dict[str, Any]]
expected_tool_calls: int
expected_last_message: str
expected_structured_response: Optional[Dict[str, Any]]
TEST_CASES = load_spec("return_direct", as_model=TestCase)
def _make_tool(return_direct: bool):
attempts = 0
def _side_effect():
nonlocal attempts
attempts += 1
return {
"status": "succeeded" if attempts >= 10 else "pending",
"attempts": attempts,
}
mock = MagicMock(side_effect=_side_effect)
@tool(
"pollJob",
description=(
"Check the status of a long-running job. "
"Returns { status: 'pending' | 'succeeded', attempts: number }."
),
return_direct=return_direct,
)
def _wrapped():
return mock()
return {"tool": _wrapped, "mock": mock}
@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI integration tests are disabled.")
@pytest.mark.parametrize("case", TEST_CASES, ids=[c.name for c in TEST_CASES])
def test_return_direct_integration_matrix(case: TestCase) -> None:
poll_tool = _make_tool(case.return_direct)
model = ChatOpenAI(
model="gpt-4o",
temperature=0,
)
if case.response_format:
agent = create_agent(
model,
tools=[poll_tool["tool"]],
prompt=AGENT_PROMPT,
response_format=ToolStrategy(case.response_format),
)
else:
agent = create_agent(
model,
tools=[poll_tool["tool"]],
prompt=AGENT_PROMPT,
)
result = agent.invoke(
{
"messages": [
HumanMessage("Poll the job until it's done and tell me how many attempts it took.")
]
}
)
# Count tool calls
assert poll_tool["mock"].call_count == case.expected_tool_calls
# Check last message content
last_message = result["messages"][-1]
assert last_message.content == case.expected_last_message
# Check structured response
if case.expected_structured_response is not None:
structured_response_json = result["structured_response"]
assert structured_response_json == case.expected_structured_response
else:
assert "structured_response" not in result

View File

@@ -0,0 +1,328 @@
"""Combined tests for agent specifications."""
import pytest
# Skip these tests since langgraph.prebuilt.responses is not available
pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
try:
from langchain_openai import ChatOpenAI
except ImportError:
skip_openai_integration_tests = True
else:
skip_openai_integration_tests = False
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional, Union
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain.agents import create_agent
from langchain.agents.structured_output import ToolStrategy
from unittest.mock import MagicMock
# Import specification loading utilities
try:
from .specifications.responses import load_spec
from .specifications.return_direct import load_spec as load_return_direct_spec
except ImportError:
# Fallback if specifications are not available
def load_spec(name, as_model=None):
return []
def load_return_direct_spec(name, as_model=None):
return []
# Test data models for responses specification
class ToolCalls(BaseModel):
get_employee_role: int
get_employee_department: int
class AssertionByInvocation(BaseModel):
prompt: str
tools_with_expected_calls: ToolCalls
expected_last_message: str
expected_structured_response: Optional[Dict[str, Any]]
llm_request_count: int
class TestCase(BaseModel):
name: str
response_format: Union[Dict[str, Any], List[Dict[str, Any]]]
assertions_by_invocation: List[AssertionByInvocation]
class Employee(BaseModel):
name: str
role: str
department: str
EMPLOYEES: list[Employee] = [
Employee(name="Sabine", role="Developer", department="IT"),
Employee(name="Henrik", role="Product Manager", department="IT"),
Employee(name="Jessica", role="HR", department="People"),
]
TEST_CASES = load_spec("responses", as_model=TestCase)
# Test data models for return_direct specification
class ReturnDirectTestCase(BaseModel):
name: str
return_direct: bool
response_format: Optional[Dict[str, Any]]
expected_tool_calls: int
expected_last_message: str
expected_structured_response: Optional[Dict[str, Any]]
RETURN_DIRECT_TEST_CASES = load_return_direct_spec("return_direct", as_model=ReturnDirectTestCase)
# Test tools
@tool
def get_employee_role(employee_name: str) -> str:
"""Get the role of an employee."""
for emp in EMPLOYEES:
if emp.name == employee_name:
return emp.role
return "Employee not found"
@tool
def get_employee_department(employee_name: str) -> str:
"""Get the department of an employee."""
for emp in EMPLOYEES:
if emp.name == employee_name:
return emp.department
return "Employee not found"
@tool
def poll_job() -> Dict[str, Any]:
"""Poll a job status."""
# This will be mocked in tests
return {"status": "pending", "attempts": 1}
# Responses specification tests
class TestResponsesSpecification:
"""Test responses specification functionality."""
def test_responses_specification_loading(self) -> None:
"""Test that responses specification can be loaded."""
assert isinstance(TEST_CASES, list)
# If specifications are available, we should have test cases
# If not, the list will be empty due to the fallback
@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI not available")
def test_responses_specification_with_openai(self) -> None:
"""Test responses specification with OpenAI model."""
if not TEST_CASES:
pytest.skip("No test cases available")
# This would run the actual specification tests if available
# For now, just verify the structure
for test_case in TEST_CASES:
assert hasattr(test_case, "name")
assert hasattr(test_case, "response_format")
assert hasattr(test_case, "assertions_by_invocation")
# Return direct specification tests
class TestReturnDirectSpecification:
"""Test return direct specification functionality."""
def test_return_direct_specification_loading(self) -> None:
"""Test that return direct specification can be loaded."""
assert isinstance(RETURN_DIRECT_TEST_CASES, list)
@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI not available")
def test_return_direct_specification_with_openai(self) -> None:
"""Test return direct specification with OpenAI model."""
if not RETURN_DIRECT_TEST_CASES:
pytest.skip("No test cases available")
# This would run the actual specification tests if available
# For now, just verify the structure
for test_case in RETURN_DIRECT_TEST_CASES:
assert hasattr(test_case, "name")
assert hasattr(test_case, "return_direct")
assert hasattr(test_case, "expected_tool_calls")
# Tool strategy tests
class TestToolStrategy:
"""Test ToolStrategy functionality."""
def test_tool_strategy_basic_creation(self) -> None:
"""Test basic ToolStrategy creation."""
strategy = ToolStrategy(schema=Employee)
assert strategy.schema == Employee
assert strategy.tool_message_content is None
assert len(strategy.schema_specs) == 1
def test_tool_strategy_with_tool_message_content(self) -> None:
"""Test ToolStrategy with tool message content."""
strategy = ToolStrategy(schema=Employee, tool_message_content="custom message")
assert strategy.schema == Employee
assert strategy.tool_message_content == "custom message"
def test_tool_strategy_with_union_schema(self) -> None:
"""Test ToolStrategy with Union schema."""
class CustomModel(BaseModel):
value: float
description: str
strategy = ToolStrategy(schema=Union[Employee, CustomModel])
assert len(strategy.schema_specs) == 2
assert strategy.schema_specs[0].schema == Employee
assert strategy.schema_specs[1].schema == CustomModel
# Agent with specifications tests
class TestAgentWithSpecifications:
"""Test agents with various specifications."""
def test_agent_with_employee_schema(self) -> None:
"""Test agent with employee schema."""
# Mock model for testing
class MockModel:
def invoke(self, messages, **kwargs):
return HumanMessage(content="Mock response")
agent = create_agent(
model=MockModel(),
tools=[get_employee_role, get_employee_department],
response_format=ToolStrategy(schema=Employee),
)
# Test that agent can be created
assert agent is not None
def test_agent_with_polling_tool(self) -> None:
"""Test agent with polling tool."""
# Mock the polling tool
mock_poll = MagicMock()
mock_poll.side_effect = [
{"status": "pending", "attempts": 1},
{"status": "pending", "attempts": 2},
{"status": "succeeded", "attempts": 3},
]
@tool
def mock_poll_job() -> Dict[str, Any]:
"""Mock polling tool."""
return mock_poll()
class MockModel:
def invoke(self, messages, **kwargs):
return HumanMessage(content="Mock response")
agent = create_agent(
model=MockModel(),
tools=[mock_poll_job],
)
# Test that agent can be created
assert agent is not None
def test_agent_with_return_direct_tool(self) -> None:
"""Test agent with return_direct tool."""
@tool
def return_direct_tool(input: str) -> str:
"""Tool that returns directly."""
return f"Direct result: {input}"
class MockModel:
def invoke(self, messages, **kwargs):
return HumanMessage(content="Mock response")
agent = create_agent(
model=MockModel(),
tools=[return_direct_tool],
)
# Test that agent can be created
assert agent is not None
# Specification validation tests
class TestSpecificationValidation:
"""Test specification validation."""
def test_employee_schema_validation(self) -> None:
"""Test employee schema validation."""
# Valid employee
emp = Employee(name="Test", role="Developer", department="IT")
assert emp.name == "Test"
assert emp.role == "Developer"
assert emp.department == "IT"
# Invalid employee (missing required fields)
with pytest.raises(Exception):
Employee(name="Test") # Missing role and department
def test_tool_calls_schema_validation(self) -> None:
"""Test tool calls schema validation."""
tool_calls = ToolCalls(get_employee_role=1, get_employee_department=2)
assert tool_calls.get_employee_role == 1
assert tool_calls.get_employee_department == 2
def test_assertion_schema_validation(self) -> None:
"""Test assertion schema validation."""
tool_calls = ToolCalls(get_employee_role=1, get_employee_department=2)
assertion = AssertionByInvocation(
prompt="Test prompt",
tools_with_expected_calls=tool_calls,
expected_last_message="Expected message",
expected_structured_response={"key": "value"},
llm_request_count=1,
)
assert assertion.prompt == "Test prompt"
assert assertion.llm_request_count == 1
# Integration tests (when specifications are available)
class TestSpecificationIntegration:
"""Test specification integration."""
def test_specification_file_loading(self) -> None:
"""Test that specification files can be loaded."""
# This test verifies that the specification loading mechanism works
# even if the actual specification files are not available
try:
responses_spec = load_spec("responses")
return_direct_spec = load_return_direct_spec("return_direct")
assert isinstance(responses_spec, list)
assert isinstance(return_direct_spec, list)
except Exception:
# If specifications are not available, that's okay for this test
pass
def test_specification_with_mock_data(self) -> None:
"""Test specifications with mock data."""
# Create mock test cases
mock_tool_calls = ToolCalls(get_employee_role=1, get_employee_department=1)
mock_assertion = AssertionByInvocation(
prompt="Test prompt",
tools_with_expected_calls=mock_tool_calls,
expected_last_message="Expected message",
expected_structured_response=None,
llm_request_count=1,
)
mock_test_case = TestCase(
name="Mock test",
response_format={"type": "json_object"},
assertions_by_invocation=[mock_assertion],
)
# Verify mock data structure
assert mock_test_case.name == "Mock test"
assert len(mock_test_case.assertions_by_invocation) == 1
assert mock_test_case.assertions_by_invocation[0].prompt == "Test prompt"

View File

@@ -7,7 +7,6 @@ from typing import (
Any,
NoReturn,
TypeVar,
Union,
)
import pytest
@@ -33,14 +32,13 @@ from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import TypedDict
from langchain.tools import (
ToolNode,
InjectedState,
InjectedStore,
ToolNode,
)
from langchain.tools.tool_node import TOOL_CALL_ERROR_TEMPLATE, ToolInvocationError, tools_condition
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
from .model import FakeToolCallingModel
from tests.unit_tests.agents.messages import _AnyIdHumanMessage, _AnyIdToolMessage
from tests.unit_tests.agents.model import FakeToolCallingModel
pytestmark = pytest.mark.anyio
@@ -1374,7 +1372,7 @@ def test_tool_node_inject_store() -> None:
def test_tool_node_ensure_utf8() -> None:
@dec_tool
def get_day_list(days: list[str]) -> list[str]:
"""choose days"""
"""Choose days"""
return days
data = ["星期一", "水曜日", "목요일", "Friday"]