mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
feat: support state updates from wrap_model_call with command(s) (#35033)
Alternative to https://github.com/langchain-ai/langchain/pull/35024. Paving the way for summarization in `wrap_model_call` (which requires state updates). --- Add `ExtendedModelResponse` dataclass that allows `wrap_model_call` middleware to return a `Command` alongside the model response for additional state updates. ```py @dataclass class ExtendedModelResponse(Generic[ResponseT]): model_response: ModelResponse[ResponseT] command: Command ``` ## Motivation Previously, `wrap_model_call` middleware could only return a `ModelResponse` or `AIMessage` — there was no way to inject additional state updates (e.g. custom state fields) from the model call middleware layer. `ExtendedModelResponse` fills this gap by accepting an optional `Command`. This feature is needed by the summarization middleware, which needs to track summarization trigger points calculated during `wrap_model_call`. ## Why `Command` instead of a plain `state_update` dict? We chose `Command` rather than the raw `state_update: dict` approach from the earlier iteration because `Command` is the established LangGraph primitive for state updates from nodes. Using `Command` means: - State updates flow through the graph's reducers (e.g. `add_messages`) rather than being merged as raw dicts. This makes messages updates additive alongside the model response instead of replacing them. - Consistency with `wrap_tool_call`, which already returns `Command`. - Future-proof: as `Command` gains new capabilities (e.g. `goto`, `send`), middleware can leverage them without API changes. ## Why keep `model_response` separate instead of using `Command` directly? The model node needs to distinguish the model's actual response (messages + structured output) from supplementary middleware state updates. If middleware returned only a `Command`, there would be no clean way to extract the `ModelResponse` for structured output handling, response validation, and the core model-to-tools routing logic. Keeping `model_response` explicit preserves a clear boundary between "what the model said" and "what middleware wants to update." Also, in order to avoid breaking, the `handler` passed to `wrap_tool_call` needs to always return a `ModelResponse`. There's no easy way to preserve this if we pump it into a `Command`. One nice thing about having this `ExtendedModelResponse` structure is that it's extensible if we want to add more metadata in the future. ## Composition When multiple middleware layers return `ExtendedModelResponse`, their commands compose naturally: - **Inner commands propagate outward:** At composition boundaries, `ExtendedModelResponse` is unwrapped to its underlying `ModelResponse` so outer middleware always sees a plain `ModelResponse` from `handler()`. The inner command is captured and accumulated. - **Commands are applied through reducers:** Each `Command` becomes a separate state update applied through the graph's reducers. For messages, this means they're additive (via `add_messages`), not replacing. - **Outer wins on conflicts:** For non-reducer state fields, commands are applied inner-first then outer, so the outermost middleware's value takes precedence on conflicting keys. - **Retry-safe:** When outer middleware retries by calling `handler()` again, accumulated inner commands are cleared and re-collected from the fresh call. ```python class Outer(AgentMiddleware): def wrap_model_call(self, request, handler): response = handler(request) # sees ModelResponse, not ExtendedModelResponse return ExtendedModelResponse( model_response=response, command=Command(update={"outer_key": "val"}), ) class Inner(AgentMiddleware): def wrap_model_call(self, request, handler): response = handler(request) return ExtendedModelResponse( model_response=response, command=Command(update={"inner_key": "val"}), ) # Final state merges both commands: {"inner_key": "val", "outer_key": "val"} ``` ## Backwards compatibility Fully backwards compatible. The `ModelCallResult` type alias is widened from `ModelResponse | AIMessage` to `ModelResponse | AIMessage | ExtendedModelResponse`, but existing middleware returning `ModelResponse` or `AIMessage` continues to work identically. ## Internals - `model_node` / `amodel_node` now return `list[Command]` instead of `dict[str, Any]` - `_build_commands` converts the model response + accumulated middleware commands into a list of `Command` objects for LangGraph - `_ComposedExtendedModelResponse` is the internal type that accumulates commands across layers during composition
This commit is contained in:
@@ -270,6 +270,7 @@ def test_configurable_with_default() -> None:
|
||||
"default_headers": None,
|
||||
"model_kwargs": {},
|
||||
"reuse_last_container": None,
|
||||
"inference_geo": None,
|
||||
"streaming": False,
|
||||
"stream_usage": True,
|
||||
"output_version": None,
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Generic,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
@@ -27,6 +29,7 @@ from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ExtendedModelResponse,
|
||||
JumpTo,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
@@ -49,6 +52,23 @@ from langchain.agents.structured_output import (
|
||||
)
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ComposedExtendedModelResponse(Generic[ResponseT]):
|
||||
"""Internal result from composed ``wrap_model_call`` middleware.
|
||||
|
||||
Unlike ``ExtendedModelResponse`` (user-facing, single command), this holds the
|
||||
full list of commands accumulated across all middleware layers during
|
||||
composition.
|
||||
"""
|
||||
|
||||
model_response: ModelResponse[ResponseT]
|
||||
"""The underlying model response."""
|
||||
|
||||
commands: list[Command[Any]] = field(default_factory=list)
|
||||
"""Commands accumulated from all middleware layers (inner-first, then outer)."""
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
|
||||
@@ -61,6 +81,27 @@ if TYPE_CHECKING:
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper
|
||||
|
||||
_ModelCallHandler = Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse | AIMessage | ExtendedModelResponse,
|
||||
]
|
||||
|
||||
_ComposedModelCallHandler = Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
_ComposedExtendedModelResponse,
|
||||
]
|
||||
|
||||
_AsyncModelCallHandler = Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse | AIMessage | ExtendedModelResponse],
|
||||
]
|
||||
|
||||
_ComposedAsyncModelCallHandler = Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[_ComposedExtendedModelResponse],
|
||||
]
|
||||
|
||||
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
DYNAMIC_TOOL_ERROR_TEMPLATE = """
|
||||
@@ -102,31 +143,72 @@ FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
|
||||
]
|
||||
|
||||
|
||||
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
|
||||
"""Normalize middleware return value to ModelResponse."""
|
||||
def _normalize_to_model_response(
|
||||
result: ModelResponse | AIMessage | ExtendedModelResponse,
|
||||
) -> ModelResponse:
|
||||
"""Normalize middleware return value to ModelResponse.
|
||||
|
||||
At inner composition boundaries, ``ExtendedModelResponse`` is unwrapped to its
|
||||
underlying ``ModelResponse`` so that inner middleware always sees ``ModelResponse``
|
||||
from the handler.
|
||||
"""
|
||||
if isinstance(result, AIMessage):
|
||||
return ModelResponse(result=[result], structured_response=None)
|
||||
if isinstance(result, ExtendedModelResponse):
|
||||
return result.model_response
|
||||
return result
|
||||
|
||||
|
||||
def _build_commands(
|
||||
model_response: ModelResponse,
|
||||
middleware_commands: list[Command[Any]] | None = None,
|
||||
) -> list[Command[Any]]:
|
||||
"""Build a list of Commands from a model response and middleware commands.
|
||||
|
||||
The first Command contains the model response state (messages and optional
|
||||
structured_response). Middleware commands are appended as-is.
|
||||
|
||||
Args:
|
||||
model_response: The model response containing messages and optional
|
||||
structured output.
|
||||
middleware_commands: Commands accumulated from middleware layers during
|
||||
composition (inner-first ordering).
|
||||
|
||||
Returns:
|
||||
List of ``Command`` objects ready to be returned from a model node.
|
||||
"""
|
||||
state: dict[str, Any] = {"messages": model_response.result}
|
||||
|
||||
if model_response.structured_response is not None:
|
||||
state["structured_response"] = model_response.structured_response
|
||||
|
||||
for cmd in middleware_commands or []:
|
||||
if cmd.goto:
|
||||
msg = (
|
||||
"Command goto is not yet supported in wrap_model_call middleware. "
|
||||
"Use the jump_to state field with before_model/after_model hooks instead."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
if cmd.resume:
|
||||
msg = "Command resume is not yet supported in wrap_model_call middleware."
|
||||
raise NotImplementedError(msg)
|
||||
if cmd.graph:
|
||||
msg = "Command graph is not yet supported in wrap_model_call middleware."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
commands: list[Command[Any]] = [Command(update=state)]
|
||||
commands.extend(middleware_commands or [])
|
||||
return commands
|
||||
|
||||
|
||||
def _chain_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse | AIMessage,
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse,
|
||||
]
|
||||
| None
|
||||
):
|
||||
"""Compose multiple `wrap_model_call` handlers into single middleware stack.
|
||||
handlers: Sequence[_ModelCallHandler[ContextT]],
|
||||
) -> _ComposedModelCallHandler[ContextT] | None:
|
||||
"""Compose multiple ``wrap_model_call`` handlers into single middleware stack.
|
||||
|
||||
Composes handlers so first in list becomes outermost layer. Each handler receives a
|
||||
handler callback to execute inner layers.
|
||||
handler callback to execute inner layers. Commands from each layer are accumulated
|
||||
into a list (inner-first, then outer) without merging.
|
||||
|
||||
Args:
|
||||
handlers: List of handlers.
|
||||
@@ -134,110 +216,90 @@ def _chain_model_call_handlers(
|
||||
First handler wraps all others.
|
||||
|
||||
Returns:
|
||||
Composed handler, or `None` if handlers empty.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# handlers=[auth, retry] means: auth wraps retry
|
||||
# Flow: auth calls retry, retry calls base handler
|
||||
def auth(req, state, runtime, handler):
|
||||
try:
|
||||
return handler(req)
|
||||
except UnauthorizedError:
|
||||
refresh_token()
|
||||
return handler(req)
|
||||
|
||||
|
||||
def retry(req, state, runtime, handler):
|
||||
for attempt in range(3):
|
||||
try:
|
||||
return handler(req)
|
||||
except Exception:
|
||||
if attempt == 2:
|
||||
raise
|
||||
|
||||
|
||||
handler = _chain_model_call_handlers([auth, retry])
|
||||
```
|
||||
Composed handler returning ``_ComposedExtendedModelResponse``,
|
||||
or ``None`` if handlers empty.
|
||||
"""
|
||||
if not handlers:
|
||||
return None
|
||||
|
||||
def _to_composed_result(
|
||||
result: ModelResponse | AIMessage | ExtendedModelResponse | _ComposedExtendedModelResponse,
|
||||
extra_commands: list[Command[Any]] | None = None,
|
||||
) -> _ComposedExtendedModelResponse:
|
||||
"""Normalize any handler result to _ComposedExtendedModelResponse."""
|
||||
commands: list[Command[Any]] = list(extra_commands or [])
|
||||
if isinstance(result, _ComposedExtendedModelResponse):
|
||||
commands.extend(result.commands)
|
||||
model_response = result.model_response
|
||||
elif isinstance(result, ExtendedModelResponse):
|
||||
model_response = result.model_response
|
||||
if result.command is not None:
|
||||
commands.append(result.command)
|
||||
else:
|
||||
model_response = _normalize_to_model_response(result)
|
||||
|
||||
return _ComposedExtendedModelResponse(model_response=model_response, commands=commands)
|
||||
|
||||
if len(handlers) == 1:
|
||||
# Single handler - wrap to normalize output
|
||||
single_handler = handlers[0]
|
||||
|
||||
def normalized_single(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
result = single_handler(request, handler)
|
||||
return _normalize_to_model_response(result)
|
||||
) -> _ComposedExtendedModelResponse:
|
||||
return _to_composed_result(single_handler(request, handler))
|
||||
|
||||
return normalized_single
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse | AIMessage,
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse | AIMessage,
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse,
|
||||
]:
|
||||
outer: _ModelCallHandler[ContextT] | _ComposedModelCallHandler[ContextT],
|
||||
inner: _ModelCallHandler[ContextT] | _ComposedModelCallHandler[ContextT],
|
||||
) -> _ComposedModelCallHandler[ContextT]:
|
||||
"""Compose two handlers where outer wraps inner."""
|
||||
|
||||
def composed(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Create a wrapper that calls inner with the base handler and normalizes
|
||||
) -> _ComposedExtendedModelResponse:
|
||||
# Closure variable to capture inner's commands before normalizing
|
||||
accumulated_commands: list[Command[Any]] = []
|
||||
|
||||
def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse:
|
||||
# Clear on each call for retry safety
|
||||
accumulated_commands.clear()
|
||||
inner_result = inner(req, handler)
|
||||
if isinstance(inner_result, _ComposedExtendedModelResponse):
|
||||
accumulated_commands.extend(inner_result.commands)
|
||||
return inner_result.model_response
|
||||
if isinstance(inner_result, ExtendedModelResponse):
|
||||
if inner_result.command is not None:
|
||||
accumulated_commands.append(inner_result.command)
|
||||
return inner_result.model_response
|
||||
return _normalize_to_model_response(inner_result)
|
||||
|
||||
# Call outer with the wrapped inner as its handler and normalize
|
||||
outer_result = outer(request, inner_handler)
|
||||
return _normalize_to_model_response(outer_result)
|
||||
return _to_composed_result(
|
||||
outer_result,
|
||||
extra_commands=accumulated_commands or None,
|
||||
)
|
||||
|
||||
return composed
|
||||
|
||||
# Compose right-to-left: outer(inner(innermost(handler)))
|
||||
result = handlers[-1]
|
||||
for handler in reversed(handlers[:-1]):
|
||||
result = compose_two(handler, result)
|
||||
composed_handler = compose_two(handlers[-2], handlers[-1])
|
||||
for h in reversed(handlers[:-2]):
|
||||
composed_handler = compose_two(h, composed_handler)
|
||||
|
||||
# Wrap to ensure final return type is exactly ModelResponse
|
||||
def final_normalized(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
||||
final_result = result(request, handler)
|
||||
return _normalize_to_model_response(final_result)
|
||||
|
||||
return final_normalized
|
||||
return composed_handler
|
||||
|
||||
|
||||
def _chain_async_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse],
|
||||
]
|
||||
| None
|
||||
):
|
||||
"""Compose multiple async `wrap_model_call` handlers into single middleware stack.
|
||||
handlers: Sequence[_AsyncModelCallHandler[ContextT]],
|
||||
) -> _ComposedAsyncModelCallHandler[ContextT] | None:
|
||||
"""Compose multiple async ``wrap_model_call`` handlers into single middleware stack.
|
||||
|
||||
Commands from each layer are accumulated into a list (inner-first, then outer)
|
||||
without merging.
|
||||
|
||||
Args:
|
||||
handlers: List of async handlers.
|
||||
@@ -245,69 +307,81 @@ def _chain_async_model_call_handlers(
|
||||
First handler wraps all others.
|
||||
|
||||
Returns:
|
||||
Composed async handler, or `None` if handlers empty.
|
||||
Composed async handler returning ``_ComposedExtendedModelResponse``,
|
||||
or ``None`` if handlers empty.
|
||||
"""
|
||||
if not handlers:
|
||||
return None
|
||||
|
||||
def _to_composed_result(
|
||||
result: ModelResponse | AIMessage | ExtendedModelResponse | _ComposedExtendedModelResponse,
|
||||
extra_commands: list[Command[Any]] | None = None,
|
||||
) -> _ComposedExtendedModelResponse:
|
||||
"""Normalize any handler result to _ComposedExtendedModelResponse."""
|
||||
commands: list[Command[Any]] = list(extra_commands or [])
|
||||
if isinstance(result, _ComposedExtendedModelResponse):
|
||||
commands.extend(result.commands)
|
||||
model_response = result.model_response
|
||||
elif isinstance(result, ExtendedModelResponse):
|
||||
model_response = result.model_response
|
||||
if result.command is not None:
|
||||
commands.append(result.command)
|
||||
else:
|
||||
model_response = _normalize_to_model_response(result)
|
||||
|
||||
return _ComposedExtendedModelResponse(model_response=model_response, commands=commands)
|
||||
|
||||
if len(handlers) == 1:
|
||||
# Single handler - wrap to normalize output
|
||||
single_handler = handlers[0]
|
||||
|
||||
async def normalized_single(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
result = await single_handler(request, handler)
|
||||
return _normalize_to_model_response(result)
|
||||
) -> _ComposedExtendedModelResponse:
|
||||
return _to_composed_result(await single_handler(request, handler))
|
||||
|
||||
return normalized_single
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse],
|
||||
]:
|
||||
outer: _AsyncModelCallHandler[ContextT] | _ComposedAsyncModelCallHandler[ContextT],
|
||||
inner: _AsyncModelCallHandler[ContextT] | _ComposedAsyncModelCallHandler[ContextT],
|
||||
) -> _ComposedAsyncModelCallHandler[ContextT]:
|
||||
"""Compose two async handlers where outer wraps inner."""
|
||||
|
||||
async def composed(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
# Create a wrapper that calls inner with the base handler and normalizes
|
||||
) -> _ComposedExtendedModelResponse:
|
||||
# Closure variable to capture inner's commands before normalizing
|
||||
accumulated_commands: list[Command[Any]] = []
|
||||
|
||||
async def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse:
|
||||
# Clear on each call for retry safety
|
||||
accumulated_commands.clear()
|
||||
inner_result = await inner(req, handler)
|
||||
if isinstance(inner_result, _ComposedExtendedModelResponse):
|
||||
accumulated_commands.extend(inner_result.commands)
|
||||
return inner_result.model_response
|
||||
if isinstance(inner_result, ExtendedModelResponse):
|
||||
if inner_result.command is not None:
|
||||
accumulated_commands.append(inner_result.command)
|
||||
return inner_result.model_response
|
||||
return _normalize_to_model_response(inner_result)
|
||||
|
||||
# Call outer with the wrapped inner as its handler and normalize
|
||||
outer_result = await outer(request, inner_handler)
|
||||
return _normalize_to_model_response(outer_result)
|
||||
return _to_composed_result(
|
||||
outer_result,
|
||||
extra_commands=accumulated_commands or None,
|
||||
)
|
||||
|
||||
return composed
|
||||
|
||||
# Compose right-to-left: outer(inner(innermost(handler)))
|
||||
result = handlers[-1]
|
||||
for handler in reversed(handlers[:-1]):
|
||||
result = compose_two(handler, result)
|
||||
composed_handler = compose_two(handlers[-2], handlers[-1])
|
||||
for h in reversed(handlers[:-2]):
|
||||
composed_handler = compose_two(h, composed_handler)
|
||||
|
||||
# Wrap to ensure final return type is exactly ModelResponse
|
||||
async def final_normalized(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
||||
final_result = await result(request, handler)
|
||||
return _normalize_to_model_response(final_result)
|
||||
|
||||
return final_normalized
|
||||
return composed_handler
|
||||
|
||||
|
||||
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
||||
@@ -1165,7 +1239,7 @@ def create_agent(
|
||||
structured_response=structured_response,
|
||||
)
|
||||
|
||||
def model_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
def model_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> list[Command[Any]]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
@@ -1179,18 +1253,11 @@ def create_agent(
|
||||
)
|
||||
|
||||
if wrap_model_call_handler is None:
|
||||
# No handlers - execute directly
|
||||
response = _execute_model_sync(request)
|
||||
else:
|
||||
# Call composed handler with base handler
|
||||
response = wrap_model_call_handler(request, _execute_model_sync)
|
||||
model_response = _execute_model_sync(request)
|
||||
return _build_commands(model_response)
|
||||
|
||||
# Extract state updates from ModelResponse
|
||||
state_updates = {"messages": response.result}
|
||||
if response.structured_response is not None:
|
||||
state_updates["structured_response"] = response.structured_response
|
||||
|
||||
return state_updates
|
||||
result = wrap_model_call_handler(request, _execute_model_sync)
|
||||
return _build_commands(result.model_response, result.commands)
|
||||
|
||||
async def _execute_model_async(request: ModelRequest[ContextT]) -> ModelResponse:
|
||||
"""Execute model asynchronously and return response.
|
||||
@@ -1220,7 +1287,7 @@ def create_agent(
|
||||
structured_response=structured_response,
|
||||
)
|
||||
|
||||
async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> list[Command[Any]]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
@@ -1234,18 +1301,11 @@ def create_agent(
|
||||
)
|
||||
|
||||
if awrap_model_call_handler is None:
|
||||
# No async handlers - execute directly
|
||||
response = await _execute_model_async(request)
|
||||
else:
|
||||
# Call composed async handler with base handler
|
||||
response = await awrap_model_call_handler(request, _execute_model_async)
|
||||
model_response = await _execute_model_async(request)
|
||||
return _build_commands(model_response)
|
||||
|
||||
# Extract state updates from ModelResponse
|
||||
state_updates = {"messages": response.result}
|
||||
if response.structured_response is not None:
|
||||
state_updates["structured_response"] = response.structured_response
|
||||
|
||||
return state_updates
|
||||
result = await awrap_model_call_handler(request, _execute_model_async)
|
||||
return _build_commands(result.model_response, result.commands)
|
||||
|
||||
# Use sync or async based on model capabilities
|
||||
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
|
||||
|
||||
@@ -26,6 +26,8 @@ from langchain.agents.middleware.tool_selection import LLMToolSelectorMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ExtendedModelResponse,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ToolCallRequest,
|
||||
@@ -46,6 +48,7 @@ __all__ = [
|
||||
"CodexSandboxExecutionPolicy",
|
||||
"ContextEditingMiddleware",
|
||||
"DockerExecutionPolicy",
|
||||
"ExtendedModelResponse",
|
||||
"FilesystemFileSearchMiddleware",
|
||||
"HostExecutionPolicy",
|
||||
"HumanInTheLoopMiddleware",
|
||||
@@ -53,6 +56,7 @@ __all__ = [
|
||||
"LLMToolEmulator",
|
||||
"LLMToolSelectorMiddleware",
|
||||
"ModelCallLimitMiddleware",
|
||||
"ModelCallResult",
|
||||
"ModelFallbackMiddleware",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
|
||||
@@ -19,8 +19,6 @@ from typing import (
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
# Needed as top level import for Pydantic schema generation on AgentState
|
||||
import warnings
|
||||
from typing import TypeAlias
|
||||
@@ -42,6 +40,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.structured_output import ResponseFormat
|
||||
|
||||
@@ -49,6 +48,8 @@ __all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
"ContextT",
|
||||
"ExtendedModelResponse",
|
||||
"ModelCallResult",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"OmitFromSchema",
|
||||
@@ -285,14 +286,43 @@ class ModelResponse(Generic[ResponseT]):
|
||||
"""Parsed structured output if `response_format` was specified, `None` otherwise."""
|
||||
|
||||
|
||||
# Type alias for middleware return type - allows returning either full response or just AIMessage
|
||||
ModelCallResult: TypeAlias = "ModelResponse[ResponseT] | AIMessage"
|
||||
@dataclass
|
||||
class ExtendedModelResponse(Generic[ResponseT]):
|
||||
"""Model response with an optional 'Command' from 'wrap_model_call' middleware.
|
||||
|
||||
Use this to return a 'Command' alongside the model response from a
|
||||
'wrap_model_call' handler. The command is applied as an additional state
|
||||
update after the model node completes, using the graph's reducers (e.g.
|
||||
'add_messages' for the 'messages' key).
|
||||
|
||||
Because each 'Command' is applied through the reducer, messages in the
|
||||
command are **added alongside** the model response messages rather than
|
||||
replacing them. For non-reducer state fields, later commands overwrite
|
||||
earlier ones (outermost middleware wins over inner).
|
||||
|
||||
Type Parameters:
|
||||
ResponseT: The type of the structured response. Defaults to 'Any' if not specified.
|
||||
"""
|
||||
|
||||
model_response: ModelResponse[ResponseT]
|
||||
"""The underlying model response."""
|
||||
|
||||
command: Command[Any] | None = None
|
||||
"""Optional command to apply as an additional state update."""
|
||||
|
||||
|
||||
ModelCallResult: TypeAlias = (
|
||||
"ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]"
|
||||
)
|
||||
"""`TypeAlias` for model call handler return value.
|
||||
|
||||
Middleware can return either:
|
||||
|
||||
- `ModelResponse`: Full response with messages and optional structured output
|
||||
- `AIMessage`: Simplified return for simple use cases
|
||||
- `ExtendedModelResponse`: Response with an optional `Command` for additional state updates
|
||||
`goto`, `resume`, and `graph` are not yet supported on these commands.
|
||||
A `NotImplementedError` will be raised if you try to use them.
|
||||
"""
|
||||
|
||||
|
||||
@@ -449,7 +479,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
) -> ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]:
|
||||
"""Intercept and control model execution via handler callback.
|
||||
|
||||
Async version is `awrap_model_call`
|
||||
@@ -544,7 +574,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
) -> ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]:
|
||||
"""Intercept and control async model execution via handler callback.
|
||||
|
||||
The handler callback executes the model request and returns a `ModelResponse`.
|
||||
|
||||
@@ -99,6 +99,7 @@ langchain-core = { path = "../core", editable = true }
|
||||
langchain-tests = { path = "../standard-tests", editable = true }
|
||||
langchain-text-splitters = { path = "../text-splitters", editable = true }
|
||||
langchain-openai = { path = "../partners/openai", editable = true }
|
||||
langchain-anthropic = { path = "../partners/anthropic", editable = true }
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
|
||||
@@ -5,10 +5,11 @@ from typing import Any, TypedDict, cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.factory import _chain_model_call_handlers
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain.agents.factory import _chain_model_call_handlers, _ComposedExtendedModelResponse
|
||||
from langchain.agents.middleware.types import ExtendedModelResponse, ModelRequest, ModelResponse
|
||||
|
||||
|
||||
def create_test_request(**kwargs: Any) -> ModelRequest:
|
||||
@@ -88,9 +89,41 @@ class TestChainModelCallHandlers:
|
||||
"inner-after",
|
||||
"outer-after",
|
||||
]
|
||||
# Result is now ModelResponse
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "test"
|
||||
# Outermost result is always _ComposedExtendedModelResponse
|
||||
assert isinstance(result, _ComposedExtendedModelResponse)
|
||||
assert result.model_response.result[0].content == "test"
|
||||
|
||||
def test_two_handlers_with_commands(self) -> None:
|
||||
"""Test that commands from inner and outer are collected correctly."""
|
||||
|
||||
def outer(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"outer_key": "outer_val"}),
|
||||
)
|
||||
|
||||
def inner(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"inner_key": "inner_val"}),
|
||||
)
|
||||
|
||||
composed = _chain_model_call_handlers([outer, inner])
|
||||
assert composed is not None
|
||||
|
||||
result = composed(create_test_request(), create_mock_base_handler())
|
||||
|
||||
assert isinstance(result, _ComposedExtendedModelResponse)
|
||||
# Commands are collected: inner first, then outer
|
||||
assert len(result.commands) == 2
|
||||
assert result.commands[0].update == {"inner_key": "inner_val"}
|
||||
assert result.commands[1].update == {"outer_key": "outer_val"}
|
||||
|
||||
def test_three_handlers_composition(self) -> None:
|
||||
"""Test composition of three handlers."""
|
||||
@@ -134,8 +167,8 @@ class TestChainModelCallHandlers:
|
||||
"second-after",
|
||||
"first-after",
|
||||
]
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "test"
|
||||
assert isinstance(result, _ComposedExtendedModelResponse)
|
||||
assert result.model_response.result[0].content == "test"
|
||||
|
||||
def test_inner_handler_retry(self) -> None:
|
||||
"""Test inner handler retrying before outer sees response."""
|
||||
@@ -173,8 +206,8 @@ class TestChainModelCallHandlers:
|
||||
result = composed(create_test_request(), mock_base_handler)
|
||||
|
||||
assert inner_attempts == [0, 1, 2]
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "success"
|
||||
assert isinstance(result, _ComposedExtendedModelResponse)
|
||||
assert result.model_response.result[0].content == "success"
|
||||
|
||||
def test_error_to_success_conversion(self) -> None:
|
||||
"""Test handler converting error to success response."""
|
||||
@@ -202,10 +235,10 @@ class TestChainModelCallHandlers:
|
||||
|
||||
result = composed(create_test_request(), mock_base_handler)
|
||||
|
||||
# AIMessage was automatically converted to ModelResponse
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "Fallback response"
|
||||
assert result.structured_response is None
|
||||
# AIMessage was automatically normalized into ExtendedModelResponse
|
||||
assert isinstance(result, _ComposedExtendedModelResponse)
|
||||
assert result.model_response.result[0].content == "Fallback response"
|
||||
assert result.model_response.structured_response is None
|
||||
|
||||
def test_request_modification(self) -> None:
|
||||
"""Test handlers modifying the request."""
|
||||
@@ -231,8 +264,8 @@ class TestChainModelCallHandlers:
|
||||
result = composed(create_test_request(), create_mock_base_handler(content="response"))
|
||||
|
||||
assert requests_seen == ["Added by outer"]
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "response"
|
||||
assert isinstance(result, _ComposedExtendedModelResponse)
|
||||
assert result.model_response.result[0].content == "response"
|
||||
|
||||
def test_composition_preserves_state_and_runtime(self) -> None:
|
||||
"""Test that state and runtime are passed through composition."""
|
||||
@@ -273,8 +306,8 @@ class TestChainModelCallHandlers:
|
||||
# Both handlers should see same state and runtime
|
||||
assert state_values == [("outer", test_state), ("inner", test_state)]
|
||||
assert runtime_values == [("outer", test_runtime), ("inner", test_runtime)]
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "test"
|
||||
assert isinstance(result, _ComposedExtendedModelResponse)
|
||||
assert result.model_response.result[0].content == "test"
|
||||
|
||||
def test_multiple_yields_in_retry_loop(self) -> None:
|
||||
"""Test handler that retries multiple times."""
|
||||
@@ -312,5 +345,5 @@ class TestChainModelCallHandlers:
|
||||
# Outer called once, inner retried so base handler called twice
|
||||
assert call_count["value"] == 1
|
||||
assert attempt["value"] == 2
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "ok"
|
||||
assert isinstance(result, _ComposedExtendedModelResponse)
|
||||
assert result.model_response.result[0].content == "ok"
|
||||
|
||||
@@ -0,0 +1,917 @@
|
||||
"""Unit tests for ExtendedModelResponse command support in wrap_model_call.
|
||||
|
||||
Tests that wrap_model_call middleware can return ExtendedModelResponse to provide
|
||||
a Command alongside the model response. Commands are applied as separate state
|
||||
updates through graph reducers (e.g. add_messages for messages).
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.errors import InvalidUpdateError
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents import AgentState, create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ExtendedModelResponse,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
wrap_model_call,
|
||||
)
|
||||
|
||||
|
||||
class TestBasicCommand:
|
||||
"""Test basic ExtendedModelResponse functionality with Command."""
|
||||
|
||||
def test_command_messages_added_alongside_model_messages(self) -> None:
|
||||
"""Command messages are added alongside model response messages (additive)."""
|
||||
|
||||
class AddMessagesMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
custom_msg = HumanMessage(content="Custom message", id="custom")
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"messages": [custom_msg]}),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
agent = create_agent(model=model, middleware=[AddMessagesMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
# Both model response AND command messages appear (additive via add_messages)
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Hello!"
|
||||
assert messages[2].content == "Custom message"
|
||||
|
||||
def test_command_with_extra_messages_and_model_response(self) -> None:
|
||||
"""Middleware can add extra messages via command alongside model messages."""
|
||||
|
||||
class ExtraMessagesMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
summary = HumanMessage(content="Summary", id="summary")
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"messages": [summary]}),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
agent = create_agent(model=model, middleware=[ExtraMessagesMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Hello!"
|
||||
assert messages[2].content == "Summary"
|
||||
|
||||
def test_command_structured_response_conflicts_with_model_response(self) -> None:
|
||||
"""Command and model response both setting structured_response raises."""
|
||||
|
||||
class OverrideMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
response_with_structured = ModelResponse(
|
||||
result=response.result,
|
||||
structured_response={"from": "model"},
|
||||
)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response_with_structured,
|
||||
command=Command(
|
||||
update={
|
||||
"structured_response": {"from": "command"},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Model msg")]))
|
||||
agent = create_agent(model=model, middleware=[OverrideMiddleware()])
|
||||
|
||||
# Two Commands both setting structured_response (a LastValue channel)
|
||||
# in the same step raises InvalidUpdateError
|
||||
with pytest.raises(InvalidUpdateError):
|
||||
agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
def test_command_with_custom_state_field(self) -> None:
|
||||
"""When command updates a custom field, model response messages are preserved."""
|
||||
|
||||
class CustomFieldMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"custom_key": "custom_value"}),
|
||||
)
|
||||
|
||||
class CustomState(AgentState):
|
||||
custom_key: str
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[CustomFieldMiddleware()],
|
||||
state_schema=CustomState,
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert result["messages"][-1].content == "Hello"
|
||||
|
||||
|
||||
class TestCustomStateField:
|
||||
"""Test ExtendedModelResponse with custom state fields defined via state_schema."""
|
||||
|
||||
def test_custom_field_via_state_schema(self) -> None:
|
||||
"""Middleware updates a custom state field via ExtendedModelResponse."""
|
||||
|
||||
class MyState(AgentState):
|
||||
summary: str
|
||||
|
||||
class SummaryMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"summary": "conversation summarized"}),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[SummaryMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert result["messages"][-1].content == "Hello"
|
||||
|
||||
def test_no_command(self) -> None:
|
||||
"""ExtendedModelResponse with no command works like ModelResponse."""
|
||||
|
||||
class NoCommandMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[NoCommandMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Hello"
|
||||
|
||||
|
||||
class TestBackwardsCompatibility:
|
||||
"""Test that existing ModelResponse and AIMessage returns still work."""
|
||||
|
||||
def test_model_response_return_unchanged(self) -> None:
|
||||
"""Existing middleware returning ModelResponse works identically."""
|
||||
|
||||
class PassthroughMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
return handler(request)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[PassthroughMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Hello"
|
||||
|
||||
def test_ai_message_return_unchanged(self) -> None:
|
||||
"""Existing middleware returning AIMessage works identically."""
|
||||
|
||||
class ShortCircuitMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> AIMessage:
|
||||
return AIMessage(content="Short-circuited")
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Should not appear")]))
|
||||
agent = create_agent(model=model, middleware=[ShortCircuitMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Short-circuited"
|
||||
|
||||
def test_no_middleware_unchanged(self) -> None:
|
||||
"""Agent without middleware works identically."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Hello"
|
||||
|
||||
|
||||
class TestAsyncExtendedModelResponse:
|
||||
"""Test async variant of ExtendedModelResponse."""
|
||||
|
||||
async def test_async_command_adds_messages(self) -> None:
|
||||
"""awrap_model_call command adds messages alongside model response."""
|
||||
|
||||
class AsyncAddMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ExtendedModelResponse:
|
||||
response = await handler(request)
|
||||
custom = HumanMessage(content="Async custom", id="async-custom")
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"messages": [custom]}),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async hello!")]))
|
||||
agent = create_agent(model=model, middleware=[AsyncAddMiddleware()])
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
# Both model response and command messages are present (additive)
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Async hello!"
|
||||
assert messages[2].content == "Async custom"
|
||||
|
||||
async def test_async_decorator_command(self) -> None:
|
||||
"""@wrap_model_call async decorator returns ExtendedModelResponse with command."""
|
||||
|
||||
@wrap_model_call
|
||||
async def command_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ExtendedModelResponse:
|
||||
response = await handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(
|
||||
update={
|
||||
"messages": [
|
||||
HumanMessage(content="Decorator msg", id="dec"),
|
||||
]
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
|
||||
agent = create_agent(model=model, middleware=[command_middleware])
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[1].content == "Async response"
|
||||
assert messages[2].content == "Decorator msg"
|
||||
|
||||
|
||||
class TestComposition:
|
||||
"""Test ExtendedModelResponse with composed middleware.
|
||||
|
||||
Key semantics: Commands are collected inner-first, then outer.
|
||||
For non-reducer fields, later Commands overwrite (outer wins).
|
||||
For reducer fields (messages), all Commands are additive.
|
||||
"""
|
||||
|
||||
def test_outer_command_messages_added_alongside_model(self) -> None:
|
||||
"""Outer middleware's command messages are added alongside model messages."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
execution_order.append("outer-before")
|
||||
response = handler(request)
|
||||
execution_order.append("outer-after")
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(
|
||||
update={"messages": [HumanMessage(content="Outer msg", id="outer-msg")]}
|
||||
),
|
||||
)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
execution_order.append("inner-before")
|
||||
response = handler(request)
|
||||
execution_order.append("inner-after")
|
||||
return response
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Composed")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# Execution order: outer wraps inner
|
||||
assert execution_order == [
|
||||
"outer-before",
|
||||
"inner-before",
|
||||
"inner-after",
|
||||
"outer-after",
|
||||
]
|
||||
|
||||
# Model messages + outer command messages (additive)
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Composed"
|
||||
assert messages[2].content == "Outer msg"
|
||||
|
||||
def test_inner_command_propagated_through_composition(self) -> None:
|
||||
"""Inner middleware's ExtendedModelResponse command is propagated.
|
||||
|
||||
When inner middleware returns ExtendedModelResponse, its command is
|
||||
captured before normalizing to ModelResponse at the composition boundary
|
||||
and collected into the final result.
|
||||
"""
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Outer sees a ModelResponse from handler (inner's ExtendedModelResponse
|
||||
# was normalized at the composition boundary)
|
||||
response = handler(request)
|
||||
assert isinstance(response, ModelResponse)
|
||||
return response
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(
|
||||
update={
|
||||
"messages": [
|
||||
HumanMessage(content="Inner msg", id="inner"),
|
||||
]
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# Model messages + inner command messages (additive)
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Hello"
|
||||
assert messages[2].content == "Inner msg"
|
||||
|
||||
def test_non_reducer_key_conflict_raises(self) -> None:
|
||||
"""Multiple Commands setting the same non-reducer key raises.
|
||||
|
||||
LastValue channels (like custom_key) can only receive one value per
|
||||
step. Inner and outer both setting the same key is an error.
|
||||
"""
|
||||
|
||||
class MyState(AgentState):
|
||||
custom_key: str
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(
|
||||
update={
|
||||
"messages": [HumanMessage(content="Outer msg", id="outer")],
|
||||
"custom_key": "outer_value",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(
|
||||
update={
|
||||
"messages": [HumanMessage(content="Inner msg", id="inner")],
|
||||
"custom_key": "inner_value",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
# Two Commands both setting custom_key (a LastValue channel)
|
||||
# in the same step raises InvalidUpdateError
|
||||
with pytest.raises(InvalidUpdateError):
|
||||
agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
def test_inner_state_preserved_when_outer_has_no_conflict(self) -> None:
|
||||
"""Inner's command keys are preserved when outer doesn't conflict."""
|
||||
|
||||
class MyState(AgentState):
|
||||
inner_key: str
|
||||
outer_key: str
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"outer_key": "from_outer"}),
|
||||
)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"inner_key": "from_inner"}),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# Both keys survive since there's no conflict
|
||||
messages = result["messages"]
|
||||
assert messages[-1].content == "Hello"
|
||||
|
||||
def test_inner_command_retry_safe(self) -> None:
|
||||
"""When outer retries, only the last inner command is used."""
|
||||
call_count = 0
|
||||
|
||||
class MyState(AgentState):
|
||||
attempt: str
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Call handler twice (simulating retry)
|
||||
handler(request)
|
||||
return handler(request)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"attempt": f"attempt_{call_count}"}),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter([AIMessage(content="First"), AIMessage(content="Second")])
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# Only the last retry's inner state should survive
|
||||
messages = result["messages"]
|
||||
assert messages[-1].content == "Second"
|
||||
|
||||
def test_decorator_returns_wrap_result(self) -> None:
|
||||
"""@wrap_model_call decorator can return ExtendedModelResponse with command."""
|
||||
|
||||
@wrap_model_call
|
||||
def command_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(
|
||||
update={
|
||||
"messages": [
|
||||
HumanMessage(content="From decorator", id="dec"),
|
||||
]
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")]))
|
||||
agent = create_agent(model=model, middleware=[command_middleware])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[1].content == "Model response"
|
||||
assert messages[2].content == "From decorator"
|
||||
|
||||
def test_structured_response_preserved(self) -> None:
|
||||
"""ExtendedModelResponse preserves structured_response from ModelResponse."""
|
||||
|
||||
class StructuredMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
response_with_structured = ModelResponse(
|
||||
result=response.result,
|
||||
structured_response={"key": "value"},
|
||||
)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response_with_structured,
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[StructuredMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert result.get("structured_response") == {"key": "value"}
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[1].content == "Hello"
|
||||
|
||||
|
||||
class TestAsyncComposition:
|
||||
"""Test async ExtendedModelResponse propagation through composed middleware."""
|
||||
|
||||
async def test_async_inner_command_propagated(self) -> None:
|
||||
"""Async: inner middleware's ExtendedModelResponse command is propagated."""
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
response = await handler(request)
|
||||
assert isinstance(response, ModelResponse)
|
||||
return response
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ExtendedModelResponse:
|
||||
response = await handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(
|
||||
update={
|
||||
"messages": [
|
||||
HumanMessage(content="Inner msg", id="inner"),
|
||||
]
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# Model messages + inner command messages (additive)
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Hello"
|
||||
assert messages[2].content == "Inner msg"
|
||||
|
||||
async def test_async_both_commands_additive_messages(self) -> None:
|
||||
"""Async: both inner and outer command messages are added alongside model."""
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ExtendedModelResponse:
|
||||
response = await handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(
|
||||
update={"messages": [HumanMessage(content="Outer msg", id="outer")]}
|
||||
),
|
||||
)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ExtendedModelResponse:
|
||||
response = await handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(
|
||||
update={"messages": [HumanMessage(content="Inner msg", id="inner")]}
|
||||
),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# All messages additive: model + inner + outer
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 4
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Hello"
|
||||
assert messages[2].content == "Inner msg"
|
||||
assert messages[3].content == "Outer msg"
|
||||
|
||||
async def test_async_inner_command_retry_safe(self) -> None:
|
||||
"""Async: when outer retries, only last inner command is used."""
|
||||
call_count = 0
|
||||
|
||||
class MyState(AgentState):
|
||||
attempt: str
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
# Call handler twice (simulating retry)
|
||||
await handler(request)
|
||||
return await handler(request)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ExtendedModelResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
response = await handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(update={"attempt": f"attempt_{call_count}"}),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter([AIMessage(content="First"), AIMessage(content="Second")])
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert any(m.content == "Second" for m in messages)
|
||||
|
||||
|
||||
class TestCommandGotoDisallowed:
|
||||
"""Test that Command goto raises NotImplementedError in wrap_model_call."""
|
||||
|
||||
def test_command_goto_raises_not_implemented(self) -> None:
|
||||
"""Command with goto in wrap_model_call raises NotImplementedError."""
|
||||
|
||||
class GotoMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(goto="__end__"),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
agent = create_agent(model=model, middleware=[GotoMiddleware()])
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Command goto is not yet supported"):
|
||||
agent.invoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
async def test_async_command_goto_raises_not_implemented(self) -> None:
|
||||
"""Async: Command with goto in wrap_model_call raises NotImplementedError."""
|
||||
|
||||
class AsyncGotoMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ExtendedModelResponse:
|
||||
response = await handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(goto="tools"),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
agent = create_agent(model=model, middleware=[AsyncGotoMiddleware()])
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Command goto is not yet supported"):
|
||||
await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
|
||||
class TestCommandResumeDisallowed:
|
||||
"""Test that Command resume raises NotImplementedError in wrap_model_call."""
|
||||
|
||||
def test_command_resume_raises_not_implemented(self) -> None:
|
||||
"""Command with resume in wrap_model_call raises NotImplementedError."""
|
||||
|
||||
class ResumeMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(resume="some_value"),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
agent = create_agent(model=model, middleware=[ResumeMiddleware()])
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Command resume is not yet supported"):
|
||||
agent.invoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
async def test_async_command_resume_raises_not_implemented(self) -> None:
|
||||
"""Async: Command with resume in wrap_model_call raises NotImplementedError."""
|
||||
|
||||
class AsyncResumeMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ExtendedModelResponse:
|
||||
response = await handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(resume="some_value"),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
agent = create_agent(model=model, middleware=[AsyncResumeMiddleware()])
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Command resume is not yet supported"):
|
||||
await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
|
||||
class TestCommandGraphDisallowed:
|
||||
"""Test that Command graph raises NotImplementedError in wrap_model_call."""
|
||||
|
||||
def test_command_graph_raises_not_implemented(self) -> None:
|
||||
"""Command with graph in wrap_model_call raises NotImplementedError."""
|
||||
|
||||
class GraphMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ExtendedModelResponse:
|
||||
response = handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(graph=Command.PARENT, update={"messages": []}),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
agent = create_agent(model=model, middleware=[GraphMiddleware()])
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Command graph is not yet supported"):
|
||||
agent.invoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
async def test_async_command_graph_raises_not_implemented(self) -> None:
|
||||
"""Async: Command with graph in wrap_model_call raises NotImplementedError."""
|
||||
|
||||
class AsyncGraphMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ExtendedModelResponse:
|
||||
response = await handler(request)
|
||||
return ExtendedModelResponse(
|
||||
model_response=response,
|
||||
command=Command(graph=Command.PARENT, update={"messages": []}),
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
agent = create_agent(model=model, middleware=[AsyncGraphMiddleware()])
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Command graph is not yet supported"):
|
||||
await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
|
||||
@@ -313,6 +313,7 @@ def test_configurable_with_default() -> None:
|
||||
"default_headers": None,
|
||||
"model_kwargs": {},
|
||||
"reuse_last_container": None,
|
||||
"inference_geo": None,
|
||||
"streaming": False,
|
||||
"stream_usage": True,
|
||||
"output_version": None,
|
||||
|
||||
53
libs/langchain_v1/uv.lock
generated
53
libs/langchain_v1/uv.lock
generated
@@ -173,7 +173,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "anthropic"
|
||||
version = "0.72.1"
|
||||
version = "0.78.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
@@ -185,9 +185,9 @@ dependencies = [
|
||||
{ name = "sniffio" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dd/f3/feb750a21461090ecf48bbebcaa261cd09003cc1d14e2fa9643ad59edd4d/anthropic-0.72.1.tar.gz", hash = "sha256:a6d1d660e1f4af91dddc732f340786d19acaffa1ae8e69442e56be5fa6539d51", size = 415395, upload-time = "2025-11-11T16:53:29.001Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ec/51/32849a48f9b1cfe80a508fd269b20bd8f0b1357c70ba092890fde5a6a10b/anthropic-0.78.0.tar.gz", hash = "sha256:55fd978ab9b049c61857463f4c4e9e092b24f892519c6d8078cee1713d8af06e", size = 509136, upload-time = "2026-02-05T17:52:04.986Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/51/05/d9d45edad1aa28330cea09a3b35e1590f7279f91bb5ab5237c70a0884ea3/anthropic-0.72.1-py3-none-any.whl", hash = "sha256:81e73cca55e8924776c8c4418003defe6bf9eaf0cd92beb94c8dbf537b95316f", size = 357373, upload-time = "2025-11-11T16:53:27.438Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/03/2f50931a942e5e13f80e24d83406714672c57964be593fc046d81369335b/anthropic-0.78.0-py3-none-any.whl", hash = "sha256:2a9887d2e99d1b0f9fe08857a1e9fe5d2d4030455dbf9ac65aab052e2efaeac4", size = 405485, upload-time = "2026-02-05T17:52:03.674Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1975,7 +1975,7 @@ typing = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "langchain-anthropic", marker = "extra == 'anthropic'" },
|
||||
{ name = "langchain-anthropic", marker = "extra == 'anthropic'", editable = "../partners/anthropic" },
|
||||
{ name = "langchain-aws", marker = "extra == 'aws'" },
|
||||
{ name = "langchain-azure-ai", marker = "extra == 'azure-ai'" },
|
||||
{ name = "langchain-community", marker = "extra == 'community'" },
|
||||
@@ -2028,16 +2028,51 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-anthropic"
|
||||
version = "1.0.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
version = "1.3.1"
|
||||
source = { editable = "../partners/anthropic" }
|
||||
dependencies = [
|
||||
{ name = "anthropic" },
|
||||
{ name = "langchain-core" },
|
||||
{ name = "pydantic" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/92/6b/aaa770beea6f4ed4c3f5c75fd6d80ed5c82708aec15318c06d9379dd3543/langchain_anthropic-1.0.3.tar.gz", hash = "sha256:91083c5df82634602f6772989918108d9448fa0b7499a11434687198f5bf9aef", size = 680336, upload-time = "2025-11-12T15:58:58.54Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/3d/0499eeb10d333ea79e1c156b84d075e96cfde2fbc6a9ec9cbfa50ac3e47e/langchain_anthropic-1.0.3-py3-none-any.whl", hash = "sha256:0d4106111d57e19988e2976fb6c6e59b0c47ca7afb0f6a2f888362006f871871", size = 46608, upload-time = "2025-11-12T15:58:57.28Z" },
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "anthropic", specifier = ">=0.75.0,<1.0.0" },
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [{ name = "langchain-core", editable = "../core" }]
|
||||
lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }]
|
||||
test = [
|
||||
{ name = "blockbuster", specifier = ">=1.5.5,<1.6" },
|
||||
{ name = "defusedxml", specifier = ">=0.7.1,<1.0.0" },
|
||||
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
|
||||
{ name = "langchain", editable = "." },
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "langchain-tests", editable = "../standard-tests" },
|
||||
{ name = "langgraph-prebuilt", specifier = ">=0.7.0a2" },
|
||||
{ name = "pytest", specifier = ">=7.3.0,<8.0.0" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" },
|
||||
{ name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" },
|
||||
{ name = "pytest-retry", specifier = ">=1.7.0,<1.8.0" },
|
||||
{ name = "pytest-socket", specifier = ">=0.7.0,<1.0.0" },
|
||||
{ name = "pytest-timeout", specifier = ">=2.3.1,<3.0.0" },
|
||||
{ name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" },
|
||||
{ name = "pytest-xdist", specifier = ">=3.8.0,<4.0.0" },
|
||||
{ name = "syrupy", specifier = ">=4.0.2,<5.0.0" },
|
||||
{ name = "vcrpy", specifier = ">=8.0.0,<9.0.0" },
|
||||
]
|
||||
test-integration = [
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "requests", specifier = ">=2.32.3,<3.0.0" },
|
||||
]
|
||||
typing = [
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "mypy", specifier = ">=1.17.1,<2.0.0" },
|
||||
{ name = "types-requests", specifier = ">=2.31.0,<3.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user