From 18db07c2928847844dbc23ed0abdf51900bbd492 Mon Sep 17 00:00:00 2001
From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
Date: Wed, 27 Aug 2025 13:32:21 -0400
Subject: [PATCH] feat(langchain): revamped `create_react_agent` (#32705)
Adding `create_react_agent` and introducing `langchain.agents`!
## Enhanced Structured Output
`create_react_agent` supports coercion of outputs to structured data
types like `pydantic` models, dataclasses, typed dicts, or JSON schemas
specifications.
### Structural Changes
In langgraph < 1.0, `create_react_agent` implemented support for
structured output via an additional LLM call to the model after the
standard model / tool calling loop finished. This introduced extra
expense and was unnecessary.
This new version implements structured output support in the main loop,
allowing a model to choose between calling tools or generating
structured output (or both).
The same basic pattern for structured output generation works:
```py
from langchain.agents import create_react_agent
from langchain_core.messages import HumanMessage
from pydantic import BaseModel
class Weather(BaseModel):
temperature: float
condition: str
def weather_tool(city: str) -> str:
"""Get the weather for a city."""
return f"it's sunny and 70 degrees in {city}"
agent = create_react_agent("openai:gpt-4o-mini", tools=[weather_tool], response_format=Weather)
print(repr(result["structured_response"]))
#> Weather(temperature=70.0, condition='sunny')
```
### Advanced Configuration
The new API exposes two ways to configure how structured output is
generated. Under the hood, LangChain will attempt to pick the best
approach if not explicitly specified. That is, if provider native
support is available for a given model, that takes priority over
artificial tool calling.
1. Artificial tool calling (the default for most models)
LangChain generates a tool (or tools) under the hood that match the
schema of your response format. When the model calls those tools,
LangChain coerces the args to the desired format. Note, LangChain does
not validate outputs adhering to JSON schema specifications.
Extended example
```py
from langchain.agents import create_react_agent
from langchain_core.messages import HumanMessage
from langchain.agents.structured_output import ToolStrategy
from pydantic import BaseModel
class Weather(BaseModel):
temperature: float
condition: str
def weather_tool(city: str) -> str:
"""Get the weather for a city."""
return f"it's sunny and 70 degrees in {city}"
agent = create_react_agent(
"openai:gpt-4o-mini",
tools=[weather_tool],
response_format=ToolStrategy(
schema=Weather, tool_message_content="Final Weather result generated"
),
)
result = agent.invoke({"messages": [HumanMessage("What's the weather in Tokyo?")]})
for message in result["messages"]:
message.pretty_print()
"""
================================ Human Message =================================
What's the weather in Tokyo?
================================== Ai Message ==================================
Tool Calls:
weather_tool (call_Gg933BMHMwck50Q39dtBjXm7)
Call ID: call_Gg933BMHMwck50Q39dtBjXm7
Args:
city: Tokyo
================================= Tool Message =================================
Name: weather_tool
it's sunny and 70 degrees in Tokyo
================================== Ai Message ==================================
Tool Calls:
Weather (call_9xOkYUM7PuEXl9DQq9sWGv5l)
Call ID: call_9xOkYUM7PuEXl9DQq9sWGv5l
Args:
temperature: 70
condition: sunny
================================= Tool Message =================================
Name: Weather
Final Weather result generated
"""
print(repr(result["structured_response"]))
#> Weather(temperature=70.0, condition='sunny')
```
2. Provider implementations (limited to OpenAI, Groq)
Some providers support structured output generating directly. For those
cases, we offer the `ProviderStrategy` hint:
Extended example
```py
from langchain.agents import create_react_agent
from langchain_core.messages import HumanMessage
from langchain.agents.structured_output import ProviderStrategy
from pydantic import BaseModel
class Weather(BaseModel):
temperature: float
condition: str
def weather_tool(city: str) -> str:
"""Get the weather for a city."""
return f"it's sunny and 70 degrees in {city}"
agent = create_react_agent(
"openai:gpt-4o-mini",
tools=[weather_tool],
response_format=ProviderStrategy(Weather),
)
result = agent.invoke({"messages": [HumanMessage("What's the weather in Tokyo?")]})
for message in result["messages"]:
message.pretty_print()
"""
================================ Human Message =================================
What's the weather in Tokyo?
================================== Ai Message ==================================
Tool Calls:
weather_tool (call_OFJq1FngIXS6cvjWv5nfSFZp)
Call ID: call_OFJq1FngIXS6cvjWv5nfSFZp
Args:
city: Tokyo
================================= Tool Message =================================
Name: weather_tool
it's sunny and 70 degrees in Tokyo
================================== Ai Message ==================================
{"temperature":70,"condition":"sunny"}
Weather(temperature=70.0, condition='sunny')
"""
print(repr(result["structured_response"]))
#> Weather(temperature=70.0, condition='sunny')
```
Note! The final tool message has the custom content provided by the dev.
Prompted output was previously supported and is no longer supported via
the `response_format` argument to `create_react_agent`. If there's
significant demand for this, we'd be happy to engineer a solution.
## Error Handling
`create_react_agent` now exposes an API for managing errors associated
with structured output generation. There are two common problems with
structured output generation (w/ artificial tool calling):
1. **Parsing error** -- the model generates data that doesn't match the
desired structure for the output
2. **Multiple tool calls error** -- the model generates 2 or more tool
calls associated with structured output schemas
A developer can control the desired behavior for this via the
`handle_errors` arg to `ToolStrategy`.
Extended example
```py
from langchain_core.messages import HumanMessage
from pydantic import BaseModel
from langchain.agents import create_react_agent
from langchain.agents.structured_output import StructuredOutputValidationError, ToolStrategy
class Weather(BaseModel):
temperature: float
condition: str
def weather_tool(city: str) -> str:
"""Get the weather for a city."""
return f"it's sunny and 70 degrees in {city}"
def handle_validation_error(error: Exception) -> str:
if isinstance(error, StructuredOutputValidationError):
return (
f"Please call the {error.tool_name} call again with the correct arguments. "
f"Your mistake was: {error.source}"
)
raise error
agent = create_react_agent(
"openai:gpt-5",
tools=[weather_tool],
response_format=ToolStrategy(
schema=Weather,
handle_errors=handle_validation_error,
),
)
```
## Error Handling for Tool Calling
Tools fail for two main reasons:
1. **Invocation failure** -- the args generated by the model for the
tool are incorrect (missing, incompatible data types, etc)
2. **Execution failure** -- the tool execution itself fails due to a
developer error, network error, or some other exception.
By default, when tool **invocation** fails, the react agent will return
an artificial `ToolMessage` to the model asking it to correct its
mistakes and retry.
Now, when tool **execution** fails, the react agent raises the
`ToolException` by default instead of asking the model to retry. This
helps to avoid looping that should be avoided due to the aforementioned
issues.
Developers can configure their desired behavior for retries / error
handling via the `handle_tool_errors` arg to `ToolNode`.
## Pre-Bound Models
`create_react_agent` no longer supports inputs to `model` that have been
pre-bound w/ tools or other configuration. To properly support
structured output generation, the agent itself needs the power to bind
tools + structured output kwargs.
This also makes the devx cleaner - it's always expected that `model` is
an instance of `BaseChatModel` (or `str` that we coerce into a chat
model instance).
Dynamic model functions can return a pre-bound model **IF** structured
output is not also used. Dynamic model functions can then bind tools /
structured output logic.
## Import Changes
Users should now use `create_react_agent` from `langchain.agents`
instead of `langgraph.prebuilts`.
Other imports have a similar migration path, `ToolNode` and `AgentState`
for example.
* `chat_agent_executor.py` -> `react_agent.py`
Some notes:
1. Disabled blockbuster + some linting in `langchain/agents` -- beyond
ideal, but necessary to get this across the line for the alpha. We
should re-enable before official release.
---
libs/langchain_v1/Makefile | 36 +-
.../langchain_v1/langchain/agents/__init__.py | 10 +
.../langchain/agents/_internal/__init__.py | 1 +
.../langchain/agents/_internal/_typing.py | 13 +
.../langchain/agents/interrupt.py | 92 +
.../langchain/agents/react_agent.py | 1175 ++++++++++++
.../langchain/agents/structured_output.py | 403 ++++
.../langchain/agents/tool_node.py | 1174 ++++++++++++
.../langchain/chat_models/base.py | 4 +-
libs/langchain_v1/pyproject.toml | 21 +-
.../integration_tests/agents/__init__.py | 1 +
.../agents/test_response_format.py | 79 +
.../tests/unit_tests/agents/__init__.py | 0
.../__snapshots__/test_react_agent_graph.ambr | 83 +
.../tests/unit_tests/agents/any_str.py | 18 +
.../unit_tests/agents/compose-postgres.yml | 17 +
.../tests/unit_tests/agents/compose-redis.yml | 16 +
.../tests/unit_tests/agents/conftest.py | 194 ++
.../agents/conftest_checkpointer.py | 64 +
.../tests/unit_tests/agents/conftest_store.py | 58 +
.../tests/unit_tests/agents/memory_assert.py | 56 +
.../tests/unit_tests/agents/messages.py | 28 +
.../tests/unit_tests/agents/model.py | 111 ++
.../agents/specifications/responses.json | 87 +
.../agents/specifications/return_direct.json | 48 +
.../unit_tests/agents/test_react_agent.py | 1647 +++++++++++++++++
.../agents/test_react_agent_graph.py | 58 +
.../unit_tests/agents/test_response_format.py | 710 +++++++
.../tests/unit_tests/agents/test_responses.py | 140 ++
.../unit_tests/agents/test_responses_spec.py | 147 ++
.../agents/test_return_direct_spec.py | 107 ++
.../tests/unit_tests/agents/test_tool_node.py | 1482 +++++++++++++++
.../tests/unit_tests/agents/utils.py | 21 +
.../langchain_v1/tests/unit_tests/conftest.py | 35 +-
libs/langchain_v1/uv.lock | 36 +-
35 files changed, 8103 insertions(+), 69 deletions(-)
create mode 100644 libs/langchain_v1/langchain/agents/__init__.py
create mode 100644 libs/langchain_v1/langchain/agents/_internal/__init__.py
create mode 100644 libs/langchain_v1/langchain/agents/_internal/_typing.py
create mode 100644 libs/langchain_v1/langchain/agents/interrupt.py
create mode 100644 libs/langchain_v1/langchain/agents/react_agent.py
create mode 100644 libs/langchain_v1/langchain/agents/structured_output.py
create mode 100644 libs/langchain_v1/langchain/agents/tool_node.py
create mode 100644 libs/langchain_v1/tests/integration_tests/agents/__init__.py
create mode 100644 libs/langchain_v1/tests/integration_tests/agents/test_response_format.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/__init__.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_react_agent_graph.ambr
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/any_str.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/compose-postgres.yml
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/compose-redis.yml
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/conftest.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/conftest_checkpointer.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/conftest_store.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/memory_assert.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/messages.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/model.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/specifications/responses.json
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/specifications/return_direct.json
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/test_react_agent_graph.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/test_response_format.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/test_responses.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/test_return_direct_spec.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py
create mode 100644 libs/langchain_v1/tests/unit_tests/agents/utils.py
diff --git a/libs/langchain_v1/Makefile b/libs/langchain_v1/Makefile
index 885a56cf7b6..bb07973f2be 100644
--- a/libs/langchain_v1/Makefile
+++ b/libs/langchain_v1/Makefile
@@ -1,4 +1,4 @@
-.PHONY: all clean docs_build docs_clean docs_linkcheck api_docs_build api_docs_clean api_docs_linkcheck format lint test tests test_watch integration_tests help extended_tests
+.PHONY: all clean docs_build docs_clean docs_linkcheck api_docs_build api_docs_clean api_docs_linkcheck format lint test tests test_watch integration_tests help extended_tests start_services stop_services
# Default target executed when no arguments are given to make.
all: help
@@ -7,6 +7,12 @@ all: help
# TESTING AND COVERAGE
######################
+start_services:
+ docker compose -f tests/unit_tests/agents/compose-postgres.yml -f tests/unit_tests/agents/compose-redis.yml up -V --force-recreate --wait --remove-orphans
+
+stop_services:
+ docker compose -f tests/unit_tests/agents/compose-postgres.yml -f tests/unit_tests/agents/compose-redis.yml down -v
+
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
@@ -21,17 +27,32 @@ coverage:
--cov-report term-missing:skip-covered \
$(TEST_FILE)
-test tests:
- uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
+test:
+ make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered; \
+ EXIT_CODE=$$?; \
+ make stop_services; \
+ exit $$EXIT_CODE
+
+test_fast:
+ LANGGRAPH_TEST_FAST=1 uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
extended_tests:
- uv run --group test pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests
+ make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests; \
+ EXIT_CODE=$$?; \
+ make stop_services; \
+ exit $$EXIT_CODE
test_watch:
- uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests
+ make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests; \
+ EXIT_CODE=$$?; \
+ make stop_services; \
+ exit $$EXIT_CODE
test_watch_extended:
- uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests
+ make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests; \
+ EXIT_CODE=$$?; \
+ make stop_services; \
+ exit $$EXIT_CODE
integration_tests:
uv run --group test --group test_integration pytest tests/integration_tests
@@ -87,7 +108,8 @@ help:
@echo 'spell_fix - run codespell on the project and fix the errors'
@echo '-- TESTS --'
@echo 'coverage - run unit tests and generate coverage report'
- @echo 'test - run unit tests'
+ @echo 'test - run unit tests with all services'
+ @echo 'test_fast - run unit tests with in-memory services only'
@echo 'tests - run unit tests (alias for "make test")'
@echo 'test TEST_FILE= - run all tests in file'
@echo 'extended_tests - run only extended unit tests'
diff --git a/libs/langchain_v1/langchain/agents/__init__.py b/libs/langchain_v1/langchain/agents/__init__.py
new file mode 100644
index 00000000000..f3c10ee29e6
--- /dev/null
+++ b/libs/langchain_v1/langchain/agents/__init__.py
@@ -0,0 +1,10 @@
+"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""
+
+from langchain.agents.react_agent import AgentState, create_react_agent
+from langchain.agents.tool_node import ToolNode
+
+__all__ = [
+ "AgentState",
+ "ToolNode",
+ "create_react_agent",
+]
diff --git a/libs/langchain_v1/langchain/agents/_internal/__init__.py b/libs/langchain_v1/langchain/agents/_internal/__init__.py
new file mode 100644
index 00000000000..b4ff9ac2baa
--- /dev/null
+++ b/libs/langchain_v1/langchain/agents/_internal/__init__.py
@@ -0,0 +1 @@
+"""Internal utilities for agents."""
diff --git a/libs/langchain_v1/langchain/agents/_internal/_typing.py b/libs/langchain_v1/langchain/agents/_internal/_typing.py
new file mode 100644
index 00000000000..36989c6ce22
--- /dev/null
+++ b/libs/langchain_v1/langchain/agents/_internal/_typing.py
@@ -0,0 +1,13 @@
+"""Typing utilities for agents."""
+
+from __future__ import annotations
+
+from collections.abc import Awaitable, Callable
+from typing import TypeVar, Union
+
+from typing_extensions import ParamSpec
+
+P = ParamSpec("P")
+R = TypeVar("R")
+
+SyncOrAsync = Callable[P, Union[R, Awaitable[R]]]
diff --git a/libs/langchain_v1/langchain/agents/interrupt.py b/libs/langchain_v1/langchain/agents/interrupt.py
new file mode 100644
index 00000000000..fda69363774
--- /dev/null
+++ b/libs/langchain_v1/langchain/agents/interrupt.py
@@ -0,0 +1,92 @@
+"""Interrupt types to use with agent inbox like setups."""
+
+from typing import Literal, Union
+
+from typing_extensions import TypedDict
+
+
+class HumanInterruptConfig(TypedDict):
+ """Configuration that defines what actions are allowed for a human interrupt.
+
+ This controls the available interaction options when the graph is paused for human input.
+
+ Attributes:
+ allow_ignore: Whether the human can choose to ignore/skip the current step
+ allow_respond: Whether the human can provide a text response/feedback
+ allow_edit: Whether the human can edit the provided content/state
+ allow_accept: Whether the human can accept/approve the current state
+ """
+
+ allow_ignore: bool
+ allow_respond: bool
+ allow_edit: bool
+ allow_accept: bool
+
+
+class ActionRequest(TypedDict):
+ """Represents a request for human action within the graph execution.
+
+ Contains the action type and any associated arguments needed for the action.
+
+ Attributes:
+ action: The type or name of action being requested (e.g., "Approve XYZ action")
+ args: Key-value pairs of arguments needed for the action
+ """
+
+ action: str
+ args: dict
+
+
+class HumanInterrupt(TypedDict):
+ """Represents an interrupt triggered by the graph that requires human intervention.
+
+ This is passed to the `interrupt` function when execution is paused for human input.
+
+ Attributes:
+ action_request: The specific action being requested from the human
+ config: Configuration defining what actions are allowed
+ description: Optional detailed description of what input is needed
+
+ Example:
+ ```python
+ # Extract a tool call from the state and create an interrupt request
+ request = HumanInterrupt(
+ action_request=ActionRequest(
+ action="run_command", # The action being requested
+ args={"command": "ls", "args": ["-l"]} # Arguments for the action
+ ),
+ config=HumanInterruptConfig(
+ allow_ignore=True, # Allow skipping this step
+ allow_respond=True, # Allow text feedback
+ allow_edit=False, # Don't allow editing
+ allow_accept=True # Allow direct acceptance
+ ),
+ description="Please review the command before execution"
+ )
+ # Send the interrupt request and get the response
+ response = interrupt([request])[0]
+ ```
+ """
+
+ action_request: ActionRequest
+ config: HumanInterruptConfig
+ description: str | None
+
+
+class HumanResponse(TypedDict):
+ """The response provided by a human to an interrupt, which is returned when graph execution resumes.
+
+ Attributes:
+ type: The type of response:
+ - "accept": Approves the current state without changes
+ - "ignore": Skips/ignores the current step
+ - "response": Provides text feedback or instructions
+ - "edit": Modifies the current state/content
+ args: The response payload:
+ - None: For ignore/accept actions
+ - str: For text responses
+ - ActionRequest: For edit actions with updated content
+ """
+
+ type: Literal["accept", "ignore", "response", "edit"]
+ args: Union[None, str, ActionRequest]
diff --git a/libs/langchain_v1/langchain/agents/react_agent.py b/libs/langchain_v1/langchain/agents/react_agent.py
new file mode 100644
index 00000000000..045d92e97dc
--- /dev/null
+++ b/libs/langchain_v1/langchain/agents/react_agent.py
@@ -0,0 +1,1175 @@
+from __future__ import annotations
+
+import inspect
+from collections.abc import Awaitable, Callable, Sequence
+from dataclasses import asdict, is_dataclass
+from typing import (
+ TYPE_CHECKING,
+ Annotated,
+ Any,
+ Generic,
+ Literal,
+ Union,
+ cast,
+ get_type_hints,
+)
+from warnings import warn
+
+from langchain_core.language_models import (
+ BaseChatModel,
+ LanguageModelInput,
+ LanguageModelLike,
+)
+from langchain_core.messages import (
+ AIMessage,
+ AnyMessage,
+ BaseMessage,
+ SystemMessage,
+ ToolCall,
+ ToolMessage,
+)
+from langchain_core.runnables import (
+ Runnable,
+ RunnableConfig,
+)
+from langgraph._internal._runnable import RunnableCallable, RunnableLike
+from langgraph._internal._typing import MISSING
+from langgraph.errors import ErrorCode, create_error_message
+from langgraph.graph import END, StateGraph
+from langgraph.graph.message import add_messages
+from langgraph.managed import RemainingSteps # noqa: TC002
+from langgraph.types import Checkpointer, Command, Send
+from langgraph.typing import ContextT, StateT
+from pydantic import BaseModel
+from typing_extensions import NotRequired, TypedDict, TypeVar
+
+from langchain.agents.structured_output import (
+ MultipleStructuredOutputsError,
+ OutputToolBinding,
+ ProviderStrategy,
+ ProviderStrategyBinding,
+ ResponseFormat,
+ StructuredOutputValidationError,
+ ToolStrategy,
+)
+from langchain.agents.tool_node import ToolNode
+from langchain.chat_models import init_chat_model
+
+if TYPE_CHECKING:
+ from langchain_core.tools import BaseTool
+ from langgraph.graph.state import CompiledStateGraph
+ from langgraph.runtime import Runtime
+ from langgraph.store.base import BaseStore
+
+ from langchain.agents._internal._typing import (
+ SyncOrAsync,
+ )
+
+StructuredResponseT = TypeVar("StructuredResponseT", default=None)
+
+STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
+
+
+class AgentState(TypedDict):
+ """The state of the agent."""
+
+ messages: Annotated[Sequence[BaseMessage], add_messages]
+
+ remaining_steps: NotRequired[RemainingSteps]
+
+
+class AgentStatePydantic(BaseModel):
+ """The state of the agent."""
+
+ messages: Annotated[Sequence[BaseMessage], add_messages]
+
+ remaining_steps: RemainingSteps = 25
+
+
+class AgentStateWithStructuredResponse(AgentState, Generic[StructuredResponseT]):
+ """The state of the agent with a structured response."""
+
+ structured_response: StructuredResponseT
+
+
+class AgentStateWithStructuredResponsePydantic(AgentStatePydantic, Generic[StructuredResponseT]):
+ """The state of the agent with a structured response."""
+
+ structured_response: StructuredResponseT
+
+
+PROMPT_RUNNABLE_NAME = "Prompt"
+
+Prompt = Union[
+ SystemMessage,
+ str,
+ Callable[[StateT], LanguageModelInput],
+ Runnable[StateT, LanguageModelInput],
+]
+
+
+def _get_state_value(state: StateT, key: str, default: Any = None) -> Any:
+ return state.get(key, default) if isinstance(state, dict) else getattr(state, key, default)
+
+
+def _get_prompt_runnable(prompt: Prompt | None) -> Runnable:
+ prompt_runnable: Runnable
+ if prompt is None:
+ prompt_runnable = RunnableCallable(
+ lambda state: _get_state_value(state, "messages"), name=PROMPT_RUNNABLE_NAME
+ )
+ elif isinstance(prompt, str):
+ _system_message: BaseMessage = SystemMessage(content=prompt)
+ prompt_runnable = RunnableCallable(
+ lambda state: [_system_message, *_get_state_value(state, "messages")],
+ name=PROMPT_RUNNABLE_NAME,
+ )
+ elif isinstance(prompt, SystemMessage):
+ prompt_runnable = RunnableCallable(
+ lambda state: [prompt, *_get_state_value(state, "messages")],
+ name=PROMPT_RUNNABLE_NAME,
+ )
+ elif inspect.iscoroutinefunction(prompt):
+ prompt_runnable = RunnableCallable(
+ None,
+ prompt,
+ name=PROMPT_RUNNABLE_NAME,
+ )
+ elif callable(prompt):
+ prompt_runnable = RunnableCallable(
+ prompt,
+ name=PROMPT_RUNNABLE_NAME,
+ )
+ elif isinstance(prompt, Runnable):
+ prompt_runnable = prompt
+ else:
+ msg = f"Got unexpected type for `prompt`: {type(prompt)}"
+ raise ValueError(msg)
+
+ return prompt_runnable
+
+
+def _validate_chat_history(
+ messages: Sequence[BaseMessage],
+) -> None:
+ """Validate that all tool calls in AIMessages have a corresponding ToolMessage."""
+ all_tool_calls = [
+ tool_call
+ for message in messages
+ if isinstance(message, AIMessage)
+ for tool_call in message.tool_calls
+ ]
+ tool_call_ids_with_results = {
+ message.tool_call_id for message in messages if isinstance(message, ToolMessage)
+ }
+ tool_calls_without_results = [
+ tool_call
+ for tool_call in all_tool_calls
+ if tool_call["id"] not in tool_call_ids_with_results
+ ]
+ if not tool_calls_without_results:
+ return
+
+ error_message = create_error_message(
+ message="Found AIMessages with tool_calls that do not have a corresponding ToolMessage. "
+ f"Here are the first few of those tool calls: {tool_calls_without_results[:3]}.\n\n"
+ "Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage "
+ "(result of a tool invocation to return to the LLM) - this is required by most LLM providers.",
+ error_code=ErrorCode.INVALID_CHAT_HISTORY,
+ )
+ raise ValueError(error_message)
+
+
+class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
+ """Internal builder class for constructing and agent."""
+
+ def __init__(
+ self,
+ model: Union[
+ str,
+ BaseChatModel,
+ SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
+ ],
+ tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode],
+ *,
+ prompt: Prompt | None = None,
+ response_format: ResponseFormat[StructuredResponseT] | None = None,
+ pre_model_hook: RunnableLike | None = None,
+ post_model_hook: RunnableLike | None = None,
+ state_schema: type[StateT] | None = None,
+ context_schema: type[ContextT] | None = None,
+ version: Literal["v1", "v2"] = "v2",
+ name: str | None = None,
+ store: BaseStore | None = None,
+ ) -> None:
+ self.model = model
+ self.tools = tools
+ self.prompt = prompt
+ self.response_format = response_format
+ self.pre_model_hook = pre_model_hook
+ self.post_model_hook = post_model_hook
+ self.state_schema = state_schema
+ self.context_schema = context_schema
+ self.version = version
+ self.name = name
+ self.store = store
+
+ if isinstance(model, Runnable) and not isinstance(model, BaseChatModel):
+ msg = (
+ "Expected `model` to be a BaseChatModel or a string, got {type(model)}."
+ "The `model` parameter should not have pre-bound tools, simply pass the model and tools separately."
+ )
+ raise ValueError(msg)
+
+ self._setup_tools()
+ self._setup_state_schema()
+ self._setup_structured_output()
+ self._setup_model()
+
+ def _setup_tools(self) -> None:
+ """Setup tool-related attributes."""
+ if isinstance(self.tools, ToolNode):
+ self._tool_classes = list(self.tools.tools_by_name.values())
+ self._tool_node = self.tools
+ self._llm_builtin_tools = []
+ else:
+ self._llm_builtin_tools = [t for t in self.tools if isinstance(t, dict)]
+ self._tool_node = ToolNode([t for t in self.tools if not isinstance(t, dict)])
+ self._tool_classes = list(self._tool_node.tools_by_name.values())
+
+ self._should_return_direct = {t.name for t in self._tool_classes if t.return_direct}
+ self._tool_calling_enabled = len(self._tool_classes) > 0
+
+ def _setup_structured_output(self) -> None:
+ """Set up structured output tracking for "tools" and "native" strategies.
+
+ "tools" strategy for structured output:
+ 1. Converting response format schemas to LangChain tools
+ 2. Creating metadata for proper response reconstruction
+ 3. Handling both Pydantic models and dict schemas
+
+ "native" strategy for structured output:
+ 1. Capturing the schema reference for later parsing
+ 2. Binding provider-native response_format kwargs at model bind time
+ 3. Parsing provider-enforced structured output directly into the schema
+ """
+ self.structured_output_tools: dict[str, OutputToolBinding[StructuredResponseT]] = {}
+ self.native_output_binding: ProviderStrategyBinding[StructuredResponseT] | None = None
+
+ if self.response_format is not None:
+ response_format = self.response_format
+
+ if isinstance(response_format, ToolStrategy):
+ # check if response_format.schema is a union
+ for response_schema in response_format.schema_specs:
+ structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
+ self.structured_output_tools[structured_tool_info.tool.name] = (
+ structured_tool_info
+ )
+ elif isinstance(response_format, ProviderStrategy):
+ # Use native strategy - create ProviderStrategyBinding for parsing
+ self.native_output_binding = ProviderStrategyBinding.from_schema_spec(
+ response_format.schema_spec
+ )
+ else:
+ # This shouldn't happen with the new ResponseFormat type, but keeping for safety
+ msg = (
+ f"Unsupported response_format type: {type(response_format)}. "
+ f"Expected ToolStrategy."
+ )
+ raise ValueError(msg)
+
+ def _setup_state_schema(self) -> None:
+ """Setup state schema with validation."""
+ if self.state_schema is not None:
+ required_keys = {"messages", "remaining_steps"}
+ if self.response_format is not None:
+ required_keys.add("structured_response")
+
+ schema_keys = set(get_type_hints(self.state_schema))
+ if missing_keys := required_keys - schema_keys:
+ msg = f"Missing required key(s) {missing_keys} in state_schema"
+ raise ValueError(msg)
+
+ self._final_state_schema = self.state_schema
+ else:
+ self._final_state_schema = (
+ AgentStateWithStructuredResponse # type: ignore[assignment]
+ if self.response_format is not None
+ else AgentState
+ )
+
+ def _handle_structured_response_tool_calls(self, response: AIMessage) -> Command | None:
+ """Handle tool calls that match structured output tools using the tools strategy.
+
+ Args:
+ response: The AI message containing potential tool calls
+
+ Returns:
+ Command with structured response update if found, None otherwise
+
+ Raises:
+ MultipleStructuredOutputsError: If multiple structured responses are returned and error handling is disabled
+ StructuredOutputParsingError: If parsing fails and error handling is disabled
+ """
+ if not isinstance(self.response_format, ToolStrategy) or not response.tool_calls:
+ return None
+
+ structured_tool_calls = [
+ tool_call
+ for tool_call in response.tool_calls
+ if tool_call["name"] in self.structured_output_tools
+ ]
+
+ if not structured_tool_calls:
+ return None
+
+ if len(structured_tool_calls) > 1:
+ return self._handle_multiple_structured_outputs(response, structured_tool_calls)
+
+ return self._handle_single_structured_output(response, structured_tool_calls[0])
+
+ def _handle_multiple_structured_outputs(
+ self,
+ response: AIMessage,
+ structured_tool_calls: list[ToolCall],
+ ) -> Command:
+ """Handle multiple structured output tool calls."""
+ tool_names = [tool_call["name"] for tool_call in structured_tool_calls]
+ exception = MultipleStructuredOutputsError(tool_names)
+
+ should_retry, error_message = self._handle_structured_output_error(exception)
+
+ if not should_retry:
+ raise exception
+
+ tool_messages = [
+ ToolMessage(
+ content=error_message,
+ tool_call_id=tool_call["id"],
+ name=tool_call["name"],
+ )
+ for tool_call in structured_tool_calls
+ ]
+
+ return Command(
+ update={"messages": [response, *tool_messages]},
+ goto="agent",
+ )
+
+ def _handle_single_structured_output(
+ self,
+ response: AIMessage,
+ tool_call: Any,
+ ) -> Command:
+ """Handle a single structured output tool call."""
+ structured_tool_binding = self.structured_output_tools[tool_call["name"]]
+
+ try:
+ structured_response = structured_tool_binding.parse(tool_call["args"])
+
+ if isinstance(structured_response, BaseModel):
+ structured_response_dict = structured_response.model_dump()
+ elif is_dataclass(structured_response):
+ structured_response_dict = asdict(structured_response) # type: ignore[arg-type]
+ else:
+ structured_response_dict = cast("dict", structured_response)
+
+ tool_message_content = (
+ self.response_format.tool_message_content
+ if isinstance(self.response_format, ToolStrategy)
+ and self.response_format.tool_message_content
+ else f"Returning structured response: {structured_response_dict}"
+ )
+
+ return Command(
+ update={
+ "messages": [
+ response,
+ ToolMessage(
+ content=tool_message_content,
+ tool_call_id=tool_call["id"],
+ name=tool_call["name"],
+ ),
+ ],
+ "structured_response": structured_response,
+ }
+ )
+ except Exception as exc: # noqa: BLE001
+ exception = StructuredOutputValidationError(tool_call["name"], exc)
+
+ should_retry, error_message = self._handle_structured_output_error(exception)
+
+ if not should_retry:
+ raise exception
+
+ return Command(
+ update={
+ "messages": [
+ response,
+ ToolMessage(
+ content=error_message,
+ tool_call_id=tool_call["id"],
+ name=tool_call["name"],
+ ),
+ ],
+ },
+ goto="agent",
+ )
+
+ def _handle_structured_output_error(
+ self,
+ exception: Exception,
+ ) -> tuple[bool, str]:
+ """Handle structured output error.
+
+ Returns (should_retry, retry_tool_message).
+ """
+ handle_errors = cast("ToolStrategy", self.response_format).handle_errors
+
+ if handle_errors is False:
+ return False, ""
+ if handle_errors is True:
+ return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
+ if isinstance(handle_errors, str):
+ return True, handle_errors
+ if isinstance(handle_errors, type) and issubclass(handle_errors, Exception):
+ if isinstance(exception, handle_errors):
+ return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
+ return False, ""
+ if isinstance(handle_errors, tuple):
+ if any(isinstance(exception, exc_type) for exc_type in handle_errors):
+ return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
+ return False, ""
+ if callable(handle_errors):
+ return True, handle_errors(exception) # type: ignore[call-arg, return-value]
+ return False, ""
+
+ def _apply_native_output_binding(self, model: LanguageModelLike) -> LanguageModelLike:
+ """If native output is configured, bind provider-native kwargs onto the model."""
+ if not isinstance(self.response_format, ProviderStrategy):
+ return model
+ kwargs = self.response_format.to_model_kwargs()
+ return model.bind(**kwargs)
+
+ def _handle_structured_response_native(self, response: AIMessage) -> Command | None:
+ """If native output is configured and there are no tool calls, parse using ProviderStrategyBinding."""
+ if self.native_output_binding is None:
+ return None
+ if response.tool_calls:
+ # if the model chooses to call tools, we let the normal flow handle it
+ return None
+
+ structured_response = self.native_output_binding.parse(response)
+
+ return Command(update={"messages": [response], "structured_response": structured_response})
+
+ def _setup_model(self) -> None:
+ """Setup model-related attributes."""
+ self._is_dynamic_model = not isinstance(self.model, (str, Runnable)) and callable(
+ self.model
+ )
+ self._is_async_dynamic_model = self._is_dynamic_model and inspect.iscoroutinefunction(
+ self.model
+ )
+
+ if not self._is_dynamic_model:
+ model = self.model
+ if isinstance(model, str):
+ model = init_chat_model(model)
+
+ # Collect all tools: regular tools + structured output tools
+ structured_output_tools = list(self.structured_output_tools.values())
+ all_tools = (
+ self._tool_classes
+ + self._llm_builtin_tools
+ + [info.tool for info in structured_output_tools]
+ )
+
+ if len(all_tools) > 0:
+ # Check if we need to force tool use for structured output
+ tool_choice = None
+ if self.response_format is not None and isinstance(
+ self.response_format, ToolStrategy
+ ):
+ tool_choice = "any"
+
+ if tool_choice:
+ model = cast("BaseChatModel", model).bind_tools( # type: ignore[assignment]
+ all_tools, tool_choice=tool_choice
+ )
+ # If native output is configured, bind tools with strict=True. Required for OpenAI.
+ elif isinstance(self.response_format, ProviderStrategy):
+ model = cast("BaseChatModel", model).bind_tools( # type: ignore[assignment]
+ all_tools, strict=True
+ )
+ else:
+ model = cast("BaseChatModel", model).bind_tools(all_tools) # type: ignore[assignment]
+
+ # bind native structured-output kwargs
+ model = self._apply_native_output_binding(model) # type: ignore[assignment, arg-type]
+
+ # Extract just the model part for direct invocation
+ self._static_model: Runnable | None = model # type: ignore[assignment]
+ else:
+ self._static_model = None
+
+ def _resolve_model(self, state: StateT, runtime: Runtime[ContextT]) -> LanguageModelLike:
+ """Resolve the model to use, handling both static and dynamic models."""
+ if self._is_dynamic_model:
+ dynamic_model = self.model(state, runtime) # type: ignore[operator, arg-type]
+ return self._apply_native_output_binding(dynamic_model) # type: ignore[arg-type]
+ return self._static_model # type: ignore[return-value]
+
+ async def _aresolve_model(self, state: StateT, runtime: Runtime[ContextT]) -> LanguageModelLike:
+ """Async resolve the model to use, handling both static and dynamic models."""
+ if self._is_async_dynamic_model:
+ dynamic_model = cast(
+ "Callable[[StateT, Runtime[ContextT]], Awaitable[BaseChatModel]]",
+ self.model,
+ )
+ return await dynamic_model(state, runtime)
+ if self._is_dynamic_model:
+ dynamic_model = self.model(state, runtime) # type: ignore[arg-type, assignment, operator]
+ return self._apply_native_output_binding(dynamic_model) # type: ignore[arg-type]
+ return self._static_model # type: ignore[return-value]
+
+ def create_model_node(self) -> RunnableCallable:
+ """Create the 'agent' node that calls the LLM."""
+
+ def _get_model_input_state(state: StateT) -> StateT:
+ if self.pre_model_hook is not None:
+ messages = _get_state_value(state, "llm_input_messages") or _get_state_value(
+ state, "messages"
+ )
+ error_msg = (
+ f"Expected input to call_model to have 'llm_input_messages' "
+ f"or 'messages' key, but got {state}"
+ )
+ else:
+ messages = _get_state_value(state, "messages")
+ error_msg = f"Expected input to call_model to have 'messages' key, but got {state}"
+
+ if messages is None:
+ raise ValueError(error_msg)
+
+ _validate_chat_history(messages)
+
+ if isinstance(self._final_state_schema, type) and issubclass(
+ self._final_state_schema, BaseModel
+ ):
+ # we're passing messages under `messages` key, as this
+ # is expected by the prompt
+ state.messages = messages # type: ignore[union-attr]
+ else:
+ state["messages"] = messages # type: ignore[index]
+ return state
+
+ def _are_more_steps_needed(state: StateT, response: BaseMessage) -> bool:
+ has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
+ all_tools_return_direct = (
+ all(call["name"] in self._should_return_direct for call in response.tool_calls)
+ if isinstance(response, AIMessage)
+ else False
+ )
+ remaining_steps = _get_state_value(state, "remaining_steps", None)
+ return (
+ remaining_steps is not None # type: ignore[return-value]
+ and (
+ (remaining_steps < 1 and all_tools_return_direct)
+ or (remaining_steps < 2 and has_tool_calls)
+ )
+ )
+
+ def call_model(
+ state: StateT, runtime: Runtime[ContextT], config: RunnableConfig
+ ) -> dict[str, Any] | Command:
+ """Call the model with the current state and return the response."""
+ if self._is_async_dynamic_model:
+ msg = (
+ "Async model callable provided but agent invoked synchronously. "
+ "Use agent.ainvoke() or agent.astream(), or provide a sync model callable."
+ )
+ raise RuntimeError(msg)
+
+ model_input = _get_model_input_state(state)
+ model = self._resolve_model(state, runtime)
+
+ # Get prompt runnable and invoke it first to prepare messages
+ prompt_runnable = _get_prompt_runnable(self.prompt)
+ prepared_messages = prompt_runnable.invoke(model_input, config)
+
+ # Then invoke the model with the prepared messages
+ response = cast("AIMessage", model.invoke(prepared_messages, config))
+ response.name = self.name
+
+ if _are_more_steps_needed(state, response):
+ return {
+ "messages": [
+ AIMessage(
+ id=response.id,
+ content="Sorry, need more steps to process this request.",
+ )
+ ]
+ }
+
+ # Check if any tool calls match structured output tools
+ structured_command = self._handle_structured_response_tool_calls(response)
+ if structured_command:
+ return structured_command
+
+ # Native structured output
+ native_command = self._handle_structured_response_native(response)
+ if native_command:
+ return native_command
+
+ return {"messages": [response]}
+
+ async def acall_model(
+ state: StateT, runtime: Runtime[ContextT], config: RunnableConfig
+ ) -> dict[str, Any] | Command:
+ """Call the model with the current state and return the response."""
+ model_input = _get_model_input_state(state)
+
+ model = await self._aresolve_model(state, runtime)
+
+ # Get prompt runnable and invoke it first to prepare messages
+ prompt_runnable = _get_prompt_runnable(self.prompt)
+ prepared_messages = await prompt_runnable.ainvoke(model_input, config)
+
+ # Then invoke the model with the prepared messages
+ response = cast(
+ "AIMessage",
+ await model.ainvoke(prepared_messages, config),
+ )
+ response.name = self.name
+ if _are_more_steps_needed(state, response):
+ return {
+ "messages": [
+ AIMessage(
+ id=response.id,
+ content="Sorry, need more steps to process this request.",
+ )
+ ]
+ }
+
+ # Check if any tool calls match structured output tools
+ structured_command = self._handle_structured_response_tool_calls(response)
+ if structured_command:
+ return structured_command
+
+ # Native structured output
+ native_command = self._handle_structured_response_native(response)
+ if native_command:
+ return native_command
+
+ return {"messages": [response]}
+
+ return RunnableCallable(call_model, acall_model)
+
+ def _get_input_schema(self) -> type[StateT]:
+ """Get input schema for model node."""
+ if self.pre_model_hook is not None:
+ if isinstance(self._final_state_schema, type) and issubclass(
+ self._final_state_schema, BaseModel
+ ):
+ from pydantic import create_model
+
+ return create_model(
+ "CallModelInputSchema",
+ llm_input_messages=(list[AnyMessage], ...),
+ __base__=self._final_state_schema,
+ )
+
+ class CallModelInputSchema(self._final_state_schema): # type: ignore[name-defined, misc]
+ llm_input_messages: list[AnyMessage]
+
+ return CallModelInputSchema
+ return self._final_state_schema
+
+ def create_model_router(self) -> Callable[[StateT], Union[str, list[Send]]]:
+ """Create routing function for model node conditional edges."""
+
+ def should_continue(state: StateT) -> Union[str, list[Send]]:
+ messages = _get_state_value(state, "messages")
+ last_message = messages[-1]
+
+ # Check if the last message is a ToolMessage from a structured tool.
+ # This condition exists to support structured output via tools.
+ # Once a tool has been called for structured output, we skip
+ # tool execution and go to END (if there is no post_model_hook).
+ if (
+ isinstance(last_message, ToolMessage)
+ and last_message.name in self.structured_output_tools
+ ):
+ return END
+
+ if isinstance(last_message, ToolMessage):
+ return END
+
+ if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
+ if self.post_model_hook is not None:
+ return "post_model_hook"
+ return END
+ if self.version == "v1":
+ return "tools"
+ if self.version == "v2":
+ if self.post_model_hook is not None:
+ return "post_model_hook"
+ tool_calls = [
+ self._tool_node.inject_tool_args(call, state, self.store) # type: ignore[arg-type]
+ for call in last_message.tool_calls
+ ]
+ return [Send("tools", [tool_call]) for tool_call in tool_calls]
+ return None
+
+ return should_continue
+
+ def create_post_model_hook_router(
+ self,
+ ) -> Callable[[StateT], Union[str, list[Send]]]:
+ """Create a routing function for post_model_hook node conditional edges."""
+
+ def post_model_hook_router(state: StateT) -> Union[str, list[Send]]:
+ messages = _get_state_value(state, "messages")
+
+ # Check if the last message is a ToolMessage from a structured tool.
+ # This condition exists to support structured output via tools.
+ # Once a tool has been called for structured output, we skip
+ # tool execution and go to END (if there is no post_model_hook).
+ last_message = messages[-1]
+ if (
+ isinstance(last_message, ToolMessage)
+ and last_message.name in self.structured_output_tools
+ ):
+ return END
+
+ tool_messages = [m.tool_call_id for m in messages if isinstance(m, ToolMessage)]
+ last_ai_message = next(m for m in reversed(messages) if isinstance(m, AIMessage))
+ pending_tool_calls = [
+ c for c in last_ai_message.tool_calls if c["id"] not in tool_messages
+ ]
+
+ if pending_tool_calls:
+ pending_tool_calls = [
+ self._tool_node.inject_tool_args(call, state, self.store) # type: ignore[arg-type]
+ for call in pending_tool_calls
+ ]
+ return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
+ if isinstance(messages[-1], ToolMessage):
+ return self._get_entry_point()
+ return END
+
+ return post_model_hook_router
+
+ def create_tools_router(self) -> Callable[[StateT], str] | None:
+ """Create a routing function for tools node conditional edges."""
+ if not self._should_return_direct:
+ return None
+
+ def route_tool_responses(state: StateT) -> str:
+ messages = _get_state_value(state, "messages")
+ for m in reversed(messages):
+ if not isinstance(m, ToolMessage):
+ break
+ if m.name in self._should_return_direct:
+ return END
+
+ if (
+ isinstance(m, AIMessage)
+ and m.tool_calls
+ and any(call["name"] in self._should_return_direct for call in m.tool_calls)
+ ):
+ return END
+
+ return self._get_entry_point()
+
+ return route_tool_responses
+
+ def _get_entry_point(self) -> str:
+ """Get the workflow entry point."""
+ return "pre_model_hook" if self.pre_model_hook else "agent"
+
+ def _get_model_paths(self) -> list[str]:
+ """Get possible edge destinations from model node."""
+ paths = []
+ if self._tool_calling_enabled:
+ paths.append("tools")
+ if self.post_model_hook:
+ paths.append("post_model_hook")
+ else:
+ paths.append(END)
+
+ return paths
+
+ def _get_post_model_hook_paths(self) -> list[str]:
+ """Get possible edge destinations from post_model_hook node."""
+ paths = []
+ if self._tool_calling_enabled:
+ paths = [self._get_entry_point(), "tools"]
+ paths.append(END)
+ return paths
+
+ def build(self) -> StateGraph[StateT, ContextT]:
+ """Build the agent workflow graph (uncompiled)."""
+ workflow = StateGraph(
+ state_schema=self._final_state_schema,
+ context_schema=self.context_schema,
+ )
+
+ # Set entry point
+ workflow.set_entry_point(self._get_entry_point())
+
+ # Add nodes
+ workflow.add_node("agent", self.create_model_node(), input_schema=self._get_input_schema())
+
+ if self._tool_calling_enabled:
+ workflow.add_node("tools", self._tool_node)
+
+ if self.pre_model_hook:
+ workflow.add_node("pre_model_hook", self.pre_model_hook) # type: ignore[arg-type]
+
+ if self.post_model_hook:
+ workflow.add_node("post_model_hook", self.post_model_hook) # type: ignore[arg-type]
+
+ # Add edges
+ if self.pre_model_hook:
+ workflow.add_edge("pre_model_hook", "agent")
+
+ if self.post_model_hook:
+ workflow.add_edge("agent", "post_model_hook")
+ post_hook_paths = self._get_post_model_hook_paths()
+ if len(post_hook_paths) == 1:
+ # No need for a conditional edge if there's only one path
+ workflow.add_edge("post_model_hook", post_hook_paths[0])
+ else:
+ workflow.add_conditional_edges(
+ "post_model_hook",
+ self.create_post_model_hook_router(),
+ path_map=post_hook_paths,
+ )
+ else:
+ model_paths = self._get_model_paths()
+ if len(model_paths) == 1:
+ # No need for a conditional edge if there's only one path
+ workflow.add_edge("agent", model_paths[0])
+ else:
+ workflow.add_conditional_edges(
+ "agent",
+ self.create_model_router(),
+ path_map=model_paths,
+ )
+
+ if self._tool_calling_enabled:
+ # In some cases, tools can return directly. In these cases
+ # we add a conditional edge from the tools node to the END node
+ # instead of going to the entry point.
+ tools_router = self.create_tools_router()
+ if tools_router:
+ workflow.add_conditional_edges(
+ "tools",
+ tools_router,
+ path_map=[self._get_entry_point(), END],
+ )
+ else:
+ workflow.add_edge("tools", self._get_entry_point())
+
+ return workflow
+
+
+def _supports_native_structured_output(
+ model: Union[str, BaseChatModel, SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel]],
+) -> bool:
+ """Check if a model supports native structured output.
+
+ TODO: replace with more robust model profiles.
+ """
+ model_name: str | None = None
+ if isinstance(model, str):
+ model_name = model
+ elif isinstance(model, BaseChatModel):
+ model_name = getattr(model, "model_name", None)
+
+ return (
+ "grok" in model_name.lower()
+ or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
+ if model_name
+ else False
+ )
+
+
+def create_react_agent( # noqa: D417
+ model: Union[
+ str,
+ BaseChatModel,
+ SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
+ ],
+ tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode],
+ *,
+ prompt: Prompt | None = None,
+ response_format: Union[
+ ToolStrategy[StructuredResponseT],
+ ProviderStrategy[StructuredResponseT],
+ type[StructuredResponseT],
+ ]
+ | None = None,
+ pre_model_hook: RunnableLike | None = None,
+ post_model_hook: RunnableLike | None = None,
+ state_schema: type[StateT] | None = None,
+ context_schema: type[ContextT] | None = None,
+ checkpointer: Checkpointer | None = None,
+ store: BaseStore | None = None,
+ interrupt_before: list[str] | None = None,
+ interrupt_after: list[str] | None = None,
+ debug: bool = False,
+ version: Literal["v1", "v2"] = "v2",
+ name: str | None = None,
+ **deprecated_kwargs: Any,
+) -> CompiledStateGraph[StateT, ContextT]:
+ """Creates an agent graph that calls tools in a loop until a stopping condition is met.
+
+ For more details on using `create_react_agent`, visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.
+
+ Args:
+ model: The language model for the agent. Supports static and dynamic
+ model selection.
+
+ - **Static model**: A chat model instance (e.g., `ChatOpenAI()`) or
+ string identifier (e.g., `"openai:gpt-4"`)
+ - **Dynamic model**: A callable with signature
+ `(state, runtime) -> BaseChatModel` that returns different models
+ based on runtime context
+ If the model has tools bound via `.bind_tools()` or other configurations,
+ the return type should be a Runnable[LanguageModelInput, BaseMessage]
+ Coroutines are also supported, allowing for asynchronous model selection.
+
+ Dynamic functions receive graph state and runtime, enabling
+ context-dependent model selection. Must return a `BaseChatModel`
+ instance. For tool calling, bind tools using `.bind_tools()`.
+ Bound tools must be a subset of the `tools` parameter.
+
+ Dynamic model example:
+ ```python
+ from dataclasses import dataclass
+
+ @dataclass
+ class ModelContext:
+ model_name: str = "gpt-3.5-turbo"
+
+ # Instantiate models globally
+ gpt4_model = ChatOpenAI(model="gpt-4")
+ gpt35_model = ChatOpenAI(model="gpt-3.5-turbo")
+
+ def select_model(state: AgentState, runtime: Runtime[ModelContext]) -> ChatOpenAI:
+ model_name = runtime.context.model_name
+ model = gpt4_model if model_name == "gpt-4" else gpt35_model
+ return model.bind_tools(tools)
+ ```
+
+ !!! note "Dynamic Model Requirements"
+ Ensure returned models have appropriate tools bound via
+ `.bind_tools()` and support required functionality. Bound tools
+ must be a subset of those specified in the `tools` parameter.
+
+ tools: A list of tools or a ToolNode instance.
+ If an empty list is provided, the agent will consist of a single LLM node without tool calling.
+ prompt: An optional prompt for the LLM. Can take a few different forms:
+
+ - str: This is converted to a SystemMessage and added to the beginning of the list of messages in state["messages"].
+ - SystemMessage: this is added to the beginning of the list of messages in state["messages"].
+ - Callable: This function should take in full graph state and the output is then passed to the language model.
+ - Runnable: This runnable should take in full graph state and the output is then passed to the language model.
+
+ response_format: An optional UsingToolStrategy configuration for structured responses.
+
+ If provided, the agent will handle structured output via tool calls during the normal conversation flow.
+ When the model calls a structured output tool, the response will be captured and returned in the 'structured_response' state key.
+ If not provided, `structured_response` will not be present in the output state.
+
+ The UsingToolStrategy should contain:
+ - schemas: A sequence of ResponseSchema objects that define the structured output format
+ - tool_choice: Either "required" or "auto" to control when structured output is used
+
+ Each ResponseSchema contains:
+ - schema: A Pydantic model that defines the structure
+ - name: Optional custom name for the tool (defaults to model name)
+ - description: Optional custom description (defaults to model docstring)
+ - strict: Whether to enforce strict validation
+
+ !!! Important
+ `response_format` requires the model to support tool calling
+
+ !!! Note
+ Structured responses are handled directly in the model call node via tool calls, eliminating the need for separate structured response nodes.
+
+ pre_model_hook: An optional node to add before the `agent` node (i.e., the node that calls the LLM).
+ Useful for managing long message histories (e.g., message trimming, summarization, etc.).
+ Pre-model hook must be a callable or a runnable that takes in current graph state and returns a state update in the form of
+ ```python
+ # At least one of `messages` or `llm_input_messages` MUST be provided
+ {
+ # If provided, will UPDATE the `messages` in the state
+ "messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), ...],
+ # If provided, will be used as the input to the LLM,
+ # and will NOT UPDATE `messages` in the state
+ "llm_input_messages": [...],
+ # Any other state keys that need to be propagated
+ ...
+ }
+ ```
+
+ !!! Important
+ At least one of `messages` or `llm_input_messages` MUST be provided and will be used as an input to the `agent` node.
+ The rest of the keys will be added to the graph state.
+
+ !!! Warning
+ If you are returning `messages` in the pre-model hook, you should OVERWRITE the `messages` key by doing the following:
+
+ ```python
+ {
+ "messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *new_messages]
+ ...
+ }
+ ```
+ post_model_hook: An optional node to add after the `agent` node (i.e., the node that calls the LLM).
+ Useful for implementing human-in-the-loop, guardrails, validation, or other post-processing.
+ Post-model hook must be a callable or a runnable that takes in current graph state and returns a state update.
+
+ !!! Note
+ Only available with `version="v2"`.
+ state_schema: An optional state schema that defines graph state.
+ Must have `messages` and `remaining_steps` keys.
+ Defaults to `AgentState` that defines those two keys.
+ context_schema: An optional schema for runtime context.
+ checkpointer: An optional checkpoint saver object. This is used for persisting
+ the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation).
+ store: An optional store object. This is used for persisting data
+ across multiple threads (e.g., multiple conversations / users).
+ interrupt_before: An optional list of node names to interrupt before.
+ Should be one of the following: "agent", "tools".
+ This is useful if you want to add a user confirmation or other interrupt before taking an action.
+ interrupt_after: An optional list of node names to interrupt after.
+ Should be one of the following: "agent", "tools".
+ This is useful if you want to return directly or run additional processing on an output.
+ debug: A flag indicating whether to enable debug mode.
+ version: Determines the version of the graph to create.
+ Can be one of:
+
+ - `"v1"`: The tool node processes a single message. All tool
+ calls in the message are executed in parallel within the tool node.
+ - `"v2"`: The tool node processes a tool call.
+ Tool calls are distributed across multiple instances of the tool
+ node using the [Send](https://langchain-ai.github.io/langgraph/concepts/low_level/#send)
+ API.
+ name: An optional name for the CompiledStateGraph.
+ This name will be automatically used when adding ReAct agent graph to another graph as a subgraph node -
+ particularly useful for building multi-agent systems.
+
+ !!! warning "`config_schema` Deprecated"
+ The `config_schema` parameter is deprecated in v0.6.0 and support will be removed in v2.0.0.
+ Please use `context_schema` instead to specify the schema for run-scoped context.
+
+
+ Returns:
+ A compiled LangChain runnable that can be used for chat interactions.
+
+ The "agent" node calls the language model with the messages list (after applying the prompt).
+ If the resulting AIMessage contains `tool_calls`, the graph will then call the ["tools"][langgraph.prebuilt.tool_node.ToolNode].
+ The "tools" node executes the tools (1 tool per `tool_call`) and adds the responses to the messages list
+ as `ToolMessage` objects. The agent node then calls the language model again.
+ The process repeats until no more `tool_calls` are present in the response.
+ The agent then returns the full list of messages as a dictionary containing the key "messages".
+
+ ``` mermaid
+ sequenceDiagram
+ participant U as User
+ participant A as LLM
+ participant T as Tools
+ U->>A: Initial input
+ Note over A: Prompt + LLM
+ loop while tool_calls present
+ A->>T: Execute tools
+ T-->>A: ToolMessage for each tool_calls
+ end
+ A->>U: Return final state
+ ```
+
+ Example:
+ ```python
+ from langchain.agents import create_react_agent
+
+ def check_weather(location: str) -> str:
+ '''Return the weather forecast for the specified location.'''
+ return f"It's always sunny in {location}"
+
+ graph = create_react_agent(
+ "anthropic:claude-3-7-sonnet-latest",
+ tools=[check_weather],
+ prompt="You are a helpful assistant",
+ )
+ inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
+ for chunk in graph.stream(inputs, stream_mode="updates"):
+ print(chunk)
+ ```
+ """
+ # Handle deprecated config_schema parameter
+ if (config_schema := deprecated_kwargs.pop("config_schema", MISSING)) is not MISSING:
+ warn(
+ "`config_schema` is deprecated and will be removed. Please use `context_schema` instead.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+ if context_schema is None:
+ context_schema = config_schema
+
+ if len(deprecated_kwargs) > 0:
+ msg = f"create_react_agent() got unexpected keyword arguments: {deprecated_kwargs}"
+ raise TypeError(msg)
+
+ if response_format and not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
+ if _supports_native_structured_output(model):
+ response_format = ProviderStrategy(
+ schema=response_format,
+ )
+ else:
+ response_format = ToolStrategy(
+ schema=response_format,
+ )
+ elif isinstance(response_format, tuple) and len(response_format) == 2:
+ msg = "Passing a 2-tuple as response_format is no longer supported. "
+ raise ValueError(msg)
+
+ # Create and configure the agent builder
+ builder = _AgentBuilder(
+ model=model,
+ tools=tools,
+ prompt=prompt,
+ response_format=cast("Union[ResponseFormat[StructuredResponseT], None]", response_format),
+ pre_model_hook=pre_model_hook,
+ post_model_hook=post_model_hook,
+ state_schema=state_schema,
+ context_schema=context_schema,
+ version=version,
+ name=name,
+ store=store,
+ )
+
+ # Build and compile the workflow
+ workflow = builder.build()
+ return workflow.compile( # type: ignore[return-value]
+ checkpointer=checkpointer,
+ store=store,
+ interrupt_before=interrupt_before,
+ interrupt_after=interrupt_after,
+ debug=debug,
+ name=name,
+ )
+
+
+__all__ = [
+ "AgentState",
+ "AgentStatePydantic",
+ "AgentStateWithStructuredResponse",
+ "AgentStateWithStructuredResponsePydantic",
+ "create_react_agent",
+]
diff --git a/libs/langchain_v1/langchain/agents/structured_output.py b/libs/langchain_v1/langchain/agents/structured_output.py
new file mode 100644
index 00000000000..dbbd3eef609
--- /dev/null
+++ b/libs/langchain_v1/langchain/agents/structured_output.py
@@ -0,0 +1,403 @@
+"""Types for setting agent response formats."""
+
+from __future__ import annotations
+
+import uuid
+from dataclasses import dataclass, is_dataclass
+from types import UnionType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Generic,
+ Literal,
+ TypeVar,
+ Union,
+ get_args,
+ get_origin,
+)
+
+from langchain_core.tools import BaseTool, StructuredTool
+from pydantic import BaseModel, TypeAdapter
+from typing_extensions import Self, is_typeddict
+
+if TYPE_CHECKING:
+ from collections.abc import Callable, Iterable
+
+ from langchain_core.messages import AIMessage
+
+# Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts
+SchemaT = TypeVar("SchemaT")
+
+SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
+
+
+class StructuredOutputError(Exception):
+ """Base class for structured output errors."""
+
+
+class MultipleStructuredOutputsError(StructuredOutputError):
+ """Raised when model returns multiple structured output tool calls when only one is expected."""
+
+ def __init__(self, tool_names: list[str]) -> None:
+ """Initialize MultipleStructuredOutputsError.
+
+ Args:
+ tool_names: The names of the tools called for structured output.
+ """
+ self.tool_names = tool_names
+
+ super().__init__(
+ f"Model incorrectly returned multiple structured responses ({', '.join(tool_names)}) when only one is expected."
+ )
+
+
+class StructuredOutputValidationError(StructuredOutputError):
+ """Raised when structured output tool call arguments fail to parse according to the schema."""
+
+ def __init__(self, tool_name: str, source: Exception) -> None:
+ """Initialize StructuredOutputValidationError.
+
+ Args:
+ tool_name: The name of the tool that failed.
+ source: The exception that occurred.
+ """
+ self.tool_name = tool_name
+ self.source = source
+ super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.")
+
+
+def _parse_with_schema(
+ schema: Union[type[SchemaT], dict], schema_kind: SchemaKind, data: dict[str, Any]
+) -> Any:
+ """Parse data using for any supported schema type.
+
+ Args:
+ schema: The schema type (Pydantic model, dataclass, or TypedDict)
+ schema_kind: One of "pydantic", "dataclass", "typeddict", or "json_schema"
+ data: The data to parse
+
+ Returns:
+ The parsed instance according to the schema type
+
+ Raises:
+ ValueError: If parsing fails
+ """
+ if schema_kind == "json_schema":
+ return data
+ try:
+ adapter: TypeAdapter[SchemaT] = TypeAdapter(schema)
+ return adapter.validate_python(data)
+ except Exception as e:
+ schema_name = getattr(schema, "__name__", str(schema))
+ msg = f"Failed to parse data to {schema_name}: {e}"
+ raise ValueError(msg) from e
+
+
+@dataclass(init=False)
+class _SchemaSpec(Generic[SchemaT]):
+ """Describes a structured output schema."""
+
+ schema: type[SchemaT]
+ """The schema for the response, can be a Pydantic model, dataclass, TypedDict, or JSON schema dict."""
+
+ name: str
+ """Name of the schema, used for tool calling.
+
+ If not provided, the name will be the model name or "response_format" if it's a JSON schema.
+ """
+
+ description: str
+ """Custom description of the schema.
+
+ If not provided, provided will use the model's docstring.
+ """
+
+ schema_kind: SchemaKind
+ """The kind of schema."""
+
+ json_schema: dict[str, Any]
+ """JSON schema associated with the schema."""
+
+ strict: bool = False
+ """Whether to enforce strict validation of the schema."""
+
+ def __init__(
+ self,
+ schema: type[SchemaT],
+ *,
+ name: str | None = None,
+ description: str | None = None,
+ strict: bool = False,
+ ) -> None:
+ """Initialize SchemaSpec with schema and optional parameters."""
+ self.schema = schema
+
+ if name:
+ self.name = name
+ elif isinstance(schema, dict):
+ self.name = str(schema.get("title", f"response_format_{str(uuid.uuid4())[:4]}"))
+ else:
+ self.name = str(getattr(schema, "__name__", f"response_format_{str(uuid.uuid4())[:4]}"))
+
+ self.description = description or (
+ schema.get("description", "")
+ if isinstance(schema, dict)
+ else getattr(schema, "__doc__", None) or ""
+ )
+
+ self.strict = strict
+
+ if isinstance(schema, dict):
+ self.schema_kind = "json_schema"
+ self.json_schema = schema
+ elif isinstance(schema, type) and issubclass(schema, BaseModel):
+ self.schema_kind = "pydantic"
+ self.json_schema = schema.model_json_schema()
+ elif is_dataclass(schema):
+ self.schema_kind = "dataclass"
+ self.json_schema = TypeAdapter(schema).json_schema()
+ elif is_typeddict(schema):
+ self.schema_kind = "typeddict"
+ self.json_schema = TypeAdapter(schema).json_schema()
+ else:
+ msg = (
+ f"Unsupported schema type: {type(schema)}. "
+ f"Supported types: Pydantic models, dataclasses, TypedDicts, and JSON schema dicts."
+ )
+ raise ValueError(msg)
+
+
+@dataclass(init=False)
+class ToolStrategy(Generic[SchemaT]):
+ """Use a tool calling strategy for model responses."""
+
+ schema: type[SchemaT]
+ """Schema for the tool calls."""
+
+ schema_specs: list[_SchemaSpec[SchemaT]]
+ """Schema specs for the tool calls."""
+
+ tool_message_content: str | None
+ """The content of the tool message to be returned when the model calls an artificial structured output tool."""
+
+ handle_errors: Union[
+ bool,
+ str,
+ type[Exception],
+ tuple[type[Exception], ...],
+ Callable[[Exception], str],
+ ]
+ """Error handling strategy for structured output via ToolStrategy. Default is True.
+
+ - True: Catch all errors with default error template
+ - str: Catch all errors with this custom message
+ - type[Exception]: Only catch this exception type with default message
+ - tuple[type[Exception], ...]: Only catch these exception types with default message
+ - Callable[[Exception], str]: Custom function that returns error message
+ - False: No retry, let exceptions propagate
+ """
+
+ def __init__(
+ self,
+ schema: type[SchemaT],
+ *,
+ tool_message_content: str | None = None,
+ handle_errors: Union[
+ bool,
+ str,
+ type[Exception],
+ tuple[type[Exception], ...],
+ Callable[[Exception], str],
+ ] = True,
+ ) -> None:
+ """Initialize ToolStrategy with schemas, tool message content, and error handling strategy."""
+ self.schema = schema
+ self.tool_message_content = tool_message_content
+ self.handle_errors = handle_errors
+
+ def _iter_variants(schema: Any) -> Iterable[Any]:
+ """Yield leaf variants from Union and JSON Schema oneOf."""
+ if get_origin(schema) in (UnionType, Union):
+ for arg in get_args(schema):
+ yield from _iter_variants(arg)
+ return
+
+ if isinstance(schema, dict) and "oneOf" in schema:
+ for sub in schema.get("oneOf", []):
+ yield from _iter_variants(sub)
+ return
+
+ yield schema
+
+ self.schema_specs = [_SchemaSpec(s) for s in _iter_variants(schema)]
+
+
+@dataclass(init=False)
+class ProviderStrategy(Generic[SchemaT]):
+ """Use the model provider's native structured output method."""
+
+ schema: type[SchemaT]
+ """Schema for native mode."""
+
+ schema_spec: _SchemaSpec[SchemaT]
+ """Schema spec for native mode."""
+
+ def __init__(
+ self,
+ schema: type[SchemaT],
+ ) -> None:
+ """Initialize ProviderStrategy with schema."""
+ self.schema = schema
+ self.schema_spec = _SchemaSpec(schema)
+
+ def to_model_kwargs(self) -> dict[str, Any]:
+ """Convert to kwargs to bind to a model to force structured output."""
+ # OpenAI:
+ # - see https://platform.openai.com/docs/guides/structured-outputs
+ response_format = {
+ "type": "json_schema",
+ "json_schema": {
+ "name": self.schema_spec.name,
+ "schema": self.schema_spec.json_schema,
+ },
+ }
+ return {"response_format": response_format}
+
+
+@dataclass
+class OutputToolBinding(Generic[SchemaT]):
+ """Information for tracking structured output tool metadata.
+
+ This contains all necessary information to handle structured responses
+ generated via tool calls, including the original schema, its type classification,
+ and the corresponding tool implementation used by the tools strategy.
+ """
+
+ schema: type[SchemaT]
+ """The original schema provided for structured output (Pydantic model, dataclass, TypedDict, or JSON schema dict)."""
+
+ schema_kind: SchemaKind
+ """Classification of the schema type for proper response construction."""
+
+ tool: BaseTool
+ """LangChain tool instance created from the schema for model binding."""
+
+ @classmethod
+ def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
+ """Create an OutputToolBinding instance from a SchemaSpec.
+
+ Args:
+ schema_spec: The SchemaSpec to convert
+
+ Returns:
+ An OutputToolBinding instance with the appropriate tool created
+ """
+ return cls(
+ schema=schema_spec.schema,
+ schema_kind=schema_spec.schema_kind,
+ tool=StructuredTool(
+ args_schema=schema_spec.json_schema,
+ name=schema_spec.name,
+ description=schema_spec.description,
+ ),
+ )
+
+ def parse(self, tool_args: dict[str, Any]) -> SchemaT:
+ """Parse tool arguments according to the schema.
+
+ Args:
+ tool_args: The arguments from the tool call
+
+ Returns:
+ The parsed response according to the schema type
+
+ Raises:
+ ValueError: If parsing fails
+ """
+ return _parse_with_schema(self.schema, self.schema_kind, tool_args)
+
+
+@dataclass
+class ProviderStrategyBinding(Generic[SchemaT]):
+ """Information for tracking native structured output metadata.
+
+ This contains all necessary information to handle structured responses
+ generated via native provider output, including the original schema,
+ its type classification, and parsing logic for provider-enforced JSON.
+ """
+
+ schema: type[SchemaT]
+ """The original schema provided for structured output (Pydantic model, dataclass, TypedDict, or JSON schema dict)."""
+
+ schema_kind: SchemaKind
+ """Classification of the schema type for proper response construction."""
+
+ @classmethod
+ def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
+ """Create a ProviderStrategyBinding instance from a SchemaSpec.
+
+ Args:
+ schema_spec: The SchemaSpec to convert
+
+ Returns:
+ A ProviderStrategyBinding instance for parsing native structured output
+ """
+ return cls(
+ schema=schema_spec.schema,
+ schema_kind=schema_spec.schema_kind,
+ )
+
+ def parse(self, response: AIMessage) -> SchemaT:
+ """Parse AIMessage content according to the schema.
+
+ Args:
+ response: The AI message containing the structured output
+
+ Returns:
+ The parsed response according to the schema
+
+ Raises:
+ ValueError: If text extraction, JSON parsing or schema validation fails
+ """
+ # Extract text content from AIMessage and parse as JSON
+ raw_text = self._extract_text_content_from_message(response)
+
+ import json
+
+ try:
+ data = json.loads(raw_text)
+ except Exception as e:
+ schema_name = getattr(self.schema, "__name__", "response_format")
+ msg = f"Native structured output expected valid JSON for {schema_name}, but parsing failed: {e}."
+ raise ValueError(msg) from e
+
+ # Parse according to schema
+ return _parse_with_schema(self.schema, self.schema_kind, data)
+
+ def _extract_text_content_from_message(self, message: AIMessage) -> str:
+ """Extract text content from an AIMessage.
+
+ Args:
+ message: The AI message to extract text from
+
+ Returns:
+ The extracted text content
+ """
+ content = message.content
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ parts: list[str] = []
+ for c in content:
+ if isinstance(c, dict):
+ if c.get("type") == "text" and "text" in c:
+ parts.append(str(c["text"]))
+ elif "content" in c and isinstance(c["content"], str):
+ parts.append(c["content"])
+ else:
+ parts.append(str(c))
+ return "".join(parts)
+ return str(content)
+
+
+ResponseFormat = Union[ToolStrategy[SchemaT], ProviderStrategy[SchemaT]]
diff --git a/libs/langchain_v1/langchain/agents/tool_node.py b/libs/langchain_v1/langchain/agents/tool_node.py
new file mode 100644
index 00000000000..185a74a291d
--- /dev/null
+++ b/libs/langchain_v1/langchain/agents/tool_node.py
@@ -0,0 +1,1174 @@
+"""Tool execution node for LangGraph workflows.
+
+This module provides prebuilt functionality for executing tools in LangGraph.
+
+Tools are functions that models can call to interact with external systems,
+APIs, databases, or perform computations.
+
+The module implements several key design patterns:
+- Parallel execution of multiple tool calls for efficiency
+- Robust error handling with customizable error messages
+- State injection for tools that need access to graph state
+- Store injection for tools that need persistent storage
+- Command-based state updates for advanced control flow
+
+Key Components:
+ ToolNode: Main class for executing tools in LangGraph workflows
+ InjectedState: Annotation for injecting graph state into tools
+ InjectedStore: Annotation for injecting persistent store into tools
+ tools_condition: Utility function for conditional routing based on tool calls
+
+Typical Usage:
+ ```python
+ from langchain_core.tools import tool
+ from langchain.agents import ToolNode
+
+ @tool
+ def my_tool(x: int) -> str:
+ return f"Result: {x}"
+
+ tool_node = ToolNode([my_tool])
+ ```
+"""
+
+from __future__ import annotations
+
+import asyncio
+import inspect
+import json
+from copy import copy, deepcopy
+from dataclasses import replace
+from typing import (
+ TYPE_CHECKING,
+ Annotated,
+ Any,
+ Literal,
+ Optional,
+ Union,
+ cast,
+ get_args,
+ get_origin,
+ get_type_hints,
+)
+
+from langchain_core.messages import (
+ AIMessage,
+ AnyMessage,
+ RemoveMessage,
+ ToolCall,
+ ToolMessage,
+ convert_to_messages,
+)
+from langchain_core.runnables.config import (
+ get_config_list,
+ get_executor_for_config,
+)
+from langchain_core.tools import BaseTool, InjectedToolArg
+from langchain_core.tools import tool as create_tool
+from langchain_core.tools.base import (
+ TOOL_MESSAGE_BLOCK_TYPES,
+ get_all_basemodel_annotations,
+)
+from langgraph._internal._runnable import RunnableCallable
+from langgraph.errors import GraphBubbleUp
+from langgraph.graph.message import REMOVE_ALL_MESSAGES
+from langgraph.types import Command, Send
+from pydantic import BaseModel, ValidationError
+
+if TYPE_CHECKING:
+ from collections.abc import Callable, Sequence
+
+ from langchain_core.runnables import RunnableConfig
+ from langgraph.store.base import BaseStore
+
+INVALID_TOOL_NAME_ERROR_TEMPLATE = (
+ "Error: {requested_tool} is not a valid tool, try one of [{available_tools}]."
+)
+TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
+TOOL_EXECUTION_ERROR_TEMPLATE = "Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n {error}\n Please fix the error and try again."
+TOOL_INVOCATION_ERROR_TEMPLATE = "Error invoking tool '{tool_name}' with kwargs {tool_kwargs} with error:\n {error}\n Please fix the error and try again."
+
+
+def msg_content_output(output: Any) -> Union[str, list[dict]]:
+ """Convert tool output to valid message content format.
+
+ LangChain ToolMessages accept either string content or a list of content blocks.
+ This function ensures tool outputs are properly formatted for message consumption
+ by attempting to preserve structured data when possible, falling back to JSON
+ serialization or string conversion.
+
+ Args:
+ output: The raw output from a tool execution. Can be any type.
+
+ Returns:
+ Either a string representation of the output or a list of content blocks
+ if the output is already in the correct format for structured content.
+
+ Note:
+ This function prioritizes backward compatibility by defaulting to JSON
+ serialization rather than supporting all possible message content formats.
+ """
+ if isinstance(output, str) or (
+ isinstance(output, list)
+ and all(isinstance(x, dict) and x.get("type") in TOOL_MESSAGE_BLOCK_TYPES for x in output)
+ ):
+ return output
+ # Technically a list of strings is also valid message content, but it's
+ # not currently well tested that all chat models support this.
+ # And for backwards compatibility we want to make sure we don't break
+ # any existing ToolNode usage.
+ try:
+ return json.dumps(output, ensure_ascii=False)
+ except Exception: # noqa: BLE001
+ return str(output)
+
+
+class ToolInvocationError(Exception):
+ """Exception raised when a tool invocation fails due to invalid arguments."""
+
+ def __init__(
+ self, tool_name: str, source: ValidationError, tool_kwargs: dict[str, Any]
+ ) -> None:
+ """Initialize the ToolInvocationError.
+
+ Args:
+ tool_name: The name of the tool that failed.
+ source: The exception that occurred.
+ tool_kwargs: The keyword arguments that were passed to the tool.
+ """
+ self.message = TOOL_INVOCATION_ERROR_TEMPLATE.format(
+ tool_name=tool_name, tool_kwargs=tool_kwargs, error=source
+ )
+ self.tool_name = tool_name
+ self.tool_kwargs = tool_kwargs
+ self.source = source
+ super().__init__(self.message)
+
+
+def _default_handle_tool_errors(e: Exception) -> str:
+ """Default error handler for tool errors.
+
+ If the tool is a tool invocation error, return its message.
+ Otherwise, raise the error.
+ """
+ if isinstance(e, ToolInvocationError):
+ return e.message
+ raise e
+
+
+def _handle_tool_error(
+ e: Exception,
+ *,
+ flag: Union[
+ bool,
+ str,
+ Callable[..., str],
+ type[Exception],
+ tuple[type[Exception], ...],
+ ],
+) -> str:
+ """Generate error message content based on exception handling configuration.
+
+ This function centralizes error message generation logic, supporting different
+ error handling strategies configured via the ToolNode's handle_tool_errors
+ parameter.
+
+ Args:
+ e: The exception that occurred during tool execution.
+ flag: Configuration for how to handle the error. Can be:
+ - bool: If True, use default error template
+ - str: Use this string as the error message
+ - Callable: Call this function with the exception to get error message
+ - tuple: Not used in this context (handled by caller)
+
+ Returns:
+ A string containing the error message to include in the ToolMessage.
+
+ Raises:
+ ValueError: If flag is not one of the supported types.
+
+ Note:
+ The tuple case is handled by the caller through exception type checking,
+ not by this function directly.
+ """
+ if isinstance(flag, (bool, tuple)) or (isinstance(flag, type) and issubclass(flag, Exception)):
+ content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
+ elif isinstance(flag, str):
+ content = flag
+ elif callable(flag):
+ content = flag(e) # type: ignore [assignment, call-arg]
+ else:
+ msg = (
+ f"Got unexpected type of `handle_tool_error`. Expected bool, str "
+ f"or callable. Received: {flag}"
+ )
+ raise ValueError(msg)
+ return content
+
+
+def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], ...]:
+ """Infer exception types handled by a custom error handler function.
+
+ This function analyzes the type annotations of a custom error handler to determine
+ which exception types it's designed to handle. This enables type-safe error handling
+ where only specific exceptions are caught and processed by the handler.
+
+ Args:
+ handler: A callable that takes an exception and returns an error message string.
+ The first parameter (after self/cls if present) should be type-annotated
+ with the exception type(s) to handle.
+
+ Returns:
+ A tuple of exception types that the handler can process. Returns (Exception,)
+ if no specific type information is available for backward compatibility.
+
+ Raises:
+ ValueError: If the handler's annotation contains non-Exception types or
+ if Union types contain non-Exception types.
+
+ Note:
+ This function supports both single exception types and Union types for
+ handlers that need to handle multiple exception types differently.
+ """
+ sig = inspect.signature(handler)
+ params = list(sig.parameters.values())
+ if params:
+ # If it's a method, the first argument is typically 'self' or 'cls'
+ if params[0].name in ["self", "cls"] and len(params) == 2:
+ first_param = params[1]
+ else:
+ first_param = params[0]
+
+ type_hints = get_type_hints(handler)
+ if first_param.name in type_hints:
+ origin = get_origin(first_param.annotation)
+ if origin is Union:
+ args = get_args(first_param.annotation)
+ if all(issubclass(arg, Exception) for arg in args):
+ return tuple(args)
+ msg = (
+ "All types in the error handler error annotation must be "
+ "Exception types. For example, "
+ "`def custom_handler(e: Union[ValueError, TypeError])`. "
+ f"Got '{first_param.annotation}' instead."
+ )
+ raise ValueError(msg)
+
+ exception_type = type_hints[first_param.name]
+ if Exception in exception_type.__mro__:
+ return (exception_type,)
+ msg = (
+ f"Arbitrary types are not supported in the error handler "
+ f"signature. Please annotate the error with either a "
+ f"specific Exception type or a union of Exception types. "
+ "For example, `def custom_handler(e: ValueError)` or "
+ "`def custom_handler(e: Union[ValueError, TypeError])`. "
+ f"Got '{exception_type}' instead."
+ )
+ raise ValueError(msg)
+
+ # If no type information is available, return (Exception,)
+ # for backwards compatibility.
+ return (Exception,)
+
+
+class ToolNode(RunnableCallable):
+ """A node for executing tools in LangGraph workflows.
+
+ Handles tool execution patterns including function calls, state injection,
+ persistent storage, and control flow. Manages parallel execution,
+ error handling.
+
+ Input Formats:
+ 1. Graph state with `messages` key that has a list of messages:
+ - Common representation for agentic workflows
+ - Supports custom messages key via ``messages_key`` parameter
+
+ 2. **Message List**: ``[AIMessage(..., tool_calls=[...])]``
+ - List of messages with tool calls in the last AIMessage
+
+ 3. **Direct Tool Calls**: ``[{"name": "tool", "args": {...}, "id": "1", "type": "tool_call"}]``
+ - Bypasses message parsing for direct tool execution
+ - For programmatic tool invocation and testing
+
+ Output Formats:
+ Output format depends on input type and tool behavior:
+
+ **For Regular tools**:
+ - Dict input → ``{"messages": [ToolMessage(...)]}``
+ - List input → ``[ToolMessage(...)]``
+
+ **For Command tools**:
+ - Returns ``[Command(...)]`` or mixed list with regular tool outputs
+ - Commands can update state, trigger navigation, or send messages
+
+ Args:
+ tools: A sequence of tools that can be invoked by this node. Supports:
+ - **BaseTool instances**: Tools with schemas and metadata
+ - **Plain functions**: Automatically converted to tools with inferred schemas
+ name: The name identifier for this node in the graph. Used for debugging
+ and visualization. Defaults to "tools".
+ tags: Optional metadata tags to associate with the node for filtering
+ and organization. Defaults to None.
+ handle_tool_errors: Configuration for error handling during tool execution.
+ Supports multiple strategies:
+
+ - **True**: Catch all errors and return a ToolMessage with the default
+ error template containing the exception details.
+ - **str**: Catch all errors and return a ToolMessage with this custom
+ error message string.
+ - **type[Exception]**: Only catch exceptions with the specified type and return the default error message for it.
+ - **tuple[type[Exception], ...]**: Only catch exceptions with the specified
+ types and return default error messages for them.
+ - **Callable[..., str]**: Catch exceptions matching the callable's signature
+ and return the string result of calling it with the exception.
+ - **False**: Disable error handling entirely, allowing exceptions to
+ propagate.
+
+ Defaults to a callable that:
+ - catches tool invocation errors (due to invalid arguments provided by the model) and returns a descriptive error message
+ - ignores tool execution errors (they will be re-raised)
+
+ messages_key: The key in the state dictionary that contains the message list.
+ This same key will be used for the output ToolMessages.
+ Defaults to "messages".
+ Allows custom state schemas with different message field names.
+
+ Examples:
+ Basic usage:
+
+ ```python
+ from langchain.agents import ToolNode
+ from langchain_core.tools import tool
+
+ @tool
+ def calculator(a: int, b: int) -> int:
+ \"\"\"Add two numbers.\"\"\"
+ return a + b
+
+ tool_node = ToolNode([calculator])
+ ```
+
+ State injection:
+
+ ```python
+ from typing_extensions import Annotated
+ from langgraph.agents.tool_node import InjectedState
+
+ @tool
+ def context_tool(query: str, state: Annotated[dict, InjectedState]) -> str:
+ \"\"\"Some tool that uses state.\"\"\"
+ return f"Query: {query}, Messages: {len(state['messages'])}"
+
+ tool_node = ToolNode([context_tool])
+ ```
+
+ Error handling:
+
+ ```python
+ def handle_errors(e: ValueError) -> str:
+ return "Invalid input provided"
+
+ tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors)
+ ```
+ """
+
+ name: str = "tools"
+
+ def __init__(
+ self,
+ tools: Sequence[Union[BaseTool, Callable]],
+ *,
+ name: str = "tools",
+ tags: list[str] | None = None,
+ handle_tool_errors: Union[
+ bool, str, Callable[..., str], type[Exception], tuple[type[Exception], ...]
+ ] = _default_handle_tool_errors,
+ messages_key: str = "messages",
+ ) -> None:
+ """Initialize the ToolNode with the provided tools and configuration.
+
+ Args:
+ tools: Sequence of tools to make available for execution.
+ name: Node name for graph identification.
+ tags: Optional metadata tags.
+ handle_tool_errors: Error handling configuration.
+ messages_key: State key containing messages.
+ """
+ super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
+ self._tools_by_name: dict[str, BaseTool] = {}
+ self._tool_to_state_args: dict[str, dict[str, str | None]] = {}
+ self._tool_to_store_arg: dict[str, str | None] = {}
+ self._handle_tool_errors = handle_tool_errors
+ self._messages_key = messages_key
+ for tool in tools:
+ if not isinstance(tool, BaseTool):
+ tool_ = create_tool(cast("type[BaseTool]", tool))
+ else:
+ tool_ = tool
+ self._tools_by_name[tool_.name] = tool_
+ self._tool_to_state_args[tool_.name] = _get_state_args(tool_)
+ self._tool_to_store_arg[tool_.name] = _get_store_arg(tool_)
+
+ @property
+ def tools_by_name(self) -> dict[str, BaseTool]:
+ """Mapping from tool name to BaseTool instance."""
+ return self._tools_by_name
+
+ def _func(
+ self,
+ input: Union[
+ list[AnyMessage],
+ dict[str, Any],
+ BaseModel,
+ ],
+ config: RunnableConfig,
+ *,
+ store: Optional[BaseStore], # noqa: UP045
+ ) -> Any:
+ tool_calls, input_type = self._parse_input(input, store)
+ config_list = get_config_list(config, len(tool_calls))
+ input_types = [input_type] * len(tool_calls)
+ with get_executor_for_config(config) as executor:
+ outputs = [*executor.map(self._run_one, tool_calls, input_types, config_list)]
+
+ return self._combine_tool_outputs(outputs, input_type)
+
+ async def _afunc(
+ self,
+ input: Union[
+ list[AnyMessage],
+ dict[str, Any],
+ BaseModel,
+ ],
+ config: RunnableConfig,
+ *,
+ store: Optional[BaseStore], # noqa: UP045
+ ) -> Any:
+ tool_calls, input_type = self._parse_input(input, store)
+ outputs = await asyncio.gather(
+ *(self._arun_one(call, input_type, config) for call in tool_calls)
+ )
+
+ return self._combine_tool_outputs(outputs, input_type)
+
+ def _combine_tool_outputs(
+ self,
+ outputs: list[Union[ToolMessage, Command]],
+ input_type: Literal["list", "dict", "tool_calls"],
+ ) -> list[Union[Command, list[ToolMessage], dict[str, list[ToolMessage]]]]:
+ # preserve existing behavior for non-command tool outputs for backwards
+ # compatibility
+ if not any(isinstance(output, Command) for output in outputs):
+ # TypedDict, pydantic, dataclass, etc. should all be able to load from dict
+ return outputs if input_type == "list" else {self._messages_key: outputs} # type: ignore[return-value, return-value]
+
+ # LangGraph will automatically handle list of Command and non-command node
+ # updates
+ combined_outputs: list[Command | list[ToolMessage] | dict[str, list[ToolMessage]]] = []
+
+ # combine all parent commands with goto into a single parent command
+ parent_command: Command | None = None
+ for output in outputs:
+ if isinstance(output, Command):
+ if (
+ output.graph is Command.PARENT
+ and isinstance(output.goto, list)
+ and all(isinstance(send, Send) for send in output.goto)
+ ):
+ if parent_command:
+ parent_command = replace(
+ parent_command,
+ goto=cast("list[Send]", parent_command.goto) + output.goto,
+ )
+ else:
+ parent_command = Command(graph=Command.PARENT, goto=output.goto)
+ else:
+ combined_outputs.append(output)
+ else:
+ combined_outputs.append(
+ [output] if input_type == "list" else {self._messages_key: [output]}
+ )
+
+ if parent_command:
+ combined_outputs.append(parent_command)
+ return combined_outputs
+
+ def _run_one(
+ self,
+ call: ToolCall,
+ input_type: Literal["list", "dict", "tool_calls"],
+ config: RunnableConfig,
+ ) -> Union[ToolMessage, Command]:
+ """Run a single tool call synchronously."""
+ if invalid_tool_message := self._validate_tool_call(call):
+ return invalid_tool_message
+
+ try:
+ call_args = {**call, "type": "tool_call"}
+ tool = self.tools_by_name[call["name"]]
+
+ try:
+ response = tool.invoke(call_args, config)
+ except ValidationError as exc:
+ raise ToolInvocationError(call["name"], exc, call["args"]) from exc
+
+ # GraphInterrupt is a special exception that will always be raised.
+ # It can be triggered in the following scenarios,
+ # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation most commonly:
+ # (1) a GraphInterrupt is raised inside a tool
+ # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
+ # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
+ # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
+ except GraphBubbleUp:
+ raise
+ except Exception as e:
+ handled_types: tuple[type[Exception], ...]
+ if isinstance(self._handle_tool_errors, type) and issubclass(
+ self._handle_tool_errors, Exception
+ ):
+ handled_types = (self._handle_tool_errors,)
+ elif isinstance(self._handle_tool_errors, tuple):
+ handled_types = self._handle_tool_errors
+ elif callable(self._handle_tool_errors) and not isinstance(
+ self._handle_tool_errors, type
+ ):
+ handled_types = _infer_handled_types(self._handle_tool_errors)
+ else:
+ # default behavior is catching all exceptions
+ handled_types = (Exception,)
+
+ # Unhandled
+ if not self._handle_tool_errors or not isinstance(e, handled_types):
+ raise
+ # Handled
+ content = _handle_tool_error(e, flag=self._handle_tool_errors)
+ return ToolMessage(
+ content=content,
+ name=call["name"],
+ tool_call_id=call["id"],
+ status="error",
+ )
+
+ if isinstance(response, Command):
+ return self._validate_tool_command(response, call, input_type)
+ if isinstance(response, ToolMessage):
+ response.content = cast("Union[str, list]", msg_content_output(response.content))
+ return response
+ msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
+ raise TypeError(msg)
+
+ async def _arun_one(
+ self,
+ call: ToolCall,
+ input_type: Literal["list", "dict", "tool_calls"],
+ config: RunnableConfig,
+ ) -> Union[ToolMessage, Command]:
+ """Run a single tool call asynchronously."""
+ if invalid_tool_message := self._validate_tool_call(call):
+ return invalid_tool_message
+
+ try:
+ call_args = {**call, "type": "tool_call"}
+ tool = self.tools_by_name[call["name"]]
+
+ try:
+ response = await tool.ainvoke(call_args, config)
+ except ValidationError as exc:
+ raise ToolInvocationError(call["name"], exc, call["args"]) from exc
+
+ # GraphInterrupt is a special exception that will always be raised.
+ # It can be triggered in the following scenarios,
+ # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation most commonly:
+ # (1) a GraphInterrupt is raised inside a tool
+ # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
+ # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
+ # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
+ except GraphBubbleUp:
+ raise
+ except Exception as e:
+ handled_types: tuple[type[Exception], ...]
+ if isinstance(self._handle_tool_errors, type) and issubclass(
+ self._handle_tool_errors, Exception
+ ):
+ handled_types = (self._handle_tool_errors,)
+ elif isinstance(self._handle_tool_errors, tuple):
+ handled_types = self._handle_tool_errors
+ elif callable(self._handle_tool_errors) and not isinstance(
+ self._handle_tool_errors, type
+ ):
+ handled_types = _infer_handled_types(self._handle_tool_errors)
+ else:
+ # default behavior is catching all exceptions
+ handled_types = (Exception,)
+
+ # Unhandled
+ if not self._handle_tool_errors or not isinstance(e, handled_types):
+ raise
+ # Handled
+ content = _handle_tool_error(e, flag=self._handle_tool_errors)
+
+ return ToolMessage(
+ content=content,
+ name=call["name"],
+ tool_call_id=call["id"],
+ status="error",
+ )
+
+ if isinstance(response, Command):
+ return self._validate_tool_command(response, call, input_type)
+ if isinstance(response, ToolMessage):
+ response.content = cast("Union[str, list]", msg_content_output(response.content))
+ return response
+ msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
+ raise TypeError(msg)
+
+ def _parse_input(
+ self,
+ input: Union[
+ list[AnyMessage],
+ dict[str, Any],
+ BaseModel,
+ ],
+ store: BaseStore | None,
+ ) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
+ input_type: Literal["list", "dict", "tool_calls"]
+ if isinstance(input, list):
+ if isinstance(input[-1], dict) and input[-1].get("type") == "tool_call":
+ input_type = "tool_calls"
+ tool_calls = cast("list[ToolCall]", input)
+ return tool_calls, input_type
+ input_type = "list"
+ messages = input
+ elif isinstance(input, dict) and (messages := input.get(self._messages_key, [])):
+ input_type = "dict"
+ elif messages := getattr(input, self._messages_key, []):
+ # Assume dataclass-like state that can coerce from dict
+ input_type = "dict"
+ else:
+ msg = "No message found in input"
+ raise ValueError(msg)
+
+ try:
+ latest_ai_message = next(m for m in reversed(messages) if isinstance(m, AIMessage))
+ except StopIteration:
+ msg = "No AIMessage found in input"
+ raise ValueError(msg)
+
+ tool_calls = [
+ self.inject_tool_args(call, input, store) for call in latest_ai_message.tool_calls
+ ]
+ return tool_calls, input_type
+
+ def _validate_tool_call(self, call: ToolCall) -> ToolMessage | None:
+ requested_tool = call["name"]
+ if requested_tool not in self.tools_by_name:
+ all_tool_names = list(self.tools_by_name.keys())
+ content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format(
+ requested_tool=requested_tool,
+ available_tools=", ".join(all_tool_names),
+ )
+ return ToolMessage(
+ content, name=requested_tool, tool_call_id=call["id"], status="error"
+ )
+ return None
+
+ def _inject_state(
+ self,
+ tool_call: ToolCall,
+ input: Union[
+ list[AnyMessage],
+ dict[str, Any],
+ BaseModel,
+ ],
+ ) -> ToolCall:
+ state_args = self._tool_to_state_args[tool_call["name"]]
+ if state_args and isinstance(input, list):
+ required_fields = list(state_args.values())
+ if (
+ len(required_fields) == 1 and required_fields[0] == self._messages_key
+ ) or required_fields[0] is None:
+ input = {self._messages_key: input}
+ else:
+ err_msg = (
+ f"Invalid input to ToolNode. Tool {tool_call['name']} requires "
+ f"graph state dict as input."
+ )
+ if any(state_field for state_field in state_args.values()):
+ required_fields_str = ", ".join(f for f in required_fields if f)
+ err_msg += f" State should contain fields {required_fields_str}."
+ raise ValueError(err_msg)
+
+ if isinstance(input, dict):
+ tool_state_args = {
+ tool_arg: input[state_field] if state_field else input
+ for tool_arg, state_field in state_args.items()
+ }
+ else:
+ tool_state_args = {
+ tool_arg: getattr(input, state_field) if state_field else input
+ for tool_arg, state_field in state_args.items()
+ }
+
+ tool_call["args"] = {
+ **tool_call["args"],
+ **tool_state_args,
+ }
+ return tool_call
+
+ def _inject_store(self, tool_call: ToolCall, store: BaseStore | None) -> ToolCall:
+ store_arg = self._tool_to_store_arg[tool_call["name"]]
+ if not store_arg:
+ return tool_call
+
+ if store is None:
+ msg = (
+ "Cannot inject store into tools with InjectedStore annotations - "
+ "please compile your graph with a store."
+ )
+ raise ValueError(msg)
+
+ tool_call["args"] = {
+ **tool_call["args"],
+ store_arg: store,
+ }
+ return tool_call
+
+ def inject_tool_args(
+ self,
+ tool_call: ToolCall,
+ input: Union[
+ list[AnyMessage],
+ dict[str, Any],
+ BaseModel,
+ ],
+ store: BaseStore | None,
+ ) -> ToolCall:
+ """Inject graph state and store into tool call arguments.
+
+ This method enables tools to access graph context that should not be controlled
+ by the model. Tools can declare dependencies on graph state or persistent storage
+ using InjectedState and InjectedStore annotations. This method automatically
+ identifies these dependencies and injects the appropriate values.
+
+ The injection process preserves the original tool call structure while adding
+ the necessary context arguments. This allows tools to be both model-callable
+ and context-aware without exposing internal state management to the model.
+
+ Args:
+ tool_call: The tool call dictionary to augment with injected arguments.
+ Must contain 'name', 'args', 'id', and 'type' fields.
+ input: The current graph state to inject into tools requiring state access.
+ Can be a message list, state dictionary, or BaseModel instance.
+ store: The persistent store instance to inject into tools requiring storage.
+ Will be None if no store is configured for the graph.
+
+ Returns:
+ A new ToolCall dictionary with the same structure as the input but with
+ additional arguments injected based on the tool's annotation requirements.
+
+ Raises:
+ ValueError: If a tool requires store injection but no store is provided,
+ or if state injection requirements cannot be satisfied.
+
+ Note:
+ This method is automatically called during tool execution but can also
+ be used manually when working with the Send API or custom routing logic.
+ The injection is performed on a copy of the tool call to avoid mutating
+ the original.
+ """
+ if tool_call["name"] not in self.tools_by_name:
+ return tool_call
+
+ tool_call_copy: ToolCall = copy(tool_call)
+ tool_call_with_state = self._inject_state(tool_call_copy, input)
+ return self._inject_store(tool_call_with_state, store)
+
+ def _validate_tool_command(
+ self,
+ command: Command,
+ call: ToolCall,
+ input_type: Literal["list", "dict", "tool_calls"],
+ ) -> Command:
+ if isinstance(command.update, dict):
+ # input type is dict when ToolNode is invoked with a dict input (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
+ if input_type not in ("dict", "tool_calls"):
+ msg = (
+ f"Tools can provide a dict in Command.update only when using dict with '{self._messages_key}' key as ToolNode input, "
+ f"got: {command.update} for tool '{call['name']}'"
+ )
+ raise ValueError(msg)
+
+ updated_command = deepcopy(command)
+ state_update = cast("dict[str, Any]", updated_command.update) or {}
+ messages_update = state_update.get(self._messages_key, [])
+ elif isinstance(command.update, list):
+ # Input type is list when ToolNode is invoked with a list input (e.g. [AIMessage(..., tool_calls=[...])])
+ if input_type != "list":
+ msg = (
+ f"Tools can provide a list of messages in Command.update only when using list of messages as ToolNode input, "
+ f"got: {command.update} for tool '{call['name']}'"
+ )
+ raise ValueError(msg)
+
+ updated_command = deepcopy(command)
+ messages_update = updated_command.update
+ else:
+ return command
+
+ # convert to message objects if updates are in a dict format
+ messages_update = convert_to_messages(messages_update)
+
+ # no validation needed if all messages are being removed
+ if messages_update == [RemoveMessage(id=REMOVE_ALL_MESSAGES)]:
+ return updated_command
+
+ has_matching_tool_message = False
+ for message in messages_update:
+ if not isinstance(message, ToolMessage):
+ continue
+
+ if message.tool_call_id == call["id"]:
+ message.name = call["name"]
+ has_matching_tool_message = True
+
+ # validate that we always have a ToolMessage matching the tool call in
+ # Command.update if command is sent to the CURRENT graph
+ if updated_command.graph is None and not has_matching_tool_message:
+ example_update = (
+ '`Command(update={"messages": [ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
+ if input_type == "dict"
+ else '`Command(update=[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
+ )
+ msg = (
+ f"Expected to have a matching ToolMessage in Command.update for tool '{call['name']}', got: {messages_update}. "
+ "Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage. "
+ f"You can fix it by modifying the tool to return {example_update}."
+ )
+ raise ValueError(msg)
+ return updated_command
+
+
+def tools_condition(
+ state: Union[list[AnyMessage], dict[str, Any], BaseModel],
+ messages_key: str = "messages",
+) -> Literal["tools", "__end__"]:
+ """Conditional routing function for tool-calling workflows.
+
+ This utility function implements the standard conditional logic for ReAct-style
+ agents: if the last AI message contains tool calls, route to the tool execution
+ node; otherwise, end the workflow. This pattern is fundamental to most tool-calling
+ agent architectures.
+
+ The function handles multiple state formats commonly used in LangGraph applications,
+ making it flexible for different graph designs while maintaining consistent behavior.
+
+ Args:
+ state: The current graph state to examine for tool calls. Supported formats:
+ - Dictionary containing a messages key (for StateGraph)
+ - BaseModel instance with a messages attribute
+ messages_key: The key or attribute name containing the message list in the state.
+ This allows customization for graphs using different state schemas.
+ Defaults to "messages".
+
+ Returns:
+ Either "tools" if tool calls are present in the last AI message, or "__end__"
+ to terminate the workflow. These are the standard routing destinations for
+ tool-calling conditional edges.
+
+ Raises:
+ ValueError: If no messages can be found in the provided state format.
+
+ Example:
+ Basic usage in a ReAct agent:
+
+ ```python
+ from langgraph.graph import StateGraph
+ from langgraph.agents.tool_node import ToolNode, tools_condition
+ from typing_extensions import TypedDict
+
+ class State(TypedDict):
+ messages: list
+
+ graph = StateGraph(State)
+ graph.add_node("llm", call_model)
+ graph.add_node("tools", ToolNode([my_tool]))
+ graph.add_conditional_edges(
+ "llm",
+ tools_condition, # Routes to "tools" or "__end__"
+ {"tools": "tools", "__end__": "__end__"}
+ )
+ ```
+
+ Custom messages key:
+
+ ```python
+ def custom_condition(state):
+ return tools_condition(state, messages_key="chat_history")
+ ```
+
+ Note:
+ This function is designed to work seamlessly with ToolNode and standard
+ LangGraph patterns. It expects the last message to be an AIMessage when
+ tool calls are present, which is the standard output format for tool-calling
+ language models.
+ """
+ if isinstance(state, list):
+ ai_message = state[-1]
+ elif (isinstance(state, dict) and (messages := state.get(messages_key, []))) or (
+ messages := getattr(state, messages_key, [])
+ ):
+ ai_message = messages[-1]
+ else:
+ msg = f"No messages found in input state to tool_edge: {state}"
+ raise ValueError(msg)
+ if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
+ return "tools"
+ return "__end__"
+
+
+class InjectedState(InjectedToolArg):
+ """Annotation for injecting graph state into tool arguments.
+
+ This annotation enables tools to access graph state without exposing state
+ management details to the language model. Tools annotated with InjectedState
+ receive state data automatically during execution while remaining invisible
+ to the model's tool-calling interface.
+
+ Args:
+ field: Optional key to extract from the state dictionary. If None, the entire
+ state is injected. If specified, only that field's value is injected.
+ This allows tools to request specific state components rather than
+ processing the full state structure.
+
+ Example:
+ ```python
+ from typing import List
+ from typing_extensions import Annotated, TypedDict
+
+ from langchain_core.messages import BaseMessage, AIMessage
+ from langchain_core.tools import tool
+
+ from langgraph.agents.tool_node import InjectedState, ToolNode
+
+
+ class AgentState(TypedDict):
+ messages: List[BaseMessage]
+ foo: str
+
+ @tool
+ def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str:
+ '''Do something with state.'''
+ if len(state["messages"]) > 2:
+ return state["foo"] + str(x)
+ else:
+ return "not enough messages"
+
+ @tool
+ def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str:
+ '''Do something else with state.'''
+ return foo + str(x + 1)
+
+ node = ToolNode([state_tool, foo_tool])
+
+ tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"}
+ tool_call2 = {"name": "foo_tool", "args": {"x": 1}, "id": "2", "type": "tool_call"}
+ state = {
+ "messages": [AIMessage("", tool_calls=[tool_call1, tool_call2])],
+ "foo": "bar",
+ }
+ node.invoke(state)
+ ```
+
+ ```pycon
+ [
+ ToolMessage(content='not enough messages', name='state_tool', tool_call_id='1'),
+ ToolMessage(content='bar2', name='foo_tool', tool_call_id='2')
+ ]
+ ```
+
+ Note:
+ - InjectedState arguments are automatically excluded from tool schemas
+ presented to language models
+ - ToolNode handles the injection process during execution
+ - Tools can mix regular arguments (controlled by the model) with injected
+ arguments (controlled by the system)
+ - State injection occurs after the model generates tool calls but before
+ tool execution
+ """
+
+ def __init__(self, field: str | None = None) -> None:
+ """Initialize the InjectedState annotation."""
+ self.field = field
+
+
+class InjectedStore(InjectedToolArg):
+ """Annotation for injecting persistent store into tool arguments.
+
+ This annotation enables tools to access LangGraph's persistent storage system
+ without exposing storage details to the language model. Tools annotated with
+ InjectedStore receive the store instance automatically during execution while
+ remaining invisible to the model's tool-calling interface.
+
+ The store provides persistent, cross-session data storage that tools can use
+ for maintaining context, user preferences, or any other data that needs to
+ persist beyond individual workflow executions.
+
+ !!! Warning
+ `InjectedStore` annotation requires `langchain-core >= 0.3.8`
+
+ Example:
+ ```python
+ from typing_extensions import Annotated
+ from langchain_core.tools import tool
+ from langgraph.store.memory import InMemoryStore
+ from langgraph.agents.tool_node import InjectedStore, ToolNode
+
+ @tool
+ def save_preference(
+ key: str,
+ value: str,
+ store: Annotated[Any, InjectedStore()]
+ ) -> str:
+ \"\"\"Save user preference to persistent storage.\"\"\"
+ store.put(("preferences",), key, value)
+ return f"Saved {key} = {value}"
+
+ @tool
+ def get_preference(
+ key: str,
+ store: Annotated[Any, InjectedStore()]
+ ) -> str:
+ \"\"\"Retrieve user preference from persistent storage.\"\"\"
+ result = store.get(("preferences",), key)
+ return result.value if result else "Not found"
+ ```
+
+ Usage with ToolNode and graph compilation:
+
+ ```python
+ from langgraph.graph import StateGraph
+ from langgraph.store.memory import InMemoryStore
+
+ store = InMemoryStore()
+ tool_node = ToolNode([save_preference, get_preference])
+
+ graph = StateGraph(State)
+ graph.add_node("tools", tool_node)
+ compiled_graph = graph.compile(store=store) # Store is injected automatically
+ ```
+
+ Cross-session persistence:
+
+ ```python
+ # First session
+ result1 = graph.invoke({"messages": [HumanMessage("Save my favorite color as blue")]})
+
+ # Later session - data persists
+ result2 = graph.invoke({"messages": [HumanMessage("What's my favorite color?")]})
+ ```
+
+ Note:
+ - InjectedStore arguments are automatically excluded from tool schemas
+ presented to language models
+ - The store instance is automatically injected by ToolNode during execution
+ - Tools can access namespaced storage using the store's get/put methods
+ - Store injection requires the graph to be compiled with a store instance
+ - Multiple tools can share the same store instance for data consistency
+ """
+
+
+def _is_injection(type_arg: Any, injection_type: type[Union[InjectedState, InjectedStore]]) -> bool:
+ """Check if a type argument represents an injection annotation.
+
+ This utility function determines whether a type annotation indicates that
+ an argument should be injected with state or store data. It handles both
+ direct annotations and nested annotations within Union or Annotated types.
+
+ Args:
+ type_arg: The type argument to check for injection annotations.
+ injection_type: The injection type to look for (InjectedState or InjectedStore).
+
+ Returns:
+ True if the type argument contains the specified injection annotation.
+ """
+ if isinstance(type_arg, injection_type) or (
+ isinstance(type_arg, type) and issubclass(type_arg, injection_type)
+ ):
+ return True
+ origin_ = get_origin(type_arg)
+ if origin_ is Union or origin_ is Annotated:
+ return any(_is_injection(ta, injection_type) for ta in get_args(type_arg))
+ return False
+
+
+def _get_state_args(tool: BaseTool) -> dict[str, str | None]:
+ """Extract state injection mappings from tool annotations.
+
+ This function analyzes a tool's input schema to identify arguments that should
+ be injected with graph state. It processes InjectedState annotations to build
+ a mapping of tool argument names to state field names.
+
+ Args:
+ tool: The tool to analyze for state injection requirements.
+
+ Returns:
+ A dictionary mapping tool argument names to state field names. If a field
+ name is None, the entire state should be injected for that argument.
+ """
+ full_schema = tool.get_input_schema()
+ tool_args_to_state_fields: dict = {}
+
+ for name, type_ in get_all_basemodel_annotations(full_schema).items():
+ injections = [
+ type_arg for type_arg in get_args(type_) if _is_injection(type_arg, InjectedState)
+ ]
+ if len(injections) > 1:
+ msg = (
+ "A tool argument should not be annotated with InjectedState more than "
+ f"once. Received arg {name} with annotations {injections}."
+ )
+ raise ValueError(msg)
+ if len(injections) == 1:
+ injection = injections[0]
+ if isinstance(injection, InjectedState) and injection.field:
+ tool_args_to_state_fields[name] = injection.field
+ else:
+ tool_args_to_state_fields[name] = None
+ else:
+ pass
+ return tool_args_to_state_fields
+
+
+def _get_store_arg(tool: BaseTool) -> str | None:
+ """Extract store injection argument from tool annotations.
+
+ This function analyzes a tool's input schema to identify the argument that
+ should be injected with the graph store. Only one store argument is supported
+ per tool.
+
+ Args:
+ tool: The tool to analyze for store injection requirements.
+
+ Returns:
+ The name of the argument that should receive the store injection, or None
+ if no store injection is required.
+
+ Raises:
+ ValueError: If a tool argument has multiple InjectedStore annotations.
+ """
+ full_schema = tool.get_input_schema()
+ for name, type_ in get_all_basemodel_annotations(full_schema).items():
+ injections = [
+ type_arg for type_arg in get_args(type_) if _is_injection(type_arg, InjectedStore)
+ ]
+ if len(injections) > 1:
+ msg = (
+ "A tool argument should not be annotated with InjectedStore more than "
+ f"once. Received arg {name} with annotations {injections}."
+ )
+ raise ValueError(msg)
+ if len(injections) == 1:
+ return name
+
+ return None
diff --git a/libs/langchain_v1/langchain/chat_models/base.py b/libs/langchain_v1/langchain/chat_models/base.py
index f222aa1b86c..f9ae98e4368 100644
--- a/libs/langchain_v1/langchain/chat_models/base.py
+++ b/libs/langchain_v1/langchain/chat_models/base.py
@@ -13,7 +13,7 @@ from typing import (
)
from langchain_core.language_models import BaseChatModel, LanguageModelInput
-from langchain_core.messages import AnyMessage, BaseMessage
+from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from typing_extensions import override
@@ -915,7 +915,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
- ) -> Runnable[LanguageModelInput, BaseMessage]:
+ ) -> Runnable[LanguageModelInput, AIMessage]:
return self.__getattr__("bind_tools")(tools, **kwargs)
# Explicitly added to satisfy downstream linters.
diff --git a/libs/langchain_v1/pyproject.toml b/libs/langchain_v1/pyproject.toml
index b05002d4c2f..f471a309b5c 100644
--- a/libs/langchain_v1/pyproject.toml
+++ b/libs/langchain_v1/pyproject.toml
@@ -52,11 +52,11 @@ test = [
"pytest-socket>=0.6.0",
"syrupy>=4.0.2",
"pytest-xdist>=3.6.1",
- "blockbuster>=1.5.18",
"langchain-tests",
"langchain-text-splitters",
"langchain-openai",
"toml>=0.10.2",
+ "pytest-mock"
]
codespell = ["codespell<3.0.0,>=2.2.0"]
lint = [
@@ -94,6 +94,7 @@ strict_bytes = "True"
ignore_missing_imports = "True"
enable_error_code = "deprecated"
report_deprecated_as_note = "True"
+exclude = ["tests/unit_tests/agents/*", "tests/integration_tests/agents/*"]
# TODO: activate for 'strict' checking
disallow_untyped_calls = "False"
@@ -135,14 +136,24 @@ flake8-annotations.allow-star-arg-any = true
"D", # Documentation rules
"PLC0415", # Imports should be at the top. Not always desirable for tests
]
+"langchain/agents/*" = [
+ "ANN401", # we use Any right now, need to narrow
+ "E501", # line too long, needs to fix
+ "A002", # input is shadowing builtin
+ "A001", # input is shadowing builtin
+ "B904", # use from for exceptions
+ "PLR2004", # magic values are fine for this case
+ "C901", # too complex
+ "TRY004", # type error exception
+ "PLR0912", # too many branches
+ "PLR0911", # too many return statements
+]
+"tests/unit_tests/agents/*" = ["ALL"]
+"tests/integration_tests/agents/*" = ["ALL"]
[tool.ruff.lint.extend-per-file-ignores]
"scripts/check_imports.py" = ["ALL"]
-"langchain/globals.py" = [
- "PLW"
-]
-
"langchain/chat_models/base.py" = [
"ANN",
"C901",
diff --git a/libs/langchain_v1/tests/integration_tests/agents/__init__.py b/libs/langchain_v1/tests/integration_tests/agents/__init__.py
new file mode 100644
index 00000000000..6c212a0b69d
--- /dev/null
+++ b/libs/langchain_v1/tests/integration_tests/agents/__init__.py
@@ -0,0 +1 @@
+"""Integration tests for the agents module."""
diff --git a/libs/langchain_v1/tests/integration_tests/agents/test_response_format.py b/libs/langchain_v1/tests/integration_tests/agents/test_response_format.py
new file mode 100644
index 00000000000..6836cc692cd
--- /dev/null
+++ b/libs/langchain_v1/tests/integration_tests/agents/test_response_format.py
@@ -0,0 +1,79 @@
+import pytest
+from langchain_core.messages import HumanMessage
+from pydantic import BaseModel, Field
+
+from langchain.agents import create_react_agent
+from langchain.agents.structured_output import ToolStrategy
+
+
+class WeatherBaseModel(BaseModel):
+ """Weather response."""
+
+ temperature: float = Field(description="The temperature in fahrenheit")
+ condition: str = Field(description="Weather condition")
+
+
+def get_weather(city: str) -> str: # noqa: ARG001
+ """Get the weather for a city."""
+ return "The weather is sunny and 75°F."
+
+
+@pytest.mark.requires("langchain_openai")
+def test_inference_to_native_output() -> None:
+ """Test that native output is inferred when a model supports it."""
+ from langchain_openai import ChatOpenAI
+
+ model = ChatOpenAI(model="gpt-5")
+ agent = create_react_agent(
+ model,
+ prompt=(
+ "You are a helpful weather assistant. Please call the get_weather tool, "
+ "then use the WeatherReport tool to generate the final response."
+ ),
+ tools=[get_weather],
+ response_format=WeatherBaseModel,
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert isinstance(response["structured_response"], WeatherBaseModel)
+ assert response["structured_response"].temperature == 75.0
+ assert response["structured_response"].condition.lower() == "sunny"
+ assert len(response["messages"]) == 4
+
+ assert [m.type for m in response["messages"]] == [
+ "human", # "What's the weather?"
+ "ai", # "What's the weather?"
+ "tool", # "The weather is sunny and 75°F."
+ "ai", # structured response
+ ]
+
+
+@pytest.mark.requires("langchain_openai")
+def test_inference_to_tool_output() -> None:
+ """Test that tool output is inferred when a model supports it."""
+ from langchain_openai import ChatOpenAI
+
+ model = ChatOpenAI(model="gpt-4")
+ agent = create_react_agent(
+ model,
+ prompt=(
+ "You are a helpful weather assistant. Please call the get_weather tool, "
+ "then use the WeatherReport tool to generate the final response."
+ ),
+ tools=[get_weather],
+ response_format=ToolStrategy(WeatherBaseModel),
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert isinstance(response["structured_response"], WeatherBaseModel)
+ assert response["structured_response"].temperature == 75.0
+ assert response["structured_response"].condition.lower() == "sunny"
+ assert len(response["messages"]) == 5
+
+ assert [m.type for m in response["messages"]] == [
+ "human", # "What's the weather?"
+ "ai", # "What's the weather?"
+ "tool", # "The weather is sunny and 75°F."
+ "ai", # structured response
+ "tool", # artificial tool message
+ ]
diff --git a/libs/langchain_v1/tests/unit_tests/agents/__init__.py b/libs/langchain_v1/tests/unit_tests/agents/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_react_agent_graph.ambr b/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_react_agent_graph.ambr
new file mode 100644
index 00000000000..49e40649e57
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_react_agent_graph.ambr
@@ -0,0 +1,83 @@
+# serializer version: 1
+# name: test_react_agent_graph_structure[None-None-tools0]
+ '''
+ graph TD;
+ __start__ --> agent;
+ agent --> __end__;
+
+ '''
+# ---
+# name: test_react_agent_graph_structure[None-None-tools1]
+ '''
+ graph TD;
+ __start__ --> agent;
+ agent -.-> __end__;
+ agent -.-> tools;
+ tools --> agent;
+
+ '''
+# ---
+# name: test_react_agent_graph_structure[None-pre_model_hook-tools0]
+ '''
+ graph TD;
+ __start__ --> pre_model_hook;
+ pre_model_hook --> agent;
+ agent --> __end__;
+
+ '''
+# ---
+# name: test_react_agent_graph_structure[None-pre_model_hook-tools1]
+ '''
+ graph TD;
+ __start__ --> pre_model_hook;
+ agent -.-> __end__;
+ agent -.-> tools;
+ pre_model_hook --> agent;
+ tools --> pre_model_hook;
+
+ '''
+# ---
+# name: test_react_agent_graph_structure[post_model_hook-None-tools0]
+ '''
+ graph TD;
+ __start__ --> agent;
+ agent --> post_model_hook;
+ post_model_hook --> __end__;
+
+ '''
+# ---
+# name: test_react_agent_graph_structure[post_model_hook-None-tools1]
+ '''
+ graph TD;
+ __start__ --> agent;
+ agent --> post_model_hook;
+ post_model_hook -.-> __end__;
+ post_model_hook -.-> agent;
+ post_model_hook -.-> tools;
+ tools --> agent;
+
+ '''
+# ---
+# name: test_react_agent_graph_structure[post_model_hook-pre_model_hook-tools0]
+ '''
+ graph TD;
+ __start__ --> pre_model_hook;
+ agent --> post_model_hook;
+ pre_model_hook --> agent;
+ post_model_hook --> __end__;
+
+ '''
+# ---
+# name: test_react_agent_graph_structure[post_model_hook-pre_model_hook-tools1]
+ '''
+ graph TD;
+ __start__ --> pre_model_hook;
+ agent --> post_model_hook;
+ post_model_hook -.-> __end__;
+ post_model_hook -.-> pre_model_hook;
+ post_model_hook -.-> tools;
+ pre_model_hook --> agent;
+ tools --> pre_model_hook;
+
+ '''
+# ---
diff --git a/libs/langchain_v1/tests/unit_tests/agents/any_str.py b/libs/langchain_v1/tests/unit_tests/agents/any_str.py
new file mode 100644
index 00000000000..790324322dc
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/any_str.py
@@ -0,0 +1,18 @@
+import re
+from typing import Union
+
+
+class AnyStr(str):
+ def __init__(self, prefix: Union[str, re.Pattern] = "") -> None:
+ super().__init__()
+ self.prefix = prefix
+
+ def __eq__(self, other: object) -> bool:
+ return isinstance(other, str) and (
+ other.startswith(self.prefix)
+ if isinstance(self.prefix, str)
+ else self.prefix.match(other)
+ )
+
+ def __hash__(self) -> int:
+ return hash((str(self), self.prefix))
diff --git a/libs/langchain_v1/tests/unit_tests/agents/compose-postgres.yml b/libs/langchain_v1/tests/unit_tests/agents/compose-postgres.yml
new file mode 100644
index 00000000000..221b35dafe4
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/compose-postgres.yml
@@ -0,0 +1,17 @@
+name: langgraph-tests
+services:
+ postgres-test:
+ image: postgres:16
+ ports:
+ - "5442:5432"
+ environment:
+ POSTGRES_DB: postgres
+ POSTGRES_USER: postgres
+ POSTGRES_PASSWORD: postgres
+ healthcheck:
+ test: pg_isready -U postgres
+ start_period: 10s
+ timeout: 1s
+ retries: 5
+ interval: 60s
+ start_interval: 1s
diff --git a/libs/langchain_v1/tests/unit_tests/agents/compose-redis.yml b/libs/langchain_v1/tests/unit_tests/agents/compose-redis.yml
new file mode 100644
index 00000000000..18862fd28dd
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/compose-redis.yml
@@ -0,0 +1,16 @@
+name: langgraph-tests-redis
+services:
+ redis-test:
+ image: redis:7-alpine
+ ports:
+ - "6379:6379"
+ command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
+ healthcheck:
+ test: redis-cli ping
+ start_period: 10s
+ timeout: 1s
+ retries: 5
+ interval: 5s
+ start_interval: 1s
+ tmpfs:
+ - /data # Use tmpfs for faster testing
diff --git a/libs/langchain_v1/tests/unit_tests/agents/conftest.py b/libs/langchain_v1/tests/unit_tests/agents/conftest.py
new file mode 100644
index 00000000000..7e134f68271
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/conftest.py
@@ -0,0 +1,194 @@
+import os
+from collections.abc import AsyncIterator, Iterator
+from uuid import UUID
+
+import pytest
+from langgraph.checkpoint.base import BaseCheckpointSaver
+from langgraph.store.base import BaseStore
+from pytest_mock import MockerFixture
+
+from .conftest_checkpointer import (
+ _checkpointer_memory,
+ _checkpointer_postgres,
+ _checkpointer_postgres_aio,
+ _checkpointer_postgres_aio_pipe,
+ _checkpointer_postgres_aio_pool,
+ _checkpointer_postgres_pipe,
+ _checkpointer_postgres_pool,
+ _checkpointer_sqlite,
+ _checkpointer_sqlite_aio,
+)
+from .conftest_store import (
+ _store_memory,
+ _store_postgres,
+ _store_postgres_aio,
+ _store_postgres_aio_pipe,
+ _store_postgres_aio_pool,
+ _store_postgres_pipe,
+ _store_postgres_pool,
+)
+
+# Global variables for checkpointer and store configurations
+FAST_MODE = os.getenv("LANGGRAPH_TEST_FAST", "true").lower() in ("true", "1", "yes")
+
+SYNC_CHECKPOINTER_PARAMS = (
+ ["memory"]
+ if FAST_MODE
+ else [
+ "memory",
+ "sqlite",
+ "postgres",
+ "postgres_pipe",
+ "postgres_pool",
+ ]
+)
+
+ASYNC_CHECKPOINTER_PARAMS = (
+ ["memory"]
+ if FAST_MODE
+ else [
+ "memory",
+ "sqlite_aio",
+ "postgres_aio",
+ "postgres_aio_pipe",
+ "postgres_aio_pool",
+ ]
+)
+
+SYNC_STORE_PARAMS = (
+ ["in_memory"]
+ if FAST_MODE
+ else [
+ "in_memory",
+ "postgres",
+ "postgres_pipe",
+ "postgres_pool",
+ ]
+)
+
+ASYNC_STORE_PARAMS = (
+ ["in_memory"]
+ if FAST_MODE
+ else [
+ "in_memory",
+ "postgres_aio",
+ "postgres_aio_pipe",
+ "postgres_aio_pool",
+ ]
+)
+
+
+@pytest.fixture
+def anyio_backend() -> str:
+ return "asyncio"
+
+
+@pytest.fixture
+def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:
+ side_effect = (UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000))
+ return mocker.patch("uuid.uuid4", side_effect=side_effect)
+
+
+# checkpointer fixtures
+
+
+@pytest.fixture(
+ params=SYNC_STORE_PARAMS,
+)
+def sync_store(request: pytest.FixtureRequest) -> Iterator[BaseStore]:
+ store_name = request.param
+ if store_name is None:
+ yield None
+ elif store_name == "in_memory":
+ with _store_memory() as store:
+ yield store
+ elif store_name == "postgres":
+ with _store_postgres() as store:
+ yield store
+ elif store_name == "postgres_pipe":
+ with _store_postgres_pipe() as store:
+ yield store
+ elif store_name == "postgres_pool":
+ with _store_postgres_pool() as store:
+ yield store
+ else:
+ msg = f"Unknown store {store_name}"
+ raise NotImplementedError(msg)
+
+
+@pytest.fixture(
+ params=ASYNC_STORE_PARAMS,
+)
+async def async_store(request: pytest.FixtureRequest) -> AsyncIterator[BaseStore]:
+ store_name = request.param
+ if store_name is None:
+ yield None
+ elif store_name == "in_memory":
+ with _store_memory() as store:
+ yield store
+ elif store_name == "postgres_aio":
+ async with _store_postgres_aio() as store:
+ yield store
+ elif store_name == "postgres_aio_pipe":
+ async with _store_postgres_aio_pipe() as store:
+ yield store
+ elif store_name == "postgres_aio_pool":
+ async with _store_postgres_aio_pool() as store:
+ yield store
+ else:
+ msg = f"Unknown store {store_name}"
+ raise NotImplementedError(msg)
+
+
+@pytest.fixture(
+ params=SYNC_CHECKPOINTER_PARAMS,
+)
+def sync_checkpointer(
+ request: pytest.FixtureRequest,
+) -> Iterator[BaseCheckpointSaver]:
+ checkpointer_name = request.param
+ if checkpointer_name == "memory":
+ with _checkpointer_memory() as checkpointer:
+ yield checkpointer
+ elif checkpointer_name == "sqlite":
+ with _checkpointer_sqlite() as checkpointer:
+ yield checkpointer
+ elif checkpointer_name == "postgres":
+ with _checkpointer_postgres() as checkpointer:
+ yield checkpointer
+ elif checkpointer_name == "postgres_pipe":
+ with _checkpointer_postgres_pipe() as checkpointer:
+ yield checkpointer
+ elif checkpointer_name == "postgres_pool":
+ with _checkpointer_postgres_pool() as checkpointer:
+ yield checkpointer
+ else:
+ msg = f"Unknown checkpointer: {checkpointer_name}"
+ raise NotImplementedError(msg)
+
+
+@pytest.fixture(
+ params=ASYNC_CHECKPOINTER_PARAMS,
+)
+async def async_checkpointer(
+ request: pytest.FixtureRequest,
+) -> AsyncIterator[BaseCheckpointSaver]:
+ checkpointer_name = request.param
+ if checkpointer_name == "memory":
+ with _checkpointer_memory() as checkpointer:
+ yield checkpointer
+ elif checkpointer_name == "sqlite_aio":
+ async with _checkpointer_sqlite_aio() as checkpointer:
+ yield checkpointer
+ elif checkpointer_name == "postgres_aio":
+ async with _checkpointer_postgres_aio() as checkpointer:
+ yield checkpointer
+ elif checkpointer_name == "postgres_aio_pipe":
+ async with _checkpointer_postgres_aio_pipe() as checkpointer:
+ yield checkpointer
+ elif checkpointer_name == "postgres_aio_pool":
+ async with _checkpointer_postgres_aio_pool() as checkpointer:
+ yield checkpointer
+ else:
+ msg = f"Unknown checkpointer: {checkpointer_name}"
+ raise NotImplementedError(msg)
diff --git a/libs/langchain_v1/tests/unit_tests/agents/conftest_checkpointer.py b/libs/langchain_v1/tests/unit_tests/agents/conftest_checkpointer.py
new file mode 100644
index 00000000000..4505fb28ee3
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/conftest_checkpointer.py
@@ -0,0 +1,64 @@
+from contextlib import asynccontextmanager, contextmanager
+
+from .memory_assert import (
+ MemorySaverAssertImmutable,
+)
+
+
+@contextmanager
+def _checkpointer_memory():
+ yield MemorySaverAssertImmutable()
+
+
+@asynccontextmanager
+async def _checkpointer_memory_aio():
+ yield MemorySaverAssertImmutable()
+
+
+# Placeholder functions for other checkpointer types that aren't available
+@contextmanager
+def _checkpointer_sqlite():
+ # Fallback to memory for now
+ yield MemorySaverAssertImmutable()
+
+
+@contextmanager
+def _checkpointer_postgres():
+ # Fallback to memory for now
+ yield MemorySaverAssertImmutable()
+
+
+@contextmanager
+def _checkpointer_postgres_pipe():
+ # Fallback to memory for now
+ yield MemorySaverAssertImmutable()
+
+
+@contextmanager
+def _checkpointer_postgres_pool():
+ # Fallback to memory for now
+ yield MemorySaverAssertImmutable()
+
+
+@asynccontextmanager
+async def _checkpointer_sqlite_aio():
+ # Fallback to memory for now
+ yield MemorySaverAssertImmutable()
+
+
+@asynccontextmanager
+async def _checkpointer_postgres_aio():
+ # Fallback to memory for now
+ yield MemorySaverAssertImmutable()
+
+
+@asynccontextmanager
+async def _checkpointer_postgres_aio_pipe():
+ # Fallback to memory for now
+ yield MemorySaverAssertImmutable()
+
+
+@asynccontextmanager
+async def _checkpointer_postgres_aio_pool():
+ # Fallback to memory for now
+ yield MemorySaverAssertImmutable()
diff --git a/libs/langchain_v1/tests/unit_tests/agents/conftest_store.py b/libs/langchain_v1/tests/unit_tests/agents/conftest_store.py
new file mode 100644
index 00000000000..fb0eb4e2fba
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/conftest_store.py
@@ -0,0 +1,58 @@
+from contextlib import asynccontextmanager, contextmanager
+
+from langgraph.store.memory import InMemoryStore
+
+
+@contextmanager
+def _store_memory():
+ store = InMemoryStore()
+ yield store
+
+
+@asynccontextmanager
+async def _store_memory_aio():
+ store = InMemoryStore()
+ yield store
+
+
+# Placeholder functions for other store types that aren't available
+@contextmanager
+def _store_postgres():
+ # Fallback to memory for now
+ store = InMemoryStore()
+ yield store
+
+
+@contextmanager
+def _store_postgres_pipe():
+ # Fallback to memory for now
+ store = InMemoryStore()
+ yield store
+
+
+@contextmanager
+def _store_postgres_pool():
+ # Fallback to memory for now
+ store = InMemoryStore()
+ yield store
+
+
+@asynccontextmanager
+async def _store_postgres_aio():
+ # Fallback to memory for now
+ store = InMemoryStore()
+ yield store
+
+
+@asynccontextmanager
+async def _store_postgres_aio_pipe():
+ # Fallback to memory for now
+ store = InMemoryStore()
+ yield store
+
+
+@asynccontextmanager
+async def _store_postgres_aio_pool():
+ # Fallback to memory for now
+ store = InMemoryStore()
+ yield store
diff --git a/libs/langchain_v1/tests/unit_tests/agents/memory_assert.py b/libs/langchain_v1/tests/unit_tests/agents/memory_assert.py
new file mode 100644
index 00000000000..99d66bbe026
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/memory_assert.py
@@ -0,0 +1,56 @@
+import os
+import tempfile
+from collections import defaultdict
+from functools import partial
+
+from langgraph.checkpoint.base import (
+ ChannelVersions,
+ Checkpoint,
+ CheckpointMetadata,
+ SerializerProtocol,
+)
+from langgraph.checkpoint.memory import InMemorySaver, PersistentDict
+from langgraph.pregel._checkpoint import copy_checkpoint
+
+
+class MemorySaverAssertImmutable(InMemorySaver):
+ storage_for_copies: defaultdict[str, dict[str, dict[str, Checkpoint]]]
+
+ def __init__(
+ self,
+ *,
+ serde: SerializerProtocol | None = None,
+ put_sleep: float | None = None,
+ ) -> None:
+ _, filename = tempfile.mkstemp()
+ super().__init__(serde=serde, factory=partial(PersistentDict, filename=filename))
+ self.storage_for_copies = defaultdict(lambda: defaultdict(dict))
+ self.put_sleep = put_sleep
+ self.stack.callback(os.remove, filename)
+
+ def put(
+ self,
+ config: dict,
+ checkpoint: Checkpoint,
+ metadata: CheckpointMetadata,
+ new_versions: ChannelVersions,
+ ) -> None:
+ if self.put_sleep:
+ import time
+
+ time.sleep(self.put_sleep)
+ # assert checkpoint hasn't been modified since last written
+ thread_id = config["configurable"]["thread_id"]
+ checkpoint_ns = config["configurable"]["checkpoint_ns"]
+ if saved := super().get(config):
+ assert (
+ self.serde.loads_typed(
+ self.storage_for_copies[thread_id][checkpoint_ns][saved["id"]]
+ )
+ == saved
+ )
+ self.storage_for_copies[thread_id][checkpoint_ns][checkpoint["id"]] = (
+ self.serde.dumps_typed(copy_checkpoint(checkpoint))
+ )
+ # call super to write checkpoint
+ return super().put(config, checkpoint, metadata, new_versions)
diff --git a/libs/langchain_v1/tests/unit_tests/agents/messages.py b/libs/langchain_v1/tests/unit_tests/agents/messages.py
new file mode 100644
index 00000000000..285c2941fef
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/messages.py
@@ -0,0 +1,28 @@
+"""Redefined messages as a work-around for pydantic issue with AnyStr.
+
+The code below creates version of pydantic models
+that will work in unit tests with AnyStr as id field
+Please note that the `id` field is assigned AFTER the model is created
+to workaround an issue with pydantic ignoring the __eq__ method on
+subclassed strings.
+"""
+
+from typing import Any
+
+from langchain_core.messages import HumanMessage, ToolMessage
+
+from .any_str import AnyStr
+
+
+def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
+ """Create a human message with an any id field."""
+ message = HumanMessage(**kwargs)
+ message.id = AnyStr()
+ return message
+
+
+def _AnyIdToolMessage(**kwargs: Any) -> ToolMessage:
+ """Create a tool message with an any id field."""
+ message = ToolMessage(**kwargs)
+ message.id = AnyStr()
+ return message
diff --git a/libs/langchain_v1/tests/unit_tests/agents/model.py b/libs/langchain_v1/tests/unit_tests/agents/model.py
new file mode 100644
index 00000000000..70cbbbafe06
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/model.py
@@ -0,0 +1,111 @@
+import json
+from collections.abc import Callable, Sequence
+from dataclasses import asdict, is_dataclass
+from typing import (
+ Any,
+ Generic,
+ Literal,
+ TypeVar,
+ Union,
+)
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models import BaseChatModel, LanguageModelInput
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ ToolCall,
+)
+from langchain_core.outputs import ChatGeneration, ChatResult
+from langchain_core.runnables import Runnable
+from langchain_core.tools import BaseTool
+from pydantic import BaseModel
+
+StructuredResponseT = TypeVar("StructuredResponseT")
+
+
+class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
+ tool_calls: Union[list[list[ToolCall]], list[list[dict]]] | None = None
+ structured_response: StructuredResponseT | None = None
+ index: int = 0
+ tool_style: Literal["openai", "anthropic"] = "openai"
+
+ def _generate(
+ self,
+ messages: list[BaseMessage],
+ stop: list[str] | None = None,
+ run_manager: CallbackManagerForLLMRun | None = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ """Top Level call"""
+ rf = kwargs.get("response_format")
+ is_native = isinstance(rf, dict) and rf.get("type") == "json_schema"
+
+ if self.tool_calls:
+ if is_native:
+ tool_calls = (
+ self.tool_calls[self.index] if self.index < len(self.tool_calls) else []
+ )
+ else:
+ tool_calls = self.tool_calls[self.index % len(self.tool_calls)]
+ else:
+ tool_calls = []
+
+ if is_native and not tool_calls:
+ if isinstance(self.structured_response, BaseModel):
+ content_obj = self.structured_response.model_dump()
+ elif is_dataclass(self.structured_response):
+ content_obj = asdict(self.structured_response)
+ elif isinstance(self.structured_response, dict):
+ content_obj = self.structured_response
+ message = AIMessage(content=json.dumps(content_obj), id=str(self.index))
+ else:
+ messages_string = "-".join([m.content for m in messages])
+ message = AIMessage(
+ content=messages_string,
+ id=str(self.index),
+ tool_calls=tool_calls.copy(),
+ )
+ self.index += 1
+ return ChatResult(generations=[ChatGeneration(message=message)])
+
+ @property
+ def _llm_type(self) -> str:
+ return "fake-tool-call-model"
+
+ def bind_tools(
+ self,
+ tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
+ **kwargs: Any,
+ ) -> Runnable[LanguageModelInput, BaseMessage]:
+ if len(tools) == 0:
+ msg = "Must provide at least one tool"
+ raise ValueError(msg)
+
+ tool_dicts = []
+ for tool in tools:
+ if isinstance(tool, dict):
+ tool_dicts.append(tool)
+ continue
+ if not isinstance(tool, BaseTool):
+ msg = "Only BaseTool and dict is supported by FakeToolCallingModel.bind_tools"
+ raise TypeError(msg)
+
+ # NOTE: this is a simplified tool spec for testing purposes only
+ if self.tool_style == "openai":
+ tool_dicts.append(
+ {
+ "type": "function",
+ "function": {
+ "name": tool.name,
+ },
+ }
+ )
+ elif self.tool_style == "anthropic":
+ tool_dicts.append(
+ {
+ "name": tool.name,
+ }
+ )
+
+ return self.bind(tools=tool_dicts)
diff --git a/libs/langchain_v1/tests/unit_tests/agents/specifications/responses.json b/libs/langchain_v1/tests/unit_tests/agents/specifications/responses.json
new file mode 100644
index 00000000000..2755ef9756c
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/specifications/responses.json
@@ -0,0 +1,87 @@
+[
+ {
+ "name": "updated structured response",
+ "responseFormat": [
+ {
+ "title": "role_schema_structured_output",
+ "type": "object",
+ "properties": {
+ "name": { "type": "string" },
+ "role": { "type": "string" }
+ },
+ "required": ["name", "role"]
+ },
+ {
+ "title": "department_schema_structured_output",
+ "type": "object",
+ "properties": {
+ "name": { "type": "string" },
+ "department": { "type": "string" }
+ },
+ "required": ["name", "department"]
+ }
+ ],
+ "assertionsByInvocation": [
+ {
+ "prompt": "What is the role of Sabine?",
+ "toolsWithExpectedCalls": {
+ "getEmployeeRole": 1,
+ "getEmployeeDepartment": 0
+ },
+ "expectedLastMessage": "Returning structured response: {'name': 'Sabine', 'role': 'Developer'}",
+ "expectedStructuredResponse": { "name": "Sabine", "role": "Developer" },
+ "llmRequestCount": 2
+ },
+ {
+ "prompt": "In which department does Henrik work?",
+ "toolsWithExpectedCalls": {
+ "getEmployeeRole": 1,
+ "getEmployeeDepartment": 1
+ },
+ "expectedLastMessage": "Returning structured response: {'name': 'Henrik', 'department': 'IT'}",
+ "expectedStructuredResponse": { "name": "Henrik", "department": "IT" },
+ "llmRequestCount": 4
+ }
+ ]
+ },
+ {
+ "name": "asking for information that does not fit into the response format",
+ "responseFormat": [
+ {
+ "schema": {
+ "type": "object",
+ "properties": {
+ "name": { "type": "string" },
+ "role": { "type": "string" }
+ },
+ "required": ["name", "role"]
+ }
+ },
+ {
+ "schema": {
+ "type": "object",
+ "properties": {
+ "name": { "type": "string" },
+ "department": { "type": "string" }
+ },
+ "required": ["name", "department"]
+ }
+ }
+ ],
+ "assertionsByInvocation": [
+ {
+ "prompt": "How much does Saskia earn?",
+ "toolsWithExpectedCalls": {
+ "getEmployeeRole": 1,
+ "getEmployeeDepartment": 0
+ },
+ "expectedLastMessage": "Returning structured response: {'name': 'Saskia', 'role': 'Software Engineer'}",
+ "expectedStructuredResponse": {
+ "name": "Saskia",
+ "role": "Software Engineer"
+ },
+ "llmRequestCount": 2
+ }
+ ]
+ }
+ ]
\ No newline at end of file
diff --git a/libs/langchain_v1/tests/unit_tests/agents/specifications/return_direct.json b/libs/langchain_v1/tests/unit_tests/agents/specifications/return_direct.json
new file mode 100644
index 00000000000..e4cc5947ce0
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/specifications/return_direct.json
@@ -0,0 +1,48 @@
+[
+ {
+ "name": "Scenario: NO return_direct, NO response_format",
+ "returnDirect": false,
+ "responseFormat": null,
+ "expectedToolCalls": 10,
+ "expectedLastMessage": "Attempts: 10",
+ "expectedStructuredResponse": null
+ },
+ {
+ "name": "Scenario: NO return_direct, YES response_format",
+ "returnDirect": false,
+ "responseFormat": {
+ "type": "object",
+ "properties": {
+ "attempts": { "type": "number" },
+ "succeeded": { "type": "boolean" }
+ },
+ "required": ["attempts", "succeeded"]
+ },
+ "expectedToolCalls": 10,
+ "expectedLastMessage": "Returning structured response: {'attempts': 10, 'succeeded': True}",
+ "expectedStructuredResponse": { "attempts": 10, "succeeded": true }
+ },
+ {
+ "name": "Scenario: YES return_direct, NO response_format",
+ "returnDirect": true,
+ "responseFormat": null,
+ "expectedToolCalls": 1,
+ "expectedLastMessage": "{\"status\": \"pending\", \"attempts\": 1}",
+ "expectedStructuredResponse": null
+ },
+ {
+ "name": "Scenario: YES return_direct, YES response_format",
+ "returnDirect": true,
+ "responseFormat": {
+ "type": "object",
+ "properties": {
+ "attempts": { "type": "number" },
+ "succeeded": { "type": "boolean" }
+ },
+ "required": ["attempts", "succeeded"]
+ },
+ "expectedToolCalls": 1,
+ "expectedLastMessage": "{\"status\": \"pending\", \"attempts\": 1}",
+ "expectedStructuredResponse": null
+ }
+]
\ No newline at end of file
diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py
new file mode 100644
index 00000000000..a42d019d27c
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py
@@ -0,0 +1,1647 @@
+import dataclasses
+import inspect
+from typing import (
+ Annotated,
+ Union,
+)
+
+import pytest
+from langchain_core.language_models import BaseChatModel
+from langchain_core.messages import (
+ AIMessage,
+ HumanMessage,
+ MessageLikeRepresentation,
+ RemoveMessage,
+ SystemMessage,
+ ToolCall,
+ ToolMessage,
+)
+from langchain_core.runnables import RunnableConfig, RunnableLambda
+from langchain_core.tools import BaseTool, InjectedToolCallId, ToolException
+from langchain_core.tools import tool as dec_tool
+from langgraph.checkpoint.base import BaseCheckpointSaver
+from langgraph.graph import START, MessagesState, StateGraph
+from langgraph.graph.message import REMOVE_ALL_MESSAGES
+from langgraph.runtime import Runtime
+from langgraph.store.base import BaseStore
+from langgraph.store.memory import InMemoryStore
+from langgraph.types import Command, Interrupt, interrupt
+from pydantic import BaseModel, Field
+from typing_extensions import TypedDict
+
+from langchain.agents import (
+ AgentState,
+ ToolNode,
+ create_react_agent,
+)
+from langchain.agents.react_agent import _validate_chat_history
+from langchain.agents.tool_node import (
+ InjectedState,
+ InjectedStore,
+ _get_state_args,
+ _infer_handled_types,
+)
+
+from .any_str import AnyStr
+from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
+from .model import FakeToolCallingModel
+
+pytestmark = pytest.mark.anyio
+
+
+def test_no_prompt(sync_checkpointer: BaseCheckpointSaver) -> None:
+ model = FakeToolCallingModel()
+
+ agent = create_react_agent(
+ model,
+ [],
+ checkpointer=sync_checkpointer,
+ )
+ inputs = [HumanMessage("hi?")]
+ thread = {"configurable": {"thread_id": "123"}}
+ response = agent.invoke({"messages": inputs}, thread, debug=True)
+ expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
+ assert response == expected_response
+
+ saved = sync_checkpointer.get_tuple(thread)
+ assert saved is not None
+ assert saved.checkpoint["channel_values"] == {
+ "messages": [
+ _AnyIdHumanMessage(content="hi?"),
+ AIMessage(content="hi?", id="0"),
+ ],
+ }
+ assert saved.metadata == {
+ "parents": {},
+ "source": "loop",
+ "step": 1,
+ }
+ assert saved.pending_writes == []
+
+
+async def test_no_prompt_async(async_checkpointer: BaseCheckpointSaver) -> None:
+ model = FakeToolCallingModel()
+
+ agent = create_react_agent(model, [], checkpointer=async_checkpointer)
+ inputs = [HumanMessage("hi?")]
+ thread = {"configurable": {"thread_id": "123"}}
+ response = await agent.ainvoke({"messages": inputs}, thread, debug=True)
+ expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]}
+ assert response == expected_response
+
+ saved = await async_checkpointer.aget_tuple(thread)
+ assert saved is not None
+ assert saved.checkpoint["channel_values"] == {
+ "messages": [
+ _AnyIdHumanMessage(content="hi?"),
+ AIMessage(content="hi?", id="0"),
+ ],
+ }
+ assert saved.metadata == {
+ "parents": {},
+ "source": "loop",
+ "step": 1,
+ }
+ assert saved.pending_writes == []
+
+
+def test_system_message_prompt() -> None:
+ prompt = SystemMessage(content="Foo")
+ agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt)
+ inputs = [HumanMessage("hi?")]
+ response = agent.invoke({"messages": inputs})
+ expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]}
+ assert response == expected_response
+
+
+def test_string_prompt() -> None:
+ prompt = "Foo"
+ agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt)
+ inputs = [HumanMessage("hi?")]
+ response = agent.invoke({"messages": inputs})
+ expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]}
+ assert response == expected_response
+
+
+def test_callable_prompt() -> None:
+ def prompt(state):
+ modified_message = f"Bar {state['messages'][-1].content}"
+ return [HumanMessage(content=modified_message)]
+
+ agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt)
+ inputs = [HumanMessage("hi?")]
+ response = agent.invoke({"messages": inputs})
+ expected_response = {"messages": [*inputs, AIMessage(content="Bar hi?", id="0")]}
+ assert response == expected_response
+
+
+async def test_callable_prompt_async() -> None:
+ async def prompt(state):
+ modified_message = f"Bar {state['messages'][-1].content}"
+ return [HumanMessage(content=modified_message)]
+
+ agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt)
+ inputs = [HumanMessage("hi?")]
+ response = await agent.ainvoke({"messages": inputs})
+ expected_response = {"messages": [*inputs, AIMessage(content="Bar hi?", id="0")]}
+ assert response == expected_response
+
+
+def test_runnable_prompt() -> None:
+ prompt = RunnableLambda(
+ lambda state: [HumanMessage(content=f"Baz {state['messages'][-1].content}")]
+ )
+
+ agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt)
+ inputs = [HumanMessage("hi?")]
+ response = agent.invoke({"messages": inputs})
+ expected_response = {"messages": [*inputs, AIMessage(content="Baz hi?", id="0")]}
+ assert response == expected_response
+
+
+def test_prompt_with_store() -> None:
+ def add(a: int, b: int):
+ """Adds a and b"""
+ return a + b
+
+ in_memory_store = InMemoryStore()
+ in_memory_store.put(("memories", "1"), "user_name", {"data": "User name is Alice"})
+ in_memory_store.put(("memories", "2"), "user_name", {"data": "User name is Bob"})
+
+ def prompt(state, config, *, store):
+ user_id = config["configurable"]["user_id"]
+ system_str = store.get(("memories", user_id), "user_name").value["data"]
+ return [SystemMessage(system_str)] + state["messages"]
+
+ def prompt_no_store(state, config):
+ return SystemMessage("foo") + state["messages"]
+
+ model = FakeToolCallingModel()
+
+ # test state modifier that uses store works
+ agent = create_react_agent(
+ model,
+ [add],
+ prompt=prompt,
+ store=in_memory_store,
+ )
+ response = agent.invoke({"messages": [("user", "hi")]}, {"configurable": {"user_id": "1"}})
+ assert response["messages"][-1].content == "User name is Alice-hi"
+
+ # test state modifier that doesn't use store works
+ agent = create_react_agent(
+ model,
+ [add],
+ prompt=prompt_no_store,
+ store=in_memory_store,
+ )
+ response = agent.invoke({"messages": [("user", "hi")]}, {"configurable": {"user_id": "2"}})
+ assert response["messages"][-1].content == "foo-hi"
+
+
+async def test_prompt_with_store_async() -> None:
+ async def add(a: int, b: int):
+ """Adds a and b"""
+ return a + b
+
+ in_memory_store = InMemoryStore()
+ await in_memory_store.aput(("memories", "1"), "user_name", {"data": "User name is Alice"})
+ await in_memory_store.aput(("memories", "2"), "user_name", {"data": "User name is Bob"})
+
+ async def prompt(state, config, *, store):
+ user_id = config["configurable"]["user_id"]
+ system_str = (await store.aget(("memories", user_id), "user_name")).value["data"]
+ return [SystemMessage(system_str)] + state["messages"]
+
+ async def prompt_no_store(state, config):
+ return SystemMessage("foo") + state["messages"]
+
+ model = FakeToolCallingModel()
+
+ # test state modifier that uses store works
+ agent = create_react_agent(model, [add], prompt=prompt, store=in_memory_store)
+ response = await agent.ainvoke(
+ {"messages": [("user", "hi")]}, {"configurable": {"user_id": "1"}}
+ )
+ assert response["messages"][-1].content == "User name is Alice-hi"
+
+ # test state modifier that doesn't use store works
+ agent = create_react_agent(model, [add], prompt=prompt_no_store, store=in_memory_store)
+ response = await agent.ainvoke(
+ {"messages": [("user", "hi")]}, {"configurable": {"user_id": "2"}}
+ )
+ assert response["messages"][-1].content == "foo-hi"
+
+
+@pytest.mark.parametrize("tool_style", ["openai", "anthropic"])
+@pytest.mark.parametrize("include_builtin", [True, False])
+def test_model_with_tools(tool_style: str, include_builtin: bool) -> None:
+ model = FakeToolCallingModel(tool_style=tool_style)
+
+ @dec_tool
+ def tool1(some_val: int) -> str:
+ """Tool 1 docstring."""
+ return f"Tool 1: {some_val}"
+
+ @dec_tool
+ def tool2(some_val: int) -> str:
+ """Tool 2 docstring."""
+ return f"Tool 2: {some_val}"
+
+ tools: list[BaseTool | dict] = [tool1, tool2]
+ if include_builtin:
+ tools.append(
+ {
+ "type": "mcp",
+ "server_label": "atest_sever",
+ "server_url": "https://some.mcp.somewhere.com/sse",
+ "headers": {"foo": "bar"},
+ "allowed_tools": [
+ "mcp_tool_1",
+ "set_active_account",
+ "get_url_markdown",
+ "get_url_screenshot",
+ ],
+ "require_approval": "never",
+ }
+ )
+ # check valid agent constructor
+ with pytest.raises(ValueError):
+ create_react_agent(
+ model.bind_tools(tools),
+ tools,
+ )
+
+
+def test__validate_messages() -> None:
+ # empty input
+ _validate_chat_history([])
+
+ # single human message
+ _validate_chat_history(
+ [
+ HumanMessage(content="What's the weather?"),
+ ]
+ )
+
+ # human + AI
+ _validate_chat_history(
+ [
+ HumanMessage(content="What's the weather?"),
+ AIMessage(content="The weather is sunny and 75°F."),
+ ]
+ )
+
+ # Answered tool calls
+ _validate_chat_history(
+ [
+ HumanMessage(content="What's the weather?"),
+ AIMessage(
+ content="Let me check that for you.",
+ tool_calls=[{"id": "call1", "name": "get_weather", "args": {}}],
+ ),
+ ToolMessage(content="Sunny, 75°F", tool_call_id="call1"),
+ AIMessage(content="The weather is sunny and 75°F."),
+ ]
+ )
+
+ # Unanswered tool calls
+ with pytest.raises(ValueError):
+ _validate_chat_history(
+ [
+ AIMessage(
+ content="I'll check that for you.",
+ tool_calls=[
+ {"id": "call1", "name": "get_weather", "args": {}},
+ {"id": "call2", "name": "get_time", "args": {}},
+ ],
+ )
+ ]
+ )
+
+ with pytest.raises(ValueError):
+ _validate_chat_history(
+ [
+ HumanMessage(content="What's the weather and time?"),
+ AIMessage(
+ content="I'll check that for you.",
+ tool_calls=[
+ {"id": "call1", "name": "get_weather", "args": {}},
+ {"id": "call2", "name": "get_time", "args": {}},
+ ],
+ ),
+ ToolMessage(content="Sunny, 75°F", tool_call_id="call1"),
+ AIMessage(content="The weather is sunny and 75°F. Let me check the time."),
+ ]
+ )
+
+
+def test__infer_handled_types() -> None:
+ def handle(e) -> str: # type: ignore
+ return ""
+
+ def handle2(e: Exception) -> str:
+ return ""
+
+ def handle3(e: Union[ValueError, ToolException]) -> str:
+ return ""
+
+ class Handler:
+ def handle(self, e: ValueError) -> str:
+ return ""
+
+ handle4 = Handler().handle
+
+ def handle5(e: Union[Union[TypeError, ValueError], ToolException]) -> str:
+ return ""
+
+ expected: tuple = (Exception,)
+ actual = _infer_handled_types(handle)
+ assert expected == actual
+
+ expected = (Exception,)
+ actual = _infer_handled_types(handle2)
+ assert expected == actual
+
+ expected = (ValueError, ToolException)
+ actual = _infer_handled_types(handle3)
+ assert expected == actual
+
+ expected = (ValueError,)
+ actual = _infer_handled_types(handle4)
+ assert expected == actual
+
+ expected = (TypeError, ValueError, ToolException)
+ actual = _infer_handled_types(handle5)
+ assert expected == actual
+
+ with pytest.raises(ValueError):
+
+ def handler(e: str) -> str:
+ return ""
+
+ _infer_handled_types(handler)
+
+ with pytest.raises(ValueError):
+
+ def handler(e: list[Exception]) -> str:
+ return ""
+
+ _infer_handled_types(handler)
+
+ with pytest.raises(ValueError):
+
+ def handler(e: Union[str, int]) -> str:
+ return ""
+
+ _infer_handled_types(handler)
+
+
+def test_react_agent_with_structured_response() -> None:
+ class WeatherResponse(BaseModel):
+ temperature: float = Field(description="The temperature in fahrenheit")
+
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [{"name": "WeatherResponse", "id": "2", "args": {"temperature": 75}}],
+ ]
+
+ def get_weather() -> str:
+ """Get the weather"""
+ return "The weather is sunny and 75°F."
+
+ expected_structured_response = WeatherResponse(temperature=75)
+ model = FakeToolCallingModel[WeatherResponse](
+ tool_calls=tool_calls, structured_response=expected_structured_response
+ )
+ agent = create_react_agent(
+ model,
+ [get_weather],
+ response_format=WeatherResponse,
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+ assert response["structured_response"] == expected_structured_response
+ assert len(response["messages"]) == 5
+
+ # Check message types in message history
+ msg_types = [m.type for m in response["messages"]]
+ assert msg_types == [
+ "human", # "What's the weather?"
+ "ai", # "What's the weather?"
+ "tool", # "The weather is sunny and 75°F."
+ "ai", # structured response
+ "tool", # artificial tool message
+ ]
+
+ assert [m.content for m in response["messages"]] == [
+ "What's the weather?",
+ "What's the weather?",
+ "The weather is sunny and 75°F.",
+ "What's the weather?-What's the weather?-The weather is sunny and 75°F.",
+ "Returning structured response: {'temperature': 75.0}",
+ ]
+
+
+class CustomState(AgentState):
+ user_name: str
+
+
+def test_react_agent_update_state(
+ sync_checkpointer: BaseCheckpointSaver,
+) -> None:
+ @dec_tool
+ def get_user_name(tool_call_id: Annotated[str, InjectedToolCallId]):
+ """Retrieve user name"""
+ user_name = interrupt("Please provider user name:")
+ return Command(
+ update={
+ "user_name": user_name,
+ "messages": [
+ ToolMessage("Successfully retrieved user name", tool_call_id=tool_call_id)
+ ],
+ }
+ )
+
+ def prompt(state: CustomState):
+ user_name = state.get("user_name")
+ if user_name is None:
+ return state["messages"]
+
+ system_msg = f"User name is {user_name}"
+ return [{"role": "system", "content": system_msg}] + state["messages"]
+
+ tool_calls = [[{"args": {}, "id": "1", "name": "get_user_name"}]]
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+ agent = create_react_agent(
+ model,
+ [get_user_name],
+ state_schema=CustomState,
+ prompt=prompt,
+ checkpointer=sync_checkpointer,
+ )
+ config = {"configurable": {"thread_id": "1"}}
+ # Run until interrupted
+ agent.invoke({"messages": [("user", "what's my name")]}, config)
+ # supply the value for the interrupt
+ response = agent.invoke(Command(resume="Archibald"), config)
+ # confirm that the state was updated
+ assert response["user_name"] == "Archibald"
+ assert len(response["messages"]) == 4
+ tool_message: ToolMessage = response["messages"][-2]
+ assert tool_message.content == "Successfully retrieved user name"
+ assert tool_message.tool_call_id == "1"
+ assert tool_message.name == "get_user_name"
+
+
+def test_react_agent_parallel_tool_calls(
+ sync_checkpointer: BaseCheckpointSaver,
+) -> None:
+ human_assistance_execution_count = 0
+
+ @dec_tool
+ def human_assistance(query: str) -> str:
+ """Request assistance from a human."""
+ nonlocal human_assistance_execution_count
+ human_response = interrupt({"query": query})
+ human_assistance_execution_count += 1
+ return human_response["data"]
+
+ get_weather_execution_count = 0
+
+ @dec_tool
+ def get_weather(location: str) -> str:
+ """Use this tool to get the weather."""
+ nonlocal get_weather_execution_count
+ get_weather_execution_count += 1
+ return "It's sunny!"
+
+ tool_calls = [
+ [
+ {"args": {"location": "sf"}, "id": "1", "name": "get_weather"},
+ {"args": {"query": "request help"}, "id": "2", "name": "human_assistance"},
+ ],
+ [],
+ ]
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+ agent = create_react_agent(
+ model,
+ [human_assistance, get_weather],
+ checkpointer=sync_checkpointer,
+ )
+ config = {"configurable": {"thread_id": "1"}}
+ query = "Get user assistance and also check the weather"
+ message_types = []
+ for event in agent.stream({"messages": [("user", query)]}, config, stream_mode="values"):
+ if messages := event.get("messages"):
+ message_types.append([m.type for m in messages])
+
+ assert message_types == [
+ ["human"],
+ ["human", "ai"],
+ ["human", "ai", "tool"],
+ ]
+
+ # Resume
+ message_types = []
+ for event in agent.stream(Command(resume={"data": "Hello"}), config, stream_mode="values"):
+ if messages := event.get("messages"):
+ message_types.append([m.type for m in messages])
+
+ assert message_types == [
+ ["human", "ai"],
+ ["human", "ai", "tool", "tool"],
+ ["human", "ai", "tool", "tool", "ai"],
+ ]
+
+ assert human_assistance_execution_count == 1
+ assert get_weather_execution_count == 1
+
+
+class AgentStateExtraKey(AgentState):
+ foo: int
+
+
+def test_create_react_agent_inject_vars() -> None:
+ """Test that the agent can inject state and store into tool functions."""
+ store = InMemoryStore()
+ namespace = ("test",)
+ store.put(namespace, "test_key", {"bar": 3})
+
+ def tool1(
+ some_val: int,
+ state: Annotated[dict, InjectedState],
+ store: Annotated[BaseStore, InjectedStore()],
+ ) -> str:
+ """Tool 1 docstring."""
+ store_val = store.get(namespace, "test_key").value["bar"]
+ return some_val + state["foo"] + store_val
+
+ tool_call = {
+ "name": "tool1",
+ "args": {"some_val": 1},
+ "id": "some 0",
+ "type": "tool_call",
+ }
+ model = FakeToolCallingModel(tool_calls=[[tool_call], []])
+ agent = create_react_agent(
+ model,
+ ToolNode([tool1], handle_tool_errors=False),
+ state_schema=AgentStateExtraKey,
+ store=store,
+ )
+ result = agent.invoke({"messages": [{"role": "user", "content": "hi"}], "foo": 2})
+ assert result["messages"] == [
+ _AnyIdHumanMessage(content="hi"),
+ AIMessage(content="hi", tool_calls=[tool_call], id="0"),
+ _AnyIdToolMessage(content="6", name="tool1", tool_call_id="some 0"),
+ AIMessage("hi-hi-6", id="1"),
+ ]
+ assert result["foo"] == 2
+
+
+async def test_return_direct() -> None:
+ @dec_tool(return_direct=True)
+ def tool_return_direct(input: str) -> str:
+ """A tool that returns directly."""
+ return f"Direct result: {input}"
+
+ @dec_tool
+ def tool_normal(input: str) -> str:
+ """A normal tool."""
+ return f"Normal result: {input}"
+
+ first_tool_call = [
+ ToolCall(
+ name="tool_return_direct",
+ args={"input": "Test direct"},
+ id="1",
+ ),
+ ]
+ expected_ai = AIMessage(
+ content="Test direct",
+ id="0",
+ tool_calls=first_tool_call,
+ )
+ model = FakeToolCallingModel(tool_calls=[first_tool_call, []])
+ agent = create_react_agent(
+ model,
+ [tool_return_direct, tool_normal],
+ )
+
+ # Test direct return for tool_return_direct
+ result = agent.invoke({"messages": [HumanMessage(content="Test direct", id="hum0")]})
+ assert result["messages"] == [
+ HumanMessage(content="Test direct", id="hum0"),
+ expected_ai,
+ ToolMessage(
+ content="Direct result: Test direct",
+ name="tool_return_direct",
+ tool_call_id="1",
+ id=result["messages"][2].id,
+ ),
+ ]
+ second_tool_call = [
+ ToolCall(
+ name="tool_normal",
+ args={"input": "Test normal"},
+ id="2",
+ ),
+ ]
+ model = FakeToolCallingModel(tool_calls=[second_tool_call, []])
+ agent = create_react_agent(model, [tool_return_direct, tool_normal])
+ result = agent.invoke({"messages": [HumanMessage(content="Test normal", id="hum1")]})
+ assert result["messages"] == [
+ HumanMessage(content="Test normal", id="hum1"),
+ AIMessage(content="Test normal", id="0", tool_calls=second_tool_call),
+ ToolMessage(
+ content="Normal result: Test normal",
+ name="tool_normal",
+ tool_call_id="2",
+ id=result["messages"][2].id,
+ ),
+ AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"),
+ ]
+
+ both_tool_calls = [
+ ToolCall(
+ name="tool_return_direct",
+ args={"input": "Test both direct"},
+ id="3",
+ ),
+ ToolCall(
+ name="tool_normal",
+ args={"input": "Test both normal"},
+ id="4",
+ ),
+ ]
+ model = FakeToolCallingModel(tool_calls=[both_tool_calls, []])
+ agent = create_react_agent(model, [tool_return_direct, tool_normal])
+ result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]})
+ assert result["messages"] == [
+ HumanMessage(content="Test both", id="hum2"),
+ AIMessage(content="Test both", id="0", tool_calls=both_tool_calls),
+ ToolMessage(
+ content="Direct result: Test both direct",
+ name="tool_return_direct",
+ tool_call_id="3",
+ id=result["messages"][2].id,
+ ),
+ ToolMessage(
+ content="Normal result: Test both normal",
+ name="tool_normal",
+ tool_call_id="4",
+ id=result["messages"][3].id,
+ ),
+ ]
+
+
+def test__get_state_args() -> None:
+ class Schema1(BaseModel):
+ a: Annotated[str, InjectedState]
+
+ class Schema2(Schema1):
+ b: Annotated[int, InjectedState("bar")]
+
+ @dec_tool(args_schema=Schema2)
+ def foo(a: str, b: int) -> float:
+ """return"""
+ return 0.0
+
+ assert _get_state_args(foo) == {"a": None, "b": "bar"}
+
+
+def test_inspect_react() -> None:
+ model = FakeToolCallingModel(tool_calls=[])
+ agent = create_react_agent(model, [])
+ inspect.getclosurevars(agent.nodes["agent"].bound.func)
+
+
+def test_react_with_subgraph_tools(
+ sync_checkpointer: BaseCheckpointSaver,
+) -> None:
+ class State(TypedDict):
+ a: int
+ b: int
+
+ class Output(TypedDict):
+ result: int
+
+ # Define the subgraphs
+ def add(state):
+ return {"result": state["a"] + state["b"]}
+
+ add_subgraph = (
+ StateGraph(State, output_schema=Output).add_node(add).add_edge(START, "add").compile()
+ )
+
+ def multiply(state):
+ return {"result": state["a"] * state["b"]}
+
+ multiply_subgraph = (
+ StateGraph(State, output_schema=Output)
+ .add_node(multiply)
+ .add_edge(START, "multiply")
+ .compile()
+ )
+
+ multiply_subgraph.invoke({"a": 2, "b": 3})
+
+ # Add subgraphs as tools
+
+ def addition(a: int, b: int):
+ """Add two numbers"""
+ return add_subgraph.invoke({"a": a, "b": b})["result"]
+
+ def multiplication(a: int, b: int):
+ """Multiply two numbers"""
+ return multiply_subgraph.invoke({"a": a, "b": b})["result"]
+
+ model = FakeToolCallingModel(
+ tool_calls=[
+ [
+ {"args": {"a": 2, "b": 3}, "id": "1", "name": "addition"},
+ {"args": {"a": 2, "b": 3}, "id": "2", "name": "multiplication"},
+ ],
+ [],
+ ]
+ )
+ tool_node = ToolNode([addition, multiplication], handle_tool_errors=False)
+ agent = create_react_agent(
+ model,
+ tool_node,
+ checkpointer=sync_checkpointer,
+ )
+ result = agent.invoke(
+ {"messages": [HumanMessage(content="What's 2 + 3 and 2 * 3?")]},
+ config={"configurable": {"thread_id": "1"}},
+ )
+ assert result["messages"] == [
+ _AnyIdHumanMessage(content="What's 2 + 3 and 2 * 3?"),
+ AIMessage(
+ content="What's 2 + 3 and 2 * 3?",
+ id="0",
+ tool_calls=[
+ ToolCall(name="addition", args={"a": 2, "b": 3}, id="1"),
+ ToolCall(name="multiplication", args={"a": 2, "b": 3}, id="2"),
+ ],
+ ),
+ ToolMessage(content="5", name="addition", tool_call_id="1", id=result["messages"][2].id),
+ ToolMessage(
+ content="6",
+ name="multiplication",
+ tool_call_id="2",
+ id=result["messages"][3].id,
+ ),
+ AIMessage(content="What's 2 + 3 and 2 * 3?-What's 2 + 3 and 2 * 3?-5-6", id="1"),
+ ]
+
+
+def test_react_agent_subgraph_streaming_sync() -> None:
+ """Test React agent streaming when used as a subgraph node sync version"""
+
+ @dec_tool
+ def get_weather(city: str) -> str:
+ """Get the weather of a city."""
+ return f"The weather of {city} is sunny."
+
+ # Create a React agent
+ model = FakeToolCallingModel(
+ tool_calls=[
+ [{"args": {"city": "Tokyo"}, "id": "1", "name": "get_weather"}],
+ [],
+ ]
+ )
+
+ agent = create_react_agent(
+ model,
+ tools=[get_weather],
+ prompt="You are a helpful travel assistant.",
+ )
+
+ # Create a subgraph that uses the React agent as a node
+ def react_agent_node(state: MessagesState, config: RunnableConfig) -> MessagesState:
+ """Node that runs the React agent and collects streaming output."""
+ collected_content = ""
+
+ # Stream the agent output and collect content
+ for msg_chunk, _msg_metadata in agent.stream(
+ {"messages": [("user", state["messages"][-1].content)]},
+ config,
+ stream_mode="messages",
+ ):
+ if hasattr(msg_chunk, "content") and msg_chunk.content:
+ collected_content += msg_chunk.content
+
+ return {"messages": [("assistant", collected_content)]}
+
+ # Create the main workflow with the React agent as a subgraph node
+ workflow = StateGraph(MessagesState)
+ workflow.add_node("react_agent", react_agent_node)
+ workflow.add_edge(START, "react_agent")
+ workflow.add_edge("react_agent", "__end__")
+ compiled_workflow = workflow.compile()
+
+ # Test the streaming functionality
+ result = compiled_workflow.invoke({"messages": [("user", "What is the weather in Tokyo?")]})
+
+ # Verify the result contains expected structure
+ assert len(result["messages"]) == 2
+ assert result["messages"][0].content == "What is the weather in Tokyo?"
+ assert "assistant" in str(result["messages"][1])
+
+ # Test streaming with subgraphs = True
+ result = compiled_workflow.invoke(
+ {"messages": [("user", "What is the weather in Tokyo?")]},
+ subgraphs=True,
+ )
+ assert len(result["messages"]) == 2
+
+ events = []
+ for event in compiled_workflow.stream(
+ {"messages": [("user", "What is the weather in Tokyo?")]},
+ stream_mode="messages",
+ subgraphs=False,
+ ):
+ events.append(event)
+
+ assert len(events) == 0
+
+ events = []
+ for event in compiled_workflow.stream(
+ {"messages": [("user", "What is the weather in Tokyo?")]},
+ stream_mode="messages",
+ subgraphs=True,
+ ):
+ events.append(event)
+
+ assert len(events) == 3
+ namespace, (msg, metadata) = events[0]
+ # FakeToolCallingModel returns a single AIMessage with tool calls
+ # The content of the AIMessage reflects the input message
+ assert msg.content.startswith("You are a helpful travel assistant")
+ namespace, (msg, metadata) = events[1] # ToolMessage
+ assert msg.content.startswith("The weather of Tokyo is sunny.")
+
+
+async def test_react_agent_subgraph_streaming() -> None:
+ """Test React agent streaming when used as a subgraph node."""
+
+ @dec_tool
+ def get_weather(city: str) -> str:
+ """Get the weather of a city."""
+ return f"The weather of {city} is sunny."
+
+ # Create a React agent
+ model = FakeToolCallingModel(
+ tool_calls=[
+ [{"args": {"city": "Tokyo"}, "id": "1", "name": "get_weather"}],
+ [],
+ ]
+ )
+
+ agent = create_react_agent(
+ model,
+ tools=[get_weather],
+ prompt="You are a helpful travel assistant.",
+ )
+
+ # Create a subgraph that uses the React agent as a node
+ async def react_agent_node(state: MessagesState, config: RunnableConfig) -> MessagesState:
+ """Node that runs the React agent and collects streaming output."""
+ collected_content = ""
+
+ # Stream the agent output and collect content
+ async for msg_chunk, _msg_metadata in agent.astream(
+ {"messages": [("user", state["messages"][-1].content)]},
+ config,
+ stream_mode="messages",
+ ):
+ if hasattr(msg_chunk, "content") and msg_chunk.content:
+ collected_content += msg_chunk.content
+
+ return {"messages": [("assistant", collected_content)]}
+
+ # Create the main workflow with the React agent as a subgraph node
+ workflow = StateGraph(MessagesState)
+ workflow.add_node("react_agent", react_agent_node)
+ workflow.add_edge(START, "react_agent")
+ workflow.add_edge("react_agent", "__end__")
+ compiled_workflow = workflow.compile()
+
+ # Test the streaming functionality
+ result = await compiled_workflow.ainvoke(
+ {"messages": [("user", "What is the weather in Tokyo?")]}
+ )
+
+ # Verify the result contains expected structure
+ assert len(result["messages"]) == 2
+ assert result["messages"][0].content == "What is the weather in Tokyo?"
+ assert "assistant" in str(result["messages"][1])
+
+ # Test streaming with subgraphs = True
+ result = await compiled_workflow.ainvoke(
+ {"messages": [("user", "What is the weather in Tokyo?")]},
+ subgraphs=True,
+ )
+ assert len(result["messages"]) == 2
+
+ events = []
+ async for event in compiled_workflow.astream(
+ {"messages": [("user", "What is the weather in Tokyo?")]},
+ stream_mode="messages",
+ subgraphs=False,
+ ):
+ events.append(event)
+
+ assert len(events) == 0
+
+ events = []
+ async for event in compiled_workflow.astream(
+ {"messages": [("user", "What is the weather in Tokyo?")]},
+ stream_mode="messages",
+ subgraphs=True,
+ ):
+ events.append(event)
+
+ assert len(events) == 3
+ namespace, (msg, metadata) = events[0]
+ # FakeToolCallingModel returns a single AIMessage with tool calls
+ # The content of the AIMessage reflects the input message
+ assert msg.content.startswith("You are a helpful travel assistant")
+ namespace, (msg, metadata) = events[1] # ToolMessage
+ assert msg.content.startswith("The weather of Tokyo is sunny.")
+
+
+def test_tool_node_node_interrupt(
+ sync_checkpointer: BaseCheckpointSaver,
+) -> None:
+ def tool_normal(some_val: int) -> str:
+ """Tool docstring."""
+ return "normal"
+
+ def tool_interrupt(some_val: int) -> str:
+ """Tool docstring."""
+ return interrupt("provide value for foo")
+
+ # test inside react agent
+ model = FakeToolCallingModel(
+ tool_calls=[
+ [
+ ToolCall(name="tool_interrupt", args={"some_val": 0}, id="1"),
+ ToolCall(name="tool_normal", args={"some_val": 1}, id="2"),
+ ],
+ [],
+ ]
+ )
+ config = {"configurable": {"thread_id": "1"}}
+ agent = create_react_agent(
+ model,
+ [tool_interrupt, tool_normal],
+ checkpointer=sync_checkpointer,
+ )
+ result = agent.invoke({"messages": [HumanMessage("hi?")]}, config)
+ expected_messages = [
+ _AnyIdHumanMessage(content="hi?"),
+ AIMessage(
+ content="hi?",
+ id="0",
+ tool_calls=[
+ {
+ "name": "tool_interrupt",
+ "args": {"some_val": 0},
+ "id": "1",
+ "type": "tool_call",
+ },
+ {
+ "name": "tool_normal",
+ "args": {"some_val": 1},
+ "id": "2",
+ "type": "tool_call",
+ },
+ ],
+ ),
+ _AnyIdToolMessage(content="normal", name="tool_normal", tool_call_id="2"),
+ ]
+ assert result["messages"] == expected_messages
+
+ state = agent.get_state(config)
+ assert state.next == ("tools",)
+ task = state.tasks[0]
+ assert task.name == "tools"
+ assert task.interrupts == (
+ Interrupt(
+ value="provide value for foo",
+ id=AnyStr(),
+ ),
+ )
+
+
+def test_dynamic_model_basic() -> None:
+ """Test basic dynamic model functionality."""
+
+ def dynamic_model(state, runtime: Runtime):
+ # Return different models based on state
+ if "urgent" in state["messages"][-1].content:
+ return FakeToolCallingModel(tool_calls=[])
+ return FakeToolCallingModel(tool_calls=[])
+
+ agent = create_react_agent(dynamic_model, [])
+
+ result = agent.invoke({"messages": [HumanMessage("hello")]})
+ assert len(result["messages"]) == 2
+ assert result["messages"][-1].content == "hello"
+
+ result = agent.invoke({"messages": [HumanMessage("urgent help")]})
+ assert len(result["messages"]) == 2
+ assert result["messages"][-1].content == "urgent help"
+
+
+def test_dynamic_model_with_tools() -> None:
+ """Test dynamic model with tool calling."""
+
+ @dec_tool
+ def basic_tool(x: int) -> str:
+ """Basic tool."""
+ return f"basic: {x}"
+
+ @dec_tool
+ def advanced_tool(x: int) -> str:
+ """Advanced tool."""
+ return f"advanced: {x}"
+
+ def dynamic_model(state: dict, runtime: Runtime) -> BaseChatModel:
+ # Return model with different behaviors based on message content
+ if "advanced" in state["messages"][-1].content:
+ return FakeToolCallingModel(
+ tool_calls=[
+ [{"args": {"x": 1}, "id": "1", "name": "advanced_tool"}],
+ [],
+ ]
+ )
+ return FakeToolCallingModel(
+ tool_calls=[[{"args": {"x": 1}, "id": "1", "name": "basic_tool"}], []]
+ )
+
+ agent = create_react_agent(dynamic_model, [basic_tool, advanced_tool])
+
+ # Test basic tool usage
+ result = agent.invoke({"messages": [HumanMessage("basic request")]})
+ assert len(result["messages"]) == 3
+ tool_message = result["messages"][-1]
+ assert tool_message.content == "basic: 1"
+ assert tool_message.name == "basic_tool"
+
+ # Test advanced tool usage
+ result = agent.invoke({"messages": [HumanMessage("advanced request")]})
+ assert len(result["messages"]) == 3
+ tool_message = result["messages"][-1]
+ assert tool_message.content == "advanced: 1"
+ assert tool_message.name == "advanced_tool"
+
+
+@dataclasses.dataclass
+class Context:
+ user_id: str
+
+
+def test_dynamic_model_with_context() -> None:
+ """Test dynamic model using config parameters."""
+
+ def dynamic_model(state, runtime: Runtime[Context]):
+ # Use context to determine model behavior
+ user_id = runtime.context.user_id
+ if user_id == "user_premium":
+ return FakeToolCallingModel(tool_calls=[])
+ return FakeToolCallingModel(tool_calls=[])
+
+ agent = create_react_agent(dynamic_model, [], context_schema=Context)
+
+ # Test with basic user
+ result = agent.invoke(
+ {"messages": [HumanMessage("hello")]},
+ context=Context(user_id="user_basic"),
+ )
+ assert len(result["messages"]) == 2
+
+ # Test with premium user
+ result = agent.invoke(
+ {"messages": [HumanMessage("hello")]},
+ context=Context(user_id="user_premium"),
+ )
+ assert len(result["messages"]) == 2
+
+
+def test_dynamic_model_with_state_schema() -> None:
+ """Test dynamic model with custom state schema."""
+
+ class CustomDynamicState(AgentState):
+ model_preference: str = "default"
+
+ def dynamic_model(state: CustomDynamicState, runtime: Runtime) -> BaseChatModel:
+ # Use custom state field to determine model
+ if state.get("model_preference") == "advanced":
+ return FakeToolCallingModel(tool_calls=[])
+ return FakeToolCallingModel(tool_calls=[])
+
+ agent = create_react_agent(dynamic_model, [], state_schema=CustomDynamicState)
+
+ result = agent.invoke({"messages": [HumanMessage("hello")], "model_preference": "advanced"})
+ assert len(result["messages"]) == 2
+ assert result["model_preference"] == "advanced"
+
+
+def test_dynamic_model_with_prompt() -> None:
+ """Test dynamic model with different prompt types."""
+
+ def dynamic_model(state: AgentState, runtime: Runtime) -> BaseChatModel:
+ return FakeToolCallingModel(tool_calls=[])
+
+ # Test with string prompt
+ agent = create_react_agent(dynamic_model, [], prompt="system_msg")
+ result = agent.invoke({"messages": [HumanMessage("human_msg")]})
+ assert result["messages"][-1].content == "system_msg-human_msg"
+
+ # Test with callable prompt
+ def dynamic_prompt(state: AgentState) -> list[MessageLikeRepresentation]:
+ """Generate a dynamic system message based on state."""
+ return [{"role": "system", "content": "system_msg"}, *list(state["messages"])]
+
+ agent = create_react_agent(dynamic_model, [], prompt=dynamic_prompt)
+ result = agent.invoke({"messages": [HumanMessage("human_msg")]})
+ assert result["messages"][-1].content == "system_msg-human_msg"
+
+
+async def test_dynamic_model_async() -> None:
+ """Test dynamic model with async operations."""
+
+ def dynamic_model(state: AgentState, runtime: Runtime) -> BaseChatModel:
+ return FakeToolCallingModel(tool_calls=[])
+
+ agent = create_react_agent(dynamic_model, [])
+
+ result = await agent.ainvoke({"messages": [HumanMessage("hello async")]})
+ assert len(result["messages"]) == 2
+ assert result["messages"][-1].content == "hello async"
+
+
+def test_dynamic_model_with_structured_response() -> None:
+ """Test dynamic model with structured response format."""
+
+ class TestResponse(BaseModel):
+ message: str
+ confidence: float
+
+ def dynamic_model(state, runtime: Runtime):
+ return FakeToolCallingModel(
+ tool_calls=[
+ [
+ ToolCall(
+ name="TestResponse",
+ args={"message": "dynamic response", "confidence": 0.9},
+ id="1",
+ type="tool_call",
+ )
+ ]
+ ],
+ )
+
+ agent = create_react_agent(dynamic_model, [], response_format=TestResponse)
+
+ result = agent.invoke({"messages": [HumanMessage("hello")]})
+ assert "structured_response" in result
+ assert result["structured_response"].message == "dynamic response"
+ assert result["structured_response"].confidence == 0.9
+
+
+def test_dynamic_model_with_checkpointer(sync_checkpointer) -> None:
+ """Test dynamic model with checkpointer."""
+ call_count = 0
+
+ def dynamic_model(state: AgentState, runtime: Runtime) -> BaseChatModel:
+ nonlocal call_count
+ call_count += 1
+ return FakeToolCallingModel(
+ tool_calls=[],
+ # Incrementing the call count as it is used to assign an id
+ # to the AIMessage.
+ # The default reducer semantics are to overwrite an existing message
+ # with the new one if the id matches.
+ index=call_count,
+ )
+
+ agent = create_react_agent(dynamic_model, [], checkpointer=sync_checkpointer)
+ config = {"configurable": {"thread_id": "test_dynamic"}}
+
+ # First call
+ result1 = agent.invoke({"messages": [HumanMessage("hello")]}, config)
+ assert len(result1["messages"]) == 2 # Human + AI message
+
+ # Second call - should load from checkpoint
+ result2 = agent.invoke({"messages": [HumanMessage("world")]}, config)
+ assert len(result2["messages"]) == 4
+
+ # Dynamic model should be called each time
+ assert call_count >= 2
+
+
+def test_dynamic_model_state_dependent_tools() -> None:
+ """Test dynamic model that changes available tools based on state."""
+
+ @dec_tool
+ def tool_a(x: int) -> str:
+ """Tool A."""
+ return f"A: {x}"
+
+ @dec_tool
+ def tool_b(x: int) -> str:
+ """Tool B."""
+ return f"B: {x}"
+
+ def dynamic_model(state, runtime: Runtime):
+ # Switch tools based on message history
+ if any("use_b" in msg.content for msg in state["messages"]):
+ return FakeToolCallingModel(
+ tool_calls=[[{"args": {"x": 2}, "id": "1", "name": "tool_b"}], []]
+ )
+ return FakeToolCallingModel(
+ tool_calls=[[{"args": {"x": 1}, "id": "1", "name": "tool_a"}], []]
+ )
+
+ agent = create_react_agent(dynamic_model, [tool_a, tool_b])
+
+ # Ask to use tool B
+ result = agent.invoke({"messages": [HumanMessage("use_b please")]})
+ last_message = result["messages"][-1]
+ assert isinstance(last_message, ToolMessage)
+ assert last_message.content == "B: 2"
+
+ # Ask to use tool A
+ result = agent.invoke({"messages": [HumanMessage("hello")]})
+ last_message = result["messages"][-1]
+ assert isinstance(last_message, ToolMessage)
+ assert last_message.content == "A: 1"
+
+
+def test_dynamic_model_error_handling() -> None:
+ """Test error handling in dynamic model."""
+
+ def failing_dynamic_model(state, runtime: Runtime):
+ if "fail" in state["messages"][-1].content:
+ msg = "Dynamic model failed"
+ raise ValueError(msg)
+ return FakeToolCallingModel(tool_calls=[])
+
+ agent = create_react_agent(failing_dynamic_model, [])
+
+ # Normal operation should work
+ result = agent.invoke({"messages": [HumanMessage("hello")]})
+ assert len(result["messages"]) == 2
+
+ # Should propagate the error
+ with pytest.raises(ValueError, match="Dynamic model failed"):
+ agent.invoke({"messages": [HumanMessage("fail now")]})
+
+
+def test_dynamic_model_vs_static_model_behavior() -> None:
+ """Test that dynamic and static models produce equivalent results when configured the same."""
+ # Static model
+ static_model = FakeToolCallingModel(tool_calls=[])
+ static_agent = create_react_agent(static_model, [])
+
+ # Dynamic model returning the same model
+ def dynamic_model(state, runtime: Runtime):
+ return FakeToolCallingModel(tool_calls=[])
+
+ dynamic_agent = create_react_agent(dynamic_model, [])
+
+ input_msg = {"messages": [HumanMessage("test message")]}
+
+ static_result = static_agent.invoke(input_msg)
+ dynamic_result = dynamic_agent.invoke(input_msg)
+
+ # Results should be equivalent (content-wise, IDs may differ)
+ assert len(static_result["messages"]) == len(dynamic_result["messages"])
+ assert static_result["messages"][0].content == dynamic_result["messages"][0].content
+ assert static_result["messages"][1].content == dynamic_result["messages"][1].content
+
+
+def test_dynamic_model_receives_correct_state() -> None:
+ """Test that the dynamic model function receives the correct state, not the model input."""
+ received_states = []
+
+ class CustomAgentState(AgentState):
+ custom_field: str
+
+ def dynamic_model(state, runtime: Runtime) -> BaseChatModel:
+ # Capture the state that's passed to the dynamic model function
+ received_states.append(state)
+ return FakeToolCallingModel(tool_calls=[])
+
+ agent = create_react_agent(dynamic_model, [], state_schema=CustomAgentState)
+
+ # Test with initial state
+ input_state = {"messages": [HumanMessage("hello")], "custom_field": "test_value"}
+ agent.invoke(input_state)
+
+ # The dynamic model function should receive the original state, not the processed model input
+ assert len(received_states) == 1
+ received_state = received_states[0]
+
+ # Should have the custom field from original state
+ assert "custom_field" in received_state
+ assert received_state["custom_field"] == "test_value"
+
+ # Should have the original messages
+ assert len(received_state["messages"]) == 1
+ assert received_state["messages"][0].content == "hello"
+
+
+async def test_dynamic_model_receives_correct_state_async() -> None:
+ """Test that the async dynamic model function receives the correct state, not the model input."""
+ received_states = []
+
+ class CustomAgentStateAsync(AgentState):
+ custom_field: str
+
+ def dynamic_model(state, runtime: Runtime):
+ # Capture the state that's passed to the dynamic model function
+ received_states.append(state)
+ return FakeToolCallingModel(tool_calls=[])
+
+ agent = create_react_agent(dynamic_model, [], state_schema=CustomAgentStateAsync)
+
+ # Test with initial state
+ input_state = {
+ "messages": [HumanMessage("hello async")],
+ "custom_field": "test_value_async",
+ }
+ await agent.ainvoke(input_state)
+
+ # The dynamic model function should receive the original state, not the processed model input
+ assert len(received_states) == 1
+ received_state = received_states[0]
+
+ # Should have the custom field from original state
+ assert "custom_field" in received_state
+ assert received_state["custom_field"] == "test_value_async"
+
+ # Should have the original messages
+ assert len(received_state["messages"]) == 1
+ assert received_state["messages"][0].content == "hello async"
+
+
+def test_pre_model_hook() -> None:
+ model = FakeToolCallingModel(tool_calls=[])
+
+ # Test `llm_input_messages`
+ def pre_model_hook(state: AgentState):
+ return {"llm_input_messages": [HumanMessage("Hello!")]}
+
+ agent = create_react_agent(model, [], pre_model_hook=pre_model_hook)
+ assert "pre_model_hook" in agent.nodes
+ result = agent.invoke({"messages": [HumanMessage("hi?")]})
+ assert result == {
+ "messages": [
+ _AnyIdHumanMessage(content="hi?"),
+ AIMessage(content="Hello!", id="0"),
+ ]
+ }
+
+ # Test `messages`
+ def pre_model_hook(state: AgentState):
+ return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), HumanMessage("Hello!")]}
+
+ agent = create_react_agent(model, [], pre_model_hook=pre_model_hook)
+ result = agent.invoke({"messages": [HumanMessage("hi?")]})
+ assert result == {
+ "messages": [
+ _AnyIdHumanMessage(content="Hello!"),
+ AIMessage(content="Hello!", id="1"),
+ ]
+ }
+
+
+def test_post_model_hook() -> None:
+ class FlagState(AgentState):
+ flag: bool
+
+ model = FakeToolCallingModel(tool_calls=[])
+
+ def post_model_hook(state: FlagState) -> dict[str, bool]:
+ return {"flag": True}
+
+ pmh_agent = create_react_agent(
+ model, [], post_model_hook=post_model_hook, state_schema=FlagState
+ )
+
+ assert "post_model_hook" in pmh_agent.nodes
+
+ result = pmh_agent.invoke({"messages": [HumanMessage("hi?")], "flag": False})
+ assert result["flag"] is True
+
+ events = list(pmh_agent.stream({"messages": [HumanMessage("hi?")], "flag": False}))
+ assert events == [
+ {
+ "agent": {
+ "messages": [
+ AIMessage(
+ content="hi?",
+ additional_kwargs={},
+ response_metadata={},
+ id="1",
+ )
+ ]
+ }
+ },
+ {"post_model_hook": {"flag": True}},
+ ]
+
+
+def test_post_model_hook_with_structured_output() -> None:
+ class WeatherResponse(BaseModel):
+ temperature: float = Field(description="The temperature in fahrenheit")
+
+ tool_calls: list[list[ToolCall]] = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [{"args": {"temperature": 75}, "id": "2", "name": "WeatherResponse"}],
+ ]
+
+ def get_weather() -> str:
+ """Get the weather"""
+ return "The weather is sunny and 75°F."
+
+ expected_structured_response = WeatherResponse(temperature=75)
+
+ class State(AgentState):
+ flag: bool
+ structured_response: WeatherResponse
+
+ def post_model_hook(state: State) -> Union[dict[str, bool], Command]:
+ return {"flag": True}
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+ agent = create_react_agent(
+ model,
+ [get_weather],
+ response_format=WeatherResponse,
+ post_model_hook=post_model_hook,
+ state_schema=State,
+ )
+
+ assert "post_model_hook" in agent.nodes
+
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")], "flag": False})
+ assert response["flag"] is True
+ assert response["structured_response"] == expected_structured_response
+
+ # Reset the state of the model
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+ agent = create_react_agent(
+ model,
+ [get_weather],
+ response_format=WeatherResponse,
+ post_model_hook=post_model_hook,
+ state_schema=State,
+ )
+
+ events = list(agent.stream({"messages": [HumanMessage("What's the weather?")], "flag": False}))
+ assert events == [
+ {
+ "agent": {
+ "messages": [
+ AIMessage(
+ content="What's the weather?",
+ additional_kwargs={},
+ response_metadata={},
+ id="0",
+ tool_calls=[
+ {
+ "name": "get_weather",
+ "args": {},
+ "id": "1",
+ "type": "tool_call",
+ }
+ ],
+ )
+ ]
+ }
+ },
+ {"post_model_hook": {"flag": True}},
+ {
+ "tools": {
+ "messages": [
+ _AnyIdToolMessage(
+ content="The weather is sunny and 75°F.",
+ name="get_weather",
+ tool_call_id="1",
+ )
+ ]
+ }
+ },
+ {
+ "agent": {
+ "messages": [
+ AIMessage(
+ content="What's the weather?-What's the weather?-The weather is sunny and 75°F.",
+ additional_kwargs={},
+ response_metadata={},
+ id="1",
+ tool_calls=[
+ {
+ "name": "WeatherResponse",
+ "args": {"temperature": 75},
+ "id": "2",
+ "type": "tool_call",
+ }
+ ],
+ ),
+ _AnyIdToolMessage(
+ content="Returning structured response: {'temperature': 75.0}",
+ name="WeatherResponse",
+ tool_call_id="2",
+ ),
+ ],
+ "structured_response": WeatherResponse(temperature=75.0),
+ }
+ },
+ {"post_model_hook": {"flag": True}},
+ ]
+
+
+def test_create_react_agent_inject_vars_with_post_model_hook() -> None:
+ store = InMemoryStore()
+ namespace = ("test",)
+ store.put(namespace, "test_key", {"bar": 3})
+
+ def tool1(
+ some_val: int,
+ state: Annotated[dict, InjectedState],
+ store: Annotated[BaseStore, InjectedStore()],
+ ) -> str:
+ """Tool 1 docstring."""
+ store_val = store.get(namespace, "test_key").value["bar"]
+ return some_val + state["foo"] + store_val
+
+ tool_call = {
+ "name": "tool1",
+ "args": {"some_val": 1},
+ "id": "some 0",
+ "type": "tool_call",
+ }
+
+ def post_model_hook(state: dict) -> dict:
+ """Post model hook is injecting a new foo key."""
+ return {"foo": 2}
+
+ model = FakeToolCallingModel(tool_calls=[[tool_call], []])
+ agent = create_react_agent(
+ model,
+ ToolNode([tool1], handle_tool_errors=False),
+ state_schema=AgentStateExtraKey,
+ store=store,
+ post_model_hook=post_model_hook,
+ )
+ input_message = HumanMessage("hi")
+ result = agent.invoke({"messages": [input_message], "foo": 2})
+ assert result["messages"] == [
+ input_message,
+ AIMessage(content="hi", tool_calls=[tool_call], id="0"),
+ _AnyIdToolMessage(content="6", name="tool1", tool_call_id="some 0"),
+ AIMessage("hi-hi-6", id="1"),
+ ]
+ assert result["foo"] == 2
+
+
+def test_response_format_using_tool_choice() -> None:
+ """Test response format using tool choice."""
+
+ class WeatherResponse(BaseModel):
+ temperature: float = Field(description="The temperature in fahrenheit")
+
+ tool_calls: list[list[ToolCall]] = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [{"args": {"temperature": "75"}, "id": "2", "name": "WeatherResponse"}],
+ ]
+
+ def get_weather() -> str:
+ """Get the weather"""
+ return "The weather is sunny and 75°F."
+
+ expected_structured_response = WeatherResponse(temperature=75)
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+ agent = create_react_agent(
+ model,
+ [get_weather],
+ response_format=WeatherResponse,
+ )
+ response = agent.invoke(
+ {
+ "messages": [
+ {
+ "role": "user",
+ "content": "What's the weather?",
+ }
+ ]
+ }
+ )
+ assert response.get("structured_response") == expected_structured_response
diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_react_agent_graph.py b/libs/langchain_v1/tests/unit_tests/agents/test_react_agent_graph.py
new file mode 100644
index 00000000000..b978306c652
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/test_react_agent_graph.py
@@ -0,0 +1,58 @@
+from collections.abc import Callable
+from typing import Union
+
+import pytest
+from pydantic import BaseModel
+from syrupy.assertion import SnapshotAssertion
+
+from langchain.agents import create_react_agent
+
+from .model import FakeToolCallingModel
+
+model = FakeToolCallingModel()
+
+
+def tool() -> None:
+ """Testing tool."""
+
+
+def pre_model_hook() -> None:
+ """Pre-model hook."""
+
+
+def post_model_hook() -> None:
+ """Post-model hook."""
+
+
+class ResponseFormat(BaseModel):
+ """Response format for the agent."""
+
+ result: str
+
+
+@pytest.mark.parametrize("tools", [[], [tool]])
+@pytest.mark.parametrize("pre_model_hook", [None, pre_model_hook])
+@pytest.mark.parametrize("post_model_hook", [None, post_model_hook])
+def test_react_agent_graph_structure(
+ snapshot: SnapshotAssertion,
+ tools: list[Callable],
+ pre_model_hook: Union[Callable, None],
+ post_model_hook: Union[Callable, None],
+) -> None:
+ agent = create_react_agent(
+ model,
+ tools=tools,
+ pre_model_hook=pre_model_hook,
+ post_model_hook=post_model_hook,
+ )
+ try:
+ assert agent.get_graph().draw_mermaid(with_styles=False) == snapshot
+ except Exception as e:
+ msg = (
+ "The graph structure has changed. Please update the snapshot."
+ "Configuration used:\n"
+ f"tools: {tools}, "
+ f"pre_model_hook: {pre_model_hook}, "
+ f"post_model_hook: {post_model_hook}, "
+ )
+ raise ValueError(msg) from e
diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py
new file mode 100644
index 00000000000..06cf8572b2c
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py
@@ -0,0 +1,710 @@
+"""Test suite for create_react_agent with structured output response_format permutations."""
+
+import pytest
+
+from dataclasses import dataclass
+from typing import Union
+
+from langchain_core.messages import HumanMessage
+from langchain.agents import create_react_agent
+from langchain.agents.structured_output import (
+ MultipleStructuredOutputsError,
+ ProviderStrategy,
+ StructuredOutputValidationError,
+ ToolStrategy,
+)
+from pydantic import BaseModel, Field
+from typing_extensions import TypedDict
+
+from tests.unit_tests.agents.model import FakeToolCallingModel
+
+
+# Test data models
+class WeatherBaseModel(BaseModel):
+ """Weather response."""
+
+ temperature: float = Field(description="The temperature in fahrenheit")
+ condition: str = Field(description="Weather condition")
+
+
+@dataclass
+class WeatherDataclass:
+ """Weather response."""
+
+ temperature: float
+ condition: str
+
+
+class WeatherTypedDict(TypedDict):
+ """Weather response."""
+
+ temperature: float
+ condition: str
+
+
+weather_json_schema = {
+ "type": "object",
+ "properties": {
+ "temperature": {"type": "number", "description": "Temperature in fahrenheit"},
+ "condition": {"type": "string", "description": "Weather condition"},
+ },
+ "title": "weather_schema",
+ "required": ["temperature", "condition"],
+}
+
+
+class LocationResponse(BaseModel):
+ city: str = Field(description="The city name")
+ country: str = Field(description="The country name")
+
+
+class LocationTypedDict(TypedDict):
+ city: str
+ country: str
+
+
+location_json_schema = {
+ "type": "object",
+ "properties": {
+ "city": {"type": "string", "description": "The city name"},
+ "country": {"type": "string", "description": "The country name"},
+ },
+ "title": "location_schema",
+ "required": ["city", "country"],
+}
+
+
+def get_weather() -> str:
+ """Get the weather."""
+
+ return "The weather is sunny and 75°F."
+
+
+def get_location() -> str:
+ """Get the current location."""
+
+ return "You are in New York, USA."
+
+
+# Standardized test data
+WEATHER_DATA = {"temperature": 75.0, "condition": "sunny"}
+LOCATION_DATA = {"city": "New York", "country": "USA"}
+
+# Standardized expected responses
+EXPECTED_WEATHER_PYDANTIC = WeatherBaseModel(**WEATHER_DATA)
+EXPECTED_WEATHER_DATACLASS = WeatherDataclass(**WEATHER_DATA)
+EXPECTED_WEATHER_DICT: WeatherTypedDict = {"temperature": 75.0, "condition": "sunny"}
+EXPECTED_LOCATION = LocationResponse(**LOCATION_DATA)
+EXPECTED_LOCATION_DICT: LocationTypedDict = {"city": "New York", "country": "USA"}
+
+
+class TestResponseFormatAsModel:
+ def test_pydantic_model(self) -> None:
+ """Test response_format as Pydantic model."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(model, [get_weather], response_format=WeatherBaseModel)
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
+ assert len(response["messages"]) == 5
+
+ def test_dataclass(self) -> None:
+ """Test response_format as dataclass."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "WeatherDataclass",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(model, [get_weather], response_format=WeatherDataclass)
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS
+ assert len(response["messages"]) == 5
+
+ def test_typed_dict(self) -> None:
+ """Test response_format as TypedDict."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "WeatherTypedDict",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(model, [get_weather], response_format=WeatherTypedDict)
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_DICT
+ assert len(response["messages"]) == 5
+
+ def test_json_schema(self) -> None:
+ """Test response_format as JSON schema."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "weather_schema",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(model, [get_weather], response_format=weather_json_schema)
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_DICT
+ assert len(response["messages"]) == 5
+
+
+class TestResponseFormatAsToolStrategy:
+ def test_pydantic_model(self) -> None:
+ """Test response_format as ToolStrategy with Pydantic model."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(
+ model, [get_weather], response_format=ToolStrategy(WeatherBaseModel)
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
+ assert len(response["messages"]) == 5
+
+ def test_dataclass(self) -> None:
+ """Test response_format as ToolStrategy with dataclass."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "WeatherDataclass",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(
+ model, [get_weather], response_format=ToolStrategy(WeatherDataclass)
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS
+ assert len(response["messages"]) == 5
+
+ def test_typed_dict(self) -> None:
+ """Test response_format as ToolStrategy with TypedDict."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "WeatherTypedDict",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(
+ model, [get_weather], response_format=ToolStrategy(WeatherTypedDict)
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_DICT
+ assert len(response["messages"]) == 5
+
+ def test_json_schema(self) -> None:
+ """Test response_format as ToolStrategy with JSON schema."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "weather_schema",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(
+ model, [get_weather], response_format=ToolStrategy(weather_json_schema)
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_DICT
+ assert len(response["messages"]) == 5
+
+ def test_union_of_json_schemas(self) -> None:
+ """Test response_format as ToolStrategy with union of JSON schemas."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "weather_schema",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(
+ model,
+ [get_weather, get_location],
+ response_format=ToolStrategy({"oneOf": [weather_json_schema, location_json_schema]}),
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_DICT
+ assert len(response["messages"]) == 5
+
+ # Test with LocationResponse
+ tool_calls_location = [
+ [{"args": {}, "id": "1", "name": "get_location"}],
+ [
+ {
+ "name": "location_schema",
+ "id": "2",
+ "args": LOCATION_DATA,
+ }
+ ],
+ ]
+
+ model_location = FakeToolCallingModel(tool_calls=tool_calls_location)
+
+ agent_location = create_react_agent(
+ model_location,
+ [get_weather, get_location],
+ response_format=ToolStrategy({"oneOf": [weather_json_schema, location_json_schema]}),
+ )
+ response_location = agent_location.invoke({"messages": [HumanMessage("Where am I?")]})
+
+ assert response_location["structured_response"] == EXPECTED_LOCATION_DICT
+ assert len(response_location["messages"]) == 5
+
+ def test_union_of_types(self) -> None:
+ """Test response_format as ToolStrategy with Union of various types."""
+ # Test with WeatherBaseModel
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel[Union[WeatherBaseModel, LocationResponse]](
+ tool_calls=tool_calls
+ )
+
+ agent = create_react_agent(
+ model,
+ [get_weather, get_location],
+ response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
+ assert len(response["messages"]) == 5
+
+ # Test with LocationResponse
+ tool_calls_location = [
+ [{"args": {}, "id": "1", "name": "get_location"}],
+ [
+ {
+ "name": "LocationResponse",
+ "id": "2",
+ "args": LOCATION_DATA,
+ }
+ ],
+ ]
+
+ model_location = FakeToolCallingModel(tool_calls=tool_calls_location)
+
+ agent_location = create_react_agent(
+ model_location,
+ [get_weather, get_location],
+ response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
+ )
+ response_location = agent_location.invoke({"messages": [HumanMessage("Where am I?")]})
+
+ assert response_location["structured_response"] == EXPECTED_LOCATION
+ assert len(response_location["messages"]) == 5
+
+ def test_multiple_structured_outputs_error_without_retry(self) -> None:
+ """Test that MultipleStructuredOutputsError is raised when model returns multiple structured tool calls without retry."""
+ tool_calls = [
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "1",
+ "args": WEATHER_DATA,
+ },
+ {
+ "name": "LocationResponse",
+ "id": "2",
+ "args": LOCATION_DATA,
+ },
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(
+ model,
+ [],
+ response_format=ToolStrategy(
+ Union[WeatherBaseModel, LocationResponse],
+ handle_errors=False,
+ ),
+ )
+
+ with pytest.raises(
+ MultipleStructuredOutputsError,
+ match=".*WeatherBaseModel.*LocationResponse.*",
+ ):
+ agent.invoke({"messages": [HumanMessage("Give me weather and location")]})
+
+ def test_multiple_structured_outputs_with_retry(self) -> None:
+ """Test that retry handles multiple structured output tool calls."""
+ tool_calls = [
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "1",
+ "args": WEATHER_DATA,
+ },
+ {
+ "name": "LocationResponse",
+ "id": "2",
+ "args": LOCATION_DATA,
+ },
+ ],
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "3",
+ "args": WEATHER_DATA,
+ },
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(
+ model,
+ [],
+ response_format=ToolStrategy(
+ Union[WeatherBaseModel, LocationResponse],
+ handle_errors=True,
+ ),
+ )
+
+ response = agent.invoke({"messages": [HumanMessage("Give me weather")]})
+
+ # HumanMessage, AIMessage, ToolMessage, ToolMessage, AI, ToolMessage
+ assert len(response["messages"]) == 6
+ assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
+
+ def test_structured_output_parsing_error_without_retry(self) -> None:
+ """Test that StructuredOutputParsingError is raised when tool args fail to parse without retry."""
+ tool_calls = [
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "1",
+ "args": {"invalid": "data"},
+ },
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(
+ model,
+ [],
+ response_format=ToolStrategy(
+ WeatherBaseModel,
+ handle_errors=False,
+ ),
+ )
+
+ with pytest.raises(
+ StructuredOutputValidationError,
+ match=".*WeatherBaseModel.*",
+ ):
+ agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ def test_structured_output_parsing_error_with_retry(self) -> None:
+ """Test that retry handles parsing errors for structured output."""
+ tool_calls = [
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "1",
+ "args": {"invalid": "data"},
+ },
+ ],
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "2",
+ "args": WEATHER_DATA,
+ },
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(
+ model,
+ [],
+ response_format=ToolStrategy(
+ WeatherBaseModel,
+ handle_errors=(StructuredOutputValidationError,),
+ ),
+ )
+
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ # HumanMessage, AIMessage, ToolMessage, AIMessage, ToolMessage
+ assert len(response["messages"]) == 5
+ assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
+
+ def test_retry_with_custom_function(self) -> None:
+ """Test retry with custom message generation."""
+ tool_calls = [
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "1",
+ "args": WEATHER_DATA,
+ },
+ {
+ "name": "LocationResponse",
+ "id": "2",
+ "args": LOCATION_DATA,
+ },
+ ],
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "3",
+ "args": WEATHER_DATA,
+ },
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ def custom_message(exception: Exception) -> str:
+ if isinstance(exception, MultipleStructuredOutputsError):
+ return "Custom error: Multiple outputs not allowed"
+ return "Custom error"
+
+ agent = create_react_agent(
+ model,
+ [],
+ response_format=ToolStrategy(
+ Union[WeatherBaseModel, LocationResponse],
+ handle_errors=custom_message,
+ ),
+ )
+
+ response = agent.invoke({"messages": [HumanMessage("Give me weather")]})
+
+ # HumanMessage, AIMessage, ToolMessage, ToolMessage, AI, ToolMessage
+ assert len(response["messages"]) == 6
+ assert response["messages"][2].content == "Custom error: Multiple outputs not allowed"
+ assert response["messages"][3].content == "Custom error: Multiple outputs not allowed"
+ assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
+
+ def test_retry_with_custom_string_message(self) -> None:
+ """Test retry with custom static string message."""
+ tool_calls = [
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "1",
+ "args": {"invalid": "data"},
+ },
+ ],
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "2",
+ "args": WEATHER_DATA,
+ },
+ ],
+ ]
+
+ model = FakeToolCallingModel(tool_calls=tool_calls)
+
+ agent = create_react_agent(
+ model,
+ [],
+ response_format=ToolStrategy(
+ WeatherBaseModel,
+ handle_errors="Please provide valid weather data with temperature and condition.",
+ ),
+ )
+
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert len(response["messages"]) == 5
+ assert (
+ response["messages"][2].content
+ == "Please provide valid weather data with temperature and condition."
+ )
+ assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
+
+
+class TestResponseFormatAsProviderStrategy:
+ def test_pydantic_model(self) -> None:
+ """Test response_format as ProviderStrategy with Pydantic model."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ ]
+
+ model = FakeToolCallingModel[WeatherBaseModel](
+ tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
+ )
+
+ agent = create_react_agent(
+ model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
+ assert len(response["messages"]) == 4
+
+ def test_dataclass(self) -> None:
+ """Test response_format as ProviderStrategy with dataclass."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ ]
+
+ model = FakeToolCallingModel[WeatherDataclass](
+ tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS
+ )
+
+ agent = create_react_agent(
+ model, [get_weather], response_format=ProviderStrategy(WeatherDataclass)
+ )
+ response = agent.invoke(
+ {"messages": [HumanMessage("What's the weather?")]},
+ )
+
+ assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS
+ assert len(response["messages"]) == 4
+
+ def test_typed_dict(self) -> None:
+ """Test response_format as ProviderStrategy with TypedDict."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ ]
+
+ model = FakeToolCallingModel[WeatherTypedDict](
+ tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
+ )
+
+ agent = create_react_agent(
+ model, [get_weather], response_format=ProviderStrategy(WeatherTypedDict)
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_DICT
+ assert len(response["messages"]) == 4
+
+ def test_json_schema(self) -> None:
+ """Test response_format as ProviderStrategy with JSON schema."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ ]
+
+ model = FakeToolCallingModel[dict](
+ tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
+ )
+
+ agent = create_react_agent(
+ model, [get_weather], response_format=ProviderStrategy(weather_json_schema)
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_DICT
+ assert len(response["messages"]) == 4
+
+
+def test_union_of_types() -> None:
+ """Test response_format as ProviderStrategy with Union (if supported)."""
+ tool_calls = [
+ [{"args": {}, "id": "1", "name": "get_weather"}],
+ [
+ {
+ "name": "WeatherBaseModel",
+ "id": "2",
+ "args": WEATHER_DATA,
+ }
+ ],
+ ]
+
+ model = FakeToolCallingModel[Union[WeatherBaseModel, LocationResponse]](
+ tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
+ )
+
+ agent = create_react_agent(
+ model,
+ [get_weather, get_location],
+ response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
+ )
+ response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
+
+ assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
+ assert len(response["messages"]) == 5
diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_responses.py b/libs/langchain_v1/tests/unit_tests/agents/test_responses.py
new file mode 100644
index 00000000000..6f21e4b21c5
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/test_responses.py
@@ -0,0 +1,140 @@
+"""Unit tests for langgraph.prebuilt.responses module."""
+
+import pytest
+
+# Skip this test since langgraph.prebuilt.responses is not available
+pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
+
+
+class _TestModel(BaseModel):
+ """A test model for structured output."""
+
+ name: str
+ age: int
+ email: str = "default@example.com"
+
+
+class CustomModel(BaseModel):
+ """Custom model with a custom docstring."""
+
+ value: float
+ description: str
+
+
+class EmptyDocModel(BaseModel):
+ # No custom docstring, should have no description in tool
+ data: str
+
+
+class TestUsingToolStrategy:
+ """Test UsingToolStrategy dataclass."""
+
+ def test_basic_creation(self) -> None:
+ """Test basic UsingToolStrategy creation."""
+ strategy = ToolStrategy(schema=_TestModel)
+ assert strategy.schema == _TestModel
+ assert strategy.tool_message_content is None
+ assert len(strategy.schema_specs) == 1
+
+ def test_multiple_schemas(self) -> None:
+ """Test UsingToolStrategy with multiple schemas."""
+ strategy = ToolStrategy(schema=Union[_TestModel, CustomModel])
+ assert len(strategy.schema_specs) == 2
+ assert strategy.schema_specs[0].schema == _TestModel
+ assert strategy.schema_specs[1].schema == CustomModel
+
+ def test_schema_with_tool_message_content(self) -> None:
+ """Test UsingToolStrategy with tool message content."""
+ strategy = ToolStrategy(schema=_TestModel, tool_message_content="custom message")
+ assert strategy.schema == _TestModel
+ assert strategy.tool_message_content == "custom message"
+ assert len(strategy.schema_specs) == 1
+
+
+class TestOutputToolBinding:
+ """Test OutputToolBinding dataclass and its methods."""
+
+ def test_from_schema_spec_basic(self) -> None:
+ """Test basic OutputToolBinding creation from SchemaSpec."""
+ schema_spec = _SchemaSpec(schema=_TestModel)
+ tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
+
+ assert tool_binding.schema == _TestModel
+ assert tool_binding.schema_kind == "pydantic"
+ assert tool_binding.tool is not None
+ assert tool_binding.tool.name == "_TestModel"
+
+ def test_from_schema_spec_with_custom_name(self) -> None:
+ """Test OutputToolBinding creation with custom name."""
+ schema_spec = _SchemaSpec(schema=_TestModel, name="custom_tool_name")
+ tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
+ assert tool_binding.tool.name == "custom_tool_name"
+
+ def test_from_schema_spec_with_custom_description(self) -> None:
+ """Test OutputToolBinding creation with custom description."""
+ schema_spec = _SchemaSpec(schema=_TestModel, description="Custom tool description")
+ tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
+
+ assert tool_binding.tool.description == "Custom tool description"
+
+ def test_from_schema_spec_with_model_docstring(self) -> None:
+ """Test OutputToolBinding creation using model docstring as description."""
+ schema_spec = _SchemaSpec(schema=CustomModel)
+ tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
+
+ assert tool_binding.tool.description == "Custom model with a custom docstring."
+
+ @pytest.mark.skip(reason="Need to fix bug in langchain-core for inheritance of doc-strings.")
+ def test_from_schema_spec_empty_docstring(self) -> None:
+ """Test OutputToolBinding creation with model that has default docstring."""
+
+ # Create a model with the same docstring as BaseModel
+ class DefaultDocModel(BaseModel):
+ # This should have the same docstring as BaseModel
+ pass
+
+ schema_spec = _SchemaSpec(schema=DefaultDocModel)
+ tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
+
+ # Should use empty description when model has default BaseModel docstring
+ assert tool_binding.tool.description == ""
+
+ def test_parse_payload_pydantic_success(self) -> None:
+ """Test successful parsing for Pydantic model."""
+ schema_spec = _SchemaSpec(schema=_TestModel)
+ tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
+
+ tool_args = {"name": "John", "age": 30}
+ result = tool_binding.parse(tool_args)
+
+ assert isinstance(result, _TestModel)
+ assert result.name == "John"
+ assert result.age == 30
+ assert result.email == "default@example.com" # default value
+
+ def test_parse_payload_pydantic_validation_error(self) -> None:
+ """Test parsing failure for invalid Pydantic data."""
+ schema_spec = _SchemaSpec(schema=_TestModel)
+ tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
+
+ # Missing required field 'name'
+ tool_args = {"age": 30}
+
+ with pytest.raises(ValueError, match="Failed to parse data to _TestModel"):
+ tool_binding.parse(tool_args)
+
+
+class TestEdgeCases:
+ """Test edge cases and error conditions."""
+
+ def test_empty_schemas_list(self) -> None:
+ """Test UsingToolStrategy with empty schemas list."""
+ strategy = ToolStrategy(EmptyDocModel)
+ assert len(strategy.schema_specs) == 1
+
+ @pytest.mark.skip(reason="Need to fix bug in langchain-core for inheritance of doc-strings.")
+ def test_base_model_doc_constant(self) -> None:
+ """Test that BASE_MODEL_DOC constant is set correctly."""
+ binding = OutputToolBinding.from_schema_spec(_SchemaSpec(EmptyDocModel))
+ assert binding.tool.name == "EmptyDocModel"
+ assert binding.tool.description[:5] == "" # Should be empty for default docstring
diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py b/libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py
new file mode 100644
index 00000000000..ed1d67c3fc2
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py
@@ -0,0 +1,147 @@
+from __future__ import annotations
+
+import pytest
+
+# Skip this test since langgraph.prebuilt.responses is not available
+pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
+
+try:
+ from langchain_openai import ChatOpenAI
+except ImportError:
+ skip_openai_integration_tests = True
+else:
+ skip_openai_integration_tests = False
+
+AGENT_PROMPT = "You are an HR assistant."
+
+
+class ToolCalls(BaseSchema):
+ get_employee_role: int
+ get_employee_department: int
+
+
+class AssertionByInvocation(BaseSchema):
+ prompt: str
+ tools_with_expected_calls: ToolCalls
+ expected_last_message: str
+ expected_structured_response: Optional[Dict[str, Any]]
+ llm_request_count: int
+
+
+class TestCase(BaseSchema):
+ name: str
+ response_format: Union[Dict[str, Any], List[Dict[str, Any]]]
+ assertions_by_invocation: List[AssertionByInvocation]
+
+
+class Employee(BaseModel):
+ name: str
+ role: str
+ department: str
+
+
+EMPLOYEES: list[Employee] = [
+ Employee(name="Sabine", role="Developer", department="IT"),
+ Employee(name="Henrik", role="Product Manager", department="IT"),
+ Employee(name="Jessica", role="HR", department="People"),
+]
+
+TEST_CASES = load_spec("responses", as_model=TestCase)
+
+
+def _make_tool(fn, *, name: str, description: str):
+ mock = MagicMock(side_effect=lambda *, name: fn(name=name))
+ InputModel = create_model(f"{name}_input", name=(str, ...))
+
+ @tool(name, description=description, args_schema=InputModel)
+ def _wrapped(name: str):
+ return mock(name=name)
+
+ return {"tool": _wrapped, "mock": mock}
+
+
+@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI integration tests are disabled.")
+@pytest.mark.parametrize("case", TEST_CASES, ids=[c.name for c in TEST_CASES])
+def test_responses_integration_matrix(case: TestCase) -> None:
+ if case.name == "asking for information that does not fit into the response format":
+ pytest.xfail(
+ "currently failing due to undefined behavior when model cannot conform to any of the structured response formats."
+ )
+
+ def get_employee_role(*, name: str) -> Optional[str]:
+ for e in EMPLOYEES:
+ if e.name == name:
+ return e.role
+ return None
+
+ def get_employee_department(*, name: str) -> Optional[str]:
+ for e in EMPLOYEES:
+ if e.name == name:
+ return e.department
+ return None
+
+ role_tool = _make_tool(
+ get_employee_role,
+ name="get_employee_role",
+ description="Get the employee role by name",
+ )
+ dept_tool = _make_tool(
+ get_employee_department,
+ name="get_employee_department",
+ description="Get the employee department by name",
+ )
+
+ response_format_spec = case.response_format
+ if isinstance(response_format_spec, dict):
+ response_format_spec = [response_format_spec]
+ # Unwrap nested schema objects
+ response_format_spec = [item.get("schema", item) for item in response_format_spec]
+ if len(response_format_spec) == 1:
+ tool_output = ToolStrategy(response_format_spec[0])
+ else:
+ tool_output = ToolStrategy({"oneOf": response_format_spec})
+
+ llm_request_count = 0
+
+ for assertion in case.assertions_by_invocation:
+
+ def on_request(request: httpx.Request) -> None:
+ nonlocal llm_request_count
+ llm_request_count += 1
+
+ http_client = httpx.Client(
+ event_hooks={"request": [on_request]},
+ )
+
+ model = ChatOpenAI(
+ model="gpt-4o",
+ temperature=0,
+ http_client=http_client,
+ )
+
+ agent = create_react_agent(
+ model,
+ tools=[role_tool["tool"], dept_tool["tool"]],
+ prompt=AGENT_PROMPT,
+ response_format=tool_output,
+ )
+
+ result = agent.invoke({"messages": [HumanMessage(assertion.prompt)]})
+
+ # Count tool calls
+ assert role_tool["mock"].call_count == assertion.tools_with_expected_calls.get_employee_role
+ assert (
+ dept_tool["mock"].call_count
+ == assertion.tools_with_expected_calls.get_employee_department
+ )
+
+ # Count LLM calls
+ assert llm_request_count == assertion.llm_request_count
+
+ # Check last message content
+ last_message = result["messages"][-1]
+ assert last_message.content == assertion.expected_last_message
+
+ # Check structured response
+ structured_response_json = result["structured_response"]
+ assert structured_response_json == assertion.expected_structured_response
diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_return_direct_spec.py b/libs/langchain_v1/tests/unit_tests/agents/test_return_direct_spec.py
new file mode 100644
index 00000000000..d87c22bd283
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/test_return_direct_spec.py
@@ -0,0 +1,107 @@
+from __future__ import annotations
+
+import pytest
+
+# Skip this test since langgraph.prebuilt.responses is not available
+pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
+
+try:
+ from langchain_openai import ChatOpenAI
+except ImportError:
+ skip_openai_integration_tests = True
+else:
+ skip_openai_integration_tests = False
+
+AGENT_PROMPT = """
+You are a strict polling bot.
+
+- Only use the "poll_job" tool until it returns { status: "succeeded" }.
+- If status is "pending", call the tool again. Do not produce a final answer.
+- When it is "succeeded", return exactly: "Attempts: " with no extra text.
+"""
+
+
+class TestCase(BaseSchema):
+ name: str
+ return_direct: bool
+ response_format: Optional[Dict[str, Any]]
+ expected_tool_calls: int
+ expected_last_message: str
+ expected_structured_response: Optional[Dict[str, Any]]
+
+
+TEST_CASES = load_spec("return_direct", as_model=TestCase)
+
+
+def _make_tool(return_direct: bool):
+ attempts = 0
+
+ def _side_effect():
+ nonlocal attempts
+ attempts += 1
+ return {
+ "status": "succeeded" if attempts >= 10 else "pending",
+ "attempts": attempts,
+ }
+
+ mock = MagicMock(side_effect=_side_effect)
+
+ @tool(
+ "pollJob",
+ description=(
+ "Check the status of a long-running job. "
+ "Returns { status: 'pending' | 'succeeded', attempts: number }."
+ ),
+ return_direct=return_direct,
+ )
+ def _wrapped():
+ return mock()
+
+ return {"tool": _wrapped, "mock": mock}
+
+
+@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI integration tests are disabled.")
+@pytest.mark.parametrize("case", TEST_CASES, ids=[c.name for c in TEST_CASES])
+def test_return_direct_integration_matrix(case: TestCase) -> None:
+ poll_tool = _make_tool(case.return_direct)
+
+ model = ChatOpenAI(
+ model="gpt-4o",
+ temperature=0,
+ )
+
+ if case.response_format:
+ agent = create_react_agent(
+ model,
+ tools=[poll_tool["tool"]],
+ prompt=AGENT_PROMPT,
+ response_format=ToolStrategy(case.response_format),
+ )
+ else:
+ agent = create_react_agent(
+ model,
+ tools=[poll_tool["tool"]],
+ prompt=AGENT_PROMPT,
+ )
+
+ result = agent.invoke(
+ {
+ "messages": [
+ HumanMessage("Poll the job until it's done and tell me how many attempts it took.")
+ ]
+ }
+ )
+
+ # Count tool calls
+ assert poll_tool["mock"].call_count == case.expected_tool_calls
+
+ # Check last message content
+ last_message = result["messages"][-1]
+ assert last_message.content == case.expected_last_message
+
+ # Check structured response
+ if case.expected_structured_response is not None:
+ structured_response_json = result["structured_response"]
+ assert structured_response_json == case.expected_structured_response
+ else:
+ assert "structured_response" not in result
diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py
new file mode 100644
index 00000000000..8581cf3c7cf
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py
@@ -0,0 +1,1482 @@
+import contextlib
+import dataclasses
+import json
+from functools import partial
+from typing import (
+ Annotated,
+ Any,
+ NoReturn,
+ TypeVar,
+ Union,
+)
+
+import pytest
+from langchain_core.messages import (
+ AIMessage,
+ AnyMessage,
+ HumanMessage,
+ RemoveMessage,
+ ToolCall,
+ ToolMessage,
+)
+from langchain_core.tools import BaseTool, ToolException
+from langchain_core.tools import tool as dec_tool
+from langgraph.config import get_stream_writer
+from langgraph.errors import GraphBubbleUp, GraphInterrupt
+from langgraph.graph import START, MessagesState, StateGraph
+from langgraph.graph.message import REMOVE_ALL_MESSAGES, add_messages
+from langgraph.store.base import BaseStore
+from langgraph.store.memory import InMemoryStore
+from langgraph.types import Command, Send
+from pydantic import BaseModel
+from pydantic.v1 import BaseModel as BaseModelV1
+from typing_extensions import TypedDict
+
+from langchain.agents import (
+ ToolNode,
+)
+from langchain.agents.tool_node import (
+ TOOL_CALL_ERROR_TEMPLATE,
+ InjectedState,
+ InjectedStore,
+ ToolInvocationError,
+ tools_condition,
+)
+
+from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
+from .model import FakeToolCallingModel
+
+pytestmark = pytest.mark.anyio
+
+
+def tool1(some_val: int, some_other_val: str) -> str:
+ """Tool 1 docstring."""
+ if some_val == 0:
+ msg = "Test error"
+ raise ValueError(msg)
+ return f"{some_val} - {some_other_val}"
+
+
+async def tool2(some_val: int, some_other_val: str) -> str:
+ """Tool 2 docstring."""
+ if some_val == 0:
+ msg = "Test error"
+ raise ToolException(msg)
+ return f"tool2: {some_val} - {some_other_val}"
+
+
+async def tool3(some_val: int, some_other_val: str) -> str:
+ """Tool 3 docstring."""
+ return [
+ {"key_1": some_val, "key_2": "foo"},
+ {"key_1": some_other_val, "key_2": "baz"},
+ ]
+
+
+async def tool4(some_val: int, some_other_val: str) -> str:
+ """Tool 4 docstring."""
+ return [
+ {"type": "image_url", "image_url": {"url": "abdc"}},
+ ]
+
+
+@dec_tool
+def tool5(some_val: int) -> NoReturn:
+ """Tool 5 docstring."""
+ msg = "Test error"
+ raise ToolException(msg)
+
+
+tool5.handle_tool_error = "foo"
+
+
+async def test_tool_node() -> None:
+ """Test tool node."""
+ result = ToolNode([tool1]).invoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool1",
+ "args": {"some_val": 1, "some_other_val": "foo"},
+ "id": "some 0",
+ }
+ ],
+ )
+ ]
+ }
+ )
+
+ tool_message: ToolMessage = result["messages"][-1]
+ assert tool_message.type == "tool"
+ assert tool_message.content == "1 - foo"
+ assert tool_message.tool_call_id == "some 0"
+
+ result2 = await ToolNode([tool2]).ainvoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool2",
+ "args": {"some_val": 2, "some_other_val": "bar"},
+ "id": "some 1",
+ }
+ ],
+ )
+ ]
+ }
+ )
+
+ tool_message: ToolMessage = result2["messages"][-1]
+ assert tool_message.type == "tool"
+ assert tool_message.content == "tool2: 2 - bar"
+
+ # list of dicts tool content
+ result3 = await ToolNode([tool3]).ainvoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool3",
+ "args": {"some_val": 2, "some_other_val": "bar"},
+ "id": "some 2",
+ }
+ ],
+ )
+ ]
+ }
+ )
+ tool_message: ToolMessage = result3["messages"][-1]
+ assert tool_message.type == "tool"
+ assert (
+ tool_message.content == '[{"key_1": 2, "key_2": "foo"}, {"key_1": "bar", "key_2": "baz"}]'
+ )
+ assert tool_message.tool_call_id == "some 2"
+
+ # list of content blocks tool content
+ result4 = await ToolNode([tool4]).ainvoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool4",
+ "args": {"some_val": 2, "some_other_val": "bar"},
+ "id": "some 3",
+ }
+ ],
+ )
+ ]
+ }
+ )
+ tool_message: ToolMessage = result4["messages"][-1]
+ assert tool_message.type == "tool"
+ assert tool_message.content == [{"type": "image_url", "image_url": {"url": "abdc"}}]
+ assert tool_message.tool_call_id == "some 3"
+
+
+async def test_tool_node_tool_call_input() -> None:
+ # Single tool call
+ tool_call_1 = {
+ "name": "tool1",
+ "args": {"some_val": 1, "some_other_val": "foo"},
+ "id": "some 0",
+ "type": "tool_call",
+ }
+ result = ToolNode([tool1]).invoke([tool_call_1])
+ assert result["messages"] == [
+ ToolMessage(content="1 - foo", tool_call_id="some 0", name="tool1"),
+ ]
+
+ # Multiple tool calls
+ tool_call_2 = {
+ "name": "tool1",
+ "args": {"some_val": 2, "some_other_val": "bar"},
+ "id": "some 1",
+ "type": "tool_call",
+ }
+ result = ToolNode([tool1]).invoke([tool_call_1, tool_call_2])
+ assert result["messages"] == [
+ ToolMessage(content="1 - foo", tool_call_id="some 0", name="tool1"),
+ ToolMessage(content="2 - bar", tool_call_id="some 1", name="tool1"),
+ ]
+
+ # Test with unknown tool
+ tool_call_3 = tool_call_1.copy()
+ tool_call_3["name"] = "tool2"
+ result = ToolNode([tool1]).invoke([tool_call_1, tool_call_3])
+ assert result["messages"] == [
+ ToolMessage(content="1 - foo", tool_call_id="some 0", name="tool1"),
+ ToolMessage(
+ content="Error: tool2 is not a valid tool, try one of [tool1].",
+ name="tool2",
+ tool_call_id="some 0",
+ status="error",
+ ),
+ ]
+
+
+def test_tool_node_error_handling_default_invocation() -> None:
+ tn = ToolNode([tool1])
+ result = tn.invoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool1",
+ "args": {"invalid": 0, "args": "foo"},
+ "id": "some id",
+ },
+ ],
+ )
+ ]
+ }
+ )
+
+ assert all(m.type == "tool" for m in result["messages"])
+ assert all(m.status == "error" for m in result["messages"])
+ assert (
+ "Error invoking tool 'tool1' with kwargs {'invalid': 0, 'args': 'foo'} with error:\n"
+ in result["messages"][0].content
+ )
+
+
+def test_tool_node_error_handling_default_exception() -> None:
+ tn = ToolNode([tool1])
+ with pytest.raises(ValueError):
+ tn.invoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool1",
+ "args": {"some_val": 0, "some_other_val": "foo"},
+ "id": "some id",
+ },
+ ],
+ )
+ ]
+ }
+ )
+
+
+async def test_tool_node_error_handling() -> None:
+ def handle_all(e: Union[ValueError, ToolException, ToolInvocationError]):
+ return TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
+
+ # test catching all exceptions, via:
+ # - handle_tool_errors = True
+ # - passing a tuple of all exceptions
+ # - passing a callable with all exceptions in the signature
+ for handle_tool_errors in (
+ True,
+ (ValueError, ToolException, ToolInvocationError),
+ handle_all,
+ ):
+ result_error = await ToolNode(
+ [tool1, tool2, tool3], handle_tool_errors=handle_tool_errors
+ ).ainvoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool1",
+ "args": {"some_val": 0, "some_other_val": "foo"},
+ "id": "some id",
+ },
+ {
+ "name": "tool2",
+ "args": {"some_val": 0, "some_other_val": "bar"},
+ "id": "some other id",
+ },
+ {
+ "name": "tool3",
+ "args": {"some_val": 0},
+ "id": "another id",
+ },
+ ],
+ )
+ ]
+ }
+ )
+
+ assert all(m.type == "tool" for m in result_error["messages"])
+ assert all(m.status == "error" for m in result_error["messages"])
+ assert (
+ result_error["messages"][0].content
+ == f"Error: {ValueError('Test error')!r}\n Please fix your mistakes."
+ )
+ assert (
+ result_error["messages"][1].content
+ == f"Error: {ToolException('Test error')!r}\n Please fix your mistakes."
+ )
+ assert (
+ "ValidationError" in result_error["messages"][2].content
+ or "validation error" in result_error["messages"][2].content
+ )
+
+ assert result_error["messages"][0].tool_call_id == "some id"
+ assert result_error["messages"][1].tool_call_id == "some other id"
+ assert result_error["messages"][2].tool_call_id == "another id"
+
+
+async def test_tool_node_error_handling_callable() -> None:
+ def handle_value_error(e: ValueError) -> str:
+ return "Value error"
+
+ def handle_tool_exception(e: ToolException) -> str:
+ return "Tool exception"
+
+ for handle_tool_errors in ("Value error", handle_value_error):
+ result_error = await ToolNode([tool1], handle_tool_errors=handle_tool_errors).ainvoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool1",
+ "args": {"some_val": 0, "some_other_val": "foo"},
+ "id": "some id",
+ },
+ ],
+ )
+ ]
+ }
+ )
+ tool_message: ToolMessage = result_error["messages"][-1]
+ assert tool_message.type == "tool"
+ assert tool_message.status == "error"
+ assert tool_message.content == "Value error"
+
+ # test raising for an unhandled exception, via:
+ # - passing a tuple of all exceptions
+ # - passing a callable with all exceptions in the signature
+ for handle_tool_errors in ((ValueError,), handle_value_error):
+ with pytest.raises(ToolException) as exc_info:
+ await ToolNode([tool1, tool2], handle_tool_errors=handle_tool_errors).ainvoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool1",
+ "args": {"some_val": 0, "some_other_val": "foo"},
+ "id": "some id",
+ },
+ {
+ "name": "tool2",
+ "args": {"some_val": 0, "some_other_val": "bar"},
+ "id": "some other id",
+ },
+ ],
+ )
+ ]
+ }
+ )
+ assert str(exc_info.value) == "Test error"
+
+ for handle_tool_errors in ((ToolException,), handle_tool_exception):
+ with pytest.raises(ValueError) as exc_info:
+ await ToolNode([tool1, tool2], handle_tool_errors=handle_tool_errors).ainvoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool1",
+ "args": {"some_val": 0, "some_other_val": "foo"},
+ "id": "some id",
+ },
+ {
+ "name": "tool2",
+ "args": {"some_val": 0, "some_other_val": "bar"},
+ "id": "some other id",
+ },
+ ],
+ )
+ ]
+ }
+ )
+ assert str(exc_info.value) == "Test error"
+
+
+async def test_tool_node_handle_tool_errors_false() -> None:
+ with pytest.raises(ValueError) as exc_info:
+ ToolNode([tool1], handle_tool_errors=False).invoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool1",
+ "args": {"some_val": 0, "some_other_val": "foo"},
+ "id": "some id",
+ }
+ ],
+ )
+ ]
+ }
+ )
+
+ assert str(exc_info.value) == "Test error"
+
+ with pytest.raises(ToolException):
+ await ToolNode([tool2], handle_tool_errors=False).ainvoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool2",
+ "args": {"some_val": 0, "some_other_val": "bar"},
+ "id": "some id",
+ }
+ ],
+ )
+ ]
+ }
+ )
+
+ assert str(exc_info.value) == "Test error"
+
+ # test validation errors get raised if handle_tool_errors is False
+ with pytest.raises(ToolInvocationError):
+ ToolNode([tool1], handle_tool_errors=False).invoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool1",
+ "args": {"some_val": 0},
+ "id": "some id",
+ }
+ ],
+ )
+ ]
+ }
+ )
+
+
+def test_tool_node_individual_tool_error_handling() -> None:
+ # test error handling on individual tools (and that it overrides overall error handling!)
+ result_individual_tool_error_handler = ToolNode([tool5], handle_tool_errors="bar").invoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool5",
+ "args": {"some_val": 0},
+ "id": "some 0",
+ }
+ ],
+ )
+ ]
+ }
+ )
+
+ tool_message: ToolMessage = result_individual_tool_error_handler["messages"][-1]
+ assert tool_message.type == "tool"
+ assert tool_message.status == "error"
+ assert tool_message.content == "foo"
+ assert tool_message.tool_call_id == "some 0"
+
+
+def test_tool_node_incorrect_tool_name() -> None:
+ result_incorrect_name = ToolNode([tool1, tool2]).invoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool3",
+ "args": {"some_val": 1, "some_other_val": "foo"},
+ "id": "some 0",
+ }
+ ],
+ )
+ ]
+ }
+ )
+
+ tool_message: ToolMessage = result_incorrect_name["messages"][-1]
+ assert tool_message.type == "tool"
+ assert tool_message.status == "error"
+ assert tool_message.content == "Error: tool3 is not a valid tool, try one of [tool1, tool2]."
+ assert tool_message.tool_call_id == "some 0"
+
+
+def test_tool_node_node_interrupt() -> None:
+ def tool_interrupt(some_val: int) -> None:
+ """Tool docstring."""
+ msg = "foo"
+ raise GraphBubbleUp(msg)
+
+ def handle(e: GraphInterrupt) -> str:
+ return "handled"
+
+ for handle_tool_errors in (True, (GraphBubbleUp,), "handled", handle, False):
+ node = ToolNode([tool_interrupt], handle_tool_errors=handle_tool_errors)
+ with pytest.raises(GraphBubbleUp) as exc_info:
+ node.invoke(
+ {
+ "messages": [
+ AIMessage(
+ "hi?",
+ tool_calls=[
+ {
+ "name": "tool_interrupt",
+ "args": {"some_val": 0},
+ "id": "some 0",
+ }
+ ],
+ )
+ ]
+ }
+ )
+ assert exc_info.value == "foo"
+
+
+@pytest.mark.parametrize("input_type", ["dict", "tool_calls"])
+async def test_tool_node_command(input_type: str) -> None:
+ from langchain_core.tools.base import InjectedToolCallId
+
+ @dec_tool
+ def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]):
+ """Transfer to Bob"""
+ return Command(
+ update={
+ "messages": [ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)]
+ },
+ goto="bob",
+ graph=Command.PARENT,
+ )
+
+ @dec_tool
+ async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]):
+ """Transfer to Bob"""
+ return Command(
+ update={
+ "messages": [ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)]
+ },
+ goto="bob",
+ graph=Command.PARENT,
+ )
+
+ class CustomToolSchema(BaseModel):
+ tool_call_id: Annotated[str, InjectedToolCallId]
+
+ class MyCustomTool(BaseTool):
+ def _run(*args: Any, **kwargs: Any):
+ return Command(
+ update={
+ "messages": [
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id=kwargs["tool_call_id"],
+ )
+ ]
+ },
+ goto="bob",
+ graph=Command.PARENT,
+ )
+
+ async def _arun(*args: Any, **kwargs: Any):
+ return Command(
+ update={
+ "messages": [
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id=kwargs["tool_call_id"],
+ )
+ ]
+ },
+ goto="bob",
+ graph=Command.PARENT,
+ )
+
+ custom_tool = MyCustomTool(
+ name="custom_transfer_to_bob",
+ description="Transfer to bob",
+ args_schema=CustomToolSchema,
+ )
+ async_custom_tool = MyCustomTool(
+ name="async_custom_transfer_to_bob",
+ description="Transfer to bob",
+ args_schema=CustomToolSchema,
+ )
+
+ # test mixing regular tools and tools returning commands
+ def add(a: int, b: int) -> int:
+ """Add two numbers"""
+ return a + b
+
+ tool_calls = [
+ {"args": {"a": 1, "b": 2}, "id": "1", "name": "add", "type": "tool_call"},
+ {"args": {}, "id": "2", "name": "transfer_to_bob", "type": "tool_call"},
+ ]
+ if input_type == "dict":
+ input_ = {"messages": [AIMessage("", tool_calls=tool_calls)]}
+ elif input_type == "tool_calls":
+ input_ = tool_calls
+ result = ToolNode([add, transfer_to_bob]).invoke(input_)
+
+ assert result == [
+ {
+ "messages": [
+ ToolMessage(
+ content="3",
+ tool_call_id="1",
+ name="add",
+ )
+ ]
+ },
+ Command(
+ update={
+ "messages": [
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id="2",
+ name="transfer_to_bob",
+ )
+ ]
+ },
+ goto="bob",
+ graph=Command.PARENT,
+ ),
+ ]
+
+ # test tools returning commands
+
+ # test sync tools
+ for tool in [transfer_to_bob, custom_tool]:
+ result = ToolNode([tool]).invoke(
+ {"messages": [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])]}
+ )
+ assert result == [
+ Command(
+ update={
+ "messages": [
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id="1",
+ name=tool.name,
+ )
+ ]
+ },
+ goto="bob",
+ graph=Command.PARENT,
+ )
+ ]
+
+ # test async tools
+ for tool in [async_transfer_to_bob, async_custom_tool]:
+ result = await ToolNode([tool]).ainvoke(
+ {"messages": [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])]}
+ )
+ assert result == [
+ Command(
+ update={
+ "messages": [
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id="1",
+ name=tool.name,
+ )
+ ]
+ },
+ goto="bob",
+ graph=Command.PARENT,
+ )
+ ]
+
+ # test multiple commands
+ result = ToolNode([transfer_to_bob, custom_tool]).invoke(
+ {
+ "messages": [
+ AIMessage(
+ "",
+ tool_calls=[
+ {"args": {}, "id": "1", "name": "transfer_to_bob"},
+ {"args": {}, "id": "2", "name": "custom_transfer_to_bob"},
+ ],
+ )
+ ]
+ }
+ )
+ assert result == [
+ Command(
+ update={
+ "messages": [
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id="1",
+ name="transfer_to_bob",
+ )
+ ]
+ },
+ goto="bob",
+ graph=Command.PARENT,
+ ),
+ Command(
+ update={
+ "messages": [
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id="2",
+ name="custom_transfer_to_bob",
+ )
+ ]
+ },
+ goto="bob",
+ graph=Command.PARENT,
+ ),
+ ]
+
+ # test validation (mismatch between input type and command.update type)
+ with pytest.raises(ValueError):
+
+ @dec_tool
+ def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
+ """My tool"""
+ return Command(update=[ToolMessage(content="foo", tool_call_id=tool_call_id)])
+
+ ToolNode([list_update_tool]).invoke(
+ {
+ "messages": [
+ AIMessage(
+ "",
+ tool_calls=[{"args": {}, "id": "1", "name": "list_update_tool"}],
+ )
+ ]
+ }
+ )
+
+ # test validation (missing tool message in the update for current graph)
+ with pytest.raises(ValueError):
+
+ @dec_tool
+ def no_update_tool():
+ """My tool"""
+ return Command(update={"messages": []})
+
+ ToolNode([no_update_tool]).invoke(
+ {
+ "messages": [
+ AIMessage(
+ "",
+ tool_calls=[{"args": {}, "id": "1", "name": "no_update_tool"}],
+ )
+ ]
+ }
+ )
+
+ # test validation (tool message with a wrong tool call ID)
+ with pytest.raises(ValueError):
+
+ @dec_tool
+ def mismatching_tool_call_id_tool():
+ """My tool"""
+ return Command(update={"messages": [ToolMessage(content="foo", tool_call_id="2")]})
+
+ ToolNode([mismatching_tool_call_id_tool]).invoke(
+ {
+ "messages": [
+ AIMessage(
+ "",
+ tool_calls=[
+ {
+ "args": {},
+ "id": "1",
+ "name": "mismatching_tool_call_id_tool",
+ }
+ ],
+ )
+ ]
+ }
+ )
+
+ # test validation (missing tool message in the update for parent graph is OK)
+ @dec_tool
+ def node_update_parent_tool():
+ """No update"""
+ return Command(update={"messages": []}, graph=Command.PARENT)
+
+ assert ToolNode([node_update_parent_tool]).invoke(
+ {
+ "messages": [
+ AIMessage(
+ "",
+ tool_calls=[{"args": {}, "id": "1", "name": "node_update_parent_tool"}],
+ )
+ ]
+ }
+ ) == [Command(update={"messages": []}, graph=Command.PARENT)]
+
+
+async def test_tool_node_command_list_input() -> None:
+ from langchain_core.tools.base import InjectedToolCallId
+
+ @dec_tool
+ def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]):
+ """Transfer to Bob"""
+ return Command(
+ update=[ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)],
+ goto="bob",
+ graph=Command.PARENT,
+ )
+
+ @dec_tool
+ async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]):
+ """Transfer to Bob"""
+ return Command(
+ update=[ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)],
+ goto="bob",
+ graph=Command.PARENT,
+ )
+
+ class CustomToolSchema(BaseModel):
+ tool_call_id: Annotated[str, InjectedToolCallId]
+
+ class MyCustomTool(BaseTool):
+ def _run(*args: Any, **kwargs: Any):
+ return Command(
+ update=[
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id=kwargs["tool_call_id"],
+ )
+ ],
+ goto="bob",
+ graph=Command.PARENT,
+ )
+
+ async def _arun(*args: Any, **kwargs: Any):
+ return Command(
+ update=[
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id=kwargs["tool_call_id"],
+ )
+ ],
+ goto="bob",
+ graph=Command.PARENT,
+ )
+
+ custom_tool = MyCustomTool(
+ name="custom_transfer_to_bob",
+ description="Transfer to bob",
+ args_schema=CustomToolSchema,
+ )
+ async_custom_tool = MyCustomTool(
+ name="async_custom_transfer_to_bob",
+ description="Transfer to bob",
+ args_schema=CustomToolSchema,
+ )
+
+ # test mixing regular tools and tools returning commands
+ def add(a: int, b: int) -> int:
+ """Add two numbers"""
+ return a + b
+
+ result = ToolNode([add, transfer_to_bob]).invoke(
+ [
+ AIMessage(
+ "",
+ tool_calls=[
+ {"args": {"a": 1, "b": 2}, "id": "1", "name": "add"},
+ {"args": {}, "id": "2", "name": "transfer_to_bob"},
+ ],
+ )
+ ]
+ )
+
+ assert result == [
+ [
+ ToolMessage(
+ content="3",
+ tool_call_id="1",
+ name="add",
+ )
+ ],
+ Command(
+ update=[
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id="2",
+ name="transfer_to_bob",
+ )
+ ],
+ goto="bob",
+ graph=Command.PARENT,
+ ),
+ ]
+
+ # test tools returning commands
+
+ # test sync tools
+ for tool in [transfer_to_bob, custom_tool]:
+ result = ToolNode([tool]).invoke(
+ [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])]
+ )
+ assert result == [
+ Command(
+ update=[
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id="1",
+ name=tool.name,
+ )
+ ],
+ goto="bob",
+ graph=Command.PARENT,
+ )
+ ]
+
+ # test async tools
+ for tool in [async_transfer_to_bob, async_custom_tool]:
+ result = await ToolNode([tool]).ainvoke(
+ [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])]
+ )
+ assert result == [
+ Command(
+ update=[
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id="1",
+ name=tool.name,
+ )
+ ],
+ goto="bob",
+ graph=Command.PARENT,
+ )
+ ]
+
+ # test multiple commands
+ result = ToolNode([transfer_to_bob, custom_tool]).invoke(
+ [
+ AIMessage(
+ "",
+ tool_calls=[
+ {"args": {}, "id": "1", "name": "transfer_to_bob"},
+ {"args": {}, "id": "2", "name": "custom_transfer_to_bob"},
+ ],
+ )
+ ]
+ )
+ assert result == [
+ Command(
+ update=[
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id="1",
+ name="transfer_to_bob",
+ )
+ ],
+ goto="bob",
+ graph=Command.PARENT,
+ ),
+ Command(
+ update=[
+ ToolMessage(
+ content="Transferred to Bob",
+ tool_call_id="2",
+ name="custom_transfer_to_bob",
+ )
+ ],
+ goto="bob",
+ graph=Command.PARENT,
+ ),
+ ]
+
+ # test validation (mismatch between input type and command.update type)
+ with pytest.raises(ValueError):
+
+ @dec_tool
+ def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
+ """My tool"""
+ return Command(
+ update={"messages": [ToolMessage(content="foo", tool_call_id=tool_call_id)]}
+ )
+
+ ToolNode([list_update_tool]).invoke(
+ [
+ AIMessage(
+ "",
+ tool_calls=[{"args": {}, "id": "1", "name": "list_update_tool"}],
+ )
+ ]
+ )
+
+ # test validation (missing tool message in the update for current graph)
+ with pytest.raises(ValueError):
+
+ @dec_tool
+ def no_update_tool():
+ """My tool"""
+ return Command(update=[])
+
+ ToolNode([no_update_tool]).invoke(
+ [
+ AIMessage(
+ "",
+ tool_calls=[{"args": {}, "id": "1", "name": "no_update_tool"}],
+ )
+ ]
+ )
+
+ # test validation (tool message with a wrong tool call ID)
+ with pytest.raises(ValueError):
+
+ @dec_tool
+ def mismatching_tool_call_id_tool():
+ """My tool"""
+ return Command(update=[ToolMessage(content="foo", tool_call_id="2")])
+
+ ToolNode([mismatching_tool_call_id_tool]).invoke(
+ [
+ AIMessage(
+ "",
+ tool_calls=[{"args": {}, "id": "1", "name": "mismatching_tool_call_id_tool"}],
+ )
+ ]
+ )
+
+ # test validation (missing tool message in the update for parent graph is OK)
+ @dec_tool
+ def node_update_parent_tool():
+ """No update"""
+ return Command(update=[], graph=Command.PARENT)
+
+ assert ToolNode([node_update_parent_tool]).invoke(
+ [
+ AIMessage(
+ "",
+ tool_calls=[{"args": {}, "id": "1", "name": "node_update_parent_tool"}],
+ )
+ ]
+ ) == [Command(update=[], graph=Command.PARENT)]
+
+
+def test_tool_node_parent_command_with_send() -> None:
+ from langchain_core.tools.base import InjectedToolCallId
+
+ @dec_tool
+ def transfer_to_alice(tool_call_id: Annotated[str, InjectedToolCallId]):
+ """Transfer to Alice"""
+ return Command(
+ goto=[
+ Send(
+ "alice",
+ {
+ "messages": [
+ ToolMessage(
+ content="Transferred to Alice",
+ name="transfer_to_alice",
+ tool_call_id=tool_call_id,
+ )
+ ]
+ },
+ )
+ ],
+ graph=Command.PARENT,
+ )
+
+ @dec_tool
+ def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]):
+ """Transfer to Bob"""
+ return Command(
+ goto=[
+ Send(
+ "bob",
+ {
+ "messages": [
+ ToolMessage(
+ content="Transferred to Bob",
+ name="transfer_to_bob",
+ tool_call_id=tool_call_id,
+ )
+ ]
+ },
+ )
+ ],
+ graph=Command.PARENT,
+ )
+
+ tool_calls = [
+ {"args": {}, "id": "1", "name": "transfer_to_alice", "type": "tool_call"},
+ {"args": {}, "id": "2", "name": "transfer_to_bob", "type": "tool_call"},
+ ]
+
+ result = ToolNode([transfer_to_alice, transfer_to_bob]).invoke(
+ [AIMessage("", tool_calls=tool_calls)]
+ )
+
+ assert result == [
+ Command(
+ goto=[
+ Send(
+ "alice",
+ {
+ "messages": [
+ ToolMessage(
+ content="Transferred to Alice",
+ name="transfer_to_alice",
+ tool_call_id="1",
+ )
+ ]
+ },
+ ),
+ Send(
+ "bob",
+ {
+ "messages": [
+ ToolMessage(
+ content="Transferred to Bob",
+ name="transfer_to_bob",
+ tool_call_id="2",
+ )
+ ]
+ },
+ ),
+ ],
+ graph=Command.PARENT,
+ )
+ ]
+
+
+async def test_tool_node_command_remove_all_messages() -> None:
+ from langchain_core.tools.base import InjectedToolCallId
+
+ @dec_tool
+ def remove_all_messages_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
+ """A tool that removes all messages."""
+ return Command(update={"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)]})
+
+ tool_node = ToolNode([remove_all_messages_tool])
+ tool_call = {
+ "name": "remove_all_messages_tool",
+ "args": {},
+ "id": "tool_call_123",
+ }
+ result = await tool_node.ainvoke({"messages": [AIMessage(content="", tool_calls=[tool_call])]})
+
+ assert isinstance(result, list)
+ assert len(result) == 1
+ command = result[0]
+ assert isinstance(command, Command)
+ assert command.update == {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)]}
+
+
+class _InjectStateSchema(TypedDict):
+ messages: list
+ foo: str
+
+
+class _InjectedStatePydanticSchema(BaseModelV1):
+ messages: list
+ foo: str
+
+
+class _InjectedStatePydanticV2Schema(BaseModel):
+ messages: list
+ foo: str
+
+
+@dataclasses.dataclass
+class _InjectedStateDataclassSchema:
+ messages: list
+ foo: str
+
+
+T = TypeVar("T")
+
+
+@pytest.mark.parametrize(
+ "schema_",
+ [
+ _InjectStateSchema,
+ _InjectedStatePydanticSchema,
+ _InjectedStatePydanticV2Schema,
+ _InjectedStateDataclassSchema,
+ ],
+)
+def test_tool_node_inject_state(schema_: type[T]) -> None:
+ def tool1(some_val: int, state: Annotated[T, InjectedState]) -> str:
+ """Tool 1 docstring."""
+ if isinstance(state, dict):
+ return state["foo"]
+ return state.foo
+
+ def tool2(some_val: int, state: Annotated[T, InjectedState()]) -> str:
+ """Tool 2 docstring."""
+ if isinstance(state, dict):
+ return state["foo"]
+ return state.foo
+
+ def tool3(
+ some_val: int,
+ foo: Annotated[str, InjectedState("foo")],
+ msgs: Annotated[list[AnyMessage], InjectedState("messages")],
+ ) -> str:
+ """Tool 1 docstring."""
+ return foo
+
+ def tool4(some_val: int, msgs: Annotated[list[AnyMessage], InjectedState("messages")]) -> str:
+ """Tool 1 docstring."""
+ return msgs[0].content
+
+ node = ToolNode([tool1, tool2, tool3, tool4], handle_tool_errors=True)
+ for tool_name in ("tool1", "tool2", "tool3"):
+ tool_call = {
+ "name": tool_name,
+ "args": {"some_val": 1},
+ "id": "some 0",
+ "type": "tool_call",
+ }
+ msg = AIMessage("hi?", tool_calls=[tool_call])
+ result = node.invoke(schema_(messages=[msg], foo="bar"))
+ tool_message = result["messages"][-1]
+ assert tool_message.content == "bar", f"Failed for tool={tool_name}"
+
+ if tool_name == "tool3":
+ failure_input = None
+ with contextlib.suppress(Exception):
+ failure_input = schema_(messages=[msg], notfoo="bar")
+ if failure_input is not None:
+ with pytest.raises(KeyError):
+ node.invoke(failure_input)
+
+ with pytest.raises(ValueError):
+ node.invoke([msg])
+ else:
+ failure_input = None
+ try:
+ failure_input = schema_(messages=[msg], notfoo="bar")
+ except Exception:
+ # We'd get a validation error from pydantic state and wouldn't make it to the node
+ # anyway
+ pass
+ if failure_input is not None:
+ messages_ = node.invoke(failure_input)
+ tool_message = messages_["messages"][-1]
+ assert "KeyError" in tool_message.content
+ tool_message = node.invoke([msg])[-1]
+ assert "KeyError" in tool_message.content
+
+ tool_call = {
+ "name": "tool4",
+ "args": {"some_val": 1},
+ "id": "some 0",
+ "type": "tool_call",
+ }
+ msg = AIMessage("hi?", tool_calls=[tool_call])
+ result = node.invoke(schema_(messages=[msg], foo=""))
+ tool_message = result["messages"][-1]
+ assert tool_message.content == "hi?"
+
+ result = node.invoke([msg])
+ tool_message = result[-1]
+ assert tool_message.content == "hi?"
+
+
+def test_tool_node_inject_store() -> None:
+ store = InMemoryStore()
+ namespace = ("test",)
+
+ def tool1(some_val: int, store: Annotated[BaseStore, InjectedStore()]) -> str:
+ """Tool 1 docstring."""
+ store_val = store.get(namespace, "test_key").value["foo"]
+ return f"Some val: {some_val}, store val: {store_val}"
+
+ def tool2(some_val: int, store: Annotated[BaseStore, InjectedStore()]) -> str:
+ """Tool 2 docstring."""
+ store_val = store.get(namespace, "test_key").value["foo"]
+ return f"Some val: {some_val}, store val: {store_val}"
+
+ def tool3(
+ some_val: int,
+ bar: Annotated[str, InjectedState("bar")],
+ store: Annotated[BaseStore, InjectedStore()],
+ ) -> str:
+ """Tool 3 docstring."""
+ store_val = store.get(namespace, "test_key").value["foo"]
+ return f"Some val: {some_val}, store val: {store_val}, state val: {bar}"
+
+ node = ToolNode([tool1, tool2, tool3], handle_tool_errors=True)
+ store.put(namespace, "test_key", {"foo": "bar"})
+
+ class State(MessagesState):
+ bar: str
+
+ builder = StateGraph(State)
+ builder.add_node("tools", node)
+ builder.add_edge(START, "tools")
+ graph = builder.compile(store=store)
+
+ for tool_name in ("tool1", "tool2"):
+ tool_call = {
+ "name": tool_name,
+ "args": {"some_val": 1},
+ "id": "some 0",
+ "type": "tool_call",
+ }
+ msg = AIMessage("hi?", tool_calls=[tool_call])
+ node_result = node.invoke({"messages": [msg]}, store=store)
+ graph_result = graph.invoke({"messages": [msg]})
+ for result in (node_result, graph_result):
+ result["messages"][-1]
+ tool_message = result["messages"][-1]
+ assert tool_message.content == "Some val: 1, store val: bar", (
+ f"Failed for tool={tool_name}"
+ )
+
+ tool_call = {
+ "name": "tool3",
+ "args": {"some_val": 1},
+ "id": "some 0",
+ "type": "tool_call",
+ }
+ msg = AIMessage("hi?", tool_calls=[tool_call])
+ node_result = node.invoke({"messages": [msg], "bar": "baz"}, store=store)
+ graph_result = graph.invoke({"messages": [msg], "bar": "baz"})
+ for result in (node_result, graph_result):
+ result["messages"][-1]
+ tool_message = result["messages"][-1]
+ assert tool_message.content == "Some val: 1, store val: bar, state val: baz", (
+ f"Failed for tool={tool_name}"
+ )
+
+ # test injected store without passing store to compiled graph
+ failing_graph = builder.compile()
+ with pytest.raises(ValueError):
+ failing_graph.invoke({"messages": [msg], "bar": "baz"})
+
+
+def test_tool_node_ensure_utf8() -> None:
+ @dec_tool
+ def get_day_list(days: list[str]) -> list[str]:
+ """choose days"""
+ return days
+
+ data = ["星期一", "水曜日", "목요일", "Friday"]
+ tools = [get_day_list]
+ tool_calls = [ToolCall(name=get_day_list.name, args={"days": data}, id="test_id")]
+ outputs: list[ToolMessage] = ToolNode(tools).invoke(
+ [AIMessage(content="", tool_calls=tool_calls)]
+ )
+ assert outputs[0].content == json.dumps(data, ensure_ascii=False)
+
+
+def test_tool_node_messages_key() -> None:
+ @dec_tool
+ def add(a: int, b: int) -> int:
+ """Adds a and b."""
+ return a + b
+
+ model = FakeToolCallingModel(
+ tool_calls=[[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")]]
+ )
+
+ class State(TypedDict):
+ subgraph_messages: Annotated[list[AnyMessage], add_messages]
+
+ def call_model(state: State) -> dict[str, Any]:
+ response = model.invoke(state["subgraph_messages"])
+ model.tool_calls = []
+ return {"subgraph_messages": response}
+
+ builder = StateGraph(State)
+ builder.add_node("agent", call_model)
+ builder.add_node("tools", ToolNode([add], messages_key="subgraph_messages"))
+ builder.add_conditional_edges(
+ "agent", partial(tools_condition, messages_key="subgraph_messages")
+ )
+ builder.add_edge(START, "agent")
+ builder.add_edge("tools", "agent")
+
+ graph = builder.compile()
+ result = graph.invoke({"subgraph_messages": [HumanMessage(content="hi")]})
+ assert result["subgraph_messages"] == [
+ _AnyIdHumanMessage(content="hi"),
+ AIMessage(
+ content="hi",
+ id="0",
+ tool_calls=[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")],
+ ),
+ _AnyIdToolMessage(content="3", name=add.name, tool_call_id="test_id"),
+ AIMessage(content="hi-hi-3", id="1"),
+ ]
+
+
+def test_tool_node_stream_writer() -> None:
+ @dec_tool
+ def streaming_tool(x: int) -> str:
+ """Do something with writer."""
+ my_writer = get_stream_writer()
+ for value in ["foo", "bar", "baz"]:
+ my_writer({"custom_tool_value": value})
+
+ return x
+
+ tool_node = ToolNode([streaming_tool])
+ graph = (
+ StateGraph(MessagesState).add_node("tools", tool_node).add_edge(START, "tools").compile()
+ )
+
+ tool_call = {
+ "name": "streaming_tool",
+ "args": {"x": 1},
+ "id": "1",
+ "type": "tool_call",
+ }
+ inputs = {
+ "messages": [AIMessage("", tool_calls=[tool_call])],
+ }
+
+ assert list(graph.stream(inputs, stream_mode="custom")) == [
+ {"custom_tool_value": "foo"},
+ {"custom_tool_value": "bar"},
+ {"custom_tool_value": "baz"},
+ ]
+ assert list(graph.stream(inputs, stream_mode=["custom", "updates"])) == [
+ ("custom", {"custom_tool_value": "foo"}),
+ ("custom", {"custom_tool_value": "bar"}),
+ ("custom", {"custom_tool_value": "baz"}),
+ (
+ "updates",
+ {
+ "tools": {
+ "messages": [
+ _AnyIdToolMessage(
+ content="1",
+ name="streaming_tool",
+ tool_call_id="1",
+ ),
+ ],
+ },
+ },
+ ),
+ ]
diff --git a/libs/langchain_v1/tests/unit_tests/agents/utils.py b/libs/langchain_v1/tests/unit_tests/agents/utils.py
new file mode 100644
index 00000000000..21a9d438af5
--- /dev/null
+++ b/libs/langchain_v1/tests/unit_tests/agents/utils.py
@@ -0,0 +1,21 @@
+import json
+from pathlib import Path
+
+from pydantic import BaseModel, ConfigDict
+from pydantic.alias_generators import to_camel
+
+
+class BaseSchema(BaseModel):
+ model_config = ConfigDict(
+ alias_generator=to_camel,
+ populate_by_name=True,
+ from_attributes=True,
+ )
+
+
+def load_spec(spec_name: str, as_model: type[BaseModel]) -> list[BaseModel]:
+ with (Path(__file__).parent / "specifications" / f"{spec_name}.json").open(
+ "r", encoding="utf-8"
+ ) as f:
+ data = json.load(f)
+ return [as_model(**item) for item in data]
diff --git a/libs/langchain_v1/tests/unit_tests/conftest.py b/libs/langchain_v1/tests/unit_tests/conftest.py
index 6656529a4d1..46db1647781 100644
--- a/libs/langchain_v1/tests/unit_tests/conftest.py
+++ b/libs/langchain_v1/tests/unit_tests/conftest.py
@@ -1,42 +1,9 @@
"""Configuration for unit tests."""
-from collections.abc import Iterator, Sequence
+from collections.abc import Sequence
from importlib import util
import pytest
-from blockbuster import blockbuster_ctx
-
-
-@pytest.fixture(autouse=True)
-def blockbuster() -> Iterator[None]:
- with blockbuster_ctx("langchain") as bb:
- bb.functions["io.TextIOWrapper.read"].can_block_in(
- "langchain/__init__.py",
- "",
- )
-
- for func in ["os.stat", "os.path.abspath"]:
- (
- bb.functions[func]
- .can_block_in("langchain_core/runnables/base.py", "__repr__")
- .can_block_in(
- "langchain_core/beta/runnables/context.py",
- "aconfig_with_context",
- )
- )
-
- for func in ["os.stat", "io.TextIOWrapper.read"]:
- bb.functions[func].can_block_in(
- "langsmith/client.py",
- "_default_retry_config",
- )
-
- for bb_function in bb.functions.values():
- bb_function.can_block_in(
- "freezegun/api.py",
- "_get_cached_module_attributes",
- )
- yield
def pytest_addoption(parser: pytest.Parser) -> None:
diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock
index fbb87832bac..2daff5ca14b 100644
--- a/libs/langchain_v1/uv.lock
+++ b/libs/langchain_v1/uv.lock
@@ -264,18 +264,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" },
]
-[[package]]
-name = "blockbuster"
-version = "1.5.25"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "forbiddenfruit", marker = "implementation_name == 'cpython'" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/7f/bc/57c49465decaeeedd58ce2d970b4cdfd93a74ba9993abff2dc498a31c283/blockbuster-1.5.25.tar.gz", hash = "sha256:b72f1d2aefdeecd2a820ddf1e1c8593bf00b96e9fdc4cd2199ebafd06f7cb8f0", size = 36058, upload-time = "2025-07-14T16:00:20.766Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/0b/01/dccc277c014f171f61a6047bb22c684e16c7f2db6bb5c8cce1feaf41ec55/blockbuster-1.5.25-py3-none-any.whl", hash = "sha256:cb06229762273e0f5f3accdaed3d2c5a3b61b055e38843de202311ede21bb0f5", size = 13196, upload-time = "2025-07-14T16:00:19.396Z" },
-]
-
[[package]]
name = "boto3"
version = "1.40.18"
@@ -724,7 +712,7 @@ name = "exceptiongroup"
version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "typing-extensions", marker = "python_full_version < '3.12'" },
+ { name = "typing-extensions", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
wheels = [
@@ -774,12 +762,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5f/e7/319a4ce37bed682741bc8ebbb84b7983da3d8cd7ac069d86b52a37d79f2e/fireworks_ai-0.15.15-py3-none-any.whl", hash = "sha256:1047b8e575a536898a827b089b0022c1fab207940f9773b90fa357ebf942f5c9", size = 112831, upload-time = "2025-06-20T21:11:26.701Z" },
]
-[[package]]
-name = "forbiddenfruit"
-version = "0.1.4"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/e6/79/d4f20e91327c98096d605646bdc6a5ffedae820f38d378d3515c42ec5e60/forbiddenfruit-0.1.4.tar.gz", hash = "sha256:e3f7e66561a29ae129aac139a85d610dbf3dd896128187ed5454b6421f624253", size = 43756, upload-time = "2021-01-16T21:03:35.401Z" }
-
[[package]]
name = "frozenlist"
version = "1.7.0"
@@ -1514,13 +1496,13 @@ lint = [
{ name = "ruff" },
]
test = [
- { name = "blockbuster" },
{ name = "langchain-openai" },
{ name = "langchain-tests" },
{ name = "langchain-text-splitters" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
+ { name = "pytest-mock" },
{ name = "pytest-socket" },
{ name = "pytest-watcher" },
{ name = "pytest-xdist" },
@@ -1571,13 +1553,13 @@ lint = [
{ name = "ruff", specifier = ">=0.12.2" },
]
test = [
- { name = "blockbuster", specifier = ">=1.5.18" },
{ name = "langchain-openai", editable = "../partners/openai" },
{ name = "langchain-tests", editable = "../standard-tests" },
{ name = "langchain-text-splitters", editable = "../text-splitters" },
{ name = "pytest", specifier = ">=8,<9" },
{ name = "pytest-asyncio", specifier = ">=0.23.2" },
{ name = "pytest-cov", specifier = ">=4.0.0" },
+ { name = "pytest-mock" },
{ name = "pytest-socket", specifier = ">=0.6.0" },
{ name = "pytest-watcher", specifier = ">=0.2.6" },
{ name = "pytest-xdist", specifier = ">=3.6.1" },
@@ -3169,6 +3151,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/16/4ea354101abb1287856baa4af2732be351c7bee728065aed451b678153fd/pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5", size = 24644, upload-time = "2025-06-12T10:47:45.932Z" },
]
+[[package]]
+name = "pytest-mock"
+version = "3.14.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pytest" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/71/28/67172c96ba684058a4d24ffe144d64783d2a270d0af0d9e792737bddc75c/pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e", size = 33241, upload-time = "2025-05-26T13:58:45.167Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b2/05/77b60e520511c53d1c1ca75f1930c7dd8e971d0c4379b7f4b3f9644685ba/pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0", size = 9923, upload-time = "2025-05-26T13:58:43.487Z" },
+]
+
[[package]]
name = "pytest-recording"
version = "0.13.4"