mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 22:59:06 +00:00
fix(langchain): support async middleware decorator typing (#34584)
Fixes #35244 Users can write async agent middleware with `@wrap_model_call`, and LangChain already supports that behavior at runtime by detecting coroutine functions and wiring them to `awrap_model_call`. However, the decorator's public typing currently describes only the sync callable shape. As a result, valid async middleware is rejected by type checkers such as mypy and ty, even though the same code runs correctly. This updates the middleware decorator types so async `wrap_model_call` and `wrap_tool_call` functions type-check consistently with their runtime behavior. It also simplifies related callable aliases and uses casts where `iscoroutinefunction` narrows the callable at runtime but static type checkers cannot follow that narrowing. --------- Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
committed by
GitHub
parent
c16499c339
commit
8fc58c6013
@@ -12,8 +12,6 @@ from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
@@ -310,10 +308,8 @@ class ExtendedModelResponse(Generic[ResponseT]):
|
||||
"""Optional command to apply as an additional state update."""
|
||||
|
||||
|
||||
ModelCallResult: TypeAlias = (
|
||||
"ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]"
|
||||
)
|
||||
"""`TypeAlias` for model call handler return value.
|
||||
ModelCallResult = ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]
|
||||
"""Return type for model call handlers.
|
||||
|
||||
Middleware can return either:
|
||||
|
||||
@@ -826,69 +822,35 @@ _CallableReturningSystemMessage = (
|
||||
_SyncCallableReturningSystemMessage[ContextT] | _AsyncCallableReturningSystemMessage[ContextT]
|
||||
)
|
||||
|
||||
# Sync/async signatures for `wrap_model_call` interception; see `@wrap_model_call`.
|
||||
_SyncCallableReturningModelResponse = Callable[
|
||||
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]]],
|
||||
ModelCallResult,
|
||||
]
|
||||
_AsyncCallableReturningModelResponse = Callable[
|
||||
[
|
||||
ModelRequest[ContextT],
|
||||
Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
],
|
||||
Awaitable[ModelCallResult],
|
||||
]
|
||||
_CallableReturningModelResponse = (
|
||||
_SyncCallableReturningModelResponse[ContextT, ResponseT]
|
||||
| _AsyncCallableReturningModelResponse[ContextT, ResponseT]
|
||||
)
|
||||
|
||||
class _CallableReturningModelResponse(Protocol[ContextT, ResponseT]):
|
||||
"""Callable for model call interception with handler callback.
|
||||
|
||||
Receives handler callback to execute model and returns `ModelResponse` or
|
||||
`AIMessage`.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage: ...
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> Awaitable[ModelResponse[ResponseT] | AIMessage]: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], ModelResponse[ResponseT] | Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT] | AIMessage | Awaitable[ModelResponse[ResponseT] | AIMessage]:
|
||||
"""Intercept model execution via handler callback."""
|
||||
...
|
||||
|
||||
|
||||
class _CallableReturningToolResponse(Protocol):
|
||||
"""Callable for tool call interception with handler callback.
|
||||
|
||||
Receives handler callback to execute tool and returns final `ToolMessage` or
|
||||
`Command`.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]: ...
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> Awaitable[ToolMessage | Command[Any]]: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[
|
||||
[ToolCallRequest], ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]]
|
||||
],
|
||||
) -> ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]]:
|
||||
"""Intercept tool execution via handler callback."""
|
||||
...
|
||||
# Sync/async signatures for `wrap_tool_call` interception; see `@wrap_tool_call`.
|
||||
_SyncCallableReturningToolResponse = Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command[Any]]],
|
||||
ToolMessage | Command[Any],
|
||||
]
|
||||
_AsyncCallableReturningToolResponse = Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]],
|
||||
Awaitable[ToolMessage | Command[Any]],
|
||||
]
|
||||
_CallableReturningToolResponse = (
|
||||
_SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse
|
||||
)
|
||||
|
||||
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
@@ -1966,8 +1928,10 @@ def wrap_model_call(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
return await func(request, handler)
|
||||
) -> ModelCallResult[ResponseT]:
|
||||
return await cast(
|
||||
"_AsyncCallableReturningModelResponse[ContextT, ResponseT]", func
|
||||
)(request, handler)
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
|
||||
@@ -1992,8 +1956,10 @@ def wrap_model_call(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
return func(request, handler)
|
||||
) -> ModelCallResult[ResponseT]:
|
||||
return cast("_SyncCallableReturningModelResponse[ContextT, ResponseT]", func)(
|
||||
request, handler
|
||||
)
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
|
||||
|
||||
@@ -2137,7 +2103,7 @@ def wrap_tool_call(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return await func(request, handler)
|
||||
return await cast("_AsyncCallableReturningToolResponse", func)(request, handler)
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
|
||||
@@ -2163,7 +2129,7 @@ def wrap_tool_call(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return func(request, handler)
|
||||
return cast("_SyncCallableReturningToolResponse", func)(request, handler)
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user