fix(langchain): support async middleware decorator typing (#34584)

Fixes #35244

Users can write async agent middleware with `@wrap_model_call`, and
LangChain already supports that behavior at runtime by detecting
coroutine functions and wiring them to `awrap_model_call`.

However, the decorator's public typing currently describes only the sync
callable shape. As a result, valid async middleware is rejected by type
checkers such as mypy and ty, even though the same code runs correctly.

This updates the middleware decorator types so async `wrap_model_call`
and `wrap_tool_call` functions type-check consistently with their
runtime behavior. It also simplifies related callable aliases and uses
casts where `iscoroutinefunction` narrows the callable at runtime but
static type checkers cannot follow that narrowing.

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Christophe Bornet
2026-06-11 04:08:06 +02:00
committed by GitHub
parent c16499c339
commit 8fc58c6013

View File

@@ -12,8 +12,6 @@ from typing import (
Any,
Generic,
Literal,
Protocol,
TypeAlias,
cast,
overload,
)
@@ -310,10 +308,8 @@ class ExtendedModelResponse(Generic[ResponseT]):
"""Optional command to apply as an additional state update."""
ModelCallResult: TypeAlias = (
"ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]"
)
"""`TypeAlias` for model call handler return value.
ModelCallResult = ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]
"""Return type for model call handlers.
Middleware can return either:
@@ -826,69 +822,35 @@ _CallableReturningSystemMessage = (
_SyncCallableReturningSystemMessage[ContextT] | _AsyncCallableReturningSystemMessage[ContextT]
)
# Sync/async signatures for `wrap_model_call` interception; see `@wrap_model_call`.
_SyncCallableReturningModelResponse = Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]]],
ModelCallResult,
]
_AsyncCallableReturningModelResponse = Callable[
[
ModelRequest[ContextT],
Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
],
Awaitable[ModelCallResult],
]
_CallableReturningModelResponse = (
_SyncCallableReturningModelResponse[ContextT, ResponseT]
| _AsyncCallableReturningModelResponse[ContextT, ResponseT]
)
class _CallableReturningModelResponse(Protocol[ContextT, ResponseT]):
"""Callable for model call interception with handler callback.
Receives handler callback to execute model and returns `ModelResponse` or
`AIMessage`.
"""
@overload
def __call__(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage: ...
@overload
def __call__(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> Awaitable[ModelResponse[ResponseT] | AIMessage]: ...
def __call__(
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], ModelResponse[ResponseT] | Awaitable[ModelResponse[ResponseT]]
],
) -> ModelResponse[ResponseT] | AIMessage | Awaitable[ModelResponse[ResponseT] | AIMessage]:
"""Intercept model execution via handler callback."""
...
class _CallableReturningToolResponse(Protocol):
"""Callable for tool call interception with handler callback.
Receives handler callback to execute tool and returns final `ToolMessage` or
`Command`.
"""
@overload
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]: ...
@overload
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> Awaitable[ToolMessage | Command[Any]]: ...
def __call__(
self,
request: ToolCallRequest,
handler: Callable[
[ToolCallRequest], ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]]
],
) -> ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]]:
"""Intercept tool execution via handler callback."""
...
# Sync/async signatures for `wrap_tool_call` interception; see `@wrap_tool_call`.
_SyncCallableReturningToolResponse = Callable[
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command[Any]]],
ToolMessage | Command[Any],
]
_AsyncCallableReturningToolResponse = Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]],
Awaitable[ToolMessage | Command[Any]],
]
_CallableReturningToolResponse = (
_SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse
)
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
@@ -1966,8 +1928,10 @@ def wrap_model_call(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
return await func(request, handler)
) -> ModelCallResult[ResponseT]:
return await cast(
"_AsyncCallableReturningModelResponse[ContextT, ResponseT]", func
)(request, handler)
middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
@@ -1992,8 +1956,10 @@ def wrap_model_call(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
return func(request, handler)
) -> ModelCallResult[ResponseT]:
return cast("_SyncCallableReturningModelResponse[ContextT, ResponseT]", func)(
request, handler
)
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
@@ -2137,7 +2103,7 @@ def wrap_tool_call(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
return await func(request, handler)
return await cast("_AsyncCallableReturningToolResponse", func)(request, handler)
middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
@@ -2163,7 +2129,7 @@ def wrap_tool_call(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
return func(request, handler)
return cast("_SyncCallableReturningToolResponse", func)(request, handler)
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))