Compare commits

...

7 Commits

Author SHA1 Message Date
Sydney Runkle
9dac16e66a simplify composure logic 2026-02-05 11:27:26 -05:00
Sydney Runkle
e8544fc660 simplify 2026-02-05 10:10:27 -05:00
Sydney Runkle
a69ceb5f20 propogate 2026-02-05 09:35:43 -05:00
Sydney Runkle
6a4778743d Merge branch 'master' into sr/wrap-model-call-improvements 2026-02-05 09:20:22 -05:00
Sydney Runkle
8ce516c26b boom 2026-02-05 09:16:54 -05:00
Sydney Runkle
10224015f3 wrap improvements 2026-02-05 09:06:05 -05:00
Sydney Runkle
0626c439c1 prep for LC release 2026-02-05 07:45:49 -05:00
5 changed files with 938 additions and 92 deletions

View File

@@ -33,6 +33,7 @@ from langchain.agents.middleware.types import (
OmitFromSchema,
ResponseT,
StateT_co,
WrapModelCallResult,
_InputAgentState,
_OutputAgentState,
)
@@ -102,31 +103,80 @@ FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
]
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
"""Normalize middleware return value to ModelResponse."""
def _normalize_to_model_response(
result: ModelResponse | AIMessage | WrapModelCallResult,
) -> ModelResponse:
"""Normalize middleware return value to ModelResponse.
At inner composition boundaries, ``WrapModelCallResult`` is unwrapped to its
underlying ``ModelResponse`` so that inner middleware always sees ``ModelResponse``
from the handler.
"""
if isinstance(result, AIMessage):
return ModelResponse(result=[result], structured_response=None)
if isinstance(result, WrapModelCallResult):
return result.model_response
return result
def _normalize_to_wrap_model_call_result(
result: ModelResponse | AIMessage | WrapModelCallResult,
) -> WrapModelCallResult:
"""Normalize middleware return value to WrapModelCallResult.
At the outermost composition boundary, ensures the result is always a
``WrapModelCallResult`` so the model node has a single code path.
"""
if isinstance(result, WrapModelCallResult):
return result
return WrapModelCallResult(
model_response=_normalize_to_model_response(result),
state_update={},
)
def _build_state_updates_from_wrap_result(response: WrapModelCallResult) -> dict[str, Any]:
"""Build state updates from a ``WrapModelCallResult``.
Starts with model response defaults (messages, structured_response), then
overlays the middleware's ``state_update``. If ``state_update`` contains a key
that conflicts with defaults, the ``state_update`` value wins.
Args:
response: The wrap model call result containing model response and state updates.
Returns:
State updates dict ready to be returned from a model node.
"""
state_updates: dict[str, Any] = {"messages": response.model_response.result}
if response.model_response.structured_response is not None:
state_updates["structured_response"] = response.model_response.structured_response
state_updates.update(response.state_update)
return state_updates
def _chain_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse | AIMessage,
ModelResponse | AIMessage | WrapModelCallResult,
]
],
) -> (
Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse,
WrapModelCallResult,
]
| None
):
"""Compose multiple `wrap_model_call` handlers into single middleware stack.
Composes handlers so first in list becomes outermost layer. Each handler receives a
handler callback to execute inner layers.
handler callback to execute inner layers. The outermost result is always normalized
to ``WrapModelCallResult`` so callers have a single code path.
Args:
handlers: List of handlers.
@@ -164,81 +214,90 @@ def _chain_model_call_handlers(
return None
if len(handlers) == 1:
# Single handler - wrap to normalize output
single_handler = handlers[0]
def normalized_single(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
) -> ModelResponse:
result = single_handler(request, handler)
return _normalize_to_model_response(result)
) -> WrapModelCallResult:
return _normalize_to_wrap_model_call_result(single_handler(request, handler))
return normalized_single
def compose_two(
outer: Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse | AIMessage,
ModelResponse | AIMessage | WrapModelCallResult,
],
inner: Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse | AIMessage,
ModelResponse | AIMessage | WrapModelCallResult,
],
) -> Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
ModelResponse,
WrapModelCallResult,
]:
"""Compose two handlers where outer wraps inner."""
def composed(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
) -> ModelResponse:
) -> WrapModelCallResult:
# Closure variable to capture inner's state_update before normalizing
accumulated_inner_state: list[dict[str, Any]] = []
# Create a wrapper that calls inner with the base handler and normalizes
# Inner boundaries always normalize to ModelResponse
def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse:
# Clear on each call for retry safety — only the last call's
# state updates survive
accumulated_inner_state.clear()
inner_result = inner(req, handler)
if isinstance(inner_result, WrapModelCallResult):
accumulated_inner_state.append(inner_result.state_update)
return _normalize_to_model_response(inner_result)
# Call outer with the wrapped inner as its handler and normalize
# Call outer with the wrapped inner as its handler
outer_result = outer(request, inner_handler)
return _normalize_to_model_response(outer_result)
# Normalize outer result then merge inner state under outer (outer wins)
inner_state = accumulated_inner_state[0] if accumulated_inner_state else {}
outer_wrapped = _normalize_to_wrap_model_call_result(outer_result)
return WrapModelCallResult(
model_response=outer_wrapped.model_response,
state_update={**inner_state, **outer_wrapped.state_update},
)
return composed
# Compose right-to-left: outer(inner(innermost(handler)))
result = handlers[-1]
for handler in reversed(handlers[:-1]):
result = compose_two(handler, result)
# Seed with the innermost pair so the variable is typed as WrapModelCallResult
composed_handler = compose_two(handlers[-2], handlers[-1])
for h in reversed(handlers[:-2]):
composed_handler = compose_two(h, composed_handler)
# Wrap to ensure final return type is exactly ModelResponse
def final_normalized(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = result(request, handler)
return _normalize_to_model_response(final_result)
return final_normalized
return composed_handler
def _chain_async_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse | AIMessage],
Awaitable[ModelResponse | AIMessage | WrapModelCallResult],
]
],
) -> (
Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse],
Awaitable[WrapModelCallResult],
]
| None
):
"""Compose multiple async `wrap_model_call` handlers into single middleware stack.
The outermost result is always normalized to ``WrapModelCallResult`` so callers
have a single code path.
Args:
handlers: List of async handlers.
@@ -251,63 +310,69 @@ def _chain_async_model_call_handlers(
return None
if len(handlers) == 1:
# Single handler - wrap to normalize output
single_handler = handlers[0]
async def normalized_single(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
) -> ModelResponse:
result = await single_handler(request, handler)
return _normalize_to_model_response(result)
) -> WrapModelCallResult:
return _normalize_to_wrap_model_call_result(await single_handler(request, handler))
return normalized_single
def compose_two(
outer: Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse | AIMessage],
Awaitable[ModelResponse | AIMessage | WrapModelCallResult],
],
inner: Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse | AIMessage],
Awaitable[ModelResponse | AIMessage | WrapModelCallResult],
],
) -> Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
Awaitable[ModelResponse],
Awaitable[WrapModelCallResult],
]:
"""Compose two async handlers where outer wraps inner."""
async def composed(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
) -> ModelResponse:
) -> WrapModelCallResult:
# Closure variable to capture inner's state_update before normalizing
accumulated_inner_state: list[dict[str, Any]] = []
# Create a wrapper that calls inner with the base handler and normalizes
# Inner boundaries always normalize to ModelResponse
async def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse:
# Clear on each call for retry safety — only the last call's
# state updates survive
accumulated_inner_state.clear()
inner_result = await inner(req, handler)
if isinstance(inner_result, WrapModelCallResult):
accumulated_inner_state.append(inner_result.state_update)
return _normalize_to_model_response(inner_result)
# Call outer with the wrapped inner as its handler and normalize
# Call outer with the wrapped inner as its handler
outer_result = await outer(request, inner_handler)
return _normalize_to_model_response(outer_result)
# Normalize outer result then merge inner state under outer (outer wins)
inner_state = accumulated_inner_state[0] if accumulated_inner_state else {}
outer_wrapped = _normalize_to_wrap_model_call_result(outer_result)
return WrapModelCallResult(
model_response=outer_wrapped.model_response,
state_update={**inner_state, **outer_wrapped.state_update},
)
return composed
# Compose right-to-left: outer(inner(innermost(handler)))
result = handlers[-1]
for handler in reversed(handlers[:-1]):
result = compose_two(handler, result)
# Seed with the innermost pair so the variable is typed as WrapModelCallResult
composed_handler = compose_two(handlers[-2], handlers[-1])
for h in reversed(handlers[:-2]):
composed_handler = compose_two(h, composed_handler)
# Wrap to ensure final return type is exactly ModelResponse
async def final_normalized(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = await result(request, handler)
return _normalize_to_model_response(final_result)
return final_normalized
return composed_handler
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
@@ -1179,18 +1244,11 @@ def create_agent(
)
if wrap_model_call_handler is None:
# No handlers - execute directly
response = _execute_model_sync(request)
response = _normalize_to_wrap_model_call_result(_execute_model_sync(request))
else:
# Call composed handler with base handler
response = wrap_model_call_handler(request, _execute_model_sync)
# Extract state updates from ModelResponse
state_updates = {"messages": response.result}
if response.structured_response is not None:
state_updates["structured_response"] = response.structured_response
return state_updates
return _build_state_updates_from_wrap_result(response)
async def _execute_model_async(request: ModelRequest[ContextT]) -> ModelResponse:
"""Execute model asynchronously and return response.
@@ -1234,18 +1292,11 @@ def create_agent(
)
if awrap_model_call_handler is None:
# No async handlers - execute directly
response = await _execute_model_async(request)
response = _normalize_to_wrap_model_call_result(await _execute_model_async(request))
else:
# Call composed async handler with base handler
response = await awrap_model_call_handler(request, _execute_model_async)
# Extract state updates from ModelResponse
state_updates = {"messages": response.result}
if response.structured_response is not None:
state_updates["structured_response"] = response.structured_response
return state_updates
return _build_state_updates_from_wrap_result(response)
# Use sync or async based on model capabilities
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))

View File

@@ -29,6 +29,7 @@ from langchain.agents.middleware.types import (
ModelRequest,
ModelResponse,
ToolCallRequest,
WrapModelCallResult,
after_agent,
after_model,
before_agent,
@@ -66,6 +67,7 @@ __all__ = [
"ToolCallLimitMiddleware",
"ToolCallRequest",
"ToolRetryMiddleware",
"WrapModelCallResult",
"after_agent",
"after_model",
"before_agent",

View File

@@ -285,14 +285,42 @@ class ModelResponse(Generic[ResponseT]):
"""Parsed structured output if `response_format` was specified, `None` otherwise."""
@dataclass
class WrapModelCallResult(Generic[ResponseT]):
"""Model response with additional state updates from ``wrap_model_call`` middleware.
Use this to return state updates alongside the model response from a
``wrap_model_call`` handler. State updates are merged into the agent state
after the model node completes.
The ``state_update`` dict overwrites the default model response state updates.
For example, if ``state_update`` contains a ``"messages"`` key, those messages
replace the model response messages entirely. If you want to include both custom
messages and the model response, include them explicitly in your state update.
When multiple middleware return ``WrapModelCallResult``, the outermost
middleware's ``state_update`` wins on key conflicts.
Type Parameters:
ResponseT: The type of the structured response. Defaults to `Any` if not specified.
"""
model_response: ModelResponse[ResponseT]
"""The underlying model response."""
state_update: dict[str, Any]
"""Additional state updates to merge into the agent state."""
# Type alias for middleware return type - allows returning either full response or just AIMessage
ModelCallResult: TypeAlias = "ModelResponse[ResponseT] | AIMessage"
ModelCallResult: TypeAlias = "ModelResponse[ResponseT] | AIMessage | WrapModelCallResult[ResponseT]"
"""`TypeAlias` for model call handler return value.
Middleware can return either:
- `ModelResponse`: Full response with messages and optional structured output
- `AIMessage`: Simplified return for simple use cases
- `WrapModelCallResult`: Response with additional state updates
"""
@@ -449,7 +477,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
) -> ModelResponse[ResponseT] | AIMessage | WrapModelCallResult[ResponseT]:
"""Intercept and control model execution via handler callback.
Async version is `awrap_model_call`
@@ -544,7 +572,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
) -> ModelResponse[ResponseT] | AIMessage | WrapModelCallResult[ResponseT]:
"""Intercept and control async model execution via handler callback.
The handler callback executes the model request and returns a `ModelResponse`.

View File

@@ -8,7 +8,7 @@ from langgraph.runtime import Runtime
from langchain.agents import AgentState
from langchain.agents.factory import _chain_model_call_handlers
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain.agents.middleware.types import ModelRequest, ModelResponse, WrapModelCallResult
def create_test_request(**kwargs: Any) -> ModelRequest:
@@ -88,9 +88,9 @@ class TestChainModelCallHandlers:
"inner-after",
"outer-after",
]
# Result is now ModelResponse
assert isinstance(result, ModelResponse)
assert result.result[0].content == "test"
# Outermost result is always WrapModelCallResult
assert isinstance(result, WrapModelCallResult)
assert result.model_response.result[0].content == "test"
def test_three_handlers_composition(self) -> None:
"""Test composition of three handlers."""
@@ -134,8 +134,8 @@ class TestChainModelCallHandlers:
"second-after",
"first-after",
]
assert isinstance(result, ModelResponse)
assert result.result[0].content == "test"
assert isinstance(result, WrapModelCallResult)
assert result.model_response.result[0].content == "test"
def test_inner_handler_retry(self) -> None:
"""Test inner handler retrying before outer sees response."""
@@ -173,8 +173,8 @@ class TestChainModelCallHandlers:
result = composed(create_test_request(), mock_base_handler)
assert inner_attempts == [0, 1, 2]
assert isinstance(result, ModelResponse)
assert result.result[0].content == "success"
assert isinstance(result, WrapModelCallResult)
assert result.model_response.result[0].content == "success"
def test_error_to_success_conversion(self) -> None:
"""Test handler converting error to success response."""
@@ -202,10 +202,10 @@ class TestChainModelCallHandlers:
result = composed(create_test_request(), mock_base_handler)
# AIMessage was automatically converted to ModelResponse
assert isinstance(result, ModelResponse)
assert result.result[0].content == "Fallback response"
assert result.structured_response is None
# AIMessage was automatically normalized into WrapModelCallResult
assert isinstance(result, WrapModelCallResult)
assert result.model_response.result[0].content == "Fallback response"
assert result.model_response.structured_response is None
def test_request_modification(self) -> None:
"""Test handlers modifying the request."""
@@ -231,8 +231,8 @@ class TestChainModelCallHandlers:
result = composed(create_test_request(), create_mock_base_handler(content="response"))
assert requests_seen == ["Added by outer"]
assert isinstance(result, ModelResponse)
assert result.result[0].content == "response"
assert isinstance(result, WrapModelCallResult)
assert result.model_response.result[0].content == "response"
def test_composition_preserves_state_and_runtime(self) -> None:
"""Test that state and runtime are passed through composition."""
@@ -273,8 +273,8 @@ class TestChainModelCallHandlers:
# Both handlers should see same state and runtime
assert state_values == [("outer", test_state), ("inner", test_state)]
assert runtime_values == [("outer", test_runtime), ("inner", test_runtime)]
assert isinstance(result, ModelResponse)
assert result.result[0].content == "test"
assert isinstance(result, WrapModelCallResult)
assert result.model_response.result[0].content == "test"
def test_multiple_yields_in_retry_loop(self) -> None:
"""Test handler that retries multiple times."""
@@ -312,5 +312,5 @@ class TestChainModelCallHandlers:
# Outer called once, inner retried so base handler called twice
assert call_count["value"] == 1
assert attempt["value"] == 2
assert isinstance(result, ModelResponse)
assert result.result[0].content == "ok"
assert isinstance(result, WrapModelCallResult)
assert result.model_response.result[0].content == "ok"

View File

@@ -0,0 +1,765 @@
"""Unit tests for WrapModelCallResult state update support in wrap_model_call.
Tests that wrap_model_call middleware can return WrapModelCallResult to provide
state updates alongside the model response, with outermost middleware winning
on key conflicts.
"""
from collections.abc import Awaitable, Callable
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain.agents import AgentState, create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelRequest,
ModelResponse,
WrapModelCallResult,
wrap_model_call,
)
class TestBasicStateUpdate:
"""Test basic WrapModelCallResult functionality."""
def test_state_update_overwrites_model_messages(self) -> None:
"""state_update with 'messages' key overwrites model response messages."""
class OverwriteMessagesMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
custom_msg = HumanMessage(content="Custom message", id="custom")
return WrapModelCallResult(
model_response=response,
state_update={"messages": [custom_msg]},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, middleware=[OverwriteMessagesMiddleware()])
result = agent.invoke({"messages": [HumanMessage(content="Hi")]})
# Model response messages are overwritten — only custom message survives
messages = result["messages"]
assert len(messages) == 2
assert messages[0].content == "Hi"
assert messages[1].content == "Custom message"
def test_state_update_includes_model_messages_explicitly(self) -> None:
"""Middleware can include model messages alongside custom ones explicitly."""
class ExplicitMessagesMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
summary = HumanMessage(content="Summary", id="summary")
return WrapModelCallResult(
model_response=response,
state_update={"messages": [summary, *response.result]},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, middleware=[ExplicitMessagesMiddleware()])
result = agent.invoke({"messages": [HumanMessage(content="Hi")]})
messages = result["messages"]
assert len(messages) == 3
assert messages[0].content == "Hi"
assert messages[1].content == "Summary"
assert messages[2].content == "Hello!"
def test_state_update_takes_priority_over_model_response(self) -> None:
"""state_update messages and structured_response take priority over model response."""
class OverrideMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
# Model response has its own messages and structured_response,
# but state_update should win for both.
response_with_structured = ModelResponse(
result=response.result,
structured_response={"from": "model"},
)
return WrapModelCallResult(
model_response=response_with_structured,
state_update={
"messages": [HumanMessage(content="From state_update", id="override")],
"structured_response": {"from": "state_update"},
},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Model msg")]))
agent = create_agent(model=model, middleware=[OverrideMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
# state_update messages win over model response messages
messages = result["messages"]
assert len(messages) == 2
assert messages[0].content == "Hi"
assert messages[1].content == "From state_update"
# state_update structured_response wins over model response structured_response
assert result["structured_response"] == {"from": "state_update"}
def test_state_update_without_messages_key(self) -> None:
"""When state_update doesn't include 'messages', model response messages are used."""
class CustomFieldMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
return WrapModelCallResult(
model_response=response,
state_update={"custom_key": "custom_value"},
)
class CustomState(AgentState):
custom_key: str
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[CustomFieldMiddleware()],
state_schema=CustomState,
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert result["messages"][-1].content == "Hello"
class TestCustomStateField:
"""Test WrapModelCallResult with custom state fields defined via state_schema."""
def test_custom_field_via_state_schema(self) -> None:
"""Middleware updates a custom state field via WrapModelCallResult."""
class MyState(AgentState):
summary: str
class SummaryMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
return WrapModelCallResult(
model_response=response,
state_update={"summary": "conversation summarized"},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[SummaryMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert result["messages"][-1].content == "Hello"
def test_empty_state_update(self) -> None:
"""WrapModelCallResult with empty state_update works like ModelResponse."""
class EmptyUpdateMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
return WrapModelCallResult(
model_response=response,
state_update={},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[EmptyUpdateMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Hello"
class TestBackwardsCompatibility:
"""Test that existing ModelResponse and AIMessage returns still work."""
def test_model_response_return_unchanged(self) -> None:
"""Existing middleware returning ModelResponse works identically."""
class PassthroughMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
return handler(request)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[PassthroughMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Hello"
def test_ai_message_return_unchanged(self) -> None:
"""Existing middleware returning AIMessage works identically."""
class ShortCircuitMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> AIMessage:
return AIMessage(content="Short-circuited")
model = GenericFakeChatModel(messages=iter([AIMessage(content="Should not appear")]))
agent = create_agent(model=model, middleware=[ShortCircuitMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Short-circuited"
def test_no_middleware_unchanged(self) -> None:
"""Agent without middleware works identically."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Hello"
class TestAsyncWrapModelCallResult:
"""Test async variant of WrapModelCallResult."""
async def test_async_state_update_overwrites(self) -> None:
"""awrap_model_call state_update overwrites model response messages."""
class AsyncOverwriteMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> WrapModelCallResult:
response = await handler(request)
custom = HumanMessage(content="Async custom", id="async-custom")
return WrapModelCallResult(
model_response=response,
state_update={"messages": [custom]},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async hello!")]))
agent = create_agent(model=model, middleware=[AsyncOverwriteMiddleware()])
result = await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
messages = result["messages"]
assert len(messages) == 2
assert messages[0].content == "Hi"
assert messages[1].content == "Async custom"
async def test_async_decorator_state_update(self) -> None:
"""@wrap_model_call async decorator returns WrapModelCallResult."""
@wrap_model_call
async def state_update_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> WrapModelCallResult:
response = await handler(request)
return WrapModelCallResult(
model_response=response,
state_update={
"messages": [
HumanMessage(content="Decorator msg", id="dec"),
*response.result,
]
},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
agent = create_agent(model=model, middleware=[state_update_middleware])
result = await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
messages = result["messages"]
assert len(messages) == 3
assert messages[1].content == "Decorator msg"
assert messages[2].content == "Async response"
class TestComposition:
"""Test WrapModelCallResult with composed middleware.
Key semantics: outermost middleware's state_update wins on key conflicts.
"""
def test_outer_wrap_result_overwrites_model_messages(self) -> None:
"""Outer middleware's state_update overwrites model response messages."""
execution_order: list[str] = []
class OuterMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
execution_order.append("outer-before")
response = handler(request)
execution_order.append("outer-after")
return WrapModelCallResult(
model_response=response,
state_update={"messages": [HumanMessage(content="Outer msg", id="outer-msg")]},
)
class InnerMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
execution_order.append("inner-before")
response = handler(request)
execution_order.append("inner-after")
return response
model = GenericFakeChatModel(messages=iter([AIMessage(content="Composed")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
# Execution order: outer wraps inner
assert execution_order == [
"outer-before",
"inner-before",
"inner-after",
"outer-after",
]
# Outer's state_update overwrites model messages
messages = result["messages"]
assert len(messages) == 2
assert messages[0].content == "Hi"
assert messages[1].content == "Outer msg"
def test_inner_wrap_result_propagated_through_composition(self) -> None:
"""Inner middleware's WrapModelCallResult state_update is propagated.
When inner middleware returns WrapModelCallResult, its state_update is
captured before normalizing to ModelResponse at the composition boundary
and merged into the final result.
"""
class OuterMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
# Outer sees a ModelResponse from handler (inner's WrapModelCallResult
# was normalized at the composition boundary)
response = handler(request)
assert isinstance(response, ModelResponse)
return response
class InnerMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
return WrapModelCallResult(
model_response=response,
state_update={
"messages": [
HumanMessage(content="Inner msg", id="inner"),
*response.result,
]
},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
messages = result["messages"]
assert len(messages) == 3
assert messages[1].content == "Inner msg"
assert messages[2].content == "Hello"
def test_outer_state_update_wins_on_all_key_conflicts(self) -> None:
"""Outer's state_update fully overwrites inner's on all conflicting keys.
This applies to all keys including 'messages' — no special casing.
"""
class MyState(AgentState):
custom_key: str
class OuterMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
return WrapModelCallResult(
model_response=response,
state_update={
"messages": [HumanMessage(content="Outer msg", id="outer")],
"custom_key": "outer_value",
},
)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
return WrapModelCallResult(
model_response=response,
state_update={
"messages": [HumanMessage(content="Inner msg", id="inner")],
"custom_key": "inner_value",
},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
# Outer wins on all keys — inner's messages and custom_key are overwritten
messages = result["messages"]
assert len(messages) == 2
assert messages[0].content == "Hi"
assert messages[1].content == "Outer msg"
def test_inner_state_preserved_when_outer_has_no_conflict(self) -> None:
"""Inner's state_update keys are preserved when outer doesn't conflict."""
class MyState(AgentState):
inner_key: str
outer_key: str
class OuterMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
return WrapModelCallResult(
model_response=response,
state_update={"outer_key": "from_outer"},
)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
return WrapModelCallResult(
model_response=response,
state_update={"inner_key": "from_inner"},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
# Both keys survive since there's no conflict
messages = result["messages"]
assert messages[-1].content == "Hello"
def test_inner_state_update_retry_safe(self) -> None:
"""When outer retries, only the last inner state update is used."""
call_count = 0
class MyState(AgentState):
attempt: str
class OuterMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
# Call handler twice (simulating retry)
handler(request)
return handler(request)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
nonlocal call_count
call_count += 1
response = handler(request)
return WrapModelCallResult(
model_response=response,
state_update={"attempt": f"attempt_{call_count}"},
)
model = GenericFakeChatModel(
messages=iter([AIMessage(content="First"), AIMessage(content="Second")])
)
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("Hi")]})
# Only the last retry's inner state should survive
messages = result["messages"]
assert messages[-1].content == "Second"
def test_decorator_returns_wrap_result(self) -> None:
"""@wrap_model_call decorator can return WrapModelCallResult."""
@wrap_model_call
def state_update_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
return WrapModelCallResult(
model_response=response,
state_update={
"messages": [
HumanMessage(content="From decorator", id="dec"),
*response.result,
]
},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")]))
agent = create_agent(model=model, middleware=[state_update_middleware])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
messages = result["messages"]
assert len(messages) == 3
assert messages[1].content == "From decorator"
assert messages[2].content == "Model response"
def test_structured_response_preserved(self) -> None:
"""WrapModelCallResult preserves structured_response from ModelResponse."""
class StructuredMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> WrapModelCallResult:
response = handler(request)
response_with_structured = ModelResponse(
result=response.result,
structured_response={"key": "value"},
)
return WrapModelCallResult(
model_response=response_with_structured,
state_update={},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[StructuredMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert result.get("structured_response") == {"key": "value"}
messages = result["messages"]
assert len(messages) == 2
assert messages[1].content == "Hello"
class TestAsyncComposition:
"""Test async WrapModelCallResult propagation through composed middleware."""
async def test_async_inner_wrap_result_propagated(self) -> None:
"""Async: inner middleware's WrapModelCallResult state_update is propagated."""
class OuterMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
response = await handler(request)
assert isinstance(response, ModelResponse)
return response
class InnerMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> WrapModelCallResult:
response = await handler(request)
return WrapModelCallResult(
model_response=response,
state_update={
"messages": [
HumanMessage(content="Inner msg", id="inner"),
*response.result,
]
},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
messages = result["messages"]
assert len(messages) == 3
assert messages[1].content == "Inner msg"
assert messages[2].content == "Hello"
async def test_async_outer_wins_on_conflict(self) -> None:
"""Async: outer's state_update fully overwrites inner's on conflicts."""
class OuterMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> WrapModelCallResult:
response = await handler(request)
return WrapModelCallResult(
model_response=response,
state_update={
"messages": [HumanMessage(content="Outer msg", id="outer")],
},
)
class InnerMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> WrapModelCallResult:
response = await handler(request)
return WrapModelCallResult(
model_response=response,
state_update={
"messages": [HumanMessage(content="Inner msg", id="inner")],
},
)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
# Outer wins — inner's messages overwritten
messages = result["messages"]
assert len(messages) == 2
assert messages[0].content == "Hi"
assert messages[1].content == "Outer msg"
async def test_async_inner_state_update_retry_safe(self) -> None:
"""Async: when outer retries, only last inner state update is used."""
call_count = 0
class MyState(AgentState):
attempt: str
class OuterMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
# Call handler twice (simulating retry)
await handler(request)
return await handler(request)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> WrapModelCallResult:
nonlocal call_count
call_count += 1
response = await handler(request)
return WrapModelCallResult(
model_response=response,
state_update={"attempt": f"attempt_{call_count}"},
)
model = GenericFakeChatModel(
messages=iter([AIMessage(content="First"), AIMessage(content="Second")])
)
agent = create_agent(
model=model,
middleware=[OuterMiddleware(), InnerMiddleware()],
)
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
messages = result["messages"]
assert any(m.content == "Second" for m in messages)