mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-21 21:56:38 +00:00
Compare commits
7 Commits
langchain-
...
on_model_c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25f0464dfe | ||
|
|
14469d7fe9 | ||
|
|
702add4a58 | ||
|
|
cd81a2b92d | ||
|
|
7d1703be76 | ||
|
|
6802a45a36 | ||
|
|
e3fd9eac8e |
@@ -17,6 +17,7 @@ from .types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
after_agent,
|
||||
after_model,
|
||||
before_agent,
|
||||
@@ -24,6 +25,7 @@ from .types import (
|
||||
dynamic_prompt,
|
||||
hook_config,
|
||||
modify_model_request,
|
||||
on_model_call,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -38,6 +40,7 @@ __all__ = [
|
||||
"ModelCallLimitMiddleware",
|
||||
"ModelFallbackMiddleware",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"PlanningMiddleware",
|
||||
@@ -50,4 +53,5 @@ __all__ = [
|
||||
"dynamic_prompt",
|
||||
"hook_config",
|
||||
"modify_model_request",
|
||||
"on_model_call",
|
||||
]
|
||||
|
||||
@@ -2,14 +2,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
|
||||
class ModelFallbackMiddleware(AgentMiddleware):
|
||||
@@ -64,31 +71,43 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
else:
|
||||
self.models.append(model)
|
||||
|
||||
def retry_model_request(
|
||||
def on_model_call(
|
||||
self,
|
||||
error: Exception, # noqa: ARG002
|
||||
request: ModelRequest,
|
||||
state: AgentState, # noqa: ARG002
|
||||
runtime: Runtime, # noqa: ARG002
|
||||
attempt: int,
|
||||
) -> ModelRequest | None:
|
||||
"""Retry with the next fallback model.
|
||||
state: Any, # noqa: ARG002
|
||||
runtime: Runtime[ContextT], # noqa: ARG002
|
||||
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
|
||||
"""Try fallback models in sequence on errors.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred during model invocation.
|
||||
request: The original model request that failed.
|
||||
request: The initial model request.
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
attempt: The current attempt number (1-indexed).
|
||||
|
||||
Yields:
|
||||
ModelRequest: The request to execute.
|
||||
|
||||
Receives (via .send()):
|
||||
ModelResponse: The response from the model call.
|
||||
|
||||
Returns:
|
||||
ModelRequest with the next fallback model, or None if all models exhausted.
|
||||
ModelResponse: The final response to use.
|
||||
"""
|
||||
# attempt 1 = primary model failed, try models[0] (first fallback)
|
||||
fallback_index = attempt - 1
|
||||
# All fallback models exhausted
|
||||
if fallback_index >= len(self.models):
|
||||
return None
|
||||
# Try next fallback model
|
||||
request.model = self.models[fallback_index]
|
||||
return request
|
||||
# Try primary model first
|
||||
current_request = request
|
||||
response = yield current_request
|
||||
|
||||
# If success, return immediately
|
||||
if response.action == "return":
|
||||
return response
|
||||
|
||||
# Try each fallback model
|
||||
for fallback_model in self.models:
|
||||
current_request.model = fallback_model
|
||||
response = yield current_request
|
||||
|
||||
if response.action == "return":
|
||||
return response
|
||||
|
||||
# All models failed, return last error
|
||||
return response
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Generator
|
||||
from dataclasses import dataclass, field
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import (
|
||||
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
# needed as top level import for pydantic schema generation on AgentState
|
||||
from langchain_core.messages import AnyMessage # noqa: TC002
|
||||
from langchain_core.messages import AIMessage, AnyMessage # noqa: TC002
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from langgraph.graph.message import add_messages
|
||||
@@ -42,6 +42,7 @@ __all__ = [
|
||||
"AgentState",
|
||||
"ContextT",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"OmitFromSchema",
|
||||
"PublicAgentState",
|
||||
"after_agent",
|
||||
@@ -51,6 +52,7 @@ __all__ = [
|
||||
"dynamic_prompt",
|
||||
"hook_config",
|
||||
"modify_model_request",
|
||||
"on_model_call",
|
||||
]
|
||||
|
||||
JumpTo = Literal["tools", "model", "end"]
|
||||
@@ -72,6 +74,45 @@ class ModelRequest:
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Response from a model call.
|
||||
|
||||
Used in the on_model_call generator protocol to represent the result
|
||||
of a model invocation, whether successful or failed.
|
||||
|
||||
Middleware can modify this response to:
|
||||
- Rewrite the result message content
|
||||
- Convert errors to successful responses (error recovery)
|
||||
- Convert successful responses to errors (validation failures)
|
||||
- Add metadata or modify message attributes
|
||||
|
||||
Examples:
|
||||
Rewrite response content:
|
||||
```python
|
||||
if response.action == "return":
|
||||
modified = AIMessage(content=f"Enhanced: {response.result.content}")
|
||||
response = ModelResponse(action="return", result=modified)
|
||||
```
|
||||
|
||||
Convert error to success:
|
||||
```python
|
||||
if response.action == "raise":
|
||||
fallback = AIMessage(content="Using fallback response")
|
||||
response = ModelResponse(action="return", result=fallback)
|
||||
```
|
||||
"""
|
||||
|
||||
action: Literal["return", "raise"]
|
||||
"""The action to take: 'return' for success, 'raise' for error."""
|
||||
|
||||
result: AIMessage | None = None
|
||||
"""The AI message result if action is 'return'. Can be rewritten by middleware."""
|
||||
|
||||
exception: Exception | None = None
|
||||
"""The exception if action is 'raise'."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OmitFromSchema:
|
||||
"""Annotation used to mark state attributes as omitted from input or output schemas."""
|
||||
@@ -180,53 +221,89 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async logic to run after the model is called."""
|
||||
|
||||
def retry_model_request(
|
||||
def on_model_call(
|
||||
self,
|
||||
error: Exception, # noqa: ARG002
|
||||
request: ModelRequest, # noqa: ARG002
|
||||
state: StateT, # noqa: ARG002
|
||||
runtime: Runtime[ContextT], # noqa: ARG002
|
||||
attempt: int, # noqa: ARG002
|
||||
) -> ModelRequest | None:
|
||||
"""Logic to handle model invocation errors and optionally retry.
|
||||
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
|
||||
"""Generator-based hook to intercept and control model execution.
|
||||
|
||||
This hook allows middleware to:
|
||||
- Intercept model calls before execution
|
||||
- Modify requests dynamically
|
||||
- Implement retry logic by yielding multiple times
|
||||
- Handle errors and convert them to responses
|
||||
- Rewrite response content (e.g., post-processing, formatting)
|
||||
- Compose with other middleware (outer wraps inner)
|
||||
|
||||
The generator protocol:
|
||||
1. Yield a ModelRequest to execute the model
|
||||
2. Receive a ModelResponse via .send() with the result or error
|
||||
3. Optionally yield again to retry with a modified request
|
||||
4. Return a final ModelResponse (can be modified)
|
||||
|
||||
Args:
|
||||
error: The exception that occurred during model invocation.
|
||||
request: The original model request that failed.
|
||||
request: The initial model request.
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
attempt: The current attempt number (1-indexed).
|
||||
|
||||
Yields:
|
||||
ModelRequest: The request to execute.
|
||||
|
||||
Receives (via .send()):
|
||||
ModelResponse: The response from the model call.
|
||||
|
||||
Returns:
|
||||
ModelRequest: Modified request to retry with.
|
||||
None: Propagate the error (re-raise).
|
||||
ModelResponse: The final response to use (can be rewritten).
|
||||
|
||||
Examples:
|
||||
Implementing retry logic:
|
||||
```python
|
||||
def on_model_call(self, request, state, runtime):
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
response = yield request
|
||||
|
||||
if response.action == "return":
|
||||
return response
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Modify request and retry
|
||||
continue
|
||||
|
||||
# All retries failed
|
||||
return response
|
||||
```
|
||||
|
||||
Rewriting response content:
|
||||
```python
|
||||
def on_model_call(self, request, state, runtime):
|
||||
response = yield request
|
||||
|
||||
# Post-process successful responses
|
||||
if response.action == "return" and response.result:
|
||||
enhanced = AIMessage(content=self.enhance(response.result.content))
|
||||
response = ModelResponse(action="return", result=enhanced)
|
||||
|
||||
return response
|
||||
```
|
||||
|
||||
Converting errors to fallback responses:
|
||||
```python
|
||||
def on_model_call(self, request, state, runtime):
|
||||
response = yield request
|
||||
|
||||
# Provide fallback on error
|
||||
if response.action == "raise":
|
||||
fallback = AIMessage(content="Service unavailable")
|
||||
response = ModelResponse(action="return", result=fallback)
|
||||
|
||||
return response
|
||||
```
|
||||
"""
|
||||
return None
|
||||
|
||||
async def aretry_model_request(
|
||||
self,
|
||||
error: Exception,
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
attempt: int,
|
||||
) -> ModelRequest | None:
|
||||
"""Async logic to handle model invocation errors and optionally retry.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred during model invocation.
|
||||
request: The original model request that failed.
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
attempt: The current attempt number (1-indexed).
|
||||
|
||||
Returns:
|
||||
ModelRequest: Modified request to retry with.
|
||||
None: Propagate the error (re-raise).
|
||||
"""
|
||||
return await run_in_executor(
|
||||
None, self.retry_model_request, error, request, state, runtime, attempt
|
||||
)
|
||||
raise NotImplementedError
|
||||
yield # Make this a generator for type checking
|
||||
|
||||
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Logic to run after the agent execution completes."""
|
||||
@@ -267,6 +344,20 @@ class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]):
|
||||
...
|
||||
|
||||
|
||||
class _CallableReturningModelResponseGenerator(Protocol[StateT_contra, ContextT]):
|
||||
"""Callable that returns a generator for model call interception.
|
||||
|
||||
Note: Returns a sync generator that works with both sync and async model execution,
|
||||
following the same pattern as ToolNode's on_tool_call handler.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
|
||||
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
|
||||
"""Generate responses to intercept and control model execution."""
|
||||
...
|
||||
|
||||
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@@ -1122,3 +1213,172 @@ def dynamic_prompt(
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def on_model_call(
|
||||
func: _CallableReturningModelResponseGenerator[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def on_model_call(
|
||||
func: None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[
|
||||
[_CallableReturningModelResponseGenerator[StateT, ContextT]],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]: ...
|
||||
|
||||
|
||||
def on_model_call(
|
||||
func: _CallableReturningModelResponseGenerator[StateT, ContextT] | None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[_CallableReturningModelResponseGenerator[StateT, ContextT]],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
):
|
||||
"""Decorator used to dynamically create a middleware with the on_model_call hook.
|
||||
|
||||
This decorator creates middleware that can intercept model calls, implement
|
||||
retry logic, handle errors, and rewrite responses through a generator-based protocol.
|
||||
|
||||
Args:
|
||||
func: The generator function to be decorated. Must accept:
|
||||
`request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` and
|
||||
yield ModelRequest, receive ModelResponse, and return ModelResponse.
|
||||
state_schema: Optional custom state schema type. If not provided, uses the default
|
||||
AgentState schema.
|
||||
tools: Optional list of additional tools to register with this middleware.
|
||||
name: Optional name for the generated middleware class. If not provided,
|
||||
uses the decorated function's name.
|
||||
|
||||
Returns:
|
||||
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
||||
that can be applied to a function.
|
||||
|
||||
The decorated function should:
|
||||
- Yield ModelRequest to execute the model
|
||||
- Receive ModelResponse via .send() with results or errors
|
||||
- Return final ModelResponse (can be rewritten)
|
||||
|
||||
Examples:
|
||||
Basic retry logic:
|
||||
```python
|
||||
@on_model_call
|
||||
def retry_on_error(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
response = yield request
|
||||
|
||||
if response.action == "return":
|
||||
return response
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Modify request and retry
|
||||
continue
|
||||
|
||||
# All retries failed
|
||||
return response
|
||||
```
|
||||
|
||||
Model fallback:
|
||||
```python
|
||||
@on_model_call
|
||||
def fallback_model(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
|
||||
# Try primary model
|
||||
response = yield request
|
||||
|
||||
if response.action == "return":
|
||||
return response
|
||||
|
||||
# Try fallback model
|
||||
request.model = fallback_model_instance
|
||||
response = yield request
|
||||
return response
|
||||
```
|
||||
|
||||
Rewrite response content:
|
||||
```python
|
||||
@on_model_call
|
||||
def uppercase_responses(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
|
||||
response = yield request
|
||||
|
||||
# Rewrite successful responses
|
||||
if response.action == "return" and response.result:
|
||||
modified = AIMessage(content=response.result.content.upper())
|
||||
response = ModelResponse(action="return", result=modified)
|
||||
|
||||
return response
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableReturningModelResponseGenerator[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
# on_model_call handlers are always sync generators, even when used with async execution
|
||||
# This follows the same pattern as ToolNode's on_tool_call handler
|
||||
|
||||
def wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
|
||||
# Forward the generator yields and sends
|
||||
gen = func(request, state, runtime)
|
||||
try:
|
||||
# Initialize generator with first next()
|
||||
request_to_send = next(gen)
|
||||
response = yield request_to_send
|
||||
|
||||
# Forward subsequent sends
|
||||
while True:
|
||||
request_to_send = gen.send(response)
|
||||
response = yield request_to_send
|
||||
except StopIteration as e:
|
||||
# Validate the return value
|
||||
if e.value is None:
|
||||
msg = (
|
||||
"on_model_call handler must explicitly return a ModelResponse. "
|
||||
"Ensure your handler ends with 'return response'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if not isinstance(e.value, ModelResponse):
|
||||
msg = (
|
||||
f"on_model_call handler must return a ModelResponse, "
|
||||
f"got {type(e.value).__name__} instead"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return e.value
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "OnModelCallMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"on_model_call": wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Middleware agent implementation."""
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@@ -21,6 +22,7 @@ from langchain.agents.middleware.types import (
|
||||
AgentState,
|
||||
JumpTo,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OmitFromSchema,
|
||||
PublicAgentState,
|
||||
)
|
||||
@@ -42,6 +44,181 @@ STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class _InternalModelResponse:
|
||||
"""Internal wrapper for model responses with additional metadata.
|
||||
|
||||
This wrapper contains the actual ModelResponse that middleware sees,
|
||||
plus internal metadata needed for efficient implementation. Middleware
|
||||
authors never see this wrapper - they only interact with ModelResponse.
|
||||
"""
|
||||
|
||||
model_response: ModelResponse
|
||||
"""The actual model response that middleware interacts with."""
|
||||
|
||||
effective_response_format: Any = None
|
||||
"""The effective response format used for this model call (internal use)."""
|
||||
|
||||
|
||||
def _validate_handler_return(value: Any) -> ModelResponse:
|
||||
"""Validate that a handler returned a valid ModelResponse.
|
||||
|
||||
Args:
|
||||
value: The value returned from the handler's StopIteration.
|
||||
|
||||
Returns:
|
||||
The validated ModelResponse.
|
||||
|
||||
Raises:
|
||||
ValueError: If the value is None or not a ModelResponse.
|
||||
"""
|
||||
if value is None:
|
||||
msg = (
|
||||
"on_model_call handler must explicitly return a ModelResponse. "
|
||||
"Ensure your handler ends with 'return response'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if not isinstance(value, ModelResponse):
|
||||
msg = (
|
||||
f"on_model_call handler must return a ModelResponse, got {type(value).__name__} instead"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _chain_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest, Any, Any],
|
||||
Generator[ModelRequest, ModelResponse, ModelResponse],
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest, Any, Any],
|
||||
Generator[ModelRequest, ModelResponse, ModelResponse],
|
||||
]
|
||||
| None
|
||||
):
|
||||
"""Compose multiple model call handlers into a single middleware stack.
|
||||
|
||||
Args:
|
||||
handlers: Handlers in middleware order (first = outermost layer).
|
||||
|
||||
Returns:
|
||||
Single composed handler, or None if handlers is empty.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Auth middleware (outer) + retry (inner)
|
||||
def auth(req, state, runtime):
|
||||
resp = yield req
|
||||
if "unauthorized" in str(resp.exception):
|
||||
refresh_token()
|
||||
resp = yield req # Retry
|
||||
return resp
|
||||
|
||||
|
||||
def retry(req, state, runtime):
|
||||
for attempt in range(3):
|
||||
resp = yield req
|
||||
if resp.action == "return":
|
||||
return resp
|
||||
time.sleep(2**attempt)
|
||||
return resp
|
||||
|
||||
|
||||
handler = _chain_model_call_handlers([auth, retry])
|
||||
# Request: auth -> retry -> model
|
||||
# Response: model -> retry -> auth
|
||||
```
|
||||
"""
|
||||
if not handlers:
|
||||
return None
|
||||
|
||||
if len(handlers) == 1:
|
||||
return handlers[0]
|
||||
|
||||
def _extract_return_value(stop_iteration: StopIteration) -> ModelResponse:
|
||||
"""Extract ModelResponse from StopIteration, validating protocol compliance."""
|
||||
return _validate_handler_return(stop_iteration.value)
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest, Any, Any],
|
||||
Generator[ModelRequest, ModelResponse, ModelResponse],
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest, Any, Any],
|
||||
Generator[ModelRequest, ModelResponse, ModelResponse],
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest, Any, Any],
|
||||
Generator[ModelRequest, ModelResponse, ModelResponse],
|
||||
]:
|
||||
"""Compose two handlers where outer wraps inner."""
|
||||
|
||||
def composed(
|
||||
request: ModelRequest,
|
||||
state: Any,
|
||||
runtime: Any,
|
||||
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
|
||||
outer_gen = outer(request, state, runtime)
|
||||
|
||||
# Initialize outer generator
|
||||
try:
|
||||
outer_request = next(outer_gen)
|
||||
except StopIteration as e:
|
||||
return _extract_return_value(e)
|
||||
|
||||
# Outer retry loop
|
||||
while True:
|
||||
inner_gen = inner(outer_request, state, runtime)
|
||||
|
||||
# Initialize inner generator
|
||||
try:
|
||||
inner_request = next(inner_gen)
|
||||
except StopIteration as e:
|
||||
# Inner returned immediately - send to outer
|
||||
inner_response = _extract_return_value(e)
|
||||
try:
|
||||
outer_request = outer_gen.send(inner_response)
|
||||
continue # Outer retrying
|
||||
except StopIteration as e:
|
||||
return _extract_return_value(e)
|
||||
|
||||
# Inner retry loop - yield to next layer (or model)
|
||||
while True:
|
||||
model_response = yield inner_request
|
||||
|
||||
try:
|
||||
inner_request = inner_gen.send(model_response)
|
||||
# Inner retrying - continue inner loop
|
||||
except StopIteration as e:
|
||||
# Inner done - send response to outer
|
||||
inner_response = _extract_return_value(e)
|
||||
break
|
||||
|
||||
# Send inner's final response to outer
|
||||
try:
|
||||
outer_request = outer_gen.send(inner_response)
|
||||
# Outer retrying - continue outer loop
|
||||
except StopIteration as e:
|
||||
# Outer done - return final response
|
||||
return _extract_return_value(e)
|
||||
|
||||
return composed
|
||||
|
||||
# Compose right-to-left: handlers[0](handlers[1](...(handlers[-1](model))))
|
||||
result = handlers[-1]
|
||||
for handler in reversed(handlers[:-1]):
|
||||
result = compose_two(handler, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
||||
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
||||
|
||||
@@ -291,13 +468,28 @@ def create_agent( # noqa: PLR0915
|
||||
if m.__class__.after_agent is not AgentMiddleware.after_agent
|
||||
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
|
||||
]
|
||||
middleware_w_retry = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request
|
||||
or m.__class__.aretry_model_request is not AgentMiddleware.aretry_model_request
|
||||
middleware_w_on_model_call = [
|
||||
m for m in middleware if m.__class__.on_model_call is not AgentMiddleware.on_model_call
|
||||
]
|
||||
|
||||
# Compose on_model_call handlers into a single middleware stack
|
||||
on_model_call_handler = None
|
||||
if middleware_w_on_model_call:
|
||||
# Collect sync implementations from all middleware
|
||||
sync_handlers = []
|
||||
all_have_sync = True
|
||||
for m in middleware_w_on_model_call:
|
||||
if m.__class__.on_model_call is not AgentMiddleware.on_model_call:
|
||||
sync_handlers.append(m.on_model_call)
|
||||
else:
|
||||
# No sync implementation for this middleware
|
||||
all_have_sync = False
|
||||
break
|
||||
|
||||
# Only compose if all have sync implementations
|
||||
if all_have_sync and sync_handlers:
|
||||
on_model_call_handler = _chain_model_call_handlers(sync_handlers)
|
||||
|
||||
state_schemas = {m.state_schema for m in middleware}
|
||||
state_schemas.add(AgentState)
|
||||
|
||||
@@ -521,6 +713,29 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
return request.model.bind(**request.model_settings), None
|
||||
|
||||
def _execute_model_sync(request: ModelRequest) -> _InternalModelResponse:
|
||||
"""Execute model and return response.
|
||||
|
||||
This is the core model execution logic wrapped by on_model_call handlers.
|
||||
"""
|
||||
try:
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
output = model_.invoke(messages)
|
||||
return _InternalModelResponse(
|
||||
model_response=ModelResponse(action="return", result=output),
|
||||
effective_response_format=effective_response_format,
|
||||
)
|
||||
except Exception as error: # noqa: BLE001
|
||||
# Catch all exceptions from model invocation to wrap in ModelResponse
|
||||
return _InternalModelResponse(
|
||||
model_response=ModelResponse(action="raise", exception=error)
|
||||
)
|
||||
|
||||
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
@@ -535,7 +750,7 @@ def create_agent( # noqa: PLR0915
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for m in middleware_w_modify_model_request:
|
||||
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request:
|
||||
m.modify_model_request(request, state, runtime)
|
||||
request = m.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
msg = (
|
||||
f"No synchronous function provided for "
|
||||
@@ -545,47 +760,79 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
# Retry loop for model invocation with error handling
|
||||
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
|
||||
max_attempts = 100
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
# Execute with or without handler
|
||||
current_request = request
|
||||
internal_response: _InternalModelResponse
|
||||
final_response: ModelResponse
|
||||
|
||||
if on_model_call_handler is None:
|
||||
# No handlers - execute directly
|
||||
internal_response = _execute_model_sync(request)
|
||||
final_response = internal_response.model_response
|
||||
else:
|
||||
# Use composed handler with generator protocol
|
||||
gen = on_model_call_handler(request, state, runtime)
|
||||
|
||||
try:
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
current_request = next(gen)
|
||||
except StopIteration:
|
||||
msg = "on_model_call handler must yield at least once to request model execution"
|
||||
raise ValueError(msg)
|
||||
|
||||
output = model_.invoke(messages)
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(output, effective_response_format),
|
||||
}
|
||||
except Exception as error:
|
||||
# Try retry_model_request on each middleware
|
||||
for m in middleware_w_retry:
|
||||
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
|
||||
if retry_request := m.retry_model_request(
|
||||
error, request, state, runtime, attempt
|
||||
):
|
||||
# Break on first middleware that wants to retry
|
||||
request = retry_request
|
||||
break
|
||||
else:
|
||||
msg = (
|
||||
f"No synchronous function provided for "
|
||||
f'{m.__class__.__name__}.aretry_model_request".'
|
||||
"\nEither initialize with a synchronous function or invoke"
|
||||
" via the async API (ainvoke, astream, etc.)"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
else:
|
||||
raise
|
||||
# Execution loop - generator controls termination via StopIteration
|
||||
while True:
|
||||
internal_response = _execute_model_sync(current_request)
|
||||
|
||||
# If we exit the loop, max attempts exceeded
|
||||
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
|
||||
raise RuntimeError(msg)
|
||||
try:
|
||||
# Send only the ModelResponse to the generator (not the wrapper)
|
||||
current_request = gen.send(internal_response.model_response)
|
||||
# Handler yielded again - retry
|
||||
except StopIteration as e:
|
||||
final_response = _validate_handler_return(e.value)
|
||||
break
|
||||
|
||||
# Process the final response
|
||||
if final_response.action == "raise":
|
||||
if final_response.exception is None:
|
||||
msg = "ModelResponse with action='raise' must have an exception"
|
||||
raise ValueError(msg)
|
||||
raise final_response.exception
|
||||
|
||||
if final_response.result is None:
|
||||
msg = "ModelResponse with action='return' must have a result"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Use cached effective_response_format from internal response
|
||||
effective_response_format = internal_response.effective_response_format
|
||||
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(final_response.result, effective_response_format),
|
||||
}
|
||||
|
||||
async def _execute_model_async(request: ModelRequest) -> _InternalModelResponse:
|
||||
"""Execute model asynchronously and return response.
|
||||
|
||||
This is the core async model execution logic wrapped by on_model_call handlers.
|
||||
"""
|
||||
try:
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
output = await model_.ainvoke(messages)
|
||||
return _InternalModelResponse(
|
||||
model_response=ModelResponse(action="return", result=output),
|
||||
effective_response_format=effective_response_format,
|
||||
)
|
||||
except Exception as error: # noqa: BLE001
|
||||
# Catch all exceptions from model invocation to wrap in ModelResponse
|
||||
return _InternalModelResponse(
|
||||
model_response=ModelResponse(action="raise", exception=error)
|
||||
)
|
||||
|
||||
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
@@ -600,41 +847,59 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for m in middleware_w_modify_model_request:
|
||||
await m.amodify_model_request(request, state, runtime)
|
||||
request = await m.amodify_model_request(request, state, runtime)
|
||||
|
||||
# Execute with or without handler
|
||||
# Note: handler is sync generator, but model execution is async
|
||||
current_request = request
|
||||
internal_response: _InternalModelResponse
|
||||
final_response: ModelResponse
|
||||
|
||||
if on_model_call_handler is None:
|
||||
# No handlers - execute directly
|
||||
internal_response = await _execute_model_async(request)
|
||||
final_response = internal_response.model_response
|
||||
else:
|
||||
# Use composed handler with generator protocol (sync generator, async execution)
|
||||
gen = on_model_call_handler(request, state, runtime)
|
||||
|
||||
# Retry loop for model invocation with error handling
|
||||
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
|
||||
max_attempts = 100
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
current_request = next(gen)
|
||||
except StopIteration:
|
||||
msg = "on_model_call handler must yield at least once to request model execution"
|
||||
raise ValueError(msg)
|
||||
|
||||
output = await model_.ainvoke(messages)
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(output, effective_response_format),
|
||||
}
|
||||
except Exception as error:
|
||||
# Try retry_model_request on each middleware
|
||||
for m in middleware_w_retry:
|
||||
if retry_request := await m.aretry_model_request(
|
||||
error, request, state, runtime, attempt
|
||||
):
|
||||
# Break on first middleware that wants to retry
|
||||
request = retry_request
|
||||
break
|
||||
else:
|
||||
# If no middleware wants to retry, re-raise the error
|
||||
raise
|
||||
# Execution loop - generator controls termination via StopIteration
|
||||
while True:
|
||||
internal_response = await _execute_model_async(current_request)
|
||||
|
||||
# If we exit the loop, max attempts exceeded
|
||||
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
|
||||
raise RuntimeError(msg)
|
||||
try:
|
||||
# Send only the ModelResponse to the generator (not the wrapper)
|
||||
current_request = gen.send(internal_response.model_response)
|
||||
# Handler yielded again - retry
|
||||
except StopIteration as e:
|
||||
final_response = _validate_handler_return(e.value)
|
||||
break
|
||||
|
||||
# Process the final response
|
||||
if final_response.action == "raise":
|
||||
if final_response.exception is None:
|
||||
msg = "ModelResponse with action='raise' must have an exception"
|
||||
raise ValueError(msg)
|
||||
raise final_response.exception
|
||||
|
||||
if final_response.result is None:
|
||||
msg = "ModelResponse with action='return' must have a result"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Use cached effective_response_format from internal response
|
||||
effective_response_format = internal_response.effective_response_format
|
||||
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(final_response.result, effective_response_format),
|
||||
}
|
||||
|
||||
# Use sync or async based on model capabilities
|
||||
graph.add_node("model_request", RunnableCallable(model_request, amodel_request, trace=False))
|
||||
|
||||
@@ -48,6 +48,7 @@ from langchain.agents.middleware.types import (
|
||||
AgentState,
|
||||
hook_config,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OmitFromInput,
|
||||
OmitFromOutput,
|
||||
PrivateStateAttr,
|
||||
@@ -2110,9 +2111,9 @@ async def test_create_agent_mixed_sync_async_middleware() -> None:
|
||||
]
|
||||
|
||||
|
||||
# Tests for retry_model_request hook
|
||||
def test_retry_model_request_hook() -> None:
|
||||
"""Test that retry_model_request hook is called on model errors."""
|
||||
# Tests for on_model_call hook
|
||||
def test_on_model_call_hook() -> None:
|
||||
"""Test that on_model_call hook is called on model errors."""
|
||||
call_count = {"value": 0}
|
||||
|
||||
class FailingModel(BaseChatModel):
|
||||
@@ -2135,10 +2136,15 @@ def test_retry_model_request_hook() -> None:
|
||||
super().__init__()
|
||||
self.retry_count = 0
|
||||
|
||||
def retry_model_request(self, error, request, state, runtime, attempt):
|
||||
self.retry_count += 1
|
||||
# Return the same request to retry
|
||||
return request
|
||||
def on_model_call(self, request, state, runtime):
|
||||
response = yield request
|
||||
|
||||
if response.action == "raise":
|
||||
# Retry on error
|
||||
self.retry_count += 1
|
||||
response = yield request
|
||||
|
||||
return response
|
||||
|
||||
failing_model = FailingModel()
|
||||
retry_middleware = RetryMiddleware()
|
||||
@@ -2154,8 +2160,8 @@ def test_retry_model_request_hook() -> None:
|
||||
assert result["messages"][1].content == "Success on retry"
|
||||
|
||||
|
||||
def test_retry_model_request_attempt_number() -> None:
|
||||
"""Test that attempt number is correctly passed to retry_model_request."""
|
||||
def test_on_model_call_retry_count() -> None:
|
||||
"""Test that on_model_call can retry multiple times."""
|
||||
|
||||
class AlwaysFailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
@@ -2172,11 +2178,20 @@ def test_retry_model_request_attempt_number() -> None:
|
||||
super().__init__()
|
||||
self.attempts = []
|
||||
|
||||
def retry_model_request(self, error, request, state, runtime, attempt):
|
||||
self.attempts.append(attempt)
|
||||
if attempt < 3: # noqa: PLR2004
|
||||
return request # Retry
|
||||
return None # Stop after 3 attempts
|
||||
def on_model_call(self, request, state, runtime):
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
self.attempts.append(attempt + 1)
|
||||
response = yield request
|
||||
|
||||
if response.action == "return":
|
||||
return response
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
continue # Retry
|
||||
|
||||
# All retries failed
|
||||
return response
|
||||
|
||||
model = AlwaysFailingModel()
|
||||
tracker = AttemptTrackingMiddleware()
|
||||
@@ -2186,12 +2201,12 @@ def test_retry_model_request_attempt_number() -> None:
|
||||
with pytest.raises(ValueError, match="Always fails"):
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Should have been called with attempts 1, 2, 3
|
||||
# Should have attempted 3 times
|
||||
assert tracker.attempts == [1, 2, 3]
|
||||
|
||||
|
||||
def test_retry_model_request_no_retry() -> None:
|
||||
"""Test that error is propagated when no middleware wants to retry."""
|
||||
def test_on_model_call_no_retry() -> None:
|
||||
"""Test that error is propagated when middleware doesn't retry."""
|
||||
|
||||
class FailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
@@ -2204,9 +2219,10 @@ def test_retry_model_request_no_retry() -> None:
|
||||
return "failing"
|
||||
|
||||
class NoRetryMiddleware(AgentMiddleware):
|
||||
def retry_model_request(self, error, request, state, runtime, attempt):
|
||||
# Always return None to not retry
|
||||
return None
|
||||
def on_model_call(self, request, state, runtime):
|
||||
response = yield request
|
||||
# Don't retry, just return the error response
|
||||
return response
|
||||
|
||||
agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()]).compile()
|
||||
|
||||
@@ -2301,8 +2317,8 @@ def test_model_fallback_middleware_initialization() -> None:
|
||||
assert len(middleware.models) == 2
|
||||
|
||||
|
||||
def test_retry_model_request_max_attempts() -> None:
|
||||
"""Test that retry stops after maximum attempts."""
|
||||
def test_on_model_call_max_attempts() -> None:
|
||||
"""Test that middleware controls termination via retry limits."""
|
||||
|
||||
class AlwaysFailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
@@ -2314,72 +2330,117 @@ def test_retry_model_request_max_attempts() -> None:
|
||||
def _llm_type(self):
|
||||
return "always_failing"
|
||||
|
||||
class InfiniteRetryMiddleware(AgentMiddleware):
|
||||
"""Middleware that always wants to retry (buggy behavior)."""
|
||||
class LimitedRetryMiddleware(AgentMiddleware):
|
||||
"""Middleware that limits its own retries."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, max_retries: int = 10):
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.attempt_count = 0
|
||||
|
||||
def retry_model_request(self, error, request, state, runtime, attempt):
|
||||
self.attempt_count = attempt
|
||||
return request # Always retry (infinite loop without limit)
|
||||
def on_model_call(self, request, state, runtime):
|
||||
for attempt in range(self.max_retries):
|
||||
self.attempt_count += 1
|
||||
response = yield request
|
||||
|
||||
if response.action == "return":
|
||||
return response
|
||||
# Continue to retry
|
||||
|
||||
# All retries exhausted, return the last error
|
||||
return response
|
||||
|
||||
model = AlwaysFailingModel()
|
||||
middleware = InfiniteRetryMiddleware()
|
||||
middleware = LimitedRetryMiddleware(max_retries=10)
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware]).compile()
|
||||
|
||||
# Should fail with max attempts error, not infinite loop
|
||||
with pytest.raises(RuntimeError, match="Maximum retry attempts \\(100\\) exceeded"):
|
||||
# Should fail with the model's error after middleware stops retrying
|
||||
with pytest.raises(ValueError, match="Always fails"):
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Should have attempted 100 times
|
||||
assert middleware.attempt_count == 100
|
||||
# Should have attempted exactly 10 times as configured
|
||||
assert middleware.attempt_count == 10
|
||||
|
||||
|
||||
async def test_retry_model_request_async() -> None:
|
||||
"""Test async retry_model_request hook."""
|
||||
call_count = {"value": 0}
|
||||
def test_on_model_call_rewrite_response() -> None:
|
||||
"""Test that middleware can rewrite model responses."""
|
||||
|
||||
class AsyncFailingModel(BaseChatModel):
|
||||
"""Model that fails on first async call, succeeds on second."""
|
||||
class SimpleModel(BaseChatModel):
|
||||
"""Model that returns a simple response."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="sync"))])
|
||||
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] == 1:
|
||||
raise ValueError("First async call fails")
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Async retry success"))]
|
||||
generations=[ChatGeneration(message=AIMessage(content="Original response"))]
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "async_failing"
|
||||
return "simple"
|
||||
|
||||
class AsyncRetryMiddleware(AgentMiddleware):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retry_count = 0
|
||||
class ResponseRewriteMiddleware(AgentMiddleware):
|
||||
"""Middleware that rewrites the response."""
|
||||
|
||||
async def aretry_model_request(self, error, request, state, runtime, attempt):
|
||||
self.retry_count += 1
|
||||
return request # Retry with same request
|
||||
def on_model_call(self, request, state, runtime):
|
||||
response = yield request
|
||||
|
||||
failing_model = AsyncFailingModel()
|
||||
retry_middleware = AsyncRetryMiddleware()
|
||||
# Rewrite the response
|
||||
if response.action == "return" and response.result:
|
||||
rewritten_message = AIMessage(content=f"REWRITTEN: {response.result.content}")
|
||||
response = ModelResponse(action="return", result=rewritten_message)
|
||||
|
||||
agent = create_agent(model=failing_model, middleware=[retry_middleware]).compile()
|
||||
return response
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Test")]})
|
||||
model = SimpleModel()
|
||||
middleware = ResponseRewriteMiddleware()
|
||||
|
||||
# Should have retried once
|
||||
assert retry_middleware.retry_count == 1
|
||||
# Should have succeeded on second attempt
|
||||
assert result["messages"][1].content == "Async retry success"
|
||||
agent = create_agent(model=model, middleware=[middleware]).compile()
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Response should be rewritten by middleware
|
||||
assert result["messages"][1].content == "REWRITTEN: Original response"
|
||||
|
||||
|
||||
def test_on_model_call_convert_error_to_response() -> None:
|
||||
"""Test that middleware can convert errors to successful responses."""
|
||||
|
||||
class AlwaysFailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model error")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "failing"
|
||||
|
||||
class ErrorToResponseMiddleware(AgentMiddleware):
|
||||
"""Middleware that converts errors to success responses."""
|
||||
|
||||
def on_model_call(self, request, state, runtime):
|
||||
response = yield request
|
||||
|
||||
# Convert error to success response
|
||||
if response.action == "raise":
|
||||
fallback_message = AIMessage(
|
||||
content=f"Error occurred: {response.exception}. Using fallback response."
|
||||
)
|
||||
response = ModelResponse(action="return", result=fallback_message)
|
||||
|
||||
return response
|
||||
|
||||
model = AlwaysFailingModel()
|
||||
middleware = ErrorToResponseMiddleware()
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware]).compile()
|
||||
|
||||
# Should not raise, middleware converts error to response
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Response should be the fallback from middleware
|
||||
assert "Error occurred" in result["messages"][1].content
|
||||
assert "fallback response" in result["messages"][1].content
|
||||
|
||||
|
||||
def test_create_agent_sync_invoke_with_only_async_middleware_raises_error() -> None:
|
||||
|
||||
Reference in New Issue
Block a user