diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 34a9120c96c..f9a54094a66 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -3,7 +3,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, Union + +from typing_extensions import Self if TYPE_CHECKING: from collections.abc import Sequence @@ -879,9 +881,6 @@ class AsyncCallbackHandler(BaseCallbackHandler): """ -T = TypeVar("T", bound="BaseCallbackManager") - - class BaseCallbackManager(CallbackManagerMixin): """Base callback manager for LangChain.""" @@ -920,7 +919,7 @@ class BaseCallbackManager(CallbackManagerMixin): self.metadata = metadata or {} self.inheritable_metadata = inheritable_metadata or {} - def copy(self: T) -> T: + def copy(self) -> Self: """Copy the callback manager.""" return self.__class__( handlers=self.handlers.copy(), @@ -932,7 +931,7 @@ class BaseCallbackManager(CallbackManagerMixin): inheritable_metadata=self.inheritable_metadata.copy(), ) - def merge(self: T, other: BaseCallbackManager) -> T: + def merge(self, other: BaseCallbackManager) -> Self: """Merge the callback manager with another callback manager. May be overwritten in subclasses. Primarily used internally diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 5b1db481953..12287c330e9 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -22,6 +22,7 @@ from typing import ( from uuid import UUID from langsmith.run_helpers import get_tracing_context +from typing_extensions import Self from langchain_core.callbacks.base import ( BaseCallbackHandler, @@ -444,9 +445,6 @@ async def ahandle_event( ) -BRM = TypeVar("BRM", bound="BaseRunManager") - - class BaseRunManager(RunManagerMixin): """Base class for run manager (a bound callback manager).""" @@ -489,7 +487,7 @@ class BaseRunManager(RunManagerMixin): self.inheritable_metadata = inheritable_metadata or {} @classmethod - def get_noop_manager(cls: type[BRM]) -> BRM: + def get_noop_manager(cls) -> Self: """Return a manager that doesn't perform any operations. Returns: diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 6a3474aec1f..d98c1ce4091 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -1258,7 +1258,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 ], *, - tool_choice: Optional[Union[str, Literal["any"]]] = None, + tool_choice: Optional[Union[str]] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tools to the model. diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 50de9c25adb..5f75df829f6 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -22,7 +22,7 @@ from pydantic import ( SkipValidation, model_validator, ) -from typing_extensions import override +from typing_extensions import Self, override from langchain_core._api import deprecated from langchain_core.load import Serializable @@ -304,12 +304,12 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): @classmethod def from_template( - cls: type[MessagePromptTemplateT], + cls, template: str, template_format: PromptTemplateFormat = "f-string", partial_variables: Optional[dict[str, Any]] = None, **kwargs: Any, - ) -> MessagePromptTemplateT: + ) -> Self: """Create a class from a string template. Args: @@ -335,11 +335,11 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): @classmethod def from_template_file( - cls: type[MessagePromptTemplateT], + cls, template_file: Union[str, Path], input_variables: list[str], **kwargs: Any, - ) -> MessagePromptTemplateT: + ) -> Self: """Create a class from a template file. Args: @@ -456,11 +456,6 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): ) -_StringImageMessagePromptTemplateT = TypeVar( - "_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate" -) - - class _TextTemplateParam(TypedDict, total=False): text: Union[str, dict] @@ -483,13 +478,13 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): @classmethod def from_template( - cls: type[_StringImageMessagePromptTemplateT], + cls: type[Self], template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]], template_format: PromptTemplateFormat = "f-string", *, partial_variables: Optional[dict[str, Any]] = None, **kwargs: Any, - ) -> _StringImageMessagePromptTemplateT: + ) -> Self: """Create a class from a string template. Args: @@ -576,11 +571,11 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): @classmethod def from_template_file( - cls: type[_StringImageMessagePromptTemplateT], + cls: type[Self], template_file: Union[str, Path], input_variables: list[str], **kwargs: Any, - ) -> _StringImageMessagePromptTemplateT: + ) -> Self: """Create a class from a template file. Args: diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 3d5fef4e2f2..cb08448ea48 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -4200,7 +4200,7 @@ class RunnableGenerator(Runnable[Input, Output]): ) @override - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, RunnableGenerator): if hasattr(self, "_transform") and hasattr(other, "_transform"): return self._transform == other._transform @@ -4582,7 +4582,7 @@ class RunnableLambda(Runnable[Input, Output]): return graph @override - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, RunnableLambda): if hasattr(self, "func") and hasattr(other, "func"): return self.func == other.func @@ -5880,22 +5880,24 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): class _RunnableCallableSync(Protocol[Input, Output]): - def __call__(self, __in: Input, *, config: RunnableConfig) -> Output: ... + def __call__(self, _in: Input, /, *, config: RunnableConfig) -> Output: ... class _RunnableCallableAsync(Protocol[Input, Output]): - def __call__(self, __in: Input, *, config: RunnableConfig) -> Awaitable[Output]: ... + def __call__( + self, _in: Input, /, *, config: RunnableConfig + ) -> Awaitable[Output]: ... class _RunnableCallableIterator(Protocol[Input, Output]): def __call__( - self, __in: Iterator[Input], *, config: RunnableConfig + self, _in: Iterator[Input], /, *, config: RunnableConfig ) -> Iterator[Output]: ... class _RunnableCallableAsyncIterator(Protocol[Input, Output]): def __call__( - self, __in: AsyncIterator[Input], *, config: RunnableConfig + self, _in: AsyncIterator[Input], /, *, config: RunnableConfig ) -> AsyncIterator[Output]: ... diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 2a882e427ee..9b3530f00ab 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -515,7 +515,7 @@ _T_contra = TypeVar("_T_contra", contravariant=True) class SupportsAdd(Protocol[_T_contra, _T_co]): """Protocol for objects that support addition.""" - def __add__(self, __x: _T_contra) -> _T_co: + def __add__(self, x: _T_contra, /) -> _T_co: """Add the object to another object.""" diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py index 9dd658681ed..8931d2074e2 100644 --- a/libs/core/langchain_core/utils/aiter.py +++ b/libs/core/langchain_core/utils/aiter.py @@ -88,7 +88,12 @@ class NoLock: async def __aenter__(self) -> None: """Do nothing.""" - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: """Exception not handled.""" return False @@ -237,7 +242,12 @@ class Tee(Generic[T]): """Return the tee instance.""" return self - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: """Close all child iterators.""" await self.aclose() return False diff --git a/libs/core/langchain_core/utils/iter.py b/libs/core/langchain_core/utils/iter.py index 445280dff33..15c203a5332 100644 --- a/libs/core/langchain_core/utils/iter.py +++ b/libs/core/langchain_core/utils/iter.py @@ -4,6 +4,7 @@ from collections import deque from collections.abc import Generator, Iterable, Iterator from contextlib import AbstractContextManager from itertools import islice +from types import TracebackType from typing import ( Any, Generic, @@ -24,7 +25,12 @@ class NoLock: def __enter__(self) -> None: """Do nothing.""" - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Literal[False]: """Exception not handled.""" return False @@ -173,7 +179,12 @@ class Tee(Generic[T]): """Return Tee instance.""" return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Literal[False]: """Close all child iterators.""" self.close() return False diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 9a8f9fd832b..fcea09490bf 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -377,12 +377,7 @@ if IS_PYDANTIC_V2: def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ... def get_fields( - model: Union[ - BaseModelV2, - BaseModelV1, - type[BaseModelV2], - type[BaseModelV1], - ], + model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1], ) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]: """Get the field names of a Pydantic model.""" if hasattr(model, "model_fields"): @@ -491,19 +486,21 @@ def _create_root_model_cached( @lru_cache(maxsize=256) def _create_model_cached( - __model_name: str, + model_name: str, + /, **field_definitions: Any, ) -> type[BaseModel]: return _create_model_base( - __model_name, + model_name, __config__=_SchemaConfig, **_remap_field_definitions(field_definitions), ) def create_model( - __model_name: str, - __module_name: Optional[str] = None, + model_name: str, + module_name: Optional[str] = None, + /, **field_definitions: Any, ) -> type[BaseModel]: """Create a pydantic model with the given field definitions. @@ -511,8 +508,8 @@ def create_model( Please use create_model_v2 instead of this function. Args: - __model_name: The name of the model. - __module_name: The name of the module where the model is defined. + model_name: The name of the model. + module_name: The name of the module where the model is defined. This is used by Pydantic to resolve any forward references. **field_definitions: The field definitions for the model. @@ -524,8 +521,8 @@ def create_model( kwargs["root"] = field_definitions.pop("__root__") return create_model_v2( - __model_name, - module_name=__module_name, + model_name, + module_name=module_name, field_definitions=field_definitions, **kwargs, ) diff --git a/libs/core/langchain_core/vectorstores/base.py b/libs/core/langchain_core/vectorstores/base.py index 8368a091e08..2c9615b5945 100644 --- a/libs/core/langchain_core/vectorstores/base.py +++ b/libs/core/langchain_core/vectorstores/base.py @@ -36,6 +36,7 @@ from typing import ( ) from pydantic import ConfigDict, Field, model_validator +from typing_extensions import Self from langchain_core.embeddings import Embeddings from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams @@ -818,11 +819,11 @@ class VectorStore(ABC): @classmethod def from_documents( - cls: type[VST], + cls, documents: list[Document], embedding: Embeddings, **kwargs: Any, - ) -> VST: + ) -> Self: """Return VectorStore initialized from documents and embeddings. Args: @@ -848,11 +849,11 @@ class VectorStore(ABC): @classmethod async def afrom_documents( - cls: type[VST], + cls, documents: list[Document], embedding: Embeddings, **kwargs: Any, - ) -> VST: + ) -> Self: """Async return VectorStore initialized from documents and embeddings. Args: @@ -903,14 +904,14 @@ class VectorStore(ABC): @classmethod async def afrom_texts( - cls: type[VST], + cls, texts: list[str], embedding: Embeddings, metadatas: Optional[list[dict]] = None, *, ids: Optional[list[str]] = None, **kwargs: Any, - ) -> VST: + ) -> Self: """Async return VectorStore initialized from texts and embeddings. Args: diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index a7aa0452d57..14fe60c712b 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -103,7 +103,6 @@ ignore = [ "FBT002", "PGH003", "PLR", - "PYI", "RUF", "SLF", ] diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 72b6c6461eb..4facc2979ab 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5294,22 +5294,22 @@ async def test_ainvoke_on_returned_runnable() -> None: be runthroughaasync path (issue #13407). """ - def idchain_sync(__input: dict) -> bool: + def idchain_sync(_input: dict, /) -> bool: return False - async def idchain_async(__input: dict) -> bool: + async def idchain_async(_input: dict, /) -> bool: return True idchain = RunnableLambda(func=idchain_sync, afunc=idchain_async) - def func(__input: dict) -> Runnable: + def func(_input: dict, /) -> Runnable: return idchain assert await RunnableLambda(func).ainvoke({}) def test_invoke_stream_passthrough_assign_trace() -> None: - def idchain_sync(__input: dict) -> bool: + def idchain_sync(_input: dict, /) -> bool: return False chain = RunnablePassthrough.assign(urls=idchain_sync) @@ -5329,7 +5329,7 @@ def test_invoke_stream_passthrough_assign_trace() -> None: async def test_ainvoke_astream_passthrough_assign_trace() -> None: - def idchain_sync(__input: dict) -> bool: + def idchain_sync(_input: dict, /) -> bool: return False chain = RunnablePassthrough.assign(urls=idchain_sync) diff --git a/libs/core/tests/unit_tests/stubs.py b/libs/core/tests/unit_tests/stubs.py index 8fdeccfc7d8..5226b6aca28 100644 --- a/libs/core/tests/unit_tests/stubs.py +++ b/libs/core/tests/unit_tests/stubs.py @@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage class AnyStr(str): __slots__ = () - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, str) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index bc5e849c157..64368de0dda 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2369,7 +2369,7 @@ def test_tool_return_output_mixin() -> None: def __init__(self, x: int) -> None: self.x = x - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) and self.x == other.x @tool diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 6e32c2d0b00..5562ee5da19 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -994,12 +994,12 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None: ) def test_convert_union_type_py_39() -> None: @tool - def magic_function(input: int | float) -> str: # noqa: FA102 + def magic_function(input: int | str) -> str: # noqa: FA102 """Compute a magic function.""" result = convert_to_openai_function(magic_function) assert result["parameters"]["properties"]["input"] == { - "anyOf": [{"type": "integer"}, {"type": "number"}] + "anyOf": [{"type": "integer"}, {"type": "string"}] }