Compare commits

...

1 Commits

Author SHA1 Message Date
Sydney Runkle
a53ed2ba5d typing fixes 2025-12-17 16:09:49 -05:00

View File

@@ -709,7 +709,7 @@ class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # typ
class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
"""Callable for model call interception with handler callback.
"""Callable for sync model call interception with handler callback.
Receives handler callback to execute model and returns `ModelResponse` or
`AIMessage`.
@@ -724,8 +724,24 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ
...
class _AsyncCallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
"""Callable for async model call interception with handler callback.
Receives async handler callback to execute model and returns `ModelResponse` or
`AIMessage`.
"""
def __call__(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> Awaitable[ModelCallResult]:
"""Intercept async model execution via handler callback."""
...
class _CallableReturningToolResponse(Protocol):
"""Callable for tool call interception with handler callback.
"""Callable for sync tool call interception with handler callback.
Receives handler callback to execute tool and returns final `ToolMessage` or
`Command`.
@@ -740,6 +756,22 @@ class _CallableReturningToolResponse(Protocol):
...
class _AsyncCallableReturningToolResponse(Protocol):
"""Callable for async tool call interception with handler callback.
Receives async handler callback to execute tool and returns final `ToolMessage` or
`Command`.
"""
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> Awaitable[ToolMessage | Command]:
"""Intercept async tool execution via handler callback."""
...
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
@@ -1535,6 +1567,12 @@ def wrap_model_call(
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def wrap_model_call(
func: _AsyncCallableReturningModelResponse[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def wrap_model_call(
func: None = None,
@@ -1543,20 +1581,28 @@ def wrap_model_call(
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[
[_CallableReturningModelResponse[StateT, ContextT]],
[
_CallableReturningModelResponse[StateT, ContextT]
| _AsyncCallableReturningModelResponse[StateT, ContextT]
],
AgentMiddleware[StateT, ContextT],
]: ...
def wrap_model_call(
func: _CallableReturningModelResponse[StateT, ContextT] | None = None,
func: _CallableReturningModelResponse[StateT, ContextT]
| _AsyncCallableReturningModelResponse[StateT, ContextT]
| None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[
[_CallableReturningModelResponse[StateT, ContextT]],
[
_CallableReturningModelResponse[StateT, ContextT]
| _AsyncCallableReturningModelResponse[StateT, ContextT]
],
AgentMiddleware[StateT, ContextT],
]
| AgentMiddleware[StateT, ContextT]
@@ -1637,18 +1683,22 @@ def wrap_model_call(
"""
def decorator(
func: _CallableReturningModelResponse[StateT, ContextT],
func: _CallableReturningModelResponse[StateT, ContextT]
| _AsyncCallableReturningModelResponse[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)
if is_async:
async_func = cast(
"_AsyncCallableReturningModelResponse[StateT, ContextT]", func
)
async def async_wrapped(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await func(request, handler) # type: ignore[misc, arg-type]
return await async_func(request, handler)
middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
@@ -1664,12 +1714,14 @@ def wrap_model_call(
},
)()
sync_func = cast("_CallableReturningModelResponse[StateT, ContextT]", func)
def wrapped(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return func(request, handler)
return sync_func(request, handler)
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
@@ -1694,6 +1746,12 @@ def wrap_tool_call(
) -> AgentMiddleware: ...
@overload
def wrap_tool_call(
func: _AsyncCallableReturningToolResponse,
) -> AgentMiddleware: ...
@overload
def wrap_tool_call(
func: None = None,
@@ -1701,19 +1759,19 @@ def wrap_tool_call(
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[
[_CallableReturningToolResponse],
[_CallableReturningToolResponse | _AsyncCallableReturningToolResponse],
AgentMiddleware,
]: ...
def wrap_tool_call(
func: _CallableReturningToolResponse | None = None,
func: _CallableReturningToolResponse | _AsyncCallableReturningToolResponse | None = None,
*,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[
[_CallableReturningToolResponse],
[_CallableReturningToolResponse | _AsyncCallableReturningToolResponse],
AgentMiddleware,
]
| AgentMiddleware
@@ -1797,18 +1855,19 @@ def wrap_tool_call(
"""
def decorator(
func: _CallableReturningToolResponse,
func: _CallableReturningToolResponse | _AsyncCallableReturningToolResponse,
) -> AgentMiddleware:
is_async = iscoroutinefunction(func)
if is_async:
async_func = cast("_AsyncCallableReturningToolResponse", func)
async def async_wrapped(
_self: AgentMiddleware,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
return await func(request, handler) # type: ignore[arg-type,misc]
return await async_func(request, handler)
middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
@@ -1824,12 +1883,14 @@ def wrap_tool_call(
},
)()
sync_func = cast("_CallableReturningToolResponse", func)
def wrapped(
_self: AgentMiddleware,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
return func(request, handler)
return sync_func(request, handler)
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))