core: Add ruff rules PYI (#29335)

See https://docs.astral.sh/ruff/rules/#flake8-pyi-pyi

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2025-04-04 21:59:44 +02:00 committed by GitHub
parent d8e3b7667f
commit 6650b94627
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 78 additions and 66 deletions

View File

@ -3,7 +3,9 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import Self
if TYPE_CHECKING:
from collections.abc import Sequence
@ -879,9 +881,6 @@ class AsyncCallbackHandler(BaseCallbackHandler):
"""
T = TypeVar("T", bound="BaseCallbackManager")
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager for LangChain."""
@ -920,7 +919,7 @@ class BaseCallbackManager(CallbackManagerMixin):
self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {}
def copy(self: T) -> T:
def copy(self) -> Self:
"""Copy the callback manager."""
return self.__class__(
handlers=self.handlers.copy(),
@ -932,7 +931,7 @@ class BaseCallbackManager(CallbackManagerMixin):
inheritable_metadata=self.inheritable_metadata.copy(),
)
def merge(self: T, other: BaseCallbackManager) -> T:
def merge(self, other: BaseCallbackManager) -> Self:
"""Merge the callback manager with another callback manager.
May be overwritten in subclasses. Primarily used internally

View File

@ -22,6 +22,7 @@ from typing import (
from uuid import UUID
from langsmith.run_helpers import get_tracing_context
from typing_extensions import Self
from langchain_core.callbacks.base import (
BaseCallbackHandler,
@ -444,9 +445,6 @@ async def ahandle_event(
)
BRM = TypeVar("BRM", bound="BaseRunManager")
class BaseRunManager(RunManagerMixin):
"""Base class for run manager (a bound callback manager)."""
@ -489,7 +487,7 @@ class BaseRunManager(RunManagerMixin):
self.inheritable_metadata = inheritable_metadata or {}
@classmethod
def get_noop_manager(cls: type[BRM]) -> BRM:
def get_noop_manager(cls) -> Self:
"""Return a manager that doesn't perform any operations.
Returns:

View File

@ -1258,7 +1258,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
],
*,
tool_choice: Optional[Union[str, Literal["any"]]] = None,
tool_choice: Optional[Union[str]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the model.

View File

@ -22,7 +22,7 @@ from pydantic import (
SkipValidation,
model_validator,
)
from typing_extensions import override
from typing_extensions import Self, override
from langchain_core._api import deprecated
from langchain_core.load import Serializable
@ -304,12 +304,12 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@classmethod
def from_template(
cls: type[MessagePromptTemplateT],
cls,
template: str,
template_format: PromptTemplateFormat = "f-string",
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> MessagePromptTemplateT:
) -> Self:
"""Create a class from a string template.
Args:
@ -335,11 +335,11 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@classmethod
def from_template_file(
cls: type[MessagePromptTemplateT],
cls,
template_file: Union[str, Path],
input_variables: list[str],
**kwargs: Any,
) -> MessagePromptTemplateT:
) -> Self:
"""Create a class from a template file.
Args:
@ -456,11 +456,6 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
)
_StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
)
class _TextTemplateParam(TypedDict, total=False):
text: Union[str, dict]
@ -483,13 +478,13 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@classmethod
def from_template(
cls: type[_StringImageMessagePromptTemplateT],
cls: type[Self],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: PromptTemplateFormat = "f-string",
*,
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
) -> Self:
"""Create a class from a string template.
Args:
@ -576,11 +571,11 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@classmethod
def from_template_file(
cls: type[_StringImageMessagePromptTemplateT],
cls: type[Self],
template_file: Union[str, Path],
input_variables: list[str],
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
) -> Self:
"""Create a class from a template file.
Args:

View File

@ -4200,7 +4200,7 @@ class RunnableGenerator(Runnable[Input, Output]):
)
@override
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
@ -4582,7 +4582,7 @@ class RunnableLambda(Runnable[Input, Output]):
return graph
@override
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"):
return self.func == other.func
@ -5880,22 +5880,24 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
class _RunnableCallableSync(Protocol[Input, Output]):
def __call__(self, __in: Input, *, config: RunnableConfig) -> Output: ...
def __call__(self, _in: Input, /, *, config: RunnableConfig) -> Output: ...
class _RunnableCallableAsync(Protocol[Input, Output]):
def __call__(self, __in: Input, *, config: RunnableConfig) -> Awaitable[Output]: ...
def __call__(
self, _in: Input, /, *, config: RunnableConfig
) -> Awaitable[Output]: ...
class _RunnableCallableIterator(Protocol[Input, Output]):
def __call__(
self, __in: Iterator[Input], *, config: RunnableConfig
self, _in: Iterator[Input], /, *, config: RunnableConfig
) -> Iterator[Output]: ...
class _RunnableCallableAsyncIterator(Protocol[Input, Output]):
def __call__(
self, __in: AsyncIterator[Input], *, config: RunnableConfig
self, _in: AsyncIterator[Input], /, *, config: RunnableConfig
) -> AsyncIterator[Output]: ...

View File

@ -515,7 +515,7 @@ _T_contra = TypeVar("_T_contra", contravariant=True)
class SupportsAdd(Protocol[_T_contra, _T_co]):
"""Protocol for objects that support addition."""
def __add__(self, __x: _T_contra) -> _T_co:
def __add__(self, x: _T_contra, /) -> _T_co:
"""Add the object to another object."""

View File

@ -88,7 +88,12 @@ class NoLock:
async def __aenter__(self) -> None:
"""Do nothing."""
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
"""Exception not handled."""
return False
@ -237,7 +242,12 @@ class Tee(Generic[T]):
"""Return the tee instance."""
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
"""Close all child iterators."""
await self.aclose()
return False

View File

@ -4,6 +4,7 @@ from collections import deque
from collections.abc import Generator, Iterable, Iterator
from contextlib import AbstractContextManager
from itertools import islice
from types import TracebackType
from typing import (
Any,
Generic,
@ -24,7 +25,12 @@ class NoLock:
def __enter__(self) -> None:
"""Do nothing."""
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
"""Exception not handled."""
return False
@ -173,7 +179,12 @@ class Tee(Generic[T]):
"""Return Tee instance."""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
"""Close all child iterators."""
self.close()
return False

View File

@ -377,12 +377,7 @@ if IS_PYDANTIC_V2:
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
def get_fields(
model: Union[
BaseModelV2,
BaseModelV1,
type[BaseModelV2],
type[BaseModelV1],
],
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
@ -491,19 +486,21 @@ def _create_root_model_cached(
@lru_cache(maxsize=256)
def _create_model_cached(
__model_name: str,
model_name: str,
/,
**field_definitions: Any,
) -> type[BaseModel]:
return _create_model_base(
__model_name,
model_name,
__config__=_SchemaConfig,
**_remap_field_definitions(field_definitions),
)
def create_model(
__model_name: str,
__module_name: Optional[str] = None,
model_name: str,
module_name: Optional[str] = None,
/,
**field_definitions: Any,
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.
@ -511,8 +508,8 @@ def create_model(
Please use create_model_v2 instead of this function.
Args:
__model_name: The name of the model.
__module_name: The name of the module where the model is defined.
model_name: The name of the model.
module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references.
**field_definitions: The field definitions for the model.
@ -524,8 +521,8 @@ def create_model(
kwargs["root"] = field_definitions.pop("__root__")
return create_model_v2(
__model_name,
module_name=__module_name,
model_name,
module_name=module_name,
field_definitions=field_definitions,
**kwargs,
)

View File

@ -36,6 +36,7 @@ from typing import (
)
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
@ -818,11 +819,11 @@ class VectorStore(ABC):
@classmethod
def from_documents(
cls: type[VST],
cls,
documents: list[Document],
embedding: Embeddings,
**kwargs: Any,
) -> VST:
) -> Self:
"""Return VectorStore initialized from documents and embeddings.
Args:
@ -848,11 +849,11 @@ class VectorStore(ABC):
@classmethod
async def afrom_documents(
cls: type[VST],
cls,
documents: list[Document],
embedding: Embeddings,
**kwargs: Any,
) -> VST:
) -> Self:
"""Async return VectorStore initialized from documents and embeddings.
Args:
@ -903,14 +904,14 @@ class VectorStore(ABC):
@classmethod
async def afrom_texts(
cls: type[VST],
cls,
texts: list[str],
embedding: Embeddings,
metadatas: Optional[list[dict]] = None,
*,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> VST:
) -> Self:
"""Async return VectorStore initialized from texts and embeddings.
Args:

View File

@ -103,7 +103,6 @@ ignore = [
"FBT002",
"PGH003",
"PLR",
"PYI",
"RUF",
"SLF",
]

View File

@ -5294,22 +5294,22 @@ async def test_ainvoke_on_returned_runnable() -> None:
be runthroughaasync path (issue #13407).
"""
def idchain_sync(__input: dict) -> bool:
def idchain_sync(_input: dict, /) -> bool:
return False
async def idchain_async(__input: dict) -> bool:
async def idchain_async(_input: dict, /) -> bool:
return True
idchain = RunnableLambda(func=idchain_sync, afunc=idchain_async)
def func(__input: dict) -> Runnable:
def func(_input: dict, /) -> Runnable:
return idchain
assert await RunnableLambda(func).ainvoke({})
def test_invoke_stream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
def idchain_sync(_input: dict, /) -> bool:
return False
chain = RunnablePassthrough.assign(urls=idchain_sync)
@ -5329,7 +5329,7 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
def idchain_sync(_input: dict, /) -> bool:
return False
chain = RunnablePassthrough.assign(urls=idchain_sync)

View File

@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
class AnyStr(str):
__slots__ = ()
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, str)

View File

@ -2369,7 +2369,7 @@ def test_tool_return_output_mixin() -> None:
def __init__(self, x: int) -> None:
self.x = x
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.x == other.x
@tool

View File

@ -994,12 +994,12 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
)
def test_convert_union_type_py_39() -> None:
@tool
def magic_function(input: int | float) -> str: # noqa: FA102
def magic_function(input: int | str) -> str: # noqa: FA102
"""Compute a magic function."""
result = convert_to_openai_function(magic_function)
assert result["parameters"]["properties"]["input"] == {
"anyOf": [{"type": "integer"}, {"type": "number"}]
"anyOf": [{"type": "integer"}, {"type": "string"}]
}