mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
core: Add ruff rules PYI (#29335)
See https://docs.astral.sh/ruff/rules/#flake8-pyi-pyi --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
d8e3b7667f
commit
6650b94627
@ -3,7 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
@ -879,9 +881,6 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseCallbackManager")
|
||||
|
||||
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""Base callback manager for LangChain."""
|
||||
|
||||
@ -920,7 +919,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
def copy(self: T) -> T:
|
||||
def copy(self) -> Self:
|
||||
"""Copy the callback manager."""
|
||||
return self.__class__(
|
||||
handlers=self.handlers.copy(),
|
||||
@ -932,7 +931,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||
)
|
||||
|
||||
def merge(self: T, other: BaseCallbackManager) -> T:
|
||||
def merge(self, other: BaseCallbackManager) -> Self:
|
||||
"""Merge the callback manager with another callback manager.
|
||||
|
||||
May be overwritten in subclasses. Primarily used internally
|
||||
|
@ -22,6 +22,7 @@ from typing import (
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith.run_helpers import get_tracing_context
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
@ -444,9 +445,6 @@ async def ahandle_event(
|
||||
)
|
||||
|
||||
|
||||
BRM = TypeVar("BRM", bound="BaseRunManager")
|
||||
|
||||
|
||||
class BaseRunManager(RunManagerMixin):
|
||||
"""Base class for run manager (a bound callback manager)."""
|
||||
|
||||
@ -489,7 +487,7 @@ class BaseRunManager(RunManagerMixin):
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
@classmethod
|
||||
def get_noop_manager(cls: type[BRM]) -> BRM:
|
||||
def get_noop_manager(cls) -> Self:
|
||||
"""Return a manager that doesn't perform any operations.
|
||||
|
||||
Returns:
|
||||
|
@ -1258,7 +1258,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
||||
],
|
||||
*,
|
||||
tool_choice: Optional[Union[str, Literal["any"]]] = None,
|
||||
tool_choice: Optional[Union[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tools to the model.
|
||||
|
@ -22,7 +22,7 @@ from pydantic import (
|
||||
SkipValidation,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import override
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.load import Serializable
|
||||
@ -304,12 +304,12 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls: type[MessagePromptTemplateT],
|
||||
cls,
|
||||
template: str,
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
partial_variables: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
) -> Self:
|
||||
"""Create a class from a string template.
|
||||
|
||||
Args:
|
||||
@ -335,11 +335,11 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
|
||||
@classmethod
|
||||
def from_template_file(
|
||||
cls: type[MessagePromptTemplateT],
|
||||
cls,
|
||||
template_file: Union[str, Path],
|
||||
input_variables: list[str],
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
) -> Self:
|
||||
"""Create a class from a template file.
|
||||
|
||||
Args:
|
||||
@ -456,11 +456,6 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
)
|
||||
|
||||
|
||||
_StringImageMessagePromptTemplateT = TypeVar(
|
||||
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
|
||||
)
|
||||
|
||||
|
||||
class _TextTemplateParam(TypedDict, total=False):
|
||||
text: Union[str, dict]
|
||||
|
||||
@ -483,13 +478,13 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls: type[_StringImageMessagePromptTemplateT],
|
||||
cls: type[Self],
|
||||
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
*,
|
||||
partial_variables: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> _StringImageMessagePromptTemplateT:
|
||||
) -> Self:
|
||||
"""Create a class from a string template.
|
||||
|
||||
Args:
|
||||
@ -576,11 +571,11 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
|
||||
@classmethod
|
||||
def from_template_file(
|
||||
cls: type[_StringImageMessagePromptTemplateT],
|
||||
cls: type[Self],
|
||||
template_file: Union[str, Path],
|
||||
input_variables: list[str],
|
||||
**kwargs: Any,
|
||||
) -> _StringImageMessagePromptTemplateT:
|
||||
) -> Self:
|
||||
"""Create a class from a template file.
|
||||
|
||||
Args:
|
||||
|
@ -4200,7 +4200,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
@override
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, RunnableGenerator):
|
||||
if hasattr(self, "_transform") and hasattr(other, "_transform"):
|
||||
return self._transform == other._transform
|
||||
@ -4582,7 +4582,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return graph
|
||||
|
||||
@override
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, RunnableLambda):
|
||||
if hasattr(self, "func") and hasattr(other, "func"):
|
||||
return self.func == other.func
|
||||
@ -5880,22 +5880,24 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
|
||||
|
||||
class _RunnableCallableSync(Protocol[Input, Output]):
|
||||
def __call__(self, __in: Input, *, config: RunnableConfig) -> Output: ...
|
||||
def __call__(self, _in: Input, /, *, config: RunnableConfig) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableCallableAsync(Protocol[Input, Output]):
|
||||
def __call__(self, __in: Input, *, config: RunnableConfig) -> Awaitable[Output]: ...
|
||||
def __call__(
|
||||
self, _in: Input, /, *, config: RunnableConfig
|
||||
) -> Awaitable[Output]: ...
|
||||
|
||||
|
||||
class _RunnableCallableIterator(Protocol[Input, Output]):
|
||||
def __call__(
|
||||
self, __in: Iterator[Input], *, config: RunnableConfig
|
||||
self, _in: Iterator[Input], /, *, config: RunnableConfig
|
||||
) -> Iterator[Output]: ...
|
||||
|
||||
|
||||
class _RunnableCallableAsyncIterator(Protocol[Input, Output]):
|
||||
def __call__(
|
||||
self, __in: AsyncIterator[Input], *, config: RunnableConfig
|
||||
self, _in: AsyncIterator[Input], /, *, config: RunnableConfig
|
||||
) -> AsyncIterator[Output]: ...
|
||||
|
||||
|
||||
|
@ -515,7 +515,7 @@ _T_contra = TypeVar("_T_contra", contravariant=True)
|
||||
class SupportsAdd(Protocol[_T_contra, _T_co]):
|
||||
"""Protocol for objects that support addition."""
|
||||
|
||||
def __add__(self, __x: _T_contra) -> _T_co:
|
||||
def __add__(self, x: _T_contra, /) -> _T_co:
|
||||
"""Add the object to another object."""
|
||||
|
||||
|
||||
|
@ -88,7 +88,12 @@ class NoLock:
|
||||
async def __aenter__(self) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
"""Exception not handled."""
|
||||
return False
|
||||
|
||||
@ -237,7 +242,12 @@ class Tee(Generic[T]):
|
||||
"""Return the tee instance."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
"""Close all child iterators."""
|
||||
await self.aclose()
|
||||
return False
|
||||
|
@ -4,6 +4,7 @@ from collections import deque
|
||||
from collections.abc import Generator, Iterable, Iterator
|
||||
from contextlib import AbstractContextManager
|
||||
from itertools import islice
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
@ -24,7 +25,12 @@ class NoLock:
|
||||
def __enter__(self) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""Exception not handled."""
|
||||
return False
|
||||
|
||||
@ -173,7 +179,12 @@ class Tee(Generic[T]):
|
||||
"""Return Tee instance."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""Close all child iterators."""
|
||||
self.close()
|
||||
return False
|
||||
|
@ -377,12 +377,7 @@ if IS_PYDANTIC_V2:
|
||||
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
|
||||
|
||||
def get_fields(
|
||||
model: Union[
|
||||
BaseModelV2,
|
||||
BaseModelV1,
|
||||
type[BaseModelV2],
|
||||
type[BaseModelV1],
|
||||
],
|
||||
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
|
||||
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
if hasattr(model, "model_fields"):
|
||||
@ -491,19 +486,21 @@ def _create_root_model_cached(
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _create_model_cached(
|
||||
__model_name: str,
|
||||
model_name: str,
|
||||
/,
|
||||
**field_definitions: Any,
|
||||
) -> type[BaseModel]:
|
||||
return _create_model_base(
|
||||
__model_name,
|
||||
model_name,
|
||||
__config__=_SchemaConfig,
|
||||
**_remap_field_definitions(field_definitions),
|
||||
)
|
||||
|
||||
|
||||
def create_model(
|
||||
__model_name: str,
|
||||
__module_name: Optional[str] = None,
|
||||
model_name: str,
|
||||
module_name: Optional[str] = None,
|
||||
/,
|
||||
**field_definitions: Any,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with the given field definitions.
|
||||
@ -511,8 +508,8 @@ def create_model(
|
||||
Please use create_model_v2 instead of this function.
|
||||
|
||||
Args:
|
||||
__model_name: The name of the model.
|
||||
__module_name: The name of the module where the model is defined.
|
||||
model_name: The name of the model.
|
||||
module_name: The name of the module where the model is defined.
|
||||
This is used by Pydantic to resolve any forward references.
|
||||
**field_definitions: The field definitions for the model.
|
||||
|
||||
@ -524,8 +521,8 @@ def create_model(
|
||||
kwargs["root"] = field_definitions.pop("__root__")
|
||||
|
||||
return create_model_v2(
|
||||
__model_name,
|
||||
module_name=__module_name,
|
||||
model_name,
|
||||
module_name=module_name,
|
||||
field_definitions=field_definitions,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -36,6 +36,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
||||
@ -818,11 +819,11 @@ class VectorStore(ABC):
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: type[VST],
|
||||
cls,
|
||||
documents: list[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
) -> Self:
|
||||
"""Return VectorStore initialized from documents and embeddings.
|
||||
|
||||
Args:
|
||||
@ -848,11 +849,11 @@ class VectorStore(ABC):
|
||||
|
||||
@classmethod
|
||||
async def afrom_documents(
|
||||
cls: type[VST],
|
||||
cls,
|
||||
documents: list[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
) -> Self:
|
||||
"""Async return VectorStore initialized from documents and embeddings.
|
||||
|
||||
Args:
|
||||
@ -903,14 +904,14 @@ class VectorStore(ABC):
|
||||
|
||||
@classmethod
|
||||
async def afrom_texts(
|
||||
cls: type[VST],
|
||||
cls,
|
||||
texts: list[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
*,
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
) -> Self:
|
||||
"""Async return VectorStore initialized from texts and embeddings.
|
||||
|
||||
Args:
|
||||
|
@ -103,7 +103,6 @@ ignore = [
|
||||
"FBT002",
|
||||
"PGH003",
|
||||
"PLR",
|
||||
"PYI",
|
||||
"RUF",
|
||||
"SLF",
|
||||
]
|
||||
|
@ -5294,22 +5294,22 @@ async def test_ainvoke_on_returned_runnable() -> None:
|
||||
be runthroughaasync path (issue #13407).
|
||||
"""
|
||||
|
||||
def idchain_sync(__input: dict) -> bool:
|
||||
def idchain_sync(_input: dict, /) -> bool:
|
||||
return False
|
||||
|
||||
async def idchain_async(__input: dict) -> bool:
|
||||
async def idchain_async(_input: dict, /) -> bool:
|
||||
return True
|
||||
|
||||
idchain = RunnableLambda(func=idchain_sync, afunc=idchain_async)
|
||||
|
||||
def func(__input: dict) -> Runnable:
|
||||
def func(_input: dict, /) -> Runnable:
|
||||
return idchain
|
||||
|
||||
assert await RunnableLambda(func).ainvoke({})
|
||||
|
||||
|
||||
def test_invoke_stream_passthrough_assign_trace() -> None:
|
||||
def idchain_sync(__input: dict) -> bool:
|
||||
def idchain_sync(_input: dict, /) -> bool:
|
||||
return False
|
||||
|
||||
chain = RunnablePassthrough.assign(urls=idchain_sync)
|
||||
@ -5329,7 +5329,7 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
|
||||
|
||||
|
||||
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
||||
def idchain_sync(__input: dict) -> bool:
|
||||
def idchain_sync(_input: dict, /) -> bool:
|
||||
return False
|
||||
|
||||
chain = RunnablePassthrough.assign(urls=idchain_sync)
|
||||
|
@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||
class AnyStr(str):
|
||||
__slots__ = ()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, str)
|
||||
|
||||
|
||||
|
@ -2369,7 +2369,7 @@ def test_tool_return_output_mixin() -> None:
|
||||
def __init__(self, x: int) -> None:
|
||||
self.x = x
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.x == other.x
|
||||
|
||||
@tool
|
||||
|
@ -994,12 +994,12 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
|
||||
)
|
||||
def test_convert_union_type_py_39() -> None:
|
||||
@tool
|
||||
def magic_function(input: int | float) -> str: # noqa: FA102
|
||||
def magic_function(input: int | str) -> str: # noqa: FA102
|
||||
"""Compute a magic function."""
|
||||
|
||||
result = convert_to_openai_function(magic_function)
|
||||
assert result["parameters"]["properties"]["input"] == {
|
||||
"anyOf": [{"type": "integer"}, {"type": "number"}]
|
||||
"anyOf": [{"type": "integer"}, {"type": "string"}]
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user