mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-15 17:44:08 +00:00
Compare commits
7 Commits
v1.2
...
sr/wrap-mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9dac16e66a | ||
|
|
e8544fc660 | ||
|
|
a69ceb5f20 | ||
|
|
6a4778743d | ||
|
|
8ce516c26b | ||
|
|
10224015f3 | ||
|
|
0626c439c1 |
@@ -33,6 +33,7 @@ from langchain.agents.middleware.types import (
|
||||
OmitFromSchema,
|
||||
ResponseT,
|
||||
StateT_co,
|
||||
WrapModelCallResult,
|
||||
_InputAgentState,
|
||||
_OutputAgentState,
|
||||
)
|
||||
@@ -102,31 +103,80 @@ FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
|
||||
]
|
||||
|
||||
|
||||
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
|
||||
"""Normalize middleware return value to ModelResponse."""
|
||||
def _normalize_to_model_response(
|
||||
result: ModelResponse | AIMessage | WrapModelCallResult,
|
||||
) -> ModelResponse:
|
||||
"""Normalize middleware return value to ModelResponse.
|
||||
|
||||
At inner composition boundaries, ``WrapModelCallResult`` is unwrapped to its
|
||||
underlying ``ModelResponse`` so that inner middleware always sees ``ModelResponse``
|
||||
from the handler.
|
||||
"""
|
||||
if isinstance(result, AIMessage):
|
||||
return ModelResponse(result=[result], structured_response=None)
|
||||
if isinstance(result, WrapModelCallResult):
|
||||
return result.model_response
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_to_wrap_model_call_result(
|
||||
result: ModelResponse | AIMessage | WrapModelCallResult,
|
||||
) -> WrapModelCallResult:
|
||||
"""Normalize middleware return value to WrapModelCallResult.
|
||||
|
||||
At the outermost composition boundary, ensures the result is always a
|
||||
``WrapModelCallResult`` so the model node has a single code path.
|
||||
"""
|
||||
if isinstance(result, WrapModelCallResult):
|
||||
return result
|
||||
return WrapModelCallResult(
|
||||
model_response=_normalize_to_model_response(result),
|
||||
state_update={},
|
||||
)
|
||||
|
||||
|
||||
def _build_state_updates_from_wrap_result(response: WrapModelCallResult) -> dict[str, Any]:
|
||||
"""Build state updates from a ``WrapModelCallResult``.
|
||||
|
||||
Starts with model response defaults (messages, structured_response), then
|
||||
overlays the middleware's ``state_update``. If ``state_update`` contains a key
|
||||
that conflicts with defaults, the ``state_update`` value wins.
|
||||
|
||||
Args:
|
||||
response: The wrap model call result containing model response and state updates.
|
||||
|
||||
Returns:
|
||||
State updates dict ready to be returned from a model node.
|
||||
"""
|
||||
state_updates: dict[str, Any] = {"messages": response.model_response.result}
|
||||
|
||||
if response.model_response.structured_response is not None:
|
||||
state_updates["structured_response"] = response.model_response.structured_response
|
||||
|
||||
state_updates.update(response.state_update)
|
||||
|
||||
return state_updates
|
||||
|
||||
|
||||
def _chain_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse | AIMessage,
|
||||
ModelResponse | AIMessage | WrapModelCallResult,
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse,
|
||||
WrapModelCallResult,
|
||||
]
|
||||
| None
|
||||
):
|
||||
"""Compose multiple `wrap_model_call` handlers into single middleware stack.
|
||||
|
||||
Composes handlers so first in list becomes outermost layer. Each handler receives a
|
||||
handler callback to execute inner layers.
|
||||
handler callback to execute inner layers. The outermost result is always normalized
|
||||
to ``WrapModelCallResult`` so callers have a single code path.
|
||||
|
||||
Args:
|
||||
handlers: List of handlers.
|
||||
@@ -164,81 +214,90 @@ def _chain_model_call_handlers(
|
||||
return None
|
||||
|
||||
if len(handlers) == 1:
|
||||
# Single handler - wrap to normalize output
|
||||
single_handler = handlers[0]
|
||||
|
||||
def normalized_single(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
result = single_handler(request, handler)
|
||||
return _normalize_to_model_response(result)
|
||||
) -> WrapModelCallResult:
|
||||
return _normalize_to_wrap_model_call_result(single_handler(request, handler))
|
||||
|
||||
return normalized_single
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse | AIMessage,
|
||||
ModelResponse | AIMessage | WrapModelCallResult,
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse | AIMessage,
|
||||
ModelResponse | AIMessage | WrapModelCallResult,
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
|
||||
ModelResponse,
|
||||
WrapModelCallResult,
|
||||
]:
|
||||
"""Compose two handlers where outer wraps inner."""
|
||||
|
||||
def composed(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
) -> WrapModelCallResult:
|
||||
# Closure variable to capture inner's state_update before normalizing
|
||||
accumulated_inner_state: list[dict[str, Any]] = []
|
||||
|
||||
# Create a wrapper that calls inner with the base handler and normalizes
|
||||
# Inner boundaries always normalize to ModelResponse
|
||||
def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse:
|
||||
# Clear on each call for retry safety — only the last call's
|
||||
# state updates survive
|
||||
accumulated_inner_state.clear()
|
||||
inner_result = inner(req, handler)
|
||||
if isinstance(inner_result, WrapModelCallResult):
|
||||
accumulated_inner_state.append(inner_result.state_update)
|
||||
return _normalize_to_model_response(inner_result)
|
||||
|
||||
# Call outer with the wrapped inner as its handler and normalize
|
||||
# Call outer with the wrapped inner as its handler
|
||||
outer_result = outer(request, inner_handler)
|
||||
return _normalize_to_model_response(outer_result)
|
||||
|
||||
# Normalize outer result then merge inner state under outer (outer wins)
|
||||
inner_state = accumulated_inner_state[0] if accumulated_inner_state else {}
|
||||
outer_wrapped = _normalize_to_wrap_model_call_result(outer_result)
|
||||
return WrapModelCallResult(
|
||||
model_response=outer_wrapped.model_response,
|
||||
state_update={**inner_state, **outer_wrapped.state_update},
|
||||
)
|
||||
|
||||
return composed
|
||||
|
||||
# Compose right-to-left: outer(inner(innermost(handler)))
|
||||
result = handlers[-1]
|
||||
for handler in reversed(handlers[:-1]):
|
||||
result = compose_two(handler, result)
|
||||
# Seed with the innermost pair so the variable is typed as WrapModelCallResult
|
||||
composed_handler = compose_two(handlers[-2], handlers[-1])
|
||||
for h in reversed(handlers[:-2]):
|
||||
composed_handler = compose_two(h, composed_handler)
|
||||
|
||||
# Wrap to ensure final return type is exactly ModelResponse
|
||||
def final_normalized(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
||||
final_result = result(request, handler)
|
||||
return _normalize_to_model_response(final_result)
|
||||
|
||||
return final_normalized
|
||||
return composed_handler
|
||||
|
||||
|
||||
def _chain_async_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
Awaitable[ModelResponse | AIMessage | WrapModelCallResult],
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse],
|
||||
Awaitable[WrapModelCallResult],
|
||||
]
|
||||
| None
|
||||
):
|
||||
"""Compose multiple async `wrap_model_call` handlers into single middleware stack.
|
||||
|
||||
The outermost result is always normalized to ``WrapModelCallResult`` so callers
|
||||
have a single code path.
|
||||
|
||||
Args:
|
||||
handlers: List of async handlers.
|
||||
|
||||
@@ -251,63 +310,69 @@ def _chain_async_model_call_handlers(
|
||||
return None
|
||||
|
||||
if len(handlers) == 1:
|
||||
# Single handler - wrap to normalize output
|
||||
single_handler = handlers[0]
|
||||
|
||||
async def normalized_single(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
result = await single_handler(request, handler)
|
||||
return _normalize_to_model_response(result)
|
||||
) -> WrapModelCallResult:
|
||||
return _normalize_to_wrap_model_call_result(await single_handler(request, handler))
|
||||
|
||||
return normalized_single
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
Awaitable[ModelResponse | AIMessage | WrapModelCallResult],
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
Awaitable[ModelResponse | AIMessage | WrapModelCallResult],
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse],
|
||||
Awaitable[WrapModelCallResult],
|
||||
]:
|
||||
"""Compose two async handlers where outer wraps inner."""
|
||||
|
||||
async def composed(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
) -> WrapModelCallResult:
|
||||
# Closure variable to capture inner's state_update before normalizing
|
||||
accumulated_inner_state: list[dict[str, Any]] = []
|
||||
|
||||
# Create a wrapper that calls inner with the base handler and normalizes
|
||||
# Inner boundaries always normalize to ModelResponse
|
||||
async def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse:
|
||||
# Clear on each call for retry safety — only the last call's
|
||||
# state updates survive
|
||||
accumulated_inner_state.clear()
|
||||
inner_result = await inner(req, handler)
|
||||
if isinstance(inner_result, WrapModelCallResult):
|
||||
accumulated_inner_state.append(inner_result.state_update)
|
||||
return _normalize_to_model_response(inner_result)
|
||||
|
||||
# Call outer with the wrapped inner as its handler and normalize
|
||||
# Call outer with the wrapped inner as its handler
|
||||
outer_result = await outer(request, inner_handler)
|
||||
return _normalize_to_model_response(outer_result)
|
||||
|
||||
# Normalize outer result then merge inner state under outer (outer wins)
|
||||
inner_state = accumulated_inner_state[0] if accumulated_inner_state else {}
|
||||
outer_wrapped = _normalize_to_wrap_model_call_result(outer_result)
|
||||
return WrapModelCallResult(
|
||||
model_response=outer_wrapped.model_response,
|
||||
state_update={**inner_state, **outer_wrapped.state_update},
|
||||
)
|
||||
|
||||
return composed
|
||||
|
||||
# Compose right-to-left: outer(inner(innermost(handler)))
|
||||
result = handlers[-1]
|
||||
for handler in reversed(handlers[:-1]):
|
||||
result = compose_two(handler, result)
|
||||
# Seed with the innermost pair so the variable is typed as WrapModelCallResult
|
||||
composed_handler = compose_two(handlers[-2], handlers[-1])
|
||||
for h in reversed(handlers[:-2]):
|
||||
composed_handler = compose_two(h, composed_handler)
|
||||
|
||||
# Wrap to ensure final return type is exactly ModelResponse
|
||||
async def final_normalized(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
||||
final_result = await result(request, handler)
|
||||
return _normalize_to_model_response(final_result)
|
||||
|
||||
return final_normalized
|
||||
return composed_handler
|
||||
|
||||
|
||||
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
||||
@@ -1179,18 +1244,11 @@ def create_agent(
|
||||
)
|
||||
|
||||
if wrap_model_call_handler is None:
|
||||
# No handlers - execute directly
|
||||
response = _execute_model_sync(request)
|
||||
response = _normalize_to_wrap_model_call_result(_execute_model_sync(request))
|
||||
else:
|
||||
# Call composed handler with base handler
|
||||
response = wrap_model_call_handler(request, _execute_model_sync)
|
||||
|
||||
# Extract state updates from ModelResponse
|
||||
state_updates = {"messages": response.result}
|
||||
if response.structured_response is not None:
|
||||
state_updates["structured_response"] = response.structured_response
|
||||
|
||||
return state_updates
|
||||
return _build_state_updates_from_wrap_result(response)
|
||||
|
||||
async def _execute_model_async(request: ModelRequest[ContextT]) -> ModelResponse:
|
||||
"""Execute model asynchronously and return response.
|
||||
@@ -1234,18 +1292,11 @@ def create_agent(
|
||||
)
|
||||
|
||||
if awrap_model_call_handler is None:
|
||||
# No async handlers - execute directly
|
||||
response = await _execute_model_async(request)
|
||||
response = _normalize_to_wrap_model_call_result(await _execute_model_async(request))
|
||||
else:
|
||||
# Call composed async handler with base handler
|
||||
response = await awrap_model_call_handler(request, _execute_model_async)
|
||||
|
||||
# Extract state updates from ModelResponse
|
||||
state_updates = {"messages": response.result}
|
||||
if response.structured_response is not None:
|
||||
state_updates["structured_response"] = response.structured_response
|
||||
|
||||
return state_updates
|
||||
return _build_state_updates_from_wrap_result(response)
|
||||
|
||||
# Use sync or async based on model capabilities
|
||||
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
|
||||
|
||||
@@ -29,6 +29,7 @@ from langchain.agents.middleware.types import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ToolCallRequest,
|
||||
WrapModelCallResult,
|
||||
after_agent,
|
||||
after_model,
|
||||
before_agent,
|
||||
@@ -66,6 +67,7 @@ __all__ = [
|
||||
"ToolCallLimitMiddleware",
|
||||
"ToolCallRequest",
|
||||
"ToolRetryMiddleware",
|
||||
"WrapModelCallResult",
|
||||
"after_agent",
|
||||
"after_model",
|
||||
"before_agent",
|
||||
|
||||
@@ -285,14 +285,42 @@ class ModelResponse(Generic[ResponseT]):
|
||||
"""Parsed structured output if `response_format` was specified, `None` otherwise."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class WrapModelCallResult(Generic[ResponseT]):
|
||||
"""Model response with additional state updates from ``wrap_model_call`` middleware.
|
||||
|
||||
Use this to return state updates alongside the model response from a
|
||||
``wrap_model_call`` handler. State updates are merged into the agent state
|
||||
after the model node completes.
|
||||
|
||||
The ``state_update`` dict overwrites the default model response state updates.
|
||||
For example, if ``state_update`` contains a ``"messages"`` key, those messages
|
||||
replace the model response messages entirely. If you want to include both custom
|
||||
messages and the model response, include them explicitly in your state update.
|
||||
|
||||
When multiple middleware return ``WrapModelCallResult``, the outermost
|
||||
middleware's ``state_update`` wins on key conflicts.
|
||||
|
||||
Type Parameters:
|
||||
ResponseT: The type of the structured response. Defaults to `Any` if not specified.
|
||||
"""
|
||||
|
||||
model_response: ModelResponse[ResponseT]
|
||||
"""The underlying model response."""
|
||||
|
||||
state_update: dict[str, Any]
|
||||
"""Additional state updates to merge into the agent state."""
|
||||
|
||||
|
||||
# Type alias for middleware return type - allows returning either full response or just AIMessage
|
||||
ModelCallResult: TypeAlias = "ModelResponse[ResponseT] | AIMessage"
|
||||
ModelCallResult: TypeAlias = "ModelResponse[ResponseT] | AIMessage | WrapModelCallResult[ResponseT]"
|
||||
"""`TypeAlias` for model call handler return value.
|
||||
|
||||
Middleware can return either:
|
||||
|
||||
- `ModelResponse`: Full response with messages and optional structured output
|
||||
- `AIMessage`: Simplified return for simple use cases
|
||||
- `WrapModelCallResult`: Response with additional state updates
|
||||
"""
|
||||
|
||||
|
||||
@@ -449,7 +477,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
) -> ModelResponse[ResponseT] | AIMessage | WrapModelCallResult[ResponseT]:
|
||||
"""Intercept and control model execution via handler callback.
|
||||
|
||||
Async version is `awrap_model_call`
|
||||
@@ -544,7 +572,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
) -> ModelResponse[ResponseT] | AIMessage | WrapModelCallResult[ResponseT]:
|
||||
"""Intercept and control async model execution via handler callback.
|
||||
|
||||
The handler callback executes the model request and returns a `ModelResponse`.
|
||||
|
||||
@@ -8,7 +8,7 @@ from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.factory import _chain_model_call_handlers
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse, WrapModelCallResult
|
||||
|
||||
|
||||
def create_test_request(**kwargs: Any) -> ModelRequest:
|
||||
@@ -88,9 +88,9 @@ class TestChainModelCallHandlers:
|
||||
"inner-after",
|
||||
"outer-after",
|
||||
]
|
||||
# Result is now ModelResponse
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "test"
|
||||
# Outermost result is always WrapModelCallResult
|
||||
assert isinstance(result, WrapModelCallResult)
|
||||
assert result.model_response.result[0].content == "test"
|
||||
|
||||
def test_three_handlers_composition(self) -> None:
|
||||
"""Test composition of three handlers."""
|
||||
@@ -134,8 +134,8 @@ class TestChainModelCallHandlers:
|
||||
"second-after",
|
||||
"first-after",
|
||||
]
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "test"
|
||||
assert isinstance(result, WrapModelCallResult)
|
||||
assert result.model_response.result[0].content == "test"
|
||||
|
||||
def test_inner_handler_retry(self) -> None:
|
||||
"""Test inner handler retrying before outer sees response."""
|
||||
@@ -173,8 +173,8 @@ class TestChainModelCallHandlers:
|
||||
result = composed(create_test_request(), mock_base_handler)
|
||||
|
||||
assert inner_attempts == [0, 1, 2]
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "success"
|
||||
assert isinstance(result, WrapModelCallResult)
|
||||
assert result.model_response.result[0].content == "success"
|
||||
|
||||
def test_error_to_success_conversion(self) -> None:
|
||||
"""Test handler converting error to success response."""
|
||||
@@ -202,10 +202,10 @@ class TestChainModelCallHandlers:
|
||||
|
||||
result = composed(create_test_request(), mock_base_handler)
|
||||
|
||||
# AIMessage was automatically converted to ModelResponse
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "Fallback response"
|
||||
assert result.structured_response is None
|
||||
# AIMessage was automatically normalized into WrapModelCallResult
|
||||
assert isinstance(result, WrapModelCallResult)
|
||||
assert result.model_response.result[0].content == "Fallback response"
|
||||
assert result.model_response.structured_response is None
|
||||
|
||||
def test_request_modification(self) -> None:
|
||||
"""Test handlers modifying the request."""
|
||||
@@ -231,8 +231,8 @@ class TestChainModelCallHandlers:
|
||||
result = composed(create_test_request(), create_mock_base_handler(content="response"))
|
||||
|
||||
assert requests_seen == ["Added by outer"]
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "response"
|
||||
assert isinstance(result, WrapModelCallResult)
|
||||
assert result.model_response.result[0].content == "response"
|
||||
|
||||
def test_composition_preserves_state_and_runtime(self) -> None:
|
||||
"""Test that state and runtime are passed through composition."""
|
||||
@@ -273,8 +273,8 @@ class TestChainModelCallHandlers:
|
||||
# Both handlers should see same state and runtime
|
||||
assert state_values == [("outer", test_state), ("inner", test_state)]
|
||||
assert runtime_values == [("outer", test_runtime), ("inner", test_runtime)]
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "test"
|
||||
assert isinstance(result, WrapModelCallResult)
|
||||
assert result.model_response.result[0].content == "test"
|
||||
|
||||
def test_multiple_yields_in_retry_loop(self) -> None:
|
||||
"""Test handler that retries multiple times."""
|
||||
@@ -312,5 +312,5 @@ class TestChainModelCallHandlers:
|
||||
# Outer called once, inner retried so base handler called twice
|
||||
assert call_count["value"] == 1
|
||||
assert attempt["value"] == 2
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert result.result[0].content == "ok"
|
||||
assert isinstance(result, WrapModelCallResult)
|
||||
assert result.model_response.result[0].content == "ok"
|
||||
|
||||
@@ -0,0 +1,765 @@
|
||||
"""Unit tests for WrapModelCallResult state update support in wrap_model_call.
|
||||
|
||||
Tests that wrap_model_call middleware can return WrapModelCallResult to provide
|
||||
state updates alongside the model response, with outermost middleware winning
|
||||
on key conflicts.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain.agents import AgentState, create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
WrapModelCallResult,
|
||||
wrap_model_call,
|
||||
)
|
||||
|
||||
|
||||
class TestBasicStateUpdate:
|
||||
"""Test basic WrapModelCallResult functionality."""
|
||||
|
||||
def test_state_update_overwrites_model_messages(self) -> None:
|
||||
"""state_update with 'messages' key overwrites model response messages."""
|
||||
|
||||
class OverwriteMessagesMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
custom_msg = HumanMessage(content="Custom message", id="custom")
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={"messages": [custom_msg]},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
agent = create_agent(model=model, middleware=[OverwriteMessagesMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
# Model response messages are overwritten — only custom message survives
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Custom message"
|
||||
|
||||
def test_state_update_includes_model_messages_explicitly(self) -> None:
|
||||
"""Middleware can include model messages alongside custom ones explicitly."""
|
||||
|
||||
class ExplicitMessagesMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
summary = HumanMessage(content="Summary", id="summary")
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={"messages": [summary, *response.result]},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
agent = create_agent(model=model, middleware=[ExplicitMessagesMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Summary"
|
||||
assert messages[2].content == "Hello!"
|
||||
|
||||
def test_state_update_takes_priority_over_model_response(self) -> None:
|
||||
"""state_update messages and structured_response take priority over model response."""
|
||||
|
||||
class OverrideMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
# Model response has its own messages and structured_response,
|
||||
# but state_update should win for both.
|
||||
response_with_structured = ModelResponse(
|
||||
result=response.result,
|
||||
structured_response={"from": "model"},
|
||||
)
|
||||
return WrapModelCallResult(
|
||||
model_response=response_with_structured,
|
||||
state_update={
|
||||
"messages": [HumanMessage(content="From state_update", id="override")],
|
||||
"structured_response": {"from": "state_update"},
|
||||
},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Model msg")]))
|
||||
agent = create_agent(model=model, middleware=[OverrideMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# state_update messages win over model response messages
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "From state_update"
|
||||
|
||||
# state_update structured_response wins over model response structured_response
|
||||
assert result["structured_response"] == {"from": "state_update"}
|
||||
|
||||
def test_state_update_without_messages_key(self) -> None:
|
||||
"""When state_update doesn't include 'messages', model response messages are used."""
|
||||
|
||||
class CustomFieldMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={"custom_key": "custom_value"},
|
||||
)
|
||||
|
||||
class CustomState(AgentState):
|
||||
custom_key: str
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[CustomFieldMiddleware()],
|
||||
state_schema=CustomState,
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert result["messages"][-1].content == "Hello"
|
||||
|
||||
|
||||
class TestCustomStateField:
|
||||
"""Test WrapModelCallResult with custom state fields defined via state_schema."""
|
||||
|
||||
def test_custom_field_via_state_schema(self) -> None:
|
||||
"""Middleware updates a custom state field via WrapModelCallResult."""
|
||||
|
||||
class MyState(AgentState):
|
||||
summary: str
|
||||
|
||||
class SummaryMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={"summary": "conversation summarized"},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[SummaryMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert result["messages"][-1].content == "Hello"
|
||||
|
||||
def test_empty_state_update(self) -> None:
|
||||
"""WrapModelCallResult with empty state_update works like ModelResponse."""
|
||||
|
||||
class EmptyUpdateMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[EmptyUpdateMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Hello"
|
||||
|
||||
|
||||
class TestBackwardsCompatibility:
|
||||
"""Test that existing ModelResponse and AIMessage returns still work."""
|
||||
|
||||
def test_model_response_return_unchanged(self) -> None:
|
||||
"""Existing middleware returning ModelResponse works identically."""
|
||||
|
||||
class PassthroughMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
return handler(request)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[PassthroughMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Hello"
|
||||
|
||||
def test_ai_message_return_unchanged(self) -> None:
|
||||
"""Existing middleware returning AIMessage works identically."""
|
||||
|
||||
class ShortCircuitMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> AIMessage:
|
||||
return AIMessage(content="Short-circuited")
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Should not appear")]))
|
||||
agent = create_agent(model=model, middleware=[ShortCircuitMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Short-circuited"
|
||||
|
||||
def test_no_middleware_unchanged(self) -> None:
|
||||
"""Agent without middleware works identically."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Hello"
|
||||
|
||||
|
||||
class TestAsyncWrapModelCallResult:
|
||||
"""Test async variant of WrapModelCallResult."""
|
||||
|
||||
async def test_async_state_update_overwrites(self) -> None:
|
||||
"""awrap_model_call state_update overwrites model response messages."""
|
||||
|
||||
class AsyncOverwriteMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> WrapModelCallResult:
|
||||
response = await handler(request)
|
||||
custom = HumanMessage(content="Async custom", id="async-custom")
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={"messages": [custom]},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async hello!")]))
|
||||
agent = create_agent(model=model, middleware=[AsyncOverwriteMiddleware()])
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Async custom"
|
||||
|
||||
async def test_async_decorator_state_update(self) -> None:
|
||||
"""@wrap_model_call async decorator returns WrapModelCallResult."""
|
||||
|
||||
@wrap_model_call
|
||||
async def state_update_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> WrapModelCallResult:
|
||||
response = await handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={
|
||||
"messages": [
|
||||
HumanMessage(content="Decorator msg", id="dec"),
|
||||
*response.result,
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
|
||||
agent = create_agent(model=model, middleware=[state_update_middleware])
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage(content="Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[1].content == "Decorator msg"
|
||||
assert messages[2].content == "Async response"
|
||||
|
||||
|
||||
class TestComposition:
|
||||
"""Test WrapModelCallResult with composed middleware.
|
||||
|
||||
Key semantics: outermost middleware's state_update wins on key conflicts.
|
||||
"""
|
||||
|
||||
def test_outer_wrap_result_overwrites_model_messages(self) -> None:
|
||||
"""Outer middleware's state_update overwrites model response messages."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
execution_order.append("outer-before")
|
||||
response = handler(request)
|
||||
execution_order.append("outer-after")
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={"messages": [HumanMessage(content="Outer msg", id="outer-msg")]},
|
||||
)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
execution_order.append("inner-before")
|
||||
response = handler(request)
|
||||
execution_order.append("inner-after")
|
||||
return response
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Composed")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# Execution order: outer wraps inner
|
||||
assert execution_order == [
|
||||
"outer-before",
|
||||
"inner-before",
|
||||
"inner-after",
|
||||
"outer-after",
|
||||
]
|
||||
|
||||
# Outer's state_update overwrites model messages
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Outer msg"
|
||||
|
||||
def test_inner_wrap_result_propagated_through_composition(self) -> None:
|
||||
"""Inner middleware's WrapModelCallResult state_update is propagated.
|
||||
|
||||
When inner middleware returns WrapModelCallResult, its state_update is
|
||||
captured before normalizing to ModelResponse at the composition boundary
|
||||
and merged into the final result.
|
||||
"""
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Outer sees a ModelResponse from handler (inner's WrapModelCallResult
|
||||
# was normalized at the composition boundary)
|
||||
response = handler(request)
|
||||
assert isinstance(response, ModelResponse)
|
||||
return response
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={
|
||||
"messages": [
|
||||
HumanMessage(content="Inner msg", id="inner"),
|
||||
*response.result,
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[1].content == "Inner msg"
|
||||
assert messages[2].content == "Hello"
|
||||
|
||||
def test_outer_state_update_wins_on_all_key_conflicts(self) -> None:
|
||||
"""Outer's state_update fully overwrites inner's on all conflicting keys.
|
||||
|
||||
This applies to all keys including 'messages' — no special casing.
|
||||
"""
|
||||
|
||||
class MyState(AgentState):
|
||||
custom_key: str
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={
|
||||
"messages": [HumanMessage(content="Outer msg", id="outer")],
|
||||
"custom_key": "outer_value",
|
||||
},
|
||||
)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={
|
||||
"messages": [HumanMessage(content="Inner msg", id="inner")],
|
||||
"custom_key": "inner_value",
|
||||
},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# Outer wins on all keys — inner's messages and custom_key are overwritten
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Outer msg"
|
||||
|
||||
def test_inner_state_preserved_when_outer_has_no_conflict(self) -> None:
|
||||
"""Inner's state_update keys are preserved when outer doesn't conflict."""
|
||||
|
||||
class MyState(AgentState):
|
||||
inner_key: str
|
||||
outer_key: str
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={"outer_key": "from_outer"},
|
||||
)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={"inner_key": "from_inner"},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# Both keys survive since there's no conflict
|
||||
messages = result["messages"]
|
||||
assert messages[-1].content == "Hello"
|
||||
|
||||
def test_inner_state_update_retry_safe(self) -> None:
|
||||
"""When outer retries, only the last inner state update is used."""
|
||||
call_count = 0
|
||||
|
||||
class MyState(AgentState):
|
||||
attempt: str
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Call handler twice (simulating retry)
|
||||
handler(request)
|
||||
return handler(request)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
response = handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={"attempt": f"attempt_{call_count}"},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter([AIMessage(content="First"), AIMessage(content="Second")])
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# Only the last retry's inner state should survive
|
||||
messages = result["messages"]
|
||||
assert messages[-1].content == "Second"
|
||||
|
||||
def test_decorator_returns_wrap_result(self) -> None:
|
||||
"""@wrap_model_call decorator can return WrapModelCallResult."""
|
||||
|
||||
@wrap_model_call
|
||||
def state_update_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={
|
||||
"messages": [
|
||||
HumanMessage(content="From decorator", id="dec"),
|
||||
*response.result,
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")]))
|
||||
agent = create_agent(model=model, middleware=[state_update_middleware])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[1].content == "From decorator"
|
||||
assert messages[2].content == "Model response"
|
||||
|
||||
def test_structured_response_preserved(self) -> None:
|
||||
"""WrapModelCallResult preserves structured_response from ModelResponse."""
|
||||
|
||||
class StructuredMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> WrapModelCallResult:
|
||||
response = handler(request)
|
||||
response_with_structured = ModelResponse(
|
||||
result=response.result,
|
||||
structured_response={"key": "value"},
|
||||
)
|
||||
return WrapModelCallResult(
|
||||
model_response=response_with_structured,
|
||||
state_update={},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[StructuredMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert result.get("structured_response") == {"key": "value"}
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[1].content == "Hello"
|
||||
|
||||
|
||||
class TestAsyncComposition:
|
||||
"""Test async WrapModelCallResult propagation through composed middleware."""
|
||||
|
||||
async def test_async_inner_wrap_result_propagated(self) -> None:
|
||||
"""Async: inner middleware's WrapModelCallResult state_update is propagated."""
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
response = await handler(request)
|
||||
assert isinstance(response, ModelResponse)
|
||||
return response
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> WrapModelCallResult:
|
||||
response = await handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={
|
||||
"messages": [
|
||||
HumanMessage(content="Inner msg", id="inner"),
|
||||
*response.result,
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[1].content == "Inner msg"
|
||||
assert messages[2].content == "Hello"
|
||||
|
||||
async def test_async_outer_wins_on_conflict(self) -> None:
|
||||
"""Async: outer's state_update fully overwrites inner's on conflicts."""
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> WrapModelCallResult:
|
||||
response = await handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={
|
||||
"messages": [HumanMessage(content="Outer msg", id="outer")],
|
||||
},
|
||||
)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> WrapModelCallResult:
|
||||
response = await handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={
|
||||
"messages": [HumanMessage(content="Inner msg", id="inner")],
|
||||
},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
# Outer wins — inner's messages overwritten
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Hi"
|
||||
assert messages[1].content == "Outer msg"
|
||||
|
||||
async def test_async_inner_state_update_retry_safe(self) -> None:
|
||||
"""Async: when outer retries, only last inner state update is used."""
|
||||
call_count = 0
|
||||
|
||||
class MyState(AgentState):
|
||||
attempt: str
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
# Call handler twice (simulating retry)
|
||||
await handler(request)
|
||||
return await handler(request)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> WrapModelCallResult:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
response = await handler(request)
|
||||
return WrapModelCallResult(
|
||||
model_response=response,
|
||||
state_update={"attempt": f"attempt_{call_count}"},
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter([AIMessage(content="First"), AIMessage(content="Second")])
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
messages = result["messages"]
|
||||
assert any(m.content == "Second" for m in messages)
|
||||
Reference in New Issue
Block a user