feat: support state updates from wrap_model_call with command(s) (#35033)

Alternative to https://github.com/langchain-ai/langchain/pull/35024.
Paving the way for summarization in `wrap_model_call` (which requires
state updates).

---

Add `ExtendedModelResponse` dataclass that allows `wrap_model_call`
middleware to return a `Command` alongside the model response for
additional state updates.

```py
@dataclass
class ExtendedModelResponse(Generic[ResponseT]):
    model_response: ModelResponse[ResponseT]
    command: Command
```

## Motivation

Previously, `wrap_model_call` middleware could only return a
`ModelResponse` or `AIMessage` — there was no way to inject additional
state updates (e.g. custom state fields) from the model call middleware
layer. `ExtendedModelResponse` fills this gap by accepting an optional
`Command`.

This feature is needed by the summarization middleware, which needs to
track summarization trigger points calculated during `wrap_model_call`.

## Why `Command` instead of a plain `state_update` dict?

We chose `Command` rather than the raw `state_update: dict` approach
from the earlier iteration because `Command` is the established
LangGraph primitive for state updates from nodes. Using `Command` means:

- State updates flow through the graph's reducers (e.g. `add_messages`)
rather than being merged as raw dicts. This makes messages updates
additive alongside the model response instead of replacing them.
- Consistency with `wrap_tool_call`, which already returns `Command`.
- Future-proof: as `Command` gains new capabilities (e.g. `goto`,
`send`), middleware can leverage them without API changes.

## Why keep `model_response` separate instead of using `Command`
directly?

The model node needs to distinguish the model's actual response
(messages + structured output) from supplementary middleware state
updates. If middleware returned only a `Command`, there would be no
clean way to extract the `ModelResponse` for structured output handling,
response validation, and the core model-to-tools routing logic. Keeping
`model_response` explicit preserves a clear boundary between "what the
model said" and "what middleware wants to update."

Also, in order to avoid breaking, the `handler` passed to
`wrap_tool_call` needs to always return a `ModelResponse`. There's no
easy way to preserve this if we pump it into a `Command`.

One nice thing about having this `ExtendedModelResponse` structure is
that it's extensible if we want to add more metadata in the future.

## Composition

When multiple middleware layers return `ExtendedModelResponse`, their
commands compose naturally:

- **Inner commands propagate outward:** At composition boundaries,
`ExtendedModelResponse` is unwrapped to its underlying `ModelResponse`
so outer middleware always sees a plain `ModelResponse` from
`handler()`. The inner command is captured and accumulated.
- **Commands are applied through reducers:** Each `Command` becomes a
separate state update applied through the graph's reducers. For
messages, this means they're additive (via `add_messages`), not
replacing.
- **Outer wins on conflicts:** For non-reducer state fields, commands
are applied inner-first then outer, so the outermost middleware's value
takes precedence on conflicting keys.
- **Retry-safe:** When outer middleware retries by calling `handler()`
again, accumulated inner commands are cleared and re-collected from the
fresh call.

```python
class Outer(AgentMiddleware):
    def wrap_model_call(self, request, handler):
        response = handler(request)  # sees ModelResponse, not ExtendedModelResponse
        return ExtendedModelResponse(
            model_response=response,
            command=Command(update={"outer_key": "val"}),
        )

class Inner(AgentMiddleware):
    def wrap_model_call(self, request, handler):
        response = handler(request)
        return ExtendedModelResponse(
            model_response=response,
            command=Command(update={"inner_key": "val"}),
        )

# Final state merges both commands: {"inner_key": "val", "outer_key": "val"}
```

## Backwards compatibility

Fully backwards compatible. The `ModelCallResult` type alias is widened
from `ModelResponse | AIMessage` to `ModelResponse | AIMessage |
ExtendedModelResponse`, but existing middleware returning
`ModelResponse` or `AIMessage` continues to work identically.

## Internals

- `model_node` / `amodel_node` now return `list[Command]` instead of
`dict[str, Any]`
- `_build_commands` converts the model response + accumulated middleware
commands into a list of `Command` objects for LangGraph
- `_ComposedExtendedModelResponse` is the internal type that accumulates
commands across layers during composition
This commit is contained in:
Sydney Runkle
2026-02-06 07:28:04 -05:00
committed by GitHub
parent 273d282a29
commit 8767a462ca
9 changed files with 1263 additions and 181 deletions

View File

@@ -270,6 +270,7 @@ def test_configurable_with_default() -> None:
"default_headers": None,
"model_kwargs": {},
"reuse_last_container": None,
"inference_geo": None,
"streaming": False,
"stream_usage": True,
"output_version": None,

View File

@@ -3,10 +3,12 @@
from __future__ import annotations
import itertools
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Generic,
cast,
get_args,
get_origin,
@@ -27,6 +29,7 @@ from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ExtendedModelResponse,
JumpTo,
ModelRequest,
ModelResponse,
@@ -49,6 +52,23 @@ from langchain.agents.structured_output import (
)
from langchain.chat_models import init_chat_model
@dataclass
class _ComposedExtendedModelResponse(Generic[ResponseT]):
"""Internal result from composed ``wrap_model_call`` middleware.
Unlike ``ExtendedModelResponse`` (user-facing, single command), this holds the
full list of commands accumulated across all middleware layers during
composition.
"""
model_response: ModelResponse[ResponseT]
"""The underlying model response."""
commands: list[Command[Any]] = field(default_factory=list)
"""Commands accumulated from all middleware layers (inner-first, then outer)."""
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Sequence
@@ -61,6 +81,27 @@ if TYPE_CHECKING:
from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper
_ModelCallHandler = Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse | AIMessage | ExtendedModelResponse,
]
_ComposedModelCallHandler = Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
_ComposedExtendedModelResponse,
]
_AsyncModelCallHandler = Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse | AIMessage | ExtendedModelResponse],
]
_ComposedAsyncModelCallHandler = Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[_ComposedExtendedModelResponse],
]
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
DYNAMIC_TOOL_ERROR_TEMPLATE = """
@@ -102,31 +143,72 @@ FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
]
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
"""Normalize middleware return value to ModelResponse."""
def _normalize_to_model_response(
result: ModelResponse | AIMessage | ExtendedModelResponse,
) -> ModelResponse:
"""Normalize middleware return value to ModelResponse.
At inner composition boundaries, ``ExtendedModelResponse`` is unwrapped to its
underlying ``ModelResponse`` so that inner middleware always sees ``ModelResponse``
from the handler.
"""
if isinstance(result, AIMessage):
return ModelResponse(result=[result], structured_response=None)
if isinstance(result, ExtendedModelResponse):
return result.model_response
return result
def _build_commands(
model_response: ModelResponse,
middleware_commands: list[Command[Any]] | None = None,
) -> list[Command[Any]]:
"""Build a list of Commands from a model response and middleware commands.
The first Command contains the model response state (messages and optional
structured_response). Middleware commands are appended as-is.
Args:
model_response: The model response containing messages and optional
structured output.
middleware_commands: Commands accumulated from middleware layers during
composition (inner-first ordering).
Returns:
List of ``Command`` objects ready to be returned from a model node.
"""
state: dict[str, Any] = {"messages": model_response.result}
if model_response.structured_response is not None:
state["structured_response"] = model_response.structured_response
for cmd in middleware_commands or []:
if cmd.goto:
msg = (
"Command goto is not yet supported in wrap_model_call middleware. "
"Use the jump_to state field with before_model/after_model hooks instead."
)
raise NotImplementedError(msg)
if cmd.resume:
msg = "Command resume is not yet supported in wrap_model_call middleware."
raise NotImplementedError(msg)
if cmd.graph:
msg = "Command graph is not yet supported in wrap_model_call middleware."
raise NotImplementedError(msg)
commands: list[Command[Any]] = [Command(update=state)]
commands.extend(middleware_commands or [])
return commands
def _chain_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse | AIMessage,
]
],
) -> (
Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse,
]
| None
):
"""Compose multiple `wrap_model_call` handlers into single middleware stack.
handlers: Sequence[_ModelCallHandler[ContextT]],
) -> _ComposedModelCallHandler[ContextT] | None:
"""Compose multiple ``wrap_model_call`` handlers into single middleware stack.
Composes handlers so first in list becomes outermost layer. Each handler receives a
handler callback to execute inner layers.
handler callback to execute inner layers. Commands from each layer are accumulated
into a list (inner-first, then outer) without merging.
Args:
handlers: List of handlers.
@@ -134,110 +216,90 @@ def _chain_model_call_handlers(
First handler wraps all others.
Returns:
Composed handler, or `None` if handlers empty.
Example:
```python
# handlers=[auth, retry] means: auth wraps retry
# Flow: auth calls retry, retry calls base handler
def auth(req, state, runtime, handler):
try:
return handler(req)
except UnauthorizedError:
refresh_token()
return handler(req)
def retry(req, state, runtime, handler):
for attempt in range(3):
try:
return handler(req)
except Exception:
if attempt == 2:
raise
handler = _chain_model_call_handlers([auth, retry])
```
Composed handler returning ``_ComposedExtendedModelResponse``,
or ``None`` if handlers empty.
"""
if not handlers:
return None
def _to_composed_result(
result: ModelResponse | AIMessage | ExtendedModelResponse | _ComposedExtendedModelResponse,
extra_commands: list[Command[Any]] | None = None,
) -> _ComposedExtendedModelResponse:
"""Normalize any handler result to _ComposedExtendedModelResponse."""
commands: list[Command[Any]] = list(extra_commands or [])
if isinstance(result, _ComposedExtendedModelResponse):
commands.extend(result.commands)
model_response = result.model_response
elif isinstance(result, ExtendedModelResponse):
model_response = result.model_response
if result.command is not None:
commands.append(result.command)
else:
model_response = _normalize_to_model_response(result)
return _ComposedExtendedModelResponse(model_response=model_response, commands=commands)
if len(handlers) == 1:
# Single handler - wrap to normalize output
single_handler = handlers[0]
def normalized_single(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
) -> ModelResponse:
result = single_handler(request, handler)
return _normalize_to_model_response(result)
) -> _ComposedExtendedModelResponse:
return _to_composed_result(single_handler(request, handler))
return normalized_single
def compose_two(
outer: Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse | AIMessage,
],
inner: Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse | AIMessage,
],
) -> Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse,
]:
outer: _ModelCallHandler[ContextT] | _ComposedModelCallHandler[ContextT],
inner: _ModelCallHandler[ContextT] | _ComposedModelCallHandler[ContextT],
) -> _ComposedModelCallHandler[ContextT]:
"""Compose two handlers where outer wraps inner."""
def composed(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler and normalizes
) -> _ComposedExtendedModelResponse:
# Closure variable to capture inner's commands before normalizing
accumulated_commands: list[Command[Any]] = []
def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse:
# Clear on each call for retry safety
accumulated_commands.clear()
inner_result = inner(req, handler)
if isinstance(inner_result, _ComposedExtendedModelResponse):
accumulated_commands.extend(inner_result.commands)
return inner_result.model_response
if isinstance(inner_result, ExtendedModelResponse):
if inner_result.command is not None:
accumulated_commands.append(inner_result.command)
return inner_result.model_response
return _normalize_to_model_response(inner_result)
# Call outer with the wrapped inner as its handler and normalize
outer_result = outer(request, inner_handler)
return _normalize_to_model_response(outer_result)
return _to_composed_result(
outer_result,
extra_commands=accumulated_commands or None,
)
return composed
# Compose right-to-left: outer(inner(innermost(handler)))
result = handlers[-1]
for handler in reversed(handlers[:-1]):
result = compose_two(handler, result)
composed_handler = compose_two(handlers[-2], handlers[-1])
for h in reversed(handlers[:-2]):
composed_handler = compose_two(h, composed_handler)
# Wrap to ensure final return type is exactly ModelResponse
def final_normalized(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = result(request, handler)
return _normalize_to_model_response(final_result)
return final_normalized
return composed_handler
def _chain_async_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse | AIMessage],
]
],
) -> (
Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse],
]
| None
):
"""Compose multiple async `wrap_model_call` handlers into single middleware stack.
handlers: Sequence[_AsyncModelCallHandler[ContextT]],
) -> _ComposedAsyncModelCallHandler[ContextT] | None:
"""Compose multiple async ``wrap_model_call`` handlers into single middleware stack.
Commands from each layer are accumulated into a list (inner-first, then outer)
without merging.
Args:
handlers: List of async handlers.
@@ -245,69 +307,81 @@ def _chain_async_model_call_handlers(
First handler wraps all others.
Returns:
Composed async handler, or `None` if handlers empty.
Composed async handler returning ``_ComposedExtendedModelResponse``,
or ``None`` if handlers empty.
"""
if not handlers:
return None
def _to_composed_result(
result: ModelResponse | AIMessage | ExtendedModelResponse | _ComposedExtendedModelResponse,
extra_commands: list[Command[Any]] | None = None,
) -> _ComposedExtendedModelResponse:
"""Normalize any handler result to _ComposedExtendedModelResponse."""
commands: list[Command[Any]] = list(extra_commands or [])
if isinstance(result, _ComposedExtendedModelResponse):
commands.extend(result.commands)
model_response = result.model_response
elif isinstance(result, ExtendedModelResponse):
model_response = result.model_response
if result.command is not None:
commands.append(result.command)
else:
model_response = _normalize_to_model_response(result)
return _ComposedExtendedModelResponse(model_response=model_response, commands=commands)
if len(handlers) == 1:
# Single handler - wrap to normalize output
single_handler = handlers[0]
async def normalized_single(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
) -> ModelResponse:
result = await single_handler(request, handler)
return _normalize_to_model_response(result)
) -> _ComposedExtendedModelResponse:
return _to_composed_result(await single_handler(request, handler))
return normalized_single
def compose_two(
outer: Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse | AIMessage],
],
inner: Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse | AIMessage],
],
) -> Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse],
]:
outer: _AsyncModelCallHandler[ContextT] | _ComposedAsyncModelCallHandler[ContextT],
inner: _AsyncModelCallHandler[ContextT] | _ComposedAsyncModelCallHandler[ContextT],
) -> _ComposedAsyncModelCallHandler[ContextT]:
"""Compose two async handlers where outer wraps inner."""
async def composed(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler and normalizes
) -> _ComposedExtendedModelResponse:
# Closure variable to capture inner's commands before normalizing
accumulated_commands: list[Command[Any]] = []
async def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse:
# Clear on each call for retry safety
accumulated_commands.clear()
inner_result = await inner(req, handler)
if isinstance(inner_result, _ComposedExtendedModelResponse):
accumulated_commands.extend(inner_result.commands)
return inner_result.model_response
if isinstance(inner_result, ExtendedModelResponse):
if inner_result.command is not None:
accumulated_commands.append(inner_result.command)
return inner_result.model_response
return _normalize_to_model_response(inner_result)
# Call outer with the wrapped inner as its handler and normalize
outer_result = await outer(request, inner_handler)
return _normalize_to_model_response(outer_result)
return _to_composed_result(
outer_result,
extra_commands=accumulated_commands or None,
)
return composed
# Compose right-to-left: outer(inner(innermost(handler)))
result = handlers[-1]
for handler in reversed(handlers[:-1]):
result = compose_two(handler, result)
composed_handler = compose_two(handlers[-2], handlers[-1])
for h in reversed(handlers[:-2]):
composed_handler = compose_two(h, composed_handler)
# Wrap to ensure final return type is exactly ModelResponse
async def final_normalized(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = await result(request, handler)
return _normalize_to_model_response(final_result)
return final_normalized
return composed_handler
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
@@ -1165,7 +1239,7 @@ def create_agent(
structured_response=structured_response,
)
def model_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> dict[str, Any]:
def model_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> list[Command[Any]]:
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
@@ -1179,18 +1253,11 @@ def create_agent(
)
if wrap_model_call_handler is None:
# No handlers - execute directly
response = _execute_model_sync(request)
else:
# Call composed handler with base handler
response = wrap_model_call_handler(request, _execute_model_sync)
model_response = _execute_model_sync(request)
return _build_commands(model_response)
# Extract state updates from ModelResponse
state_updates = {"messages": response.result}
if response.structured_response is not None:
state_updates["structured_response"] = response.structured_response
return state_updates
result = wrap_model_call_handler(request, _execute_model_sync)
return _build_commands(result.model_response, result.commands)
async def _execute_model_async(request: ModelRequest[ContextT]) -> ModelResponse:
"""Execute model asynchronously and return response.
@@ -1220,7 +1287,7 @@ def create_agent(
structured_response=structured_response,
)
async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> dict[str, Any]:
async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> list[Command[Any]]:
"""Async model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
@@ -1234,18 +1301,11 @@ def create_agent(
)
if awrap_model_call_handler is None:
# No async handlers - execute directly
response = await _execute_model_async(request)
else:
# Call composed async handler with base handler
response = await awrap_model_call_handler(request, _execute_model_async)
model_response = await _execute_model_async(request)
return _build_commands(model_response)
# Extract state updates from ModelResponse
state_updates = {"messages": response.result}
if response.structured_response is not None:
state_updates["structured_response"] = response.structured_response
return state_updates
result = await awrap_model_call_handler(request, _execute_model_async)
return _build_commands(result.model_response, result.commands)
# Use sync or async based on model capabilities
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))

View File

@@ -26,6 +26,8 @@ from langchain.agents.middleware.tool_selection import LLMToolSelectorMiddleware
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ExtendedModelResponse,
ModelCallResult,
ModelRequest,
ModelResponse,
ToolCallRequest,
@@ -46,6 +48,7 @@ __all__ = [
"CodexSandboxExecutionPolicy",
"ContextEditingMiddleware",
"DockerExecutionPolicy",
"ExtendedModelResponse",
"FilesystemFileSearchMiddleware",
"HostExecutionPolicy",
"HumanInTheLoopMiddleware",
@@ -53,6 +56,7 @@ __all__ = [
"LLMToolEmulator",
"LLMToolSelectorMiddleware",
"ModelCallLimitMiddleware",
"ModelCallResult",
"ModelFallbackMiddleware",
"ModelRequest",
"ModelResponse",

View File

@@ -19,8 +19,6 @@ from typing import (
if TYPE_CHECKING:
from collections.abc import Awaitable
from langgraph.types import Command
# Needed as top level import for Pydantic schema generation on AgentState
import warnings
from typing import TypeAlias
@@ -42,6 +40,7 @@ if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents.structured_output import ResponseFormat
@@ -49,6 +48,8 @@ __all__ = [
"AgentMiddleware",
"AgentState",
"ContextT",
"ExtendedModelResponse",
"ModelCallResult",
"ModelRequest",
"ModelResponse",
"OmitFromSchema",
@@ -285,14 +286,43 @@ class ModelResponse(Generic[ResponseT]):
"""Parsed structured output if `response_format` was specified, `None` otherwise."""
# Type alias for middleware return type - allows returning either full response or just AIMessage
ModelCallResult: TypeAlias = "ModelResponse[ResponseT] | AIMessage"
@dataclass
class ExtendedModelResponse(Generic[ResponseT]):
"""Model response with an optional 'Command' from 'wrap_model_call' middleware.
Use this to return a 'Command' alongside the model response from a
'wrap_model_call' handler. The command is applied as an additional state
update after the model node completes, using the graph's reducers (e.g.
'add_messages' for the 'messages' key).
Because each 'Command' is applied through the reducer, messages in the
command are **added alongside** the model response messages rather than
replacing them. For non-reducer state fields, later commands overwrite
earlier ones (outermost middleware wins over inner).
Type Parameters:
ResponseT: The type of the structured response. Defaults to 'Any' if not specified.
"""
model_response: ModelResponse[ResponseT]
"""The underlying model response."""
command: Command[Any] | None = None
"""Optional command to apply as an additional state update."""
ModelCallResult: TypeAlias = (
"ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]"
)
"""`TypeAlias` for model call handler return value.
Middleware can return either:
- `ModelResponse`: Full response with messages and optional structured output
- `AIMessage`: Simplified return for simple use cases
- `ExtendedModelResponse`: Response with an optional `Command` for additional state updates
`goto`, `resume`, and `graph` are not yet supported on these commands.
A `NotImplementedError` will be raised if you try to use them.
"""
@@ -449,7 +479,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
) -> ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]:
"""Intercept and control model execution via handler callback.
Async version is `awrap_model_call`
@@ -544,7 +574,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
) -> ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]:
"""Intercept and control async model execution via handler callback.
The handler callback executes the model request and returns a `ModelResponse`.

View File

@@ -99,6 +99,7 @@ langchain-core = { path = "../core", editable = true }
langchain-tests = { path = "../standard-tests", editable = true }
langchain-text-splitters = { path = "../text-splitters", editable = true }
langchain-openai = { path = "../partners/openai", editable = true }
langchain-anthropic = { path = "../partners/anthropic", editable = true }
[tool.ruff]
line-length = 100

View File

@@ -5,10 +5,11 @@ from typing import Any, TypedDict, cast
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents import AgentState
from langchain.agents.factory import _chain_model_call_handlers
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain.agents.factory import _chain_model_call_handlers, _ComposedExtendedModelResponse
from langchain.agents.middleware.types import ExtendedModelResponse, ModelRequest, ModelResponse
def create_test_request(**kwargs: Any) -> ModelRequest:
@@ -88,9 +89,41 @@ class TestChainModelCallHandlers:
"inner-after",
"outer-after",
]
# Result is now ModelResponse
assert isinstance(result, ModelResponse)
assert result.result[0].content == "test"
# Outermost result is always _ComposedExtendedModelResponse
assert isinstance(result, _ComposedExtendedModelResponse)
assert result.model_response.result[0].content == "test"
def test_two_handlers_with_commands(self) -> None:
"""Test that commands from inner and outer are collected correctly."""
def outer(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(update={"outer_key": "outer_val"}),
)
def inner(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(update={"inner_key": "inner_val"}),
)
composed = _chain_model_call_handlers([outer, inner])
assert composed is not None
result = composed(create_test_request(), create_mock_base_handler())
assert isinstance(result, _ComposedExtendedModelResponse)
# Commands are collected: inner first, then outer
assert len(result.commands) == 2
assert result.commands[0].update == {"inner_key": "inner_val"}
assert result.commands[1].update == {"outer_key": "outer_val"}
def test_three_handlers_composition(self) -> None:
"""Test composition of three handlers."""
@@ -134,8 +167,8 @@ class TestChainModelCallHandlers:
"second-after",
"first-after",
]
assert isinstance(result, ModelResponse)
assert result.result[0].content == "test"
assert isinstance(result, _ComposedExtendedModelResponse)
assert result.model_response.result[0].content == "test"
def test_inner_handler_retry(self) -> None:
"""Test inner handler retrying before outer sees response."""
@@ -173,8 +206,8 @@ class TestChainModelCallHandlers:
result = composed(create_test_request(), mock_base_handler)
assert inner_attempts == [0, 1, 2]
assert isinstance(result, ModelResponse)
assert result.result[0].content == "success"
assert isinstance(result, _ComposedExtendedModelResponse)
assert result.model_response.result[0].content == "success"
def test_error_to_success_conversion(self) -> None:
"""Test handler converting error to success response."""
@@ -202,10 +235,10 @@ class TestChainModelCallHandlers:
result = composed(create_test_request(), mock_base_handler)
# AIMessage was automatically converted to ModelResponse
assert isinstance(result, ModelResponse)
assert result.result[0].content == "Fallback response"
assert result.structured_response is None
# AIMessage was automatically normalized into ExtendedModelResponse
assert isinstance(result, _ComposedExtendedModelResponse)
assert result.model_response.result[0].content == "Fallback response"
assert result.model_response.structured_response is None
def test_request_modification(self) -> None:
"""Test handlers modifying the request."""
@@ -231,8 +264,8 @@ class TestChainModelCallHandlers:
result = composed(create_test_request(), create_mock_base_handler(content="response"))
assert requests_seen == ["Added by outer"]
assert isinstance(result, ModelResponse)
assert result.result[0].content == "response"
assert isinstance(result, _ComposedExtendedModelResponse)
assert result.model_response.result[0].content == "response"
def test_composition_preserves_state_and_runtime(self) -> None:
"""Test that state and runtime are passed through composition."""
@@ -273,8 +306,8 @@ class TestChainModelCallHandlers:
# Both handlers should see same state and runtime
assert state_values == [("outer", test_state), ("inner", test_state)]
assert runtime_values == [("outer", test_runtime), ("inner", test_runtime)]
assert isinstance(result, ModelResponse)
assert result.result[0].content == "test"
assert isinstance(result, _ComposedExtendedModelResponse)
assert result.model_response.result[0].content == "test"
def test_multiple_yields_in_retry_loop(self) -> None:
"""Test handler that retries multiple times."""
@@ -312,5 +345,5 @@ class TestChainModelCallHandlers:
# Outer called once, inner retried so base handler called twice
assert call_count["value"] == 1
assert attempt["value"] == 2
assert isinstance(result, ModelResponse)
assert result.result[0].content == "ok"
assert isinstance(result, _ComposedExtendedModelResponse)
assert result.model_response.result[0].content == "ok"

View File

@@ -0,0 +1,917 @@
"""Unit tests for ExtendedModelResponse command support in wrap_model_call.
Tests that wrap_model_call middleware can return ExtendedModelResponse to provide
a Command alongside the model response. Commands are applied as separate state
updates through graph reducers (e.g. add_messages for messages).
"""
from collections.abc import Awaitable, Callable
import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.errors import InvalidUpdateError
from langgraph.types import Command
from langchain.agents import AgentState, create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
ExtendedModelResponse,
ModelRequest,
ModelResponse,
wrap_model_call,
)
class TestBasicCommand:
"""Test basic ExtendedModelResponse functionality with Command."""
def test_command_messages_added_alongside_model_messages(self) -> None:
"""Command messages are added alongside model response messages (additive)."""
class AddMessagesMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
custom_msg = HumanMessage(content="Custom message", id="custom")
return ExtendedModelResponse(
model_response=response,
command=Command(update={"messages": [custom_msg]}),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, middleware=[AddMessagesMiddleware()])
result = agent.invoke({"messages": [HumanMessage(content="Hi")]})
# Both model response AND command messages appear (additive via add_messages)
messages = result["messages"]
assert len(messages) == 3
assert messages[0].content == "Hi"
assert messages[1].content == "Hello!"
assert messages[2].content == "Custom message"
def test_command_with_extra_messages_and_model_response(self) -> None:
"""Middleware can add extra messages via command alongside model messages."""
class ExtraMessagesMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
summary = HumanMessage(content="Summary", id="summary")
return ExtendedModelResponse(
model_response=response,
command=Command(update={"messages": [summary]}),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, middleware=[ExtraMessagesMiddleware()])
result = agent.invoke({"messages": [HumanMessage(content="Hi")]})
messages = result["messages"]
assert len(messages) == 3
assert messages[0].content == "Hi"
assert messages[1].content == "Hello!"
assert messages[2].content == "Summary"
def test_command_structured_response_conflicts_with_model_response(self) -> None:
"""Command and model response both setting structured_response raises."""
class OverrideMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
response_with_structured = ModelResponse(
result=response.result,
structured_response={"from": "model"},
)
return ExtendedModelResponse(
model_response=response_with_structured,
command=Command(
update={
"structured_response": {"from": "command"},
}
),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Model msg")]))
agent = create_agent(model=model, middleware=[OverrideMiddleware()])
# Two Commands both setting structured_response (a LastValue channel)
# in the same step raises InvalidUpdateError
with pytest.raises(InvalidUpdateError):
agent.invoke({"messages": [HumanMessage("Hi")]})
def test_command_with_custom_state_field(self) -> None:
"""When command updates a custom field, model response messages are preserved."""
class CustomFieldMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(update={"custom_key": "custom_value"}),
)
class CustomState(AgentState):
custom_key: str
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[CustomFieldMiddleware()],
state_schema=CustomState,
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert result["messages"][-1].content == "Hello"
class TestCustomStateField:
"""Test ExtendedModelResponse with custom state fields defined via state_schema."""
def test_custom_field_via_state_schema(self) -> None:
"""Middleware updates a custom state field via ExtendedModelResponse."""
class MyState(AgentState):
summary: str
class SummaryMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(update={"summary": "conversation summarized"}),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[SummaryMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert result["messages"][-1].content == "Hello"
def test_no_command(self) -> None:
"""ExtendedModelResponse with no command works like ModelResponse."""
class NoCommandMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[NoCommandMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Hello"
class TestBackwardsCompatibility:
"""Test that existing ModelResponse and AIMessage returns still work."""
def test_model_response_return_unchanged(self) -> None:
"""Existing middleware returning ModelResponse works identically."""
class PassthroughMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
return handler(request)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[PassthroughMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Hello"
def test_ai_message_return_unchanged(self) -> None:
"""Existing middleware returning AIMessage works identically."""
class ShortCircuitMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> AIMessage:
return AIMessage(content="Short-circuited")
model = GenericFakeChatModel(messages=iter([AIMessage(content="Should not appear")]))
agent = create_agent(model=model, middleware=[ShortCircuitMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Short-circuited"
def test_no_middleware_unchanged(self) -> None:
"""Agent without middleware works identically."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Hello"
class TestAsyncExtendedModelResponse:
"""Test async variant of ExtendedModelResponse."""
async def test_async_command_adds_messages(self) -> None:
"""awrap_model_call command adds messages alongside model response."""
class AsyncAddMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ExtendedModelResponse:
response = await handler(request)
custom = HumanMessage(content="Async custom", id="async-custom")
return ExtendedModelResponse(
model_response=response,
command=Command(update={"messages": [custom]}),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async hello!")]))
agent = create_agent(model=model, middleware=[AsyncAddMiddleware()])
result = await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
# Both model response and command messages are present (additive)
messages = result["messages"]
assert len(messages) == 3
assert messages[0].content == "Hi"
assert messages[1].content == "Async hello!"
assert messages[2].content == "Async custom"
async def test_async_decorator_command(self) -> None:
"""@wrap_model_call async decorator returns ExtendedModelResponse with command."""
@wrap_model_call
async def command_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ExtendedModelResponse:
response = await handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(
update={
"messages": [
HumanMessage(content="Decorator msg", id="dec"),
]
}
),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
agent = create_agent(model=model, middleware=[command_middleware])
result = await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
messages = result["messages"]
assert len(messages) == 3
assert messages[1].content == "Async response"
assert messages[2].content == "Decorator msg"
class TestComposition:
"""Test ExtendedModelResponse with composed middleware.
Key semantics: Commands are collected inner-first, then outer.
For non-reducer fields, later Commands overwrite (outer wins).
For reducer fields (messages), all Commands are additive.
"""
def test_outer_command_messages_added_alongside_model(self) -> None:
"""Outer middleware's command messages are added alongside model messages."""
execution_order: list[str] = []
class OuterMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
execution_order.append("outer-before")
response = handler(request)
execution_order.append("outer-after")
return ExtendedModelResponse(
model_response=response,
command=Command(
update={"messages": [HumanMessage(content="Outer msg", id="outer-msg")]}
),
)
class InnerMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
execution_order.append("inner-before")
response = handler(request)
execution_order.append("inner-after")
return response
model = GenericFakeChatModel(messages=iter([AIMessage(content="Composed")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
# Execution order: outer wraps inner
assert execution_order == [
"outer-before",
"inner-before",
"inner-after",
"outer-after",
]
# Model messages + outer command messages (additive)
messages = result["messages"]
assert len(messages) == 3
assert messages[0].content == "Hi"
assert messages[1].content == "Composed"
assert messages[2].content == "Outer msg"
def test_inner_command_propagated_through_composition(self) -> None:
"""Inner middleware's ExtendedModelResponse command is propagated.
When inner middleware returns ExtendedModelResponse, its command is
captured before normalizing to ModelResponse at the composition boundary
and collected into the final result.
"""
class OuterMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
# Outer sees a ModelResponse from handler (inner's ExtendedModelResponse
# was normalized at the composition boundary)
response = handler(request)
assert isinstance(response, ModelResponse)
return response
class InnerMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(
update={
"messages": [
HumanMessage(content="Inner msg", id="inner"),
]
}
),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
# Model messages + inner command messages (additive)
messages = result["messages"]
assert len(messages) == 3
assert messages[0].content == "Hi"
assert messages[1].content == "Hello"
assert messages[2].content == "Inner msg"
def test_non_reducer_key_conflict_raises(self) -> None:
"""Multiple Commands setting the same non-reducer key raises.
LastValue channels (like custom_key) can only receive one value per
step. Inner and outer both setting the same key is an error.
"""
class MyState(AgentState):
custom_key: str
class OuterMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(
update={
"messages": [HumanMessage(content="Outer msg", id="outer")],
"custom_key": "outer_value",
}
),
)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(
update={
"messages": [HumanMessage(content="Inner msg", id="inner")],
"custom_key": "inner_value",
}
),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
# Two Commands both setting custom_key (a LastValue channel)
# in the same step raises InvalidUpdateError
with pytest.raises(InvalidUpdateError):
agent.invoke({"messages": [HumanMessage("Hi")]})
def test_inner_state_preserved_when_outer_has_no_conflict(self) -> None:
"""Inner's command keys are preserved when outer doesn't conflict."""
class MyState(AgentState):
inner_key: str
outer_key: str
class OuterMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(update={"outer_key": "from_outer"}),
)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(update={"inner_key": "from_inner"}),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
# Both keys survive since there's no conflict
messages = result["messages"]
assert messages[-1].content == "Hello"
def test_inner_command_retry_safe(self) -> None:
"""When outer retries, only the last inner command is used."""
call_count = 0
class MyState(AgentState):
attempt: str
class OuterMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
# Call handler twice (simulating retry)
handler(request)
return handler(request)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
nonlocal call_count
call_count += 1
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(update={"attempt": f"attempt_{call_count}"}),
)
model = GenericFakeChatModel(
messages=iter([AIMessage(content="First"), AIMessage(content="Second")])
)
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
# Only the last retry's inner state should survive
messages = result["messages"]
assert messages[-1].content == "Second"
def test_decorator_returns_wrap_result(self) -> None:
"""@wrap_model_call decorator can return ExtendedModelResponse with command."""
@wrap_model_call
def command_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(
update={
"messages": [
HumanMessage(content="From decorator", id="dec"),
]
}
),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")]))
agent = create_agent(model=model, middleware=[command_middleware])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
messages = result["messages"]
assert len(messages) == 3
assert messages[1].content == "Model response"
assert messages[2].content == "From decorator"
def test_structured_response_preserved(self) -> None:
"""ExtendedModelResponse preserves structured_response from ModelResponse."""
class StructuredMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
response_with_structured = ModelResponse(
result=response.result,
structured_response={"key": "value"},
)
return ExtendedModelResponse(
model_response=response_with_structured,
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[StructuredMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert result.get("structured_response") == {"key": "value"}
messages = result["messages"]
assert len(messages) == 2
assert messages[1].content == "Hello"
class TestAsyncComposition:
"""Test async ExtendedModelResponse propagation through composed middleware."""
async def test_async_inner_command_propagated(self) -> None:
"""Async: inner middleware's ExtendedModelResponse command is propagated."""
class OuterMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
response = await handler(request)
assert isinstance(response, ModelResponse)
return response
class InnerMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ExtendedModelResponse:
response = await handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(
update={
"messages": [
HumanMessage(content="Inner msg", id="inner"),
]
}
),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
# Model messages + inner command messages (additive)
messages = result["messages"]
assert len(messages) == 3
assert messages[0].content == "Hi"
assert messages[1].content == "Hello"
assert messages[2].content == "Inner msg"
async def test_async_both_commands_additive_messages(self) -> None:
"""Async: both inner and outer command messages are added alongside model."""
class OuterMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ExtendedModelResponse:
response = await handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(
update={"messages": [HumanMessage(content="Outer msg", id="outer")]}
),
)
class InnerMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ExtendedModelResponse:
response = await handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(
update={"messages": [HumanMessage(content="Inner msg", id="inner")]}
),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
# All messages additive: model + inner + outer
messages = result["messages"]
assert len(messages) == 4
assert messages[0].content == "Hi"
assert messages[1].content == "Hello"
assert messages[2].content == "Inner msg"
assert messages[3].content == "Outer msg"
async def test_async_inner_command_retry_safe(self) -> None:
"""Async: when outer retries, only last inner command is used."""
call_count = 0
class MyState(AgentState):
attempt: str
class OuterMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
# Call handler twice (simulating retry)
await handler(request)
return await handler(request)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ExtendedModelResponse:
nonlocal call_count
call_count += 1
response = await handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(update={"attempt": f"attempt_{call_count}"}),
)
model = GenericFakeChatModel(
messages=iter([AIMessage(content="First"), AIMessage(content="Second")])
)
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
messages = result["messages"]
assert any(m.content == "Second" for m in messages)
class TestCommandGotoDisallowed:
"""Test that Command goto raises NotImplementedError in wrap_model_call."""
def test_command_goto_raises_not_implemented(self) -> None:
"""Command with goto in wrap_model_call raises NotImplementedError."""
class GotoMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(goto="__end__"),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, middleware=[GotoMiddleware()])
with pytest.raises(NotImplementedError, match="Command goto is not yet supported"):
agent.invoke({"messages": [HumanMessage(content="Hi")]})
async def test_async_command_goto_raises_not_implemented(self) -> None:
"""Async: Command with goto in wrap_model_call raises NotImplementedError."""
class AsyncGotoMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ExtendedModelResponse:
response = await handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(goto="tools"),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, middleware=[AsyncGotoMiddleware()])
with pytest.raises(NotImplementedError, match="Command goto is not yet supported"):
await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
class TestCommandResumeDisallowed:
"""Test that Command resume raises NotImplementedError in wrap_model_call."""
def test_command_resume_raises_not_implemented(self) -> None:
"""Command with resume in wrap_model_call raises NotImplementedError."""
class ResumeMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(resume="some_value"),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, middleware=[ResumeMiddleware()])
with pytest.raises(NotImplementedError, match="Command resume is not yet supported"):
agent.invoke({"messages": [HumanMessage(content="Hi")]})
async def test_async_command_resume_raises_not_implemented(self) -> None:
"""Async: Command with resume in wrap_model_call raises NotImplementedError."""
class AsyncResumeMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ExtendedModelResponse:
response = await handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(resume="some_value"),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, middleware=[AsyncResumeMiddleware()])
with pytest.raises(NotImplementedError, match="Command resume is not yet supported"):
await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
class TestCommandGraphDisallowed:
"""Test that Command graph raises NotImplementedError in wrap_model_call."""
def test_command_graph_raises_not_implemented(self) -> None:
"""Command with graph in wrap_model_call raises NotImplementedError."""
class GraphMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(graph=Command.PARENT, update={"messages": []}),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, middleware=[GraphMiddleware()])
with pytest.raises(NotImplementedError, match="Command graph is not yet supported"):
agent.invoke({"messages": [HumanMessage(content="Hi")]})
async def test_async_command_graph_raises_not_implemented(self) -> None:
"""Async: Command with graph in wrap_model_call raises NotImplementedError."""
class AsyncGraphMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ExtendedModelResponse:
response = await handler(request)
return ExtendedModelResponse(
model_response=response,
command=Command(graph=Command.PARENT, update={"messages": []}),
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, middleware=[AsyncGraphMiddleware()])
with pytest.raises(NotImplementedError, match="Command graph is not yet supported"):
await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})

View File

@@ -313,6 +313,7 @@ def test_configurable_with_default() -> None:
"default_headers": None,
"model_kwargs": {},
"reuse_last_container": None,
"inference_geo": None,
"streaming": False,
"stream_usage": True,
"output_version": None,

View File

@@ -173,7 +173,7 @@ wheels = [
[[package]]
name = "anthropic"
version = "0.72.1"
version = "0.78.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@@ -185,9 +185,9 @@ dependencies = [
{ name = "sniffio" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/dd/f3/feb750a21461090ecf48bbebcaa261cd09003cc1d14e2fa9643ad59edd4d/anthropic-0.72.1.tar.gz", hash = "sha256:a6d1d660e1f4af91dddc732f340786d19acaffa1ae8e69442e56be5fa6539d51", size = 415395, upload-time = "2025-11-11T16:53:29.001Z" }
sdist = { url = "https://files.pythonhosted.org/packages/ec/51/32849a48f9b1cfe80a508fd269b20bd8f0b1357c70ba092890fde5a6a10b/anthropic-0.78.0.tar.gz", hash = "sha256:55fd978ab9b049c61857463f4c4e9e092b24f892519c6d8078cee1713d8af06e", size = 509136, upload-time = "2026-02-05T17:52:04.986Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/51/05/d9d45edad1aa28330cea09a3b35e1590f7279f91bb5ab5237c70a0884ea3/anthropic-0.72.1-py3-none-any.whl", hash = "sha256:81e73cca55e8924776c8c4418003defe6bf9eaf0cd92beb94c8dbf537b95316f", size = 357373, upload-time = "2025-11-11T16:53:27.438Z" },
{ url = "https://files.pythonhosted.org/packages/3b/03/2f50931a942e5e13f80e24d83406714672c57964be593fc046d81369335b/anthropic-0.78.0-py3-none-any.whl", hash = "sha256:2a9887d2e99d1b0f9fe08857a1e9fe5d2d4030455dbf9ac65aab052e2efaeac4", size = 405485, upload-time = "2026-02-05T17:52:03.674Z" },
]
[[package]]
@@ -1975,7 +1975,7 @@ typing = [
[package.metadata]
requires-dist = [
{ name = "langchain-anthropic", marker = "extra == 'anthropic'" },
{ name = "langchain-anthropic", marker = "extra == 'anthropic'", editable = "../partners/anthropic" },
{ name = "langchain-aws", marker = "extra == 'aws'" },
{ name = "langchain-azure-ai", marker = "extra == 'azure-ai'" },
{ name = "langchain-community", marker = "extra == 'community'" },
@@ -2028,16 +2028,51 @@ typing = [
[[package]]
name = "langchain-anthropic"
version = "1.0.3"
source = { registry = "https://pypi.org/simple" }
version = "1.3.1"
source = { editable = "../partners/anthropic" }
dependencies = [
{ name = "anthropic" },
{ name = "langchain-core" },
{ name = "pydantic" },
]
sdist = { url = "https://files.pythonhosted.org/packages/92/6b/aaa770beea6f4ed4c3f5c75fd6d80ed5c82708aec15318c06d9379dd3543/langchain_anthropic-1.0.3.tar.gz", hash = "sha256:91083c5df82634602f6772989918108d9448fa0b7499a11434687198f5bf9aef", size = 680336, upload-time = "2025-11-12T15:58:58.54Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/3d/0499eeb10d333ea79e1c156b84d075e96cfde2fbc6a9ec9cbfa50ac3e47e/langchain_anthropic-1.0.3-py3-none-any.whl", hash = "sha256:0d4106111d57e19988e2976fb6c6e59b0c47ca7afb0f6a2f888362006f871871", size = 46608, upload-time = "2025-11-12T15:58:57.28Z" },
[package.metadata]
requires-dist = [
{ name = "anthropic", specifier = ">=0.75.0,<1.0.0" },
{ name = "langchain-core", editable = "../core" },
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
]
[package.metadata.requires-dev]
dev = [{ name = "langchain-core", editable = "../core" }]
lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }]
test = [
{ name = "blockbuster", specifier = ">=1.5.5,<1.6" },
{ name = "defusedxml", specifier = ">=0.7.1,<1.0.0" },
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
{ name = "langchain", editable = "." },
{ name = "langchain-core", editable = "../core" },
{ name = "langchain-tests", editable = "../standard-tests" },
{ name = "langgraph-prebuilt", specifier = ">=0.7.0a2" },
{ name = "pytest", specifier = ">=7.3.0,<8.0.0" },
{ name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" },
{ name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" },
{ name = "pytest-retry", specifier = ">=1.7.0,<1.8.0" },
{ name = "pytest-socket", specifier = ">=0.7.0,<1.0.0" },
{ name = "pytest-timeout", specifier = ">=2.3.1,<3.0.0" },
{ name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" },
{ name = "pytest-xdist", specifier = ">=3.8.0,<4.0.0" },
{ name = "syrupy", specifier = ">=4.0.2,<5.0.0" },
{ name = "vcrpy", specifier = ">=8.0.0,<9.0.0" },
]
test-integration = [
{ name = "langchain-core", editable = "../core" },
{ name = "requests", specifier = ">=2.32.3,<3.0.0" },
]
typing = [
{ name = "langchain-core", editable = "../core" },
{ name = "mypy", specifier = ">=1.17.1,<2.0.0" },
{ name = "types-requests", specifier = ">=2.31.0,<3.0.0" },
]
[[package]]