Compare commits

...

7 Commits

Author SHA1 Message Date
Eugene Yurtsev
25f0464dfe x 2025-10-06 18:10:39 -04:00
Eugene Yurtsev
14469d7fe9 x 2025-10-06 18:08:54 -04:00
Eugene Yurtsev
702add4a58 x 2025-10-06 18:03:25 -04:00
Eugene Yurtsev
cd81a2b92d x 2025-10-06 17:46:53 -04:00
Eugene Yurtsev
7d1703be76 x 2025-10-06 17:37:13 -04:00
Eugene Yurtsev
6802a45a36 x 2025-10-06 16:15:03 -04:00
Eugene Yurtsev
e3fd9eac8e x 2025-10-06 15:38:49 -04:00
5 changed files with 805 additions and 196 deletions

View File

@@ -17,6 +17,7 @@ from .types import (
AgentMiddleware,
AgentState,
ModelRequest,
ModelResponse,
after_agent,
after_model,
before_agent,
@@ -24,6 +25,7 @@ from .types import (
dynamic_prompt,
hook_config,
modify_model_request,
on_model_call,
)
__all__ = [
@@ -38,6 +40,7 @@ __all__ = [
"ModelCallLimitMiddleware",
"ModelFallbackMiddleware",
"ModelRequest",
"ModelResponse",
"PIIDetectionError",
"PIIMiddleware",
"PlanningMiddleware",
@@ -50,4 +53,5 @@ __all__ = [
"dynamic_prompt",
"hook_config",
"modify_model_request",
"on_model_call",
]

View File

@@ -2,14 +2,21 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelRequest,
ModelResponse,
)
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from collections.abc import Generator
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.runtime import Runtime
from langgraph.typing import ContextT
class ModelFallbackMiddleware(AgentMiddleware):
@@ -64,31 +71,43 @@ class ModelFallbackMiddleware(AgentMiddleware):
else:
self.models.append(model)
def retry_model_request(
def on_model_call(
self,
error: Exception, # noqa: ARG002
request: ModelRequest,
state: AgentState, # noqa: ARG002
runtime: Runtime, # noqa: ARG002
attempt: int,
) -> ModelRequest | None:
"""Retry with the next fallback model.
state: Any, # noqa: ARG002
runtime: Runtime[ContextT], # noqa: ARG002
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
"""Try fallback models in sequence on errors.
Args:
error: The exception that occurred during model invocation.
request: The original model request that failed.
request: The initial model request.
state: The current agent state.
runtime: The langgraph runtime.
attempt: The current attempt number (1-indexed).
Yields:
ModelRequest: The request to execute.
Receives (via .send()):
ModelResponse: The response from the model call.
Returns:
ModelRequest with the next fallback model, or None if all models exhausted.
ModelResponse: The final response to use.
"""
# attempt 1 = primary model failed, try models[0] (first fallback)
fallback_index = attempt - 1
# All fallback models exhausted
if fallback_index >= len(self.models):
return None
# Try next fallback model
request.model = self.models[fallback_index]
return request
# Try primary model first
current_request = request
response = yield current_request
# If success, return immediately
if response.action == "return":
return response
# Try each fallback model
for fallback_model in self.models:
current_request.model = fallback_model
response = yield current_request
if response.action == "return":
return response
# All models failed, return last error
return response

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Generator
from dataclasses import dataclass, field
from inspect import iscoroutinefunction
from typing import (
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
from collections.abc import Awaitable
# needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AnyMessage # noqa: TC002
from langchain_core.messages import AIMessage, AnyMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.graph.message import add_messages
@@ -42,6 +42,7 @@ __all__ = [
"AgentState",
"ContextT",
"ModelRequest",
"ModelResponse",
"OmitFromSchema",
"PublicAgentState",
"after_agent",
@@ -51,6 +52,7 @@ __all__ = [
"dynamic_prompt",
"hook_config",
"modify_model_request",
"on_model_call",
]
JumpTo = Literal["tools", "model", "end"]
@@ -72,6 +74,45 @@ class ModelRequest:
model_settings: dict[str, Any] = field(default_factory=dict)
@dataclass
class ModelResponse:
"""Response from a model call.
Used in the on_model_call generator protocol to represent the result
of a model invocation, whether successful or failed.
Middleware can modify this response to:
- Rewrite the result message content
- Convert errors to successful responses (error recovery)
- Convert successful responses to errors (validation failures)
- Add metadata or modify message attributes
Examples:
Rewrite response content:
```python
if response.action == "return":
modified = AIMessage(content=f"Enhanced: {response.result.content}")
response = ModelResponse(action="return", result=modified)
```
Convert error to success:
```python
if response.action == "raise":
fallback = AIMessage(content="Using fallback response")
response = ModelResponse(action="return", result=fallback)
```
"""
action: Literal["return", "raise"]
"""The action to take: 'return' for success, 'raise' for error."""
result: AIMessage | None = None
"""The AI message result if action is 'return'. Can be rewritten by middleware."""
exception: Exception | None = None
"""The exception if action is 'raise'."""
@dataclass
class OmitFromSchema:
"""Annotation used to mark state attributes as omitted from input or output schemas."""
@@ -180,53 +221,89 @@ class AgentMiddleware(Generic[StateT, ContextT]):
) -> dict[str, Any] | None:
"""Async logic to run after the model is called."""
def retry_model_request(
def on_model_call(
self,
error: Exception, # noqa: ARG002
request: ModelRequest, # noqa: ARG002
state: StateT, # noqa: ARG002
runtime: Runtime[ContextT], # noqa: ARG002
attempt: int, # noqa: ARG002
) -> ModelRequest | None:
"""Logic to handle model invocation errors and optionally retry.
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
"""Generator-based hook to intercept and control model execution.
This hook allows middleware to:
- Intercept model calls before execution
- Modify requests dynamically
- Implement retry logic by yielding multiple times
- Handle errors and convert them to responses
- Rewrite response content (e.g., post-processing, formatting)
- Compose with other middleware (outer wraps inner)
The generator protocol:
1. Yield a ModelRequest to execute the model
2. Receive a ModelResponse via .send() with the result or error
3. Optionally yield again to retry with a modified request
4. Return a final ModelResponse (can be modified)
Args:
error: The exception that occurred during model invocation.
request: The original model request that failed.
request: The initial model request.
state: The current agent state.
runtime: The langgraph runtime.
attempt: The current attempt number (1-indexed).
Yields:
ModelRequest: The request to execute.
Receives (via .send()):
ModelResponse: The response from the model call.
Returns:
ModelRequest: Modified request to retry with.
None: Propagate the error (re-raise).
ModelResponse: The final response to use (can be rewritten).
Examples:
Implementing retry logic:
```python
def on_model_call(self, request, state, runtime):
max_retries = 3
for attempt in range(max_retries):
response = yield request
if response.action == "return":
return response
if attempt < max_retries - 1:
# Modify request and retry
continue
# All retries failed
return response
```
Rewriting response content:
```python
def on_model_call(self, request, state, runtime):
response = yield request
# Post-process successful responses
if response.action == "return" and response.result:
enhanced = AIMessage(content=self.enhance(response.result.content))
response = ModelResponse(action="return", result=enhanced)
return response
```
Converting errors to fallback responses:
```python
def on_model_call(self, request, state, runtime):
response = yield request
# Provide fallback on error
if response.action == "raise":
fallback = AIMessage(content="Service unavailable")
response = ModelResponse(action="return", result=fallback)
return response
```
"""
return None
async def aretry_model_request(
self,
error: Exception,
request: ModelRequest,
state: StateT,
runtime: Runtime[ContextT],
attempt: int,
) -> ModelRequest | None:
"""Async logic to handle model invocation errors and optionally retry.
Args:
error: The exception that occurred during model invocation.
request: The original model request that failed.
state: The current agent state.
runtime: The langgraph runtime.
attempt: The current attempt number (1-indexed).
Returns:
ModelRequest: Modified request to retry with.
None: Propagate the error (re-raise).
"""
return await run_in_executor(
None, self.retry_model_request, error, request, state, runtime, attempt
)
raise NotImplementedError
yield # Make this a generator for type checking
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run after the agent execution completes."""
@@ -267,6 +344,20 @@ class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]):
...
class _CallableReturningModelResponseGenerator(Protocol[StateT_contra, ContextT]):
"""Callable that returns a generator for model call interception.
Note: Returns a sync generator that works with both sync and async model execution,
following the same pattern as ToolNode's on_tool_call handler.
"""
def __call__(
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
"""Generate responses to intercept and control model execution."""
...
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
@@ -1122,3 +1213,172 @@ def dynamic_prompt(
if func is not None:
return decorator(func)
return decorator
@overload
def on_model_call(
func: _CallableReturningModelResponseGenerator[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def on_model_call(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[
[_CallableReturningModelResponseGenerator[StateT, ContextT]],
AgentMiddleware[StateT, ContextT],
]: ...
def on_model_call(
func: _CallableReturningModelResponseGenerator[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[
[_CallableReturningModelResponseGenerator[StateT, ContextT]],
AgentMiddleware[StateT, ContextT],
]
| AgentMiddleware[StateT, ContextT]
):
"""Decorator used to dynamically create a middleware with the on_model_call hook.
This decorator creates middleware that can intercept model calls, implement
retry logic, handle errors, and rewrite responses through a generator-based protocol.
Args:
func: The generator function to be decorated. Must accept:
`request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` and
yield ModelRequest, receive ModelResponse, and return ModelResponse.
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.
Returns:
Either an AgentMiddleware instance (if func is provided) or a decorator function
that can be applied to a function.
The decorated function should:
- Yield ModelRequest to execute the model
- Receive ModelResponse via .send() with results or errors
- Return final ModelResponse (can be rewritten)
Examples:
Basic retry logic:
```python
@on_model_call
def retry_on_error(
request: ModelRequest, state: AgentState, runtime: Runtime
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
max_retries = 3
for attempt in range(max_retries):
response = yield request
if response.action == "return":
return response
if attempt < max_retries - 1:
# Modify request and retry
continue
# All retries failed
return response
```
Model fallback:
```python
@on_model_call
def fallback_model(
request: ModelRequest, state: AgentState, runtime: Runtime
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
# Try primary model
response = yield request
if response.action == "return":
return response
# Try fallback model
request.model = fallback_model_instance
response = yield request
return response
```
Rewrite response content:
```python
@on_model_call
def uppercase_responses(
request: ModelRequest, state: AgentState, runtime: Runtime
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
response = yield request
# Rewrite successful responses
if response.action == "return" and response.result:
modified = AIMessage(content=response.result.content.upper())
response = ModelResponse(action="return", result=modified)
return response
```
"""
def decorator(
func: _CallableReturningModelResponseGenerator[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
# on_model_call handlers are always sync generators, even when used with async execution
# This follows the same pattern as ToolNode's on_tool_call handler
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
state: StateT,
runtime: Runtime[ContextT],
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
# Forward the generator yields and sends
gen = func(request, state, runtime)
try:
# Initialize generator with first next()
request_to_send = next(gen)
response = yield request_to_send
# Forward subsequent sends
while True:
request_to_send = gen.send(response)
response = yield request_to_send
except StopIteration as e:
# Validate the return value
if e.value is None:
msg = (
"on_model_call handler must explicitly return a ModelResponse. "
"Ensure your handler ends with 'return response'."
)
raise ValueError(msg)
if not isinstance(e.value, ModelResponse):
msg = (
f"on_model_call handler must return a ModelResponse, "
f"got {type(e.value).__name__} instead"
)
raise TypeError(msg)
return e.value
middleware_name = name or cast("str", getattr(func, "__name__", "OnModelCallMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"on_model_call": wrapped,
},
)()
if func is not None:
return decorator(func)
return decorator

View File

@@ -1,7 +1,8 @@
"""Middleware agent implementation."""
import itertools
from collections.abc import Callable, Sequence
from collections.abc import Callable, Generator, Sequence
from dataclasses import dataclass
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
from langchain_core.language_models.chat_models import BaseChatModel
@@ -21,6 +22,7 @@ from langchain.agents.middleware.types import (
AgentState,
JumpTo,
ModelRequest,
ModelResponse,
OmitFromSchema,
PublicAgentState,
)
@@ -42,6 +44,181 @@ STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
ResponseT = TypeVar("ResponseT")
@dataclass
class _InternalModelResponse:
"""Internal wrapper for model responses with additional metadata.
This wrapper contains the actual ModelResponse that middleware sees,
plus internal metadata needed for efficient implementation. Middleware
authors never see this wrapper - they only interact with ModelResponse.
"""
model_response: ModelResponse
"""The actual model response that middleware interacts with."""
effective_response_format: Any = None
"""The effective response format used for this model call (internal use)."""
def _validate_handler_return(value: Any) -> ModelResponse:
"""Validate that a handler returned a valid ModelResponse.
Args:
value: The value returned from the handler's StopIteration.
Returns:
The validated ModelResponse.
Raises:
ValueError: If the value is None or not a ModelResponse.
"""
if value is None:
msg = (
"on_model_call handler must explicitly return a ModelResponse. "
"Ensure your handler ends with 'return response'."
)
raise ValueError(msg)
if not isinstance(value, ModelResponse):
msg = (
f"on_model_call handler must return a ModelResponse, got {type(value).__name__} instead"
)
raise TypeError(msg)
return value
def _chain_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Any, Any],
Generator[ModelRequest, ModelResponse, ModelResponse],
]
],
) -> (
Callable[
[ModelRequest, Any, Any],
Generator[ModelRequest, ModelResponse, ModelResponse],
]
| None
):
"""Compose multiple model call handlers into a single middleware stack.
Args:
handlers: Handlers in middleware order (first = outermost layer).
Returns:
Single composed handler, or None if handlers is empty.
Example:
```python
# Auth middleware (outer) + retry (inner)
def auth(req, state, runtime):
resp = yield req
if "unauthorized" in str(resp.exception):
refresh_token()
resp = yield req # Retry
return resp
def retry(req, state, runtime):
for attempt in range(3):
resp = yield req
if resp.action == "return":
return resp
time.sleep(2**attempt)
return resp
handler = _chain_model_call_handlers([auth, retry])
# Request: auth -> retry -> model
# Response: model -> retry -> auth
```
"""
if not handlers:
return None
if len(handlers) == 1:
return handlers[0]
def _extract_return_value(stop_iteration: StopIteration) -> ModelResponse:
"""Extract ModelResponse from StopIteration, validating protocol compliance."""
return _validate_handler_return(stop_iteration.value)
def compose_two(
outer: Callable[
[ModelRequest, Any, Any],
Generator[ModelRequest, ModelResponse, ModelResponse],
],
inner: Callable[
[ModelRequest, Any, Any],
Generator[ModelRequest, ModelResponse, ModelResponse],
],
) -> Callable[
[ModelRequest, Any, Any],
Generator[ModelRequest, ModelResponse, ModelResponse],
]:
"""Compose two handlers where outer wraps inner."""
def composed(
request: ModelRequest,
state: Any,
runtime: Any,
) -> Generator[ModelRequest, ModelResponse, ModelResponse]:
outer_gen = outer(request, state, runtime)
# Initialize outer generator
try:
outer_request = next(outer_gen)
except StopIteration as e:
return _extract_return_value(e)
# Outer retry loop
while True:
inner_gen = inner(outer_request, state, runtime)
# Initialize inner generator
try:
inner_request = next(inner_gen)
except StopIteration as e:
# Inner returned immediately - send to outer
inner_response = _extract_return_value(e)
try:
outer_request = outer_gen.send(inner_response)
continue # Outer retrying
except StopIteration as e:
return _extract_return_value(e)
# Inner retry loop - yield to next layer (or model)
while True:
model_response = yield inner_request
try:
inner_request = inner_gen.send(model_response)
# Inner retrying - continue inner loop
except StopIteration as e:
# Inner done - send response to outer
inner_response = _extract_return_value(e)
break
# Send inner's final response to outer
try:
outer_request = outer_gen.send(inner_response)
# Outer retrying - continue outer loop
except StopIteration as e:
# Outer done - return final response
return _extract_return_value(e)
return composed
# Compose right-to-left: handlers[0](handlers[1](...(handlers[-1](model))))
result = handlers[-1]
for handler in reversed(handlers[:-1]):
result = compose_two(handler, result)
return result
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
@@ -291,13 +468,28 @@ def create_agent( # noqa: PLR0915
if m.__class__.after_agent is not AgentMiddleware.after_agent
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
]
middleware_w_retry = [
m
for m in middleware
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request
or m.__class__.aretry_model_request is not AgentMiddleware.aretry_model_request
middleware_w_on_model_call = [
m for m in middleware if m.__class__.on_model_call is not AgentMiddleware.on_model_call
]
# Compose on_model_call handlers into a single middleware stack
on_model_call_handler = None
if middleware_w_on_model_call:
# Collect sync implementations from all middleware
sync_handlers = []
all_have_sync = True
for m in middleware_w_on_model_call:
if m.__class__.on_model_call is not AgentMiddleware.on_model_call:
sync_handlers.append(m.on_model_call)
else:
# No sync implementation for this middleware
all_have_sync = False
break
# Only compose if all have sync implementations
if all_have_sync and sync_handlers:
on_model_call_handler = _chain_model_call_handlers(sync_handlers)
state_schemas = {m.state_schema for m in middleware}
state_schemas.add(AgentState)
@@ -521,6 +713,29 @@ def create_agent( # noqa: PLR0915
)
return request.model.bind(**request.model_settings), None
def _execute_model_sync(request: ModelRequest) -> _InternalModelResponse:
"""Execute model and return response.
This is the core model execution logic wrapped by on_model_call handlers.
"""
try:
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
output = model_.invoke(messages)
return _InternalModelResponse(
model_response=ModelResponse(action="return", result=output),
effective_response_format=effective_response_format,
)
except Exception as error: # noqa: BLE001
# Catch all exceptions from model invocation to wrap in ModelResponse
return _InternalModelResponse(
model_response=ModelResponse(action="raise", exception=error)
)
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
@@ -535,7 +750,7 @@ def create_agent( # noqa: PLR0915
# Apply modify_model_request middleware in sequence
for m in middleware_w_modify_model_request:
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request:
m.modify_model_request(request, state, runtime)
request = m.modify_model_request(request, state, runtime)
else:
msg = (
f"No synchronous function provided for "
@@ -545,47 +760,79 @@ def create_agent( # noqa: PLR0915
)
raise TypeError(msg)
# Retry loop for model invocation with error handling
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
max_attempts = 100
for attempt in range(1, max_attempts + 1):
# Execute with or without handler
current_request = request
internal_response: _InternalModelResponse
final_response: ModelResponse
if on_model_call_handler is None:
# No handlers - execute directly
internal_response = _execute_model_sync(request)
final_response = internal_response.model_response
else:
# Use composed handler with generator protocol
gen = on_model_call_handler(request, state, runtime)
try:
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
current_request = next(gen)
except StopIteration:
msg = "on_model_call handler must yield at least once to request model execution"
raise ValueError(msg)
output = model_.invoke(messages)
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(output, effective_response_format),
}
except Exception as error:
# Try retry_model_request on each middleware
for m in middleware_w_retry:
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
if retry_request := m.retry_model_request(
error, request, state, runtime, attempt
):
# Break on first middleware that wants to retry
request = retry_request
break
else:
msg = (
f"No synchronous function provided for "
f'{m.__class__.__name__}.aretry_model_request".'
"\nEither initialize with a synchronous function or invoke"
" via the async API (ainvoke, astream, etc.)"
)
raise TypeError(msg)
else:
raise
# Execution loop - generator controls termination via StopIteration
while True:
internal_response = _execute_model_sync(current_request)
# If we exit the loop, max attempts exceeded
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
raise RuntimeError(msg)
try:
# Send only the ModelResponse to the generator (not the wrapper)
current_request = gen.send(internal_response.model_response)
# Handler yielded again - retry
except StopIteration as e:
final_response = _validate_handler_return(e.value)
break
# Process the final response
if final_response.action == "raise":
if final_response.exception is None:
msg = "ModelResponse with action='raise' must have an exception"
raise ValueError(msg)
raise final_response.exception
if final_response.result is None:
msg = "ModelResponse with action='return' must have a result"
raise ValueError(msg)
# Use cached effective_response_format from internal response
effective_response_format = internal_response.effective_response_format
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(final_response.result, effective_response_format),
}
async def _execute_model_async(request: ModelRequest) -> _InternalModelResponse:
"""Execute model asynchronously and return response.
This is the core async model execution logic wrapped by on_model_call handlers.
"""
try:
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
output = await model_.ainvoke(messages)
return _InternalModelResponse(
model_response=ModelResponse(action="return", result=output),
effective_response_format=effective_response_format,
)
except Exception as error: # noqa: BLE001
# Catch all exceptions from model invocation to wrap in ModelResponse
return _InternalModelResponse(
model_response=ModelResponse(action="raise", exception=error)
)
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
@@ -600,41 +847,59 @@ def create_agent( # noqa: PLR0915
# Apply modify_model_request middleware in sequence
for m in middleware_w_modify_model_request:
await m.amodify_model_request(request, state, runtime)
request = await m.amodify_model_request(request, state, runtime)
# Execute with or without handler
# Note: handler is sync generator, but model execution is async
current_request = request
internal_response: _InternalModelResponse
final_response: ModelResponse
if on_model_call_handler is None:
# No handlers - execute directly
internal_response = await _execute_model_async(request)
final_response = internal_response.model_response
else:
# Use composed handler with generator protocol (sync generator, async execution)
gen = on_model_call_handler(request, state, runtime)
# Retry loop for model invocation with error handling
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
max_attempts = 100
for attempt in range(1, max_attempts + 1):
try:
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
current_request = next(gen)
except StopIteration:
msg = "on_model_call handler must yield at least once to request model execution"
raise ValueError(msg)
output = await model_.ainvoke(messages)
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(output, effective_response_format),
}
except Exception as error:
# Try retry_model_request on each middleware
for m in middleware_w_retry:
if retry_request := await m.aretry_model_request(
error, request, state, runtime, attempt
):
# Break on first middleware that wants to retry
request = retry_request
break
else:
# If no middleware wants to retry, re-raise the error
raise
# Execution loop - generator controls termination via StopIteration
while True:
internal_response = await _execute_model_async(current_request)
# If we exit the loop, max attempts exceeded
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
raise RuntimeError(msg)
try:
# Send only the ModelResponse to the generator (not the wrapper)
current_request = gen.send(internal_response.model_response)
# Handler yielded again - retry
except StopIteration as e:
final_response = _validate_handler_return(e.value)
break
# Process the final response
if final_response.action == "raise":
if final_response.exception is None:
msg = "ModelResponse with action='raise' must have an exception"
raise ValueError(msg)
raise final_response.exception
if final_response.result is None:
msg = "ModelResponse with action='return' must have a result"
raise ValueError(msg)
# Use cached effective_response_format from internal response
effective_response_format = internal_response.effective_response_format
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(final_response.result, effective_response_format),
}
# Use sync or async based on model capabilities
graph.add_node("model_request", RunnableCallable(model_request, amodel_request, trace=False))

View File

@@ -48,6 +48,7 @@ from langchain.agents.middleware.types import (
AgentState,
hook_config,
ModelRequest,
ModelResponse,
OmitFromInput,
OmitFromOutput,
PrivateStateAttr,
@@ -2110,9 +2111,9 @@ async def test_create_agent_mixed_sync_async_middleware() -> None:
]
# Tests for retry_model_request hook
def test_retry_model_request_hook() -> None:
"""Test that retry_model_request hook is called on model errors."""
# Tests for on_model_call hook
def test_on_model_call_hook() -> None:
"""Test that on_model_call hook is called on model errors."""
call_count = {"value": 0}
class FailingModel(BaseChatModel):
@@ -2135,10 +2136,15 @@ def test_retry_model_request_hook() -> None:
super().__init__()
self.retry_count = 0
def retry_model_request(self, error, request, state, runtime, attempt):
self.retry_count += 1
# Return the same request to retry
return request
def on_model_call(self, request, state, runtime):
response = yield request
if response.action == "raise":
# Retry on error
self.retry_count += 1
response = yield request
return response
failing_model = FailingModel()
retry_middleware = RetryMiddleware()
@@ -2154,8 +2160,8 @@ def test_retry_model_request_hook() -> None:
assert result["messages"][1].content == "Success on retry"
def test_retry_model_request_attempt_number() -> None:
"""Test that attempt number is correctly passed to retry_model_request."""
def test_on_model_call_retry_count() -> None:
"""Test that on_model_call can retry multiple times."""
class AlwaysFailingModel(BaseChatModel):
"""Model that always fails."""
@@ -2172,11 +2178,20 @@ def test_retry_model_request_attempt_number() -> None:
super().__init__()
self.attempts = []
def retry_model_request(self, error, request, state, runtime, attempt):
self.attempts.append(attempt)
if attempt < 3: # noqa: PLR2004
return request # Retry
return None # Stop after 3 attempts
def on_model_call(self, request, state, runtime):
max_retries = 3
for attempt in range(max_retries):
self.attempts.append(attempt + 1)
response = yield request
if response.action == "return":
return response
if attempt < max_retries - 1:
continue # Retry
# All retries failed
return response
model = AlwaysFailingModel()
tracker = AttemptTrackingMiddleware()
@@ -2186,12 +2201,12 @@ def test_retry_model_request_attempt_number() -> None:
with pytest.raises(ValueError, match="Always fails"):
agent.invoke({"messages": [HumanMessage("Test")]})
# Should have been called with attempts 1, 2, 3
# Should have attempted 3 times
assert tracker.attempts == [1, 2, 3]
def test_retry_model_request_no_retry() -> None:
"""Test that error is propagated when no middleware wants to retry."""
def test_on_model_call_no_retry() -> None:
"""Test that error is propagated when middleware doesn't retry."""
class FailingModel(BaseChatModel):
"""Model that always fails."""
@@ -2204,9 +2219,10 @@ def test_retry_model_request_no_retry() -> None:
return "failing"
class NoRetryMiddleware(AgentMiddleware):
def retry_model_request(self, error, request, state, runtime, attempt):
# Always return None to not retry
return None
def on_model_call(self, request, state, runtime):
response = yield request
# Don't retry, just return the error response
return response
agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()]).compile()
@@ -2301,8 +2317,8 @@ def test_model_fallback_middleware_initialization() -> None:
assert len(middleware.models) == 2
def test_retry_model_request_max_attempts() -> None:
"""Test that retry stops after maximum attempts."""
def test_on_model_call_max_attempts() -> None:
"""Test that middleware controls termination via retry limits."""
class AlwaysFailingModel(BaseChatModel):
"""Model that always fails."""
@@ -2314,72 +2330,117 @@ def test_retry_model_request_max_attempts() -> None:
def _llm_type(self):
return "always_failing"
class InfiniteRetryMiddleware(AgentMiddleware):
"""Middleware that always wants to retry (buggy behavior)."""
class LimitedRetryMiddleware(AgentMiddleware):
"""Middleware that limits its own retries."""
def __init__(self):
def __init__(self, max_retries: int = 10):
super().__init__()
self.max_retries = max_retries
self.attempt_count = 0
def retry_model_request(self, error, request, state, runtime, attempt):
self.attempt_count = attempt
return request # Always retry (infinite loop without limit)
def on_model_call(self, request, state, runtime):
for attempt in range(self.max_retries):
self.attempt_count += 1
response = yield request
if response.action == "return":
return response
# Continue to retry
# All retries exhausted, return the last error
return response
model = AlwaysFailingModel()
middleware = InfiniteRetryMiddleware()
middleware = LimitedRetryMiddleware(max_retries=10)
agent = create_agent(model=model, middleware=[middleware]).compile()
# Should fail with max attempts error, not infinite loop
with pytest.raises(RuntimeError, match="Maximum retry attempts \\(100\\) exceeded"):
# Should fail with the model's error after middleware stops retrying
with pytest.raises(ValueError, match="Always fails"):
agent.invoke({"messages": [HumanMessage("Test")]})
# Should have attempted 100 times
assert middleware.attempt_count == 100
# Should have attempted exactly 10 times as configured
assert middleware.attempt_count == 10
async def test_retry_model_request_async() -> None:
"""Test async retry_model_request hook."""
call_count = {"value": 0}
def test_on_model_call_rewrite_response() -> None:
"""Test that middleware can rewrite model responses."""
class AsyncFailingModel(BaseChatModel):
"""Model that fails on first async call, succeeds on second."""
class SimpleModel(BaseChatModel):
"""Model that returns a simple response."""
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="sync"))])
async def _agenerate(self, messages, **kwargs):
call_count["value"] += 1
if call_count["value"] == 1:
raise ValueError("First async call fails")
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content="Async retry success"))]
generations=[ChatGeneration(message=AIMessage(content="Original response"))]
)
@property
def _llm_type(self):
return "async_failing"
return "simple"
class AsyncRetryMiddleware(AgentMiddleware):
def __init__(self):
super().__init__()
self.retry_count = 0
class ResponseRewriteMiddleware(AgentMiddleware):
"""Middleware that rewrites the response."""
async def aretry_model_request(self, error, request, state, runtime, attempt):
self.retry_count += 1
return request # Retry with same request
def on_model_call(self, request, state, runtime):
response = yield request
failing_model = AsyncFailingModel()
retry_middleware = AsyncRetryMiddleware()
# Rewrite the response
if response.action == "return" and response.result:
rewritten_message = AIMessage(content=f"REWRITTEN: {response.result.content}")
response = ModelResponse(action="return", result=rewritten_message)
agent = create_agent(model=failing_model, middleware=[retry_middleware]).compile()
return response
result = await agent.ainvoke({"messages": [HumanMessage("Test")]})
model = SimpleModel()
middleware = ResponseRewriteMiddleware()
# Should have retried once
assert retry_middleware.retry_count == 1
# Should have succeeded on second attempt
assert result["messages"][1].content == "Async retry success"
agent = create_agent(model=model, middleware=[middleware]).compile()
result = agent.invoke({"messages": [HumanMessage("Test")]})
# Response should be rewritten by middleware
assert result["messages"][1].content == "REWRITTEN: Original response"
def test_on_model_call_convert_error_to_response() -> None:
"""Test that middleware can convert errors to successful responses."""
class AlwaysFailingModel(BaseChatModel):
"""Model that always fails."""
def _generate(self, messages, **kwargs):
raise ValueError("Model error")
@property
def _llm_type(self):
return "failing"
class ErrorToResponseMiddleware(AgentMiddleware):
"""Middleware that converts errors to success responses."""
def on_model_call(self, request, state, runtime):
response = yield request
# Convert error to success response
if response.action == "raise":
fallback_message = AIMessage(
content=f"Error occurred: {response.exception}. Using fallback response."
)
response = ModelResponse(action="return", result=fallback_message)
return response
model = AlwaysFailingModel()
middleware = ErrorToResponseMiddleware()
agent = create_agent(model=model, middleware=[middleware]).compile()
# Should not raise, middleware converts error to response
result = agent.invoke({"messages": [HumanMessage("Test")]})
# Response should be the fallback from middleware
assert "Error occurred" in result["messages"][1].content
assert "fallback response" in result["messages"][1].content
def test_create_agent_sync_invoke_with_only_async_middleware_raises_error() -> None: