diff --git a/libs/langchain/tests/unit_tests/chat_models/test_base.py b/libs/langchain/tests/unit_tests/chat_models/test_base.py index cd91955f29b..46787618e4d 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_base.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_base.py @@ -270,6 +270,7 @@ def test_configurable_with_default() -> None: "default_headers": None, "model_kwargs": {}, "reuse_last_container": None, + "inference_geo": None, "streaming": False, "stream_usage": True, "output_version": None, diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 314d9b6ee71..762e1c55b4c 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -3,10 +3,12 @@ from __future__ import annotations import itertools +from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, Annotated, Any, + Generic, cast, get_args, get_origin, @@ -27,6 +29,7 @@ from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, ContextT, + ExtendedModelResponse, JumpTo, ModelRequest, ModelResponse, @@ -49,6 +52,23 @@ from langchain.agents.structured_output import ( ) from langchain.chat_models import init_chat_model + +@dataclass +class _ComposedExtendedModelResponse(Generic[ResponseT]): + """Internal result from composed ``wrap_model_call`` middleware. + + Unlike ``ExtendedModelResponse`` (user-facing, single command), this holds the + full list of commands accumulated across all middleware layers during + composition. + """ + + model_response: ModelResponse[ResponseT] + """The underlying model response.""" + + commands: list[Command[Any]] = field(default_factory=list) + """Commands accumulated from all middleware layers (inner-first, then outer).""" + + if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Sequence @@ -61,6 +81,27 @@ if TYPE_CHECKING: from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper + _ModelCallHandler = Callable[ + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], + ModelResponse | AIMessage | ExtendedModelResponse, + ] + + _ComposedModelCallHandler = Callable[ + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], + _ComposedExtendedModelResponse, + ] + + _AsyncModelCallHandler = Callable[ + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], + Awaitable[ModelResponse | AIMessage | ExtendedModelResponse], + ] + + _ComposedAsyncModelCallHandler = Callable[ + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], + Awaitable[_ComposedExtendedModelResponse], + ] + + STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." DYNAMIC_TOOL_ERROR_TEMPLATE = """ @@ -102,31 +143,72 @@ FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [ ] -def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse: - """Normalize middleware return value to ModelResponse.""" +def _normalize_to_model_response( + result: ModelResponse | AIMessage | ExtendedModelResponse, +) -> ModelResponse: + """Normalize middleware return value to ModelResponse. + + At inner composition boundaries, ``ExtendedModelResponse`` is unwrapped to its + underlying ``ModelResponse`` so that inner middleware always sees ``ModelResponse`` + from the handler. + """ if isinstance(result, AIMessage): return ModelResponse(result=[result], structured_response=None) + if isinstance(result, ExtendedModelResponse): + return result.model_response return result +def _build_commands( + model_response: ModelResponse, + middleware_commands: list[Command[Any]] | None = None, +) -> list[Command[Any]]: + """Build a list of Commands from a model response and middleware commands. + + The first Command contains the model response state (messages and optional + structured_response). Middleware commands are appended as-is. + + Args: + model_response: The model response containing messages and optional + structured output. + middleware_commands: Commands accumulated from middleware layers during + composition (inner-first ordering). + + Returns: + List of ``Command`` objects ready to be returned from a model node. + """ + state: dict[str, Any] = {"messages": model_response.result} + + if model_response.structured_response is not None: + state["structured_response"] = model_response.structured_response + + for cmd in middleware_commands or []: + if cmd.goto: + msg = ( + "Command goto is not yet supported in wrap_model_call middleware. " + "Use the jump_to state field with before_model/after_model hooks instead." + ) + raise NotImplementedError(msg) + if cmd.resume: + msg = "Command resume is not yet supported in wrap_model_call middleware." + raise NotImplementedError(msg) + if cmd.graph: + msg = "Command graph is not yet supported in wrap_model_call middleware." + raise NotImplementedError(msg) + + commands: list[Command[Any]] = [Command(update=state)] + commands.extend(middleware_commands or []) + return commands + + def _chain_model_call_handlers( - handlers: Sequence[ - Callable[ - [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], - ModelResponse | AIMessage, - ] - ], -) -> ( - Callable[ - [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], - ModelResponse, - ] - | None -): - """Compose multiple `wrap_model_call` handlers into single middleware stack. + handlers: Sequence[_ModelCallHandler[ContextT]], +) -> _ComposedModelCallHandler[ContextT] | None: + """Compose multiple ``wrap_model_call`` handlers into single middleware stack. Composes handlers so first in list becomes outermost layer. Each handler receives a - handler callback to execute inner layers. + handler callback to execute inner layers. Commands from each layer are accumulated + into a list (inner-first, then outer) without merging. Args: handlers: List of handlers. @@ -134,110 +216,90 @@ def _chain_model_call_handlers( First handler wraps all others. Returns: - Composed handler, or `None` if handlers empty. - - Example: - ```python - # handlers=[auth, retry] means: auth wraps retry - # Flow: auth calls retry, retry calls base handler - def auth(req, state, runtime, handler): - try: - return handler(req) - except UnauthorizedError: - refresh_token() - return handler(req) - - - def retry(req, state, runtime, handler): - for attempt in range(3): - try: - return handler(req) - except Exception: - if attempt == 2: - raise - - - handler = _chain_model_call_handlers([auth, retry]) - ``` + Composed handler returning ``_ComposedExtendedModelResponse``, + or ``None`` if handlers empty. """ if not handlers: return None + def _to_composed_result( + result: ModelResponse | AIMessage | ExtendedModelResponse | _ComposedExtendedModelResponse, + extra_commands: list[Command[Any]] | None = None, + ) -> _ComposedExtendedModelResponse: + """Normalize any handler result to _ComposedExtendedModelResponse.""" + commands: list[Command[Any]] = list(extra_commands or []) + if isinstance(result, _ComposedExtendedModelResponse): + commands.extend(result.commands) + model_response = result.model_response + elif isinstance(result, ExtendedModelResponse): + model_response = result.model_response + if result.command is not None: + commands.append(result.command) + else: + model_response = _normalize_to_model_response(result) + + return _ComposedExtendedModelResponse(model_response=model_response, commands=commands) + if len(handlers) == 1: - # Single handler - wrap to normalize output single_handler = handlers[0] def normalized_single( request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], ModelResponse], - ) -> ModelResponse: - result = single_handler(request, handler) - return _normalize_to_model_response(result) + ) -> _ComposedExtendedModelResponse: + return _to_composed_result(single_handler(request, handler)) return normalized_single def compose_two( - outer: Callable[ - [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], - ModelResponse | AIMessage, - ], - inner: Callable[ - [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], - ModelResponse | AIMessage, - ], - ) -> Callable[ - [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], - ModelResponse, - ]: + outer: _ModelCallHandler[ContextT] | _ComposedModelCallHandler[ContextT], + inner: _ModelCallHandler[ContextT] | _ComposedModelCallHandler[ContextT], + ) -> _ComposedModelCallHandler[ContextT]: """Compose two handlers where outer wraps inner.""" def composed( request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], ModelResponse], - ) -> ModelResponse: - # Create a wrapper that calls inner with the base handler and normalizes + ) -> _ComposedExtendedModelResponse: + # Closure variable to capture inner's commands before normalizing + accumulated_commands: list[Command[Any]] = [] + def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse: + # Clear on each call for retry safety + accumulated_commands.clear() inner_result = inner(req, handler) + if isinstance(inner_result, _ComposedExtendedModelResponse): + accumulated_commands.extend(inner_result.commands) + return inner_result.model_response + if isinstance(inner_result, ExtendedModelResponse): + if inner_result.command is not None: + accumulated_commands.append(inner_result.command) + return inner_result.model_response return _normalize_to_model_response(inner_result) - # Call outer with the wrapped inner as its handler and normalize outer_result = outer(request, inner_handler) - return _normalize_to_model_response(outer_result) + return _to_composed_result( + outer_result, + extra_commands=accumulated_commands or None, + ) return composed # Compose right-to-left: outer(inner(innermost(handler))) - result = handlers[-1] - for handler in reversed(handlers[:-1]): - result = compose_two(handler, result) + composed_handler = compose_two(handlers[-2], handlers[-1]) + for h in reversed(handlers[:-2]): + composed_handler = compose_two(h, composed_handler) - # Wrap to ensure final return type is exactly ModelResponse - def final_normalized( - request: ModelRequest[ContextT], - handler: Callable[[ModelRequest[ContextT]], ModelResponse], - ) -> ModelResponse: - # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes - final_result = result(request, handler) - return _normalize_to_model_response(final_result) - - return final_normalized + return composed_handler def _chain_async_model_call_handlers( - handlers: Sequence[ - Callable[ - [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], - Awaitable[ModelResponse | AIMessage], - ] - ], -) -> ( - Callable[ - [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], - Awaitable[ModelResponse], - ] - | None -): - """Compose multiple async `wrap_model_call` handlers into single middleware stack. + handlers: Sequence[_AsyncModelCallHandler[ContextT]], +) -> _ComposedAsyncModelCallHandler[ContextT] | None: + """Compose multiple async ``wrap_model_call`` handlers into single middleware stack. + + Commands from each layer are accumulated into a list (inner-first, then outer) + without merging. Args: handlers: List of async handlers. @@ -245,69 +307,81 @@ def _chain_async_model_call_handlers( First handler wraps all others. Returns: - Composed async handler, or `None` if handlers empty. + Composed async handler returning ``_ComposedExtendedModelResponse``, + or ``None`` if handlers empty. """ if not handlers: return None + def _to_composed_result( + result: ModelResponse | AIMessage | ExtendedModelResponse | _ComposedExtendedModelResponse, + extra_commands: list[Command[Any]] | None = None, + ) -> _ComposedExtendedModelResponse: + """Normalize any handler result to _ComposedExtendedModelResponse.""" + commands: list[Command[Any]] = list(extra_commands or []) + if isinstance(result, _ComposedExtendedModelResponse): + commands.extend(result.commands) + model_response = result.model_response + elif isinstance(result, ExtendedModelResponse): + model_response = result.model_response + if result.command is not None: + commands.append(result.command) + else: + model_response = _normalize_to_model_response(result) + + return _ComposedExtendedModelResponse(model_response=model_response, commands=commands) + if len(handlers) == 1: - # Single handler - wrap to normalize output single_handler = handlers[0] async def normalized_single( request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]], - ) -> ModelResponse: - result = await single_handler(request, handler) - return _normalize_to_model_response(result) + ) -> _ComposedExtendedModelResponse: + return _to_composed_result(await single_handler(request, handler)) return normalized_single def compose_two( - outer: Callable[ - [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], - Awaitable[ModelResponse | AIMessage], - ], - inner: Callable[ - [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], - Awaitable[ModelResponse | AIMessage], - ], - ) -> Callable[ - [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], - Awaitable[ModelResponse], - ]: + outer: _AsyncModelCallHandler[ContextT] | _ComposedAsyncModelCallHandler[ContextT], + inner: _AsyncModelCallHandler[ContextT] | _ComposedAsyncModelCallHandler[ContextT], + ) -> _ComposedAsyncModelCallHandler[ContextT]: """Compose two async handlers where outer wraps inner.""" async def composed( request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]], - ) -> ModelResponse: - # Create a wrapper that calls inner with the base handler and normalizes + ) -> _ComposedExtendedModelResponse: + # Closure variable to capture inner's commands before normalizing + accumulated_commands: list[Command[Any]] = [] + async def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse: + # Clear on each call for retry safety + accumulated_commands.clear() inner_result = await inner(req, handler) + if isinstance(inner_result, _ComposedExtendedModelResponse): + accumulated_commands.extend(inner_result.commands) + return inner_result.model_response + if isinstance(inner_result, ExtendedModelResponse): + if inner_result.command is not None: + accumulated_commands.append(inner_result.command) + return inner_result.model_response return _normalize_to_model_response(inner_result) - # Call outer with the wrapped inner as its handler and normalize outer_result = await outer(request, inner_handler) - return _normalize_to_model_response(outer_result) + return _to_composed_result( + outer_result, + extra_commands=accumulated_commands or None, + ) return composed # Compose right-to-left: outer(inner(innermost(handler))) - result = handlers[-1] - for handler in reversed(handlers[:-1]): - result = compose_two(handler, result) + composed_handler = compose_two(handlers[-2], handlers[-1]) + for h in reversed(handlers[:-2]): + composed_handler = compose_two(h, composed_handler) - # Wrap to ensure final return type is exactly ModelResponse - async def final_normalized( - request: ModelRequest[ContextT], - handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]], - ) -> ModelResponse: - # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes - final_result = await result(request, handler) - return _normalize_to_model_response(final_result) - - return final_normalized + return composed_handler def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type: @@ -1165,7 +1239,7 @@ def create_agent( structured_response=structured_response, ) - def model_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> dict[str, Any]: + def model_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> list[Command[Any]]: """Sync model request handler with sequential middleware processing.""" request = ModelRequest( model=model, @@ -1179,18 +1253,11 @@ def create_agent( ) if wrap_model_call_handler is None: - # No handlers - execute directly - response = _execute_model_sync(request) - else: - # Call composed handler with base handler - response = wrap_model_call_handler(request, _execute_model_sync) + model_response = _execute_model_sync(request) + return _build_commands(model_response) - # Extract state updates from ModelResponse - state_updates = {"messages": response.result} - if response.structured_response is not None: - state_updates["structured_response"] = response.structured_response - - return state_updates + result = wrap_model_call_handler(request, _execute_model_sync) + return _build_commands(result.model_response, result.commands) async def _execute_model_async(request: ModelRequest[ContextT]) -> ModelResponse: """Execute model asynchronously and return response. @@ -1220,7 +1287,7 @@ def create_agent( structured_response=structured_response, ) - async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> dict[str, Any]: + async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> list[Command[Any]]: """Async model request handler with sequential middleware processing.""" request = ModelRequest( model=model, @@ -1234,18 +1301,11 @@ def create_agent( ) if awrap_model_call_handler is None: - # No async handlers - execute directly - response = await _execute_model_async(request) - else: - # Call composed async handler with base handler - response = await awrap_model_call_handler(request, _execute_model_async) + model_response = await _execute_model_async(request) + return _build_commands(model_response) - # Extract state updates from ModelResponse - state_updates = {"messages": response.result} - if response.structured_response is not None: - state_updates["structured_response"] = response.structured_response - - return state_updates + result = await awrap_model_call_handler(request, _execute_model_async) + return _build_commands(result.model_response, result.commands) # Use sync or async based on model capabilities graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False)) diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index ef2163564f2..838123e6d78 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -26,6 +26,8 @@ from langchain.agents.middleware.tool_selection import LLMToolSelectorMiddleware from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, + ExtendedModelResponse, + ModelCallResult, ModelRequest, ModelResponse, ToolCallRequest, @@ -46,6 +48,7 @@ __all__ = [ "CodexSandboxExecutionPolicy", "ContextEditingMiddleware", "DockerExecutionPolicy", + "ExtendedModelResponse", "FilesystemFileSearchMiddleware", "HostExecutionPolicy", "HumanInTheLoopMiddleware", @@ -53,6 +56,7 @@ __all__ = [ "LLMToolEmulator", "LLMToolSelectorMiddleware", "ModelCallLimitMiddleware", + "ModelCallResult", "ModelFallbackMiddleware", "ModelRequest", "ModelResponse", diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index a48147a11ee..fbf3ce1b89f 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -19,8 +19,6 @@ from typing import ( if TYPE_CHECKING: from collections.abc import Awaitable - from langgraph.types import Command - # Needed as top level import for Pydantic schema generation on AgentState import warnings from typing import TypeAlias @@ -42,6 +40,7 @@ if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.runtime import Runtime + from langgraph.types import Command from langchain.agents.structured_output import ResponseFormat @@ -49,6 +48,8 @@ __all__ = [ "AgentMiddleware", "AgentState", "ContextT", + "ExtendedModelResponse", + "ModelCallResult", "ModelRequest", "ModelResponse", "OmitFromSchema", @@ -285,14 +286,43 @@ class ModelResponse(Generic[ResponseT]): """Parsed structured output if `response_format` was specified, `None` otherwise.""" -# Type alias for middleware return type - allows returning either full response or just AIMessage -ModelCallResult: TypeAlias = "ModelResponse[ResponseT] | AIMessage" +@dataclass +class ExtendedModelResponse(Generic[ResponseT]): + """Model response with an optional 'Command' from 'wrap_model_call' middleware. + + Use this to return a 'Command' alongside the model response from a + 'wrap_model_call' handler. The command is applied as an additional state + update after the model node completes, using the graph's reducers (e.g. + 'add_messages' for the 'messages' key). + + Because each 'Command' is applied through the reducer, messages in the + command are **added alongside** the model response messages rather than + replacing them. For non-reducer state fields, later commands overwrite + earlier ones (outermost middleware wins over inner). + + Type Parameters: + ResponseT: The type of the structured response. Defaults to 'Any' if not specified. + """ + + model_response: ModelResponse[ResponseT] + """The underlying model response.""" + + command: Command[Any] | None = None + """Optional command to apply as an additional state update.""" + + +ModelCallResult: TypeAlias = ( + "ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]" +) """`TypeAlias` for model call handler return value. Middleware can return either: - `ModelResponse`: Full response with messages and optional structured output - `AIMessage`: Simplified return for simple use cases +- `ExtendedModelResponse`: Response with an optional `Command` for additional state updates + `goto`, `resume`, and `graph` are not yet supported on these commands. + A `NotImplementedError` will be raised if you try to use them. """ @@ -449,7 +479,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]): self, request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], - ) -> ModelResponse[ResponseT] | AIMessage: + ) -> ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]: """Intercept and control model execution via handler callback. Async version is `awrap_model_call` @@ -544,7 +574,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]): self, request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], - ) -> ModelResponse[ResponseT] | AIMessage: + ) -> ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]: """Intercept and control async model execution via handler callback. The handler callback executes the model request and returns a `ModelResponse`. diff --git a/libs/langchain_v1/pyproject.toml b/libs/langchain_v1/pyproject.toml index 8b7bd1c9452..4bc2396e828 100644 --- a/libs/langchain_v1/pyproject.toml +++ b/libs/langchain_v1/pyproject.toml @@ -99,6 +99,7 @@ langchain-core = { path = "../core", editable = true } langchain-tests = { path = "../standard-tests", editable = true } langchain-text-splitters = { path = "../text-splitters", editable = true } langchain-openai = { path = "../partners/openai", editable = true } +langchain-anthropic = { path = "../partners/anthropic", editable = true } [tool.ruff] line-length = 100 diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py index 24b933c7997..05b6b9a8de1 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py @@ -5,10 +5,11 @@ from typing import Any, TypedDict, cast from langchain_core.messages import AIMessage from langgraph.runtime import Runtime +from langgraph.types import Command from langchain.agents import AgentState -from langchain.agents.factory import _chain_model_call_handlers -from langchain.agents.middleware.types import ModelRequest, ModelResponse +from langchain.agents.factory import _chain_model_call_handlers, _ComposedExtendedModelResponse +from langchain.agents.middleware.types import ExtendedModelResponse, ModelRequest, ModelResponse def create_test_request(**kwargs: Any) -> ModelRequest: @@ -88,9 +89,41 @@ class TestChainModelCallHandlers: "inner-after", "outer-after", ] - # Result is now ModelResponse - assert isinstance(result, ModelResponse) - assert result.result[0].content == "test" + # Outermost result is always _ComposedExtendedModelResponse + assert isinstance(result, _ComposedExtendedModelResponse) + assert result.model_response.result[0].content == "test" + + def test_two_handlers_with_commands(self) -> None: + """Test that commands from inner and outer are collected correctly.""" + + def outer( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(update={"outer_key": "outer_val"}), + ) + + def inner( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(update={"inner_key": "inner_val"}), + ) + + composed = _chain_model_call_handlers([outer, inner]) + assert composed is not None + + result = composed(create_test_request(), create_mock_base_handler()) + + assert isinstance(result, _ComposedExtendedModelResponse) + # Commands are collected: inner first, then outer + assert len(result.commands) == 2 + assert result.commands[0].update == {"inner_key": "inner_val"} + assert result.commands[1].update == {"outer_key": "outer_val"} def test_three_handlers_composition(self) -> None: """Test composition of three handlers.""" @@ -134,8 +167,8 @@ class TestChainModelCallHandlers: "second-after", "first-after", ] - assert isinstance(result, ModelResponse) - assert result.result[0].content == "test" + assert isinstance(result, _ComposedExtendedModelResponse) + assert result.model_response.result[0].content == "test" def test_inner_handler_retry(self) -> None: """Test inner handler retrying before outer sees response.""" @@ -173,8 +206,8 @@ class TestChainModelCallHandlers: result = composed(create_test_request(), mock_base_handler) assert inner_attempts == [0, 1, 2] - assert isinstance(result, ModelResponse) - assert result.result[0].content == "success" + assert isinstance(result, _ComposedExtendedModelResponse) + assert result.model_response.result[0].content == "success" def test_error_to_success_conversion(self) -> None: """Test handler converting error to success response.""" @@ -202,10 +235,10 @@ class TestChainModelCallHandlers: result = composed(create_test_request(), mock_base_handler) - # AIMessage was automatically converted to ModelResponse - assert isinstance(result, ModelResponse) - assert result.result[0].content == "Fallback response" - assert result.structured_response is None + # AIMessage was automatically normalized into ExtendedModelResponse + assert isinstance(result, _ComposedExtendedModelResponse) + assert result.model_response.result[0].content == "Fallback response" + assert result.model_response.structured_response is None def test_request_modification(self) -> None: """Test handlers modifying the request.""" @@ -231,8 +264,8 @@ class TestChainModelCallHandlers: result = composed(create_test_request(), create_mock_base_handler(content="response")) assert requests_seen == ["Added by outer"] - assert isinstance(result, ModelResponse) - assert result.result[0].content == "response" + assert isinstance(result, _ComposedExtendedModelResponse) + assert result.model_response.result[0].content == "response" def test_composition_preserves_state_and_runtime(self) -> None: """Test that state and runtime are passed through composition.""" @@ -273,8 +306,8 @@ class TestChainModelCallHandlers: # Both handlers should see same state and runtime assert state_values == [("outer", test_state), ("inner", test_state)] assert runtime_values == [("outer", test_runtime), ("inner", test_runtime)] - assert isinstance(result, ModelResponse) - assert result.result[0].content == "test" + assert isinstance(result, _ComposedExtendedModelResponse) + assert result.model_response.result[0].content == "test" def test_multiple_yields_in_retry_loop(self) -> None: """Test handler that retries multiple times.""" @@ -312,5 +345,5 @@ class TestChainModelCallHandlers: # Outer called once, inner retried so base handler called twice assert call_count["value"] == 1 assert attempt["value"] == 2 - assert isinstance(result, ModelResponse) - assert result.result[0].content == "ok" + assert isinstance(result, _ComposedExtendedModelResponse) + assert result.model_response.result[0].content == "ok" diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py new file mode 100644 index 00000000000..085efa6885e --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py @@ -0,0 +1,917 @@ +"""Unit tests for ExtendedModelResponse command support in wrap_model_call. + +Tests that wrap_model_call middleware can return ExtendedModelResponse to provide +a Command alongside the model response. Commands are applied as separate state +updates through graph reducers (e.g. add_messages for messages). +""" + +from collections.abc import Awaitable, Callable + +import pytest +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, HumanMessage +from langgraph.errors import InvalidUpdateError +from langgraph.types import Command + +from langchain.agents import AgentState, create_agent +from langchain.agents.middleware.types import ( + AgentMiddleware, + ExtendedModelResponse, + ModelRequest, + ModelResponse, + wrap_model_call, +) + + +class TestBasicCommand: + """Test basic ExtendedModelResponse functionality with Command.""" + + def test_command_messages_added_alongside_model_messages(self) -> None: + """Command messages are added alongside model response messages (additive).""" + + class AddMessagesMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + custom_msg = HumanMessage(content="Custom message", id="custom") + return ExtendedModelResponse( + model_response=response, + command=Command(update={"messages": [custom_msg]}), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + agent = create_agent(model=model, middleware=[AddMessagesMiddleware()]) + + result = agent.invoke({"messages": [HumanMessage(content="Hi")]}) + + # Both model response AND command messages appear (additive via add_messages) + messages = result["messages"] + assert len(messages) == 3 + assert messages[0].content == "Hi" + assert messages[1].content == "Hello!" + assert messages[2].content == "Custom message" + + def test_command_with_extra_messages_and_model_response(self) -> None: + """Middleware can add extra messages via command alongside model messages.""" + + class ExtraMessagesMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + summary = HumanMessage(content="Summary", id="summary") + return ExtendedModelResponse( + model_response=response, + command=Command(update={"messages": [summary]}), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + agent = create_agent(model=model, middleware=[ExtraMessagesMiddleware()]) + + result = agent.invoke({"messages": [HumanMessage(content="Hi")]}) + + messages = result["messages"] + assert len(messages) == 3 + assert messages[0].content == "Hi" + assert messages[1].content == "Hello!" + assert messages[2].content == "Summary" + + def test_command_structured_response_conflicts_with_model_response(self) -> None: + """Command and model response both setting structured_response raises.""" + + class OverrideMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + response_with_structured = ModelResponse( + result=response.result, + structured_response={"from": "model"}, + ) + return ExtendedModelResponse( + model_response=response_with_structured, + command=Command( + update={ + "structured_response": {"from": "command"}, + } + ), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Model msg")])) + agent = create_agent(model=model, middleware=[OverrideMiddleware()]) + + # Two Commands both setting structured_response (a LastValue channel) + # in the same step raises InvalidUpdateError + with pytest.raises(InvalidUpdateError): + agent.invoke({"messages": [HumanMessage("Hi")]}) + + def test_command_with_custom_state_field(self) -> None: + """When command updates a custom field, model response messages are preserved.""" + + class CustomFieldMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(update={"custom_key": "custom_value"}), + ) + + class CustomState(AgentState): + custom_key: str + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent( + model=model, + middleware=[CustomFieldMiddleware()], + state_schema=CustomState, + ) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert result["messages"][-1].content == "Hello" + + +class TestCustomStateField: + """Test ExtendedModelResponse with custom state fields defined via state_schema.""" + + def test_custom_field_via_state_schema(self) -> None: + """Middleware updates a custom state field via ExtendedModelResponse.""" + + class MyState(AgentState): + summary: str + + class SummaryMiddleware(AgentMiddleware): + state_schema = MyState # type: ignore[assignment] + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(update={"summary": "conversation summarized"}), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent(model=model, middleware=[SummaryMiddleware()]) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert result["messages"][-1].content == "Hello" + + def test_no_command(self) -> None: + """ExtendedModelResponse with no command works like ModelResponse.""" + + class NoCommandMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent(model=model, middleware=[NoCommandMiddleware()]) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert len(result["messages"]) == 2 + assert result["messages"][1].content == "Hello" + + +class TestBackwardsCompatibility: + """Test that existing ModelResponse and AIMessage returns still work.""" + + def test_model_response_return_unchanged(self) -> None: + """Existing middleware returning ModelResponse works identically.""" + + class PassthroughMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + return handler(request) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent(model=model, middleware=[PassthroughMiddleware()]) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert len(result["messages"]) == 2 + assert result["messages"][1].content == "Hello" + + def test_ai_message_return_unchanged(self) -> None: + """Existing middleware returning AIMessage works identically.""" + + class ShortCircuitMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> AIMessage: + return AIMessage(content="Short-circuited") + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Should not appear")])) + agent = create_agent(model=model, middleware=[ShortCircuitMiddleware()]) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert len(result["messages"]) == 2 + assert result["messages"][1].content == "Short-circuited" + + def test_no_middleware_unchanged(self) -> None: + """Agent without middleware works identically.""" + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent(model=model) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert len(result["messages"]) == 2 + assert result["messages"][1].content == "Hello" + + +class TestAsyncExtendedModelResponse: + """Test async variant of ExtendedModelResponse.""" + + async def test_async_command_adds_messages(self) -> None: + """awrap_model_call command adds messages alongside model response.""" + + class AsyncAddMiddleware(AgentMiddleware): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ExtendedModelResponse: + response = await handler(request) + custom = HumanMessage(content="Async custom", id="async-custom") + return ExtendedModelResponse( + model_response=response, + command=Command(update={"messages": [custom]}), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Async hello!")])) + agent = create_agent(model=model, middleware=[AsyncAddMiddleware()]) + + result = await agent.ainvoke({"messages": [HumanMessage(content="Hi")]}) + + # Both model response and command messages are present (additive) + messages = result["messages"] + assert len(messages) == 3 + assert messages[0].content == "Hi" + assert messages[1].content == "Async hello!" + assert messages[2].content == "Async custom" + + async def test_async_decorator_command(self) -> None: + """@wrap_model_call async decorator returns ExtendedModelResponse with command.""" + + @wrap_model_call + async def command_middleware( + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ExtendedModelResponse: + response = await handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command( + update={ + "messages": [ + HumanMessage(content="Decorator msg", id="dec"), + ] + } + ), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")])) + agent = create_agent(model=model, middleware=[command_middleware]) + + result = await agent.ainvoke({"messages": [HumanMessage(content="Hi")]}) + + messages = result["messages"] + assert len(messages) == 3 + assert messages[1].content == "Async response" + assert messages[2].content == "Decorator msg" + + +class TestComposition: + """Test ExtendedModelResponse with composed middleware. + + Key semantics: Commands are collected inner-first, then outer. + For non-reducer fields, later Commands overwrite (outer wins). + For reducer fields (messages), all Commands are additive. + """ + + def test_outer_command_messages_added_alongside_model(self) -> None: + """Outer middleware's command messages are added alongside model messages.""" + execution_order: list[str] = [] + + class OuterMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + execution_order.append("outer-before") + response = handler(request) + execution_order.append("outer-after") + return ExtendedModelResponse( + model_response=response, + command=Command( + update={"messages": [HumanMessage(content="Outer msg", id="outer-msg")]} + ), + ) + + class InnerMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + execution_order.append("inner-before") + response = handler(request) + execution_order.append("inner-after") + return response + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Composed")])) + agent = create_agent( + model=model, + middleware=[OuterMiddleware(), InnerMiddleware()], + ) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + # Execution order: outer wraps inner + assert execution_order == [ + "outer-before", + "inner-before", + "inner-after", + "outer-after", + ] + + # Model messages + outer command messages (additive) + messages = result["messages"] + assert len(messages) == 3 + assert messages[0].content == "Hi" + assert messages[1].content == "Composed" + assert messages[2].content == "Outer msg" + + def test_inner_command_propagated_through_composition(self) -> None: + """Inner middleware's ExtendedModelResponse command is propagated. + + When inner middleware returns ExtendedModelResponse, its command is + captured before normalizing to ModelResponse at the composition boundary + and collected into the final result. + """ + + class OuterMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + # Outer sees a ModelResponse from handler (inner's ExtendedModelResponse + # was normalized at the composition boundary) + response = handler(request) + assert isinstance(response, ModelResponse) + return response + + class InnerMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command( + update={ + "messages": [ + HumanMessage(content="Inner msg", id="inner"), + ] + } + ), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent( + model=model, + middleware=[OuterMiddleware(), InnerMiddleware()], + ) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + # Model messages + inner command messages (additive) + messages = result["messages"] + assert len(messages) == 3 + assert messages[0].content == "Hi" + assert messages[1].content == "Hello" + assert messages[2].content == "Inner msg" + + def test_non_reducer_key_conflict_raises(self) -> None: + """Multiple Commands setting the same non-reducer key raises. + + LastValue channels (like custom_key) can only receive one value per + step. Inner and outer both setting the same key is an error. + """ + + class MyState(AgentState): + custom_key: str + + class OuterMiddleware(AgentMiddleware): + state_schema = MyState # type: ignore[assignment] + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command( + update={ + "messages": [HumanMessage(content="Outer msg", id="outer")], + "custom_key": "outer_value", + } + ), + ) + + class InnerMiddleware(AgentMiddleware): + state_schema = MyState # type: ignore[assignment] + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command( + update={ + "messages": [HumanMessage(content="Inner msg", id="inner")], + "custom_key": "inner_value", + } + ), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent( + model=model, + middleware=[OuterMiddleware(), InnerMiddleware()], + ) + + # Two Commands both setting custom_key (a LastValue channel) + # in the same step raises InvalidUpdateError + with pytest.raises(InvalidUpdateError): + agent.invoke({"messages": [HumanMessage("Hi")]}) + + def test_inner_state_preserved_when_outer_has_no_conflict(self) -> None: + """Inner's command keys are preserved when outer doesn't conflict.""" + + class MyState(AgentState): + inner_key: str + outer_key: str + + class OuterMiddleware(AgentMiddleware): + state_schema = MyState # type: ignore[assignment] + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(update={"outer_key": "from_outer"}), + ) + + class InnerMiddleware(AgentMiddleware): + state_schema = MyState # type: ignore[assignment] + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(update={"inner_key": "from_inner"}), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent( + model=model, + middleware=[OuterMiddleware(), InnerMiddleware()], + ) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + # Both keys survive since there's no conflict + messages = result["messages"] + assert messages[-1].content == "Hello" + + def test_inner_command_retry_safe(self) -> None: + """When outer retries, only the last inner command is used.""" + call_count = 0 + + class MyState(AgentState): + attempt: str + + class OuterMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + # Call handler twice (simulating retry) + handler(request) + return handler(request) + + class InnerMiddleware(AgentMiddleware): + state_schema = MyState # type: ignore[assignment] + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + nonlocal call_count + call_count += 1 + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(update={"attempt": f"attempt_{call_count}"}), + ) + + model = GenericFakeChatModel( + messages=iter([AIMessage(content="First"), AIMessage(content="Second")]) + ) + agent = create_agent( + model=model, + middleware=[OuterMiddleware(), InnerMiddleware()], + ) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + # Only the last retry's inner state should survive + messages = result["messages"] + assert messages[-1].content == "Second" + + def test_decorator_returns_wrap_result(self) -> None: + """@wrap_model_call decorator can return ExtendedModelResponse with command.""" + + @wrap_model_call + def command_middleware( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command( + update={ + "messages": [ + HumanMessage(content="From decorator", id="dec"), + ] + } + ), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")])) + agent = create_agent(model=model, middleware=[command_middleware]) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + messages = result["messages"] + assert len(messages) == 3 + assert messages[1].content == "Model response" + assert messages[2].content == "From decorator" + + def test_structured_response_preserved(self) -> None: + """ExtendedModelResponse preserves structured_response from ModelResponse.""" + + class StructuredMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + response_with_structured = ModelResponse( + result=response.result, + structured_response={"key": "value"}, + ) + return ExtendedModelResponse( + model_response=response_with_structured, + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent(model=model, middleware=[StructuredMiddleware()]) + + result = agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert result.get("structured_response") == {"key": "value"} + messages = result["messages"] + assert len(messages) == 2 + assert messages[1].content == "Hello" + + +class TestAsyncComposition: + """Test async ExtendedModelResponse propagation through composed middleware.""" + + async def test_async_inner_command_propagated(self) -> None: + """Async: inner middleware's ExtendedModelResponse command is propagated.""" + + class OuterMiddleware(AgentMiddleware): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + response = await handler(request) + assert isinstance(response, ModelResponse) + return response + + class InnerMiddleware(AgentMiddleware): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ExtendedModelResponse: + response = await handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command( + update={ + "messages": [ + HumanMessage(content="Inner msg", id="inner"), + ] + } + ), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent( + model=model, + middleware=[OuterMiddleware(), InnerMiddleware()], + ) + + result = await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + + # Model messages + inner command messages (additive) + messages = result["messages"] + assert len(messages) == 3 + assert messages[0].content == "Hi" + assert messages[1].content == "Hello" + assert messages[2].content == "Inner msg" + + async def test_async_both_commands_additive_messages(self) -> None: + """Async: both inner and outer command messages are added alongside model.""" + + class OuterMiddleware(AgentMiddleware): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ExtendedModelResponse: + response = await handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command( + update={"messages": [HumanMessage(content="Outer msg", id="outer")]} + ), + ) + + class InnerMiddleware(AgentMiddleware): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ExtendedModelResponse: + response = await handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command( + update={"messages": [HumanMessage(content="Inner msg", id="inner")]} + ), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + agent = create_agent( + model=model, + middleware=[OuterMiddleware(), InnerMiddleware()], + ) + + result = await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + + # All messages additive: model + inner + outer + messages = result["messages"] + assert len(messages) == 4 + assert messages[0].content == "Hi" + assert messages[1].content == "Hello" + assert messages[2].content == "Inner msg" + assert messages[3].content == "Outer msg" + + async def test_async_inner_command_retry_safe(self) -> None: + """Async: when outer retries, only last inner command is used.""" + call_count = 0 + + class MyState(AgentState): + attempt: str + + class OuterMiddleware(AgentMiddleware): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + # Call handler twice (simulating retry) + await handler(request) + return await handler(request) + + class InnerMiddleware(AgentMiddleware): + state_schema = MyState # type: ignore[assignment] + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ExtendedModelResponse: + nonlocal call_count + call_count += 1 + response = await handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(update={"attempt": f"attempt_{call_count}"}), + ) + + model = GenericFakeChatModel( + messages=iter([AIMessage(content="First"), AIMessage(content="Second")]) + ) + agent = create_agent( + model=model, + middleware=[OuterMiddleware(), InnerMiddleware()], + ) + + result = await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + + messages = result["messages"] + assert any(m.content == "Second" for m in messages) + + +class TestCommandGotoDisallowed: + """Test that Command goto raises NotImplementedError in wrap_model_call.""" + + def test_command_goto_raises_not_implemented(self) -> None: + """Command with goto in wrap_model_call raises NotImplementedError.""" + + class GotoMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(goto="__end__"), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + agent = create_agent(model=model, middleware=[GotoMiddleware()]) + + with pytest.raises(NotImplementedError, match="Command goto is not yet supported"): + agent.invoke({"messages": [HumanMessage(content="Hi")]}) + + async def test_async_command_goto_raises_not_implemented(self) -> None: + """Async: Command with goto in wrap_model_call raises NotImplementedError.""" + + class AsyncGotoMiddleware(AgentMiddleware): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ExtendedModelResponse: + response = await handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(goto="tools"), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + agent = create_agent(model=model, middleware=[AsyncGotoMiddleware()]) + + with pytest.raises(NotImplementedError, match="Command goto is not yet supported"): + await agent.ainvoke({"messages": [HumanMessage(content="Hi")]}) + + +class TestCommandResumeDisallowed: + """Test that Command resume raises NotImplementedError in wrap_model_call.""" + + def test_command_resume_raises_not_implemented(self) -> None: + """Command with resume in wrap_model_call raises NotImplementedError.""" + + class ResumeMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(resume="some_value"), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + agent = create_agent(model=model, middleware=[ResumeMiddleware()]) + + with pytest.raises(NotImplementedError, match="Command resume is not yet supported"): + agent.invoke({"messages": [HumanMessage(content="Hi")]}) + + async def test_async_command_resume_raises_not_implemented(self) -> None: + """Async: Command with resume in wrap_model_call raises NotImplementedError.""" + + class AsyncResumeMiddleware(AgentMiddleware): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ExtendedModelResponse: + response = await handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(resume="some_value"), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + agent = create_agent(model=model, middleware=[AsyncResumeMiddleware()]) + + with pytest.raises(NotImplementedError, match="Command resume is not yet supported"): + await agent.ainvoke({"messages": [HumanMessage(content="Hi")]}) + + +class TestCommandGraphDisallowed: + """Test that Command graph raises NotImplementedError in wrap_model_call.""" + + def test_command_graph_raises_not_implemented(self) -> None: + """Command with graph in wrap_model_call raises NotImplementedError.""" + + class GraphMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ExtendedModelResponse: + response = handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(graph=Command.PARENT, update={"messages": []}), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + agent = create_agent(model=model, middleware=[GraphMiddleware()]) + + with pytest.raises(NotImplementedError, match="Command graph is not yet supported"): + agent.invoke({"messages": [HumanMessage(content="Hi")]}) + + async def test_async_command_graph_raises_not_implemented(self) -> None: + """Async: Command with graph in wrap_model_call raises NotImplementedError.""" + + class AsyncGraphMiddleware(AgentMiddleware): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ExtendedModelResponse: + response = await handler(request) + return ExtendedModelResponse( + model_response=response, + command=Command(graph=Command.PARENT, update={"messages": []}), + ) + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + agent = create_agent(model=model, middleware=[AsyncGraphMiddleware()]) + + with pytest.raises(NotImplementedError, match="Command graph is not yet supported"): + await agent.ainvoke({"messages": [HumanMessage(content="Hi")]}) diff --git a/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py b/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py index 9f9f341fd63..b2ac9345a8f 100644 --- a/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py +++ b/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py @@ -313,6 +313,7 @@ def test_configurable_with_default() -> None: "default_headers": None, "model_kwargs": {}, "reuse_last_container": None, + "inference_geo": None, "streaming": False, "stream_usage": True, "output_version": None, diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock index fc92e0beab6..54a3e356d91 100644 --- a/libs/langchain_v1/uv.lock +++ b/libs/langchain_v1/uv.lock @@ -173,7 +173,7 @@ wheels = [ [[package]] name = "anthropic" -version = "0.72.1" +version = "0.78.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -185,9 +185,9 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/dd/f3/feb750a21461090ecf48bbebcaa261cd09003cc1d14e2fa9643ad59edd4d/anthropic-0.72.1.tar.gz", hash = "sha256:a6d1d660e1f4af91dddc732f340786d19acaffa1ae8e69442e56be5fa6539d51", size = 415395, upload-time = "2025-11-11T16:53:29.001Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/51/32849a48f9b1cfe80a508fd269b20bd8f0b1357c70ba092890fde5a6a10b/anthropic-0.78.0.tar.gz", hash = "sha256:55fd978ab9b049c61857463f4c4e9e092b24f892519c6d8078cee1713d8af06e", size = 509136, upload-time = "2026-02-05T17:52:04.986Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/51/05/d9d45edad1aa28330cea09a3b35e1590f7279f91bb5ab5237c70a0884ea3/anthropic-0.72.1-py3-none-any.whl", hash = "sha256:81e73cca55e8924776c8c4418003defe6bf9eaf0cd92beb94c8dbf537b95316f", size = 357373, upload-time = "2025-11-11T16:53:27.438Z" }, + { url = "https://files.pythonhosted.org/packages/3b/03/2f50931a942e5e13f80e24d83406714672c57964be593fc046d81369335b/anthropic-0.78.0-py3-none-any.whl", hash = "sha256:2a9887d2e99d1b0f9fe08857a1e9fe5d2d4030455dbf9ac65aab052e2efaeac4", size = 405485, upload-time = "2026-02-05T17:52:03.674Z" }, ] [[package]] @@ -1975,7 +1975,7 @@ typing = [ [package.metadata] requires-dist = [ - { name = "langchain-anthropic", marker = "extra == 'anthropic'" }, + { name = "langchain-anthropic", marker = "extra == 'anthropic'", editable = "../partners/anthropic" }, { name = "langchain-aws", marker = "extra == 'aws'" }, { name = "langchain-azure-ai", marker = "extra == 'azure-ai'" }, { name = "langchain-community", marker = "extra == 'community'" }, @@ -2028,16 +2028,51 @@ typing = [ [[package]] name = "langchain-anthropic" -version = "1.0.3" -source = { registry = "https://pypi.org/simple" } +version = "1.3.1" +source = { editable = "../partners/anthropic" } dependencies = [ { name = "anthropic" }, { name = "langchain-core" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/92/6b/aaa770beea6f4ed4c3f5c75fd6d80ed5c82708aec15318c06d9379dd3543/langchain_anthropic-1.0.3.tar.gz", hash = "sha256:91083c5df82634602f6772989918108d9448fa0b7499a11434687198f5bf9aef", size = 680336, upload-time = "2025-11-12T15:58:58.54Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/3d/0499eeb10d333ea79e1c156b84d075e96cfde2fbc6a9ec9cbfa50ac3e47e/langchain_anthropic-1.0.3-py3-none-any.whl", hash = "sha256:0d4106111d57e19988e2976fb6c6e59b0c47ca7afb0f6a2f888362006f871871", size = 46608, upload-time = "2025-11-12T15:58:57.28Z" }, + +[package.metadata] +requires-dist = [ + { name = "anthropic", specifier = ">=0.75.0,<1.0.0" }, + { name = "langchain-core", editable = "../core" }, + { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, +] + +[package.metadata.requires-dev] +dev = [{ name = "langchain-core", editable = "../core" }] +lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }] +test = [ + { name = "blockbuster", specifier = ">=1.5.5,<1.6" }, + { name = "defusedxml", specifier = ">=0.7.1,<1.0.0" }, + { name = "freezegun", specifier = ">=1.2.2,<2.0.0" }, + { name = "langchain", editable = "." }, + { name = "langchain-core", editable = "../core" }, + { name = "langchain-tests", editable = "../standard-tests" }, + { name = "langgraph-prebuilt", specifier = ">=0.7.0a2" }, + { name = "pytest", specifier = ">=7.3.0,<8.0.0" }, + { name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" }, + { name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" }, + { name = "pytest-retry", specifier = ">=1.7.0,<1.8.0" }, + { name = "pytest-socket", specifier = ">=0.7.0,<1.0.0" }, + { name = "pytest-timeout", specifier = ">=2.3.1,<3.0.0" }, + { name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" }, + { name = "pytest-xdist", specifier = ">=3.8.0,<4.0.0" }, + { name = "syrupy", specifier = ">=4.0.2,<5.0.0" }, + { name = "vcrpy", specifier = ">=8.0.0,<9.0.0" }, +] +test-integration = [ + { name = "langchain-core", editable = "../core" }, + { name = "requests", specifier = ">=2.32.3,<3.0.0" }, +] +typing = [ + { name = "langchain-core", editable = "../core" }, + { name = "mypy", specifier = ">=1.17.1,<2.0.0" }, + { name = "types-requests", specifier = ">=2.31.0,<3.0.0" }, ] [[package]]