From 23ce67787002e826414f4fd9bc39ca21f8503a91 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 10 Jun 2026 19:53:06 +0200 Subject: [PATCH] chore(langchain): activate mypy `warn_return_any` (#34249) Co-authored-by: Mason Daugherty --- libs/langchain_v1/langchain/agents/factory.py | 10 +- .../langchain/agents/middleware/shell_tool.py | 2 +- .../langchain/agents/middleware/types.py | 478 +++++++++++------- .../langchain/agents/structured_output.py | 9 +- .../langchain/chat_models/base.py | 24 +- libs/langchain_v1/pyproject.toml | 3 - 6 files changed, 335 insertions(+), 191 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index ffa23d910fa..1c07d275da4 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -454,7 +454,9 @@ def _resolve_schema( if not should_omit: all_annotations[field_name] = field_type - return TypedDict(schema_name, all_annotations) # type: ignore[operator] + # `TypedDict` dynamically creates a class, but type checkers don't infer that + # the runtime result satisfies this function's `type` return contract. + return cast("type", TypedDict(schema_name, all_annotations)) # type: ignore[operator] def _extract_metadata(type_: type) -> list[Any]: @@ -493,7 +495,8 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l and sync_method is not base_sync_method and hasattr(sync_method, "__can_jump_to__") ): - return sync_method.__can_jump_to__ + # `hasattr` proves the metadata exists at runtime, but not its value type. + return cast("list[JumpTo]", sync_method.__can_jump_to__) # Try async method - only if it's overridden from base class async_method = getattr(middleware.__class__, f"a{hook_name}", None) @@ -502,7 +505,8 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l and async_method is not base_async_method and hasattr(async_method, "__can_jump_to__") ): - return async_method.__can_jump_to__ + # `hasattr` proves the metadata exists at runtime, but not its value type. + return cast("list[JumpTo]", async_method.__can_jump_to__) return [] diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index 93fd978ca4f..e2326640932 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -764,7 +764,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, R payload: dict[str, Any], *, tool_call_id: str | None, - ) -> Any: + ) -> ToolMessage | str: session = resources.session if payload.get("restart"): diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 1570ebd9402..1c0645ad276 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass, field, replace from inspect import iscoroutinefunction @@ -12,17 +13,12 @@ from typing import ( Generic, Literal, Protocol, + TypeAlias, cast, overload, ) -if TYPE_CHECKING: - from collections.abc import Awaitable - # Needed as top level import for Pydantic schema generation on AgentState -import warnings -from typing import TypeAlias - from langchain_core.messages import ( AIMessage, AnyMessage, @@ -33,15 +29,15 @@ from langchain_core.messages import ( from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.graph.message import add_messages from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper +from langgraph.runtime import Runtime +from langgraph.types import Command from langgraph.typing import ContextT from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.tools import BaseTool - from langgraph.runtime import Runtime from langgraph.stream._mux import TransformerFactory - from langgraph.types import Command from langchain.agents.structured_output import ResponseFormat @@ -811,38 +807,54 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]): raise NotImplementedError(msg) -class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): - """Callable with `AgentState` and `Runtime` as arguments.""" +_SyncCallableWithStateAndRuntime = Callable[ + [StateT_contra, Runtime[ContextT]], dict[str, Any] | Command[Any] | None +] +_AsyncCallableWithStateAndRuntime = Callable[ + [StateT_contra, Runtime[ContextT]], Awaitable[dict[str, Any] | Command[Any] | None] +] +_CallableWithStateAndRuntime = ( + _SyncCallableWithStateAndRuntime[StateT_contra, ContextT] + | _AsyncCallableWithStateAndRuntime[StateT_contra, ContextT] +) - def __call__( - self, state: StateT_contra, runtime: Runtime[ContextT] - ) -> dict[str, Any] | Command[Any] | None | Awaitable[dict[str, Any] | Command[Any] | None]: - """Perform some logic with the state and runtime.""" - ... +_SyncCallableReturningSystemMessage = Callable[[ModelRequest[ContextT]], str | SystemMessage] +_AsyncCallableReturningSystemMessage = Callable[ + [ModelRequest[ContextT]], Awaitable[str | SystemMessage] +] +_CallableReturningSystemMessage = ( + _SyncCallableReturningSystemMessage[ContextT] | _AsyncCallableReturningSystemMessage[ContextT] +) -class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc] - """Callable that returns a prompt string or SystemMessage given `ModelRequest`.""" - - def __call__( - self, request: ModelRequest[ContextT] - ) -> str | SystemMessage | Awaitable[str | SystemMessage]: - """Generate a system prompt string or SystemMessage based on the request.""" - ... - - -class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT, ResponseT]): # type: ignore[misc] +class _CallableReturningModelResponse(Protocol[ContextT, ResponseT]): """Callable for model call interception with handler callback. Receives handler callback to execute model and returns `ModelResponse` or `AIMessage`. """ + @overload def __call__( self, request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], - ) -> ModelResponse[ResponseT] | AIMessage: + ) -> ModelResponse[ResponseT] | AIMessage: ... + + @overload + def __call__( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> Awaitable[ModelResponse[ResponseT] | AIMessage]: ... + + def __call__( + self, + request: ModelRequest[ContextT], + handler: Callable[ + [ModelRequest[ContextT]], ModelResponse[ResponseT] | Awaitable[ModelResponse[ResponseT]] + ], + ) -> ModelResponse[ResponseT] | AIMessage | Awaitable[ModelResponse[ResponseT] | AIMessage]: """Intercept model execution via handler callback.""" ... @@ -854,11 +866,27 @@ class _CallableReturningToolResponse(Protocol): `Command`. """ + @overload def __call__( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], - ) -> ToolMessage | Command[Any]: + ) -> ToolMessage | Command[Any]: ... + + @overload + def __call__( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> Awaitable[ToolMessage | Command[Any]]: ... + + def __call__( + self, + request: ToolCallRequest, + handler: Callable[ + [ToolCallRequest], ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]] + ], + ) -> ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]]: """Intercept tool execution via handler callback.""" ... @@ -1041,7 +1069,11 @@ def before_model( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return await func(state, runtime) # type: ignore[misc] + # `iscoroutinefunction` narrows `func` at runtime, but type checkers + # cannot narrow this sync-or-async callable union. + return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)( + state, runtime + ) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1051,22 +1083,29 @@ def before_model( "str", getattr(func, "__name__", "BeforeModelMiddleware") ) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": state_schema or AgentState, - "tools": tools or [], - "abefore_model": async_wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "abefore_model": async_wrapped, + }, + )(), + ) def wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return func(state, runtime) # type: ignore[return-value] + # `iscoroutinefunction` narrows `func` at runtime, but type checkers + # cannot narrow this sync-or-async callable union. + return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1075,15 +1114,20 @@ def before_model( # Use function name as default if no name provided middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware")) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": state_schema or AgentState, - "tools": tools or [], - "before_model": wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "before_model": wrapped, + }, + )(), + ) if func is not None: return decorator(func) @@ -1201,7 +1245,11 @@ def after_model( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return await func(state, runtime) # type: ignore[misc] + # `iscoroutinefunction` narrows `func` at runtime, but type checkers + # cannot narrow this sync-or-async callable union. + return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)( + state, runtime + ) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1209,22 +1257,29 @@ def after_model( middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware")) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": state_schema or AgentState, - "tools": tools or [], - "aafter_model": async_wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "aafter_model": async_wrapped, + }, + )(), + ) def wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return func(state, runtime) # type: ignore[return-value] + # `iscoroutinefunction` narrows `func` at runtime, but type checkers + # cannot narrow this sync-or-async callable union. + return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1233,15 +1288,20 @@ def after_model( # Use function name as default if no name provided middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware")) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": state_schema or AgentState, - "tools": tools or [], - "after_model": wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "after_model": wrapped, + }, + )(), + ) if func is not None: return decorator(func) @@ -1392,7 +1452,11 @@ def before_agent( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return await func(state, runtime) # type: ignore[misc] + # `iscoroutinefunction` narrows `func` at runtime, but type checkers + # cannot narrow this sync-or-async callable union. + return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)( + state, runtime + ) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1402,22 +1466,29 @@ def before_agent( "str", getattr(func, "__name__", "BeforeAgentMiddleware") ) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": state_schema or AgentState, - "tools": tools or [], - "abefore_agent": async_wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "abefore_agent": async_wrapped, + }, + )(), + ) def wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return func(state, runtime) # type: ignore[return-value] + # `iscoroutinefunction` narrows `func` at runtime, but type checkers + # cannot narrow this sync-or-async callable union. + return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1426,15 +1497,20 @@ def before_agent( # Use function name as default if no name provided middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware")) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": state_schema or AgentState, - "tools": tools or [], - "before_agent": wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "before_agent": wrapped, + }, + )(), + ) if func is not None: return decorator(func) @@ -1553,7 +1629,11 @@ def after_agent( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return await func(state, runtime) # type: ignore[misc] + # `iscoroutinefunction` narrows `func` at runtime, but type checkers + # cannot narrow this sync-or-async callable union. + return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)( + state, runtime + ) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1561,22 +1641,29 @@ def after_agent( middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware")) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": state_schema or AgentState, - "tools": tools or [], - "aafter_agent": async_wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "aafter_agent": async_wrapped, + }, + )(), + ) def wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return func(state, runtime) # type: ignore[return-value] + # `iscoroutinefunction` narrows `func` at runtime, but type checkers + # cannot narrow this sync-or-async callable union. + return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1585,15 +1672,20 @@ def after_agent( # Use function name as default if no name provided middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware")) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": state_schema or AgentState, - "tools": tools or [], - "after_agent": wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "after_agent": wrapped, + }, + )(), + ) if func is not None: return decorator(func) @@ -1602,7 +1694,7 @@ def after_agent( @overload def dynamic_prompt( - func: _CallableReturningSystemMessage[StateT, ContextT], + func: _CallableReturningSystemMessage[ContextT], ) -> AgentMiddleware[StateT, ContextT]: ... @@ -1610,16 +1702,16 @@ def dynamic_prompt( def dynamic_prompt( func: None = None, ) -> Callable[ - [_CallableReturningSystemMessage[StateT, ContextT]], + [_CallableReturningSystemMessage[ContextT]], AgentMiddleware[StateT, ContextT], ]: ... def dynamic_prompt( - func: _CallableReturningSystemMessage[StateT, ContextT] | None = None, + func: _CallableReturningSystemMessage[ContextT] | None = None, ) -> ( Callable[ - [_CallableReturningSystemMessage[StateT, ContextT]], + [_CallableReturningSystemMessage[ContextT]], AgentMiddleware[StateT, ContextT], ] | AgentMiddleware[StateT, ContextT] @@ -1673,7 +1765,7 @@ def dynamic_prompt( """ def decorator( - func: _CallableReturningSystemMessage[StateT, ContextT], + func: _CallableReturningSystemMessage[ContextT], ) -> AgentMiddleware[StateT, ContextT]: is_async = iscoroutinefunction(func) @@ -1684,7 +1776,7 @@ def dynamic_prompt( request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[Any]]], ) -> ModelResponse[Any] | AIMessage: - prompt = await func(request) # type: ignore[misc] + prompt = await cast("_AsyncCallableReturningSystemMessage[ContextT]", func)(request) if isinstance(prompt, SystemMessage): request = request.override(system_message=prompt) else: @@ -1693,22 +1785,27 @@ def dynamic_prompt( middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware")) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": AgentState, - "tools": [], - "awrap_model_call": async_wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": AgentState, + "tools": [], + "awrap_model_call": async_wrapped, + }, + )(), + ) def wrapped( _self: AgentMiddleware[StateT, ContextT], request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], ModelResponse[Any]], ) -> ModelResponse[Any] | AIMessage: - prompt = cast("Callable[[ModelRequest[ContextT]], SystemMessage | str]", func)(request) + prompt = cast("_SyncCallableReturningSystemMessage[ContextT]", func)(request) if isinstance(prompt, SystemMessage): request = request.override(system_message=prompt) else: @@ -1721,7 +1818,7 @@ def dynamic_prompt( handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[Any]]], ) -> ModelResponse[Any] | AIMessage: # Delegate to sync function - prompt = cast("Callable[[ModelRequest[ContextT]], SystemMessage | str]", func)(request) + prompt = cast("_SyncCallableReturningSystemMessage[ContextT]", func)(request) if isinstance(prompt, SystemMessage): request = request.override(system_message=prompt) else: @@ -1730,16 +1827,21 @@ def dynamic_prompt( middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware")) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": AgentState, - "tools": [], - "wrap_model_call": wrapped, - "awrap_model_call": async_wrapped_from_sync, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": AgentState, + "tools": [], + "wrap_model_call": wrapped, + "awrap_model_call": async_wrapped_from_sync, + }, + )(), + ) if func is not None: return decorator(func) @@ -1748,7 +1850,7 @@ def dynamic_prompt( @overload def wrap_model_call( - func: _CallableReturningModelResponse[StateT, ContextT, ResponseT], + func: _CallableReturningModelResponse[ContextT, ResponseT], ) -> AgentMiddleware[StateT, ContextT]: ... @@ -1760,20 +1862,20 @@ def wrap_model_call( tools: list[BaseTool] | None = None, name: str | None = None, ) -> Callable[ - [_CallableReturningModelResponse[StateT, ContextT, ResponseT]], + [_CallableReturningModelResponse[ContextT, ResponseT]], AgentMiddleware[StateT, ContextT], ]: ... def wrap_model_call( - func: _CallableReturningModelResponse[StateT, ContextT, ResponseT] | None = None, + func: _CallableReturningModelResponse[ContextT, ResponseT] | None = None, *, state_schema: type[StateT] | None = None, tools: list[BaseTool] | None = None, name: str | None = None, ) -> ( Callable[ - [_CallableReturningModelResponse[StateT, ContextT, ResponseT]], + [_CallableReturningModelResponse[ContextT, ResponseT]], AgentMiddleware[StateT, ContextT], ] | AgentMiddleware[StateT, ContextT] @@ -1854,7 +1956,7 @@ def wrap_model_call( """ def decorator( - func: _CallableReturningModelResponse[StateT, ContextT, ResponseT], + func: _CallableReturningModelResponse[ContextT, ResponseT], ) -> AgentMiddleware[StateT, ContextT]: is_async = iscoroutinefunction(func) @@ -1865,21 +1967,26 @@ def wrap_model_call( request: ModelRequest[ContextT], handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], ) -> ModelResponse[ResponseT] | AIMessage: - return await func(request, handler) # type: ignore[misc, arg-type] + return await func(request, handler) middleware_name = name or cast( "str", getattr(func, "__name__", "WrapModelCallMiddleware") ) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": state_schema or AgentState, - "tools": tools or [], - "awrap_model_call": async_wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "awrap_model_call": async_wrapped, + }, + )(), + ) def wrapped( _self: AgentMiddleware[StateT, ContextT], @@ -1890,15 +1997,20 @@ def wrap_model_call( middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware")) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": state_schema or AgentState, - "tools": tools or [], - "wrap_model_call": wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware[StateT, ContextT]", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "wrap_model_call": wrapped, + }, + )(), + ) if func is not None: return decorator(func) @@ -2025,21 +2137,26 @@ def wrap_tool_call( request: ToolCallRequest, handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], ) -> ToolMessage | Command[Any]: - return await func(request, handler) # type: ignore[arg-type,misc] + return await func(request, handler) middleware_name = name or cast( "str", getattr(func, "__name__", "WrapToolCallMiddleware") ) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": AgentState, - "tools": tools or [], - "awrap_tool_call": async_wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": AgentState, + "tools": tools or [], + "awrap_tool_call": async_wrapped, + }, + )(), + ) def wrapped( _self: AgentMiddleware, @@ -2050,15 +2167,20 @@ def wrap_tool_call( middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware")) - return type( - middleware_name, - (AgentMiddleware,), - { - "state_schema": AgentState, - "tools": tools or [], - "wrap_tool_call": wrapped, - }, - )() + # `type(...)` builds the correct middleware subclass at runtime, but + # type checkers cannot infer its generic `AgentMiddleware` parameters. + return cast( + "AgentMiddleware", + type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": AgentState, + "tools": tools or [], + "wrap_tool_call": wrapped, + }, + )(), + ) if func is not None: return decorator(func) diff --git a/libs/langchain_v1/langchain/agents/structured_output.py b/libs/langchain_v1/langchain/agents/structured_output.py index fda470bcee2..a482951b9da 100644 --- a/libs/langchain_v1/langchain/agents/structured_output.py +++ b/libs/langchain_v1/langchain/agents/structured_output.py @@ -76,7 +76,7 @@ class StructuredOutputValidationError(StructuredOutputError): def _parse_with_schema( schema: type[SchemaT] | dict[str, Any], schema_kind: SchemaKind, data: dict[str, Any] -) -> Any: +) -> SchemaT | dict[str, Any]: """Parse data using for any supported schema type. Args: @@ -92,9 +92,10 @@ def _parse_with_schema( ValueError: If parsing fails """ if schema_kind == "json_schema": + # Raw JSON schema has no corresponding Python type to instantiate. return data try: - adapter: TypeAdapter[SchemaT] = TypeAdapter(schema) + adapter = TypeAdapter[SchemaT](schema) return adapter.validate_python(data) except Exception as e: schema_name = getattr(schema, "__name__", str(schema)) @@ -344,7 +345,7 @@ class OutputToolBinding(Generic[SchemaT]): ), ) - def parse(self, tool_args: dict[str, Any]) -> SchemaT: + def parse(self, tool_args: dict[str, Any]) -> SchemaT | dict[str, Any]: """Parse tool arguments according to the schema. Args: @@ -391,7 +392,7 @@ class ProviderStrategyBinding(Generic[SchemaT]): schema_kind=schema_spec.schema_kind, ) - def parse(self, response: AIMessage) -> SchemaT: + def parse(self, response: AIMessage) -> SchemaT | dict[str, Any]: """Parse `AIMessage` content according to the schema. Args: diff --git a/libs/langchain_v1/langchain/chat_models/base.py b/libs/langchain_v1/langchain/chat_models/base.py index b210d1c74c6..89b7f9da1ab 100644 --- a/libs/langchain_v1/langchain/chat_models/base.py +++ b/libs/langchain_v1/langchain/chat_models/base.py @@ -1017,17 +1017,37 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): yield x # Explicitly added to satisfy downstream linters. + # `bind_tools` is implemented by concrete models because tool binding is + # provider-specific. A configurable model may not have a concrete model instance + # yet, since invocation config can choose it later. Save the `bind_tools` tools + # and kwargs now. When `_model` later builds the selected provider model, it calls + # `selected_model.bind_tools(tools, **kwargs)` and returns that runnable. + # Cast so callers still get the public return type. def bind_tools( self, tools: Sequence[dict[str, Any] | type[BaseModel] | Callable[..., Any] | BaseTool], **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: - return self.__getattr__("bind_tools")(tools, **kwargs) + return cast( + "Runnable[LanguageModelInput, AIMessage]", + self.__getattr__("bind_tools")(tools, **kwargs), + ) # Explicitly added to satisfy downstream linters. + # `with_structured_output` is implemented by concrete models because structured + # output support is provider-specific. A configurable model may not have a + # concrete model instance yet, since invocation config can choose it later. Save + # the structured-output schema and kwargs now. When `_model` later builds the + # selected provider model, it calls + # `selected_model.with_structured_output(schema, **kwargs)` and returns that + # runnable. + # Cast so callers still get the public return type. def with_structured_output( self, schema: dict[str, Any] | type[BaseModel], **kwargs: Any, ) -> Runnable[LanguageModelInput, dict[str, Any] | BaseModel]: - return self.__getattr__("with_structured_output")(schema, **kwargs) + return cast( + "Runnable[LanguageModelInput, dict[str, Any] | BaseModel]", + self.__getattr__("with_structured_output")(schema, **kwargs), + ) diff --git a/libs/langchain_v1/pyproject.toml b/libs/langchain_v1/pyproject.toml index a83ca6d59c5..3732cfae1e0 100644 --- a/libs/langchain_v1/pyproject.toml +++ b/libs/langchain_v1/pyproject.toml @@ -116,9 +116,6 @@ exclude = [ "tests/unit_tests/agents/test_.*\\.py", ] -# TODO: activate for 'strict' checking -warn_return_any = false - [[tool.mypy.overrides]] module = ["pytest_socket.*", "vcr.*"] ignore_missing_imports = true