core: Add ruff rules PYI (#29335)

See https://docs.astral.sh/ruff/rules/#flake8-pyi-pyi

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2025-04-04 21:59:44 +02:00 committed by GitHub
parent d8e3b7667f
commit 6650b94627
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 78 additions and 66 deletions

View File

@ -3,7 +3,9 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import Self
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
@ -879,9 +881,6 @@ class AsyncCallbackHandler(BaseCallbackHandler):
""" """
T = TypeVar("T", bound="BaseCallbackManager")
class BaseCallbackManager(CallbackManagerMixin): class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager for LangChain.""" """Base callback manager for LangChain."""
@ -920,7 +919,7 @@ class BaseCallbackManager(CallbackManagerMixin):
self.metadata = metadata or {} self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {} self.inheritable_metadata = inheritable_metadata or {}
def copy(self: T) -> T: def copy(self) -> Self:
"""Copy the callback manager.""" """Copy the callback manager."""
return self.__class__( return self.__class__(
handlers=self.handlers.copy(), handlers=self.handlers.copy(),
@ -932,7 +931,7 @@ class BaseCallbackManager(CallbackManagerMixin):
inheritable_metadata=self.inheritable_metadata.copy(), inheritable_metadata=self.inheritable_metadata.copy(),
) )
def merge(self: T, other: BaseCallbackManager) -> T: def merge(self, other: BaseCallbackManager) -> Self:
"""Merge the callback manager with another callback manager. """Merge the callback manager with another callback manager.
May be overwritten in subclasses. Primarily used internally May be overwritten in subclasses. Primarily used internally

View File

@ -22,6 +22,7 @@ from typing import (
from uuid import UUID from uuid import UUID
from langsmith.run_helpers import get_tracing_context from langsmith.run_helpers import get_tracing_context
from typing_extensions import Self
from langchain_core.callbacks.base import ( from langchain_core.callbacks.base import (
BaseCallbackHandler, BaseCallbackHandler,
@ -444,9 +445,6 @@ async def ahandle_event(
) )
BRM = TypeVar("BRM", bound="BaseRunManager")
class BaseRunManager(RunManagerMixin): class BaseRunManager(RunManagerMixin):
"""Base class for run manager (a bound callback manager).""" """Base class for run manager (a bound callback manager)."""
@ -489,7 +487,7 @@ class BaseRunManager(RunManagerMixin):
self.inheritable_metadata = inheritable_metadata or {} self.inheritable_metadata = inheritable_metadata or {}
@classmethod @classmethod
def get_noop_manager(cls: type[BRM]) -> BRM: def get_noop_manager(cls) -> Self:
"""Return a manager that doesn't perform any operations. """Return a manager that doesn't perform any operations.
Returns: Returns:

View File

@ -1258,7 +1258,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
], ],
*, *,
tool_choice: Optional[Union[str, Literal["any"]]] = None, tool_choice: Optional[Union[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the model. """Bind tools to the model.

View File

@ -22,7 +22,7 @@ from pydantic import (
SkipValidation, SkipValidation,
model_validator, model_validator,
) )
from typing_extensions import override from typing_extensions import Self, override
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.load import Serializable from langchain_core.load import Serializable
@ -304,12 +304,12 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@classmethod @classmethod
def from_template( def from_template(
cls: type[MessagePromptTemplateT], cls,
template: str, template: str,
template_format: PromptTemplateFormat = "f-string", template_format: PromptTemplateFormat = "f-string",
partial_variables: Optional[dict[str, Any]] = None, partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> MessagePromptTemplateT: ) -> Self:
"""Create a class from a string template. """Create a class from a string template.
Args: Args:
@ -335,11 +335,11 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@classmethod @classmethod
def from_template_file( def from_template_file(
cls: type[MessagePromptTemplateT], cls,
template_file: Union[str, Path], template_file: Union[str, Path],
input_variables: list[str], input_variables: list[str],
**kwargs: Any, **kwargs: Any,
) -> MessagePromptTemplateT: ) -> Self:
"""Create a class from a template file. """Create a class from a template file.
Args: Args:
@ -456,11 +456,6 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
) )
_StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
)
class _TextTemplateParam(TypedDict, total=False): class _TextTemplateParam(TypedDict, total=False):
text: Union[str, dict] text: Union[str, dict]
@ -483,13 +478,13 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@classmethod @classmethod
def from_template( def from_template(
cls: type[_StringImageMessagePromptTemplateT], cls: type[Self],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]], template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: PromptTemplateFormat = "f-string", template_format: PromptTemplateFormat = "f-string",
*, *,
partial_variables: Optional[dict[str, Any]] = None, partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> _StringImageMessagePromptTemplateT: ) -> Self:
"""Create a class from a string template. """Create a class from a string template.
Args: Args:
@ -576,11 +571,11 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@classmethod @classmethod
def from_template_file( def from_template_file(
cls: type[_StringImageMessagePromptTemplateT], cls: type[Self],
template_file: Union[str, Path], template_file: Union[str, Path],
input_variables: list[str], input_variables: list[str],
**kwargs: Any, **kwargs: Any,
) -> _StringImageMessagePromptTemplateT: ) -> Self:
"""Create a class from a template file. """Create a class from a template file.
Args: Args:

View File

@ -4200,7 +4200,7 @@ class RunnableGenerator(Runnable[Input, Output]):
) )
@override @override
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
if isinstance(other, RunnableGenerator): if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"): if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform return self._transform == other._transform
@ -4582,7 +4582,7 @@ class RunnableLambda(Runnable[Input, Output]):
return graph return graph
@override @override
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
if isinstance(other, RunnableLambda): if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"): if hasattr(self, "func") and hasattr(other, "func"):
return self.func == other.func return self.func == other.func
@ -5880,22 +5880,24 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
class _RunnableCallableSync(Protocol[Input, Output]): class _RunnableCallableSync(Protocol[Input, Output]):
def __call__(self, __in: Input, *, config: RunnableConfig) -> Output: ... def __call__(self, _in: Input, /, *, config: RunnableConfig) -> Output: ...
class _RunnableCallableAsync(Protocol[Input, Output]): class _RunnableCallableAsync(Protocol[Input, Output]):
def __call__(self, __in: Input, *, config: RunnableConfig) -> Awaitable[Output]: ... def __call__(
self, _in: Input, /, *, config: RunnableConfig
) -> Awaitable[Output]: ...
class _RunnableCallableIterator(Protocol[Input, Output]): class _RunnableCallableIterator(Protocol[Input, Output]):
def __call__( def __call__(
self, __in: Iterator[Input], *, config: RunnableConfig self, _in: Iterator[Input], /, *, config: RunnableConfig
) -> Iterator[Output]: ... ) -> Iterator[Output]: ...
class _RunnableCallableAsyncIterator(Protocol[Input, Output]): class _RunnableCallableAsyncIterator(Protocol[Input, Output]):
def __call__( def __call__(
self, __in: AsyncIterator[Input], *, config: RunnableConfig self, _in: AsyncIterator[Input], /, *, config: RunnableConfig
) -> AsyncIterator[Output]: ... ) -> AsyncIterator[Output]: ...

View File

@ -515,7 +515,7 @@ _T_contra = TypeVar("_T_contra", contravariant=True)
class SupportsAdd(Protocol[_T_contra, _T_co]): class SupportsAdd(Protocol[_T_contra, _T_co]):
"""Protocol for objects that support addition.""" """Protocol for objects that support addition."""
def __add__(self, __x: _T_contra) -> _T_co: def __add__(self, x: _T_contra, /) -> _T_co:
"""Add the object to another object.""" """Add the object to another object."""

View File

@ -88,7 +88,12 @@ class NoLock:
async def __aenter__(self) -> None: async def __aenter__(self) -> None:
"""Do nothing.""" """Do nothing."""
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
"""Exception not handled.""" """Exception not handled."""
return False return False
@ -237,7 +242,12 @@ class Tee(Generic[T]):
"""Return the tee instance.""" """Return the tee instance."""
return self return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
"""Close all child iterators.""" """Close all child iterators."""
await self.aclose() await self.aclose()
return False return False

View File

@ -4,6 +4,7 @@ from collections import deque
from collections.abc import Generator, Iterable, Iterator from collections.abc import Generator, Iterable, Iterator
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from itertools import islice from itertools import islice
from types import TracebackType
from typing import ( from typing import (
Any, Any,
Generic, Generic,
@ -24,7 +25,12 @@ class NoLock:
def __enter__(self) -> None: def __enter__(self) -> None:
"""Do nothing.""" """Do nothing."""
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
"""Exception not handled.""" """Exception not handled."""
return False return False
@ -173,7 +179,12 @@ class Tee(Generic[T]):
"""Return Tee instance.""" """Return Tee instance."""
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
"""Close all child iterators.""" """Close all child iterators."""
self.close() self.close()
return False return False

View File

@ -377,12 +377,7 @@ if IS_PYDANTIC_V2:
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ... def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
def get_fields( def get_fields(
model: Union[ model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
BaseModelV2,
BaseModelV1,
type[BaseModelV2],
type[BaseModelV1],
],
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]: ) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model.""" """Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"): if hasattr(model, "model_fields"):
@ -491,19 +486,21 @@ def _create_root_model_cached(
@lru_cache(maxsize=256) @lru_cache(maxsize=256)
def _create_model_cached( def _create_model_cached(
__model_name: str, model_name: str,
/,
**field_definitions: Any, **field_definitions: Any,
) -> type[BaseModel]: ) -> type[BaseModel]:
return _create_model_base( return _create_model_base(
__model_name, model_name,
__config__=_SchemaConfig, __config__=_SchemaConfig,
**_remap_field_definitions(field_definitions), **_remap_field_definitions(field_definitions),
) )
def create_model( def create_model(
__model_name: str, model_name: str,
__module_name: Optional[str] = None, module_name: Optional[str] = None,
/,
**field_definitions: Any, **field_definitions: Any,
) -> type[BaseModel]: ) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions. """Create a pydantic model with the given field definitions.
@ -511,8 +508,8 @@ def create_model(
Please use create_model_v2 instead of this function. Please use create_model_v2 instead of this function.
Args: Args:
__model_name: The name of the model. model_name: The name of the model.
__module_name: The name of the module where the model is defined. module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references. This is used by Pydantic to resolve any forward references.
**field_definitions: The field definitions for the model. **field_definitions: The field definitions for the model.
@ -524,8 +521,8 @@ def create_model(
kwargs["root"] = field_definitions.pop("__root__") kwargs["root"] = field_definitions.pop("__root__")
return create_model_v2( return create_model_v2(
__model_name, model_name,
module_name=__module_name, module_name=module_name,
field_definitions=field_definitions, field_definitions=field_definitions,
**kwargs, **kwargs,
) )

View File

@ -36,6 +36,7 @@ from typing import (
) )
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
@ -818,11 +819,11 @@ class VectorStore(ABC):
@classmethod @classmethod
def from_documents( def from_documents(
cls: type[VST], cls,
documents: list[Document], documents: list[Document],
embedding: Embeddings, embedding: Embeddings,
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> Self:
"""Return VectorStore initialized from documents and embeddings. """Return VectorStore initialized from documents and embeddings.
Args: Args:
@ -848,11 +849,11 @@ class VectorStore(ABC):
@classmethod @classmethod
async def afrom_documents( async def afrom_documents(
cls: type[VST], cls,
documents: list[Document], documents: list[Document],
embedding: Embeddings, embedding: Embeddings,
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> Self:
"""Async return VectorStore initialized from documents and embeddings. """Async return VectorStore initialized from documents and embeddings.
Args: Args:
@ -903,14 +904,14 @@ class VectorStore(ABC):
@classmethod @classmethod
async def afrom_texts( async def afrom_texts(
cls: type[VST], cls,
texts: list[str], texts: list[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[list[dict]] = None, metadatas: Optional[list[dict]] = None,
*, *,
ids: Optional[list[str]] = None, ids: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> Self:
"""Async return VectorStore initialized from texts and embeddings. """Async return VectorStore initialized from texts and embeddings.
Args: Args:

View File

@ -103,7 +103,6 @@ ignore = [
"FBT002", "FBT002",
"PGH003", "PGH003",
"PLR", "PLR",
"PYI",
"RUF", "RUF",
"SLF", "SLF",
] ]

View File

@ -5294,22 +5294,22 @@ async def test_ainvoke_on_returned_runnable() -> None:
be runthroughaasync path (issue #13407). be runthroughaasync path (issue #13407).
""" """
def idchain_sync(__input: dict) -> bool: def idchain_sync(_input: dict, /) -> bool:
return False return False
async def idchain_async(__input: dict) -> bool: async def idchain_async(_input: dict, /) -> bool:
return True return True
idchain = RunnableLambda(func=idchain_sync, afunc=idchain_async) idchain = RunnableLambda(func=idchain_sync, afunc=idchain_async)
def func(__input: dict) -> Runnable: def func(_input: dict, /) -> Runnable:
return idchain return idchain
assert await RunnableLambda(func).ainvoke({}) assert await RunnableLambda(func).ainvoke({})
def test_invoke_stream_passthrough_assign_trace() -> None: def test_invoke_stream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool: def idchain_sync(_input: dict, /) -> bool:
return False return False
chain = RunnablePassthrough.assign(urls=idchain_sync) chain = RunnablePassthrough.assign(urls=idchain_sync)
@ -5329,7 +5329,7 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
async def test_ainvoke_astream_passthrough_assign_trace() -> None: async def test_ainvoke_astream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool: def idchain_sync(_input: dict, /) -> bool:
return False return False
chain = RunnablePassthrough.assign(urls=idchain_sync) chain = RunnablePassthrough.assign(urls=idchain_sync)

View File

@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
class AnyStr(str): class AnyStr(str):
__slots__ = () __slots__ = ()
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, str) return isinstance(other, str)

View File

@ -2369,7 +2369,7 @@ def test_tool_return_output_mixin() -> None:
def __init__(self, x: int) -> None: def __init__(self, x: int) -> None:
self.x = x self.x = x
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.x == other.x return isinstance(other, self.__class__) and self.x == other.x
@tool @tool

View File

@ -994,12 +994,12 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
) )
def test_convert_union_type_py_39() -> None: def test_convert_union_type_py_39() -> None:
@tool @tool
def magic_function(input: int | float) -> str: # noqa: FA102 def magic_function(input: int | str) -> str: # noqa: FA102
"""Compute a magic function.""" """Compute a magic function."""
result = convert_to_openai_function(magic_function) result = convert_to_openai_function(magic_function)
assert result["parameters"]["properties"]["input"] == { assert result["parameters"]["properties"]["input"] == {
"anyOf": [{"type": "integer"}, {"type": "number"}] "anyOf": [{"type": "integer"}, {"type": "string"}]
} }