diff --git a/libs/langchain_v1/Makefile b/libs/langchain_v1/Makefile index 885a56cf7b6..bb07973f2be 100644 --- a/libs/langchain_v1/Makefile +++ b/libs/langchain_v1/Makefile @@ -1,4 +1,4 @@ -.PHONY: all clean docs_build docs_clean docs_linkcheck api_docs_build api_docs_clean api_docs_linkcheck format lint test tests test_watch integration_tests help extended_tests +.PHONY: all clean docs_build docs_clean docs_linkcheck api_docs_build api_docs_clean api_docs_linkcheck format lint test tests test_watch integration_tests help extended_tests start_services stop_services # Default target executed when no arguments are given to make. all: help @@ -7,6 +7,12 @@ all: help # TESTING AND COVERAGE ###################### +start_services: + docker compose -f tests/unit_tests/agents/compose-postgres.yml -f tests/unit_tests/agents/compose-redis.yml up -V --force-recreate --wait --remove-orphans + +stop_services: + docker compose -f tests/unit_tests/agents/compose-postgres.yml -f tests/unit_tests/agents/compose-redis.yml down -v + # Define a variable for the test file path. TEST_FILE ?= tests/unit_tests/ @@ -21,17 +27,32 @@ coverage: --cov-report term-missing:skip-covered \ $(TEST_FILE) -test tests: - uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) +test: + make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered; \ + EXIT_CODE=$$?; \ + make stop_services; \ + exit $$EXIT_CODE + +test_fast: + LANGGRAPH_TEST_FAST=1 uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) extended_tests: - uv run --group test pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests + make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests; \ + EXIT_CODE=$$?; \ + make stop_services; \ + exit $$EXIT_CODE test_watch: - uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests + make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests; \ + EXIT_CODE=$$?; \ + make stop_services; \ + exit $$EXIT_CODE test_watch_extended: - uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests + make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests; \ + EXIT_CODE=$$?; \ + make stop_services; \ + exit $$EXIT_CODE integration_tests: uv run --group test --group test_integration pytest tests/integration_tests @@ -87,7 +108,8 @@ help: @echo 'spell_fix - run codespell on the project and fix the errors' @echo '-- TESTS --' @echo 'coverage - run unit tests and generate coverage report' - @echo 'test - run unit tests' + @echo 'test - run unit tests with all services' + @echo 'test_fast - run unit tests with in-memory services only' @echo 'tests - run unit tests (alias for "make test")' @echo 'test TEST_FILE= - run all tests in file' @echo 'extended_tests - run only extended unit tests' diff --git a/libs/langchain_v1/langchain/agents/__init__.py b/libs/langchain_v1/langchain/agents/__init__.py new file mode 100644 index 00000000000..f3c10ee29e6 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/__init__.py @@ -0,0 +1,10 @@ +"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools.""" + +from langchain.agents.react_agent import AgentState, create_react_agent +from langchain.agents.tool_node import ToolNode + +__all__ = [ + "AgentState", + "ToolNode", + "create_react_agent", +] diff --git a/libs/langchain_v1/langchain/agents/_internal/__init__.py b/libs/langchain_v1/langchain/agents/_internal/__init__.py new file mode 100644 index 00000000000..b4ff9ac2baa --- /dev/null +++ b/libs/langchain_v1/langchain/agents/_internal/__init__.py @@ -0,0 +1 @@ +"""Internal utilities for agents.""" diff --git a/libs/langchain_v1/langchain/agents/_internal/_typing.py b/libs/langchain_v1/langchain/agents/_internal/_typing.py new file mode 100644 index 00000000000..36989c6ce22 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/_internal/_typing.py @@ -0,0 +1,13 @@ +"""Typing utilities for agents.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import TypeVar, Union + +from typing_extensions import ParamSpec + +P = ParamSpec("P") +R = TypeVar("R") + +SyncOrAsync = Callable[P, Union[R, Awaitable[R]]] diff --git a/libs/langchain_v1/langchain/agents/interrupt.py b/libs/langchain_v1/langchain/agents/interrupt.py new file mode 100644 index 00000000000..fda69363774 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/interrupt.py @@ -0,0 +1,92 @@ +"""Interrupt types to use with agent inbox like setups.""" + +from typing import Literal, Union + +from typing_extensions import TypedDict + + +class HumanInterruptConfig(TypedDict): + """Configuration that defines what actions are allowed for a human interrupt. + + This controls the available interaction options when the graph is paused for human input. + + Attributes: + allow_ignore: Whether the human can choose to ignore/skip the current step + allow_respond: Whether the human can provide a text response/feedback + allow_edit: Whether the human can edit the provided content/state + allow_accept: Whether the human can accept/approve the current state + """ + + allow_ignore: bool + allow_respond: bool + allow_edit: bool + allow_accept: bool + + +class ActionRequest(TypedDict): + """Represents a request for human action within the graph execution. + + Contains the action type and any associated arguments needed for the action. + + Attributes: + action: The type or name of action being requested (e.g., "Approve XYZ action") + args: Key-value pairs of arguments needed for the action + """ + + action: str + args: dict + + +class HumanInterrupt(TypedDict): + """Represents an interrupt triggered by the graph that requires human intervention. + + This is passed to the `interrupt` function when execution is paused for human input. + + Attributes: + action_request: The specific action being requested from the human + config: Configuration defining what actions are allowed + description: Optional detailed description of what input is needed + + Example: + ```python + # Extract a tool call from the state and create an interrupt request + request = HumanInterrupt( + action_request=ActionRequest( + action="run_command", # The action being requested + args={"command": "ls", "args": ["-l"]} # Arguments for the action + ), + config=HumanInterruptConfig( + allow_ignore=True, # Allow skipping this step + allow_respond=True, # Allow text feedback + allow_edit=False, # Don't allow editing + allow_accept=True # Allow direct acceptance + ), + description="Please review the command before execution" + ) + # Send the interrupt request and get the response + response = interrupt([request])[0] + ``` + """ + + action_request: ActionRequest + config: HumanInterruptConfig + description: str | None + + +class HumanResponse(TypedDict): + """The response provided by a human to an interrupt, which is returned when graph execution resumes. + + Attributes: + type: The type of response: + - "accept": Approves the current state without changes + - "ignore": Skips/ignores the current step + - "response": Provides text feedback or instructions + - "edit": Modifies the current state/content + args: The response payload: + - None: For ignore/accept actions + - str: For text responses + - ActionRequest: For edit actions with updated content + """ + + type: Literal["accept", "ignore", "response", "edit"] + args: Union[None, str, ActionRequest] diff --git a/libs/langchain_v1/langchain/agents/react_agent.py b/libs/langchain_v1/langchain/agents/react_agent.py new file mode 100644 index 00000000000..045d92e97dc --- /dev/null +++ b/libs/langchain_v1/langchain/agents/react_agent.py @@ -0,0 +1,1175 @@ +from __future__ import annotations + +import inspect +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import asdict, is_dataclass +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Generic, + Literal, + Union, + cast, + get_type_hints, +) +from warnings import warn + +from langchain_core.language_models import ( + BaseChatModel, + LanguageModelInput, + LanguageModelLike, +) +from langchain_core.messages import ( + AIMessage, + AnyMessage, + BaseMessage, + SystemMessage, + ToolCall, + ToolMessage, +) +from langchain_core.runnables import ( + Runnable, + RunnableConfig, +) +from langgraph._internal._runnable import RunnableCallable, RunnableLike +from langgraph._internal._typing import MISSING +from langgraph.errors import ErrorCode, create_error_message +from langgraph.graph import END, StateGraph +from langgraph.graph.message import add_messages +from langgraph.managed import RemainingSteps # noqa: TC002 +from langgraph.types import Checkpointer, Command, Send +from langgraph.typing import ContextT, StateT +from pydantic import BaseModel +from typing_extensions import NotRequired, TypedDict, TypeVar + +from langchain.agents.structured_output import ( + MultipleStructuredOutputsError, + OutputToolBinding, + ProviderStrategy, + ProviderStrategyBinding, + ResponseFormat, + StructuredOutputValidationError, + ToolStrategy, +) +from langchain.agents.tool_node import ToolNode +from langchain.chat_models import init_chat_model + +if TYPE_CHECKING: + from langchain_core.tools import BaseTool + from langgraph.graph.state import CompiledStateGraph + from langgraph.runtime import Runtime + from langgraph.store.base import BaseStore + + from langchain.agents._internal._typing import ( + SyncOrAsync, + ) + +StructuredResponseT = TypeVar("StructuredResponseT", default=None) + +STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." + + +class AgentState(TypedDict): + """The state of the agent.""" + + messages: Annotated[Sequence[BaseMessage], add_messages] + + remaining_steps: NotRequired[RemainingSteps] + + +class AgentStatePydantic(BaseModel): + """The state of the agent.""" + + messages: Annotated[Sequence[BaseMessage], add_messages] + + remaining_steps: RemainingSteps = 25 + + +class AgentStateWithStructuredResponse(AgentState, Generic[StructuredResponseT]): + """The state of the agent with a structured response.""" + + structured_response: StructuredResponseT + + +class AgentStateWithStructuredResponsePydantic(AgentStatePydantic, Generic[StructuredResponseT]): + """The state of the agent with a structured response.""" + + structured_response: StructuredResponseT + + +PROMPT_RUNNABLE_NAME = "Prompt" + +Prompt = Union[ + SystemMessage, + str, + Callable[[StateT], LanguageModelInput], + Runnable[StateT, LanguageModelInput], +] + + +def _get_state_value(state: StateT, key: str, default: Any = None) -> Any: + return state.get(key, default) if isinstance(state, dict) else getattr(state, key, default) + + +def _get_prompt_runnable(prompt: Prompt | None) -> Runnable: + prompt_runnable: Runnable + if prompt is None: + prompt_runnable = RunnableCallable( + lambda state: _get_state_value(state, "messages"), name=PROMPT_RUNNABLE_NAME + ) + elif isinstance(prompt, str): + _system_message: BaseMessage = SystemMessage(content=prompt) + prompt_runnable = RunnableCallable( + lambda state: [_system_message, *_get_state_value(state, "messages")], + name=PROMPT_RUNNABLE_NAME, + ) + elif isinstance(prompt, SystemMessage): + prompt_runnable = RunnableCallable( + lambda state: [prompt, *_get_state_value(state, "messages")], + name=PROMPT_RUNNABLE_NAME, + ) + elif inspect.iscoroutinefunction(prompt): + prompt_runnable = RunnableCallable( + None, + prompt, + name=PROMPT_RUNNABLE_NAME, + ) + elif callable(prompt): + prompt_runnable = RunnableCallable( + prompt, + name=PROMPT_RUNNABLE_NAME, + ) + elif isinstance(prompt, Runnable): + prompt_runnable = prompt + else: + msg = f"Got unexpected type for `prompt`: {type(prompt)}" + raise ValueError(msg) + + return prompt_runnable + + +def _validate_chat_history( + messages: Sequence[BaseMessage], +) -> None: + """Validate that all tool calls in AIMessages have a corresponding ToolMessage.""" + all_tool_calls = [ + tool_call + for message in messages + if isinstance(message, AIMessage) + for tool_call in message.tool_calls + ] + tool_call_ids_with_results = { + message.tool_call_id for message in messages if isinstance(message, ToolMessage) + } + tool_calls_without_results = [ + tool_call + for tool_call in all_tool_calls + if tool_call["id"] not in tool_call_ids_with_results + ] + if not tool_calls_without_results: + return + + error_message = create_error_message( + message="Found AIMessages with tool_calls that do not have a corresponding ToolMessage. " + f"Here are the first few of those tool calls: {tool_calls_without_results[:3]}.\n\n" + "Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage " + "(result of a tool invocation to return to the LLM) - this is required by most LLM providers.", + error_code=ErrorCode.INVALID_CHAT_HISTORY, + ) + raise ValueError(error_message) + + +class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]): + """Internal builder class for constructing and agent.""" + + def __init__( + self, + model: Union[ + str, + BaseChatModel, + SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel], + ], + tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode], + *, + prompt: Prompt | None = None, + response_format: ResponseFormat[StructuredResponseT] | None = None, + pre_model_hook: RunnableLike | None = None, + post_model_hook: RunnableLike | None = None, + state_schema: type[StateT] | None = None, + context_schema: type[ContextT] | None = None, + version: Literal["v1", "v2"] = "v2", + name: str | None = None, + store: BaseStore | None = None, + ) -> None: + self.model = model + self.tools = tools + self.prompt = prompt + self.response_format = response_format + self.pre_model_hook = pre_model_hook + self.post_model_hook = post_model_hook + self.state_schema = state_schema + self.context_schema = context_schema + self.version = version + self.name = name + self.store = store + + if isinstance(model, Runnable) and not isinstance(model, BaseChatModel): + msg = ( + "Expected `model` to be a BaseChatModel or a string, got {type(model)}." + "The `model` parameter should not have pre-bound tools, simply pass the model and tools separately." + ) + raise ValueError(msg) + + self._setup_tools() + self._setup_state_schema() + self._setup_structured_output() + self._setup_model() + + def _setup_tools(self) -> None: + """Setup tool-related attributes.""" + if isinstance(self.tools, ToolNode): + self._tool_classes = list(self.tools.tools_by_name.values()) + self._tool_node = self.tools + self._llm_builtin_tools = [] + else: + self._llm_builtin_tools = [t for t in self.tools if isinstance(t, dict)] + self._tool_node = ToolNode([t for t in self.tools if not isinstance(t, dict)]) + self._tool_classes = list(self._tool_node.tools_by_name.values()) + + self._should_return_direct = {t.name for t in self._tool_classes if t.return_direct} + self._tool_calling_enabled = len(self._tool_classes) > 0 + + def _setup_structured_output(self) -> None: + """Set up structured output tracking for "tools" and "native" strategies. + + "tools" strategy for structured output: + 1. Converting response format schemas to LangChain tools + 2. Creating metadata for proper response reconstruction + 3. Handling both Pydantic models and dict schemas + + "native" strategy for structured output: + 1. Capturing the schema reference for later parsing + 2. Binding provider-native response_format kwargs at model bind time + 3. Parsing provider-enforced structured output directly into the schema + """ + self.structured_output_tools: dict[str, OutputToolBinding[StructuredResponseT]] = {} + self.native_output_binding: ProviderStrategyBinding[StructuredResponseT] | None = None + + if self.response_format is not None: + response_format = self.response_format + + if isinstance(response_format, ToolStrategy): + # check if response_format.schema is a union + for response_schema in response_format.schema_specs: + structured_tool_info = OutputToolBinding.from_schema_spec(response_schema) + self.structured_output_tools[structured_tool_info.tool.name] = ( + structured_tool_info + ) + elif isinstance(response_format, ProviderStrategy): + # Use native strategy - create ProviderStrategyBinding for parsing + self.native_output_binding = ProviderStrategyBinding.from_schema_spec( + response_format.schema_spec + ) + else: + # This shouldn't happen with the new ResponseFormat type, but keeping for safety + msg = ( + f"Unsupported response_format type: {type(response_format)}. " + f"Expected ToolStrategy." + ) + raise ValueError(msg) + + def _setup_state_schema(self) -> None: + """Setup state schema with validation.""" + if self.state_schema is not None: + required_keys = {"messages", "remaining_steps"} + if self.response_format is not None: + required_keys.add("structured_response") + + schema_keys = set(get_type_hints(self.state_schema)) + if missing_keys := required_keys - schema_keys: + msg = f"Missing required key(s) {missing_keys} in state_schema" + raise ValueError(msg) + + self._final_state_schema = self.state_schema + else: + self._final_state_schema = ( + AgentStateWithStructuredResponse # type: ignore[assignment] + if self.response_format is not None + else AgentState + ) + + def _handle_structured_response_tool_calls(self, response: AIMessage) -> Command | None: + """Handle tool calls that match structured output tools using the tools strategy. + + Args: + response: The AI message containing potential tool calls + + Returns: + Command with structured response update if found, None otherwise + + Raises: + MultipleStructuredOutputsError: If multiple structured responses are returned and error handling is disabled + StructuredOutputParsingError: If parsing fails and error handling is disabled + """ + if not isinstance(self.response_format, ToolStrategy) or not response.tool_calls: + return None + + structured_tool_calls = [ + tool_call + for tool_call in response.tool_calls + if tool_call["name"] in self.structured_output_tools + ] + + if not structured_tool_calls: + return None + + if len(structured_tool_calls) > 1: + return self._handle_multiple_structured_outputs(response, structured_tool_calls) + + return self._handle_single_structured_output(response, structured_tool_calls[0]) + + def _handle_multiple_structured_outputs( + self, + response: AIMessage, + structured_tool_calls: list[ToolCall], + ) -> Command: + """Handle multiple structured output tool calls.""" + tool_names = [tool_call["name"] for tool_call in structured_tool_calls] + exception = MultipleStructuredOutputsError(tool_names) + + should_retry, error_message = self._handle_structured_output_error(exception) + + if not should_retry: + raise exception + + tool_messages = [ + ToolMessage( + content=error_message, + tool_call_id=tool_call["id"], + name=tool_call["name"], + ) + for tool_call in structured_tool_calls + ] + + return Command( + update={"messages": [response, *tool_messages]}, + goto="agent", + ) + + def _handle_single_structured_output( + self, + response: AIMessage, + tool_call: Any, + ) -> Command: + """Handle a single structured output tool call.""" + structured_tool_binding = self.structured_output_tools[tool_call["name"]] + + try: + structured_response = structured_tool_binding.parse(tool_call["args"]) + + if isinstance(structured_response, BaseModel): + structured_response_dict = structured_response.model_dump() + elif is_dataclass(structured_response): + structured_response_dict = asdict(structured_response) # type: ignore[arg-type] + else: + structured_response_dict = cast("dict", structured_response) + + tool_message_content = ( + self.response_format.tool_message_content + if isinstance(self.response_format, ToolStrategy) + and self.response_format.tool_message_content + else f"Returning structured response: {structured_response_dict}" + ) + + return Command( + update={ + "messages": [ + response, + ToolMessage( + content=tool_message_content, + tool_call_id=tool_call["id"], + name=tool_call["name"], + ), + ], + "structured_response": structured_response, + } + ) + except Exception as exc: # noqa: BLE001 + exception = StructuredOutputValidationError(tool_call["name"], exc) + + should_retry, error_message = self._handle_structured_output_error(exception) + + if not should_retry: + raise exception + + return Command( + update={ + "messages": [ + response, + ToolMessage( + content=error_message, + tool_call_id=tool_call["id"], + name=tool_call["name"], + ), + ], + }, + goto="agent", + ) + + def _handle_structured_output_error( + self, + exception: Exception, + ) -> tuple[bool, str]: + """Handle structured output error. + + Returns (should_retry, retry_tool_message). + """ + handle_errors = cast("ToolStrategy", self.response_format).handle_errors + + if handle_errors is False: + return False, "" + if handle_errors is True: + return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception)) + if isinstance(handle_errors, str): + return True, handle_errors + if isinstance(handle_errors, type) and issubclass(handle_errors, Exception): + if isinstance(exception, handle_errors): + return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception)) + return False, "" + if isinstance(handle_errors, tuple): + if any(isinstance(exception, exc_type) for exc_type in handle_errors): + return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception)) + return False, "" + if callable(handle_errors): + return True, handle_errors(exception) # type: ignore[call-arg, return-value] + return False, "" + + def _apply_native_output_binding(self, model: LanguageModelLike) -> LanguageModelLike: + """If native output is configured, bind provider-native kwargs onto the model.""" + if not isinstance(self.response_format, ProviderStrategy): + return model + kwargs = self.response_format.to_model_kwargs() + return model.bind(**kwargs) + + def _handle_structured_response_native(self, response: AIMessage) -> Command | None: + """If native output is configured and there are no tool calls, parse using ProviderStrategyBinding.""" + if self.native_output_binding is None: + return None + if response.tool_calls: + # if the model chooses to call tools, we let the normal flow handle it + return None + + structured_response = self.native_output_binding.parse(response) + + return Command(update={"messages": [response], "structured_response": structured_response}) + + def _setup_model(self) -> None: + """Setup model-related attributes.""" + self._is_dynamic_model = not isinstance(self.model, (str, Runnable)) and callable( + self.model + ) + self._is_async_dynamic_model = self._is_dynamic_model and inspect.iscoroutinefunction( + self.model + ) + + if not self._is_dynamic_model: + model = self.model + if isinstance(model, str): + model = init_chat_model(model) + + # Collect all tools: regular tools + structured output tools + structured_output_tools = list(self.structured_output_tools.values()) + all_tools = ( + self._tool_classes + + self._llm_builtin_tools + + [info.tool for info in structured_output_tools] + ) + + if len(all_tools) > 0: + # Check if we need to force tool use for structured output + tool_choice = None + if self.response_format is not None and isinstance( + self.response_format, ToolStrategy + ): + tool_choice = "any" + + if tool_choice: + model = cast("BaseChatModel", model).bind_tools( # type: ignore[assignment] + all_tools, tool_choice=tool_choice + ) + # If native output is configured, bind tools with strict=True. Required for OpenAI. + elif isinstance(self.response_format, ProviderStrategy): + model = cast("BaseChatModel", model).bind_tools( # type: ignore[assignment] + all_tools, strict=True + ) + else: + model = cast("BaseChatModel", model).bind_tools(all_tools) # type: ignore[assignment] + + # bind native structured-output kwargs + model = self._apply_native_output_binding(model) # type: ignore[assignment, arg-type] + + # Extract just the model part for direct invocation + self._static_model: Runnable | None = model # type: ignore[assignment] + else: + self._static_model = None + + def _resolve_model(self, state: StateT, runtime: Runtime[ContextT]) -> LanguageModelLike: + """Resolve the model to use, handling both static and dynamic models.""" + if self._is_dynamic_model: + dynamic_model = self.model(state, runtime) # type: ignore[operator, arg-type] + return self._apply_native_output_binding(dynamic_model) # type: ignore[arg-type] + return self._static_model # type: ignore[return-value] + + async def _aresolve_model(self, state: StateT, runtime: Runtime[ContextT]) -> LanguageModelLike: + """Async resolve the model to use, handling both static and dynamic models.""" + if self._is_async_dynamic_model: + dynamic_model = cast( + "Callable[[StateT, Runtime[ContextT]], Awaitable[BaseChatModel]]", + self.model, + ) + return await dynamic_model(state, runtime) + if self._is_dynamic_model: + dynamic_model = self.model(state, runtime) # type: ignore[arg-type, assignment, operator] + return self._apply_native_output_binding(dynamic_model) # type: ignore[arg-type] + return self._static_model # type: ignore[return-value] + + def create_model_node(self) -> RunnableCallable: + """Create the 'agent' node that calls the LLM.""" + + def _get_model_input_state(state: StateT) -> StateT: + if self.pre_model_hook is not None: + messages = _get_state_value(state, "llm_input_messages") or _get_state_value( + state, "messages" + ) + error_msg = ( + f"Expected input to call_model to have 'llm_input_messages' " + f"or 'messages' key, but got {state}" + ) + else: + messages = _get_state_value(state, "messages") + error_msg = f"Expected input to call_model to have 'messages' key, but got {state}" + + if messages is None: + raise ValueError(error_msg) + + _validate_chat_history(messages) + + if isinstance(self._final_state_schema, type) and issubclass( + self._final_state_schema, BaseModel + ): + # we're passing messages under `messages` key, as this + # is expected by the prompt + state.messages = messages # type: ignore[union-attr] + else: + state["messages"] = messages # type: ignore[index] + return state + + def _are_more_steps_needed(state: StateT, response: BaseMessage) -> bool: + has_tool_calls = isinstance(response, AIMessage) and response.tool_calls + all_tools_return_direct = ( + all(call["name"] in self._should_return_direct for call in response.tool_calls) + if isinstance(response, AIMessage) + else False + ) + remaining_steps = _get_state_value(state, "remaining_steps", None) + return ( + remaining_steps is not None # type: ignore[return-value] + and ( + (remaining_steps < 1 and all_tools_return_direct) + or (remaining_steps < 2 and has_tool_calls) + ) + ) + + def call_model( + state: StateT, runtime: Runtime[ContextT], config: RunnableConfig + ) -> dict[str, Any] | Command: + """Call the model with the current state and return the response.""" + if self._is_async_dynamic_model: + msg = ( + "Async model callable provided but agent invoked synchronously. " + "Use agent.ainvoke() or agent.astream(), or provide a sync model callable." + ) + raise RuntimeError(msg) + + model_input = _get_model_input_state(state) + model = self._resolve_model(state, runtime) + + # Get prompt runnable and invoke it first to prepare messages + prompt_runnable = _get_prompt_runnable(self.prompt) + prepared_messages = prompt_runnable.invoke(model_input, config) + + # Then invoke the model with the prepared messages + response = cast("AIMessage", model.invoke(prepared_messages, config)) + response.name = self.name + + if _are_more_steps_needed(state, response): + return { + "messages": [ + AIMessage( + id=response.id, + content="Sorry, need more steps to process this request.", + ) + ] + } + + # Check if any tool calls match structured output tools + structured_command = self._handle_structured_response_tool_calls(response) + if structured_command: + return structured_command + + # Native structured output + native_command = self._handle_structured_response_native(response) + if native_command: + return native_command + + return {"messages": [response]} + + async def acall_model( + state: StateT, runtime: Runtime[ContextT], config: RunnableConfig + ) -> dict[str, Any] | Command: + """Call the model with the current state and return the response.""" + model_input = _get_model_input_state(state) + + model = await self._aresolve_model(state, runtime) + + # Get prompt runnable and invoke it first to prepare messages + prompt_runnable = _get_prompt_runnable(self.prompt) + prepared_messages = await prompt_runnable.ainvoke(model_input, config) + + # Then invoke the model with the prepared messages + response = cast( + "AIMessage", + await model.ainvoke(prepared_messages, config), + ) + response.name = self.name + if _are_more_steps_needed(state, response): + return { + "messages": [ + AIMessage( + id=response.id, + content="Sorry, need more steps to process this request.", + ) + ] + } + + # Check if any tool calls match structured output tools + structured_command = self._handle_structured_response_tool_calls(response) + if structured_command: + return structured_command + + # Native structured output + native_command = self._handle_structured_response_native(response) + if native_command: + return native_command + + return {"messages": [response]} + + return RunnableCallable(call_model, acall_model) + + def _get_input_schema(self) -> type[StateT]: + """Get input schema for model node.""" + if self.pre_model_hook is not None: + if isinstance(self._final_state_schema, type) and issubclass( + self._final_state_schema, BaseModel + ): + from pydantic import create_model + + return create_model( + "CallModelInputSchema", + llm_input_messages=(list[AnyMessage], ...), + __base__=self._final_state_schema, + ) + + class CallModelInputSchema(self._final_state_schema): # type: ignore[name-defined, misc] + llm_input_messages: list[AnyMessage] + + return CallModelInputSchema + return self._final_state_schema + + def create_model_router(self) -> Callable[[StateT], Union[str, list[Send]]]: + """Create routing function for model node conditional edges.""" + + def should_continue(state: StateT) -> Union[str, list[Send]]: + messages = _get_state_value(state, "messages") + last_message = messages[-1] + + # Check if the last message is a ToolMessage from a structured tool. + # This condition exists to support structured output via tools. + # Once a tool has been called for structured output, we skip + # tool execution and go to END (if there is no post_model_hook). + if ( + isinstance(last_message, ToolMessage) + and last_message.name in self.structured_output_tools + ): + return END + + if isinstance(last_message, ToolMessage): + return END + + if not isinstance(last_message, AIMessage) or not last_message.tool_calls: + if self.post_model_hook is not None: + return "post_model_hook" + return END + if self.version == "v1": + return "tools" + if self.version == "v2": + if self.post_model_hook is not None: + return "post_model_hook" + tool_calls = [ + self._tool_node.inject_tool_args(call, state, self.store) # type: ignore[arg-type] + for call in last_message.tool_calls + ] + return [Send("tools", [tool_call]) for tool_call in tool_calls] + return None + + return should_continue + + def create_post_model_hook_router( + self, + ) -> Callable[[StateT], Union[str, list[Send]]]: + """Create a routing function for post_model_hook node conditional edges.""" + + def post_model_hook_router(state: StateT) -> Union[str, list[Send]]: + messages = _get_state_value(state, "messages") + + # Check if the last message is a ToolMessage from a structured tool. + # This condition exists to support structured output via tools. + # Once a tool has been called for structured output, we skip + # tool execution and go to END (if there is no post_model_hook). + last_message = messages[-1] + if ( + isinstance(last_message, ToolMessage) + and last_message.name in self.structured_output_tools + ): + return END + + tool_messages = [m.tool_call_id for m in messages if isinstance(m, ToolMessage)] + last_ai_message = next(m for m in reversed(messages) if isinstance(m, AIMessage)) + pending_tool_calls = [ + c for c in last_ai_message.tool_calls if c["id"] not in tool_messages + ] + + if pending_tool_calls: + pending_tool_calls = [ + self._tool_node.inject_tool_args(call, state, self.store) # type: ignore[arg-type] + for call in pending_tool_calls + ] + return [Send("tools", [tool_call]) for tool_call in pending_tool_calls] + if isinstance(messages[-1], ToolMessage): + return self._get_entry_point() + return END + + return post_model_hook_router + + def create_tools_router(self) -> Callable[[StateT], str] | None: + """Create a routing function for tools node conditional edges.""" + if not self._should_return_direct: + return None + + def route_tool_responses(state: StateT) -> str: + messages = _get_state_value(state, "messages") + for m in reversed(messages): + if not isinstance(m, ToolMessage): + break + if m.name in self._should_return_direct: + return END + + if ( + isinstance(m, AIMessage) + and m.tool_calls + and any(call["name"] in self._should_return_direct for call in m.tool_calls) + ): + return END + + return self._get_entry_point() + + return route_tool_responses + + def _get_entry_point(self) -> str: + """Get the workflow entry point.""" + return "pre_model_hook" if self.pre_model_hook else "agent" + + def _get_model_paths(self) -> list[str]: + """Get possible edge destinations from model node.""" + paths = [] + if self._tool_calling_enabled: + paths.append("tools") + if self.post_model_hook: + paths.append("post_model_hook") + else: + paths.append(END) + + return paths + + def _get_post_model_hook_paths(self) -> list[str]: + """Get possible edge destinations from post_model_hook node.""" + paths = [] + if self._tool_calling_enabled: + paths = [self._get_entry_point(), "tools"] + paths.append(END) + return paths + + def build(self) -> StateGraph[StateT, ContextT]: + """Build the agent workflow graph (uncompiled).""" + workflow = StateGraph( + state_schema=self._final_state_schema, + context_schema=self.context_schema, + ) + + # Set entry point + workflow.set_entry_point(self._get_entry_point()) + + # Add nodes + workflow.add_node("agent", self.create_model_node(), input_schema=self._get_input_schema()) + + if self._tool_calling_enabled: + workflow.add_node("tools", self._tool_node) + + if self.pre_model_hook: + workflow.add_node("pre_model_hook", self.pre_model_hook) # type: ignore[arg-type] + + if self.post_model_hook: + workflow.add_node("post_model_hook", self.post_model_hook) # type: ignore[arg-type] + + # Add edges + if self.pre_model_hook: + workflow.add_edge("pre_model_hook", "agent") + + if self.post_model_hook: + workflow.add_edge("agent", "post_model_hook") + post_hook_paths = self._get_post_model_hook_paths() + if len(post_hook_paths) == 1: + # No need for a conditional edge if there's only one path + workflow.add_edge("post_model_hook", post_hook_paths[0]) + else: + workflow.add_conditional_edges( + "post_model_hook", + self.create_post_model_hook_router(), + path_map=post_hook_paths, + ) + else: + model_paths = self._get_model_paths() + if len(model_paths) == 1: + # No need for a conditional edge if there's only one path + workflow.add_edge("agent", model_paths[0]) + else: + workflow.add_conditional_edges( + "agent", + self.create_model_router(), + path_map=model_paths, + ) + + if self._tool_calling_enabled: + # In some cases, tools can return directly. In these cases + # we add a conditional edge from the tools node to the END node + # instead of going to the entry point. + tools_router = self.create_tools_router() + if tools_router: + workflow.add_conditional_edges( + "tools", + tools_router, + path_map=[self._get_entry_point(), END], + ) + else: + workflow.add_edge("tools", self._get_entry_point()) + + return workflow + + +def _supports_native_structured_output( + model: Union[str, BaseChatModel, SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel]], +) -> bool: + """Check if a model supports native structured output. + + TODO: replace with more robust model profiles. + """ + model_name: str | None = None + if isinstance(model, str): + model_name = model + elif isinstance(model, BaseChatModel): + model_name = getattr(model, "model_name", None) + + return ( + "grok" in model_name.lower() + or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"]) + if model_name + else False + ) + + +def create_react_agent( # noqa: D417 + model: Union[ + str, + BaseChatModel, + SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel], + ], + tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode], + *, + prompt: Prompt | None = None, + response_format: Union[ + ToolStrategy[StructuredResponseT], + ProviderStrategy[StructuredResponseT], + type[StructuredResponseT], + ] + | None = None, + pre_model_hook: RunnableLike | None = None, + post_model_hook: RunnableLike | None = None, + state_schema: type[StateT] | None = None, + context_schema: type[ContextT] | None = None, + checkpointer: Checkpointer | None = None, + store: BaseStore | None = None, + interrupt_before: list[str] | None = None, + interrupt_after: list[str] | None = None, + debug: bool = False, + version: Literal["v1", "v2"] = "v2", + name: str | None = None, + **deprecated_kwargs: Any, +) -> CompiledStateGraph[StateT, ContextT]: + """Creates an agent graph that calls tools in a loop until a stopping condition is met. + + For more details on using `create_react_agent`, visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation. + + Args: + model: The language model for the agent. Supports static and dynamic + model selection. + + - **Static model**: A chat model instance (e.g., `ChatOpenAI()`) or + string identifier (e.g., `"openai:gpt-4"`) + - **Dynamic model**: A callable with signature + `(state, runtime) -> BaseChatModel` that returns different models + based on runtime context + If the model has tools bound via `.bind_tools()` or other configurations, + the return type should be a Runnable[LanguageModelInput, BaseMessage] + Coroutines are also supported, allowing for asynchronous model selection. + + Dynamic functions receive graph state and runtime, enabling + context-dependent model selection. Must return a `BaseChatModel` + instance. For tool calling, bind tools using `.bind_tools()`. + Bound tools must be a subset of the `tools` parameter. + + Dynamic model example: + ```python + from dataclasses import dataclass + + @dataclass + class ModelContext: + model_name: str = "gpt-3.5-turbo" + + # Instantiate models globally + gpt4_model = ChatOpenAI(model="gpt-4") + gpt35_model = ChatOpenAI(model="gpt-3.5-turbo") + + def select_model(state: AgentState, runtime: Runtime[ModelContext]) -> ChatOpenAI: + model_name = runtime.context.model_name + model = gpt4_model if model_name == "gpt-4" else gpt35_model + return model.bind_tools(tools) + ``` + + !!! note "Dynamic Model Requirements" + Ensure returned models have appropriate tools bound via + `.bind_tools()` and support required functionality. Bound tools + must be a subset of those specified in the `tools` parameter. + + tools: A list of tools or a ToolNode instance. + If an empty list is provided, the agent will consist of a single LLM node without tool calling. + prompt: An optional prompt for the LLM. Can take a few different forms: + + - str: This is converted to a SystemMessage and added to the beginning of the list of messages in state["messages"]. + - SystemMessage: this is added to the beginning of the list of messages in state["messages"]. + - Callable: This function should take in full graph state and the output is then passed to the language model. + - Runnable: This runnable should take in full graph state and the output is then passed to the language model. + + response_format: An optional UsingToolStrategy configuration for structured responses. + + If provided, the agent will handle structured output via tool calls during the normal conversation flow. + When the model calls a structured output tool, the response will be captured and returned in the 'structured_response' state key. + If not provided, `structured_response` will not be present in the output state. + + The UsingToolStrategy should contain: + - schemas: A sequence of ResponseSchema objects that define the structured output format + - tool_choice: Either "required" or "auto" to control when structured output is used + + Each ResponseSchema contains: + - schema: A Pydantic model that defines the structure + - name: Optional custom name for the tool (defaults to model name) + - description: Optional custom description (defaults to model docstring) + - strict: Whether to enforce strict validation + + !!! Important + `response_format` requires the model to support tool calling + + !!! Note + Structured responses are handled directly in the model call node via tool calls, eliminating the need for separate structured response nodes. + + pre_model_hook: An optional node to add before the `agent` node (i.e., the node that calls the LLM). + Useful for managing long message histories (e.g., message trimming, summarization, etc.). + Pre-model hook must be a callable or a runnable that takes in current graph state and returns a state update in the form of + ```python + # At least one of `messages` or `llm_input_messages` MUST be provided + { + # If provided, will UPDATE the `messages` in the state + "messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), ...], + # If provided, will be used as the input to the LLM, + # and will NOT UPDATE `messages` in the state + "llm_input_messages": [...], + # Any other state keys that need to be propagated + ... + } + ``` + + !!! Important + At least one of `messages` or `llm_input_messages` MUST be provided and will be used as an input to the `agent` node. + The rest of the keys will be added to the graph state. + + !!! Warning + If you are returning `messages` in the pre-model hook, you should OVERWRITE the `messages` key by doing the following: + + ```python + { + "messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *new_messages] + ... + } + ``` + post_model_hook: An optional node to add after the `agent` node (i.e., the node that calls the LLM). + Useful for implementing human-in-the-loop, guardrails, validation, or other post-processing. + Post-model hook must be a callable or a runnable that takes in current graph state and returns a state update. + + !!! Note + Only available with `version="v2"`. + state_schema: An optional state schema that defines graph state. + Must have `messages` and `remaining_steps` keys. + Defaults to `AgentState` that defines those two keys. + context_schema: An optional schema for runtime context. + checkpointer: An optional checkpoint saver object. This is used for persisting + the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation). + store: An optional store object. This is used for persisting data + across multiple threads (e.g., multiple conversations / users). + interrupt_before: An optional list of node names to interrupt before. + Should be one of the following: "agent", "tools". + This is useful if you want to add a user confirmation or other interrupt before taking an action. + interrupt_after: An optional list of node names to interrupt after. + Should be one of the following: "agent", "tools". + This is useful if you want to return directly or run additional processing on an output. + debug: A flag indicating whether to enable debug mode. + version: Determines the version of the graph to create. + Can be one of: + + - `"v1"`: The tool node processes a single message. All tool + calls in the message are executed in parallel within the tool node. + - `"v2"`: The tool node processes a tool call. + Tool calls are distributed across multiple instances of the tool + node using the [Send](https://langchain-ai.github.io/langgraph/concepts/low_level/#send) + API. + name: An optional name for the CompiledStateGraph. + This name will be automatically used when adding ReAct agent graph to another graph as a subgraph node - + particularly useful for building multi-agent systems. + + !!! warning "`config_schema` Deprecated" + The `config_schema` parameter is deprecated in v0.6.0 and support will be removed in v2.0.0. + Please use `context_schema` instead to specify the schema for run-scoped context. + + + Returns: + A compiled LangChain runnable that can be used for chat interactions. + + The "agent" node calls the language model with the messages list (after applying the prompt). + If the resulting AIMessage contains `tool_calls`, the graph will then call the ["tools"][langgraph.prebuilt.tool_node.ToolNode]. + The "tools" node executes the tools (1 tool per `tool_call`) and adds the responses to the messages list + as `ToolMessage` objects. The agent node then calls the language model again. + The process repeats until no more `tool_calls` are present in the response. + The agent then returns the full list of messages as a dictionary containing the key "messages". + + ``` mermaid + sequenceDiagram + participant U as User + participant A as LLM + participant T as Tools + U->>A: Initial input + Note over A: Prompt + LLM + loop while tool_calls present + A->>T: Execute tools + T-->>A: ToolMessage for each tool_calls + end + A->>U: Return final state + ``` + + Example: + ```python + from langchain.agents import create_react_agent + + def check_weather(location: str) -> str: + '''Return the weather forecast for the specified location.''' + return f"It's always sunny in {location}" + + graph = create_react_agent( + "anthropic:claude-3-7-sonnet-latest", + tools=[check_weather], + prompt="You are a helpful assistant", + ) + inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]} + for chunk in graph.stream(inputs, stream_mode="updates"): + print(chunk) + ``` + """ + # Handle deprecated config_schema parameter + if (config_schema := deprecated_kwargs.pop("config_schema", MISSING)) is not MISSING: + warn( + "`config_schema` is deprecated and will be removed. Please use `context_schema` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + if context_schema is None: + context_schema = config_schema + + if len(deprecated_kwargs) > 0: + msg = f"create_react_agent() got unexpected keyword arguments: {deprecated_kwargs}" + raise TypeError(msg) + + if response_format and not isinstance(response_format, (ToolStrategy, ProviderStrategy)): + if _supports_native_structured_output(model): + response_format = ProviderStrategy( + schema=response_format, + ) + else: + response_format = ToolStrategy( + schema=response_format, + ) + elif isinstance(response_format, tuple) and len(response_format) == 2: + msg = "Passing a 2-tuple as response_format is no longer supported. " + raise ValueError(msg) + + # Create and configure the agent builder + builder = _AgentBuilder( + model=model, + tools=tools, + prompt=prompt, + response_format=cast("Union[ResponseFormat[StructuredResponseT], None]", response_format), + pre_model_hook=pre_model_hook, + post_model_hook=post_model_hook, + state_schema=state_schema, + context_schema=context_schema, + version=version, + name=name, + store=store, + ) + + # Build and compile the workflow + workflow = builder.build() + return workflow.compile( # type: ignore[return-value] + checkpointer=checkpointer, + store=store, + interrupt_before=interrupt_before, + interrupt_after=interrupt_after, + debug=debug, + name=name, + ) + + +__all__ = [ + "AgentState", + "AgentStatePydantic", + "AgentStateWithStructuredResponse", + "AgentStateWithStructuredResponsePydantic", + "create_react_agent", +] diff --git a/libs/langchain_v1/langchain/agents/structured_output.py b/libs/langchain_v1/langchain/agents/structured_output.py new file mode 100644 index 00000000000..dbbd3eef609 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/structured_output.py @@ -0,0 +1,403 @@ +"""Types for setting agent response formats.""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, is_dataclass +from types import UnionType +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + TypeVar, + Union, + get_args, + get_origin, +) + +from langchain_core.tools import BaseTool, StructuredTool +from pydantic import BaseModel, TypeAdapter +from typing_extensions import Self, is_typeddict + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from langchain_core.messages import AIMessage + +# Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts +SchemaT = TypeVar("SchemaT") + +SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"] + + +class StructuredOutputError(Exception): + """Base class for structured output errors.""" + + +class MultipleStructuredOutputsError(StructuredOutputError): + """Raised when model returns multiple structured output tool calls when only one is expected.""" + + def __init__(self, tool_names: list[str]) -> None: + """Initialize MultipleStructuredOutputsError. + + Args: + tool_names: The names of the tools called for structured output. + """ + self.tool_names = tool_names + + super().__init__( + f"Model incorrectly returned multiple structured responses ({', '.join(tool_names)}) when only one is expected." + ) + + +class StructuredOutputValidationError(StructuredOutputError): + """Raised when structured output tool call arguments fail to parse according to the schema.""" + + def __init__(self, tool_name: str, source: Exception) -> None: + """Initialize StructuredOutputValidationError. + + Args: + tool_name: The name of the tool that failed. + source: The exception that occurred. + """ + self.tool_name = tool_name + self.source = source + super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.") + + +def _parse_with_schema( + schema: Union[type[SchemaT], dict], schema_kind: SchemaKind, data: dict[str, Any] +) -> Any: + """Parse data using for any supported schema type. + + Args: + schema: The schema type (Pydantic model, dataclass, or TypedDict) + schema_kind: One of "pydantic", "dataclass", "typeddict", or "json_schema" + data: The data to parse + + Returns: + The parsed instance according to the schema type + + Raises: + ValueError: If parsing fails + """ + if schema_kind == "json_schema": + return data + try: + adapter: TypeAdapter[SchemaT] = TypeAdapter(schema) + return adapter.validate_python(data) + except Exception as e: + schema_name = getattr(schema, "__name__", str(schema)) + msg = f"Failed to parse data to {schema_name}: {e}" + raise ValueError(msg) from e + + +@dataclass(init=False) +class _SchemaSpec(Generic[SchemaT]): + """Describes a structured output schema.""" + + schema: type[SchemaT] + """The schema for the response, can be a Pydantic model, dataclass, TypedDict, or JSON schema dict.""" + + name: str + """Name of the schema, used for tool calling. + + If not provided, the name will be the model name or "response_format" if it's a JSON schema. + """ + + description: str + """Custom description of the schema. + + If not provided, provided will use the model's docstring. + """ + + schema_kind: SchemaKind + """The kind of schema.""" + + json_schema: dict[str, Any] + """JSON schema associated with the schema.""" + + strict: bool = False + """Whether to enforce strict validation of the schema.""" + + def __init__( + self, + schema: type[SchemaT], + *, + name: str | None = None, + description: str | None = None, + strict: bool = False, + ) -> None: + """Initialize SchemaSpec with schema and optional parameters.""" + self.schema = schema + + if name: + self.name = name + elif isinstance(schema, dict): + self.name = str(schema.get("title", f"response_format_{str(uuid.uuid4())[:4]}")) + else: + self.name = str(getattr(schema, "__name__", f"response_format_{str(uuid.uuid4())[:4]}")) + + self.description = description or ( + schema.get("description", "") + if isinstance(schema, dict) + else getattr(schema, "__doc__", None) or "" + ) + + self.strict = strict + + if isinstance(schema, dict): + self.schema_kind = "json_schema" + self.json_schema = schema + elif isinstance(schema, type) and issubclass(schema, BaseModel): + self.schema_kind = "pydantic" + self.json_schema = schema.model_json_schema() + elif is_dataclass(schema): + self.schema_kind = "dataclass" + self.json_schema = TypeAdapter(schema).json_schema() + elif is_typeddict(schema): + self.schema_kind = "typeddict" + self.json_schema = TypeAdapter(schema).json_schema() + else: + msg = ( + f"Unsupported schema type: {type(schema)}. " + f"Supported types: Pydantic models, dataclasses, TypedDicts, and JSON schema dicts." + ) + raise ValueError(msg) + + +@dataclass(init=False) +class ToolStrategy(Generic[SchemaT]): + """Use a tool calling strategy for model responses.""" + + schema: type[SchemaT] + """Schema for the tool calls.""" + + schema_specs: list[_SchemaSpec[SchemaT]] + """Schema specs for the tool calls.""" + + tool_message_content: str | None + """The content of the tool message to be returned when the model calls an artificial structured output tool.""" + + handle_errors: Union[ + bool, + str, + type[Exception], + tuple[type[Exception], ...], + Callable[[Exception], str], + ] + """Error handling strategy for structured output via ToolStrategy. Default is True. + + - True: Catch all errors with default error template + - str: Catch all errors with this custom message + - type[Exception]: Only catch this exception type with default message + - tuple[type[Exception], ...]: Only catch these exception types with default message + - Callable[[Exception], str]: Custom function that returns error message + - False: No retry, let exceptions propagate + """ + + def __init__( + self, + schema: type[SchemaT], + *, + tool_message_content: str | None = None, + handle_errors: Union[ + bool, + str, + type[Exception], + tuple[type[Exception], ...], + Callable[[Exception], str], + ] = True, + ) -> None: + """Initialize ToolStrategy with schemas, tool message content, and error handling strategy.""" + self.schema = schema + self.tool_message_content = tool_message_content + self.handle_errors = handle_errors + + def _iter_variants(schema: Any) -> Iterable[Any]: + """Yield leaf variants from Union and JSON Schema oneOf.""" + if get_origin(schema) in (UnionType, Union): + for arg in get_args(schema): + yield from _iter_variants(arg) + return + + if isinstance(schema, dict) and "oneOf" in schema: + for sub in schema.get("oneOf", []): + yield from _iter_variants(sub) + return + + yield schema + + self.schema_specs = [_SchemaSpec(s) for s in _iter_variants(schema)] + + +@dataclass(init=False) +class ProviderStrategy(Generic[SchemaT]): + """Use the model provider's native structured output method.""" + + schema: type[SchemaT] + """Schema for native mode.""" + + schema_spec: _SchemaSpec[SchemaT] + """Schema spec for native mode.""" + + def __init__( + self, + schema: type[SchemaT], + ) -> None: + """Initialize ProviderStrategy with schema.""" + self.schema = schema + self.schema_spec = _SchemaSpec(schema) + + def to_model_kwargs(self) -> dict[str, Any]: + """Convert to kwargs to bind to a model to force structured output.""" + # OpenAI: + # - see https://platform.openai.com/docs/guides/structured-outputs + response_format = { + "type": "json_schema", + "json_schema": { + "name": self.schema_spec.name, + "schema": self.schema_spec.json_schema, + }, + } + return {"response_format": response_format} + + +@dataclass +class OutputToolBinding(Generic[SchemaT]): + """Information for tracking structured output tool metadata. + + This contains all necessary information to handle structured responses + generated via tool calls, including the original schema, its type classification, + and the corresponding tool implementation used by the tools strategy. + """ + + schema: type[SchemaT] + """The original schema provided for structured output (Pydantic model, dataclass, TypedDict, or JSON schema dict).""" + + schema_kind: SchemaKind + """Classification of the schema type for proper response construction.""" + + tool: BaseTool + """LangChain tool instance created from the schema for model binding.""" + + @classmethod + def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self: + """Create an OutputToolBinding instance from a SchemaSpec. + + Args: + schema_spec: The SchemaSpec to convert + + Returns: + An OutputToolBinding instance with the appropriate tool created + """ + return cls( + schema=schema_spec.schema, + schema_kind=schema_spec.schema_kind, + tool=StructuredTool( + args_schema=schema_spec.json_schema, + name=schema_spec.name, + description=schema_spec.description, + ), + ) + + def parse(self, tool_args: dict[str, Any]) -> SchemaT: + """Parse tool arguments according to the schema. + + Args: + tool_args: The arguments from the tool call + + Returns: + The parsed response according to the schema type + + Raises: + ValueError: If parsing fails + """ + return _parse_with_schema(self.schema, self.schema_kind, tool_args) + + +@dataclass +class ProviderStrategyBinding(Generic[SchemaT]): + """Information for tracking native structured output metadata. + + This contains all necessary information to handle structured responses + generated via native provider output, including the original schema, + its type classification, and parsing logic for provider-enforced JSON. + """ + + schema: type[SchemaT] + """The original schema provided for structured output (Pydantic model, dataclass, TypedDict, or JSON schema dict).""" + + schema_kind: SchemaKind + """Classification of the schema type for proper response construction.""" + + @classmethod + def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self: + """Create a ProviderStrategyBinding instance from a SchemaSpec. + + Args: + schema_spec: The SchemaSpec to convert + + Returns: + A ProviderStrategyBinding instance for parsing native structured output + """ + return cls( + schema=schema_spec.schema, + schema_kind=schema_spec.schema_kind, + ) + + def parse(self, response: AIMessage) -> SchemaT: + """Parse AIMessage content according to the schema. + + Args: + response: The AI message containing the structured output + + Returns: + The parsed response according to the schema + + Raises: + ValueError: If text extraction, JSON parsing or schema validation fails + """ + # Extract text content from AIMessage and parse as JSON + raw_text = self._extract_text_content_from_message(response) + + import json + + try: + data = json.loads(raw_text) + except Exception as e: + schema_name = getattr(self.schema, "__name__", "response_format") + msg = f"Native structured output expected valid JSON for {schema_name}, but parsing failed: {e}." + raise ValueError(msg) from e + + # Parse according to schema + return _parse_with_schema(self.schema, self.schema_kind, data) + + def _extract_text_content_from_message(self, message: AIMessage) -> str: + """Extract text content from an AIMessage. + + Args: + message: The AI message to extract text from + + Returns: + The extracted text content + """ + content = message.content + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for c in content: + if isinstance(c, dict): + if c.get("type") == "text" and "text" in c: + parts.append(str(c["text"])) + elif "content" in c and isinstance(c["content"], str): + parts.append(c["content"]) + else: + parts.append(str(c)) + return "".join(parts) + return str(content) + + +ResponseFormat = Union[ToolStrategy[SchemaT], ProviderStrategy[SchemaT]] diff --git a/libs/langchain_v1/langchain/agents/tool_node.py b/libs/langchain_v1/langchain/agents/tool_node.py new file mode 100644 index 00000000000..185a74a291d --- /dev/null +++ b/libs/langchain_v1/langchain/agents/tool_node.py @@ -0,0 +1,1174 @@ +"""Tool execution node for LangGraph workflows. + +This module provides prebuilt functionality for executing tools in LangGraph. + +Tools are functions that models can call to interact with external systems, +APIs, databases, or perform computations. + +The module implements several key design patterns: +- Parallel execution of multiple tool calls for efficiency +- Robust error handling with customizable error messages +- State injection for tools that need access to graph state +- Store injection for tools that need persistent storage +- Command-based state updates for advanced control flow + +Key Components: + ToolNode: Main class for executing tools in LangGraph workflows + InjectedState: Annotation for injecting graph state into tools + InjectedStore: Annotation for injecting persistent store into tools + tools_condition: Utility function for conditional routing based on tool calls + +Typical Usage: + ```python + from langchain_core.tools import tool + from langchain.agents import ToolNode + + @tool + def my_tool(x: int) -> str: + return f"Result: {x}" + + tool_node = ToolNode([my_tool]) + ``` +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +from copy import copy, deepcopy +from dataclasses import replace +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Literal, + Optional, + Union, + cast, + get_args, + get_origin, + get_type_hints, +) + +from langchain_core.messages import ( + AIMessage, + AnyMessage, + RemoveMessage, + ToolCall, + ToolMessage, + convert_to_messages, +) +from langchain_core.runnables.config import ( + get_config_list, + get_executor_for_config, +) +from langchain_core.tools import BaseTool, InjectedToolArg +from langchain_core.tools import tool as create_tool +from langchain_core.tools.base import ( + TOOL_MESSAGE_BLOCK_TYPES, + get_all_basemodel_annotations, +) +from langgraph._internal._runnable import RunnableCallable +from langgraph.errors import GraphBubbleUp +from langgraph.graph.message import REMOVE_ALL_MESSAGES +from langgraph.types import Command, Send +from pydantic import BaseModel, ValidationError + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from langchain_core.runnables import RunnableConfig + from langgraph.store.base import BaseStore + +INVALID_TOOL_NAME_ERROR_TEMPLATE = ( + "Error: {requested_tool} is not a valid tool, try one of [{available_tools}]." +) +TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." +TOOL_EXECUTION_ERROR_TEMPLATE = "Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n {error}\n Please fix the error and try again." +TOOL_INVOCATION_ERROR_TEMPLATE = "Error invoking tool '{tool_name}' with kwargs {tool_kwargs} with error:\n {error}\n Please fix the error and try again." + + +def msg_content_output(output: Any) -> Union[str, list[dict]]: + """Convert tool output to valid message content format. + + LangChain ToolMessages accept either string content or a list of content blocks. + This function ensures tool outputs are properly formatted for message consumption + by attempting to preserve structured data when possible, falling back to JSON + serialization or string conversion. + + Args: + output: The raw output from a tool execution. Can be any type. + + Returns: + Either a string representation of the output or a list of content blocks + if the output is already in the correct format for structured content. + + Note: + This function prioritizes backward compatibility by defaulting to JSON + serialization rather than supporting all possible message content formats. + """ + if isinstance(output, str) or ( + isinstance(output, list) + and all(isinstance(x, dict) and x.get("type") in TOOL_MESSAGE_BLOCK_TYPES for x in output) + ): + return output + # Technically a list of strings is also valid message content, but it's + # not currently well tested that all chat models support this. + # And for backwards compatibility we want to make sure we don't break + # any existing ToolNode usage. + try: + return json.dumps(output, ensure_ascii=False) + except Exception: # noqa: BLE001 + return str(output) + + +class ToolInvocationError(Exception): + """Exception raised when a tool invocation fails due to invalid arguments.""" + + def __init__( + self, tool_name: str, source: ValidationError, tool_kwargs: dict[str, Any] + ) -> None: + """Initialize the ToolInvocationError. + + Args: + tool_name: The name of the tool that failed. + source: The exception that occurred. + tool_kwargs: The keyword arguments that were passed to the tool. + """ + self.message = TOOL_INVOCATION_ERROR_TEMPLATE.format( + tool_name=tool_name, tool_kwargs=tool_kwargs, error=source + ) + self.tool_name = tool_name + self.tool_kwargs = tool_kwargs + self.source = source + super().__init__(self.message) + + +def _default_handle_tool_errors(e: Exception) -> str: + """Default error handler for tool errors. + + If the tool is a tool invocation error, return its message. + Otherwise, raise the error. + """ + if isinstance(e, ToolInvocationError): + return e.message + raise e + + +def _handle_tool_error( + e: Exception, + *, + flag: Union[ + bool, + str, + Callable[..., str], + type[Exception], + tuple[type[Exception], ...], + ], +) -> str: + """Generate error message content based on exception handling configuration. + + This function centralizes error message generation logic, supporting different + error handling strategies configured via the ToolNode's handle_tool_errors + parameter. + + Args: + e: The exception that occurred during tool execution. + flag: Configuration for how to handle the error. Can be: + - bool: If True, use default error template + - str: Use this string as the error message + - Callable: Call this function with the exception to get error message + - tuple: Not used in this context (handled by caller) + + Returns: + A string containing the error message to include in the ToolMessage. + + Raises: + ValueError: If flag is not one of the supported types. + + Note: + The tuple case is handled by the caller through exception type checking, + not by this function directly. + """ + if isinstance(flag, (bool, tuple)) or (isinstance(flag, type) and issubclass(flag, Exception)): + content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) + elif isinstance(flag, str): + content = flag + elif callable(flag): + content = flag(e) # type: ignore [assignment, call-arg] + else: + msg = ( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {flag}" + ) + raise ValueError(msg) + return content + + +def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], ...]: + """Infer exception types handled by a custom error handler function. + + This function analyzes the type annotations of a custom error handler to determine + which exception types it's designed to handle. This enables type-safe error handling + where only specific exceptions are caught and processed by the handler. + + Args: + handler: A callable that takes an exception and returns an error message string. + The first parameter (after self/cls if present) should be type-annotated + with the exception type(s) to handle. + + Returns: + A tuple of exception types that the handler can process. Returns (Exception,) + if no specific type information is available for backward compatibility. + + Raises: + ValueError: If the handler's annotation contains non-Exception types or + if Union types contain non-Exception types. + + Note: + This function supports both single exception types and Union types for + handlers that need to handle multiple exception types differently. + """ + sig = inspect.signature(handler) + params = list(sig.parameters.values()) + if params: + # If it's a method, the first argument is typically 'self' or 'cls' + if params[0].name in ["self", "cls"] and len(params) == 2: + first_param = params[1] + else: + first_param = params[0] + + type_hints = get_type_hints(handler) + if first_param.name in type_hints: + origin = get_origin(first_param.annotation) + if origin is Union: + args = get_args(first_param.annotation) + if all(issubclass(arg, Exception) for arg in args): + return tuple(args) + msg = ( + "All types in the error handler error annotation must be " + "Exception types. For example, " + "`def custom_handler(e: Union[ValueError, TypeError])`. " + f"Got '{first_param.annotation}' instead." + ) + raise ValueError(msg) + + exception_type = type_hints[first_param.name] + if Exception in exception_type.__mro__: + return (exception_type,) + msg = ( + f"Arbitrary types are not supported in the error handler " + f"signature. Please annotate the error with either a " + f"specific Exception type or a union of Exception types. " + "For example, `def custom_handler(e: ValueError)` or " + "`def custom_handler(e: Union[ValueError, TypeError])`. " + f"Got '{exception_type}' instead." + ) + raise ValueError(msg) + + # If no type information is available, return (Exception,) + # for backwards compatibility. + return (Exception,) + + +class ToolNode(RunnableCallable): + """A node for executing tools in LangGraph workflows. + + Handles tool execution patterns including function calls, state injection, + persistent storage, and control flow. Manages parallel execution, + error handling. + + Input Formats: + 1. Graph state with `messages` key that has a list of messages: + - Common representation for agentic workflows + - Supports custom messages key via ``messages_key`` parameter + + 2. **Message List**: ``[AIMessage(..., tool_calls=[...])]`` + - List of messages with tool calls in the last AIMessage + + 3. **Direct Tool Calls**: ``[{"name": "tool", "args": {...}, "id": "1", "type": "tool_call"}]`` + - Bypasses message parsing for direct tool execution + - For programmatic tool invocation and testing + + Output Formats: + Output format depends on input type and tool behavior: + + **For Regular tools**: + - Dict input → ``{"messages": [ToolMessage(...)]}`` + - List input → ``[ToolMessage(...)]`` + + **For Command tools**: + - Returns ``[Command(...)]`` or mixed list with regular tool outputs + - Commands can update state, trigger navigation, or send messages + + Args: + tools: A sequence of tools that can be invoked by this node. Supports: + - **BaseTool instances**: Tools with schemas and metadata + - **Plain functions**: Automatically converted to tools with inferred schemas + name: The name identifier for this node in the graph. Used for debugging + and visualization. Defaults to "tools". + tags: Optional metadata tags to associate with the node for filtering + and organization. Defaults to None. + handle_tool_errors: Configuration for error handling during tool execution. + Supports multiple strategies: + + - **True**: Catch all errors and return a ToolMessage with the default + error template containing the exception details. + - **str**: Catch all errors and return a ToolMessage with this custom + error message string. + - **type[Exception]**: Only catch exceptions with the specified type and return the default error message for it. + - **tuple[type[Exception], ...]**: Only catch exceptions with the specified + types and return default error messages for them. + - **Callable[..., str]**: Catch exceptions matching the callable's signature + and return the string result of calling it with the exception. + - **False**: Disable error handling entirely, allowing exceptions to + propagate. + + Defaults to a callable that: + - catches tool invocation errors (due to invalid arguments provided by the model) and returns a descriptive error message + - ignores tool execution errors (they will be re-raised) + + messages_key: The key in the state dictionary that contains the message list. + This same key will be used for the output ToolMessages. + Defaults to "messages". + Allows custom state schemas with different message field names. + + Examples: + Basic usage: + + ```python + from langchain.agents import ToolNode + from langchain_core.tools import tool + + @tool + def calculator(a: int, b: int) -> int: + \"\"\"Add two numbers.\"\"\" + return a + b + + tool_node = ToolNode([calculator]) + ``` + + State injection: + + ```python + from typing_extensions import Annotated + from langgraph.agents.tool_node import InjectedState + + @tool + def context_tool(query: str, state: Annotated[dict, InjectedState]) -> str: + \"\"\"Some tool that uses state.\"\"\" + return f"Query: {query}, Messages: {len(state['messages'])}" + + tool_node = ToolNode([context_tool]) + ``` + + Error handling: + + ```python + def handle_errors(e: ValueError) -> str: + return "Invalid input provided" + + tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors) + ``` + """ + + name: str = "tools" + + def __init__( + self, + tools: Sequence[Union[BaseTool, Callable]], + *, + name: str = "tools", + tags: list[str] | None = None, + handle_tool_errors: Union[ + bool, str, Callable[..., str], type[Exception], tuple[type[Exception], ...] + ] = _default_handle_tool_errors, + messages_key: str = "messages", + ) -> None: + """Initialize the ToolNode with the provided tools and configuration. + + Args: + tools: Sequence of tools to make available for execution. + name: Node name for graph identification. + tags: Optional metadata tags. + handle_tool_errors: Error handling configuration. + messages_key: State key containing messages. + """ + super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) + self._tools_by_name: dict[str, BaseTool] = {} + self._tool_to_state_args: dict[str, dict[str, str | None]] = {} + self._tool_to_store_arg: dict[str, str | None] = {} + self._handle_tool_errors = handle_tool_errors + self._messages_key = messages_key + for tool in tools: + if not isinstance(tool, BaseTool): + tool_ = create_tool(cast("type[BaseTool]", tool)) + else: + tool_ = tool + self._tools_by_name[tool_.name] = tool_ + self._tool_to_state_args[tool_.name] = _get_state_args(tool_) + self._tool_to_store_arg[tool_.name] = _get_store_arg(tool_) + + @property + def tools_by_name(self) -> dict[str, BaseTool]: + """Mapping from tool name to BaseTool instance.""" + return self._tools_by_name + + def _func( + self, + input: Union[ + list[AnyMessage], + dict[str, Any], + BaseModel, + ], + config: RunnableConfig, + *, + store: Optional[BaseStore], # noqa: UP045 + ) -> Any: + tool_calls, input_type = self._parse_input(input, store) + config_list = get_config_list(config, len(tool_calls)) + input_types = [input_type] * len(tool_calls) + with get_executor_for_config(config) as executor: + outputs = [*executor.map(self._run_one, tool_calls, input_types, config_list)] + + return self._combine_tool_outputs(outputs, input_type) + + async def _afunc( + self, + input: Union[ + list[AnyMessage], + dict[str, Any], + BaseModel, + ], + config: RunnableConfig, + *, + store: Optional[BaseStore], # noqa: UP045 + ) -> Any: + tool_calls, input_type = self._parse_input(input, store) + outputs = await asyncio.gather( + *(self._arun_one(call, input_type, config) for call in tool_calls) + ) + + return self._combine_tool_outputs(outputs, input_type) + + def _combine_tool_outputs( + self, + outputs: list[Union[ToolMessage, Command]], + input_type: Literal["list", "dict", "tool_calls"], + ) -> list[Union[Command, list[ToolMessage], dict[str, list[ToolMessage]]]]: + # preserve existing behavior for non-command tool outputs for backwards + # compatibility + if not any(isinstance(output, Command) for output in outputs): + # TypedDict, pydantic, dataclass, etc. should all be able to load from dict + return outputs if input_type == "list" else {self._messages_key: outputs} # type: ignore[return-value, return-value] + + # LangGraph will automatically handle list of Command and non-command node + # updates + combined_outputs: list[Command | list[ToolMessage] | dict[str, list[ToolMessage]]] = [] + + # combine all parent commands with goto into a single parent command + parent_command: Command | None = None + for output in outputs: + if isinstance(output, Command): + if ( + output.graph is Command.PARENT + and isinstance(output.goto, list) + and all(isinstance(send, Send) for send in output.goto) + ): + if parent_command: + parent_command = replace( + parent_command, + goto=cast("list[Send]", parent_command.goto) + output.goto, + ) + else: + parent_command = Command(graph=Command.PARENT, goto=output.goto) + else: + combined_outputs.append(output) + else: + combined_outputs.append( + [output] if input_type == "list" else {self._messages_key: [output]} + ) + + if parent_command: + combined_outputs.append(parent_command) + return combined_outputs + + def _run_one( + self, + call: ToolCall, + input_type: Literal["list", "dict", "tool_calls"], + config: RunnableConfig, + ) -> Union[ToolMessage, Command]: + """Run a single tool call synchronously.""" + if invalid_tool_message := self._validate_tool_call(call): + return invalid_tool_message + + try: + call_args = {**call, "type": "tool_call"} + tool = self.tools_by_name[call["name"]] + + try: + response = tool.invoke(call_args, config) + except ValidationError as exc: + raise ToolInvocationError(call["name"], exc, call["args"]) from exc + + # GraphInterrupt is a special exception that will always be raised. + # It can be triggered in the following scenarios, + # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation most commonly: + # (1) a GraphInterrupt is raised inside a tool + # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool + # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool + # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture) + except GraphBubbleUp: + raise + except Exception as e: + handled_types: tuple[type[Exception], ...] + if isinstance(self._handle_tool_errors, type) and issubclass( + self._handle_tool_errors, Exception + ): + handled_types = (self._handle_tool_errors,) + elif isinstance(self._handle_tool_errors, tuple): + handled_types = self._handle_tool_errors + elif callable(self._handle_tool_errors) and not isinstance( + self._handle_tool_errors, type + ): + handled_types = _infer_handled_types(self._handle_tool_errors) + else: + # default behavior is catching all exceptions + handled_types = (Exception,) + + # Unhandled + if not self._handle_tool_errors or not isinstance(e, handled_types): + raise + # Handled + content = _handle_tool_error(e, flag=self._handle_tool_errors) + return ToolMessage( + content=content, + name=call["name"], + tool_call_id=call["id"], + status="error", + ) + + if isinstance(response, Command): + return self._validate_tool_command(response, call, input_type) + if isinstance(response, ToolMessage): + response.content = cast("Union[str, list]", msg_content_output(response.content)) + return response + msg = f"Tool {call['name']} returned unexpected type: {type(response)}" + raise TypeError(msg) + + async def _arun_one( + self, + call: ToolCall, + input_type: Literal["list", "dict", "tool_calls"], + config: RunnableConfig, + ) -> Union[ToolMessage, Command]: + """Run a single tool call asynchronously.""" + if invalid_tool_message := self._validate_tool_call(call): + return invalid_tool_message + + try: + call_args = {**call, "type": "tool_call"} + tool = self.tools_by_name[call["name"]] + + try: + response = await tool.ainvoke(call_args, config) + except ValidationError as exc: + raise ToolInvocationError(call["name"], exc, call["args"]) from exc + + # GraphInterrupt is a special exception that will always be raised. + # It can be triggered in the following scenarios, + # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation most commonly: + # (1) a GraphInterrupt is raised inside a tool + # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool + # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool + # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture) + except GraphBubbleUp: + raise + except Exception as e: + handled_types: tuple[type[Exception], ...] + if isinstance(self._handle_tool_errors, type) and issubclass( + self._handle_tool_errors, Exception + ): + handled_types = (self._handle_tool_errors,) + elif isinstance(self._handle_tool_errors, tuple): + handled_types = self._handle_tool_errors + elif callable(self._handle_tool_errors) and not isinstance( + self._handle_tool_errors, type + ): + handled_types = _infer_handled_types(self._handle_tool_errors) + else: + # default behavior is catching all exceptions + handled_types = (Exception,) + + # Unhandled + if not self._handle_tool_errors or not isinstance(e, handled_types): + raise + # Handled + content = _handle_tool_error(e, flag=self._handle_tool_errors) + + return ToolMessage( + content=content, + name=call["name"], + tool_call_id=call["id"], + status="error", + ) + + if isinstance(response, Command): + return self._validate_tool_command(response, call, input_type) + if isinstance(response, ToolMessage): + response.content = cast("Union[str, list]", msg_content_output(response.content)) + return response + msg = f"Tool {call['name']} returned unexpected type: {type(response)}" + raise TypeError(msg) + + def _parse_input( + self, + input: Union[ + list[AnyMessage], + dict[str, Any], + BaseModel, + ], + store: BaseStore | None, + ) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]: + input_type: Literal["list", "dict", "tool_calls"] + if isinstance(input, list): + if isinstance(input[-1], dict) and input[-1].get("type") == "tool_call": + input_type = "tool_calls" + tool_calls = cast("list[ToolCall]", input) + return tool_calls, input_type + input_type = "list" + messages = input + elif isinstance(input, dict) and (messages := input.get(self._messages_key, [])): + input_type = "dict" + elif messages := getattr(input, self._messages_key, []): + # Assume dataclass-like state that can coerce from dict + input_type = "dict" + else: + msg = "No message found in input" + raise ValueError(msg) + + try: + latest_ai_message = next(m for m in reversed(messages) if isinstance(m, AIMessage)) + except StopIteration: + msg = "No AIMessage found in input" + raise ValueError(msg) + + tool_calls = [ + self.inject_tool_args(call, input, store) for call in latest_ai_message.tool_calls + ] + return tool_calls, input_type + + def _validate_tool_call(self, call: ToolCall) -> ToolMessage | None: + requested_tool = call["name"] + if requested_tool not in self.tools_by_name: + all_tool_names = list(self.tools_by_name.keys()) + content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format( + requested_tool=requested_tool, + available_tools=", ".join(all_tool_names), + ) + return ToolMessage( + content, name=requested_tool, tool_call_id=call["id"], status="error" + ) + return None + + def _inject_state( + self, + tool_call: ToolCall, + input: Union[ + list[AnyMessage], + dict[str, Any], + BaseModel, + ], + ) -> ToolCall: + state_args = self._tool_to_state_args[tool_call["name"]] + if state_args and isinstance(input, list): + required_fields = list(state_args.values()) + if ( + len(required_fields) == 1 and required_fields[0] == self._messages_key + ) or required_fields[0] is None: + input = {self._messages_key: input} + else: + err_msg = ( + f"Invalid input to ToolNode. Tool {tool_call['name']} requires " + f"graph state dict as input." + ) + if any(state_field for state_field in state_args.values()): + required_fields_str = ", ".join(f for f in required_fields if f) + err_msg += f" State should contain fields {required_fields_str}." + raise ValueError(err_msg) + + if isinstance(input, dict): + tool_state_args = { + tool_arg: input[state_field] if state_field else input + for tool_arg, state_field in state_args.items() + } + else: + tool_state_args = { + tool_arg: getattr(input, state_field) if state_field else input + for tool_arg, state_field in state_args.items() + } + + tool_call["args"] = { + **tool_call["args"], + **tool_state_args, + } + return tool_call + + def _inject_store(self, tool_call: ToolCall, store: BaseStore | None) -> ToolCall: + store_arg = self._tool_to_store_arg[tool_call["name"]] + if not store_arg: + return tool_call + + if store is None: + msg = ( + "Cannot inject store into tools with InjectedStore annotations - " + "please compile your graph with a store." + ) + raise ValueError(msg) + + tool_call["args"] = { + **tool_call["args"], + store_arg: store, + } + return tool_call + + def inject_tool_args( + self, + tool_call: ToolCall, + input: Union[ + list[AnyMessage], + dict[str, Any], + BaseModel, + ], + store: BaseStore | None, + ) -> ToolCall: + """Inject graph state and store into tool call arguments. + + This method enables tools to access graph context that should not be controlled + by the model. Tools can declare dependencies on graph state or persistent storage + using InjectedState and InjectedStore annotations. This method automatically + identifies these dependencies and injects the appropriate values. + + The injection process preserves the original tool call structure while adding + the necessary context arguments. This allows tools to be both model-callable + and context-aware without exposing internal state management to the model. + + Args: + tool_call: The tool call dictionary to augment with injected arguments. + Must contain 'name', 'args', 'id', and 'type' fields. + input: The current graph state to inject into tools requiring state access. + Can be a message list, state dictionary, or BaseModel instance. + store: The persistent store instance to inject into tools requiring storage. + Will be None if no store is configured for the graph. + + Returns: + A new ToolCall dictionary with the same structure as the input but with + additional arguments injected based on the tool's annotation requirements. + + Raises: + ValueError: If a tool requires store injection but no store is provided, + or if state injection requirements cannot be satisfied. + + Note: + This method is automatically called during tool execution but can also + be used manually when working with the Send API or custom routing logic. + The injection is performed on a copy of the tool call to avoid mutating + the original. + """ + if tool_call["name"] not in self.tools_by_name: + return tool_call + + tool_call_copy: ToolCall = copy(tool_call) + tool_call_with_state = self._inject_state(tool_call_copy, input) + return self._inject_store(tool_call_with_state, store) + + def _validate_tool_command( + self, + command: Command, + call: ToolCall, + input_type: Literal["list", "dict", "tool_calls"], + ) -> Command: + if isinstance(command.update, dict): + # input type is dict when ToolNode is invoked with a dict input (e.g. {"messages": [AIMessage(..., tool_calls=[...])]}) + if input_type not in ("dict", "tool_calls"): + msg = ( + f"Tools can provide a dict in Command.update only when using dict with '{self._messages_key}' key as ToolNode input, " + f"got: {command.update} for tool '{call['name']}'" + ) + raise ValueError(msg) + + updated_command = deepcopy(command) + state_update = cast("dict[str, Any]", updated_command.update) or {} + messages_update = state_update.get(self._messages_key, []) + elif isinstance(command.update, list): + # Input type is list when ToolNode is invoked with a list input (e.g. [AIMessage(..., tool_calls=[...])]) + if input_type != "list": + msg = ( + f"Tools can provide a list of messages in Command.update only when using list of messages as ToolNode input, " + f"got: {command.update} for tool '{call['name']}'" + ) + raise ValueError(msg) + + updated_command = deepcopy(command) + messages_update = updated_command.update + else: + return command + + # convert to message objects if updates are in a dict format + messages_update = convert_to_messages(messages_update) + + # no validation needed if all messages are being removed + if messages_update == [RemoveMessage(id=REMOVE_ALL_MESSAGES)]: + return updated_command + + has_matching_tool_message = False + for message in messages_update: + if not isinstance(message, ToolMessage): + continue + + if message.tool_call_id == call["id"]: + message.name = call["name"] + has_matching_tool_message = True + + # validate that we always have a ToolMessage matching the tool call in + # Command.update if command is sent to the CURRENT graph + if updated_command.graph is None and not has_matching_tool_message: + example_update = ( + '`Command(update={"messages": [ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`' + if input_type == "dict" + else '`Command(update=[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`' + ) + msg = ( + f"Expected to have a matching ToolMessage in Command.update for tool '{call['name']}', got: {messages_update}. " + "Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage. " + f"You can fix it by modifying the tool to return {example_update}." + ) + raise ValueError(msg) + return updated_command + + +def tools_condition( + state: Union[list[AnyMessage], dict[str, Any], BaseModel], + messages_key: str = "messages", +) -> Literal["tools", "__end__"]: + """Conditional routing function for tool-calling workflows. + + This utility function implements the standard conditional logic for ReAct-style + agents: if the last AI message contains tool calls, route to the tool execution + node; otherwise, end the workflow. This pattern is fundamental to most tool-calling + agent architectures. + + The function handles multiple state formats commonly used in LangGraph applications, + making it flexible for different graph designs while maintaining consistent behavior. + + Args: + state: The current graph state to examine for tool calls. Supported formats: + - Dictionary containing a messages key (for StateGraph) + - BaseModel instance with a messages attribute + messages_key: The key or attribute name containing the message list in the state. + This allows customization for graphs using different state schemas. + Defaults to "messages". + + Returns: + Either "tools" if tool calls are present in the last AI message, or "__end__" + to terminate the workflow. These are the standard routing destinations for + tool-calling conditional edges. + + Raises: + ValueError: If no messages can be found in the provided state format. + + Example: + Basic usage in a ReAct agent: + + ```python + from langgraph.graph import StateGraph + from langgraph.agents.tool_node import ToolNode, tools_condition + from typing_extensions import TypedDict + + class State(TypedDict): + messages: list + + graph = StateGraph(State) + graph.add_node("llm", call_model) + graph.add_node("tools", ToolNode([my_tool])) + graph.add_conditional_edges( + "llm", + tools_condition, # Routes to "tools" or "__end__" + {"tools": "tools", "__end__": "__end__"} + ) + ``` + + Custom messages key: + + ```python + def custom_condition(state): + return tools_condition(state, messages_key="chat_history") + ``` + + Note: + This function is designed to work seamlessly with ToolNode and standard + LangGraph patterns. It expects the last message to be an AIMessage when + tool calls are present, which is the standard output format for tool-calling + language models. + """ + if isinstance(state, list): + ai_message = state[-1] + elif (isinstance(state, dict) and (messages := state.get(messages_key, []))) or ( + messages := getattr(state, messages_key, []) + ): + ai_message = messages[-1] + else: + msg = f"No messages found in input state to tool_edge: {state}" + raise ValueError(msg) + if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + return "tools" + return "__end__" + + +class InjectedState(InjectedToolArg): + """Annotation for injecting graph state into tool arguments. + + This annotation enables tools to access graph state without exposing state + management details to the language model. Tools annotated with InjectedState + receive state data automatically during execution while remaining invisible + to the model's tool-calling interface. + + Args: + field: Optional key to extract from the state dictionary. If None, the entire + state is injected. If specified, only that field's value is injected. + This allows tools to request specific state components rather than + processing the full state structure. + + Example: + ```python + from typing import List + from typing_extensions import Annotated, TypedDict + + from langchain_core.messages import BaseMessage, AIMessage + from langchain_core.tools import tool + + from langgraph.agents.tool_node import InjectedState, ToolNode + + + class AgentState(TypedDict): + messages: List[BaseMessage] + foo: str + + @tool + def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str: + '''Do something with state.''' + if len(state["messages"]) > 2: + return state["foo"] + str(x) + else: + return "not enough messages" + + @tool + def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str: + '''Do something else with state.''' + return foo + str(x + 1) + + node = ToolNode([state_tool, foo_tool]) + + tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"} + tool_call2 = {"name": "foo_tool", "args": {"x": 1}, "id": "2", "type": "tool_call"} + state = { + "messages": [AIMessage("", tool_calls=[tool_call1, tool_call2])], + "foo": "bar", + } + node.invoke(state) + ``` + + ```pycon + [ + ToolMessage(content='not enough messages', name='state_tool', tool_call_id='1'), + ToolMessage(content='bar2', name='foo_tool', tool_call_id='2') + ] + ``` + + Note: + - InjectedState arguments are automatically excluded from tool schemas + presented to language models + - ToolNode handles the injection process during execution + - Tools can mix regular arguments (controlled by the model) with injected + arguments (controlled by the system) + - State injection occurs after the model generates tool calls but before + tool execution + """ + + def __init__(self, field: str | None = None) -> None: + """Initialize the InjectedState annotation.""" + self.field = field + + +class InjectedStore(InjectedToolArg): + """Annotation for injecting persistent store into tool arguments. + + This annotation enables tools to access LangGraph's persistent storage system + without exposing storage details to the language model. Tools annotated with + InjectedStore receive the store instance automatically during execution while + remaining invisible to the model's tool-calling interface. + + The store provides persistent, cross-session data storage that tools can use + for maintaining context, user preferences, or any other data that needs to + persist beyond individual workflow executions. + + !!! Warning + `InjectedStore` annotation requires `langchain-core >= 0.3.8` + + Example: + ```python + from typing_extensions import Annotated + from langchain_core.tools import tool + from langgraph.store.memory import InMemoryStore + from langgraph.agents.tool_node import InjectedStore, ToolNode + + @tool + def save_preference( + key: str, + value: str, + store: Annotated[Any, InjectedStore()] + ) -> str: + \"\"\"Save user preference to persistent storage.\"\"\" + store.put(("preferences",), key, value) + return f"Saved {key} = {value}" + + @tool + def get_preference( + key: str, + store: Annotated[Any, InjectedStore()] + ) -> str: + \"\"\"Retrieve user preference from persistent storage.\"\"\" + result = store.get(("preferences",), key) + return result.value if result else "Not found" + ``` + + Usage with ToolNode and graph compilation: + + ```python + from langgraph.graph import StateGraph + from langgraph.store.memory import InMemoryStore + + store = InMemoryStore() + tool_node = ToolNode([save_preference, get_preference]) + + graph = StateGraph(State) + graph.add_node("tools", tool_node) + compiled_graph = graph.compile(store=store) # Store is injected automatically + ``` + + Cross-session persistence: + + ```python + # First session + result1 = graph.invoke({"messages": [HumanMessage("Save my favorite color as blue")]}) + + # Later session - data persists + result2 = graph.invoke({"messages": [HumanMessage("What's my favorite color?")]}) + ``` + + Note: + - InjectedStore arguments are automatically excluded from tool schemas + presented to language models + - The store instance is automatically injected by ToolNode during execution + - Tools can access namespaced storage using the store's get/put methods + - Store injection requires the graph to be compiled with a store instance + - Multiple tools can share the same store instance for data consistency + """ + + +def _is_injection(type_arg: Any, injection_type: type[Union[InjectedState, InjectedStore]]) -> bool: + """Check if a type argument represents an injection annotation. + + This utility function determines whether a type annotation indicates that + an argument should be injected with state or store data. It handles both + direct annotations and nested annotations within Union or Annotated types. + + Args: + type_arg: The type argument to check for injection annotations. + injection_type: The injection type to look for (InjectedState or InjectedStore). + + Returns: + True if the type argument contains the specified injection annotation. + """ + if isinstance(type_arg, injection_type) or ( + isinstance(type_arg, type) and issubclass(type_arg, injection_type) + ): + return True + origin_ = get_origin(type_arg) + if origin_ is Union or origin_ is Annotated: + return any(_is_injection(ta, injection_type) for ta in get_args(type_arg)) + return False + + +def _get_state_args(tool: BaseTool) -> dict[str, str | None]: + """Extract state injection mappings from tool annotations. + + This function analyzes a tool's input schema to identify arguments that should + be injected with graph state. It processes InjectedState annotations to build + a mapping of tool argument names to state field names. + + Args: + tool: The tool to analyze for state injection requirements. + + Returns: + A dictionary mapping tool argument names to state field names. If a field + name is None, the entire state should be injected for that argument. + """ + full_schema = tool.get_input_schema() + tool_args_to_state_fields: dict = {} + + for name, type_ in get_all_basemodel_annotations(full_schema).items(): + injections = [ + type_arg for type_arg in get_args(type_) if _is_injection(type_arg, InjectedState) + ] + if len(injections) > 1: + msg = ( + "A tool argument should not be annotated with InjectedState more than " + f"once. Received arg {name} with annotations {injections}." + ) + raise ValueError(msg) + if len(injections) == 1: + injection = injections[0] + if isinstance(injection, InjectedState) and injection.field: + tool_args_to_state_fields[name] = injection.field + else: + tool_args_to_state_fields[name] = None + else: + pass + return tool_args_to_state_fields + + +def _get_store_arg(tool: BaseTool) -> str | None: + """Extract store injection argument from tool annotations. + + This function analyzes a tool's input schema to identify the argument that + should be injected with the graph store. Only one store argument is supported + per tool. + + Args: + tool: The tool to analyze for store injection requirements. + + Returns: + The name of the argument that should receive the store injection, or None + if no store injection is required. + + Raises: + ValueError: If a tool argument has multiple InjectedStore annotations. + """ + full_schema = tool.get_input_schema() + for name, type_ in get_all_basemodel_annotations(full_schema).items(): + injections = [ + type_arg for type_arg in get_args(type_) if _is_injection(type_arg, InjectedStore) + ] + if len(injections) > 1: + msg = ( + "A tool argument should not be annotated with InjectedStore more than " + f"once. Received arg {name} with annotations {injections}." + ) + raise ValueError(msg) + if len(injections) == 1: + return name + + return None diff --git a/libs/langchain_v1/langchain/chat_models/base.py b/libs/langchain_v1/langchain/chat_models/base.py index f222aa1b86c..f9ae98e4368 100644 --- a/libs/langchain_v1/langchain/chat_models/base.py +++ b/libs/langchain_v1/langchain/chat_models/base.py @@ -13,7 +13,7 @@ from typing import ( ) from langchain_core.language_models import BaseChatModel, LanguageModelInput -from langchain_core.messages import AnyMessage, BaseMessage +from langchain_core.messages import AIMessage, AnyMessage from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from typing_extensions import override @@ -915,7 +915,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): self, tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], **kwargs: Any, - ) -> Runnable[LanguageModelInput, BaseMessage]: + ) -> Runnable[LanguageModelInput, AIMessage]: return self.__getattr__("bind_tools")(tools, **kwargs) # Explicitly added to satisfy downstream linters. diff --git a/libs/langchain_v1/pyproject.toml b/libs/langchain_v1/pyproject.toml index b05002d4c2f..f471a309b5c 100644 --- a/libs/langchain_v1/pyproject.toml +++ b/libs/langchain_v1/pyproject.toml @@ -52,11 +52,11 @@ test = [ "pytest-socket>=0.6.0", "syrupy>=4.0.2", "pytest-xdist>=3.6.1", - "blockbuster>=1.5.18", "langchain-tests", "langchain-text-splitters", "langchain-openai", "toml>=0.10.2", + "pytest-mock" ] codespell = ["codespell<3.0.0,>=2.2.0"] lint = [ @@ -94,6 +94,7 @@ strict_bytes = "True" ignore_missing_imports = "True" enable_error_code = "deprecated" report_deprecated_as_note = "True" +exclude = ["tests/unit_tests/agents/*", "tests/integration_tests/agents/*"] # TODO: activate for 'strict' checking disallow_untyped_calls = "False" @@ -135,14 +136,24 @@ flake8-annotations.allow-star-arg-any = true "D", # Documentation rules "PLC0415", # Imports should be at the top. Not always desirable for tests ] +"langchain/agents/*" = [ + "ANN401", # we use Any right now, need to narrow + "E501", # line too long, needs to fix + "A002", # input is shadowing builtin + "A001", # input is shadowing builtin + "B904", # use from for exceptions + "PLR2004", # magic values are fine for this case + "C901", # too complex + "TRY004", # type error exception + "PLR0912", # too many branches + "PLR0911", # too many return statements +] +"tests/unit_tests/agents/*" = ["ALL"] +"tests/integration_tests/agents/*" = ["ALL"] [tool.ruff.lint.extend-per-file-ignores] "scripts/check_imports.py" = ["ALL"] -"langchain/globals.py" = [ - "PLW" -] - "langchain/chat_models/base.py" = [ "ANN", "C901", diff --git a/libs/langchain_v1/tests/integration_tests/agents/__init__.py b/libs/langchain_v1/tests/integration_tests/agents/__init__.py new file mode 100644 index 00000000000..6c212a0b69d --- /dev/null +++ b/libs/langchain_v1/tests/integration_tests/agents/__init__.py @@ -0,0 +1 @@ +"""Integration tests for the agents module.""" diff --git a/libs/langchain_v1/tests/integration_tests/agents/test_response_format.py b/libs/langchain_v1/tests/integration_tests/agents/test_response_format.py new file mode 100644 index 00000000000..6836cc692cd --- /dev/null +++ b/libs/langchain_v1/tests/integration_tests/agents/test_response_format.py @@ -0,0 +1,79 @@ +import pytest +from langchain_core.messages import HumanMessage +from pydantic import BaseModel, Field + +from langchain.agents import create_react_agent +from langchain.agents.structured_output import ToolStrategy + + +class WeatherBaseModel(BaseModel): + """Weather response.""" + + temperature: float = Field(description="The temperature in fahrenheit") + condition: str = Field(description="Weather condition") + + +def get_weather(city: str) -> str: # noqa: ARG001 + """Get the weather for a city.""" + return "The weather is sunny and 75°F." + + +@pytest.mark.requires("langchain_openai") +def test_inference_to_native_output() -> None: + """Test that native output is inferred when a model supports it.""" + from langchain_openai import ChatOpenAI + + model = ChatOpenAI(model="gpt-5") + agent = create_react_agent( + model, + prompt=( + "You are a helpful weather assistant. Please call the get_weather tool, " + "then use the WeatherReport tool to generate the final response." + ), + tools=[get_weather], + response_format=WeatherBaseModel, + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert isinstance(response["structured_response"], WeatherBaseModel) + assert response["structured_response"].temperature == 75.0 + assert response["structured_response"].condition.lower() == "sunny" + assert len(response["messages"]) == 4 + + assert [m.type for m in response["messages"]] == [ + "human", # "What's the weather?" + "ai", # "What's the weather?" + "tool", # "The weather is sunny and 75°F." + "ai", # structured response + ] + + +@pytest.mark.requires("langchain_openai") +def test_inference_to_tool_output() -> None: + """Test that tool output is inferred when a model supports it.""" + from langchain_openai import ChatOpenAI + + model = ChatOpenAI(model="gpt-4") + agent = create_react_agent( + model, + prompt=( + "You are a helpful weather assistant. Please call the get_weather tool, " + "then use the WeatherReport tool to generate the final response." + ), + tools=[get_weather], + response_format=ToolStrategy(WeatherBaseModel), + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert isinstance(response["structured_response"], WeatherBaseModel) + assert response["structured_response"].temperature == 75.0 + assert response["structured_response"].condition.lower() == "sunny" + assert len(response["messages"]) == 5 + + assert [m.type for m in response["messages"]] == [ + "human", # "What's the weather?" + "ai", # "What's the weather?" + "tool", # "The weather is sunny and 75°F." + "ai", # structured response + "tool", # artificial tool message + ] diff --git a/libs/langchain_v1/tests/unit_tests/agents/__init__.py b/libs/langchain_v1/tests/unit_tests/agents/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_react_agent_graph.ambr b/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_react_agent_graph.ambr new file mode 100644 index 00000000000..49e40649e57 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_react_agent_graph.ambr @@ -0,0 +1,83 @@ +# serializer version: 1 +# name: test_react_agent_graph_structure[None-None-tools0] + ''' + graph TD; + __start__ --> agent; + agent --> __end__; + + ''' +# --- +# name: test_react_agent_graph_structure[None-None-tools1] + ''' + graph TD; + __start__ --> agent; + agent -.-> __end__; + agent -.-> tools; + tools --> agent; + + ''' +# --- +# name: test_react_agent_graph_structure[None-pre_model_hook-tools0] + ''' + graph TD; + __start__ --> pre_model_hook; + pre_model_hook --> agent; + agent --> __end__; + + ''' +# --- +# name: test_react_agent_graph_structure[None-pre_model_hook-tools1] + ''' + graph TD; + __start__ --> pre_model_hook; + agent -.-> __end__; + agent -.-> tools; + pre_model_hook --> agent; + tools --> pre_model_hook; + + ''' +# --- +# name: test_react_agent_graph_structure[post_model_hook-None-tools0] + ''' + graph TD; + __start__ --> agent; + agent --> post_model_hook; + post_model_hook --> __end__; + + ''' +# --- +# name: test_react_agent_graph_structure[post_model_hook-None-tools1] + ''' + graph TD; + __start__ --> agent; + agent --> post_model_hook; + post_model_hook -.-> __end__; + post_model_hook -.-> agent; + post_model_hook -.-> tools; + tools --> agent; + + ''' +# --- +# name: test_react_agent_graph_structure[post_model_hook-pre_model_hook-tools0] + ''' + graph TD; + __start__ --> pre_model_hook; + agent --> post_model_hook; + pre_model_hook --> agent; + post_model_hook --> __end__; + + ''' +# --- +# name: test_react_agent_graph_structure[post_model_hook-pre_model_hook-tools1] + ''' + graph TD; + __start__ --> pre_model_hook; + agent --> post_model_hook; + post_model_hook -.-> __end__; + post_model_hook -.-> pre_model_hook; + post_model_hook -.-> tools; + pre_model_hook --> agent; + tools --> pre_model_hook; + + ''' +# --- diff --git a/libs/langchain_v1/tests/unit_tests/agents/any_str.py b/libs/langchain_v1/tests/unit_tests/agents/any_str.py new file mode 100644 index 00000000000..790324322dc --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/any_str.py @@ -0,0 +1,18 @@ +import re +from typing import Union + + +class AnyStr(str): + def __init__(self, prefix: Union[str, re.Pattern] = "") -> None: + super().__init__() + self.prefix = prefix + + def __eq__(self, other: object) -> bool: + return isinstance(other, str) and ( + other.startswith(self.prefix) + if isinstance(self.prefix, str) + else self.prefix.match(other) + ) + + def __hash__(self) -> int: + return hash((str(self), self.prefix)) diff --git a/libs/langchain_v1/tests/unit_tests/agents/compose-postgres.yml b/libs/langchain_v1/tests/unit_tests/agents/compose-postgres.yml new file mode 100644 index 00000000000..221b35dafe4 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/compose-postgres.yml @@ -0,0 +1,17 @@ +name: langgraph-tests +services: + postgres-test: + image: postgres:16 + ports: + - "5442:5432" + environment: + POSTGRES_DB: postgres + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + healthcheck: + test: pg_isready -U postgres + start_period: 10s + timeout: 1s + retries: 5 + interval: 60s + start_interval: 1s diff --git a/libs/langchain_v1/tests/unit_tests/agents/compose-redis.yml b/libs/langchain_v1/tests/unit_tests/agents/compose-redis.yml new file mode 100644 index 00000000000..18862fd28dd --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/compose-redis.yml @@ -0,0 +1,16 @@ +name: langgraph-tests-redis +services: + redis-test: + image: redis:7-alpine + ports: + - "6379:6379" + command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru + healthcheck: + test: redis-cli ping + start_period: 10s + timeout: 1s + retries: 5 + interval: 5s + start_interval: 1s + tmpfs: + - /data # Use tmpfs for faster testing diff --git a/libs/langchain_v1/tests/unit_tests/agents/conftest.py b/libs/langchain_v1/tests/unit_tests/agents/conftest.py new file mode 100644 index 00000000000..7e134f68271 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/conftest.py @@ -0,0 +1,194 @@ +import os +from collections.abc import AsyncIterator, Iterator +from uuid import UUID + +import pytest +from langgraph.checkpoint.base import BaseCheckpointSaver +from langgraph.store.base import BaseStore +from pytest_mock import MockerFixture + +from .conftest_checkpointer import ( + _checkpointer_memory, + _checkpointer_postgres, + _checkpointer_postgres_aio, + _checkpointer_postgres_aio_pipe, + _checkpointer_postgres_aio_pool, + _checkpointer_postgres_pipe, + _checkpointer_postgres_pool, + _checkpointer_sqlite, + _checkpointer_sqlite_aio, +) +from .conftest_store import ( + _store_memory, + _store_postgres, + _store_postgres_aio, + _store_postgres_aio_pipe, + _store_postgres_aio_pool, + _store_postgres_pipe, + _store_postgres_pool, +) + +# Global variables for checkpointer and store configurations +FAST_MODE = os.getenv("LANGGRAPH_TEST_FAST", "true").lower() in ("true", "1", "yes") + +SYNC_CHECKPOINTER_PARAMS = ( + ["memory"] + if FAST_MODE + else [ + "memory", + "sqlite", + "postgres", + "postgres_pipe", + "postgres_pool", + ] +) + +ASYNC_CHECKPOINTER_PARAMS = ( + ["memory"] + if FAST_MODE + else [ + "memory", + "sqlite_aio", + "postgres_aio", + "postgres_aio_pipe", + "postgres_aio_pool", + ] +) + +SYNC_STORE_PARAMS = ( + ["in_memory"] + if FAST_MODE + else [ + "in_memory", + "postgres", + "postgres_pipe", + "postgres_pool", + ] +) + +ASYNC_STORE_PARAMS = ( + ["in_memory"] + if FAST_MODE + else [ + "in_memory", + "postgres_aio", + "postgres_aio_pipe", + "postgres_aio_pool", + ] +) + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" + + +@pytest.fixture +def deterministic_uuids(mocker: MockerFixture) -> MockerFixture: + side_effect = (UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)) + return mocker.patch("uuid.uuid4", side_effect=side_effect) + + +# checkpointer fixtures + + +@pytest.fixture( + params=SYNC_STORE_PARAMS, +) +def sync_store(request: pytest.FixtureRequest) -> Iterator[BaseStore]: + store_name = request.param + if store_name is None: + yield None + elif store_name == "in_memory": + with _store_memory() as store: + yield store + elif store_name == "postgres": + with _store_postgres() as store: + yield store + elif store_name == "postgres_pipe": + with _store_postgres_pipe() as store: + yield store + elif store_name == "postgres_pool": + with _store_postgres_pool() as store: + yield store + else: + msg = f"Unknown store {store_name}" + raise NotImplementedError(msg) + + +@pytest.fixture( + params=ASYNC_STORE_PARAMS, +) +async def async_store(request: pytest.FixtureRequest) -> AsyncIterator[BaseStore]: + store_name = request.param + if store_name is None: + yield None + elif store_name == "in_memory": + with _store_memory() as store: + yield store + elif store_name == "postgres_aio": + async with _store_postgres_aio() as store: + yield store + elif store_name == "postgres_aio_pipe": + async with _store_postgres_aio_pipe() as store: + yield store + elif store_name == "postgres_aio_pool": + async with _store_postgres_aio_pool() as store: + yield store + else: + msg = f"Unknown store {store_name}" + raise NotImplementedError(msg) + + +@pytest.fixture( + params=SYNC_CHECKPOINTER_PARAMS, +) +def sync_checkpointer( + request: pytest.FixtureRequest, +) -> Iterator[BaseCheckpointSaver]: + checkpointer_name = request.param + if checkpointer_name == "memory": + with _checkpointer_memory() as checkpointer: + yield checkpointer + elif checkpointer_name == "sqlite": + with _checkpointer_sqlite() as checkpointer: + yield checkpointer + elif checkpointer_name == "postgres": + with _checkpointer_postgres() as checkpointer: + yield checkpointer + elif checkpointer_name == "postgres_pipe": + with _checkpointer_postgres_pipe() as checkpointer: + yield checkpointer + elif checkpointer_name == "postgres_pool": + with _checkpointer_postgres_pool() as checkpointer: + yield checkpointer + else: + msg = f"Unknown checkpointer: {checkpointer_name}" + raise NotImplementedError(msg) + + +@pytest.fixture( + params=ASYNC_CHECKPOINTER_PARAMS, +) +async def async_checkpointer( + request: pytest.FixtureRequest, +) -> AsyncIterator[BaseCheckpointSaver]: + checkpointer_name = request.param + if checkpointer_name == "memory": + with _checkpointer_memory() as checkpointer: + yield checkpointer + elif checkpointer_name == "sqlite_aio": + async with _checkpointer_sqlite_aio() as checkpointer: + yield checkpointer + elif checkpointer_name == "postgres_aio": + async with _checkpointer_postgres_aio() as checkpointer: + yield checkpointer + elif checkpointer_name == "postgres_aio_pipe": + async with _checkpointer_postgres_aio_pipe() as checkpointer: + yield checkpointer + elif checkpointer_name == "postgres_aio_pool": + async with _checkpointer_postgres_aio_pool() as checkpointer: + yield checkpointer + else: + msg = f"Unknown checkpointer: {checkpointer_name}" + raise NotImplementedError(msg) diff --git a/libs/langchain_v1/tests/unit_tests/agents/conftest_checkpointer.py b/libs/langchain_v1/tests/unit_tests/agents/conftest_checkpointer.py new file mode 100644 index 00000000000..4505fb28ee3 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/conftest_checkpointer.py @@ -0,0 +1,64 @@ +from contextlib import asynccontextmanager, contextmanager + +from .memory_assert import ( + MemorySaverAssertImmutable, +) + + +@contextmanager +def _checkpointer_memory(): + yield MemorySaverAssertImmutable() + + +@asynccontextmanager +async def _checkpointer_memory_aio(): + yield MemorySaverAssertImmutable() + + +# Placeholder functions for other checkpointer types that aren't available +@contextmanager +def _checkpointer_sqlite(): + # Fallback to memory for now + yield MemorySaverAssertImmutable() + + +@contextmanager +def _checkpointer_postgres(): + # Fallback to memory for now + yield MemorySaverAssertImmutable() + + +@contextmanager +def _checkpointer_postgres_pipe(): + # Fallback to memory for now + yield MemorySaverAssertImmutable() + + +@contextmanager +def _checkpointer_postgres_pool(): + # Fallback to memory for now + yield MemorySaverAssertImmutable() + + +@asynccontextmanager +async def _checkpointer_sqlite_aio(): + # Fallback to memory for now + yield MemorySaverAssertImmutable() + + +@asynccontextmanager +async def _checkpointer_postgres_aio(): + # Fallback to memory for now + yield MemorySaverAssertImmutable() + + +@asynccontextmanager +async def _checkpointer_postgres_aio_pipe(): + # Fallback to memory for now + yield MemorySaverAssertImmutable() + + +@asynccontextmanager +async def _checkpointer_postgres_aio_pool(): + # Fallback to memory for now + yield MemorySaverAssertImmutable() diff --git a/libs/langchain_v1/tests/unit_tests/agents/conftest_store.py b/libs/langchain_v1/tests/unit_tests/agents/conftest_store.py new file mode 100644 index 00000000000..fb0eb4e2fba --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/conftest_store.py @@ -0,0 +1,58 @@ +from contextlib import asynccontextmanager, contextmanager + +from langgraph.store.memory import InMemoryStore + + +@contextmanager +def _store_memory(): + store = InMemoryStore() + yield store + + +@asynccontextmanager +async def _store_memory_aio(): + store = InMemoryStore() + yield store + + +# Placeholder functions for other store types that aren't available +@contextmanager +def _store_postgres(): + # Fallback to memory for now + store = InMemoryStore() + yield store + + +@contextmanager +def _store_postgres_pipe(): + # Fallback to memory for now + store = InMemoryStore() + yield store + + +@contextmanager +def _store_postgres_pool(): + # Fallback to memory for now + store = InMemoryStore() + yield store + + +@asynccontextmanager +async def _store_postgres_aio(): + # Fallback to memory for now + store = InMemoryStore() + yield store + + +@asynccontextmanager +async def _store_postgres_aio_pipe(): + # Fallback to memory for now + store = InMemoryStore() + yield store + + +@asynccontextmanager +async def _store_postgres_aio_pool(): + # Fallback to memory for now + store = InMemoryStore() + yield store diff --git a/libs/langchain_v1/tests/unit_tests/agents/memory_assert.py b/libs/langchain_v1/tests/unit_tests/agents/memory_assert.py new file mode 100644 index 00000000000..99d66bbe026 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/memory_assert.py @@ -0,0 +1,56 @@ +import os +import tempfile +from collections import defaultdict +from functools import partial + +from langgraph.checkpoint.base import ( + ChannelVersions, + Checkpoint, + CheckpointMetadata, + SerializerProtocol, +) +from langgraph.checkpoint.memory import InMemorySaver, PersistentDict +from langgraph.pregel._checkpoint import copy_checkpoint + + +class MemorySaverAssertImmutable(InMemorySaver): + storage_for_copies: defaultdict[str, dict[str, dict[str, Checkpoint]]] + + def __init__( + self, + *, + serde: SerializerProtocol | None = None, + put_sleep: float | None = None, + ) -> None: + _, filename = tempfile.mkstemp() + super().__init__(serde=serde, factory=partial(PersistentDict, filename=filename)) + self.storage_for_copies = defaultdict(lambda: defaultdict(dict)) + self.put_sleep = put_sleep + self.stack.callback(os.remove, filename) + + def put( + self, + config: dict, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> None: + if self.put_sleep: + import time + + time.sleep(self.put_sleep) + # assert checkpoint hasn't been modified since last written + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"]["checkpoint_ns"] + if saved := super().get(config): + assert ( + self.serde.loads_typed( + self.storage_for_copies[thread_id][checkpoint_ns][saved["id"]] + ) + == saved + ) + self.storage_for_copies[thread_id][checkpoint_ns][checkpoint["id"]] = ( + self.serde.dumps_typed(copy_checkpoint(checkpoint)) + ) + # call super to write checkpoint + return super().put(config, checkpoint, metadata, new_versions) diff --git a/libs/langchain_v1/tests/unit_tests/agents/messages.py b/libs/langchain_v1/tests/unit_tests/agents/messages.py new file mode 100644 index 00000000000..285c2941fef --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/messages.py @@ -0,0 +1,28 @@ +"""Redefined messages as a work-around for pydantic issue with AnyStr. + +The code below creates version of pydantic models +that will work in unit tests with AnyStr as id field +Please note that the `id` field is assigned AFTER the model is created +to workaround an issue with pydantic ignoring the __eq__ method on +subclassed strings. +""" + +from typing import Any + +from langchain_core.messages import HumanMessage, ToolMessage + +from .any_str import AnyStr + + +def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: + """Create a human message with an any id field.""" + message = HumanMessage(**kwargs) + message.id = AnyStr() + return message + + +def _AnyIdToolMessage(**kwargs: Any) -> ToolMessage: + """Create a tool message with an any id field.""" + message = ToolMessage(**kwargs) + message.id = AnyStr() + return message diff --git a/libs/langchain_v1/tests/unit_tests/agents/model.py b/libs/langchain_v1/tests/unit_tests/agents/model.py new file mode 100644 index 00000000000..70cbbbafe06 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/model.py @@ -0,0 +1,111 @@ +import json +from collections.abc import Callable, Sequence +from dataclasses import asdict, is_dataclass +from typing import ( + Any, + Generic, + Literal, + TypeVar, + Union, +) + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel, LanguageModelInput +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ToolCall, +) +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from pydantic import BaseModel + +StructuredResponseT = TypeVar("StructuredResponseT") + + +class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]): + tool_calls: Union[list[list[ToolCall]], list[list[dict]]] | None = None + structured_response: StructuredResponseT | None = None + index: int = 0 + tool_style: Literal["openai", "anthropic"] = "openai" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + rf = kwargs.get("response_format") + is_native = isinstance(rf, dict) and rf.get("type") == "json_schema" + + if self.tool_calls: + if is_native: + tool_calls = ( + self.tool_calls[self.index] if self.index < len(self.tool_calls) else [] + ) + else: + tool_calls = self.tool_calls[self.index % len(self.tool_calls)] + else: + tool_calls = [] + + if is_native and not tool_calls: + if isinstance(self.structured_response, BaseModel): + content_obj = self.structured_response.model_dump() + elif is_dataclass(self.structured_response): + content_obj = asdict(self.structured_response) + elif isinstance(self.structured_response, dict): + content_obj = self.structured_response + message = AIMessage(content=json.dumps(content_obj), id=str(self.index)) + else: + messages_string = "-".join([m.content for m in messages]) + message = AIMessage( + content=messages_string, + id=str(self.index), + tool_calls=tool_calls.copy(), + ) + self.index += 1 + return ChatResult(generations=[ChatGeneration(message=message)]) + + @property + def _llm_type(self) -> str: + return "fake-tool-call-model" + + def bind_tools( + self, + tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + if len(tools) == 0: + msg = "Must provide at least one tool" + raise ValueError(msg) + + tool_dicts = [] + for tool in tools: + if isinstance(tool, dict): + tool_dicts.append(tool) + continue + if not isinstance(tool, BaseTool): + msg = "Only BaseTool and dict is supported by FakeToolCallingModel.bind_tools" + raise TypeError(msg) + + # NOTE: this is a simplified tool spec for testing purposes only + if self.tool_style == "openai": + tool_dicts.append( + { + "type": "function", + "function": { + "name": tool.name, + }, + } + ) + elif self.tool_style == "anthropic": + tool_dicts.append( + { + "name": tool.name, + } + ) + + return self.bind(tools=tool_dicts) diff --git a/libs/langchain_v1/tests/unit_tests/agents/specifications/responses.json b/libs/langchain_v1/tests/unit_tests/agents/specifications/responses.json new file mode 100644 index 00000000000..2755ef9756c --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/specifications/responses.json @@ -0,0 +1,87 @@ +[ + { + "name": "updated structured response", + "responseFormat": [ + { + "title": "role_schema_structured_output", + "type": "object", + "properties": { + "name": { "type": "string" }, + "role": { "type": "string" } + }, + "required": ["name", "role"] + }, + { + "title": "department_schema_structured_output", + "type": "object", + "properties": { + "name": { "type": "string" }, + "department": { "type": "string" } + }, + "required": ["name", "department"] + } + ], + "assertionsByInvocation": [ + { + "prompt": "What is the role of Sabine?", + "toolsWithExpectedCalls": { + "getEmployeeRole": 1, + "getEmployeeDepartment": 0 + }, + "expectedLastMessage": "Returning structured response: {'name': 'Sabine', 'role': 'Developer'}", + "expectedStructuredResponse": { "name": "Sabine", "role": "Developer" }, + "llmRequestCount": 2 + }, + { + "prompt": "In which department does Henrik work?", + "toolsWithExpectedCalls": { + "getEmployeeRole": 1, + "getEmployeeDepartment": 1 + }, + "expectedLastMessage": "Returning structured response: {'name': 'Henrik', 'department': 'IT'}", + "expectedStructuredResponse": { "name": "Henrik", "department": "IT" }, + "llmRequestCount": 4 + } + ] + }, + { + "name": "asking for information that does not fit into the response format", + "responseFormat": [ + { + "schema": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "role": { "type": "string" } + }, + "required": ["name", "role"] + } + }, + { + "schema": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "department": { "type": "string" } + }, + "required": ["name", "department"] + } + } + ], + "assertionsByInvocation": [ + { + "prompt": "How much does Saskia earn?", + "toolsWithExpectedCalls": { + "getEmployeeRole": 1, + "getEmployeeDepartment": 0 + }, + "expectedLastMessage": "Returning structured response: {'name': 'Saskia', 'role': 'Software Engineer'}", + "expectedStructuredResponse": { + "name": "Saskia", + "role": "Software Engineer" + }, + "llmRequestCount": 2 + } + ] + } + ] \ No newline at end of file diff --git a/libs/langchain_v1/tests/unit_tests/agents/specifications/return_direct.json b/libs/langchain_v1/tests/unit_tests/agents/specifications/return_direct.json new file mode 100644 index 00000000000..e4cc5947ce0 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/specifications/return_direct.json @@ -0,0 +1,48 @@ +[ + { + "name": "Scenario: NO return_direct, NO response_format", + "returnDirect": false, + "responseFormat": null, + "expectedToolCalls": 10, + "expectedLastMessage": "Attempts: 10", + "expectedStructuredResponse": null + }, + { + "name": "Scenario: NO return_direct, YES response_format", + "returnDirect": false, + "responseFormat": { + "type": "object", + "properties": { + "attempts": { "type": "number" }, + "succeeded": { "type": "boolean" } + }, + "required": ["attempts", "succeeded"] + }, + "expectedToolCalls": 10, + "expectedLastMessage": "Returning structured response: {'attempts': 10, 'succeeded': True}", + "expectedStructuredResponse": { "attempts": 10, "succeeded": true } + }, + { + "name": "Scenario: YES return_direct, NO response_format", + "returnDirect": true, + "responseFormat": null, + "expectedToolCalls": 1, + "expectedLastMessage": "{\"status\": \"pending\", \"attempts\": 1}", + "expectedStructuredResponse": null + }, + { + "name": "Scenario: YES return_direct, YES response_format", + "returnDirect": true, + "responseFormat": { + "type": "object", + "properties": { + "attempts": { "type": "number" }, + "succeeded": { "type": "boolean" } + }, + "required": ["attempts", "succeeded"] + }, + "expectedToolCalls": 1, + "expectedLastMessage": "{\"status\": \"pending\", \"attempts\": 1}", + "expectedStructuredResponse": null + } +] \ No newline at end of file diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py new file mode 100644 index 00000000000..a42d019d27c --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py @@ -0,0 +1,1647 @@ +import dataclasses +import inspect +from typing import ( + Annotated, + Union, +) + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + HumanMessage, + MessageLikeRepresentation, + RemoveMessage, + SystemMessage, + ToolCall, + ToolMessage, +) +from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_core.tools import BaseTool, InjectedToolCallId, ToolException +from langchain_core.tools import tool as dec_tool +from langgraph.checkpoint.base import BaseCheckpointSaver +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.graph.message import REMOVE_ALL_MESSAGES +from langgraph.runtime import Runtime +from langgraph.store.base import BaseStore +from langgraph.store.memory import InMemoryStore +from langgraph.types import Command, Interrupt, interrupt +from pydantic import BaseModel, Field +from typing_extensions import TypedDict + +from langchain.agents import ( + AgentState, + ToolNode, + create_react_agent, +) +from langchain.agents.react_agent import _validate_chat_history +from langchain.agents.tool_node import ( + InjectedState, + InjectedStore, + _get_state_args, + _infer_handled_types, +) + +from .any_str import AnyStr +from .messages import _AnyIdHumanMessage, _AnyIdToolMessage +from .model import FakeToolCallingModel + +pytestmark = pytest.mark.anyio + + +def test_no_prompt(sync_checkpointer: BaseCheckpointSaver) -> None: + model = FakeToolCallingModel() + + agent = create_react_agent( + model, + [], + checkpointer=sync_checkpointer, + ) + inputs = [HumanMessage("hi?")] + thread = {"configurable": {"thread_id": "123"}} + response = agent.invoke({"messages": inputs}, thread, debug=True) + expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]} + assert response == expected_response + + saved = sync_checkpointer.get_tuple(thread) + assert saved is not None + assert saved.checkpoint["channel_values"] == { + "messages": [ + _AnyIdHumanMessage(content="hi?"), + AIMessage(content="hi?", id="0"), + ], + } + assert saved.metadata == { + "parents": {}, + "source": "loop", + "step": 1, + } + assert saved.pending_writes == [] + + +async def test_no_prompt_async(async_checkpointer: BaseCheckpointSaver) -> None: + model = FakeToolCallingModel() + + agent = create_react_agent(model, [], checkpointer=async_checkpointer) + inputs = [HumanMessage("hi?")] + thread = {"configurable": {"thread_id": "123"}} + response = await agent.ainvoke({"messages": inputs}, thread, debug=True) + expected_response = {"messages": [*inputs, AIMessage(content="hi?", id="0")]} + assert response == expected_response + + saved = await async_checkpointer.aget_tuple(thread) + assert saved is not None + assert saved.checkpoint["channel_values"] == { + "messages": [ + _AnyIdHumanMessage(content="hi?"), + AIMessage(content="hi?", id="0"), + ], + } + assert saved.metadata == { + "parents": {}, + "source": "loop", + "step": 1, + } + assert saved.pending_writes == [] + + +def test_system_message_prompt() -> None: + prompt = SystemMessage(content="Foo") + agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt) + inputs = [HumanMessage("hi?")] + response = agent.invoke({"messages": inputs}) + expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]} + assert response == expected_response + + +def test_string_prompt() -> None: + prompt = "Foo" + agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt) + inputs = [HumanMessage("hi?")] + response = agent.invoke({"messages": inputs}) + expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]} + assert response == expected_response + + +def test_callable_prompt() -> None: + def prompt(state): + modified_message = f"Bar {state['messages'][-1].content}" + return [HumanMessage(content=modified_message)] + + agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt) + inputs = [HumanMessage("hi?")] + response = agent.invoke({"messages": inputs}) + expected_response = {"messages": [*inputs, AIMessage(content="Bar hi?", id="0")]} + assert response == expected_response + + +async def test_callable_prompt_async() -> None: + async def prompt(state): + modified_message = f"Bar {state['messages'][-1].content}" + return [HumanMessage(content=modified_message)] + + agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt) + inputs = [HumanMessage("hi?")] + response = await agent.ainvoke({"messages": inputs}) + expected_response = {"messages": [*inputs, AIMessage(content="Bar hi?", id="0")]} + assert response == expected_response + + +def test_runnable_prompt() -> None: + prompt = RunnableLambda( + lambda state: [HumanMessage(content=f"Baz {state['messages'][-1].content}")] + ) + + agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt) + inputs = [HumanMessage("hi?")] + response = agent.invoke({"messages": inputs}) + expected_response = {"messages": [*inputs, AIMessage(content="Baz hi?", id="0")]} + assert response == expected_response + + +def test_prompt_with_store() -> None: + def add(a: int, b: int): + """Adds a and b""" + return a + b + + in_memory_store = InMemoryStore() + in_memory_store.put(("memories", "1"), "user_name", {"data": "User name is Alice"}) + in_memory_store.put(("memories", "2"), "user_name", {"data": "User name is Bob"}) + + def prompt(state, config, *, store): + user_id = config["configurable"]["user_id"] + system_str = store.get(("memories", user_id), "user_name").value["data"] + return [SystemMessage(system_str)] + state["messages"] + + def prompt_no_store(state, config): + return SystemMessage("foo") + state["messages"] + + model = FakeToolCallingModel() + + # test state modifier that uses store works + agent = create_react_agent( + model, + [add], + prompt=prompt, + store=in_memory_store, + ) + response = agent.invoke({"messages": [("user", "hi")]}, {"configurable": {"user_id": "1"}}) + assert response["messages"][-1].content == "User name is Alice-hi" + + # test state modifier that doesn't use store works + agent = create_react_agent( + model, + [add], + prompt=prompt_no_store, + store=in_memory_store, + ) + response = agent.invoke({"messages": [("user", "hi")]}, {"configurable": {"user_id": "2"}}) + assert response["messages"][-1].content == "foo-hi" + + +async def test_prompt_with_store_async() -> None: + async def add(a: int, b: int): + """Adds a and b""" + return a + b + + in_memory_store = InMemoryStore() + await in_memory_store.aput(("memories", "1"), "user_name", {"data": "User name is Alice"}) + await in_memory_store.aput(("memories", "2"), "user_name", {"data": "User name is Bob"}) + + async def prompt(state, config, *, store): + user_id = config["configurable"]["user_id"] + system_str = (await store.aget(("memories", user_id), "user_name")).value["data"] + return [SystemMessage(system_str)] + state["messages"] + + async def prompt_no_store(state, config): + return SystemMessage("foo") + state["messages"] + + model = FakeToolCallingModel() + + # test state modifier that uses store works + agent = create_react_agent(model, [add], prompt=prompt, store=in_memory_store) + response = await agent.ainvoke( + {"messages": [("user", "hi")]}, {"configurable": {"user_id": "1"}} + ) + assert response["messages"][-1].content == "User name is Alice-hi" + + # test state modifier that doesn't use store works + agent = create_react_agent(model, [add], prompt=prompt_no_store, store=in_memory_store) + response = await agent.ainvoke( + {"messages": [("user", "hi")]}, {"configurable": {"user_id": "2"}} + ) + assert response["messages"][-1].content == "foo-hi" + + +@pytest.mark.parametrize("tool_style", ["openai", "anthropic"]) +@pytest.mark.parametrize("include_builtin", [True, False]) +def test_model_with_tools(tool_style: str, include_builtin: bool) -> None: + model = FakeToolCallingModel(tool_style=tool_style) + + @dec_tool + def tool1(some_val: int) -> str: + """Tool 1 docstring.""" + return f"Tool 1: {some_val}" + + @dec_tool + def tool2(some_val: int) -> str: + """Tool 2 docstring.""" + return f"Tool 2: {some_val}" + + tools: list[BaseTool | dict] = [tool1, tool2] + if include_builtin: + tools.append( + { + "type": "mcp", + "server_label": "atest_sever", + "server_url": "https://some.mcp.somewhere.com/sse", + "headers": {"foo": "bar"}, + "allowed_tools": [ + "mcp_tool_1", + "set_active_account", + "get_url_markdown", + "get_url_screenshot", + ], + "require_approval": "never", + } + ) + # check valid agent constructor + with pytest.raises(ValueError): + create_react_agent( + model.bind_tools(tools), + tools, + ) + + +def test__validate_messages() -> None: + # empty input + _validate_chat_history([]) + + # single human message + _validate_chat_history( + [ + HumanMessage(content="What's the weather?"), + ] + ) + + # human + AI + _validate_chat_history( + [ + HumanMessage(content="What's the weather?"), + AIMessage(content="The weather is sunny and 75°F."), + ] + ) + + # Answered tool calls + _validate_chat_history( + [ + HumanMessage(content="What's the weather?"), + AIMessage( + content="Let me check that for you.", + tool_calls=[{"id": "call1", "name": "get_weather", "args": {}}], + ), + ToolMessage(content="Sunny, 75°F", tool_call_id="call1"), + AIMessage(content="The weather is sunny and 75°F."), + ] + ) + + # Unanswered tool calls + with pytest.raises(ValueError): + _validate_chat_history( + [ + AIMessage( + content="I'll check that for you.", + tool_calls=[ + {"id": "call1", "name": "get_weather", "args": {}}, + {"id": "call2", "name": "get_time", "args": {}}, + ], + ) + ] + ) + + with pytest.raises(ValueError): + _validate_chat_history( + [ + HumanMessage(content="What's the weather and time?"), + AIMessage( + content="I'll check that for you.", + tool_calls=[ + {"id": "call1", "name": "get_weather", "args": {}}, + {"id": "call2", "name": "get_time", "args": {}}, + ], + ), + ToolMessage(content="Sunny, 75°F", tool_call_id="call1"), + AIMessage(content="The weather is sunny and 75°F. Let me check the time."), + ] + ) + + +def test__infer_handled_types() -> None: + def handle(e) -> str: # type: ignore + return "" + + def handle2(e: Exception) -> str: + return "" + + def handle3(e: Union[ValueError, ToolException]) -> str: + return "" + + class Handler: + def handle(self, e: ValueError) -> str: + return "" + + handle4 = Handler().handle + + def handle5(e: Union[Union[TypeError, ValueError], ToolException]) -> str: + return "" + + expected: tuple = (Exception,) + actual = _infer_handled_types(handle) + assert expected == actual + + expected = (Exception,) + actual = _infer_handled_types(handle2) + assert expected == actual + + expected = (ValueError, ToolException) + actual = _infer_handled_types(handle3) + assert expected == actual + + expected = (ValueError,) + actual = _infer_handled_types(handle4) + assert expected == actual + + expected = (TypeError, ValueError, ToolException) + actual = _infer_handled_types(handle5) + assert expected == actual + + with pytest.raises(ValueError): + + def handler(e: str) -> str: + return "" + + _infer_handled_types(handler) + + with pytest.raises(ValueError): + + def handler(e: list[Exception]) -> str: + return "" + + _infer_handled_types(handler) + + with pytest.raises(ValueError): + + def handler(e: Union[str, int]) -> str: + return "" + + _infer_handled_types(handler) + + +def test_react_agent_with_structured_response() -> None: + class WeatherResponse(BaseModel): + temperature: float = Field(description="The temperature in fahrenheit") + + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [{"name": "WeatherResponse", "id": "2", "args": {"temperature": 75}}], + ] + + def get_weather() -> str: + """Get the weather""" + return "The weather is sunny and 75°F." + + expected_structured_response = WeatherResponse(temperature=75) + model = FakeToolCallingModel[WeatherResponse]( + tool_calls=tool_calls, structured_response=expected_structured_response + ) + agent = create_react_agent( + model, + [get_weather], + response_format=WeatherResponse, + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + assert response["structured_response"] == expected_structured_response + assert len(response["messages"]) == 5 + + # Check message types in message history + msg_types = [m.type for m in response["messages"]] + assert msg_types == [ + "human", # "What's the weather?" + "ai", # "What's the weather?" + "tool", # "The weather is sunny and 75°F." + "ai", # structured response + "tool", # artificial tool message + ] + + assert [m.content for m in response["messages"]] == [ + "What's the weather?", + "What's the weather?", + "The weather is sunny and 75°F.", + "What's the weather?-What's the weather?-The weather is sunny and 75°F.", + "Returning structured response: {'temperature': 75.0}", + ] + + +class CustomState(AgentState): + user_name: str + + +def test_react_agent_update_state( + sync_checkpointer: BaseCheckpointSaver, +) -> None: + @dec_tool + def get_user_name(tool_call_id: Annotated[str, InjectedToolCallId]): + """Retrieve user name""" + user_name = interrupt("Please provider user name:") + return Command( + update={ + "user_name": user_name, + "messages": [ + ToolMessage("Successfully retrieved user name", tool_call_id=tool_call_id) + ], + } + ) + + def prompt(state: CustomState): + user_name = state.get("user_name") + if user_name is None: + return state["messages"] + + system_msg = f"User name is {user_name}" + return [{"role": "system", "content": system_msg}] + state["messages"] + + tool_calls = [[{"args": {}, "id": "1", "name": "get_user_name"}]] + model = FakeToolCallingModel(tool_calls=tool_calls) + agent = create_react_agent( + model, + [get_user_name], + state_schema=CustomState, + prompt=prompt, + checkpointer=sync_checkpointer, + ) + config = {"configurable": {"thread_id": "1"}} + # Run until interrupted + agent.invoke({"messages": [("user", "what's my name")]}, config) + # supply the value for the interrupt + response = agent.invoke(Command(resume="Archibald"), config) + # confirm that the state was updated + assert response["user_name"] == "Archibald" + assert len(response["messages"]) == 4 + tool_message: ToolMessage = response["messages"][-2] + assert tool_message.content == "Successfully retrieved user name" + assert tool_message.tool_call_id == "1" + assert tool_message.name == "get_user_name" + + +def test_react_agent_parallel_tool_calls( + sync_checkpointer: BaseCheckpointSaver, +) -> None: + human_assistance_execution_count = 0 + + @dec_tool + def human_assistance(query: str) -> str: + """Request assistance from a human.""" + nonlocal human_assistance_execution_count + human_response = interrupt({"query": query}) + human_assistance_execution_count += 1 + return human_response["data"] + + get_weather_execution_count = 0 + + @dec_tool + def get_weather(location: str) -> str: + """Use this tool to get the weather.""" + nonlocal get_weather_execution_count + get_weather_execution_count += 1 + return "It's sunny!" + + tool_calls = [ + [ + {"args": {"location": "sf"}, "id": "1", "name": "get_weather"}, + {"args": {"query": "request help"}, "id": "2", "name": "human_assistance"}, + ], + [], + ] + model = FakeToolCallingModel(tool_calls=tool_calls) + agent = create_react_agent( + model, + [human_assistance, get_weather], + checkpointer=sync_checkpointer, + ) + config = {"configurable": {"thread_id": "1"}} + query = "Get user assistance and also check the weather" + message_types = [] + for event in agent.stream({"messages": [("user", query)]}, config, stream_mode="values"): + if messages := event.get("messages"): + message_types.append([m.type for m in messages]) + + assert message_types == [ + ["human"], + ["human", "ai"], + ["human", "ai", "tool"], + ] + + # Resume + message_types = [] + for event in agent.stream(Command(resume={"data": "Hello"}), config, stream_mode="values"): + if messages := event.get("messages"): + message_types.append([m.type for m in messages]) + + assert message_types == [ + ["human", "ai"], + ["human", "ai", "tool", "tool"], + ["human", "ai", "tool", "tool", "ai"], + ] + + assert human_assistance_execution_count == 1 + assert get_weather_execution_count == 1 + + +class AgentStateExtraKey(AgentState): + foo: int + + +def test_create_react_agent_inject_vars() -> None: + """Test that the agent can inject state and store into tool functions.""" + store = InMemoryStore() + namespace = ("test",) + store.put(namespace, "test_key", {"bar": 3}) + + def tool1( + some_val: int, + state: Annotated[dict, InjectedState], + store: Annotated[BaseStore, InjectedStore()], + ) -> str: + """Tool 1 docstring.""" + store_val = store.get(namespace, "test_key").value["bar"] + return some_val + state["foo"] + store_val + + tool_call = { + "name": "tool1", + "args": {"some_val": 1}, + "id": "some 0", + "type": "tool_call", + } + model = FakeToolCallingModel(tool_calls=[[tool_call], []]) + agent = create_react_agent( + model, + ToolNode([tool1], handle_tool_errors=False), + state_schema=AgentStateExtraKey, + store=store, + ) + result = agent.invoke({"messages": [{"role": "user", "content": "hi"}], "foo": 2}) + assert result["messages"] == [ + _AnyIdHumanMessage(content="hi"), + AIMessage(content="hi", tool_calls=[tool_call], id="0"), + _AnyIdToolMessage(content="6", name="tool1", tool_call_id="some 0"), + AIMessage("hi-hi-6", id="1"), + ] + assert result["foo"] == 2 + + +async def test_return_direct() -> None: + @dec_tool(return_direct=True) + def tool_return_direct(input: str) -> str: + """A tool that returns directly.""" + return f"Direct result: {input}" + + @dec_tool + def tool_normal(input: str) -> str: + """A normal tool.""" + return f"Normal result: {input}" + + first_tool_call = [ + ToolCall( + name="tool_return_direct", + args={"input": "Test direct"}, + id="1", + ), + ] + expected_ai = AIMessage( + content="Test direct", + id="0", + tool_calls=first_tool_call, + ) + model = FakeToolCallingModel(tool_calls=[first_tool_call, []]) + agent = create_react_agent( + model, + [tool_return_direct, tool_normal], + ) + + # Test direct return for tool_return_direct + result = agent.invoke({"messages": [HumanMessage(content="Test direct", id="hum0")]}) + assert result["messages"] == [ + HumanMessage(content="Test direct", id="hum0"), + expected_ai, + ToolMessage( + content="Direct result: Test direct", + name="tool_return_direct", + tool_call_id="1", + id=result["messages"][2].id, + ), + ] + second_tool_call = [ + ToolCall( + name="tool_normal", + args={"input": "Test normal"}, + id="2", + ), + ] + model = FakeToolCallingModel(tool_calls=[second_tool_call, []]) + agent = create_react_agent(model, [tool_return_direct, tool_normal]) + result = agent.invoke({"messages": [HumanMessage(content="Test normal", id="hum1")]}) + assert result["messages"] == [ + HumanMessage(content="Test normal", id="hum1"), + AIMessage(content="Test normal", id="0", tool_calls=second_tool_call), + ToolMessage( + content="Normal result: Test normal", + name="tool_normal", + tool_call_id="2", + id=result["messages"][2].id, + ), + AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"), + ] + + both_tool_calls = [ + ToolCall( + name="tool_return_direct", + args={"input": "Test both direct"}, + id="3", + ), + ToolCall( + name="tool_normal", + args={"input": "Test both normal"}, + id="4", + ), + ] + model = FakeToolCallingModel(tool_calls=[both_tool_calls, []]) + agent = create_react_agent(model, [tool_return_direct, tool_normal]) + result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]}) + assert result["messages"] == [ + HumanMessage(content="Test both", id="hum2"), + AIMessage(content="Test both", id="0", tool_calls=both_tool_calls), + ToolMessage( + content="Direct result: Test both direct", + name="tool_return_direct", + tool_call_id="3", + id=result["messages"][2].id, + ), + ToolMessage( + content="Normal result: Test both normal", + name="tool_normal", + tool_call_id="4", + id=result["messages"][3].id, + ), + ] + + +def test__get_state_args() -> None: + class Schema1(BaseModel): + a: Annotated[str, InjectedState] + + class Schema2(Schema1): + b: Annotated[int, InjectedState("bar")] + + @dec_tool(args_schema=Schema2) + def foo(a: str, b: int) -> float: + """return""" + return 0.0 + + assert _get_state_args(foo) == {"a": None, "b": "bar"} + + +def test_inspect_react() -> None: + model = FakeToolCallingModel(tool_calls=[]) + agent = create_react_agent(model, []) + inspect.getclosurevars(agent.nodes["agent"].bound.func) + + +def test_react_with_subgraph_tools( + sync_checkpointer: BaseCheckpointSaver, +) -> None: + class State(TypedDict): + a: int + b: int + + class Output(TypedDict): + result: int + + # Define the subgraphs + def add(state): + return {"result": state["a"] + state["b"]} + + add_subgraph = ( + StateGraph(State, output_schema=Output).add_node(add).add_edge(START, "add").compile() + ) + + def multiply(state): + return {"result": state["a"] * state["b"]} + + multiply_subgraph = ( + StateGraph(State, output_schema=Output) + .add_node(multiply) + .add_edge(START, "multiply") + .compile() + ) + + multiply_subgraph.invoke({"a": 2, "b": 3}) + + # Add subgraphs as tools + + def addition(a: int, b: int): + """Add two numbers""" + return add_subgraph.invoke({"a": a, "b": b})["result"] + + def multiplication(a: int, b: int): + """Multiply two numbers""" + return multiply_subgraph.invoke({"a": a, "b": b})["result"] + + model = FakeToolCallingModel( + tool_calls=[ + [ + {"args": {"a": 2, "b": 3}, "id": "1", "name": "addition"}, + {"args": {"a": 2, "b": 3}, "id": "2", "name": "multiplication"}, + ], + [], + ] + ) + tool_node = ToolNode([addition, multiplication], handle_tool_errors=False) + agent = create_react_agent( + model, + tool_node, + checkpointer=sync_checkpointer, + ) + result = agent.invoke( + {"messages": [HumanMessage(content="What's 2 + 3 and 2 * 3?")]}, + config={"configurable": {"thread_id": "1"}}, + ) + assert result["messages"] == [ + _AnyIdHumanMessage(content="What's 2 + 3 and 2 * 3?"), + AIMessage( + content="What's 2 + 3 and 2 * 3?", + id="0", + tool_calls=[ + ToolCall(name="addition", args={"a": 2, "b": 3}, id="1"), + ToolCall(name="multiplication", args={"a": 2, "b": 3}, id="2"), + ], + ), + ToolMessage(content="5", name="addition", tool_call_id="1", id=result["messages"][2].id), + ToolMessage( + content="6", + name="multiplication", + tool_call_id="2", + id=result["messages"][3].id, + ), + AIMessage(content="What's 2 + 3 and 2 * 3?-What's 2 + 3 and 2 * 3?-5-6", id="1"), + ] + + +def test_react_agent_subgraph_streaming_sync() -> None: + """Test React agent streaming when used as a subgraph node sync version""" + + @dec_tool + def get_weather(city: str) -> str: + """Get the weather of a city.""" + return f"The weather of {city} is sunny." + + # Create a React agent + model = FakeToolCallingModel( + tool_calls=[ + [{"args": {"city": "Tokyo"}, "id": "1", "name": "get_weather"}], + [], + ] + ) + + agent = create_react_agent( + model, + tools=[get_weather], + prompt="You are a helpful travel assistant.", + ) + + # Create a subgraph that uses the React agent as a node + def react_agent_node(state: MessagesState, config: RunnableConfig) -> MessagesState: + """Node that runs the React agent and collects streaming output.""" + collected_content = "" + + # Stream the agent output and collect content + for msg_chunk, _msg_metadata in agent.stream( + {"messages": [("user", state["messages"][-1].content)]}, + config, + stream_mode="messages", + ): + if hasattr(msg_chunk, "content") and msg_chunk.content: + collected_content += msg_chunk.content + + return {"messages": [("assistant", collected_content)]} + + # Create the main workflow with the React agent as a subgraph node + workflow = StateGraph(MessagesState) + workflow.add_node("react_agent", react_agent_node) + workflow.add_edge(START, "react_agent") + workflow.add_edge("react_agent", "__end__") + compiled_workflow = workflow.compile() + + # Test the streaming functionality + result = compiled_workflow.invoke({"messages": [("user", "What is the weather in Tokyo?")]}) + + # Verify the result contains expected structure + assert len(result["messages"]) == 2 + assert result["messages"][0].content == "What is the weather in Tokyo?" + assert "assistant" in str(result["messages"][1]) + + # Test streaming with subgraphs = True + result = compiled_workflow.invoke( + {"messages": [("user", "What is the weather in Tokyo?")]}, + subgraphs=True, + ) + assert len(result["messages"]) == 2 + + events = [] + for event in compiled_workflow.stream( + {"messages": [("user", "What is the weather in Tokyo?")]}, + stream_mode="messages", + subgraphs=False, + ): + events.append(event) + + assert len(events) == 0 + + events = [] + for event in compiled_workflow.stream( + {"messages": [("user", "What is the weather in Tokyo?")]}, + stream_mode="messages", + subgraphs=True, + ): + events.append(event) + + assert len(events) == 3 + namespace, (msg, metadata) = events[0] + # FakeToolCallingModel returns a single AIMessage with tool calls + # The content of the AIMessage reflects the input message + assert msg.content.startswith("You are a helpful travel assistant") + namespace, (msg, metadata) = events[1] # ToolMessage + assert msg.content.startswith("The weather of Tokyo is sunny.") + + +async def test_react_agent_subgraph_streaming() -> None: + """Test React agent streaming when used as a subgraph node.""" + + @dec_tool + def get_weather(city: str) -> str: + """Get the weather of a city.""" + return f"The weather of {city} is sunny." + + # Create a React agent + model = FakeToolCallingModel( + tool_calls=[ + [{"args": {"city": "Tokyo"}, "id": "1", "name": "get_weather"}], + [], + ] + ) + + agent = create_react_agent( + model, + tools=[get_weather], + prompt="You are a helpful travel assistant.", + ) + + # Create a subgraph that uses the React agent as a node + async def react_agent_node(state: MessagesState, config: RunnableConfig) -> MessagesState: + """Node that runs the React agent and collects streaming output.""" + collected_content = "" + + # Stream the agent output and collect content + async for msg_chunk, _msg_metadata in agent.astream( + {"messages": [("user", state["messages"][-1].content)]}, + config, + stream_mode="messages", + ): + if hasattr(msg_chunk, "content") and msg_chunk.content: + collected_content += msg_chunk.content + + return {"messages": [("assistant", collected_content)]} + + # Create the main workflow with the React agent as a subgraph node + workflow = StateGraph(MessagesState) + workflow.add_node("react_agent", react_agent_node) + workflow.add_edge(START, "react_agent") + workflow.add_edge("react_agent", "__end__") + compiled_workflow = workflow.compile() + + # Test the streaming functionality + result = await compiled_workflow.ainvoke( + {"messages": [("user", "What is the weather in Tokyo?")]} + ) + + # Verify the result contains expected structure + assert len(result["messages"]) == 2 + assert result["messages"][0].content == "What is the weather in Tokyo?" + assert "assistant" in str(result["messages"][1]) + + # Test streaming with subgraphs = True + result = await compiled_workflow.ainvoke( + {"messages": [("user", "What is the weather in Tokyo?")]}, + subgraphs=True, + ) + assert len(result["messages"]) == 2 + + events = [] + async for event in compiled_workflow.astream( + {"messages": [("user", "What is the weather in Tokyo?")]}, + stream_mode="messages", + subgraphs=False, + ): + events.append(event) + + assert len(events) == 0 + + events = [] + async for event in compiled_workflow.astream( + {"messages": [("user", "What is the weather in Tokyo?")]}, + stream_mode="messages", + subgraphs=True, + ): + events.append(event) + + assert len(events) == 3 + namespace, (msg, metadata) = events[0] + # FakeToolCallingModel returns a single AIMessage with tool calls + # The content of the AIMessage reflects the input message + assert msg.content.startswith("You are a helpful travel assistant") + namespace, (msg, metadata) = events[1] # ToolMessage + assert msg.content.startswith("The weather of Tokyo is sunny.") + + +def test_tool_node_node_interrupt( + sync_checkpointer: BaseCheckpointSaver, +) -> None: + def tool_normal(some_val: int) -> str: + """Tool docstring.""" + return "normal" + + def tool_interrupt(some_val: int) -> str: + """Tool docstring.""" + return interrupt("provide value for foo") + + # test inside react agent + model = FakeToolCallingModel( + tool_calls=[ + [ + ToolCall(name="tool_interrupt", args={"some_val": 0}, id="1"), + ToolCall(name="tool_normal", args={"some_val": 1}, id="2"), + ], + [], + ] + ) + config = {"configurable": {"thread_id": "1"}} + agent = create_react_agent( + model, + [tool_interrupt, tool_normal], + checkpointer=sync_checkpointer, + ) + result = agent.invoke({"messages": [HumanMessage("hi?")]}, config) + expected_messages = [ + _AnyIdHumanMessage(content="hi?"), + AIMessage( + content="hi?", + id="0", + tool_calls=[ + { + "name": "tool_interrupt", + "args": {"some_val": 0}, + "id": "1", + "type": "tool_call", + }, + { + "name": "tool_normal", + "args": {"some_val": 1}, + "id": "2", + "type": "tool_call", + }, + ], + ), + _AnyIdToolMessage(content="normal", name="tool_normal", tool_call_id="2"), + ] + assert result["messages"] == expected_messages + + state = agent.get_state(config) + assert state.next == ("tools",) + task = state.tasks[0] + assert task.name == "tools" + assert task.interrupts == ( + Interrupt( + value="provide value for foo", + id=AnyStr(), + ), + ) + + +def test_dynamic_model_basic() -> None: + """Test basic dynamic model functionality.""" + + def dynamic_model(state, runtime: Runtime): + # Return different models based on state + if "urgent" in state["messages"][-1].content: + return FakeToolCallingModel(tool_calls=[]) + return FakeToolCallingModel(tool_calls=[]) + + agent = create_react_agent(dynamic_model, []) + + result = agent.invoke({"messages": [HumanMessage("hello")]}) + assert len(result["messages"]) == 2 + assert result["messages"][-1].content == "hello" + + result = agent.invoke({"messages": [HumanMessage("urgent help")]}) + assert len(result["messages"]) == 2 + assert result["messages"][-1].content == "urgent help" + + +def test_dynamic_model_with_tools() -> None: + """Test dynamic model with tool calling.""" + + @dec_tool + def basic_tool(x: int) -> str: + """Basic tool.""" + return f"basic: {x}" + + @dec_tool + def advanced_tool(x: int) -> str: + """Advanced tool.""" + return f"advanced: {x}" + + def dynamic_model(state: dict, runtime: Runtime) -> BaseChatModel: + # Return model with different behaviors based on message content + if "advanced" in state["messages"][-1].content: + return FakeToolCallingModel( + tool_calls=[ + [{"args": {"x": 1}, "id": "1", "name": "advanced_tool"}], + [], + ] + ) + return FakeToolCallingModel( + tool_calls=[[{"args": {"x": 1}, "id": "1", "name": "basic_tool"}], []] + ) + + agent = create_react_agent(dynamic_model, [basic_tool, advanced_tool]) + + # Test basic tool usage + result = agent.invoke({"messages": [HumanMessage("basic request")]}) + assert len(result["messages"]) == 3 + tool_message = result["messages"][-1] + assert tool_message.content == "basic: 1" + assert tool_message.name == "basic_tool" + + # Test advanced tool usage + result = agent.invoke({"messages": [HumanMessage("advanced request")]}) + assert len(result["messages"]) == 3 + tool_message = result["messages"][-1] + assert tool_message.content == "advanced: 1" + assert tool_message.name == "advanced_tool" + + +@dataclasses.dataclass +class Context: + user_id: str + + +def test_dynamic_model_with_context() -> None: + """Test dynamic model using config parameters.""" + + def dynamic_model(state, runtime: Runtime[Context]): + # Use context to determine model behavior + user_id = runtime.context.user_id + if user_id == "user_premium": + return FakeToolCallingModel(tool_calls=[]) + return FakeToolCallingModel(tool_calls=[]) + + agent = create_react_agent(dynamic_model, [], context_schema=Context) + + # Test with basic user + result = agent.invoke( + {"messages": [HumanMessage("hello")]}, + context=Context(user_id="user_basic"), + ) + assert len(result["messages"]) == 2 + + # Test with premium user + result = agent.invoke( + {"messages": [HumanMessage("hello")]}, + context=Context(user_id="user_premium"), + ) + assert len(result["messages"]) == 2 + + +def test_dynamic_model_with_state_schema() -> None: + """Test dynamic model with custom state schema.""" + + class CustomDynamicState(AgentState): + model_preference: str = "default" + + def dynamic_model(state: CustomDynamicState, runtime: Runtime) -> BaseChatModel: + # Use custom state field to determine model + if state.get("model_preference") == "advanced": + return FakeToolCallingModel(tool_calls=[]) + return FakeToolCallingModel(tool_calls=[]) + + agent = create_react_agent(dynamic_model, [], state_schema=CustomDynamicState) + + result = agent.invoke({"messages": [HumanMessage("hello")], "model_preference": "advanced"}) + assert len(result["messages"]) == 2 + assert result["model_preference"] == "advanced" + + +def test_dynamic_model_with_prompt() -> None: + """Test dynamic model with different prompt types.""" + + def dynamic_model(state: AgentState, runtime: Runtime) -> BaseChatModel: + return FakeToolCallingModel(tool_calls=[]) + + # Test with string prompt + agent = create_react_agent(dynamic_model, [], prompt="system_msg") + result = agent.invoke({"messages": [HumanMessage("human_msg")]}) + assert result["messages"][-1].content == "system_msg-human_msg" + + # Test with callable prompt + def dynamic_prompt(state: AgentState) -> list[MessageLikeRepresentation]: + """Generate a dynamic system message based on state.""" + return [{"role": "system", "content": "system_msg"}, *list(state["messages"])] + + agent = create_react_agent(dynamic_model, [], prompt=dynamic_prompt) + result = agent.invoke({"messages": [HumanMessage("human_msg")]}) + assert result["messages"][-1].content == "system_msg-human_msg" + + +async def test_dynamic_model_async() -> None: + """Test dynamic model with async operations.""" + + def dynamic_model(state: AgentState, runtime: Runtime) -> BaseChatModel: + return FakeToolCallingModel(tool_calls=[]) + + agent = create_react_agent(dynamic_model, []) + + result = await agent.ainvoke({"messages": [HumanMessage("hello async")]}) + assert len(result["messages"]) == 2 + assert result["messages"][-1].content == "hello async" + + +def test_dynamic_model_with_structured_response() -> None: + """Test dynamic model with structured response format.""" + + class TestResponse(BaseModel): + message: str + confidence: float + + def dynamic_model(state, runtime: Runtime): + return FakeToolCallingModel( + tool_calls=[ + [ + ToolCall( + name="TestResponse", + args={"message": "dynamic response", "confidence": 0.9}, + id="1", + type="tool_call", + ) + ] + ], + ) + + agent = create_react_agent(dynamic_model, [], response_format=TestResponse) + + result = agent.invoke({"messages": [HumanMessage("hello")]}) + assert "structured_response" in result + assert result["structured_response"].message == "dynamic response" + assert result["structured_response"].confidence == 0.9 + + +def test_dynamic_model_with_checkpointer(sync_checkpointer) -> None: + """Test dynamic model with checkpointer.""" + call_count = 0 + + def dynamic_model(state: AgentState, runtime: Runtime) -> BaseChatModel: + nonlocal call_count + call_count += 1 + return FakeToolCallingModel( + tool_calls=[], + # Incrementing the call count as it is used to assign an id + # to the AIMessage. + # The default reducer semantics are to overwrite an existing message + # with the new one if the id matches. + index=call_count, + ) + + agent = create_react_agent(dynamic_model, [], checkpointer=sync_checkpointer) + config = {"configurable": {"thread_id": "test_dynamic"}} + + # First call + result1 = agent.invoke({"messages": [HumanMessage("hello")]}, config) + assert len(result1["messages"]) == 2 # Human + AI message + + # Second call - should load from checkpoint + result2 = agent.invoke({"messages": [HumanMessage("world")]}, config) + assert len(result2["messages"]) == 4 + + # Dynamic model should be called each time + assert call_count >= 2 + + +def test_dynamic_model_state_dependent_tools() -> None: + """Test dynamic model that changes available tools based on state.""" + + @dec_tool + def tool_a(x: int) -> str: + """Tool A.""" + return f"A: {x}" + + @dec_tool + def tool_b(x: int) -> str: + """Tool B.""" + return f"B: {x}" + + def dynamic_model(state, runtime: Runtime): + # Switch tools based on message history + if any("use_b" in msg.content for msg in state["messages"]): + return FakeToolCallingModel( + tool_calls=[[{"args": {"x": 2}, "id": "1", "name": "tool_b"}], []] + ) + return FakeToolCallingModel( + tool_calls=[[{"args": {"x": 1}, "id": "1", "name": "tool_a"}], []] + ) + + agent = create_react_agent(dynamic_model, [tool_a, tool_b]) + + # Ask to use tool B + result = agent.invoke({"messages": [HumanMessage("use_b please")]}) + last_message = result["messages"][-1] + assert isinstance(last_message, ToolMessage) + assert last_message.content == "B: 2" + + # Ask to use tool A + result = agent.invoke({"messages": [HumanMessage("hello")]}) + last_message = result["messages"][-1] + assert isinstance(last_message, ToolMessage) + assert last_message.content == "A: 1" + + +def test_dynamic_model_error_handling() -> None: + """Test error handling in dynamic model.""" + + def failing_dynamic_model(state, runtime: Runtime): + if "fail" in state["messages"][-1].content: + msg = "Dynamic model failed" + raise ValueError(msg) + return FakeToolCallingModel(tool_calls=[]) + + agent = create_react_agent(failing_dynamic_model, []) + + # Normal operation should work + result = agent.invoke({"messages": [HumanMessage("hello")]}) + assert len(result["messages"]) == 2 + + # Should propagate the error + with pytest.raises(ValueError, match="Dynamic model failed"): + agent.invoke({"messages": [HumanMessage("fail now")]}) + + +def test_dynamic_model_vs_static_model_behavior() -> None: + """Test that dynamic and static models produce equivalent results when configured the same.""" + # Static model + static_model = FakeToolCallingModel(tool_calls=[]) + static_agent = create_react_agent(static_model, []) + + # Dynamic model returning the same model + def dynamic_model(state, runtime: Runtime): + return FakeToolCallingModel(tool_calls=[]) + + dynamic_agent = create_react_agent(dynamic_model, []) + + input_msg = {"messages": [HumanMessage("test message")]} + + static_result = static_agent.invoke(input_msg) + dynamic_result = dynamic_agent.invoke(input_msg) + + # Results should be equivalent (content-wise, IDs may differ) + assert len(static_result["messages"]) == len(dynamic_result["messages"]) + assert static_result["messages"][0].content == dynamic_result["messages"][0].content + assert static_result["messages"][1].content == dynamic_result["messages"][1].content + + +def test_dynamic_model_receives_correct_state() -> None: + """Test that the dynamic model function receives the correct state, not the model input.""" + received_states = [] + + class CustomAgentState(AgentState): + custom_field: str + + def dynamic_model(state, runtime: Runtime) -> BaseChatModel: + # Capture the state that's passed to the dynamic model function + received_states.append(state) + return FakeToolCallingModel(tool_calls=[]) + + agent = create_react_agent(dynamic_model, [], state_schema=CustomAgentState) + + # Test with initial state + input_state = {"messages": [HumanMessage("hello")], "custom_field": "test_value"} + agent.invoke(input_state) + + # The dynamic model function should receive the original state, not the processed model input + assert len(received_states) == 1 + received_state = received_states[0] + + # Should have the custom field from original state + assert "custom_field" in received_state + assert received_state["custom_field"] == "test_value" + + # Should have the original messages + assert len(received_state["messages"]) == 1 + assert received_state["messages"][0].content == "hello" + + +async def test_dynamic_model_receives_correct_state_async() -> None: + """Test that the async dynamic model function receives the correct state, not the model input.""" + received_states = [] + + class CustomAgentStateAsync(AgentState): + custom_field: str + + def dynamic_model(state, runtime: Runtime): + # Capture the state that's passed to the dynamic model function + received_states.append(state) + return FakeToolCallingModel(tool_calls=[]) + + agent = create_react_agent(dynamic_model, [], state_schema=CustomAgentStateAsync) + + # Test with initial state + input_state = { + "messages": [HumanMessage("hello async")], + "custom_field": "test_value_async", + } + await agent.ainvoke(input_state) + + # The dynamic model function should receive the original state, not the processed model input + assert len(received_states) == 1 + received_state = received_states[0] + + # Should have the custom field from original state + assert "custom_field" in received_state + assert received_state["custom_field"] == "test_value_async" + + # Should have the original messages + assert len(received_state["messages"]) == 1 + assert received_state["messages"][0].content == "hello async" + + +def test_pre_model_hook() -> None: + model = FakeToolCallingModel(tool_calls=[]) + + # Test `llm_input_messages` + def pre_model_hook(state: AgentState): + return {"llm_input_messages": [HumanMessage("Hello!")]} + + agent = create_react_agent(model, [], pre_model_hook=pre_model_hook) + assert "pre_model_hook" in agent.nodes + result = agent.invoke({"messages": [HumanMessage("hi?")]}) + assert result == { + "messages": [ + _AnyIdHumanMessage(content="hi?"), + AIMessage(content="Hello!", id="0"), + ] + } + + # Test `messages` + def pre_model_hook(state: AgentState): + return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), HumanMessage("Hello!")]} + + agent = create_react_agent(model, [], pre_model_hook=pre_model_hook) + result = agent.invoke({"messages": [HumanMessage("hi?")]}) + assert result == { + "messages": [ + _AnyIdHumanMessage(content="Hello!"), + AIMessage(content="Hello!", id="1"), + ] + } + + +def test_post_model_hook() -> None: + class FlagState(AgentState): + flag: bool + + model = FakeToolCallingModel(tool_calls=[]) + + def post_model_hook(state: FlagState) -> dict[str, bool]: + return {"flag": True} + + pmh_agent = create_react_agent( + model, [], post_model_hook=post_model_hook, state_schema=FlagState + ) + + assert "post_model_hook" in pmh_agent.nodes + + result = pmh_agent.invoke({"messages": [HumanMessage("hi?")], "flag": False}) + assert result["flag"] is True + + events = list(pmh_agent.stream({"messages": [HumanMessage("hi?")], "flag": False})) + assert events == [ + { + "agent": { + "messages": [ + AIMessage( + content="hi?", + additional_kwargs={}, + response_metadata={}, + id="1", + ) + ] + } + }, + {"post_model_hook": {"flag": True}}, + ] + + +def test_post_model_hook_with_structured_output() -> None: + class WeatherResponse(BaseModel): + temperature: float = Field(description="The temperature in fahrenheit") + + tool_calls: list[list[ToolCall]] = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [{"args": {"temperature": 75}, "id": "2", "name": "WeatherResponse"}], + ] + + def get_weather() -> str: + """Get the weather""" + return "The weather is sunny and 75°F." + + expected_structured_response = WeatherResponse(temperature=75) + + class State(AgentState): + flag: bool + structured_response: WeatherResponse + + def post_model_hook(state: State) -> Union[dict[str, bool], Command]: + return {"flag": True} + + model = FakeToolCallingModel(tool_calls=tool_calls) + agent = create_react_agent( + model, + [get_weather], + response_format=WeatherResponse, + post_model_hook=post_model_hook, + state_schema=State, + ) + + assert "post_model_hook" in agent.nodes + + response = agent.invoke({"messages": [HumanMessage("What's the weather?")], "flag": False}) + assert response["flag"] is True + assert response["structured_response"] == expected_structured_response + + # Reset the state of the model + model = FakeToolCallingModel(tool_calls=tool_calls) + agent = create_react_agent( + model, + [get_weather], + response_format=WeatherResponse, + post_model_hook=post_model_hook, + state_schema=State, + ) + + events = list(agent.stream({"messages": [HumanMessage("What's the weather?")], "flag": False})) + assert events == [ + { + "agent": { + "messages": [ + AIMessage( + content="What's the weather?", + additional_kwargs={}, + response_metadata={}, + id="0", + tool_calls=[ + { + "name": "get_weather", + "args": {}, + "id": "1", + "type": "tool_call", + } + ], + ) + ] + } + }, + {"post_model_hook": {"flag": True}}, + { + "tools": { + "messages": [ + _AnyIdToolMessage( + content="The weather is sunny and 75°F.", + name="get_weather", + tool_call_id="1", + ) + ] + } + }, + { + "agent": { + "messages": [ + AIMessage( + content="What's the weather?-What's the weather?-The weather is sunny and 75°F.", + additional_kwargs={}, + response_metadata={}, + id="1", + tool_calls=[ + { + "name": "WeatherResponse", + "args": {"temperature": 75}, + "id": "2", + "type": "tool_call", + } + ], + ), + _AnyIdToolMessage( + content="Returning structured response: {'temperature': 75.0}", + name="WeatherResponse", + tool_call_id="2", + ), + ], + "structured_response": WeatherResponse(temperature=75.0), + } + }, + {"post_model_hook": {"flag": True}}, + ] + + +def test_create_react_agent_inject_vars_with_post_model_hook() -> None: + store = InMemoryStore() + namespace = ("test",) + store.put(namespace, "test_key", {"bar": 3}) + + def tool1( + some_val: int, + state: Annotated[dict, InjectedState], + store: Annotated[BaseStore, InjectedStore()], + ) -> str: + """Tool 1 docstring.""" + store_val = store.get(namespace, "test_key").value["bar"] + return some_val + state["foo"] + store_val + + tool_call = { + "name": "tool1", + "args": {"some_val": 1}, + "id": "some 0", + "type": "tool_call", + } + + def post_model_hook(state: dict) -> dict: + """Post model hook is injecting a new foo key.""" + return {"foo": 2} + + model = FakeToolCallingModel(tool_calls=[[tool_call], []]) + agent = create_react_agent( + model, + ToolNode([tool1], handle_tool_errors=False), + state_schema=AgentStateExtraKey, + store=store, + post_model_hook=post_model_hook, + ) + input_message = HumanMessage("hi") + result = agent.invoke({"messages": [input_message], "foo": 2}) + assert result["messages"] == [ + input_message, + AIMessage(content="hi", tool_calls=[tool_call], id="0"), + _AnyIdToolMessage(content="6", name="tool1", tool_call_id="some 0"), + AIMessage("hi-hi-6", id="1"), + ] + assert result["foo"] == 2 + + +def test_response_format_using_tool_choice() -> None: + """Test response format using tool choice.""" + + class WeatherResponse(BaseModel): + temperature: float = Field(description="The temperature in fahrenheit") + + tool_calls: list[list[ToolCall]] = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [{"args": {"temperature": "75"}, "id": "2", "name": "WeatherResponse"}], + ] + + def get_weather() -> str: + """Get the weather""" + return "The weather is sunny and 75°F." + + expected_structured_response = WeatherResponse(temperature=75) + model = FakeToolCallingModel(tool_calls=tool_calls) + agent = create_react_agent( + model, + [get_weather], + response_format=WeatherResponse, + ) + response = agent.invoke( + { + "messages": [ + { + "role": "user", + "content": "What's the weather?", + } + ] + } + ) + assert response.get("structured_response") == expected_structured_response diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_react_agent_graph.py b/libs/langchain_v1/tests/unit_tests/agents/test_react_agent_graph.py new file mode 100644 index 00000000000..b978306c652 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_react_agent_graph.py @@ -0,0 +1,58 @@ +from collections.abc import Callable +from typing import Union + +import pytest +from pydantic import BaseModel +from syrupy.assertion import SnapshotAssertion + +from langchain.agents import create_react_agent + +from .model import FakeToolCallingModel + +model = FakeToolCallingModel() + + +def tool() -> None: + """Testing tool.""" + + +def pre_model_hook() -> None: + """Pre-model hook.""" + + +def post_model_hook() -> None: + """Post-model hook.""" + + +class ResponseFormat(BaseModel): + """Response format for the agent.""" + + result: str + + +@pytest.mark.parametrize("tools", [[], [tool]]) +@pytest.mark.parametrize("pre_model_hook", [None, pre_model_hook]) +@pytest.mark.parametrize("post_model_hook", [None, post_model_hook]) +def test_react_agent_graph_structure( + snapshot: SnapshotAssertion, + tools: list[Callable], + pre_model_hook: Union[Callable, None], + post_model_hook: Union[Callable, None], +) -> None: + agent = create_react_agent( + model, + tools=tools, + pre_model_hook=pre_model_hook, + post_model_hook=post_model_hook, + ) + try: + assert agent.get_graph().draw_mermaid(with_styles=False) == snapshot + except Exception as e: + msg = ( + "The graph structure has changed. Please update the snapshot." + "Configuration used:\n" + f"tools: {tools}, " + f"pre_model_hook: {pre_model_hook}, " + f"post_model_hook: {post_model_hook}, " + ) + raise ValueError(msg) from e diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py new file mode 100644 index 00000000000..06cf8572b2c --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py @@ -0,0 +1,710 @@ +"""Test suite for create_react_agent with structured output response_format permutations.""" + +import pytest + +from dataclasses import dataclass +from typing import Union + +from langchain_core.messages import HumanMessage +from langchain.agents import create_react_agent +from langchain.agents.structured_output import ( + MultipleStructuredOutputsError, + ProviderStrategy, + StructuredOutputValidationError, + ToolStrategy, +) +from pydantic import BaseModel, Field +from typing_extensions import TypedDict + +from tests.unit_tests.agents.model import FakeToolCallingModel + + +# Test data models +class WeatherBaseModel(BaseModel): + """Weather response.""" + + temperature: float = Field(description="The temperature in fahrenheit") + condition: str = Field(description="Weather condition") + + +@dataclass +class WeatherDataclass: + """Weather response.""" + + temperature: float + condition: str + + +class WeatherTypedDict(TypedDict): + """Weather response.""" + + temperature: float + condition: str + + +weather_json_schema = { + "type": "object", + "properties": { + "temperature": {"type": "number", "description": "Temperature in fahrenheit"}, + "condition": {"type": "string", "description": "Weather condition"}, + }, + "title": "weather_schema", + "required": ["temperature", "condition"], +} + + +class LocationResponse(BaseModel): + city: str = Field(description="The city name") + country: str = Field(description="The country name") + + +class LocationTypedDict(TypedDict): + city: str + country: str + + +location_json_schema = { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city name"}, + "country": {"type": "string", "description": "The country name"}, + }, + "title": "location_schema", + "required": ["city", "country"], +} + + +def get_weather() -> str: + """Get the weather.""" + + return "The weather is sunny and 75°F." + + +def get_location() -> str: + """Get the current location.""" + + return "You are in New York, USA." + + +# Standardized test data +WEATHER_DATA = {"temperature": 75.0, "condition": "sunny"} +LOCATION_DATA = {"city": "New York", "country": "USA"} + +# Standardized expected responses +EXPECTED_WEATHER_PYDANTIC = WeatherBaseModel(**WEATHER_DATA) +EXPECTED_WEATHER_DATACLASS = WeatherDataclass(**WEATHER_DATA) +EXPECTED_WEATHER_DICT: WeatherTypedDict = {"temperature": 75.0, "condition": "sunny"} +EXPECTED_LOCATION = LocationResponse(**LOCATION_DATA) +EXPECTED_LOCATION_DICT: LocationTypedDict = {"city": "New York", "country": "USA"} + + +class TestResponseFormatAsModel: + def test_pydantic_model(self) -> None: + """Test response_format as Pydantic model.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "WeatherBaseModel", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent(model, [get_weather], response_format=WeatherBaseModel) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + assert len(response["messages"]) == 5 + + def test_dataclass(self) -> None: + """Test response_format as dataclass.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "WeatherDataclass", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent(model, [get_weather], response_format=WeatherDataclass) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS + assert len(response["messages"]) == 5 + + def test_typed_dict(self) -> None: + """Test response_format as TypedDict.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "WeatherTypedDict", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent(model, [get_weather], response_format=WeatherTypedDict) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_DICT + assert len(response["messages"]) == 5 + + def test_json_schema(self) -> None: + """Test response_format as JSON schema.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "weather_schema", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent(model, [get_weather], response_format=weather_json_schema) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_DICT + assert len(response["messages"]) == 5 + + +class TestResponseFormatAsToolStrategy: + def test_pydantic_model(self) -> None: + """Test response_format as ToolStrategy with Pydantic model.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "WeatherBaseModel", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent( + model, [get_weather], response_format=ToolStrategy(WeatherBaseModel) + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + assert len(response["messages"]) == 5 + + def test_dataclass(self) -> None: + """Test response_format as ToolStrategy with dataclass.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "WeatherDataclass", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent( + model, [get_weather], response_format=ToolStrategy(WeatherDataclass) + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS + assert len(response["messages"]) == 5 + + def test_typed_dict(self) -> None: + """Test response_format as ToolStrategy with TypedDict.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "WeatherTypedDict", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent( + model, [get_weather], response_format=ToolStrategy(WeatherTypedDict) + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_DICT + assert len(response["messages"]) == 5 + + def test_json_schema(self) -> None: + """Test response_format as ToolStrategy with JSON schema.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "weather_schema", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent( + model, [get_weather], response_format=ToolStrategy(weather_json_schema) + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_DICT + assert len(response["messages"]) == 5 + + def test_union_of_json_schemas(self) -> None: + """Test response_format as ToolStrategy with union of JSON schemas.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "weather_schema", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent( + model, + [get_weather, get_location], + response_format=ToolStrategy({"oneOf": [weather_json_schema, location_json_schema]}), + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_DICT + assert len(response["messages"]) == 5 + + # Test with LocationResponse + tool_calls_location = [ + [{"args": {}, "id": "1", "name": "get_location"}], + [ + { + "name": "location_schema", + "id": "2", + "args": LOCATION_DATA, + } + ], + ] + + model_location = FakeToolCallingModel(tool_calls=tool_calls_location) + + agent_location = create_react_agent( + model_location, + [get_weather, get_location], + response_format=ToolStrategy({"oneOf": [weather_json_schema, location_json_schema]}), + ) + response_location = agent_location.invoke({"messages": [HumanMessage("Where am I?")]}) + + assert response_location["structured_response"] == EXPECTED_LOCATION_DICT + assert len(response_location["messages"]) == 5 + + def test_union_of_types(self) -> None: + """Test response_format as ToolStrategy with Union of various types.""" + # Test with WeatherBaseModel + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "WeatherBaseModel", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel[Union[WeatherBaseModel, LocationResponse]]( + tool_calls=tool_calls + ) + + agent = create_react_agent( + model, + [get_weather, get_location], + response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]), + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + assert len(response["messages"]) == 5 + + # Test with LocationResponse + tool_calls_location = [ + [{"args": {}, "id": "1", "name": "get_location"}], + [ + { + "name": "LocationResponse", + "id": "2", + "args": LOCATION_DATA, + } + ], + ] + + model_location = FakeToolCallingModel(tool_calls=tool_calls_location) + + agent_location = create_react_agent( + model_location, + [get_weather, get_location], + response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]), + ) + response_location = agent_location.invoke({"messages": [HumanMessage("Where am I?")]}) + + assert response_location["structured_response"] == EXPECTED_LOCATION + assert len(response_location["messages"]) == 5 + + def test_multiple_structured_outputs_error_without_retry(self) -> None: + """Test that MultipleStructuredOutputsError is raised when model returns multiple structured tool calls without retry.""" + tool_calls = [ + [ + { + "name": "WeatherBaseModel", + "id": "1", + "args": WEATHER_DATA, + }, + { + "name": "LocationResponse", + "id": "2", + "args": LOCATION_DATA, + }, + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent( + model, + [], + response_format=ToolStrategy( + Union[WeatherBaseModel, LocationResponse], + handle_errors=False, + ), + ) + + with pytest.raises( + MultipleStructuredOutputsError, + match=".*WeatherBaseModel.*LocationResponse.*", + ): + agent.invoke({"messages": [HumanMessage("Give me weather and location")]}) + + def test_multiple_structured_outputs_with_retry(self) -> None: + """Test that retry handles multiple structured output tool calls.""" + tool_calls = [ + [ + { + "name": "WeatherBaseModel", + "id": "1", + "args": WEATHER_DATA, + }, + { + "name": "LocationResponse", + "id": "2", + "args": LOCATION_DATA, + }, + ], + [ + { + "name": "WeatherBaseModel", + "id": "3", + "args": WEATHER_DATA, + }, + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent( + model, + [], + response_format=ToolStrategy( + Union[WeatherBaseModel, LocationResponse], + handle_errors=True, + ), + ) + + response = agent.invoke({"messages": [HumanMessage("Give me weather")]}) + + # HumanMessage, AIMessage, ToolMessage, ToolMessage, AI, ToolMessage + assert len(response["messages"]) == 6 + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + + def test_structured_output_parsing_error_without_retry(self) -> None: + """Test that StructuredOutputParsingError is raised when tool args fail to parse without retry.""" + tool_calls = [ + [ + { + "name": "WeatherBaseModel", + "id": "1", + "args": {"invalid": "data"}, + }, + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent( + model, + [], + response_format=ToolStrategy( + WeatherBaseModel, + handle_errors=False, + ), + ) + + with pytest.raises( + StructuredOutputValidationError, + match=".*WeatherBaseModel.*", + ): + agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + def test_structured_output_parsing_error_with_retry(self) -> None: + """Test that retry handles parsing errors for structured output.""" + tool_calls = [ + [ + { + "name": "WeatherBaseModel", + "id": "1", + "args": {"invalid": "data"}, + }, + ], + [ + { + "name": "WeatherBaseModel", + "id": "2", + "args": WEATHER_DATA, + }, + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent( + model, + [], + response_format=ToolStrategy( + WeatherBaseModel, + handle_errors=(StructuredOutputValidationError,), + ), + ) + + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + # HumanMessage, AIMessage, ToolMessage, AIMessage, ToolMessage + assert len(response["messages"]) == 5 + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + + def test_retry_with_custom_function(self) -> None: + """Test retry with custom message generation.""" + tool_calls = [ + [ + { + "name": "WeatherBaseModel", + "id": "1", + "args": WEATHER_DATA, + }, + { + "name": "LocationResponse", + "id": "2", + "args": LOCATION_DATA, + }, + ], + [ + { + "name": "WeatherBaseModel", + "id": "3", + "args": WEATHER_DATA, + }, + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + def custom_message(exception: Exception) -> str: + if isinstance(exception, MultipleStructuredOutputsError): + return "Custom error: Multiple outputs not allowed" + return "Custom error" + + agent = create_react_agent( + model, + [], + response_format=ToolStrategy( + Union[WeatherBaseModel, LocationResponse], + handle_errors=custom_message, + ), + ) + + response = agent.invoke({"messages": [HumanMessage("Give me weather")]}) + + # HumanMessage, AIMessage, ToolMessage, ToolMessage, AI, ToolMessage + assert len(response["messages"]) == 6 + assert response["messages"][2].content == "Custom error: Multiple outputs not allowed" + assert response["messages"][3].content == "Custom error: Multiple outputs not allowed" + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + + def test_retry_with_custom_string_message(self) -> None: + """Test retry with custom static string message.""" + tool_calls = [ + [ + { + "name": "WeatherBaseModel", + "id": "1", + "args": {"invalid": "data"}, + }, + ], + [ + { + "name": "WeatherBaseModel", + "id": "2", + "args": WEATHER_DATA, + }, + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_react_agent( + model, + [], + response_format=ToolStrategy( + WeatherBaseModel, + handle_errors="Please provide valid weather data with temperature and condition.", + ), + ) + + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert len(response["messages"]) == 5 + assert ( + response["messages"][2].content + == "Please provide valid weather data with temperature and condition." + ) + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + + +class TestResponseFormatAsProviderStrategy: + def test_pydantic_model(self) -> None: + """Test response_format as ProviderStrategy with Pydantic model.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + ] + + model = FakeToolCallingModel[WeatherBaseModel]( + tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC + ) + + agent = create_react_agent( + model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel) + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + assert len(response["messages"]) == 4 + + def test_dataclass(self) -> None: + """Test response_format as ProviderStrategy with dataclass.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + ] + + model = FakeToolCallingModel[WeatherDataclass]( + tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS + ) + + agent = create_react_agent( + model, [get_weather], response_format=ProviderStrategy(WeatherDataclass) + ) + response = agent.invoke( + {"messages": [HumanMessage("What's the weather?")]}, + ) + + assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS + assert len(response["messages"]) == 4 + + def test_typed_dict(self) -> None: + """Test response_format as ProviderStrategy with TypedDict.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + ] + + model = FakeToolCallingModel[WeatherTypedDict]( + tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT + ) + + agent = create_react_agent( + model, [get_weather], response_format=ProviderStrategy(WeatherTypedDict) + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_DICT + assert len(response["messages"]) == 4 + + def test_json_schema(self) -> None: + """Test response_format as ProviderStrategy with JSON schema.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + ] + + model = FakeToolCallingModel[dict]( + tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT + ) + + agent = create_react_agent( + model, [get_weather], response_format=ProviderStrategy(weather_json_schema) + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_DICT + assert len(response["messages"]) == 4 + + +def test_union_of_types() -> None: + """Test response_format as ProviderStrategy with Union (if supported).""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + [ + { + "name": "WeatherBaseModel", + "id": "2", + "args": WEATHER_DATA, + } + ], + ] + + model = FakeToolCallingModel[Union[WeatherBaseModel, LocationResponse]]( + tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC + ) + + agent = create_react_agent( + model, + [get_weather, get_location], + response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]), + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + assert len(response["messages"]) == 5 diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_responses.py b/libs/langchain_v1/tests/unit_tests/agents/test_responses.py new file mode 100644 index 00000000000..6f21e4b21c5 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_responses.py @@ -0,0 +1,140 @@ +"""Unit tests for langgraph.prebuilt.responses module.""" + +import pytest + +# Skip this test since langgraph.prebuilt.responses is not available +pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True) + + +class _TestModel(BaseModel): + """A test model for structured output.""" + + name: str + age: int + email: str = "default@example.com" + + +class CustomModel(BaseModel): + """Custom model with a custom docstring.""" + + value: float + description: str + + +class EmptyDocModel(BaseModel): + # No custom docstring, should have no description in tool + data: str + + +class TestUsingToolStrategy: + """Test UsingToolStrategy dataclass.""" + + def test_basic_creation(self) -> None: + """Test basic UsingToolStrategy creation.""" + strategy = ToolStrategy(schema=_TestModel) + assert strategy.schema == _TestModel + assert strategy.tool_message_content is None + assert len(strategy.schema_specs) == 1 + + def test_multiple_schemas(self) -> None: + """Test UsingToolStrategy with multiple schemas.""" + strategy = ToolStrategy(schema=Union[_TestModel, CustomModel]) + assert len(strategy.schema_specs) == 2 + assert strategy.schema_specs[0].schema == _TestModel + assert strategy.schema_specs[1].schema == CustomModel + + def test_schema_with_tool_message_content(self) -> None: + """Test UsingToolStrategy with tool message content.""" + strategy = ToolStrategy(schema=_TestModel, tool_message_content="custom message") + assert strategy.schema == _TestModel + assert strategy.tool_message_content == "custom message" + assert len(strategy.schema_specs) == 1 + + +class TestOutputToolBinding: + """Test OutputToolBinding dataclass and its methods.""" + + def test_from_schema_spec_basic(self) -> None: + """Test basic OutputToolBinding creation from SchemaSpec.""" + schema_spec = _SchemaSpec(schema=_TestModel) + tool_binding = OutputToolBinding.from_schema_spec(schema_spec) + + assert tool_binding.schema == _TestModel + assert tool_binding.schema_kind == "pydantic" + assert tool_binding.tool is not None + assert tool_binding.tool.name == "_TestModel" + + def test_from_schema_spec_with_custom_name(self) -> None: + """Test OutputToolBinding creation with custom name.""" + schema_spec = _SchemaSpec(schema=_TestModel, name="custom_tool_name") + tool_binding = OutputToolBinding.from_schema_spec(schema_spec) + assert tool_binding.tool.name == "custom_tool_name" + + def test_from_schema_spec_with_custom_description(self) -> None: + """Test OutputToolBinding creation with custom description.""" + schema_spec = _SchemaSpec(schema=_TestModel, description="Custom tool description") + tool_binding = OutputToolBinding.from_schema_spec(schema_spec) + + assert tool_binding.tool.description == "Custom tool description" + + def test_from_schema_spec_with_model_docstring(self) -> None: + """Test OutputToolBinding creation using model docstring as description.""" + schema_spec = _SchemaSpec(schema=CustomModel) + tool_binding = OutputToolBinding.from_schema_spec(schema_spec) + + assert tool_binding.tool.description == "Custom model with a custom docstring." + + @pytest.mark.skip(reason="Need to fix bug in langchain-core for inheritance of doc-strings.") + def test_from_schema_spec_empty_docstring(self) -> None: + """Test OutputToolBinding creation with model that has default docstring.""" + + # Create a model with the same docstring as BaseModel + class DefaultDocModel(BaseModel): + # This should have the same docstring as BaseModel + pass + + schema_spec = _SchemaSpec(schema=DefaultDocModel) + tool_binding = OutputToolBinding.from_schema_spec(schema_spec) + + # Should use empty description when model has default BaseModel docstring + assert tool_binding.tool.description == "" + + def test_parse_payload_pydantic_success(self) -> None: + """Test successful parsing for Pydantic model.""" + schema_spec = _SchemaSpec(schema=_TestModel) + tool_binding = OutputToolBinding.from_schema_spec(schema_spec) + + tool_args = {"name": "John", "age": 30} + result = tool_binding.parse(tool_args) + + assert isinstance(result, _TestModel) + assert result.name == "John" + assert result.age == 30 + assert result.email == "default@example.com" # default value + + def test_parse_payload_pydantic_validation_error(self) -> None: + """Test parsing failure for invalid Pydantic data.""" + schema_spec = _SchemaSpec(schema=_TestModel) + tool_binding = OutputToolBinding.from_schema_spec(schema_spec) + + # Missing required field 'name' + tool_args = {"age": 30} + + with pytest.raises(ValueError, match="Failed to parse data to _TestModel"): + tool_binding.parse(tool_args) + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_schemas_list(self) -> None: + """Test UsingToolStrategy with empty schemas list.""" + strategy = ToolStrategy(EmptyDocModel) + assert len(strategy.schema_specs) == 1 + + @pytest.mark.skip(reason="Need to fix bug in langchain-core for inheritance of doc-strings.") + def test_base_model_doc_constant(self) -> None: + """Test that BASE_MODEL_DOC constant is set correctly.""" + binding = OutputToolBinding.from_schema_spec(_SchemaSpec(EmptyDocModel)) + assert binding.tool.name == "EmptyDocModel" + assert binding.tool.description[:5] == "" # Should be empty for default docstring diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py b/libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py new file mode 100644 index 00000000000..ed1d67c3fc2 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import pytest + +# Skip this test since langgraph.prebuilt.responses is not available +pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True) + +try: + from langchain_openai import ChatOpenAI +except ImportError: + skip_openai_integration_tests = True +else: + skip_openai_integration_tests = False + +AGENT_PROMPT = "You are an HR assistant." + + +class ToolCalls(BaseSchema): + get_employee_role: int + get_employee_department: int + + +class AssertionByInvocation(BaseSchema): + prompt: str + tools_with_expected_calls: ToolCalls + expected_last_message: str + expected_structured_response: Optional[Dict[str, Any]] + llm_request_count: int + + +class TestCase(BaseSchema): + name: str + response_format: Union[Dict[str, Any], List[Dict[str, Any]]] + assertions_by_invocation: List[AssertionByInvocation] + + +class Employee(BaseModel): + name: str + role: str + department: str + + +EMPLOYEES: list[Employee] = [ + Employee(name="Sabine", role="Developer", department="IT"), + Employee(name="Henrik", role="Product Manager", department="IT"), + Employee(name="Jessica", role="HR", department="People"), +] + +TEST_CASES = load_spec("responses", as_model=TestCase) + + +def _make_tool(fn, *, name: str, description: str): + mock = MagicMock(side_effect=lambda *, name: fn(name=name)) + InputModel = create_model(f"{name}_input", name=(str, ...)) + + @tool(name, description=description, args_schema=InputModel) + def _wrapped(name: str): + return mock(name=name) + + return {"tool": _wrapped, "mock": mock} + + +@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI integration tests are disabled.") +@pytest.mark.parametrize("case", TEST_CASES, ids=[c.name for c in TEST_CASES]) +def test_responses_integration_matrix(case: TestCase) -> None: + if case.name == "asking for information that does not fit into the response format": + pytest.xfail( + "currently failing due to undefined behavior when model cannot conform to any of the structured response formats." + ) + + def get_employee_role(*, name: str) -> Optional[str]: + for e in EMPLOYEES: + if e.name == name: + return e.role + return None + + def get_employee_department(*, name: str) -> Optional[str]: + for e in EMPLOYEES: + if e.name == name: + return e.department + return None + + role_tool = _make_tool( + get_employee_role, + name="get_employee_role", + description="Get the employee role by name", + ) + dept_tool = _make_tool( + get_employee_department, + name="get_employee_department", + description="Get the employee department by name", + ) + + response_format_spec = case.response_format + if isinstance(response_format_spec, dict): + response_format_spec = [response_format_spec] + # Unwrap nested schema objects + response_format_spec = [item.get("schema", item) for item in response_format_spec] + if len(response_format_spec) == 1: + tool_output = ToolStrategy(response_format_spec[0]) + else: + tool_output = ToolStrategy({"oneOf": response_format_spec}) + + llm_request_count = 0 + + for assertion in case.assertions_by_invocation: + + def on_request(request: httpx.Request) -> None: + nonlocal llm_request_count + llm_request_count += 1 + + http_client = httpx.Client( + event_hooks={"request": [on_request]}, + ) + + model = ChatOpenAI( + model="gpt-4o", + temperature=0, + http_client=http_client, + ) + + agent = create_react_agent( + model, + tools=[role_tool["tool"], dept_tool["tool"]], + prompt=AGENT_PROMPT, + response_format=tool_output, + ) + + result = agent.invoke({"messages": [HumanMessage(assertion.prompt)]}) + + # Count tool calls + assert role_tool["mock"].call_count == assertion.tools_with_expected_calls.get_employee_role + assert ( + dept_tool["mock"].call_count + == assertion.tools_with_expected_calls.get_employee_department + ) + + # Count LLM calls + assert llm_request_count == assertion.llm_request_count + + # Check last message content + last_message = result["messages"][-1] + assert last_message.content == assertion.expected_last_message + + # Check structured response + structured_response_json = result["structured_response"] + assert structured_response_json == assertion.expected_structured_response diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_return_direct_spec.py b/libs/langchain_v1/tests/unit_tests/agents/test_return_direct_spec.py new file mode 100644 index 00000000000..d87c22bd283 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_return_direct_spec.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import pytest + +# Skip this test since langgraph.prebuilt.responses is not available +pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True) + +try: + from langchain_openai import ChatOpenAI +except ImportError: + skip_openai_integration_tests = True +else: + skip_openai_integration_tests = False + +AGENT_PROMPT = """ +You are a strict polling bot. + +- Only use the "poll_job" tool until it returns { status: "succeeded" }. +- If status is "pending", call the tool again. Do not produce a final answer. +- When it is "succeeded", return exactly: "Attempts: " with no extra text. +""" + + +class TestCase(BaseSchema): + name: str + return_direct: bool + response_format: Optional[Dict[str, Any]] + expected_tool_calls: int + expected_last_message: str + expected_structured_response: Optional[Dict[str, Any]] + + +TEST_CASES = load_spec("return_direct", as_model=TestCase) + + +def _make_tool(return_direct: bool): + attempts = 0 + + def _side_effect(): + nonlocal attempts + attempts += 1 + return { + "status": "succeeded" if attempts >= 10 else "pending", + "attempts": attempts, + } + + mock = MagicMock(side_effect=_side_effect) + + @tool( + "pollJob", + description=( + "Check the status of a long-running job. " + "Returns { status: 'pending' | 'succeeded', attempts: number }." + ), + return_direct=return_direct, + ) + def _wrapped(): + return mock() + + return {"tool": _wrapped, "mock": mock} + + +@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI integration tests are disabled.") +@pytest.mark.parametrize("case", TEST_CASES, ids=[c.name for c in TEST_CASES]) +def test_return_direct_integration_matrix(case: TestCase) -> None: + poll_tool = _make_tool(case.return_direct) + + model = ChatOpenAI( + model="gpt-4o", + temperature=0, + ) + + if case.response_format: + agent = create_react_agent( + model, + tools=[poll_tool["tool"]], + prompt=AGENT_PROMPT, + response_format=ToolStrategy(case.response_format), + ) + else: + agent = create_react_agent( + model, + tools=[poll_tool["tool"]], + prompt=AGENT_PROMPT, + ) + + result = agent.invoke( + { + "messages": [ + HumanMessage("Poll the job until it's done and tell me how many attempts it took.") + ] + } + ) + + # Count tool calls + assert poll_tool["mock"].call_count == case.expected_tool_calls + + # Check last message content + last_message = result["messages"][-1] + assert last_message.content == case.expected_last_message + + # Check structured response + if case.expected_structured_response is not None: + structured_response_json = result["structured_response"] + assert structured_response_json == case.expected_structured_response + else: + assert "structured_response" not in result diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py new file mode 100644 index 00000000000..8581cf3c7cf --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py @@ -0,0 +1,1482 @@ +import contextlib +import dataclasses +import json +from functools import partial +from typing import ( + Annotated, + Any, + NoReturn, + TypeVar, + Union, +) + +import pytest +from langchain_core.messages import ( + AIMessage, + AnyMessage, + HumanMessage, + RemoveMessage, + ToolCall, + ToolMessage, +) +from langchain_core.tools import BaseTool, ToolException +from langchain_core.tools import tool as dec_tool +from langgraph.config import get_stream_writer +from langgraph.errors import GraphBubbleUp, GraphInterrupt +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.graph.message import REMOVE_ALL_MESSAGES, add_messages +from langgraph.store.base import BaseStore +from langgraph.store.memory import InMemoryStore +from langgraph.types import Command, Send +from pydantic import BaseModel +from pydantic.v1 import BaseModel as BaseModelV1 +from typing_extensions import TypedDict + +from langchain.agents import ( + ToolNode, +) +from langchain.agents.tool_node import ( + TOOL_CALL_ERROR_TEMPLATE, + InjectedState, + InjectedStore, + ToolInvocationError, + tools_condition, +) + +from .messages import _AnyIdHumanMessage, _AnyIdToolMessage +from .model import FakeToolCallingModel + +pytestmark = pytest.mark.anyio + + +def tool1(some_val: int, some_other_val: str) -> str: + """Tool 1 docstring.""" + if some_val == 0: + msg = "Test error" + raise ValueError(msg) + return f"{some_val} - {some_other_val}" + + +async def tool2(some_val: int, some_other_val: str) -> str: + """Tool 2 docstring.""" + if some_val == 0: + msg = "Test error" + raise ToolException(msg) + return f"tool2: {some_val} - {some_other_val}" + + +async def tool3(some_val: int, some_other_val: str) -> str: + """Tool 3 docstring.""" + return [ + {"key_1": some_val, "key_2": "foo"}, + {"key_1": some_other_val, "key_2": "baz"}, + ] + + +async def tool4(some_val: int, some_other_val: str) -> str: + """Tool 4 docstring.""" + return [ + {"type": "image_url", "image_url": {"url": "abdc"}}, + ] + + +@dec_tool +def tool5(some_val: int) -> NoReturn: + """Tool 5 docstring.""" + msg = "Test error" + raise ToolException(msg) + + +tool5.handle_tool_error = "foo" + + +async def test_tool_node() -> None: + """Test tool node.""" + result = ToolNode([tool1]).invoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool1", + "args": {"some_val": 1, "some_other_val": "foo"}, + "id": "some 0", + } + ], + ) + ] + } + ) + + tool_message: ToolMessage = result["messages"][-1] + assert tool_message.type == "tool" + assert tool_message.content == "1 - foo" + assert tool_message.tool_call_id == "some 0" + + result2 = await ToolNode([tool2]).ainvoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool2", + "args": {"some_val": 2, "some_other_val": "bar"}, + "id": "some 1", + } + ], + ) + ] + } + ) + + tool_message: ToolMessage = result2["messages"][-1] + assert tool_message.type == "tool" + assert tool_message.content == "tool2: 2 - bar" + + # list of dicts tool content + result3 = await ToolNode([tool3]).ainvoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool3", + "args": {"some_val": 2, "some_other_val": "bar"}, + "id": "some 2", + } + ], + ) + ] + } + ) + tool_message: ToolMessage = result3["messages"][-1] + assert tool_message.type == "tool" + assert ( + tool_message.content == '[{"key_1": 2, "key_2": "foo"}, {"key_1": "bar", "key_2": "baz"}]' + ) + assert tool_message.tool_call_id == "some 2" + + # list of content blocks tool content + result4 = await ToolNode([tool4]).ainvoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool4", + "args": {"some_val": 2, "some_other_val": "bar"}, + "id": "some 3", + } + ], + ) + ] + } + ) + tool_message: ToolMessage = result4["messages"][-1] + assert tool_message.type == "tool" + assert tool_message.content == [{"type": "image_url", "image_url": {"url": "abdc"}}] + assert tool_message.tool_call_id == "some 3" + + +async def test_tool_node_tool_call_input() -> None: + # Single tool call + tool_call_1 = { + "name": "tool1", + "args": {"some_val": 1, "some_other_val": "foo"}, + "id": "some 0", + "type": "tool_call", + } + result = ToolNode([tool1]).invoke([tool_call_1]) + assert result["messages"] == [ + ToolMessage(content="1 - foo", tool_call_id="some 0", name="tool1"), + ] + + # Multiple tool calls + tool_call_2 = { + "name": "tool1", + "args": {"some_val": 2, "some_other_val": "bar"}, + "id": "some 1", + "type": "tool_call", + } + result = ToolNode([tool1]).invoke([tool_call_1, tool_call_2]) + assert result["messages"] == [ + ToolMessage(content="1 - foo", tool_call_id="some 0", name="tool1"), + ToolMessage(content="2 - bar", tool_call_id="some 1", name="tool1"), + ] + + # Test with unknown tool + tool_call_3 = tool_call_1.copy() + tool_call_3["name"] = "tool2" + result = ToolNode([tool1]).invoke([tool_call_1, tool_call_3]) + assert result["messages"] == [ + ToolMessage(content="1 - foo", tool_call_id="some 0", name="tool1"), + ToolMessage( + content="Error: tool2 is not a valid tool, try one of [tool1].", + name="tool2", + tool_call_id="some 0", + status="error", + ), + ] + + +def test_tool_node_error_handling_default_invocation() -> None: + tn = ToolNode([tool1]) + result = tn.invoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool1", + "args": {"invalid": 0, "args": "foo"}, + "id": "some id", + }, + ], + ) + ] + } + ) + + assert all(m.type == "tool" for m in result["messages"]) + assert all(m.status == "error" for m in result["messages"]) + assert ( + "Error invoking tool 'tool1' with kwargs {'invalid': 0, 'args': 'foo'} with error:\n" + in result["messages"][0].content + ) + + +def test_tool_node_error_handling_default_exception() -> None: + tn = ToolNode([tool1]) + with pytest.raises(ValueError): + tn.invoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool1", + "args": {"some_val": 0, "some_other_val": "foo"}, + "id": "some id", + }, + ], + ) + ] + } + ) + + +async def test_tool_node_error_handling() -> None: + def handle_all(e: Union[ValueError, ToolException, ToolInvocationError]): + return TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) + + # test catching all exceptions, via: + # - handle_tool_errors = True + # - passing a tuple of all exceptions + # - passing a callable with all exceptions in the signature + for handle_tool_errors in ( + True, + (ValueError, ToolException, ToolInvocationError), + handle_all, + ): + result_error = await ToolNode( + [tool1, tool2, tool3], handle_tool_errors=handle_tool_errors + ).ainvoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool1", + "args": {"some_val": 0, "some_other_val": "foo"}, + "id": "some id", + }, + { + "name": "tool2", + "args": {"some_val": 0, "some_other_val": "bar"}, + "id": "some other id", + }, + { + "name": "tool3", + "args": {"some_val": 0}, + "id": "another id", + }, + ], + ) + ] + } + ) + + assert all(m.type == "tool" for m in result_error["messages"]) + assert all(m.status == "error" for m in result_error["messages"]) + assert ( + result_error["messages"][0].content + == f"Error: {ValueError('Test error')!r}\n Please fix your mistakes." + ) + assert ( + result_error["messages"][1].content + == f"Error: {ToolException('Test error')!r}\n Please fix your mistakes." + ) + assert ( + "ValidationError" in result_error["messages"][2].content + or "validation error" in result_error["messages"][2].content + ) + + assert result_error["messages"][0].tool_call_id == "some id" + assert result_error["messages"][1].tool_call_id == "some other id" + assert result_error["messages"][2].tool_call_id == "another id" + + +async def test_tool_node_error_handling_callable() -> None: + def handle_value_error(e: ValueError) -> str: + return "Value error" + + def handle_tool_exception(e: ToolException) -> str: + return "Tool exception" + + for handle_tool_errors in ("Value error", handle_value_error): + result_error = await ToolNode([tool1], handle_tool_errors=handle_tool_errors).ainvoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool1", + "args": {"some_val": 0, "some_other_val": "foo"}, + "id": "some id", + }, + ], + ) + ] + } + ) + tool_message: ToolMessage = result_error["messages"][-1] + assert tool_message.type == "tool" + assert tool_message.status == "error" + assert tool_message.content == "Value error" + + # test raising for an unhandled exception, via: + # - passing a tuple of all exceptions + # - passing a callable with all exceptions in the signature + for handle_tool_errors in ((ValueError,), handle_value_error): + with pytest.raises(ToolException) as exc_info: + await ToolNode([tool1, tool2], handle_tool_errors=handle_tool_errors).ainvoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool1", + "args": {"some_val": 0, "some_other_val": "foo"}, + "id": "some id", + }, + { + "name": "tool2", + "args": {"some_val": 0, "some_other_val": "bar"}, + "id": "some other id", + }, + ], + ) + ] + } + ) + assert str(exc_info.value) == "Test error" + + for handle_tool_errors in ((ToolException,), handle_tool_exception): + with pytest.raises(ValueError) as exc_info: + await ToolNode([tool1, tool2], handle_tool_errors=handle_tool_errors).ainvoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool1", + "args": {"some_val": 0, "some_other_val": "foo"}, + "id": "some id", + }, + { + "name": "tool2", + "args": {"some_val": 0, "some_other_val": "bar"}, + "id": "some other id", + }, + ], + ) + ] + } + ) + assert str(exc_info.value) == "Test error" + + +async def test_tool_node_handle_tool_errors_false() -> None: + with pytest.raises(ValueError) as exc_info: + ToolNode([tool1], handle_tool_errors=False).invoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool1", + "args": {"some_val": 0, "some_other_val": "foo"}, + "id": "some id", + } + ], + ) + ] + } + ) + + assert str(exc_info.value) == "Test error" + + with pytest.raises(ToolException): + await ToolNode([tool2], handle_tool_errors=False).ainvoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool2", + "args": {"some_val": 0, "some_other_val": "bar"}, + "id": "some id", + } + ], + ) + ] + } + ) + + assert str(exc_info.value) == "Test error" + + # test validation errors get raised if handle_tool_errors is False + with pytest.raises(ToolInvocationError): + ToolNode([tool1], handle_tool_errors=False).invoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool1", + "args": {"some_val": 0}, + "id": "some id", + } + ], + ) + ] + } + ) + + +def test_tool_node_individual_tool_error_handling() -> None: + # test error handling on individual tools (and that it overrides overall error handling!) + result_individual_tool_error_handler = ToolNode([tool5], handle_tool_errors="bar").invoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool5", + "args": {"some_val": 0}, + "id": "some 0", + } + ], + ) + ] + } + ) + + tool_message: ToolMessage = result_individual_tool_error_handler["messages"][-1] + assert tool_message.type == "tool" + assert tool_message.status == "error" + assert tool_message.content == "foo" + assert tool_message.tool_call_id == "some 0" + + +def test_tool_node_incorrect_tool_name() -> None: + result_incorrect_name = ToolNode([tool1, tool2]).invoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool3", + "args": {"some_val": 1, "some_other_val": "foo"}, + "id": "some 0", + } + ], + ) + ] + } + ) + + tool_message: ToolMessage = result_incorrect_name["messages"][-1] + assert tool_message.type == "tool" + assert tool_message.status == "error" + assert tool_message.content == "Error: tool3 is not a valid tool, try one of [tool1, tool2]." + assert tool_message.tool_call_id == "some 0" + + +def test_tool_node_node_interrupt() -> None: + def tool_interrupt(some_val: int) -> None: + """Tool docstring.""" + msg = "foo" + raise GraphBubbleUp(msg) + + def handle(e: GraphInterrupt) -> str: + return "handled" + + for handle_tool_errors in (True, (GraphBubbleUp,), "handled", handle, False): + node = ToolNode([tool_interrupt], handle_tool_errors=handle_tool_errors) + with pytest.raises(GraphBubbleUp) as exc_info: + node.invoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool_interrupt", + "args": {"some_val": 0}, + "id": "some 0", + } + ], + ) + ] + } + ) + assert exc_info.value == "foo" + + +@pytest.mark.parametrize("input_type", ["dict", "tool_calls"]) +async def test_tool_node_command(input_type: str) -> None: + from langchain_core.tools.base import InjectedToolCallId + + @dec_tool + def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): + """Transfer to Bob""" + return Command( + update={ + "messages": [ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)] + }, + goto="bob", + graph=Command.PARENT, + ) + + @dec_tool + async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): + """Transfer to Bob""" + return Command( + update={ + "messages": [ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)] + }, + goto="bob", + graph=Command.PARENT, + ) + + class CustomToolSchema(BaseModel): + tool_call_id: Annotated[str, InjectedToolCallId] + + class MyCustomTool(BaseTool): + def _run(*args: Any, **kwargs: Any): + return Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ) + + async def _arun(*args: Any, **kwargs: Any): + return Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ) + + custom_tool = MyCustomTool( + name="custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, + ) + async_custom_tool = MyCustomTool( + name="async_custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, + ) + + # test mixing regular tools and tools returning commands + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + tool_calls = [ + {"args": {"a": 1, "b": 2}, "id": "1", "name": "add", "type": "tool_call"}, + {"args": {}, "id": "2", "name": "transfer_to_bob", "type": "tool_call"}, + ] + if input_type == "dict": + input_ = {"messages": [AIMessage("", tool_calls=tool_calls)]} + elif input_type == "tool_calls": + input_ = tool_calls + result = ToolNode([add, transfer_to_bob]).invoke(input_) + + assert result == [ + { + "messages": [ + ToolMessage( + content="3", + tool_call_id="1", + name="add", + ) + ] + }, + Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id="2", + name="transfer_to_bob", + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ), + ] + + # test tools returning commands + + # test sync tools + for tool in [transfer_to_bob, custom_tool]: + result = ToolNode([tool]).invoke( + {"messages": [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])]} + ) + assert result == [ + Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name=tool.name, + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ) + ] + + # test async tools + for tool in [async_transfer_to_bob, async_custom_tool]: + result = await ToolNode([tool]).ainvoke( + {"messages": [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])]} + ) + assert result == [ + Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name=tool.name, + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ) + ] + + # test multiple commands + result = ToolNode([transfer_to_bob, custom_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[ + {"args": {}, "id": "1", "name": "transfer_to_bob"}, + {"args": {}, "id": "2", "name": "custom_transfer_to_bob"}, + ], + ) + ] + } + ) + assert result == [ + Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name="transfer_to_bob", + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ), + Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id="2", + name="custom_transfer_to_bob", + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ), + ] + + # test validation (mismatch between input type and command.update type) + with pytest.raises(ValueError): + + @dec_tool + def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]): + """My tool""" + return Command(update=[ToolMessage(content="foo", tool_call_id=tool_call_id)]) + + ToolNode([list_update_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "list_update_tool"}], + ) + ] + } + ) + + # test validation (missing tool message in the update for current graph) + with pytest.raises(ValueError): + + @dec_tool + def no_update_tool(): + """My tool""" + return Command(update={"messages": []}) + + ToolNode([no_update_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "no_update_tool"}], + ) + ] + } + ) + + # test validation (tool message with a wrong tool call ID) + with pytest.raises(ValueError): + + @dec_tool + def mismatching_tool_call_id_tool(): + """My tool""" + return Command(update={"messages": [ToolMessage(content="foo", tool_call_id="2")]}) + + ToolNode([mismatching_tool_call_id_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[ + { + "args": {}, + "id": "1", + "name": "mismatching_tool_call_id_tool", + } + ], + ) + ] + } + ) + + # test validation (missing tool message in the update for parent graph is OK) + @dec_tool + def node_update_parent_tool(): + """No update""" + return Command(update={"messages": []}, graph=Command.PARENT) + + assert ToolNode([node_update_parent_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "node_update_parent_tool"}], + ) + ] + } + ) == [Command(update={"messages": []}, graph=Command.PARENT)] + + +async def test_tool_node_command_list_input() -> None: + from langchain_core.tools.base import InjectedToolCallId + + @dec_tool + def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): + """Transfer to Bob""" + return Command( + update=[ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)], + goto="bob", + graph=Command.PARENT, + ) + + @dec_tool + async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): + """Transfer to Bob""" + return Command( + update=[ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)], + goto="bob", + graph=Command.PARENT, + ) + + class CustomToolSchema(BaseModel): + tool_call_id: Annotated[str, InjectedToolCallId] + + class MyCustomTool(BaseTool): + def _run(*args: Any, **kwargs: Any): + return Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], + ) + ], + goto="bob", + graph=Command.PARENT, + ) + + async def _arun(*args: Any, **kwargs: Any): + return Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], + ) + ], + goto="bob", + graph=Command.PARENT, + ) + + custom_tool = MyCustomTool( + name="custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, + ) + async_custom_tool = MyCustomTool( + name="async_custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, + ) + + # test mixing regular tools and tools returning commands + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + result = ToolNode([add, transfer_to_bob]).invoke( + [ + AIMessage( + "", + tool_calls=[ + {"args": {"a": 1, "b": 2}, "id": "1", "name": "add"}, + {"args": {}, "id": "2", "name": "transfer_to_bob"}, + ], + ) + ] + ) + + assert result == [ + [ + ToolMessage( + content="3", + tool_call_id="1", + name="add", + ) + ], + Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id="2", + name="transfer_to_bob", + ) + ], + goto="bob", + graph=Command.PARENT, + ), + ] + + # test tools returning commands + + # test sync tools + for tool in [transfer_to_bob, custom_tool]: + result = ToolNode([tool]).invoke( + [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])] + ) + assert result == [ + Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name=tool.name, + ) + ], + goto="bob", + graph=Command.PARENT, + ) + ] + + # test async tools + for tool in [async_transfer_to_bob, async_custom_tool]: + result = await ToolNode([tool]).ainvoke( + [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])] + ) + assert result == [ + Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name=tool.name, + ) + ], + goto="bob", + graph=Command.PARENT, + ) + ] + + # test multiple commands + result = ToolNode([transfer_to_bob, custom_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[ + {"args": {}, "id": "1", "name": "transfer_to_bob"}, + {"args": {}, "id": "2", "name": "custom_transfer_to_bob"}, + ], + ) + ] + ) + assert result == [ + Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name="transfer_to_bob", + ) + ], + goto="bob", + graph=Command.PARENT, + ), + Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id="2", + name="custom_transfer_to_bob", + ) + ], + goto="bob", + graph=Command.PARENT, + ), + ] + + # test validation (mismatch between input type and command.update type) + with pytest.raises(ValueError): + + @dec_tool + def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]): + """My tool""" + return Command( + update={"messages": [ToolMessage(content="foo", tool_call_id=tool_call_id)]} + ) + + ToolNode([list_update_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "list_update_tool"}], + ) + ] + ) + + # test validation (missing tool message in the update for current graph) + with pytest.raises(ValueError): + + @dec_tool + def no_update_tool(): + """My tool""" + return Command(update=[]) + + ToolNode([no_update_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "no_update_tool"}], + ) + ] + ) + + # test validation (tool message with a wrong tool call ID) + with pytest.raises(ValueError): + + @dec_tool + def mismatching_tool_call_id_tool(): + """My tool""" + return Command(update=[ToolMessage(content="foo", tool_call_id="2")]) + + ToolNode([mismatching_tool_call_id_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "mismatching_tool_call_id_tool"}], + ) + ] + ) + + # test validation (missing tool message in the update for parent graph is OK) + @dec_tool + def node_update_parent_tool(): + """No update""" + return Command(update=[], graph=Command.PARENT) + + assert ToolNode([node_update_parent_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "node_update_parent_tool"}], + ) + ] + ) == [Command(update=[], graph=Command.PARENT)] + + +def test_tool_node_parent_command_with_send() -> None: + from langchain_core.tools.base import InjectedToolCallId + + @dec_tool + def transfer_to_alice(tool_call_id: Annotated[str, InjectedToolCallId]): + """Transfer to Alice""" + return Command( + goto=[ + Send( + "alice", + { + "messages": [ + ToolMessage( + content="Transferred to Alice", + name="transfer_to_alice", + tool_call_id=tool_call_id, + ) + ] + }, + ) + ], + graph=Command.PARENT, + ) + + @dec_tool + def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): + """Transfer to Bob""" + return Command( + goto=[ + Send( + "bob", + { + "messages": [ + ToolMessage( + content="Transferred to Bob", + name="transfer_to_bob", + tool_call_id=tool_call_id, + ) + ] + }, + ) + ], + graph=Command.PARENT, + ) + + tool_calls = [ + {"args": {}, "id": "1", "name": "transfer_to_alice", "type": "tool_call"}, + {"args": {}, "id": "2", "name": "transfer_to_bob", "type": "tool_call"}, + ] + + result = ToolNode([transfer_to_alice, transfer_to_bob]).invoke( + [AIMessage("", tool_calls=tool_calls)] + ) + + assert result == [ + Command( + goto=[ + Send( + "alice", + { + "messages": [ + ToolMessage( + content="Transferred to Alice", + name="transfer_to_alice", + tool_call_id="1", + ) + ] + }, + ), + Send( + "bob", + { + "messages": [ + ToolMessage( + content="Transferred to Bob", + name="transfer_to_bob", + tool_call_id="2", + ) + ] + }, + ), + ], + graph=Command.PARENT, + ) + ] + + +async def test_tool_node_command_remove_all_messages() -> None: + from langchain_core.tools.base import InjectedToolCallId + + @dec_tool + def remove_all_messages_tool(tool_call_id: Annotated[str, InjectedToolCallId]): + """A tool that removes all messages.""" + return Command(update={"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)]}) + + tool_node = ToolNode([remove_all_messages_tool]) + tool_call = { + "name": "remove_all_messages_tool", + "args": {}, + "id": "tool_call_123", + } + result = await tool_node.ainvoke({"messages": [AIMessage(content="", tool_calls=[tool_call])]}) + + assert isinstance(result, list) + assert len(result) == 1 + command = result[0] + assert isinstance(command, Command) + assert command.update == {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)]} + + +class _InjectStateSchema(TypedDict): + messages: list + foo: str + + +class _InjectedStatePydanticSchema(BaseModelV1): + messages: list + foo: str + + +class _InjectedStatePydanticV2Schema(BaseModel): + messages: list + foo: str + + +@dataclasses.dataclass +class _InjectedStateDataclassSchema: + messages: list + foo: str + + +T = TypeVar("T") + + +@pytest.mark.parametrize( + "schema_", + [ + _InjectStateSchema, + _InjectedStatePydanticSchema, + _InjectedStatePydanticV2Schema, + _InjectedStateDataclassSchema, + ], +) +def test_tool_node_inject_state(schema_: type[T]) -> None: + def tool1(some_val: int, state: Annotated[T, InjectedState]) -> str: + """Tool 1 docstring.""" + if isinstance(state, dict): + return state["foo"] + return state.foo + + def tool2(some_val: int, state: Annotated[T, InjectedState()]) -> str: + """Tool 2 docstring.""" + if isinstance(state, dict): + return state["foo"] + return state.foo + + def tool3( + some_val: int, + foo: Annotated[str, InjectedState("foo")], + msgs: Annotated[list[AnyMessage], InjectedState("messages")], + ) -> str: + """Tool 1 docstring.""" + return foo + + def tool4(some_val: int, msgs: Annotated[list[AnyMessage], InjectedState("messages")]) -> str: + """Tool 1 docstring.""" + return msgs[0].content + + node = ToolNode([tool1, tool2, tool3, tool4], handle_tool_errors=True) + for tool_name in ("tool1", "tool2", "tool3"): + tool_call = { + "name": tool_name, + "args": {"some_val": 1}, + "id": "some 0", + "type": "tool_call", + } + msg = AIMessage("hi?", tool_calls=[tool_call]) + result = node.invoke(schema_(messages=[msg], foo="bar")) + tool_message = result["messages"][-1] + assert tool_message.content == "bar", f"Failed for tool={tool_name}" + + if tool_name == "tool3": + failure_input = None + with contextlib.suppress(Exception): + failure_input = schema_(messages=[msg], notfoo="bar") + if failure_input is not None: + with pytest.raises(KeyError): + node.invoke(failure_input) + + with pytest.raises(ValueError): + node.invoke([msg]) + else: + failure_input = None + try: + failure_input = schema_(messages=[msg], notfoo="bar") + except Exception: + # We'd get a validation error from pydantic state and wouldn't make it to the node + # anyway + pass + if failure_input is not None: + messages_ = node.invoke(failure_input) + tool_message = messages_["messages"][-1] + assert "KeyError" in tool_message.content + tool_message = node.invoke([msg])[-1] + assert "KeyError" in tool_message.content + + tool_call = { + "name": "tool4", + "args": {"some_val": 1}, + "id": "some 0", + "type": "tool_call", + } + msg = AIMessage("hi?", tool_calls=[tool_call]) + result = node.invoke(schema_(messages=[msg], foo="")) + tool_message = result["messages"][-1] + assert tool_message.content == "hi?" + + result = node.invoke([msg]) + tool_message = result[-1] + assert tool_message.content == "hi?" + + +def test_tool_node_inject_store() -> None: + store = InMemoryStore() + namespace = ("test",) + + def tool1(some_val: int, store: Annotated[BaseStore, InjectedStore()]) -> str: + """Tool 1 docstring.""" + store_val = store.get(namespace, "test_key").value["foo"] + return f"Some val: {some_val}, store val: {store_val}" + + def tool2(some_val: int, store: Annotated[BaseStore, InjectedStore()]) -> str: + """Tool 2 docstring.""" + store_val = store.get(namespace, "test_key").value["foo"] + return f"Some val: {some_val}, store val: {store_val}" + + def tool3( + some_val: int, + bar: Annotated[str, InjectedState("bar")], + store: Annotated[BaseStore, InjectedStore()], + ) -> str: + """Tool 3 docstring.""" + store_val = store.get(namespace, "test_key").value["foo"] + return f"Some val: {some_val}, store val: {store_val}, state val: {bar}" + + node = ToolNode([tool1, tool2, tool3], handle_tool_errors=True) + store.put(namespace, "test_key", {"foo": "bar"}) + + class State(MessagesState): + bar: str + + builder = StateGraph(State) + builder.add_node("tools", node) + builder.add_edge(START, "tools") + graph = builder.compile(store=store) + + for tool_name in ("tool1", "tool2"): + tool_call = { + "name": tool_name, + "args": {"some_val": 1}, + "id": "some 0", + "type": "tool_call", + } + msg = AIMessage("hi?", tool_calls=[tool_call]) + node_result = node.invoke({"messages": [msg]}, store=store) + graph_result = graph.invoke({"messages": [msg]}) + for result in (node_result, graph_result): + result["messages"][-1] + tool_message = result["messages"][-1] + assert tool_message.content == "Some val: 1, store val: bar", ( + f"Failed for tool={tool_name}" + ) + + tool_call = { + "name": "tool3", + "args": {"some_val": 1}, + "id": "some 0", + "type": "tool_call", + } + msg = AIMessage("hi?", tool_calls=[tool_call]) + node_result = node.invoke({"messages": [msg], "bar": "baz"}, store=store) + graph_result = graph.invoke({"messages": [msg], "bar": "baz"}) + for result in (node_result, graph_result): + result["messages"][-1] + tool_message = result["messages"][-1] + assert tool_message.content == "Some val: 1, store val: bar, state val: baz", ( + f"Failed for tool={tool_name}" + ) + + # test injected store without passing store to compiled graph + failing_graph = builder.compile() + with pytest.raises(ValueError): + failing_graph.invoke({"messages": [msg], "bar": "baz"}) + + +def test_tool_node_ensure_utf8() -> None: + @dec_tool + def get_day_list(days: list[str]) -> list[str]: + """choose days""" + return days + + data = ["星期一", "水曜日", "목요일", "Friday"] + tools = [get_day_list] + tool_calls = [ToolCall(name=get_day_list.name, args={"days": data}, id="test_id")] + outputs: list[ToolMessage] = ToolNode(tools).invoke( + [AIMessage(content="", tool_calls=tool_calls)] + ) + assert outputs[0].content == json.dumps(data, ensure_ascii=False) + + +def test_tool_node_messages_key() -> None: + @dec_tool + def add(a: int, b: int) -> int: + """Adds a and b.""" + return a + b + + model = FakeToolCallingModel( + tool_calls=[[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")]] + ) + + class State(TypedDict): + subgraph_messages: Annotated[list[AnyMessage], add_messages] + + def call_model(state: State) -> dict[str, Any]: + response = model.invoke(state["subgraph_messages"]) + model.tool_calls = [] + return {"subgraph_messages": response} + + builder = StateGraph(State) + builder.add_node("agent", call_model) + builder.add_node("tools", ToolNode([add], messages_key="subgraph_messages")) + builder.add_conditional_edges( + "agent", partial(tools_condition, messages_key="subgraph_messages") + ) + builder.add_edge(START, "agent") + builder.add_edge("tools", "agent") + + graph = builder.compile() + result = graph.invoke({"subgraph_messages": [HumanMessage(content="hi")]}) + assert result["subgraph_messages"] == [ + _AnyIdHumanMessage(content="hi"), + AIMessage( + content="hi", + id="0", + tool_calls=[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")], + ), + _AnyIdToolMessage(content="3", name=add.name, tool_call_id="test_id"), + AIMessage(content="hi-hi-3", id="1"), + ] + + +def test_tool_node_stream_writer() -> None: + @dec_tool + def streaming_tool(x: int) -> str: + """Do something with writer.""" + my_writer = get_stream_writer() + for value in ["foo", "bar", "baz"]: + my_writer({"custom_tool_value": value}) + + return x + + tool_node = ToolNode([streaming_tool]) + graph = ( + StateGraph(MessagesState).add_node("tools", tool_node).add_edge(START, "tools").compile() + ) + + tool_call = { + "name": "streaming_tool", + "args": {"x": 1}, + "id": "1", + "type": "tool_call", + } + inputs = { + "messages": [AIMessage("", tool_calls=[tool_call])], + } + + assert list(graph.stream(inputs, stream_mode="custom")) == [ + {"custom_tool_value": "foo"}, + {"custom_tool_value": "bar"}, + {"custom_tool_value": "baz"}, + ] + assert list(graph.stream(inputs, stream_mode=["custom", "updates"])) == [ + ("custom", {"custom_tool_value": "foo"}), + ("custom", {"custom_tool_value": "bar"}), + ("custom", {"custom_tool_value": "baz"}), + ( + "updates", + { + "tools": { + "messages": [ + _AnyIdToolMessage( + content="1", + name="streaming_tool", + tool_call_id="1", + ), + ], + }, + }, + ), + ] diff --git a/libs/langchain_v1/tests/unit_tests/agents/utils.py b/libs/langchain_v1/tests/unit_tests/agents/utils.py new file mode 100644 index 00000000000..21a9d438af5 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/utils.py @@ -0,0 +1,21 @@ +import json +from pathlib import Path + +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel + + +class BaseSchema(BaseModel): + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + from_attributes=True, + ) + + +def load_spec(spec_name: str, as_model: type[BaseModel]) -> list[BaseModel]: + with (Path(__file__).parent / "specifications" / f"{spec_name}.json").open( + "r", encoding="utf-8" + ) as f: + data = json.load(f) + return [as_model(**item) for item in data] diff --git a/libs/langchain_v1/tests/unit_tests/conftest.py b/libs/langchain_v1/tests/unit_tests/conftest.py index 6656529a4d1..46db1647781 100644 --- a/libs/langchain_v1/tests/unit_tests/conftest.py +++ b/libs/langchain_v1/tests/unit_tests/conftest.py @@ -1,42 +1,9 @@ """Configuration for unit tests.""" -from collections.abc import Iterator, Sequence +from collections.abc import Sequence from importlib import util import pytest -from blockbuster import blockbuster_ctx - - -@pytest.fixture(autouse=True) -def blockbuster() -> Iterator[None]: - with blockbuster_ctx("langchain") as bb: - bb.functions["io.TextIOWrapper.read"].can_block_in( - "langchain/__init__.py", - "", - ) - - for func in ["os.stat", "os.path.abspath"]: - ( - bb.functions[func] - .can_block_in("langchain_core/runnables/base.py", "__repr__") - .can_block_in( - "langchain_core/beta/runnables/context.py", - "aconfig_with_context", - ) - ) - - for func in ["os.stat", "io.TextIOWrapper.read"]: - bb.functions[func].can_block_in( - "langsmith/client.py", - "_default_retry_config", - ) - - for bb_function in bb.functions.values(): - bb_function.can_block_in( - "freezegun/api.py", - "_get_cached_module_attributes", - ) - yield def pytest_addoption(parser: pytest.Parser) -> None: diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock index fbb87832bac..2daff5ca14b 100644 --- a/libs/langchain_v1/uv.lock +++ b/libs/langchain_v1/uv.lock @@ -264,18 +264,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, ] -[[package]] -name = "blockbuster" -version = "1.5.25" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "forbiddenfruit", marker = "implementation_name == 'cpython'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/7f/bc/57c49465decaeeedd58ce2d970b4cdfd93a74ba9993abff2dc498a31c283/blockbuster-1.5.25.tar.gz", hash = "sha256:b72f1d2aefdeecd2a820ddf1e1c8593bf00b96e9fdc4cd2199ebafd06f7cb8f0", size = 36058, upload-time = "2025-07-14T16:00:20.766Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/01/dccc277c014f171f61a6047bb22c684e16c7f2db6bb5c8cce1feaf41ec55/blockbuster-1.5.25-py3-none-any.whl", hash = "sha256:cb06229762273e0f5f3accdaed3d2c5a3b61b055e38843de202311ede21bb0f5", size = 13196, upload-time = "2025-07-14T16:00:19.396Z" }, -] - [[package]] name = "boto3" version = "1.40.18" @@ -724,7 +712,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -774,12 +762,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5f/e7/319a4ce37bed682741bc8ebbb84b7983da3d8cd7ac069d86b52a37d79f2e/fireworks_ai-0.15.15-py3-none-any.whl", hash = "sha256:1047b8e575a536898a827b089b0022c1fab207940f9773b90fa357ebf942f5c9", size = 112831, upload-time = "2025-06-20T21:11:26.701Z" }, ] -[[package]] -name = "forbiddenfruit" -version = "0.1.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e6/79/d4f20e91327c98096d605646bdc6a5ffedae820f38d378d3515c42ec5e60/forbiddenfruit-0.1.4.tar.gz", hash = "sha256:e3f7e66561a29ae129aac139a85d610dbf3dd896128187ed5454b6421f624253", size = 43756, upload-time = "2021-01-16T21:03:35.401Z" } - [[package]] name = "frozenlist" version = "1.7.0" @@ -1514,13 +1496,13 @@ lint = [ { name = "ruff" }, ] test = [ - { name = "blockbuster" }, { name = "langchain-openai" }, { name = "langchain-tests" }, { name = "langchain-text-splitters" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-mock" }, { name = "pytest-socket" }, { name = "pytest-watcher" }, { name = "pytest-xdist" }, @@ -1571,13 +1553,13 @@ lint = [ { name = "ruff", specifier = ">=0.12.2" }, ] test = [ - { name = "blockbuster", specifier = ">=1.5.18" }, { name = "langchain-openai", editable = "../partners/openai" }, { name = "langchain-tests", editable = "../standard-tests" }, { name = "langchain-text-splitters", editable = "../text-splitters" }, { name = "pytest", specifier = ">=8,<9" }, { name = "pytest-asyncio", specifier = ">=0.23.2" }, { name = "pytest-cov", specifier = ">=4.0.0" }, + { name = "pytest-mock" }, { name = "pytest-socket", specifier = ">=0.6.0" }, { name = "pytest-watcher", specifier = ">=0.2.6" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, @@ -3169,6 +3151,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/16/4ea354101abb1287856baa4af2732be351c7bee728065aed451b678153fd/pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5", size = 24644, upload-time = "2025-06-12T10:47:45.932Z" }, ] +[[package]] +name = "pytest-mock" +version = "3.14.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/28/67172c96ba684058a4d24ffe144d64783d2a270d0af0d9e792737bddc75c/pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e", size = 33241, upload-time = "2025-05-26T13:58:45.167Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/05/77b60e520511c53d1c1ca75f1930c7dd8e971d0c4379b7f4b3f9644685ba/pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0", size = 9923, upload-time = "2025-05-26T13:58:43.487Z" }, +] + [[package]] name = "pytest-recording" version = "0.13.4"