mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-09 10:41:52 +00:00
Compare commits
9 Commits
langchain=
...
sr/create-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
984e1e984f | ||
|
|
e529445b29 | ||
|
|
5b972fb0e9 | ||
|
|
b8a20329d8 | ||
|
|
54e507f331 | ||
|
|
6001543093 | ||
|
|
894ffa0be5 | ||
|
|
56c93fa82f | ||
|
|
3743d596c7 |
@@ -1,6 +1,7 @@
|
||||
"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""
|
||||
|
||||
from langchain.agents.react_agent import AgentState, create_agent
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
|
||||
__all__ = [
|
||||
"AgentState",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Middleware agent implementation."""
|
||||
"""Agent factory implementation."""
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
@@ -9,9 +9,10 @@ from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMe
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph.state import StateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Send
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.types import Checkpointer, Send
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||
|
||||
@@ -167,17 +168,142 @@ def _handle_structured_output_error(
|
||||
|
||||
|
||||
def create_agent( # noqa: PLR0915
|
||||
*,
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
|
||||
*,
|
||||
system_prompt: str | None = None,
|
||||
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
|
||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
) -> StateGraph[
|
||||
checkpointer: Checkpointer | None = None,
|
||||
store: BaseStore | None = None,
|
||||
interrupt_before: list[str] | None = None,
|
||||
interrupt_after: list[str] | None = None,
|
||||
debug: bool = False,
|
||||
cache: Any = None,
|
||||
name: str | None = None,
|
||||
) -> CompiledStateGraph[
|
||||
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
||||
]:
|
||||
"""Create a middleware agent graph."""
|
||||
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
|
||||
|
||||
For more details on using `create_agent`,
|
||||
visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.
|
||||
|
||||
Args:
|
||||
model: The language model for the agent. Supports static model
|
||||
selection with string identifier (e.g., `"openai:gpt-4"`) or
|
||||
chat model instance (e.g., `ChatOpenAI()`).
|
||||
|
||||
tools: A list of tools or a ToolNode instance.
|
||||
If an empty list is provided, the agent will consist of a single LLM node
|
||||
without tool calling.
|
||||
|
||||
system_prompt: An optional system prompt for the LLM. This is converted to a
|
||||
SystemMessage and added to the beginning of the list of messages in
|
||||
state["messages"].
|
||||
|
||||
middleware: An optional sequence of AgentMiddleware instances for customizing
|
||||
agent behavior. Middleware can:
|
||||
|
||||
- **Intercept execution** with before_model/after_model hooks
|
||||
- **Modify model requests** with modify_model_request
|
||||
- **Add custom state** via state_schema
|
||||
- **Control flow** with jump_to (using @hook_config decorator)
|
||||
|
||||
See AgentMiddleware documentation for implementation details.
|
||||
|
||||
response_format: An optional configuration for structured responses.
|
||||
|
||||
If provided, the agent will handle structured output via tool calls
|
||||
during the normal conversation flow.
|
||||
When the model calls a structured output tool, the response will be captured
|
||||
and returned in the 'structured_response' state key.
|
||||
If not provided, `structured_response` will not be present in the output state.
|
||||
|
||||
Can be one of:
|
||||
|
||||
- **Pydantic model class**: Automatically converted to ToolStrategy or
|
||||
ProviderStrategy based on model capabilities
|
||||
- **ToolStrategy**: Uses tool calling for structured output with error handling
|
||||
- **ProviderStrategy**: Uses provider-native structured output (OpenAI, etc.)
|
||||
|
||||
!!! important
|
||||
`response_format` requires the model to support tool calling
|
||||
|
||||
!!! note
|
||||
Structured responses are handled directly in the model call node via
|
||||
tool calls, eliminating the need for separate structured response nodes.
|
||||
|
||||
context_schema: An optional schema for runtime context.
|
||||
|
||||
checkpointer: An optional checkpoint saver object. This is used for persisting
|
||||
the state of the graph (e.g., as chat memory) for a single thread
|
||||
(e.g., a single conversation).
|
||||
|
||||
store: An optional store object. This is used for persisting data
|
||||
across multiple threads (e.g., multiple conversations / users).
|
||||
|
||||
interrupt_before: An optional list of node names to interrupt before.
|
||||
Should be one of the following: "model_request", "tools".
|
||||
This is useful if you want to add a user confirmation or other interrupt
|
||||
before taking an action.
|
||||
|
||||
interrupt_after: An optional list of node names to interrupt after.
|
||||
Should be one of the following: "model_request", "tools".
|
||||
This is useful if you want to return directly or run additional processing on an output.
|
||||
|
||||
debug: A flag indicating whether to enable debug mode.
|
||||
|
||||
cache: An optional cache object for caching LLM responses.
|
||||
|
||||
name: An optional name for the compiled graph.
|
||||
|
||||
Returns:
|
||||
A compiled LangGraph agent that can be used for chat interactions.
|
||||
|
||||
The "model_request" node calls the language model with the messages list
|
||||
(after applying the system prompt if provided).
|
||||
If the resulting AIMessage contains `tool_calls`,
|
||||
the graph will then call the ["tools"][langgraph.prebuilt.tool_node.ToolNode].
|
||||
The "tools" node executes the tools (1 tool per `tool_call`)
|
||||
and adds the responses to the messages list as `ToolMessage` objects.
|
||||
The model_request node then calls the language model again.
|
||||
The process repeats until no more `tool_calls` are present in the response.
|
||||
The agent then returns the full list of messages as a dictionary containing the key "messages".
|
||||
|
||||
``` mermaid
|
||||
sequenceDiagram
|
||||
participant U as User
|
||||
participant A as LLM
|
||||
participant T as Tools
|
||||
U->>A: Initial input
|
||||
Note over A: System Prompt + LLM
|
||||
loop while tool_calls present
|
||||
A->>T: Execute tools
|
||||
T-->>A: ToolMessage for each tool_calls
|
||||
end
|
||||
A->>U: Return final state
|
||||
```
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
|
||||
def check_weather(location: str) -> str:
|
||||
'''Return the weather forecast for the specified location.'''
|
||||
return f"It's always sunny in {location}"
|
||||
|
||||
graph = create_agent(
|
||||
model="anthropic:claude-3-7-sonnet-latest",
|
||||
tools=[check_weather],
|
||||
system_prompt="You are a helpful assistant",
|
||||
)
|
||||
inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
|
||||
for chunk in graph.stream(inputs, stream_mode="updates"):
|
||||
print(chunk)
|
||||
```
|
||||
"""
|
||||
# init chat model
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
@@ -368,26 +494,8 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> Runnable:
|
||||
"""Get the model with appropriate tool bindings."""
|
||||
# Get actual tool objects from tool names
|
||||
tools_by_name = {t.name: t for t in default_tools}
|
||||
|
||||
unknown_tools = [name for name in request.tools if name not in tools_by_name]
|
||||
if unknown_tools:
|
||||
available_tools = sorted(tools_by_name.keys())
|
||||
msg = (
|
||||
f"Middleware returned unknown tool names: {unknown_tools}\n\n"
|
||||
f"Available tools: {available_tools}\n\n"
|
||||
"To fix this issue:\n"
|
||||
"1. Ensure the tools are passed to create_agent() via "
|
||||
"the 'tools' parameter\n"
|
||||
"2. If using custom middleware with tools, ensure "
|
||||
"they're registered via middleware.tools attribute\n"
|
||||
"3. Verify that tool names in ModelRequest.tools match "
|
||||
"the actual tool.name values"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
requested_tools = [tools_by_name[name] for name in request.tools]
|
||||
# request.tools contains BaseTool | dict objects
|
||||
requested_tools = request.tools
|
||||
|
||||
if isinstance(response_format, ProviderStrategy):
|
||||
# Use native structured output
|
||||
@@ -411,7 +519,7 @@ def create_agent( # noqa: PLR0915
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=[t.name for t in default_tools],
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
@@ -438,17 +546,19 @@ def create_agent( # noqa: PLR0915
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
output = model_.invoke(messages)
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(output),
|
||||
}
|
||||
result = _handle_model_output(output)
|
||||
|
||||
# Always add call counts
|
||||
result["thread_model_call_count"] = state.get("thread_model_call_count", 0) + 1
|
||||
result["run_model_call_count"] = state.get("run_model_call_count", 0) + 1
|
||||
|
||||
return result
|
||||
|
||||
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=[t.name for t in default_tools],
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
@@ -466,11 +576,13 @@ def create_agent( # noqa: PLR0915
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
output = await model_.ainvoke(messages)
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(output),
|
||||
}
|
||||
result = _handle_model_output(output)
|
||||
|
||||
# Always add call counts
|
||||
result["thread_model_call_count"] = state.get("thread_model_call_count", 0) + 1
|
||||
result["run_model_call_count"] = state.get("run_model_call_count", 0) + 1
|
||||
|
||||
return result
|
||||
|
||||
# Use sync or async based on model capabilities
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
@@ -595,7 +707,15 @@ def create_agent( # noqa: PLR0915
|
||||
can_jump_to=_get_can_jump_to(m1, "after_model"),
|
||||
)
|
||||
|
||||
return graph
|
||||
return graph.compile(
|
||||
checkpointer=checkpointer,
|
||||
store=store,
|
||||
interrupt_before=interrupt_before,
|
||||
interrupt_after=interrupt_after,
|
||||
debug=debug,
|
||||
cache=cache,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
||||
@@ -626,8 +746,8 @@ def _fetch_last_ai_and_tool_messages(
|
||||
|
||||
def _make_model_to_tools_edge(
|
||||
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
|
||||
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
||||
def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
|
||||
) -> Callable[[dict[str, Any], Runtime], str | list[Send] | None]:
|
||||
def model_to_tools(state: dict[str, Any], runtime: Runtime) -> str | list[Send] | None:
|
||||
# 1. if there's an explicit jump_to in the state, use it
|
||||
if jump_to := state.get("jump_to"):
|
||||
return _resolve_jump(jump_to, first_node)
|
||||
@@ -646,14 +766,19 @@ def _make_model_to_tools_edge(
|
||||
if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
|
||||
]
|
||||
|
||||
# 3. if there are pending tool calls, jump to the tool node
|
||||
# 3. if there are pending (non-structured) tool calls, jump to the tool node
|
||||
if pending_tool_calls:
|
||||
pending_tool_calls = [
|
||||
tool_node.inject_tool_args(call, state, None) for call in pending_tool_calls
|
||||
tool_node.inject_tool_args(call, state, runtime.store)
|
||||
for call in pending_tool_calls
|
||||
]
|
||||
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
||||
|
||||
# 4. AIMessage has tool calls, but there are no pending tool calls
|
||||
# 4. if we have a structured response and no pending tool calls, we're done
|
||||
if "structured_response" in state:
|
||||
return END
|
||||
|
||||
# 5. AIMessage has tool calls, but there are no pending tool calls
|
||||
# which suggests the injection of artificial tool messages. jump to the first node
|
||||
return first_node
|
||||
|
||||
@@ -666,7 +791,8 @@ def _make_tools_to_model_edge(
|
||||
def tools_to_model(state: dict[str, Any]) -> str | None:
|
||||
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||
|
||||
if all(
|
||||
# Check if any tool call has return_direct=True
|
||||
if any(
|
||||
tool_node.tools_by_name[c["name"]].return_direct
|
||||
for c in last_ai_message.tool_calls
|
||||
if c["name"] in tool_node.tools_by_name
|
||||
@@ -692,7 +818,6 @@ def _add_middleware_edge(
|
||||
|
||||
Args:
|
||||
graph: The graph to add the edge to.
|
||||
method: The method to call for the middleware node.
|
||||
name: The name of the middleware node.
|
||||
default_destination: The default destination for the edge.
|
||||
model_destination: The destination for the edge to the model.
|
||||
@@ -716,3 +841,6 @@ def _add_middleware_edge(
|
||||
|
||||
else:
|
||||
graph.add_edge(name, default_destination)
|
||||
|
||||
|
||||
__all__ = ["create_agent"]
|
||||
@@ -62,7 +62,7 @@ class ModelRequest:
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage] # excluding system prompt
|
||||
tool_choice: Any | None
|
||||
tools: list[str]
|
||||
tools: list[BaseTool | dict[str, Any]]
|
||||
response_format: ResponseFormat | None
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -108,4 +108,4 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
|
||||
}
|
||||
)
|
||||
|
||||
return self.bind(tools=tool_dicts)
|
||||
return self.bind(tools=tool_dicts, **kwargs)
|
||||
|
||||
767
libs/langchain_v1/tests/unit_tests/agents/test_create_agent.py
Normal file
767
libs/langchain_v1/tests/unit_tests/agents/test_create_agent.py
Normal file
@@ -0,0 +1,767 @@
|
||||
"""Tests for create_agent functionality."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.tools import tool as dec_tool, InjectedToolCallId, ToolException
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from langgraph.types import Command
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain.agents import (
|
||||
create_agent,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
from langchain.tools import (
|
||||
InjectedState,
|
||||
InjectedStore,
|
||||
ToolNode,
|
||||
)
|
||||
from langchain.tools.tool_node import (
|
||||
_get_state_args,
|
||||
_infer_handled_types,
|
||||
)
|
||||
|
||||
from .any_str import AnyStr
|
||||
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
|
||||
def test_no_prompt(sync_checkpointer: BaseCheckpointSaver) -> None:
|
||||
"""Test agent with no system prompt."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[],
|
||||
checkpointer=sync_checkpointer,
|
||||
)
|
||||
inputs = [HumanMessage("hi?")]
|
||||
thread = {"configurable": {"thread_id": "123"}}
|
||||
response = agent.invoke({"messages": inputs}, thread, debug=True)
|
||||
expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
|
||||
assert response == expected_response
|
||||
|
||||
saved = sync_checkpointer.get_tuple(thread)
|
||||
assert saved is not None
|
||||
checkpoint_values = saved.checkpoint["channel_values"]
|
||||
assert checkpoint_values["messages"] == [
|
||||
_AnyIdHumanMessage(content="hi?"),
|
||||
AIMessage(content="hi?", id="0"),
|
||||
]
|
||||
assert checkpoint_values["thread_model_call_count"] == 1
|
||||
assert saved.metadata == {
|
||||
"parents": {},
|
||||
"source": "loop",
|
||||
"step": 1,
|
||||
}
|
||||
assert saved.pending_writes == []
|
||||
|
||||
|
||||
async def test_no_prompt_async(async_checkpointer: BaseCheckpointSaver) -> None:
|
||||
"""Test agent with no system prompt (async)."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
agent = create_agent(model, [], checkpointer=async_checkpointer)
|
||||
inputs = [HumanMessage("hi?")]
|
||||
thread = {"configurable": {"thread_id": "123"}}
|
||||
response = await agent.ainvoke({"messages": inputs}, thread, debug=True)
|
||||
expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
|
||||
assert response == expected_response
|
||||
|
||||
saved = await async_checkpointer.aget_tuple(thread)
|
||||
assert saved is not None
|
||||
checkpoint_values = saved.checkpoint["channel_values"]
|
||||
assert checkpoint_values["messages"] == [
|
||||
_AnyIdHumanMessage(content="hi?"),
|
||||
AIMessage(content="hi?", id="0"),
|
||||
]
|
||||
assert checkpoint_values["thread_model_call_count"] == 1
|
||||
assert saved.metadata == {
|
||||
"parents": {},
|
||||
"source": "loop",
|
||||
"step": 1,
|
||||
}
|
||||
assert saved.pending_writes == []
|
||||
|
||||
|
||||
def test_system_message_prompt() -> None:
|
||||
"""Test agent with system message prompt."""
|
||||
system_prompt = "Foo"
|
||||
agent = create_agent(FakeToolCallingModel(), [], system_prompt=system_prompt)
|
||||
inputs = [HumanMessage("hi?")]
|
||||
response = agent.invoke({"messages": inputs})
|
||||
expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]}
|
||||
assert response == expected_response
|
||||
|
||||
|
||||
def test_string_prompt() -> None:
|
||||
"""Test agent with string system prompt."""
|
||||
system_prompt = "Foo"
|
||||
agent = create_agent(FakeToolCallingModel(), [], system_prompt=system_prompt)
|
||||
inputs = [HumanMessage("hi?")]
|
||||
response = agent.invoke({"messages": inputs})
|
||||
expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]}
|
||||
assert response == expected_response
|
||||
|
||||
|
||||
def test__infer_handled_types() -> None:
|
||||
"""Test type inference for exception handlers."""
|
||||
|
||||
def handle(e) -> str: # type: ignore[no-untyped-def]
|
||||
return ""
|
||||
|
||||
def handle2(e: Exception) -> str:
|
||||
return ""
|
||||
|
||||
def handle3(e: ValueError | ToolException) -> str: # noqa: UP007
|
||||
return ""
|
||||
|
||||
def handle4(e: ValueError | ToolException) -> str: # noqa: UP007
|
||||
return ""
|
||||
|
||||
class Handler:
|
||||
def handle(self, e: ValueError) -> str:
|
||||
return ""
|
||||
|
||||
handle5 = Handler().handle
|
||||
|
||||
def handle6(e: TypeError | ValueError | ToolException) -> str: # noqa: UP007
|
||||
return ""
|
||||
|
||||
expected: tuple = (Exception,)
|
||||
actual = _infer_handled_types(handle)
|
||||
assert expected == actual
|
||||
|
||||
expected = (Exception,)
|
||||
actual = _infer_handled_types(handle2)
|
||||
assert expected == actual
|
||||
|
||||
expected = (ValueError, ToolException)
|
||||
actual = _infer_handled_types(handle3)
|
||||
assert expected == actual
|
||||
|
||||
expected = (ValueError, ToolException)
|
||||
actual = _infer_handled_types(handle4)
|
||||
assert expected == actual
|
||||
|
||||
expected = (ValueError,)
|
||||
actual = _infer_handled_types(handle5)
|
||||
assert expected == actual
|
||||
|
||||
expected = (TypeError, ValueError, ToolException)
|
||||
actual = _infer_handled_types(handle6)
|
||||
assert expected == actual
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
def handler(e: str) -> str:
|
||||
return ""
|
||||
|
||||
_infer_handled_types(handler)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
def handler(e: list[Exception]) -> str:
|
||||
return ""
|
||||
|
||||
_infer_handled_types(handler)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
def handler(e: str | int) -> str: # noqa: UP007
|
||||
return ""
|
||||
|
||||
_infer_handled_types(handler)
|
||||
|
||||
|
||||
def test_agent_with_structured_response() -> None:
|
||||
"""Test agent with structured response format."""
|
||||
|
||||
class WeatherResponse(BaseModel):
|
||||
temperature: float = Field(description="The temperature in fahrenheit")
|
||||
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[{"name": "WeatherResponse", "id": "2", "args": {"temperature": 75}}],
|
||||
]
|
||||
|
||||
@dec_tool
|
||||
def get_weather() -> str:
|
||||
"""Get the weather"""
|
||||
return "The weather is sunny and 75°F."
|
||||
|
||||
expected_structured_response = WeatherResponse(temperature=75)
|
||||
model = FakeToolCallingModel[WeatherResponse](
|
||||
tool_calls=tool_calls, structured_response=expected_structured_response
|
||||
)
|
||||
agent = create_agent(
|
||||
model,
|
||||
[get_weather],
|
||||
response_format=WeatherResponse,
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
assert response["structured_response"] == expected_structured_response
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
# Check message types in message history
|
||||
msg_types = [m.type for m in response["messages"]]
|
||||
assert msg_types == [
|
||||
"human", # "What's the weather?"
|
||||
"ai", # "What's the weather?"
|
||||
"tool", # "The weather is sunny and 75°F."
|
||||
"ai", # structured response
|
||||
"tool", # artificial tool message
|
||||
]
|
||||
|
||||
# Check last message contains "Returning structured response"
|
||||
assert "Returning structured response" in response["messages"][-1].content
|
||||
assert "temperature=75.0" in response["messages"][-1].content
|
||||
|
||||
|
||||
class CustomState(AgentState):
|
||||
"""Custom state for testing."""
|
||||
|
||||
user_name: str
|
||||
|
||||
|
||||
def test_agent_update_state(
|
||||
sync_checkpointer: BaseCheckpointSaver,
|
||||
) -> None:
|
||||
"""Test agent that updates custom state."""
|
||||
|
||||
@dec_tool
|
||||
def get_user_name(tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
|
||||
"""Retrieve user name"""
|
||||
user_name = "interrupt" # Simplified - in real test this would use interrupt()
|
||||
return Command(
|
||||
update={
|
||||
"user_name": user_name,
|
||||
"messages": [
|
||||
ToolMessage("Successfully retrieved user name", tool_call_id=tool_call_id)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
class CustomMiddleware(AgentMiddleware[CustomState]):
|
||||
"""Middleware to add custom state."""
|
||||
|
||||
state_schema = CustomState
|
||||
|
||||
def modify_model_request(self, request, state, runtime):
|
||||
"""Add user name to system prompt if available."""
|
||||
user_name = state.get("user_name")
|
||||
if user_name:
|
||||
system_msg = f"User name is {user_name}"
|
||||
if request.system_prompt:
|
||||
request.system_prompt = f"{request.system_prompt}\n\n{system_msg}"
|
||||
else:
|
||||
request.system_prompt = system_msg
|
||||
return request
|
||||
|
||||
tool_calls = [[{"args": {}, "id": "1", "name": "get_user_name"}], []]
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
agent = create_agent(
|
||||
model,
|
||||
[get_user_name],
|
||||
middleware=[CustomMiddleware()],
|
||||
checkpointer=sync_checkpointer,
|
||||
)
|
||||
config = {"configurable": {"thread_id": "1"}}
|
||||
# Run until tool is called
|
||||
response = agent.invoke({"messages": [("user", "what's my name")]}, config)
|
||||
# confirm that the state was updated
|
||||
assert response["user_name"] == "interrupt"
|
||||
assert len(response["messages"]) == 4
|
||||
tool_message: ToolMessage = response["messages"][-2]
|
||||
assert tool_message.content == "Successfully retrieved user name"
|
||||
assert tool_message.tool_call_id == "1"
|
||||
assert tool_message.name == "get_user_name"
|
||||
|
||||
|
||||
def test_agent_parallel_tool_calls(
|
||||
sync_checkpointer: BaseCheckpointSaver,
|
||||
) -> None:
|
||||
"""Test agent with parallel tool calls."""
|
||||
human_assistance_execution_count = 0
|
||||
|
||||
@dec_tool
|
||||
def human_assistance(query: str) -> str:
|
||||
"""Request assistance from a human."""
|
||||
nonlocal human_assistance_execution_count
|
||||
human_assistance_execution_count += 1
|
||||
return "Hello"
|
||||
|
||||
get_weather_execution_count = 0
|
||||
|
||||
@dec_tool
|
||||
def get_weather(location: str) -> str:
|
||||
"""Use this tool to get the weather."""
|
||||
nonlocal get_weather_execution_count
|
||||
get_weather_execution_count += 1
|
||||
return "It's sunny!"
|
||||
|
||||
tool_calls = [
|
||||
[
|
||||
{"args": {"location": "sf"}, "id": "1", "name": "get_weather"},
|
||||
{"args": {"query": "request help"}, "id": "2", "name": "human_assistance"},
|
||||
],
|
||||
[],
|
||||
]
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
agent = create_agent(
|
||||
model,
|
||||
[human_assistance, get_weather],
|
||||
checkpointer=sync_checkpointer,
|
||||
)
|
||||
config = {"configurable": {"thread_id": "1"}}
|
||||
query = "Get user assistance and also check the weather"
|
||||
message_types = []
|
||||
for event in agent.stream({"messages": [("user", query)]}, config, stream_mode="values"):
|
||||
if messages := event.get("messages"):
|
||||
message_types.append([m.type for m in messages])
|
||||
|
||||
assert message_types == [
|
||||
["human"],
|
||||
["human", "ai"],
|
||||
["human", "ai", "tool", "tool"],
|
||||
["human", "ai", "tool", "tool", "ai"],
|
||||
]
|
||||
|
||||
assert human_assistance_execution_count == 1
|
||||
assert get_weather_execution_count == 1
|
||||
|
||||
|
||||
class AgentStateExtraKey(AgentState):
|
||||
"""Agent state with extra key."""
|
||||
|
||||
foo: int
|
||||
|
||||
|
||||
def test_create_agent_inject_vars() -> None:
|
||||
"""Test that the agent can inject state and store into tool functions."""
|
||||
store = InMemoryStore()
|
||||
namespace = ("test",)
|
||||
store.put(namespace, "test_key", {"bar": 3})
|
||||
|
||||
class CustomMiddleware(AgentMiddleware[AgentStateExtraKey]):
|
||||
"""Middleware to add custom state."""
|
||||
|
||||
state_schema = AgentStateExtraKey
|
||||
|
||||
def tool1(
|
||||
some_val: int,
|
||||
state: Annotated[dict, InjectedState],
|
||||
store: Annotated[BaseStore, InjectedStore()],
|
||||
) -> int:
|
||||
"""Tool 1 docstring."""
|
||||
store_val = store.get(namespace, "test_key").value["bar"]
|
||||
return some_val + state["foo"] + store_val
|
||||
|
||||
tool_call = {
|
||||
"name": "tool1",
|
||||
"args": {"some_val": 1},
|
||||
"id": "some 0",
|
||||
"type": "tool_call",
|
||||
}
|
||||
model = FakeToolCallingModel(tool_calls=[[tool_call], []])
|
||||
agent = create_agent(
|
||||
model,
|
||||
ToolNode([tool1], handle_tool_errors=False),
|
||||
middleware=[CustomMiddleware()],
|
||||
store=store,
|
||||
)
|
||||
result = agent.invoke({"messages": [{"role": "user", "content": "hi"}], "foo": 2})
|
||||
assert result["messages"] == [
|
||||
_AnyIdHumanMessage(content="hi"),
|
||||
AIMessage(content="hi", tool_calls=[tool_call], id="0"),
|
||||
_AnyIdToolMessage(content="6", name="tool1", tool_call_id="some 0"),
|
||||
AIMessage("hi-hi-6", id="1"),
|
||||
]
|
||||
assert result["foo"] == 2
|
||||
|
||||
|
||||
async def test_return_direct() -> None:
|
||||
"""Test tools with return_direct flag."""
|
||||
|
||||
@dec_tool(return_direct=True)
|
||||
def tool_return_direct(input: str) -> str:
|
||||
"""A tool that returns directly."""
|
||||
return f"Direct result: {input}"
|
||||
|
||||
@dec_tool
|
||||
def tool_normal(input: str) -> str:
|
||||
"""A normal tool."""
|
||||
return f"Normal result: {input}"
|
||||
|
||||
first_tool_call = [
|
||||
ToolCall(
|
||||
name="tool_return_direct",
|
||||
args={"input": "Test direct"},
|
||||
id="1",
|
||||
),
|
||||
]
|
||||
expected_ai = AIMessage(
|
||||
content="Test direct",
|
||||
id="0",
|
||||
tool_calls=first_tool_call,
|
||||
)
|
||||
model = FakeToolCallingModel(tool_calls=[first_tool_call, []])
|
||||
agent = create_agent(
|
||||
model,
|
||||
[tool_return_direct, tool_normal],
|
||||
)
|
||||
|
||||
# Test direct return for tool_return_direct
|
||||
result = agent.invoke({"messages": [HumanMessage(content="Test direct", id="hum0")]})
|
||||
assert result["messages"] == [
|
||||
HumanMessage(content="Test direct", id="hum0"),
|
||||
expected_ai,
|
||||
ToolMessage(
|
||||
content="Direct result: Test direct",
|
||||
name="tool_return_direct",
|
||||
tool_call_id="1",
|
||||
id=result["messages"][2].id,
|
||||
),
|
||||
]
|
||||
second_tool_call = [
|
||||
ToolCall(
|
||||
name="tool_normal",
|
||||
args={"input": "Test normal"},
|
||||
id="2",
|
||||
),
|
||||
]
|
||||
model = FakeToolCallingModel(tool_calls=[second_tool_call, []])
|
||||
agent = create_agent(model, [tool_return_direct, tool_normal])
|
||||
result = agent.invoke({"messages": [HumanMessage(content="Test normal", id="hum1")]})
|
||||
assert result["messages"] == [
|
||||
HumanMessage(content="Test normal", id="hum1"),
|
||||
AIMessage(content="Test normal", id="0", tool_calls=second_tool_call),
|
||||
ToolMessage(
|
||||
content="Normal result: Test normal",
|
||||
name="tool_normal",
|
||||
tool_call_id="2",
|
||||
id=result["messages"][2].id,
|
||||
),
|
||||
AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"),
|
||||
]
|
||||
|
||||
both_tool_calls = [
|
||||
ToolCall(
|
||||
name="tool_return_direct",
|
||||
args={"input": "Test both direct"},
|
||||
id="3",
|
||||
),
|
||||
ToolCall(
|
||||
name="tool_normal",
|
||||
args={"input": "Test both normal"},
|
||||
id="4",
|
||||
),
|
||||
]
|
||||
model = FakeToolCallingModel(tool_calls=[both_tool_calls, []])
|
||||
agent = create_agent(model, [tool_return_direct, tool_normal])
|
||||
result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]})
|
||||
assert result["messages"] == [
|
||||
HumanMessage(content="Test both", id="hum2"),
|
||||
AIMessage(content="Test both", id="0", tool_calls=both_tool_calls),
|
||||
ToolMessage(
|
||||
content="Direct result: Test both direct",
|
||||
name="tool_return_direct",
|
||||
tool_call_id="3",
|
||||
id=result["messages"][2].id,
|
||||
),
|
||||
ToolMessage(
|
||||
content="Normal result: Test both normal",
|
||||
name="tool_normal",
|
||||
tool_call_id="4",
|
||||
id=result["messages"][3].id,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test__get_state_args() -> None:
|
||||
"""Test state argument extraction from tools."""
|
||||
|
||||
class Schema1(BaseModel):
|
||||
a: Annotated[str, InjectedState]
|
||||
|
||||
class Schema2(Schema1):
|
||||
b: Annotated[int, InjectedState("bar")]
|
||||
|
||||
@dec_tool(args_schema=Schema2)
|
||||
def foo(a: str, b: int) -> float:
|
||||
"""return"""
|
||||
return 0.0
|
||||
|
||||
assert _get_state_args(foo) == {"a": None, "b": "bar"}
|
||||
|
||||
|
||||
def test_agent_with_subgraph_tools(
|
||||
sync_checkpointer: BaseCheckpointSaver,
|
||||
) -> None:
|
||||
"""Test agent with subgraph tools."""
|
||||
from langgraph.graph import START, StateGraph
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
class State(TypedDict):
|
||||
a: int
|
||||
b: int
|
||||
|
||||
class Output(TypedDict):
|
||||
result: int
|
||||
|
||||
# Define the subgraphs
|
||||
def add(state):
|
||||
return {"result": state["a"] + state["b"]}
|
||||
|
||||
add_subgraph = (
|
||||
StateGraph(State, output_schema=Output).add_node(add).add_edge(START, "add").compile()
|
||||
)
|
||||
|
||||
def multiply(state):
|
||||
return {"result": state["a"] * state["b"]}
|
||||
|
||||
multiply_subgraph = (
|
||||
StateGraph(State, output_schema=Output)
|
||||
.add_node(multiply)
|
||||
.add_edge(START, "multiply")
|
||||
.compile()
|
||||
)
|
||||
|
||||
# Add subgraphs as tools
|
||||
def addition(a: int, b: int) -> int:
|
||||
"""Add two numbers"""
|
||||
return add_subgraph.invoke({"a": a, "b": b})["result"]
|
||||
|
||||
def multiplication(a: int, b: int) -> int:
|
||||
"""Multiply two numbers"""
|
||||
return multiply_subgraph.invoke({"a": a, "b": b})["result"]
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{"args": {"a": 2, "b": 3}, "id": "1", "name": "addition"},
|
||||
{"args": {"a": 2, "b": 3}, "id": "2", "name": "multiplication"},
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
tool_node = ToolNode([addition, multiplication], handle_tool_errors=False)
|
||||
agent = create_agent(
|
||||
model,
|
||||
tool_node,
|
||||
checkpointer=sync_checkpointer,
|
||||
)
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage(content="What's 2 + 3 and 2 * 3?")]},
|
||||
config={"configurable": {"thread_id": "1"}},
|
||||
)
|
||||
assert result["messages"] == [
|
||||
_AnyIdHumanMessage(content="What's 2 + 3 and 2 * 3?"),
|
||||
AIMessage(
|
||||
content="What's 2 + 3 and 2 * 3?",
|
||||
id="0",
|
||||
tool_calls=[
|
||||
ToolCall(name="addition", args={"a": 2, "b": 3}, id="1"),
|
||||
ToolCall(name="multiplication", args={"a": 2, "b": 3}, id="2"),
|
||||
],
|
||||
),
|
||||
ToolMessage(content="5", name="addition", tool_call_id="1", id=result["messages"][2].id),
|
||||
ToolMessage(
|
||||
content="6",
|
||||
name="multiplication",
|
||||
tool_call_id="2",
|
||||
id=result["messages"][3].id,
|
||||
),
|
||||
AIMessage(content="What's 2 + 3 and 2 * 3?-What's 2 + 3 and 2 * 3?-5-6", id="1"),
|
||||
]
|
||||
|
||||
|
||||
def test_agent_subgraph_streaming_sync() -> None:
|
||||
"""Test React agent streaming when used as a subgraph node sync version."""
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import START, MessagesState, StateGraph
|
||||
|
||||
@dec_tool
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get the weather of a city."""
|
||||
return f"The weather of {city} is sunny."
|
||||
|
||||
# Create a React agent
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"args": {"city": "Tokyo"}, "id": "1", "name": "get_weather"}],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=[get_weather],
|
||||
system_prompt="You are a helpful travel assistant.",
|
||||
)
|
||||
|
||||
# Create a subgraph that uses the React agent as a node
|
||||
def react_agent_node(state: MessagesState, config: RunnableConfig) -> MessagesState:
|
||||
"""Node that runs the React agent and collects streaming output."""
|
||||
collected_content = ""
|
||||
|
||||
# Stream the agent output and collect content
|
||||
for msg_chunk, _msg_metadata in agent.stream(
|
||||
{"messages": [("user", state["messages"][-1].content)]},
|
||||
config,
|
||||
stream_mode="messages",
|
||||
):
|
||||
if hasattr(msg_chunk, "content") and msg_chunk.content:
|
||||
collected_content += msg_chunk.content
|
||||
|
||||
return {"messages": [("assistant", collected_content)]}
|
||||
|
||||
# Create the main workflow with the React agent as a subgraph node
|
||||
workflow = StateGraph(MessagesState)
|
||||
workflow.add_node("react_agent", react_agent_node)
|
||||
workflow.add_edge(START, "react_agent")
|
||||
workflow.add_edge("react_agent", "__end__")
|
||||
compiled_workflow = workflow.compile()
|
||||
|
||||
# Test the streaming functionality
|
||||
result = compiled_workflow.invoke({"messages": [("user", "What is the weather in Tokyo?")]})
|
||||
|
||||
# Verify the result contains expected structure
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][0].content == "What is the weather in Tokyo?"
|
||||
assert "assistant" in str(result["messages"][1])
|
||||
|
||||
# Test streaming with subgraphs = True
|
||||
result = compiled_workflow.invoke(
|
||||
{"messages": [("user", "What is the weather in Tokyo?")]},
|
||||
subgraphs=True,
|
||||
)
|
||||
assert len(result["messages"]) == 2
|
||||
|
||||
events = []
|
||||
for event in compiled_workflow.stream(
|
||||
{"messages": [("user", "What is the weather in Tokyo?")]},
|
||||
stream_mode="messages",
|
||||
subgraphs=False,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 0
|
||||
|
||||
events = []
|
||||
for event in compiled_workflow.stream(
|
||||
{"messages": [("user", "What is the weather in Tokyo?")]},
|
||||
stream_mode="messages",
|
||||
subgraphs=True,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 3
|
||||
namespace, (msg, metadata) = events[0]
|
||||
# FakeToolCallingModel returns a single AIMessage with tool calls
|
||||
# The content of the AIMessage reflects the input message
|
||||
assert msg.content.startswith("You are a helpful travel assistant")
|
||||
namespace, (msg, metadata) = events[1] # ToolMessage
|
||||
assert msg.content.startswith("The weather of Tokyo is sunny.")
|
||||
|
||||
|
||||
async def test_agent_subgraph_streaming() -> None:
|
||||
"""Test React agent streaming when used as a subgraph node (async)."""
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import START, MessagesState, StateGraph
|
||||
|
||||
@dec_tool
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get the weather of a city."""
|
||||
return f"The weather of {city} is sunny."
|
||||
|
||||
# Create a React agent
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"args": {"city": "Tokyo"}, "id": "1", "name": "get_weather"}],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=[get_weather],
|
||||
system_prompt="You are a helpful travel assistant.",
|
||||
)
|
||||
|
||||
# Create a subgraph that uses the React agent as a node
|
||||
async def react_agent_node(state: MessagesState, config: RunnableConfig) -> MessagesState:
|
||||
"""Node that runs the React agent and collects streaming output."""
|
||||
collected_content = ""
|
||||
|
||||
# Stream the agent output and collect content
|
||||
async for msg_chunk, _msg_metadata in agent.astream(
|
||||
{"messages": [("user", state["messages"][-1].content)]},
|
||||
config,
|
||||
stream_mode="messages",
|
||||
):
|
||||
if hasattr(msg_chunk, "content") and msg_chunk.content:
|
||||
collected_content += msg_chunk.content
|
||||
|
||||
return {"messages": [("assistant", collected_content)]}
|
||||
|
||||
# Create the main workflow with the React agent as a subgraph node
|
||||
workflow = StateGraph(MessagesState)
|
||||
workflow.add_node("react_agent", react_agent_node)
|
||||
workflow.add_edge(START, "react_agent")
|
||||
workflow.add_edge("react_agent", "__end__")
|
||||
compiled_workflow = workflow.compile()
|
||||
|
||||
# Test the streaming functionality
|
||||
result = await compiled_workflow.ainvoke(
|
||||
{"messages": [("user", "What is the weather in Tokyo?")]}
|
||||
)
|
||||
|
||||
# Verify the result contains expected structure
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][0].content == "What is the weather in Tokyo?"
|
||||
assert "assistant" in str(result["messages"][1])
|
||||
|
||||
# Test streaming with subgraphs = True
|
||||
result = await compiled_workflow.ainvoke(
|
||||
{"messages": [("user", "What is the weather in Tokyo?")]},
|
||||
subgraphs=True,
|
||||
)
|
||||
assert len(result["messages"]) == 2
|
||||
|
||||
events = []
|
||||
async for event in compiled_workflow.astream(
|
||||
{"messages": [("user", "What is the weather in Tokyo?")]},
|
||||
stream_mode="messages",
|
||||
subgraphs=False,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 0
|
||||
|
||||
events = []
|
||||
async for event in compiled_workflow.astream(
|
||||
{"messages": [("user", "What is the weather in Tokyo?")]},
|
||||
stream_mode="messages",
|
||||
subgraphs=True,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 3
|
||||
namespace, (msg, metadata) = events[0]
|
||||
# FakeToolCallingModel returns a single AIMessage with tool calls
|
||||
# The content of the AIMessage reflects the input message
|
||||
assert msg.content.startswith("You are a helpful travel assistant")
|
||||
namespace, (msg, metadata) = events[1] # ToolMessage
|
||||
assert msg.content.startswith("The weather of Tokyo is sunny.")
|
||||
@@ -51,7 +51,7 @@ from langchain.agents.middleware.types import (
|
||||
OmitFromOutput,
|
||||
PrivateStateAttr,
|
||||
)
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
from langchain.tools import InjectedState
|
||||
|
||||
@@ -133,7 +133,7 @@ def test_create_agent_diagram(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
assert agent_zero.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_zero.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_one = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -142,7 +142,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopOne()],
|
||||
)
|
||||
|
||||
assert agent_one.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_one.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_two = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -151,7 +151,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopOne(), NoopTwo()],
|
||||
)
|
||||
|
||||
assert agent_two.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_two.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_three = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -160,7 +160,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopOne(), NoopTwo(), NoopThree()],
|
||||
)
|
||||
|
||||
assert agent_three.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_three.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_four = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -169,7 +169,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopFour()],
|
||||
)
|
||||
|
||||
assert agent_four.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_four.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_five = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -178,7 +178,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopFour(), NoopFive()],
|
||||
)
|
||||
|
||||
assert agent_five.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_five.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_six = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -187,7 +187,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopFour(), NoopFive(), NoopSix()],
|
||||
)
|
||||
|
||||
assert agent_six.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_six.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_seven = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -196,7 +196,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopSeven()],
|
||||
)
|
||||
|
||||
assert agent_seven.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_seven.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_eight = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -205,7 +205,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopSeven(), NoopEight()],
|
||||
)
|
||||
|
||||
assert agent_eight.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_eight.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_nine = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -214,7 +214,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopSeven(), NoopEight(), NoopNine()],
|
||||
)
|
||||
|
||||
assert agent_nine.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_nine.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_ten = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -223,7 +223,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopTen()],
|
||||
)
|
||||
|
||||
assert agent_ten.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_ten.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_eleven = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -232,7 +232,7 @@ def test_create_agent_diagram(
|
||||
middleware=[NoopTen(), NoopEleven()],
|
||||
)
|
||||
|
||||
assert agent_eleven.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_eleven.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
|
||||
def test_create_agent_invoke(
|
||||
@@ -281,7 +281,8 @@ def test_create_agent_invoke(
|
||||
tools=[my_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven(), NoopEight()],
|
||||
).compile(checkpointer=sync_checkpointer)
|
||||
checkpointer=sync_checkpointer,
|
||||
)
|
||||
|
||||
thread1 = {"configurable": {"thread_id": "1"}}
|
||||
assert agent_one.invoke({"messages": ["hello"]}, thread1) == {
|
||||
@@ -370,7 +371,8 @@ def test_create_agent_jump(
|
||||
tools=[my_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven(), NoopEight()],
|
||||
).compile(checkpointer=sync_checkpointer)
|
||||
checkpointer=sync_checkpointer,
|
||||
)
|
||||
|
||||
if isinstance(sync_checkpointer, InMemorySaver):
|
||||
assert agent_one.get_graph().draw_mermaid() == snapshot
|
||||
@@ -1214,14 +1216,13 @@ def test_modify_model_request() -> None:
|
||||
request.messages.append(HumanMessage("remember to be nice!"))
|
||||
return request
|
||||
|
||||
builder = create_agent(
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ModifyMiddleware()],
|
||||
)
|
||||
|
||||
agent = builder.compile()
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["messages"][0].content == "Hello"
|
||||
assert result["messages"][1].content == "remember to be nice!"
|
||||
@@ -1268,9 +1269,8 @@ def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
|
||||
response_format=ToolStrategy(schema=WeatherResponse),
|
||||
)
|
||||
|
||||
# Compile and invoke the agent
|
||||
compiled_agent = agent.compile()
|
||||
result = compiled_agent.invoke(
|
||||
# Invoke the agent
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("What's the weather and help me with a query?")]}
|
||||
)
|
||||
|
||||
@@ -1323,7 +1323,7 @@ def test_public_private_state_for_custom_middleware() -> None:
|
||||
return {"omit_input": "test", "omit_output": "test", "private_state": "test"}
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
|
||||
agent = agent.compile()
|
||||
# agent is already compiled
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Hello")],
|
||||
@@ -1358,7 +1358,7 @@ def test_runtime_injected_into_middleware() -> None:
|
||||
middleware = CustomMiddleware()
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
|
||||
agent = agent.compile()
|
||||
# agent is already compiled
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
|
||||
@@ -1389,7 +1389,7 @@ def test_injected_state_in_middleware_agent() -> None:
|
||||
tools=[test_state],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[TestMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{"test_state": "I love pizza", "messages": [HumanMessage("Call the test state tool")]}
|
||||
@@ -1419,7 +1419,7 @@ def test_jump_to_is_ephemeral() -> None:
|
||||
return {"jump_to": "model"}
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[MyMiddleware()])
|
||||
agent = agent.compile()
|
||||
# agent is already compiled
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "jump_to" not in result
|
||||
|
||||
@@ -1550,7 +1550,7 @@ def test_planning_middleware_agent_creation_with_middleware() -> None:
|
||||
)
|
||||
middleware = PlanningMiddleware()
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
agent = agent.compile()
|
||||
# agent is already compiled
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["todos"] == [{"content": "Task 1", "status": "completed"}]
|
||||
@@ -1661,7 +1661,7 @@ def test_planning_middleware_custom_system_prompt() -> None:
|
||||
)
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
agent = agent.compile()
|
||||
# agent is already compiled
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["todos"] == [{"content": "Custom task", "status": "pending"}]
|
||||
@@ -1735,7 +1735,8 @@ def test_thread_limit_with_create_agent():
|
||||
model=model,
|
||||
tools=[simple_tool],
|
||||
middleware=[ModelCallLimitMiddleware(thread_limit=1)],
|
||||
).compile(checkpointer=InMemorySaver())
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
# First invocation should work - 1 model call, within thread limit
|
||||
result = agent.invoke(
|
||||
@@ -1778,7 +1779,8 @@ def test_run_limit_with_create_agent():
|
||||
model=model,
|
||||
tools=[simple_tool],
|
||||
middleware=[ModelCallLimitMiddleware(run_limit=1)],
|
||||
).compile(checkpointer=InMemorySaver())
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
# This should hit the run limit after the first model call
|
||||
result = agent.invoke(
|
||||
@@ -1865,8 +1867,7 @@ def test_run_limit_resets_between_invocations() -> None:
|
||||
tool_calls=[[], [], [], []]
|
||||
) # No tool calls, so only model call per run
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
agent = agent.compile(checkpointer=InMemorySaver())
|
||||
agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver())
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test_thread"}}
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config)
|
||||
@@ -1913,7 +1914,7 @@ async def test_create_agent_async_invoke() -> None:
|
||||
tools=[my_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
@@ -1969,7 +1970,7 @@ async def test_create_agent_async_invoke_multiple_middleware() -> None:
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
@@ -2010,7 +2011,7 @@ async def test_create_agent_async_jump() -> None:
|
||||
tools=[my_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": []})
|
||||
|
||||
@@ -2049,7 +2050,7 @@ async def test_create_agent_mixed_sync_async_middleware() -> None:
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[SyncMiddleware(), AsyncMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
@@ -2076,7 +2077,7 @@ def test_create_agent_sync_invoke_with_only_async_middleware_raises_error() -> N
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncOnlyMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
@@ -2109,7 +2110,7 @@ def test_create_agent_sync_invoke_with_mixed_middleware() -> None:
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[MixedMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from langchain.agents.middleware.types import (
|
||||
modify_model_request,
|
||||
hook_config,
|
||||
)
|
||||
from langchain.agents.middleware_agent import create_agent, _get_can_jump_to
|
||||
from langchain.agents.factory import create_agent, _get_can_jump_to
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
@@ -136,7 +136,7 @@ def test_all_decorators_integration() -> None:
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(), middleware=[track_before, track_modify, track_after]
|
||||
)
|
||||
agent = agent.compile()
|
||||
# agent is already compiled
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert call_order == ["before", "modify", "after"]
|
||||
@@ -225,7 +225,7 @@ def test_can_jump_to_integration() -> None:
|
||||
return None
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[early_exit])
|
||||
agent = agent.compile()
|
||||
# agent is already compiled
|
||||
|
||||
# Test with early exit
|
||||
result = agent.invoke({"messages": [HumanMessage("exit")]})
|
||||
@@ -338,7 +338,7 @@ async def test_async_decorators_integration() -> None:
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[track_async_before, track_async_modify, track_async_after],
|
||||
)
|
||||
agent = agent.compile()
|
||||
# agent is already compiled
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert call_order == ["async_before", "async_modify", "async_after"]
|
||||
@@ -394,7 +394,7 @@ async def test_mixed_sync_async_decorators_integration() -> None:
|
||||
track_sync_after,
|
||||
],
|
||||
)
|
||||
agent = agent.compile()
|
||||
# agent is already compiled
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert call_order == [
|
||||
@@ -455,7 +455,7 @@ async def test_async_can_jump_to_integration() -> None:
|
||||
return None
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[async_early_exit])
|
||||
agent = agent.compile()
|
||||
# agent is already compiled
|
||||
|
||||
# Test with early exit
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("exit")]})
|
||||
@@ -526,7 +526,7 @@ def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAsse
|
||||
model=FakeToolCallingModel(), middleware=[async_before_with_jump]
|
||||
)
|
||||
|
||||
assert agent_async_before.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_async_before.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
# Test 2: Async after_model with can_jump_to
|
||||
@after_model(can_jump_to=["model", "end"])
|
||||
@@ -539,7 +539,7 @@ def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAsse
|
||||
model=FakeToolCallingModel(), middleware=[async_after_with_jump]
|
||||
)
|
||||
|
||||
assert agent_async_after.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_async_after.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
# Test 3: Multiple async middleware with can_jump_to
|
||||
@before_model(can_jump_to=["end"])
|
||||
@@ -555,7 +555,7 @@ def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAsse
|
||||
middleware=[async_before_early_exit, async_after_retry],
|
||||
)
|
||||
|
||||
assert agent_multiple_async.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_multiple_async.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
# Test 4: Mixed sync and async middleware with can_jump_to
|
||||
@before_model(can_jump_to=["end"])
|
||||
@@ -571,4 +571,4 @@ def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAsse
|
||||
middleware=[sync_before_with_jump, async_after_with_jumps],
|
||||
)
|
||||
|
||||
assert agent_mixed.compile().get_graph().draw_mermaid() == snapshot
|
||||
assert agent_mixed.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
@@ -3,15 +3,15 @@
|
||||
import pytest
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from .model import FakeToolCallingModel
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def test_model_request_tools_are_strings() -> None:
|
||||
"""Test that ModelRequest.tools contains tool names as strings, not tool objects."""
|
||||
def test_model_request_tools_are_objects() -> None:
|
||||
"""Test that ModelRequest.tools contains tool objects (BaseTool | dict)."""
|
||||
captured_requests: list[ModelRequest] = []
|
||||
|
||||
@tool
|
||||
@@ -36,23 +36,23 @@ def test_model_request_tools_are_strings() -> None:
|
||||
tools=[search_tool, calculator],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[RequestCapturingMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
# Verify that at least one request was captured
|
||||
assert len(captured_requests) > 0
|
||||
|
||||
# Check that tools in the request are strings (tool names)
|
||||
# Check that tools in the request are tool objects
|
||||
request = captured_requests[0]
|
||||
assert isinstance(request.tools, list)
|
||||
assert len(request.tools) == 2
|
||||
assert all(isinstance(tool_name, str) for tool_name in request.tools)
|
||||
assert set(request.tools) == {"search_tool", "calculator"}
|
||||
tool_names = {t.name for t in request.tools}
|
||||
assert tool_names == {"search_tool", "calculator"}
|
||||
|
||||
|
||||
def test_middleware_can_modify_tool_names() -> None:
|
||||
"""Test that middleware can modify the list of tool names in ModelRequest."""
|
||||
def test_middleware_can_modify_tools() -> None:
|
||||
"""Test that middleware can modify the list of tools in ModelRequest."""
|
||||
|
||||
@tool
|
||||
def tool_a(input: str) -> str:
|
||||
@@ -74,7 +74,7 @@ def test_middleware_can_modify_tool_names() -> None:
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
# Only allow tool_a and tool_b
|
||||
request.tools = ["tool_a", "tool_b"]
|
||||
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
|
||||
return request
|
||||
|
||||
# Model will try to call tool_a
|
||||
@@ -87,7 +87,7 @@ def test_middleware_can_modify_tool_names() -> None:
|
||||
tools=[tool_a, tool_b, tool_c],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ToolFilteringMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Use tool_a")]})
|
||||
|
||||
@@ -98,31 +98,37 @@ def test_middleware_can_modify_tool_names() -> None:
|
||||
assert tool_messages[0].name == "tool_a"
|
||||
|
||||
|
||||
def test_unknown_tool_name_raises_error() -> None:
|
||||
"""Test that using an unknown tool name in ModelRequest raises a clear error."""
|
||||
def test_middleware_can_add_custom_tools() -> None:
|
||||
"""Test that middleware can add custom tool objects to ModelRequest."""
|
||||
|
||||
@tool
|
||||
def known_tool(input: str) -> str:
|
||||
"""A known tool."""
|
||||
return "result"
|
||||
|
||||
class BadMiddleware(AgentMiddleware):
|
||||
@tool
|
||||
def custom_tool(input: str) -> str:
|
||||
"""A custom tool added by middleware."""
|
||||
return "custom result"
|
||||
|
||||
class ToolAddingMiddleware(AgentMiddleware):
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
# Add an unknown tool name
|
||||
request.tools = ["known_tool", "unknown_tool"]
|
||||
# Add a custom tool
|
||||
request.tools = list(request.tools) + [custom_tool]
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[known_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[BadMiddleware()],
|
||||
).compile()
|
||||
middleware=[ToolAddingMiddleware()],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Middleware returned unknown tool names"):
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
# Should work fine with custom tools added
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
|
||||
def test_middleware_can_add_and_remove_tools() -> None:
|
||||
@@ -149,7 +155,7 @@ def test_middleware_can_add_and_remove_tools() -> None:
|
||||
) -> ModelRequest:
|
||||
# Remove admin_tool if not admin
|
||||
if not state.get("is_admin", False):
|
||||
request.tools = [name for name in request.tools if name != "admin_tool"]
|
||||
request.tools = [t for t in request.tools if t.name != "admin_tool"]
|
||||
return request
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
@@ -159,7 +165,7 @@ def test_middleware_can_add_and_remove_tools() -> None:
|
||||
tools=[search, admin_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ConditionalToolMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
# Test non-admin user - should not have access to admin_tool
|
||||
# We can't directly inspect the bound model, but we can verify the agent runs
|
||||
@@ -194,7 +200,7 @@ def test_empty_tools_list_is_valid() -> None:
|
||||
tools=[some_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoToolsMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
# Should run without error even with no tools
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
@@ -224,20 +230,20 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
modification_order.append(request.tools.copy())
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
# Remove tool_c
|
||||
request.tools = [name for name in request.tools if name != "tool_c"]
|
||||
request.tools = [t for t in request.tools if t.name != "tool_c"]
|
||||
return request
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
modification_order.append(request.tools.copy())
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
# Should not see tool_c here
|
||||
assert "tool_c" not in request.tools
|
||||
assert "tool_c" not in [t.name for t in request.tools]
|
||||
# Remove tool_b
|
||||
request.tools = [name for name in request.tools if name != "tool_b"]
|
||||
request.tools = [t for t in request.tools if t.name != "tool_b"]
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
@@ -245,7 +251,7 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
tools=[tool_a, tool_b, tool_c],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[FirstMiddleware(), SecondMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
@@ -286,7 +292,7 @@ def test_middleware_with_additional_tools() -> None:
|
||||
tools=[base_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ToolProvidingMiddleware()],
|
||||
).compile()
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Use middleware tool")]})
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,58 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
|
||||
def tool() -> None:
|
||||
"""Testing tool."""
|
||||
|
||||
|
||||
def pre_model_hook() -> None:
|
||||
"""Pre-model hook."""
|
||||
|
||||
|
||||
def post_model_hook() -> None:
|
||||
"""Post-model hook."""
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
"""Response format for the agent."""
|
||||
|
||||
result: str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [[], [tool]])
|
||||
@pytest.mark.parametrize("pre_model_hook", [None, pre_model_hook])
|
||||
@pytest.mark.parametrize("post_model_hook", [None, post_model_hook])
|
||||
def test_react_agent_graph_structure(
|
||||
snapshot: SnapshotAssertion,
|
||||
tools: list[Callable],
|
||||
pre_model_hook: Union[Callable, None],
|
||||
post_model_hook: Union[Callable, None],
|
||||
) -> None:
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=tools,
|
||||
pre_model_hook=pre_model_hook,
|
||||
post_model_hook=post_model_hook,
|
||||
)
|
||||
try:
|
||||
assert agent.get_graph().draw_mermaid(with_styles=False) == snapshot
|
||||
except Exception as e:
|
||||
msg = (
|
||||
"The graph structure has changed. Please update the snapshot."
|
||||
"Configuration used:\n"
|
||||
f"tools: {tools}, "
|
||||
f"pre_model_hook: {pre_model_hook}, "
|
||||
f"post_model_hook: {post_model_hook}, "
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
@@ -7,6 +7,7 @@ from typing import Union
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.tools import tool
|
||||
from langchain.agents.structured_output import (
|
||||
MultipleStructuredOutputsError,
|
||||
ProviderStrategy,
|
||||
@@ -74,12 +75,14 @@ location_json_schema = {
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def get_weather() -> str:
|
||||
"""Get the weather."""
|
||||
|
||||
return "The weather is sunny and 75°F."
|
||||
|
||||
|
||||
@tool
|
||||
def get_location() -> str:
|
||||
"""Get the current location."""
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from langchain.tools import (
|
||||
ToolNode,
|
||||
InjectedState,
|
||||
InjectedStore,
|
||||
ToolNode,
|
||||
)
|
||||
from langchain.tools.tool_node import TOOL_CALL_ERROR_TEMPLATE, ToolInvocationError, tools_condition
|
||||
|
||||
|
||||
Reference in New Issue
Block a user