diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 1c0645ad276..58ebcac29b1 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -12,8 +12,6 @@ from typing import ( Any, Generic, Literal, - Protocol, - TypeAlias, cast, overload, ) @@ -310,10 +308,8 @@ class ExtendedModelResponse(Generic[ResponseT]): """Optional command to apply as an additional state update.""" -ModelCallResult: TypeAlias = ( - "ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT]" -) -"""`TypeAlias` for model call handler return value. +ModelCallResult = ModelResponse[ResponseT] | AIMessage | ExtendedModelResponse[ResponseT] +"""Return type for model call handlers. Middleware can return either: @@ -826,69 +822,35 @@ _CallableReturningSystemMessage = ( _SyncCallableReturningSystemMessage[ContextT] | _AsyncCallableReturningSystemMessage[ContextT] ) +# Sync/async signatures for `wrap_model_call` interception; see `@wrap_model_call`. +_SyncCallableReturningModelResponse = Callable[ + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]]], + ModelCallResult, +] +_AsyncCallableReturningModelResponse = Callable[ + [ + ModelRequest[ContextT], + Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ], + Awaitable[ModelCallResult], +] +_CallableReturningModelResponse = ( + _SyncCallableReturningModelResponse[ContextT, ResponseT] + | _AsyncCallableReturningModelResponse[ContextT, ResponseT] +) -class _CallableReturningModelResponse(Protocol[ContextT, ResponseT]): - """Callable for model call interception with handler callback. - - Receives handler callback to execute model and returns `ModelResponse` or - `AIMessage`. - """ - - @overload - def __call__( - self, - request: ModelRequest[ContextT], - handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], - ) -> ModelResponse[ResponseT] | AIMessage: ... - - @overload - def __call__( - self, - request: ModelRequest[ContextT], - handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], - ) -> Awaitable[ModelResponse[ResponseT] | AIMessage]: ... - - def __call__( - self, - request: ModelRequest[ContextT], - handler: Callable[ - [ModelRequest[ContextT]], ModelResponse[ResponseT] | Awaitable[ModelResponse[ResponseT]] - ], - ) -> ModelResponse[ResponseT] | AIMessage | Awaitable[ModelResponse[ResponseT] | AIMessage]: - """Intercept model execution via handler callback.""" - ... - - -class _CallableReturningToolResponse(Protocol): - """Callable for tool call interception with handler callback. - - Receives handler callback to execute tool and returns final `ToolMessage` or - `Command`. - """ - - @overload - def __call__( - self, - request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], - ) -> ToolMessage | Command[Any]: ... - - @overload - def __call__( - self, - request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], - ) -> Awaitable[ToolMessage | Command[Any]]: ... - - def __call__( - self, - request: ToolCallRequest, - handler: Callable[ - [ToolCallRequest], ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]] - ], - ) -> ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]]: - """Intercept tool execution via handler callback.""" - ... +# Sync/async signatures for `wrap_tool_call` interception; see `@wrap_tool_call`. +_SyncCallableReturningToolResponse = Callable[ + [ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command[Any]]], + ToolMessage | Command[Any], +] +_AsyncCallableReturningToolResponse = Callable[ + [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]], + Awaitable[ToolMessage | Command[Any]], +] +_CallableReturningToolResponse = ( + _SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse +) CallableT = TypeVar("CallableT", bound=Callable[..., Any]) @@ -1966,8 +1928,10 @@ def wrap_model_call( _self: AgentMiddleware[StateT, ContextT], request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], - ) -> ModelResponse[ResponseT] | AIMessage: - return await func(request, handler) + ) -> ModelCallResult[ResponseT]: + return await cast( + "_AsyncCallableReturningModelResponse[ContextT, ResponseT]", func + )(request, handler) middleware_name = name or cast( "str", getattr(func, "__name__", "WrapModelCallMiddleware") @@ -1992,8 +1956,10 @@ def wrap_model_call( _self: AgentMiddleware[StateT, ContextT], request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], - ) -> ModelResponse[ResponseT] | AIMessage: - return func(request, handler) + ) -> ModelCallResult[ResponseT]: + return cast("_SyncCallableReturningModelResponse[ContextT, ResponseT]", func)( + request, handler + ) middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware")) @@ -2137,7 +2103,7 @@ def wrap_tool_call( request: ToolCallRequest, handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], ) -> ToolMessage | Command[Any]: - return await func(request, handler) + return await cast("_AsyncCallableReturningToolResponse", func)(request, handler) middleware_name = name or cast( "str", getattr(func, "__name__", "WrapToolCallMiddleware") @@ -2163,7 +2129,7 @@ def wrap_tool_call( request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], ) -> ToolMessage | Command[Any]: - return func(request, handler) + return cast("_SyncCallableReturningToolResponse", func)(request, handler) middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))