Compare commits

...

9 Commits

Author SHA1 Message Date
Sydney Runkle
984e1e984f Merge branch 'master' into sr/create-agent-api 2025-10-03 11:15:26 -04:00
Sydney Runkle
e529445b29 no printing 2025-10-03 11:12:57 -04:00
Sydney Runkle
5b972fb0e9 fixes 2025-10-03 11:11:52 -04:00
Sydney Runkle
b8a20329d8 fixing up tests 2025-10-03 10:39:48 -04:00
Sydney Runkle
54e507f331 more fixes 2025-10-03 10:32:22 -04:00
Sydney Runkle
6001543093 fix tests + support builtins 2025-10-03 10:31:22 -04:00
Sydney Runkle
894ffa0be5 minor fixes 2025-10-03 10:04:47 -04:00
Sydney Runkle
56c93fa82f adding back some tests 2025-10-03 09:31:42 -04:00
Sydney Runkle
3743d596c7 beginnings of a refactor 2025-10-03 09:06:47 -04:00
13 changed files with 1032 additions and 3067 deletions

View File

@@ -1,6 +1,7 @@
"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""
from langchain.agents.react_agent import AgentState, create_agent
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentState
__all__ = [
"AgentState",

View File

@@ -1,4 +1,4 @@
"""Middleware agent implementation."""
"""Agent factory implementation."""
import itertools
from collections.abc import Callable, Sequence
@@ -9,9 +9,10 @@ from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMe
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langgraph.constants import END, START
from langgraph.graph.state import StateGraph
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.runtime import Runtime
from langgraph.types import Send
from langgraph.store.base import BaseStore
from langgraph.types import Checkpointer, Send
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
@@ -167,17 +168,142 @@ def _handle_structured_output_error(
def create_agent( # noqa: PLR0915
*,
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
*,
system_prompt: str | None = None,
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
context_schema: type[ContextT] | None = None,
) -> StateGraph[
checkpointer: Checkpointer | None = None,
store: BaseStore | None = None,
interrupt_before: list[str] | None = None,
interrupt_after: list[str] | None = None,
debug: bool = False,
cache: Any = None,
name: str | None = None,
) -> CompiledStateGraph[
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
]:
"""Create a middleware agent graph."""
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
For more details on using `create_agent`,
visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.
Args:
model: The language model for the agent. Supports static model
selection with string identifier (e.g., `"openai:gpt-4"`) or
chat model instance (e.g., `ChatOpenAI()`).
tools: A list of tools or a ToolNode instance.
If an empty list is provided, the agent will consist of a single LLM node
without tool calling.
system_prompt: An optional system prompt for the LLM. This is converted to a
SystemMessage and added to the beginning of the list of messages in
state["messages"].
middleware: An optional sequence of AgentMiddleware instances for customizing
agent behavior. Middleware can:
- **Intercept execution** with before_model/after_model hooks
- **Modify model requests** with modify_model_request
- **Add custom state** via state_schema
- **Control flow** with jump_to (using @hook_config decorator)
See AgentMiddleware documentation for implementation details.
response_format: An optional configuration for structured responses.
If provided, the agent will handle structured output via tool calls
during the normal conversation flow.
When the model calls a structured output tool, the response will be captured
and returned in the 'structured_response' state key.
If not provided, `structured_response` will not be present in the output state.
Can be one of:
- **Pydantic model class**: Automatically converted to ToolStrategy or
ProviderStrategy based on model capabilities
- **ToolStrategy**: Uses tool calling for structured output with error handling
- **ProviderStrategy**: Uses provider-native structured output (OpenAI, etc.)
!!! important
`response_format` requires the model to support tool calling
!!! note
Structured responses are handled directly in the model call node via
tool calls, eliminating the need for separate structured response nodes.
context_schema: An optional schema for runtime context.
checkpointer: An optional checkpoint saver object. This is used for persisting
the state of the graph (e.g., as chat memory) for a single thread
(e.g., a single conversation).
store: An optional store object. This is used for persisting data
across multiple threads (e.g., multiple conversations / users).
interrupt_before: An optional list of node names to interrupt before.
Should be one of the following: "model_request", "tools".
This is useful if you want to add a user confirmation or other interrupt
before taking an action.
interrupt_after: An optional list of node names to interrupt after.
Should be one of the following: "model_request", "tools".
This is useful if you want to return directly or run additional processing on an output.
debug: A flag indicating whether to enable debug mode.
cache: An optional cache object for caching LLM responses.
name: An optional name for the compiled graph.
Returns:
A compiled LangGraph agent that can be used for chat interactions.
The "model_request" node calls the language model with the messages list
(after applying the system prompt if provided).
If the resulting AIMessage contains `tool_calls`,
the graph will then call the ["tools"][langgraph.prebuilt.tool_node.ToolNode].
The "tools" node executes the tools (1 tool per `tool_call`)
and adds the responses to the messages list as `ToolMessage` objects.
The model_request node then calls the language model again.
The process repeats until no more `tool_calls` are present in the response.
The agent then returns the full list of messages as a dictionary containing the key "messages".
``` mermaid
sequenceDiagram
participant U as User
participant A as LLM
participant T as Tools
U->>A: Initial input
Note over A: System Prompt + LLM
loop while tool_calls present
A->>T: Execute tools
T-->>A: ToolMessage for each tool_calls
end
A->>U: Return final state
```
Example:
```python
from langchain.agents import create_agent
def check_weather(location: str) -> str:
'''Return the weather forecast for the specified location.'''
return f"It's always sunny in {location}"
graph = create_agent(
model="anthropic:claude-3-7-sonnet-latest",
tools=[check_weather],
system_prompt="You are a helpful assistant",
)
inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
for chunk in graph.stream(inputs, stream_mode="updates"):
print(chunk)
```
"""
# init chat model
if isinstance(model, str):
model = init_chat_model(model)
@@ -368,26 +494,8 @@ def create_agent( # noqa: PLR0915
def _get_bound_model(request: ModelRequest) -> Runnable:
"""Get the model with appropriate tool bindings."""
# Get actual tool objects from tool names
tools_by_name = {t.name: t for t in default_tools}
unknown_tools = [name for name in request.tools if name not in tools_by_name]
if unknown_tools:
available_tools = sorted(tools_by_name.keys())
msg = (
f"Middleware returned unknown tool names: {unknown_tools}\n\n"
f"Available tools: {available_tools}\n\n"
"To fix this issue:\n"
"1. Ensure the tools are passed to create_agent() via "
"the 'tools' parameter\n"
"2. If using custom middleware with tools, ensure "
"they're registered via middleware.tools attribute\n"
"3. Verify that tool names in ModelRequest.tools match "
"the actual tool.name values"
)
raise ValueError(msg)
requested_tools = [tools_by_name[name] for name in request.tools]
# request.tools contains BaseTool | dict objects
requested_tools = request.tools
if isinstance(response_format, ProviderStrategy):
# Use native structured output
@@ -411,7 +519,7 @@ def create_agent( # noqa: PLR0915
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
tools=[t.name for t in default_tools],
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
@@ -438,17 +546,19 @@ def create_agent( # noqa: PLR0915
messages = [SystemMessage(request.system_prompt), *messages]
output = model_.invoke(messages)
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(output),
}
result = _handle_model_output(output)
# Always add call counts
result["thread_model_call_count"] = state.get("thread_model_call_count", 0) + 1
result["run_model_call_count"] = state.get("run_model_call_count", 0) + 1
return result
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
tools=[t.name for t in default_tools],
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
@@ -466,11 +576,13 @@ def create_agent( # noqa: PLR0915
messages = [SystemMessage(request.system_prompt), *messages]
output = await model_.ainvoke(messages)
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(output),
}
result = _handle_model_output(output)
# Always add call counts
result["thread_model_call_count"] = state.get("thread_model_call_count", 0) + 1
result["run_model_call_count"] = state.get("run_model_call_count", 0) + 1
return result
# Use sync or async based on model capabilities
from langgraph._internal._runnable import RunnableCallable
@@ -595,7 +707,15 @@ def create_agent( # noqa: PLR0915
can_jump_to=_get_can_jump_to(m1, "after_model"),
)
return graph
return graph.compile(
checkpointer=checkpointer,
store=store,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
cache=cache,
name=name,
)
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
@@ -626,8 +746,8 @@ def _fetch_last_ai_and_tool_messages(
def _make_model_to_tools_edge(
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
) -> Callable[[dict[str, Any], Runtime], str | list[Send] | None]:
def model_to_tools(state: dict[str, Any], runtime: Runtime) -> str | list[Send] | None:
# 1. if there's an explicit jump_to in the state, use it
if jump_to := state.get("jump_to"):
return _resolve_jump(jump_to, first_node)
@@ -646,14 +766,19 @@ def _make_model_to_tools_edge(
if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
]
# 3. if there are pending tool calls, jump to the tool node
# 3. if there are pending (non-structured) tool calls, jump to the tool node
if pending_tool_calls:
pending_tool_calls = [
tool_node.inject_tool_args(call, state, None) for call in pending_tool_calls
tool_node.inject_tool_args(call, state, runtime.store)
for call in pending_tool_calls
]
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
# 4. AIMessage has tool calls, but there are no pending tool calls
# 4. if we have a structured response and no pending tool calls, we're done
if "structured_response" in state:
return END
# 5. AIMessage has tool calls, but there are no pending tool calls
# which suggests the injection of artificial tool messages. jump to the first node
return first_node
@@ -666,7 +791,8 @@ def _make_tools_to_model_edge(
def tools_to_model(state: dict[str, Any]) -> str | None:
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
if all(
# Check if any tool call has return_direct=True
if any(
tool_node.tools_by_name[c["name"]].return_direct
for c in last_ai_message.tool_calls
if c["name"] in tool_node.tools_by_name
@@ -692,7 +818,6 @@ def _add_middleware_edge(
Args:
graph: The graph to add the edge to.
method: The method to call for the middleware node.
name: The name of the middleware node.
default_destination: The default destination for the edge.
model_destination: The destination for the edge to the model.
@@ -716,3 +841,6 @@ def _add_middleware_edge(
else:
graph.add_edge(name, default_destination)
__all__ = ["create_agent"]

View File

@@ -62,7 +62,7 @@ class ModelRequest:
system_prompt: str | None
messages: list[AnyMessage] # excluding system prompt
tool_choice: Any | None
tools: list[str]
tools: list[BaseTool | dict[str, Any]]
response_format: ResponseFormat | None
model_settings: dict[str, Any] = field(default_factory=dict)

File diff suppressed because it is too large Load Diff

View File

@@ -108,4 +108,4 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
}
)
return self.bind(tools=tool_dicts)
return self.bind(tools=tool_dicts, **kwargs)

View File

@@ -0,0 +1,767 @@
"""Tests for create_agent functionality."""
import pytest
from langchain_core.messages import (
AIMessage,
HumanMessage,
ToolCall,
ToolMessage,
)
from langchain_core.tools import tool as dec_tool, InjectedToolCallId, ToolException
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
from langgraph.types import Command
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from langchain.agents import (
create_agent,
)
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.agents.structured_output import ToolStrategy
from langchain.tools import (
InjectedState,
InjectedStore,
ToolNode,
)
from langchain.tools.tool_node import (
_get_state_args,
_infer_handled_types,
)
from .any_str import AnyStr
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
from .model import FakeToolCallingModel
pytestmark = pytest.mark.anyio
def test_no_prompt(sync_checkpointer: BaseCheckpointSaver) -> None:
"""Test agent with no system prompt."""
model = FakeToolCallingModel()
agent = create_agent(
model,
[],
checkpointer=sync_checkpointer,
)
inputs = [HumanMessage("hi?")]
thread = {"configurable": {"thread_id": "123"}}
response = agent.invoke({"messages": inputs}, thread, debug=True)
expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
assert response == expected_response
saved = sync_checkpointer.get_tuple(thread)
assert saved is not None
checkpoint_values = saved.checkpoint["channel_values"]
assert checkpoint_values["messages"] == [
_AnyIdHumanMessage(content="hi?"),
AIMessage(content="hi?", id="0"),
]
assert checkpoint_values["thread_model_call_count"] == 1
assert saved.metadata == {
"parents": {},
"source": "loop",
"step": 1,
}
assert saved.pending_writes == []
async def test_no_prompt_async(async_checkpointer: BaseCheckpointSaver) -> None:
"""Test agent with no system prompt (async)."""
model = FakeToolCallingModel()
agent = create_agent(model, [], checkpointer=async_checkpointer)
inputs = [HumanMessage("hi?")]
thread = {"configurable": {"thread_id": "123"}}
response = await agent.ainvoke({"messages": inputs}, thread, debug=True)
expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
assert response == expected_response
saved = await async_checkpointer.aget_tuple(thread)
assert saved is not None
checkpoint_values = saved.checkpoint["channel_values"]
assert checkpoint_values["messages"] == [
_AnyIdHumanMessage(content="hi?"),
AIMessage(content="hi?", id="0"),
]
assert checkpoint_values["thread_model_call_count"] == 1
assert saved.metadata == {
"parents": {},
"source": "loop",
"step": 1,
}
assert saved.pending_writes == []
def test_system_message_prompt() -> None:
"""Test agent with system message prompt."""
system_prompt = "Foo"
agent = create_agent(FakeToolCallingModel(), [], system_prompt=system_prompt)
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]}
assert response == expected_response
def test_string_prompt() -> None:
"""Test agent with string system prompt."""
system_prompt = "Foo"
agent = create_agent(FakeToolCallingModel(), [], system_prompt=system_prompt)
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]}
assert response == expected_response
def test__infer_handled_types() -> None:
"""Test type inference for exception handlers."""
def handle(e) -> str: # type: ignore[no-untyped-def]
return ""
def handle2(e: Exception) -> str:
return ""
def handle3(e: ValueError | ToolException) -> str: # noqa: UP007
return ""
def handle4(e: ValueError | ToolException) -> str: # noqa: UP007
return ""
class Handler:
def handle(self, e: ValueError) -> str:
return ""
handle5 = Handler().handle
def handle6(e: TypeError | ValueError | ToolException) -> str: # noqa: UP007
return ""
expected: tuple = (Exception,)
actual = _infer_handled_types(handle)
assert expected == actual
expected = (Exception,)
actual = _infer_handled_types(handle2)
assert expected == actual
expected = (ValueError, ToolException)
actual = _infer_handled_types(handle3)
assert expected == actual
expected = (ValueError, ToolException)
actual = _infer_handled_types(handle4)
assert expected == actual
expected = (ValueError,)
actual = _infer_handled_types(handle5)
assert expected == actual
expected = (TypeError, ValueError, ToolException)
actual = _infer_handled_types(handle6)
assert expected == actual
with pytest.raises(ValueError):
def handler(e: str) -> str:
return ""
_infer_handled_types(handler)
with pytest.raises(ValueError):
def handler(e: list[Exception]) -> str:
return ""
_infer_handled_types(handler)
with pytest.raises(ValueError):
def handler(e: str | int) -> str: # noqa: UP007
return ""
_infer_handled_types(handler)
def test_agent_with_structured_response() -> None:
"""Test agent with structured response format."""
class WeatherResponse(BaseModel):
temperature: float = Field(description="The temperature in fahrenheit")
tool_calls = [
[{"args": {}, "id": "1", "name": "get_weather"}],
[{"name": "WeatherResponse", "id": "2", "args": {"temperature": 75}}],
]
@dec_tool
def get_weather() -> str:
"""Get the weather"""
return "The weather is sunny and 75°F."
expected_structured_response = WeatherResponse(temperature=75)
model = FakeToolCallingModel[WeatherResponse](
tool_calls=tool_calls, structured_response=expected_structured_response
)
agent = create_agent(
model,
[get_weather],
response_format=WeatherResponse,
)
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
assert response["structured_response"] == expected_structured_response
assert len(response["messages"]) == 5
# Check message types in message history
msg_types = [m.type for m in response["messages"]]
assert msg_types == [
"human", # "What's the weather?"
"ai", # "What's the weather?"
"tool", # "The weather is sunny and 75°F."
"ai", # structured response
"tool", # artificial tool message
]
# Check last message contains "Returning structured response"
assert "Returning structured response" in response["messages"][-1].content
assert "temperature=75.0" in response["messages"][-1].content
class CustomState(AgentState):
"""Custom state for testing."""
user_name: str
def test_agent_update_state(
sync_checkpointer: BaseCheckpointSaver,
) -> None:
"""Test agent that updates custom state."""
@dec_tool
def get_user_name(tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
"""Retrieve user name"""
user_name = "interrupt" # Simplified - in real test this would use interrupt()
return Command(
update={
"user_name": user_name,
"messages": [
ToolMessage("Successfully retrieved user name", tool_call_id=tool_call_id)
],
}
)
class CustomMiddleware(AgentMiddleware[CustomState]):
"""Middleware to add custom state."""
state_schema = CustomState
def modify_model_request(self, request, state, runtime):
"""Add user name to system prompt if available."""
user_name = state.get("user_name")
if user_name:
system_msg = f"User name is {user_name}"
if request.system_prompt:
request.system_prompt = f"{request.system_prompt}\n\n{system_msg}"
else:
request.system_prompt = system_msg
return request
tool_calls = [[{"args": {}, "id": "1", "name": "get_user_name"}], []]
model = FakeToolCallingModel(tool_calls=tool_calls)
agent = create_agent(
model,
[get_user_name],
middleware=[CustomMiddleware()],
checkpointer=sync_checkpointer,
)
config = {"configurable": {"thread_id": "1"}}
# Run until tool is called
response = agent.invoke({"messages": [("user", "what's my name")]}, config)
# confirm that the state was updated
assert response["user_name"] == "interrupt"
assert len(response["messages"]) == 4
tool_message: ToolMessage = response["messages"][-2]
assert tool_message.content == "Successfully retrieved user name"
assert tool_message.tool_call_id == "1"
assert tool_message.name == "get_user_name"
def test_agent_parallel_tool_calls(
sync_checkpointer: BaseCheckpointSaver,
) -> None:
"""Test agent with parallel tool calls."""
human_assistance_execution_count = 0
@dec_tool
def human_assistance(query: str) -> str:
"""Request assistance from a human."""
nonlocal human_assistance_execution_count
human_assistance_execution_count += 1
return "Hello"
get_weather_execution_count = 0
@dec_tool
def get_weather(location: str) -> str:
"""Use this tool to get the weather."""
nonlocal get_weather_execution_count
get_weather_execution_count += 1
return "It's sunny!"
tool_calls = [
[
{"args": {"location": "sf"}, "id": "1", "name": "get_weather"},
{"args": {"query": "request help"}, "id": "2", "name": "human_assistance"},
],
[],
]
model = FakeToolCallingModel(tool_calls=tool_calls)
agent = create_agent(
model,
[human_assistance, get_weather],
checkpointer=sync_checkpointer,
)
config = {"configurable": {"thread_id": "1"}}
query = "Get user assistance and also check the weather"
message_types = []
for event in agent.stream({"messages": [("user", query)]}, config, stream_mode="values"):
if messages := event.get("messages"):
message_types.append([m.type for m in messages])
assert message_types == [
["human"],
["human", "ai"],
["human", "ai", "tool", "tool"],
["human", "ai", "tool", "tool", "ai"],
]
assert human_assistance_execution_count == 1
assert get_weather_execution_count == 1
class AgentStateExtraKey(AgentState):
"""Agent state with extra key."""
foo: int
def test_create_agent_inject_vars() -> None:
"""Test that the agent can inject state and store into tool functions."""
store = InMemoryStore()
namespace = ("test",)
store.put(namespace, "test_key", {"bar": 3})
class CustomMiddleware(AgentMiddleware[AgentStateExtraKey]):
"""Middleware to add custom state."""
state_schema = AgentStateExtraKey
def tool1(
some_val: int,
state: Annotated[dict, InjectedState],
store: Annotated[BaseStore, InjectedStore()],
) -> int:
"""Tool 1 docstring."""
store_val = store.get(namespace, "test_key").value["bar"]
return some_val + state["foo"] + store_val
tool_call = {
"name": "tool1",
"args": {"some_val": 1},
"id": "some 0",
"type": "tool_call",
}
model = FakeToolCallingModel(tool_calls=[[tool_call], []])
agent = create_agent(
model,
ToolNode([tool1], handle_tool_errors=False),
middleware=[CustomMiddleware()],
store=store,
)
result = agent.invoke({"messages": [{"role": "user", "content": "hi"}], "foo": 2})
assert result["messages"] == [
_AnyIdHumanMessage(content="hi"),
AIMessage(content="hi", tool_calls=[tool_call], id="0"),
_AnyIdToolMessage(content="6", name="tool1", tool_call_id="some 0"),
AIMessage("hi-hi-6", id="1"),
]
assert result["foo"] == 2
async def test_return_direct() -> None:
"""Test tools with return_direct flag."""
@dec_tool(return_direct=True)
def tool_return_direct(input: str) -> str:
"""A tool that returns directly."""
return f"Direct result: {input}"
@dec_tool
def tool_normal(input: str) -> str:
"""A normal tool."""
return f"Normal result: {input}"
first_tool_call = [
ToolCall(
name="tool_return_direct",
args={"input": "Test direct"},
id="1",
),
]
expected_ai = AIMessage(
content="Test direct",
id="0",
tool_calls=first_tool_call,
)
model = FakeToolCallingModel(tool_calls=[first_tool_call, []])
agent = create_agent(
model,
[tool_return_direct, tool_normal],
)
# Test direct return for tool_return_direct
result = agent.invoke({"messages": [HumanMessage(content="Test direct", id="hum0")]})
assert result["messages"] == [
HumanMessage(content="Test direct", id="hum0"),
expected_ai,
ToolMessage(
content="Direct result: Test direct",
name="tool_return_direct",
tool_call_id="1",
id=result["messages"][2].id,
),
]
second_tool_call = [
ToolCall(
name="tool_normal",
args={"input": "Test normal"},
id="2",
),
]
model = FakeToolCallingModel(tool_calls=[second_tool_call, []])
agent = create_agent(model, [tool_return_direct, tool_normal])
result = agent.invoke({"messages": [HumanMessage(content="Test normal", id="hum1")]})
assert result["messages"] == [
HumanMessage(content="Test normal", id="hum1"),
AIMessage(content="Test normal", id="0", tool_calls=second_tool_call),
ToolMessage(
content="Normal result: Test normal",
name="tool_normal",
tool_call_id="2",
id=result["messages"][2].id,
),
AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"),
]
both_tool_calls = [
ToolCall(
name="tool_return_direct",
args={"input": "Test both direct"},
id="3",
),
ToolCall(
name="tool_normal",
args={"input": "Test both normal"},
id="4",
),
]
model = FakeToolCallingModel(tool_calls=[both_tool_calls, []])
agent = create_agent(model, [tool_return_direct, tool_normal])
result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]})
assert result["messages"] == [
HumanMessage(content="Test both", id="hum2"),
AIMessage(content="Test both", id="0", tool_calls=both_tool_calls),
ToolMessage(
content="Direct result: Test both direct",
name="tool_return_direct",
tool_call_id="3",
id=result["messages"][2].id,
),
ToolMessage(
content="Normal result: Test both normal",
name="tool_normal",
tool_call_id="4",
id=result["messages"][3].id,
),
]
def test__get_state_args() -> None:
"""Test state argument extraction from tools."""
class Schema1(BaseModel):
a: Annotated[str, InjectedState]
class Schema2(Schema1):
b: Annotated[int, InjectedState("bar")]
@dec_tool(args_schema=Schema2)
def foo(a: str, b: int) -> float:
"""return"""
return 0.0
assert _get_state_args(foo) == {"a": None, "b": "bar"}
def test_agent_with_subgraph_tools(
sync_checkpointer: BaseCheckpointSaver,
) -> None:
"""Test agent with subgraph tools."""
from langgraph.graph import START, StateGraph
from typing_extensions import TypedDict
class State(TypedDict):
a: int
b: int
class Output(TypedDict):
result: int
# Define the subgraphs
def add(state):
return {"result": state["a"] + state["b"]}
add_subgraph = (
StateGraph(State, output_schema=Output).add_node(add).add_edge(START, "add").compile()
)
def multiply(state):
return {"result": state["a"] * state["b"]}
multiply_subgraph = (
StateGraph(State, output_schema=Output)
.add_node(multiply)
.add_edge(START, "multiply")
.compile()
)
# Add subgraphs as tools
def addition(a: int, b: int) -> int:
"""Add two numbers"""
return add_subgraph.invoke({"a": a, "b": b})["result"]
def multiplication(a: int, b: int) -> int:
"""Multiply two numbers"""
return multiply_subgraph.invoke({"a": a, "b": b})["result"]
model = FakeToolCallingModel(
tool_calls=[
[
{"args": {"a": 2, "b": 3}, "id": "1", "name": "addition"},
{"args": {"a": 2, "b": 3}, "id": "2", "name": "multiplication"},
],
[],
]
)
tool_node = ToolNode([addition, multiplication], handle_tool_errors=False)
agent = create_agent(
model,
tool_node,
checkpointer=sync_checkpointer,
)
result = agent.invoke(
{"messages": [HumanMessage(content="What's 2 + 3 and 2 * 3?")]},
config={"configurable": {"thread_id": "1"}},
)
assert result["messages"] == [
_AnyIdHumanMessage(content="What's 2 + 3 and 2 * 3?"),
AIMessage(
content="What's 2 + 3 and 2 * 3?",
id="0",
tool_calls=[
ToolCall(name="addition", args={"a": 2, "b": 3}, id="1"),
ToolCall(name="multiplication", args={"a": 2, "b": 3}, id="2"),
],
),
ToolMessage(content="5", name="addition", tool_call_id="1", id=result["messages"][2].id),
ToolMessage(
content="6",
name="multiplication",
tool_call_id="2",
id=result["messages"][3].id,
),
AIMessage(content="What's 2 + 3 and 2 * 3?-What's 2 + 3 and 2 * 3?-5-6", id="1"),
]
def test_agent_subgraph_streaming_sync() -> None:
"""Test React agent streaming when used as a subgraph node sync version."""
from langchain_core.runnables import RunnableConfig
from langgraph.graph import START, MessagesState, StateGraph
@dec_tool
def get_weather(city: str) -> str:
"""Get the weather of a city."""
return f"The weather of {city} is sunny."
# Create a React agent
model = FakeToolCallingModel(
tool_calls=[
[{"args": {"city": "Tokyo"}, "id": "1", "name": "get_weather"}],
[],
]
)
agent = create_agent(
model,
tools=[get_weather],
system_prompt="You are a helpful travel assistant.",
)
# Create a subgraph that uses the React agent as a node
def react_agent_node(state: MessagesState, config: RunnableConfig) -> MessagesState:
"""Node that runs the React agent and collects streaming output."""
collected_content = ""
# Stream the agent output and collect content
for msg_chunk, _msg_metadata in agent.stream(
{"messages": [("user", state["messages"][-1].content)]},
config,
stream_mode="messages",
):
if hasattr(msg_chunk, "content") and msg_chunk.content:
collected_content += msg_chunk.content
return {"messages": [("assistant", collected_content)]}
# Create the main workflow with the React agent as a subgraph node
workflow = StateGraph(MessagesState)
workflow.add_node("react_agent", react_agent_node)
workflow.add_edge(START, "react_agent")
workflow.add_edge("react_agent", "__end__")
compiled_workflow = workflow.compile()
# Test the streaming functionality
result = compiled_workflow.invoke({"messages": [("user", "What is the weather in Tokyo?")]})
# Verify the result contains expected structure
assert len(result["messages"]) == 2
assert result["messages"][0].content == "What is the weather in Tokyo?"
assert "assistant" in str(result["messages"][1])
# Test streaming with subgraphs = True
result = compiled_workflow.invoke(
{"messages": [("user", "What is the weather in Tokyo?")]},
subgraphs=True,
)
assert len(result["messages"]) == 2
events = []
for event in compiled_workflow.stream(
{"messages": [("user", "What is the weather in Tokyo?")]},
stream_mode="messages",
subgraphs=False,
):
events.append(event)
assert len(events) == 0
events = []
for event in compiled_workflow.stream(
{"messages": [("user", "What is the weather in Tokyo?")]},
stream_mode="messages",
subgraphs=True,
):
events.append(event)
assert len(events) == 3
namespace, (msg, metadata) = events[0]
# FakeToolCallingModel returns a single AIMessage with tool calls
# The content of the AIMessage reflects the input message
assert msg.content.startswith("You are a helpful travel assistant")
namespace, (msg, metadata) = events[1] # ToolMessage
assert msg.content.startswith("The weather of Tokyo is sunny.")
async def test_agent_subgraph_streaming() -> None:
"""Test React agent streaming when used as a subgraph node (async)."""
from langchain_core.runnables import RunnableConfig
from langgraph.graph import START, MessagesState, StateGraph
@dec_tool
def get_weather(city: str) -> str:
"""Get the weather of a city."""
return f"The weather of {city} is sunny."
# Create a React agent
model = FakeToolCallingModel(
tool_calls=[
[{"args": {"city": "Tokyo"}, "id": "1", "name": "get_weather"}],
[],
]
)
agent = create_agent(
model,
tools=[get_weather],
system_prompt="You are a helpful travel assistant.",
)
# Create a subgraph that uses the React agent as a node
async def react_agent_node(state: MessagesState, config: RunnableConfig) -> MessagesState:
"""Node that runs the React agent and collects streaming output."""
collected_content = ""
# Stream the agent output and collect content
async for msg_chunk, _msg_metadata in agent.astream(
{"messages": [("user", state["messages"][-1].content)]},
config,
stream_mode="messages",
):
if hasattr(msg_chunk, "content") and msg_chunk.content:
collected_content += msg_chunk.content
return {"messages": [("assistant", collected_content)]}
# Create the main workflow with the React agent as a subgraph node
workflow = StateGraph(MessagesState)
workflow.add_node("react_agent", react_agent_node)
workflow.add_edge(START, "react_agent")
workflow.add_edge("react_agent", "__end__")
compiled_workflow = workflow.compile()
# Test the streaming functionality
result = await compiled_workflow.ainvoke(
{"messages": [("user", "What is the weather in Tokyo?")]}
)
# Verify the result contains expected structure
assert len(result["messages"]) == 2
assert result["messages"][0].content == "What is the weather in Tokyo?"
assert "assistant" in str(result["messages"][1])
# Test streaming with subgraphs = True
result = await compiled_workflow.ainvoke(
{"messages": [("user", "What is the weather in Tokyo?")]},
subgraphs=True,
)
assert len(result["messages"]) == 2
events = []
async for event in compiled_workflow.astream(
{"messages": [("user", "What is the weather in Tokyo?")]},
stream_mode="messages",
subgraphs=False,
):
events.append(event)
assert len(events) == 0
events = []
async for event in compiled_workflow.astream(
{"messages": [("user", "What is the weather in Tokyo?")]},
stream_mode="messages",
subgraphs=True,
):
events.append(event)
assert len(events) == 3
namespace, (msg, metadata) = events[0]
# FakeToolCallingModel returns a single AIMessage with tool calls
# The content of the AIMessage reflects the input message
assert msg.content.startswith("You are a helpful travel assistant")
namespace, (msg, metadata) = events[1] # ToolMessage
assert msg.content.startswith("The weather of Tokyo is sunny.")

View File

@@ -51,7 +51,7 @@ from langchain.agents.middleware.types import (
OmitFromOutput,
PrivateStateAttr,
)
from langchain.agents.middleware_agent import create_agent
from langchain.agents.factory import create_agent
from langchain.agents.structured_output import ToolStrategy
from langchain.tools import InjectedState
@@ -133,7 +133,7 @@ def test_create_agent_diagram(
system_prompt="You are a helpful assistant.",
)
assert agent_zero.compile().get_graph().draw_mermaid() == snapshot
assert agent_zero.get_graph().draw_mermaid() == snapshot
agent_one = create_agent(
model=FakeToolCallingModel(),
@@ -142,7 +142,7 @@ def test_create_agent_diagram(
middleware=[NoopOne()],
)
assert agent_one.compile().get_graph().draw_mermaid() == snapshot
assert agent_one.get_graph().draw_mermaid() == snapshot
agent_two = create_agent(
model=FakeToolCallingModel(),
@@ -151,7 +151,7 @@ def test_create_agent_diagram(
middleware=[NoopOne(), NoopTwo()],
)
assert agent_two.compile().get_graph().draw_mermaid() == snapshot
assert agent_two.get_graph().draw_mermaid() == snapshot
agent_three = create_agent(
model=FakeToolCallingModel(),
@@ -160,7 +160,7 @@ def test_create_agent_diagram(
middleware=[NoopOne(), NoopTwo(), NoopThree()],
)
assert agent_three.compile().get_graph().draw_mermaid() == snapshot
assert agent_three.get_graph().draw_mermaid() == snapshot
agent_four = create_agent(
model=FakeToolCallingModel(),
@@ -169,7 +169,7 @@ def test_create_agent_diagram(
middleware=[NoopFour()],
)
assert agent_four.compile().get_graph().draw_mermaid() == snapshot
assert agent_four.get_graph().draw_mermaid() == snapshot
agent_five = create_agent(
model=FakeToolCallingModel(),
@@ -178,7 +178,7 @@ def test_create_agent_diagram(
middleware=[NoopFour(), NoopFive()],
)
assert agent_five.compile().get_graph().draw_mermaid() == snapshot
assert agent_five.get_graph().draw_mermaid() == snapshot
agent_six = create_agent(
model=FakeToolCallingModel(),
@@ -187,7 +187,7 @@ def test_create_agent_diagram(
middleware=[NoopFour(), NoopFive(), NoopSix()],
)
assert agent_six.compile().get_graph().draw_mermaid() == snapshot
assert agent_six.get_graph().draw_mermaid() == snapshot
agent_seven = create_agent(
model=FakeToolCallingModel(),
@@ -196,7 +196,7 @@ def test_create_agent_diagram(
middleware=[NoopSeven()],
)
assert agent_seven.compile().get_graph().draw_mermaid() == snapshot
assert agent_seven.get_graph().draw_mermaid() == snapshot
agent_eight = create_agent(
model=FakeToolCallingModel(),
@@ -205,7 +205,7 @@ def test_create_agent_diagram(
middleware=[NoopSeven(), NoopEight()],
)
assert agent_eight.compile().get_graph().draw_mermaid() == snapshot
assert agent_eight.get_graph().draw_mermaid() == snapshot
agent_nine = create_agent(
model=FakeToolCallingModel(),
@@ -214,7 +214,7 @@ def test_create_agent_diagram(
middleware=[NoopSeven(), NoopEight(), NoopNine()],
)
assert agent_nine.compile().get_graph().draw_mermaid() == snapshot
assert agent_nine.get_graph().draw_mermaid() == snapshot
agent_ten = create_agent(
model=FakeToolCallingModel(),
@@ -223,7 +223,7 @@ def test_create_agent_diagram(
middleware=[NoopTen()],
)
assert agent_ten.compile().get_graph().draw_mermaid() == snapshot
assert agent_ten.get_graph().draw_mermaid() == snapshot
agent_eleven = create_agent(
model=FakeToolCallingModel(),
@@ -232,7 +232,7 @@ def test_create_agent_diagram(
middleware=[NoopTen(), NoopEleven()],
)
assert agent_eleven.compile().get_graph().draw_mermaid() == snapshot
assert agent_eleven.get_graph().draw_mermaid() == snapshot
def test_create_agent_invoke(
@@ -281,7 +281,8 @@ def test_create_agent_invoke(
tools=[my_tool],
system_prompt="You are a helpful assistant.",
middleware=[NoopSeven(), NoopEight()],
).compile(checkpointer=sync_checkpointer)
checkpointer=sync_checkpointer,
)
thread1 = {"configurable": {"thread_id": "1"}}
assert agent_one.invoke({"messages": ["hello"]}, thread1) == {
@@ -370,7 +371,8 @@ def test_create_agent_jump(
tools=[my_tool],
system_prompt="You are a helpful assistant.",
middleware=[NoopSeven(), NoopEight()],
).compile(checkpointer=sync_checkpointer)
checkpointer=sync_checkpointer,
)
if isinstance(sync_checkpointer, InMemorySaver):
assert agent_one.get_graph().draw_mermaid() == snapshot
@@ -1214,14 +1216,13 @@ def test_modify_model_request() -> None:
request.messages.append(HumanMessage("remember to be nice!"))
return request
builder = create_agent(
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[ModifyMiddleware()],
)
agent = builder.compile()
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert result["messages"][0].content == "Hello"
assert result["messages"][1].content == "remember to be nice!"
@@ -1268,9 +1269,8 @@ def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
response_format=ToolStrategy(schema=WeatherResponse),
)
# Compile and invoke the agent
compiled_agent = agent.compile()
result = compiled_agent.invoke(
# Invoke the agent
result = agent.invoke(
{"messages": [HumanMessage("What's the weather and help me with a query?")]}
)
@@ -1323,7 +1323,7 @@ def test_public_private_state_for_custom_middleware() -> None:
return {"omit_input": "test", "omit_output": "test", "private_state": "test"}
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
agent = agent.compile()
# agent is already compiled
result = agent.invoke(
{
"messages": [HumanMessage("Hello")],
@@ -1358,7 +1358,7 @@ def test_runtime_injected_into_middleware() -> None:
middleware = CustomMiddleware()
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
agent = agent.compile()
# agent is already compiled
agent.invoke({"messages": [HumanMessage("Hello")]})
@@ -1389,7 +1389,7 @@ def test_injected_state_in_middleware_agent() -> None:
tools=[test_state],
system_prompt="You are a helpful assistant.",
middleware=[TestMiddleware()],
).compile()
)
result = agent.invoke(
{"test_state": "I love pizza", "messages": [HumanMessage("Call the test state tool")]}
@@ -1419,7 +1419,7 @@ def test_jump_to_is_ephemeral() -> None:
return {"jump_to": "model"}
agent = create_agent(model=FakeToolCallingModel(), middleware=[MyMiddleware()])
agent = agent.compile()
# agent is already compiled
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "jump_to" not in result
@@ -1550,7 +1550,7 @@ def test_planning_middleware_agent_creation_with_middleware() -> None:
)
middleware = PlanningMiddleware()
agent = create_agent(model=model, middleware=[middleware])
agent = agent.compile()
# agent is already compiled
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert result["todos"] == [{"content": "Task 1", "status": "completed"}]
@@ -1661,7 +1661,7 @@ def test_planning_middleware_custom_system_prompt() -> None:
)
agent = create_agent(model=model, middleware=[middleware])
agent = agent.compile()
# agent is already compiled
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert result["todos"] == [{"content": "Custom task", "status": "pending"}]
@@ -1735,7 +1735,8 @@ def test_thread_limit_with_create_agent():
model=model,
tools=[simple_tool],
middleware=[ModelCallLimitMiddleware(thread_limit=1)],
).compile(checkpointer=InMemorySaver())
checkpointer=InMemorySaver(),
)
# First invocation should work - 1 model call, within thread limit
result = agent.invoke(
@@ -1778,7 +1779,8 @@ def test_run_limit_with_create_agent():
model=model,
tools=[simple_tool],
middleware=[ModelCallLimitMiddleware(run_limit=1)],
).compile(checkpointer=InMemorySaver())
checkpointer=InMemorySaver(),
)
# This should hit the run limit after the first model call
result = agent.invoke(
@@ -1865,8 +1867,7 @@ def test_run_limit_resets_between_invocations() -> None:
tool_calls=[[], [], [], []]
) # No tool calls, so only model call per run
agent = create_agent(model=model, middleware=[middleware])
agent = agent.compile(checkpointer=InMemorySaver())
agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver())
thread_config = {"configurable": {"thread_id": "test_thread"}}
agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config)
@@ -1913,7 +1914,7 @@ async def test_create_agent_async_invoke() -> None:
tools=[my_tool],
system_prompt="You are a helpful assistant.",
middleware=[AsyncMiddleware()],
).compile()
)
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
@@ -1969,7 +1970,7 @@ async def test_create_agent_async_invoke_multiple_middleware() -> None:
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
).compile()
)
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
@@ -2010,7 +2011,7 @@ async def test_create_agent_async_jump() -> None:
tools=[my_tool],
system_prompt="You are a helpful assistant.",
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
).compile()
)
result = await agent.ainvoke({"messages": []})
@@ -2049,7 +2050,7 @@ async def test_create_agent_mixed_sync_async_middleware() -> None:
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[SyncMiddleware(), AsyncMiddleware()],
).compile()
)
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
@@ -2076,7 +2077,7 @@ def test_create_agent_sync_invoke_with_only_async_middleware_raises_error() -> N
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[AsyncOnlyMiddleware()],
).compile()
)
with pytest.raises(
TypeError,
@@ -2109,7 +2110,7 @@ def test_create_agent_sync_invoke_with_mixed_middleware() -> None:
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[MixedMiddleware()],
).compile()
)
result = agent.invoke({"messages": [HumanMessage("hello")]})

View File

@@ -19,7 +19,7 @@ from langchain.agents.middleware.types import (
modify_model_request,
hook_config,
)
from langchain.agents.middleware_agent import create_agent, _get_can_jump_to
from langchain.agents.factory import create_agent, _get_can_jump_to
from .model import FakeToolCallingModel
@@ -136,7 +136,7 @@ def test_all_decorators_integration() -> None:
agent = create_agent(
model=FakeToolCallingModel(), middleware=[track_before, track_modify, track_after]
)
agent = agent.compile()
# agent is already compiled
agent.invoke({"messages": [HumanMessage("Hello")]})
assert call_order == ["before", "modify", "after"]
@@ -225,7 +225,7 @@ def test_can_jump_to_integration() -> None:
return None
agent = create_agent(model=FakeToolCallingModel(), middleware=[early_exit])
agent = agent.compile()
# agent is already compiled
# Test with early exit
result = agent.invoke({"messages": [HumanMessage("exit")]})
@@ -338,7 +338,7 @@ async def test_async_decorators_integration() -> None:
model=FakeToolCallingModel(),
middleware=[track_async_before, track_async_modify, track_async_after],
)
agent = agent.compile()
# agent is already compiled
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
assert call_order == ["async_before", "async_modify", "async_after"]
@@ -394,7 +394,7 @@ async def test_mixed_sync_async_decorators_integration() -> None:
track_sync_after,
],
)
agent = agent.compile()
# agent is already compiled
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
assert call_order == [
@@ -455,7 +455,7 @@ async def test_async_can_jump_to_integration() -> None:
return None
agent = create_agent(model=FakeToolCallingModel(), middleware=[async_early_exit])
agent = agent.compile()
# agent is already compiled
# Test with early exit
result = await agent.ainvoke({"messages": [HumanMessage("exit")]})
@@ -526,7 +526,7 @@ def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAsse
model=FakeToolCallingModel(), middleware=[async_before_with_jump]
)
assert agent_async_before.compile().get_graph().draw_mermaid() == snapshot
assert agent_async_before.get_graph().draw_mermaid() == snapshot
# Test 2: Async after_model with can_jump_to
@after_model(can_jump_to=["model", "end"])
@@ -539,7 +539,7 @@ def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAsse
model=FakeToolCallingModel(), middleware=[async_after_with_jump]
)
assert agent_async_after.compile().get_graph().draw_mermaid() == snapshot
assert agent_async_after.get_graph().draw_mermaid() == snapshot
# Test 3: Multiple async middleware with can_jump_to
@before_model(can_jump_to=["end"])
@@ -555,7 +555,7 @@ def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAsse
middleware=[async_before_early_exit, async_after_retry],
)
assert agent_multiple_async.compile().get_graph().draw_mermaid() == snapshot
assert agent_multiple_async.get_graph().draw_mermaid() == snapshot
# Test 4: Mixed sync and async middleware with can_jump_to
@before_model(can_jump_to=["end"])
@@ -571,4 +571,4 @@ def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAsse
middleware=[sync_before_with_jump, async_after_with_jumps],
)
assert agent_mixed.compile().get_graph().draw_mermaid() == snapshot
assert agent_mixed.get_graph().draw_mermaid() == snapshot

View File

@@ -3,15 +3,15 @@
import pytest
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
from langchain.agents.middleware_agent import create_agent
from langchain.agents.factory import create_agent
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.tools import tool
from .model import FakeToolCallingModel
from langgraph.runtime import Runtime
def test_model_request_tools_are_strings() -> None:
"""Test that ModelRequest.tools contains tool names as strings, not tool objects."""
def test_model_request_tools_are_objects() -> None:
"""Test that ModelRequest.tools contains tool objects (BaseTool | dict)."""
captured_requests: list[ModelRequest] = []
@tool
@@ -36,23 +36,23 @@ def test_model_request_tools_are_strings() -> None:
tools=[search_tool, calculator],
system_prompt="You are a helpful assistant.",
middleware=[RequestCapturingMiddleware()],
).compile()
)
agent.invoke({"messages": [HumanMessage("Hello")]})
# Verify that at least one request was captured
assert len(captured_requests) > 0
# Check that tools in the request are strings (tool names)
# Check that tools in the request are tool objects
request = captured_requests[0]
assert isinstance(request.tools, list)
assert len(request.tools) == 2
assert all(isinstance(tool_name, str) for tool_name in request.tools)
assert set(request.tools) == {"search_tool", "calculator"}
tool_names = {t.name for t in request.tools}
assert tool_names == {"search_tool", "calculator"}
def test_middleware_can_modify_tool_names() -> None:
"""Test that middleware can modify the list of tool names in ModelRequest."""
def test_middleware_can_modify_tools() -> None:
"""Test that middleware can modify the list of tools in ModelRequest."""
@tool
def tool_a(input: str) -> str:
@@ -74,7 +74,7 @@ def test_middleware_can_modify_tool_names() -> None:
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
# Only allow tool_a and tool_b
request.tools = ["tool_a", "tool_b"]
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
return request
# Model will try to call tool_a
@@ -87,7 +87,7 @@ def test_middleware_can_modify_tool_names() -> None:
tools=[tool_a, tool_b, tool_c],
system_prompt="You are a helpful assistant.",
middleware=[ToolFilteringMiddleware()],
).compile()
)
result = agent.invoke({"messages": [HumanMessage("Use tool_a")]})
@@ -98,31 +98,37 @@ def test_middleware_can_modify_tool_names() -> None:
assert tool_messages[0].name == "tool_a"
def test_unknown_tool_name_raises_error() -> None:
"""Test that using an unknown tool name in ModelRequest raises a clear error."""
def test_middleware_can_add_custom_tools() -> None:
"""Test that middleware can add custom tool objects to ModelRequest."""
@tool
def known_tool(input: str) -> str:
"""A known tool."""
return "result"
class BadMiddleware(AgentMiddleware):
@tool
def custom_tool(input: str) -> str:
"""A custom tool added by middleware."""
return "custom result"
class ToolAddingMiddleware(AgentMiddleware):
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
# Add an unknown tool name
request.tools = ["known_tool", "unknown_tool"]
# Add a custom tool
request.tools = list(request.tools) + [custom_tool]
return request
agent = create_agent(
model=FakeToolCallingModel(),
tools=[known_tool],
system_prompt="You are a helpful assistant.",
middleware=[BadMiddleware()],
).compile()
middleware=[ToolAddingMiddleware()],
)
with pytest.raises(ValueError, match="Middleware returned unknown tool names"):
agent.invoke({"messages": [HumanMessage("Hello")]})
# Should work fine with custom tools added
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "messages" in result
def test_middleware_can_add_and_remove_tools() -> None:
@@ -149,7 +155,7 @@ def test_middleware_can_add_and_remove_tools() -> None:
) -> ModelRequest:
# Remove admin_tool if not admin
if not state.get("is_admin", False):
request.tools = [name for name in request.tools if name != "admin_tool"]
request.tools = [t for t in request.tools if t.name != "admin_tool"]
return request
model = FakeToolCallingModel()
@@ -159,7 +165,7 @@ def test_middleware_can_add_and_remove_tools() -> None:
tools=[search, admin_tool],
system_prompt="You are a helpful assistant.",
middleware=[ConditionalToolMiddleware()],
).compile()
)
# Test non-admin user - should not have access to admin_tool
# We can't directly inspect the bound model, but we can verify the agent runs
@@ -194,7 +200,7 @@ def test_empty_tools_list_is_valid() -> None:
tools=[some_tool],
system_prompt="You are a helpful assistant.",
middleware=[NoToolsMiddleware()],
).compile()
)
# Should run without error even with no tools
result = agent.invoke({"messages": [HumanMessage("Hello")]})
@@ -224,20 +230,20 @@ def test_tools_preserved_across_multiple_middleware() -> None:
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
modification_order.append(request.tools.copy())
modification_order.append([t.name for t in request.tools])
# Remove tool_c
request.tools = [name for name in request.tools if name != "tool_c"]
request.tools = [t for t in request.tools if t.name != "tool_c"]
return request
class SecondMiddleware(AgentMiddleware):
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
modification_order.append(request.tools.copy())
modification_order.append([t.name for t in request.tools])
# Should not see tool_c here
assert "tool_c" not in request.tools
assert "tool_c" not in [t.name for t in request.tools]
# Remove tool_b
request.tools = [name for name in request.tools if name != "tool_b"]
request.tools = [t for t in request.tools if t.name != "tool_b"]
return request
agent = create_agent(
@@ -245,7 +251,7 @@ def test_tools_preserved_across_multiple_middleware() -> None:
tools=[tool_a, tool_b, tool_c],
system_prompt="You are a helpful assistant.",
middleware=[FirstMiddleware(), SecondMiddleware()],
).compile()
)
agent.invoke({"messages": [HumanMessage("Hello")]})
@@ -286,7 +292,7 @@ def test_middleware_with_additional_tools() -> None:
tools=[base_tool],
system_prompt="You are a helpful assistant.",
middleware=[ToolProvidingMiddleware()],
).compile()
)
result = agent.invoke({"messages": [HumanMessage("Use middleware tool")]})

File diff suppressed because it is too large Load Diff

View File

@@ -1,58 +0,0 @@
from collections.abc import Callable
from typing import Union
import pytest
from pydantic import BaseModel
from syrupy.assertion import SnapshotAssertion
from langchain.agents import create_agent
from .model import FakeToolCallingModel
model = FakeToolCallingModel()
def tool() -> None:
"""Testing tool."""
def pre_model_hook() -> None:
"""Pre-model hook."""
def post_model_hook() -> None:
"""Post-model hook."""
class ResponseFormat(BaseModel):
"""Response format for the agent."""
result: str
@pytest.mark.parametrize("tools", [[], [tool]])
@pytest.mark.parametrize("pre_model_hook", [None, pre_model_hook])
@pytest.mark.parametrize("post_model_hook", [None, post_model_hook])
def test_react_agent_graph_structure(
snapshot: SnapshotAssertion,
tools: list[Callable],
pre_model_hook: Union[Callable, None],
post_model_hook: Union[Callable, None],
) -> None:
agent = create_agent(
model,
tools=tools,
pre_model_hook=pre_model_hook,
post_model_hook=post_model_hook,
)
try:
assert agent.get_graph().draw_mermaid(with_styles=False) == snapshot
except Exception as e:
msg = (
"The graph structure has changed. Please update the snapshot."
"Configuration used:\n"
f"tools: {tools}, "
f"pre_model_hook: {pre_model_hook}, "
f"post_model_hook: {post_model_hook}, "
)
raise ValueError(msg) from e

View File

@@ -7,6 +7,7 @@ from typing import Union
from langchain_core.messages import HumanMessage
from langchain.agents import create_agent
from langchain_core.tools import tool
from langchain.agents.structured_output import (
MultipleStructuredOutputsError,
ProviderStrategy,
@@ -74,12 +75,14 @@ location_json_schema = {
}
@tool
def get_weather() -> str:
"""Get the weather."""
return "The weather is sunny and 75°F."
@tool
def get_location() -> str:
"""Get the current location."""

View File

@@ -36,6 +36,7 @@ from langchain.tools import (
ToolNode,
InjectedState,
InjectedStore,
ToolNode,
)
from langchain.tools.tool_node import TOOL_CALL_ERROR_TEMPLATE, ToolInvocationError, tools_condition