chore(langchain): activate mypy warn_return_any (#34249)

Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Christophe Bornet
2026-06-10 19:53:06 +02:00
committed by GitHub
parent 2b4735712c
commit 23ce677870
6 changed files with 335 additions and 191 deletions

View File

@@ -454,7 +454,9 @@ def _resolve_schema(
if not should_omit:
all_annotations[field_name] = field_type
return TypedDict(schema_name, all_annotations) # type: ignore[operator]
# `TypedDict` dynamically creates a class, but type checkers don't infer that
# the runtime result satisfies this function's `type` return contract.
return cast("type", TypedDict(schema_name, all_annotations)) # type: ignore[operator]
def _extract_metadata(type_: type) -> list[Any]:
@@ -493,7 +495,8 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
and sync_method is not base_sync_method
and hasattr(sync_method, "__can_jump_to__")
):
return sync_method.__can_jump_to__
# `hasattr` proves the metadata exists at runtime, but not its value type.
return cast("list[JumpTo]", sync_method.__can_jump_to__)
# Try async method - only if it's overridden from base class
async_method = getattr(middleware.__class__, f"a{hook_name}", None)
@@ -502,7 +505,8 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
and async_method is not base_async_method
and hasattr(async_method, "__can_jump_to__")
):
return async_method.__can_jump_to__
# `hasattr` proves the metadata exists at runtime, but not its value type.
return cast("list[JumpTo]", async_method.__can_jump_to__)
return []

View File

@@ -764,7 +764,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, R
payload: dict[str, Any],
*,
tool_call_id: str | None,
) -> Any:
) -> ToolMessage | str:
session = resources.session
if payload.get("restart"):

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import warnings
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass, field, replace
from inspect import iscoroutinefunction
@@ -12,17 +13,12 @@ from typing import (
Generic,
Literal,
Protocol,
TypeAlias,
cast,
overload,
)
if TYPE_CHECKING:
from collections.abc import Awaitable
# Needed as top level import for Pydantic schema generation on AgentState
import warnings
from typing import TypeAlias
from langchain_core.messages import (
AIMessage,
AnyMessage,
@@ -33,15 +29,15 @@ from langchain_core.messages import (
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
from langgraph.runtime import Runtime
from langgraph.types import Command
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from langgraph.stream._mux import TransformerFactory
from langgraph.types import Command
from langchain.agents.structured_output import ResponseFormat
@@ -811,38 +807,54 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
raise NotImplementedError(msg)
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with `AgentState` and `Runtime` as arguments."""
_SyncCallableWithStateAndRuntime = Callable[
[StateT_contra, Runtime[ContextT]], dict[str, Any] | Command[Any] | None
]
_AsyncCallableWithStateAndRuntime = Callable[
[StateT_contra, Runtime[ContextT]], Awaitable[dict[str, Any] | Command[Any] | None]
]
_CallableWithStateAndRuntime = (
_SyncCallableWithStateAndRuntime[StateT_contra, ContextT]
| _AsyncCallableWithStateAndRuntime[StateT_contra, ContextT]
)
def __call__(
self, state: StateT_contra, runtime: Runtime[ContextT]
) -> dict[str, Any] | Command[Any] | None | Awaitable[dict[str, Any] | Command[Any] | None]:
"""Perform some logic with the state and runtime."""
...
_SyncCallableReturningSystemMessage = Callable[[ModelRequest[ContextT]], str | SystemMessage]
_AsyncCallableReturningSystemMessage = Callable[
[ModelRequest[ContextT]], Awaitable[str | SystemMessage]
]
_CallableReturningSystemMessage = (
_SyncCallableReturningSystemMessage[ContextT] | _AsyncCallableReturningSystemMessage[ContextT]
)
class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
"""Callable that returns a prompt string or SystemMessage given `ModelRequest`."""
def __call__(
self, request: ModelRequest[ContextT]
) -> str | SystemMessage | Awaitable[str | SystemMessage]:
"""Generate a system prompt string or SystemMessage based on the request."""
...
class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT, ResponseT]): # type: ignore[misc]
class _CallableReturningModelResponse(Protocol[ContextT, ResponseT]):
"""Callable for model call interception with handler callback.
Receives handler callback to execute model and returns `ModelResponse` or
`AIMessage`.
"""
@overload
def __call__(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
) -> ModelResponse[ResponseT] | AIMessage: ...
@overload
def __call__(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> Awaitable[ModelResponse[ResponseT] | AIMessage]: ...
def __call__(
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], ModelResponse[ResponseT] | Awaitable[ModelResponse[ResponseT]]
],
) -> ModelResponse[ResponseT] | AIMessage | Awaitable[ModelResponse[ResponseT] | AIMessage]:
"""Intercept model execution via handler callback."""
...
@@ -854,11 +866,27 @@ class _CallableReturningToolResponse(Protocol):
`Command`.
"""
@overload
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
) -> ToolMessage | Command[Any]: ...
@overload
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> Awaitable[ToolMessage | Command[Any]]: ...
def __call__(
self,
request: ToolCallRequest,
handler: Callable[
[ToolCallRequest], ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]]
],
) -> ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]]:
"""Intercept tool execution via handler callback."""
...
@@ -1041,7 +1069,11 @@ def before_model(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return await func(state, runtime) # type: ignore[misc]
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
# cannot narrow this sync-or-async callable union.
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
state, runtime
)
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
@@ -1051,22 +1083,29 @@ def before_model(
"str", getattr(func, "__name__", "BeforeModelMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"abefore_model": async_wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"abefore_model": async_wrapped,
},
)(),
)
def wrapped(
_self: AgentMiddleware[StateT, ContextT],
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return func(state, runtime) # type: ignore[return-value]
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
# cannot narrow this sync-or-async callable union.
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
@@ -1075,15 +1114,20 @@ def before_model(
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"before_model": wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"before_model": wrapped,
},
)(),
)
if func is not None:
return decorator(func)
@@ -1201,7 +1245,11 @@ def after_model(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return await func(state, runtime) # type: ignore[misc]
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
# cannot narrow this sync-or-async callable union.
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
state, runtime
)
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
@@ -1209,22 +1257,29 @@ def after_model(
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"aafter_model": async_wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"aafter_model": async_wrapped,
},
)(),
)
def wrapped(
_self: AgentMiddleware[StateT, ContextT],
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return func(state, runtime) # type: ignore[return-value]
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
# cannot narrow this sync-or-async callable union.
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
@@ -1233,15 +1288,20 @@ def after_model(
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"after_model": wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"after_model": wrapped,
},
)(),
)
if func is not None:
return decorator(func)
@@ -1392,7 +1452,11 @@ def before_agent(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return await func(state, runtime) # type: ignore[misc]
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
# cannot narrow this sync-or-async callable union.
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
state, runtime
)
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
@@ -1402,22 +1466,29 @@ def before_agent(
"str", getattr(func, "__name__", "BeforeAgentMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"abefore_agent": async_wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"abefore_agent": async_wrapped,
},
)(),
)
def wrapped(
_self: AgentMiddleware[StateT, ContextT],
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return func(state, runtime) # type: ignore[return-value]
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
# cannot narrow this sync-or-async callable union.
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
@@ -1426,15 +1497,20 @@ def before_agent(
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"before_agent": wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"before_agent": wrapped,
},
)(),
)
if func is not None:
return decorator(func)
@@ -1553,7 +1629,11 @@ def after_agent(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return await func(state, runtime) # type: ignore[misc]
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
# cannot narrow this sync-or-async callable union.
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
state, runtime
)
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
@@ -1561,22 +1641,29 @@ def after_agent(
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"aafter_agent": async_wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"aafter_agent": async_wrapped,
},
)(),
)
def wrapped(
_self: AgentMiddleware[StateT, ContextT],
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return func(state, runtime) # type: ignore[return-value]
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
# cannot narrow this sync-or-async callable union.
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
@@ -1585,15 +1672,20 @@ def after_agent(
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"after_agent": wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"after_agent": wrapped,
},
)(),
)
if func is not None:
return decorator(func)
@@ -1602,7 +1694,7 @@ def after_agent(
@overload
def dynamic_prompt(
func: _CallableReturningSystemMessage[StateT, ContextT],
func: _CallableReturningSystemMessage[ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@@ -1610,16 +1702,16 @@ def dynamic_prompt(
def dynamic_prompt(
func: None = None,
) -> Callable[
[_CallableReturningSystemMessage[StateT, ContextT]],
[_CallableReturningSystemMessage[ContextT]],
AgentMiddleware[StateT, ContextT],
]: ...
def dynamic_prompt(
func: _CallableReturningSystemMessage[StateT, ContextT] | None = None,
func: _CallableReturningSystemMessage[ContextT] | None = None,
) -> (
Callable[
[_CallableReturningSystemMessage[StateT, ContextT]],
[_CallableReturningSystemMessage[ContextT]],
AgentMiddleware[StateT, ContextT],
]
| AgentMiddleware[StateT, ContextT]
@@ -1673,7 +1765,7 @@ def dynamic_prompt(
"""
def decorator(
func: _CallableReturningSystemMessage[StateT, ContextT],
func: _CallableReturningSystemMessage[ContextT],
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)
@@ -1684,7 +1776,7 @@ def dynamic_prompt(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[Any]]],
) -> ModelResponse[Any] | AIMessage:
prompt = await func(request) # type: ignore[misc]
prompt = await cast("_AsyncCallableReturningSystemMessage[ContextT]", func)(request)
if isinstance(prompt, SystemMessage):
request = request.override(system_message=prompt)
else:
@@ -1693,22 +1785,27 @@ def dynamic_prompt(
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": AgentState,
"tools": [],
"awrap_model_call": async_wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": AgentState,
"tools": [],
"awrap_model_call": async_wrapped,
},
)(),
)
def wrapped(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[Any]],
) -> ModelResponse[Any] | AIMessage:
prompt = cast("Callable[[ModelRequest[ContextT]], SystemMessage | str]", func)(request)
prompt = cast("_SyncCallableReturningSystemMessage[ContextT]", func)(request)
if isinstance(prompt, SystemMessage):
request = request.override(system_message=prompt)
else:
@@ -1721,7 +1818,7 @@ def dynamic_prompt(
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[Any]]],
) -> ModelResponse[Any] | AIMessage:
# Delegate to sync function
prompt = cast("Callable[[ModelRequest[ContextT]], SystemMessage | str]", func)(request)
prompt = cast("_SyncCallableReturningSystemMessage[ContextT]", func)(request)
if isinstance(prompt, SystemMessage):
request = request.override(system_message=prompt)
else:
@@ -1730,16 +1827,21 @@ def dynamic_prompt(
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": AgentState,
"tools": [],
"wrap_model_call": wrapped,
"awrap_model_call": async_wrapped_from_sync,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": AgentState,
"tools": [],
"wrap_model_call": wrapped,
"awrap_model_call": async_wrapped_from_sync,
},
)(),
)
if func is not None:
return decorator(func)
@@ -1748,7 +1850,7 @@ def dynamic_prompt(
@overload
def wrap_model_call(
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT],
func: _CallableReturningModelResponse[ContextT, ResponseT],
) -> AgentMiddleware[StateT, ContextT]: ...
@@ -1760,20 +1862,20 @@ def wrap_model_call(
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[
[_CallableReturningModelResponse[StateT, ContextT, ResponseT]],
[_CallableReturningModelResponse[ContextT, ResponseT]],
AgentMiddleware[StateT, ContextT],
]: ...
def wrap_model_call(
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT] | None = None,
func: _CallableReturningModelResponse[ContextT, ResponseT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[
[_CallableReturningModelResponse[StateT, ContextT, ResponseT]],
[_CallableReturningModelResponse[ContextT, ResponseT]],
AgentMiddleware[StateT, ContextT],
]
| AgentMiddleware[StateT, ContextT]
@@ -1854,7 +1956,7 @@ def wrap_model_call(
"""
def decorator(
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT],
func: _CallableReturningModelResponse[ContextT, ResponseT],
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)
@@ -1865,21 +1967,26 @@ def wrap_model_call(
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
return await func(request, handler) # type: ignore[misc, arg-type]
return await func(request, handler)
middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"awrap_model_call": async_wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"awrap_model_call": async_wrapped,
},
)(),
)
def wrapped(
_self: AgentMiddleware[StateT, ContextT],
@@ -1890,15 +1997,20 @@ def wrap_model_call(
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"wrap_model_call": wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware[StateT, ContextT]",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"wrap_model_call": wrapped,
},
)(),
)
if func is not None:
return decorator(func)
@@ -2025,21 +2137,26 @@ def wrap_tool_call(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
return await func(request, handler) # type: ignore[arg-type,misc]
return await func(request, handler)
middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": AgentState,
"tools": tools or [],
"awrap_tool_call": async_wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": AgentState,
"tools": tools or [],
"awrap_tool_call": async_wrapped,
},
)(),
)
def wrapped(
_self: AgentMiddleware,
@@ -2050,15 +2167,20 @@ def wrap_tool_call(
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": AgentState,
"tools": tools or [],
"wrap_tool_call": wrapped,
},
)()
# `type(...)` builds the correct middleware subclass at runtime, but
# type checkers cannot infer its generic `AgentMiddleware` parameters.
return cast(
"AgentMiddleware",
type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": AgentState,
"tools": tools or [],
"wrap_tool_call": wrapped,
},
)(),
)
if func is not None:
return decorator(func)

View File

@@ -76,7 +76,7 @@ class StructuredOutputValidationError(StructuredOutputError):
def _parse_with_schema(
schema: type[SchemaT] | dict[str, Any], schema_kind: SchemaKind, data: dict[str, Any]
) -> Any:
) -> SchemaT | dict[str, Any]:
"""Parse data using for any supported schema type.
Args:
@@ -92,9 +92,10 @@ def _parse_with_schema(
ValueError: If parsing fails
"""
if schema_kind == "json_schema":
# Raw JSON schema has no corresponding Python type to instantiate.
return data
try:
adapter: TypeAdapter[SchemaT] = TypeAdapter(schema)
adapter = TypeAdapter[SchemaT](schema)
return adapter.validate_python(data)
except Exception as e:
schema_name = getattr(schema, "__name__", str(schema))
@@ -344,7 +345,7 @@ class OutputToolBinding(Generic[SchemaT]):
),
)
def parse(self, tool_args: dict[str, Any]) -> SchemaT:
def parse(self, tool_args: dict[str, Any]) -> SchemaT | dict[str, Any]:
"""Parse tool arguments according to the schema.
Args:
@@ -391,7 +392,7 @@ class ProviderStrategyBinding(Generic[SchemaT]):
schema_kind=schema_spec.schema_kind,
)
def parse(self, response: AIMessage) -> SchemaT:
def parse(self, response: AIMessage) -> SchemaT | dict[str, Any]:
"""Parse `AIMessage` content according to the schema.
Args:

View File

@@ -1017,17 +1017,37 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
yield x
# Explicitly added to satisfy downstream linters.
# `bind_tools` is implemented by concrete models because tool binding is
# provider-specific. A configurable model may not have a concrete model instance
# yet, since invocation config can choose it later. Save the `bind_tools` tools
# and kwargs now. When `_model` later builds the selected provider model, it calls
# `selected_model.bind_tools(tools, **kwargs)` and returns that runnable.
# Cast so callers still get the public return type.
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable[..., Any] | BaseTool],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
return self.__getattr__("bind_tools")(tools, **kwargs)
return cast(
"Runnable[LanguageModelInput, AIMessage]",
self.__getattr__("bind_tools")(tools, **kwargs),
)
# Explicitly added to satisfy downstream linters.
# `with_structured_output` is implemented by concrete models because structured
# output support is provider-specific. A configurable model may not have a
# concrete model instance yet, since invocation config can choose it later. Save
# the structured-output schema and kwargs now. When `_model` later builds the
# selected provider model, it calls
# `selected_model.with_structured_output(schema, **kwargs)` and returns that
# runnable.
# Cast so callers still get the public return type.
def with_structured_output(
self,
schema: dict[str, Any] | type[BaseModel],
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict[str, Any] | BaseModel]:
return self.__getattr__("with_structured_output")(schema, **kwargs)
return cast(
"Runnable[LanguageModelInput, dict[str, Any] | BaseModel]",
self.__getattr__("with_structured_output")(schema, **kwargs),
)

View File

@@ -116,9 +116,6 @@ exclude = [
"tests/unit_tests/agents/test_.*\\.py",
]
# TODO: activate for 'strict' checking
warn_return_any = false
[[tool.mypy.overrides]]
module = ["pytest_socket.*", "vcr.*"]
ignore_missing_imports = true