diff --git a/libs/core/langchain_core/_api/beta_decorator.py b/libs/core/langchain_core/_api/beta_decorator.py index 1c421b5b01b..46514760549 100644 --- a/libs/core/langchain_core/_api/beta_decorator.py +++ b/libs/core/langchain_core/_api/beta_decorator.py @@ -155,7 +155,7 @@ def beta( _name = _name or obj.fget.__qualname__ old_doc = obj.__doc__ - class _beta_property(property): + class _BetaProperty(property): """A beta property.""" def __init__(self, fget=None, fset=None, fdel=None, doc=None): @@ -186,7 +186,7 @@ def beta( def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: """Finalize the property.""" - return _beta_property( + return _BetaProperty( fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc ) diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 33d59e819e5..5153e8acdac 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -264,7 +264,7 @@ def deprecated( _name = _name or cast(Union[type, Callable], obj.fget).__qualname__ old_doc = obj.__doc__ - class _deprecated_property(property): + class _DeprecatedProperty(property): """A deprecated property.""" def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def] @@ -297,7 +297,7 @@ def deprecated( """Finalize the property.""" return cast( T, - _deprecated_property( + _DeprecatedProperty( fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc ), ) diff --git a/libs/core/langchain_core/exceptions.py b/libs/core/langchain_core/exceptions.py index 3c3f2785112..7c60ccfa4db 100644 --- a/libs/core/langchain_core/exceptions.py +++ b/libs/core/langchain_core/exceptions.py @@ -3,7 +3,7 @@ from typing import Any, Optional -class LangChainException(Exception): +class LangChainException(Exception): # noqa: N818 """General LangChain exception.""" @@ -11,7 +11,7 @@ class TracerException(LangChainException): """Base class for exceptions in tracers module.""" -class OutputParserException(ValueError, LangChainException): +class OutputParserException(ValueError, LangChainException): # noqa: N818 """Exception that output parsers should raise to signify a parsing error. This exists to differentiate parsing errors from other code or execution errors diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index dc1106205f3..44ccd90e550 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -14,7 +14,7 @@ from typing import ( ) from pydantic import BaseModel, ConfigDict, Field, field_validator -from typing_extensions import TypeAlias, TypedDict +from typing_extensions import TypeAlias, TypedDict, override from langchain_core._api import deprecated from langchain_core.messages import ( @@ -143,6 +143,7 @@ class BaseLanguageModel( return verbose @property + @override def InputType(self) -> TypeAlias: """Get the input type for this runnable.""" from langchain_core.prompt_values import ( diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 9553cb61225..f322383065f 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -26,6 +26,7 @@ from pydantic import ( Field, model_validator, ) +from typing_extensions import override from langchain_core._api import deprecated from langchain_core.caches import BaseCache @@ -251,6 +252,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): # --- Runnable methods --- @property + @override def OutputType(self) -> Any: """Get the output type for this runnable.""" return AnyMessage diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 7458efd3120..7eeb679b7a2 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -31,6 +31,7 @@ from tenacity import ( stop_after_attempt, wait_exponential, ) +from typing_extensions import override from langchain_core._api import deprecated from langchain_core.caches import BaseCache @@ -318,6 +319,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): # --- Runnable methods --- @property + @override def OutputType(self) -> type[str]: """Get the input type for this runnable.""" return str diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index a6d539ecc6b..16464806dd3 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -10,6 +10,8 @@ from typing import ( Union, ) +from typing_extensions import override + from langchain_core.language_models import LanguageModelOutput from langchain_core.messages import AnyMessage, BaseMessage from langchain_core.outputs import ChatGeneration, Generation @@ -63,11 +65,13 @@ class BaseGenerationOutputParser( """Base class to parse the output of an LLM call.""" @property + @override def InputType(self) -> Any: """Return the input type for the parser.""" return Union[str, AnyMessage] @property + @override def OutputType(self) -> type[T]: """Return the output type for the parser.""" # even though mypy complains this isn't valid, @@ -148,11 +152,13 @@ class BaseOutputParser( """ # noqa: E501 @property + @override def InputType(self) -> Any: """Return the input type for the parser.""" return Union[str, AnyMessage] @property + @override def OutputType(self) -> type[T]: """Return the output type for the parser. diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index 081dc47094e..23c0dcf90a8 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -3,6 +3,7 @@ from typing import Annotated, Generic, Optional import pydantic from pydantic import SkipValidation +from typing_extensions import override from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import JsonOutputParser @@ -107,6 +108,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): return "pydantic" @property + @override def OutputType(self) -> type[TBaseModel]: """Return the pydantic model.""" return self.pydantic_object diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index d7249f8c0db..96ca870cab6 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,6 +1,6 @@ import re import xml -import xml.etree.ElementTree as ET +import xml.etree.ElementTree as ET # noqa: N817 from collections.abc import AsyncIterator, Iterator from typing import Any, Literal, Optional, Union from xml.etree.ElementTree import TreeBuilder @@ -46,14 +46,14 @@ class _StreamingParser: """ if parser == "defusedxml": try: - from defusedxml import ElementTree as DET # type: ignore + import defusedxml # type: ignore except ImportError as e: raise ImportError( "defusedxml is not installed. " "Please install it to use the defusedxml parser." "You can install it with `pip install defusedxml` " ) from e - _parser = DET.DefusedXMLParser(target=TreeBuilder()) + _parser = defusedxml.ElementTree.DefusedXMLParser(target=TreeBuilder()) else: _parser = None self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser) @@ -189,7 +189,7 @@ class XMLOutputParser(BaseTransformOutputParser): # likely if you're reading this you can move them to the top of the file if self.parser == "defusedxml": try: - from defusedxml import ElementTree as DET # type: ignore + import defusedxml # type: ignore except ImportError as e: raise ImportError( "defusedxml is not installed. " @@ -197,9 +197,9 @@ class XMLOutputParser(BaseTransformOutputParser): "You can install it with `pip install defusedxml`" "See https://github.com/tiran/defusedxml for more details" ) from e - _ET = DET # Use the defusedxml parser + _et = defusedxml.ElementTree # Use the defusedxml parser else: - _ET = ET # Use the standard library parser + _et = ET # Use the standard library parser match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) if match is not None: diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index e1c43157d59..f98d87d780c 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -18,7 +18,7 @@ from typing import ( import yaml from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Self +from typing_extensions import Self, override from langchain_core.load import dumpd from langchain_core.output_parsers.base import BaseOutputParser @@ -107,6 +107,7 @@ class BasePromptTemplate( return dumpd(self) @property + @override def OutputType(self) -> Any: """Return the output type of the prompt.""" return Union[StringPromptValue, ChatPromptValueConcrete] diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 1960979317c..3c7262fb949 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -36,7 +36,7 @@ from typing import ( ) from pydantic import BaseModel, ConfigDict, Field, RootModel -from typing_extensions import Literal, get_args +from typing_extensions import Literal, get_args, override from langchain_core._api import beta_decorator from langchain_core.load.serializable import ( @@ -272,7 +272,7 @@ class Runnable(Generic[Input, Output], ABC): return name_ @property - def InputType(self) -> type[Input]: + def InputType(self) -> type[Input]: # noqa: N802 """The type of input this Runnable accepts specified as a type annotation.""" # First loop through all parent classes and if any of them is # a pydantic model, we will pick up the generic parameterization @@ -297,7 +297,7 @@ class Runnable(Generic[Input, Output], ABC): ) @property - def OutputType(self) -> type[Output]: + def OutputType(self) -> type[Output]: # noqa: N802 """The type of output this Runnable produces specified as a type annotation.""" # First loop through bases -- this will help generic # any pydantic models. @@ -2811,11 +2811,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]): ) @property + @override def InputType(self) -> type[Input]: """The type of the input to the Runnable.""" return self.first.InputType @property + @override def OutputType(self) -> type[Output]: """The type of the output of the Runnable.""" return self.last.OutputType @@ -3564,6 +3566,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): return super().get_name(suffix, name=name) @property + @override def InputType(self) -> Any: """The type of the input to the Runnable.""" for step in self.steps__.values(): @@ -4057,6 +4060,7 @@ class RunnableGenerator(Runnable[Input, Output]): self.name = "RunnableGenerator" @property + @override def InputType(self) -> Any: func = getattr(self, "_transform", None) or self._atransform try: @@ -4097,6 +4101,7 @@ class RunnableGenerator(Runnable[Input, Output]): ) @property + @override def OutputType(self) -> Any: func = getattr(self, "_transform", None) or self._atransform try: @@ -4346,6 +4351,7 @@ class RunnableLambda(Runnable[Input, Output]): pass @property + @override def InputType(self) -> Any: """The type of the input to this Runnable.""" func = getattr(self, "func", None) or self.afunc @@ -4405,6 +4411,7 @@ class RunnableLambda(Runnable[Input, Output]): return super().get_input_schema(config) @property + @override def OutputType(self) -> Any: """The type of the output of this Runnable as a type annotation. @@ -4958,6 +4965,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]): ) @property + @override def InputType(self) -> Any: return list[self.bound.InputType] # type: ignore[name-defined] @@ -4981,6 +4989,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]): ) @property + @override def OutputType(self) -> type[list[Output]]: return list[self.bound.OutputType] # type: ignore[name-defined] @@ -5274,6 +5283,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): return self.bound.get_name(suffix, name=name) @property + @override def InputType(self) -> type[Input]: return ( cast(type[Input], self.custom_input_type) @@ -5282,6 +5292,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): ) @property + @override def OutputType(self) -> type[Output]: return ( cast(type[Output], self.custom_output_type) diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index 373431c1bdc..bfd00c1de7f 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -16,6 +16,7 @@ from typing import ( from weakref import WeakValueDictionary from pydantic import BaseModel, ConfigDict +from typing_extensions import override from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.config import ( @@ -68,10 +69,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): return ["langchain", "schema", "runnable"] @property + @override def InputType(self) -> type[Input]: return self.default.InputType @property + @override def OutputType(self) -> type[Output]: return self.default.OutputType diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index c97ca9451bc..1d8e19bb034 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -13,6 +13,7 @@ from typing import ( ) from pydantic import BaseModel, ConfigDict +from typing_extensions import override from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.config import ( @@ -106,10 +107,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): ) @property + @override def InputType(self) -> type[Input]: return self.runnable.InputType @property + @override def OutputType(self) -> type[Output]: return self.runnable.OutputType diff --git a/libs/core/langchain_core/runnables/graph_ascii.py b/libs/core/langchain_core/runnables/graph_ascii.py index 5aa05c7ec20..91a647f20e2 100644 --- a/libs/core/langchain_core/runnables/graph_ascii.py +++ b/libs/core/langchain_core/runnables/graph_ascii.py @@ -246,27 +246,27 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str: # NOTE: coordinates might me negative, so we need to shift # everything to the positive plane before we actually draw it. - Xs = [] - Ys = [] + xlist = [] + ylist = [] sug = _build_sugiyama_layout(vertices, edges) for vertex in sug.g.sV: # NOTE: moving boxes w/2 to the left - Xs.append(vertex.view.xy[0] - vertex.view.w / 2.0) - Xs.append(vertex.view.xy[0] + vertex.view.w / 2.0) - Ys.append(vertex.view.xy[1]) - Ys.append(vertex.view.xy[1] + vertex.view.h) + xlist.append(vertex.view.xy[0] - vertex.view.w / 2.0) + xlist.append(vertex.view.xy[0] + vertex.view.w / 2.0) + ylist.append(vertex.view.xy[1]) + ylist.append(vertex.view.xy[1] + vertex.view.h) for edge in sug.g.sE: for x, y in edge.view._pts: - Xs.append(x) - Ys.append(y) + xlist.append(x) + ylist.append(y) - minx = min(Xs) - miny = min(Ys) - maxx = max(Xs) - maxy = max(Ys) + minx = min(xlist) + miny = min(ylist) + maxx = max(xlist) + maxy = max(ylist) canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1 canvas_lines = int(round(maxy - miny)) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 6b60c3b3d12..71135513c5a 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -12,6 +12,7 @@ from typing import ( ) from pydantic import BaseModel +from typing_extensions import override from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.load.load import load @@ -396,6 +397,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): ) @property + @override def OutputType(self) -> type[Output]: output_type = self._history_chain.OutputType return output_type diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 68ca7e1c93d..3fc82ec9feb 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -16,6 +16,7 @@ from typing import ( ) from pydantic import BaseModel, RootModel +from typing_extensions import override from langchain_core.runnables.base import ( Other, @@ -193,10 +194,12 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): return ["langchain", "schema", "runnable"] @property + @override def InputType(self) -> Any: return self.input_type or Any @property + @override def OutputType(self) -> Any: return self.input_type or Any diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 09a6877fc75..e464c971b71 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -28,7 +28,7 @@ from typing import ( Union, ) -from typing_extensions import TypeGuard +from typing_extensions import TypeGuard, override from langchain_core.runnables.schema import StreamEvent @@ -135,6 +135,7 @@ class IsLocalDict(ast.NodeVisitor): self.name = name self.keys = keys + @override def visit_Subscript(self, node: ast.Subscript) -> Any: """Visit a subscript node. @@ -154,6 +155,7 @@ class IsLocalDict(ast.NodeVisitor): # we've found a subscript access on the name we're looking for self.keys.add(node.slice.value) + @override def visit_Call(self, node: ast.Call) -> Any: """Visit a call node. @@ -182,6 +184,7 @@ class IsFunctionArgDict(ast.NodeVisitor): def __init__(self) -> None: self.keys: set[str] = set() + @override def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. @@ -196,6 +199,7 @@ class IsFunctionArgDict(ast.NodeVisitor): input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node.body) + @override def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: """Visit a function definition. @@ -210,6 +214,7 @@ class IsFunctionArgDict(ast.NodeVisitor): input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node) + @override def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: """Visit an async function definition. @@ -232,6 +237,7 @@ class NonLocals(ast.NodeVisitor): self.loads: set[str] = set() self.stores: set[str] = set() + @override def visit_Name(self, node: ast.Name) -> Any: """Visit a name node. @@ -246,6 +252,7 @@ class NonLocals(ast.NodeVisitor): elif isinstance(node.ctx, ast.Store): self.stores.add(node.id) + @override def visit_Attribute(self, node: ast.Attribute) -> Any: """Visit an attribute node. @@ -272,6 +279,7 @@ class FunctionNonLocals(ast.NodeVisitor): def __init__(self) -> None: self.nonlocals: set[str] = set() + @override def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: """Visit a function definition. @@ -285,6 +293,7 @@ class FunctionNonLocals(ast.NodeVisitor): visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores) + @override def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: """Visit an async function definition. @@ -298,6 +307,7 @@ class FunctionNonLocals(ast.NodeVisitor): visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores) + @override def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. @@ -320,6 +330,7 @@ class GetLambdaSource(ast.NodeVisitor): self.source: Optional[str] = None self.count = 0 + @override def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index a3fe0aba534..b540767ab65 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -311,7 +311,7 @@ def create_schema_from_function( ) -class ToolException(Exception): +class ToolException(Exception): # noqa: N818 """Optional exception that tool throws when execution error occurs. When this exception is thrown, the agent will not stop working, diff --git a/libs/core/langchain_core/tracers/langchain_v1.py b/libs/core/langchain_core/tracers/langchain_v1.py index bf1237d66ab..ea1c882ea67 100644 --- a/libs/core/langchain_core/tracers/langchain_v1.py +++ b/libs/core/langchain_core/tracers/langchain_v1.py @@ -9,7 +9,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any: ) -def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: +def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802 """Throw an error because this has been replaced by LangChainTracer.""" raise RuntimeError( "LangChainTracerV1 is no longer supported. Please use LangChainTracer instead." diff --git a/libs/core/langchain_core/tracers/schemas.py b/libs/core/langchain_core/tracers/schemas.py index 72ae7d14b17..c9f493d3e2a 100644 --- a/libs/core/langchain_core/tracers/schemas.py +++ b/libs/core/langchain_core/tracers/schemas.py @@ -18,7 +18,7 @@ from langchain_core._api import deprecated @deprecated("0.1.0", alternative="Use string instead.", removal="1.0") -def RunTypeEnum() -> type[RunTypeEnumDep]: +def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802 """RunTypeEnum.""" warnings.warn( "RunTypeEnum is deprecated. Please directly use a string instead" diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py index 9c42fc70e38..135b116a419 100644 --- a/libs/core/langchain_core/utils/aiter.py +++ b/libs/core/langchain_core/utils/aiter.py @@ -235,7 +235,7 @@ class Tee(Generic[T]): atee = Tee -class aclosing(AbstractAsyncContextManager): +class aclosing(AbstractAsyncContextManager): # noqa: N801 """Async context manager for safely finalizing an asynchronously cleaned-up resource such as an async generator, calling its ``aclose()`` method. diff --git a/libs/core/langchain_core/vectorstores/utils.py b/libs/core/langchain_core/vectorstores/utils.py index c744d2ef621..777bb68b68d 100644 --- a/libs/core/langchain_core/vectorstores/utils.py +++ b/libs/core/langchain_core/vectorstores/utils.py @@ -17,12 +17,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: +def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: """Row-wise cosine similarity between two equal-width matrices. Args: - X: A matrix of shape (n, m). - Y: A matrix of shape (k, m). + x: A matrix of shape (n, m). + y: A matrix of shape (k, m). Returns: A matrix of shape (n, k) where each element (i, j) is the cosine similarity @@ -40,33 +40,33 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: "Please install numpy with `pip install numpy`." ) from e - if len(X) == 0 or len(Y) == 0: + if len(x) == 0 or len(y) == 0: return np.array([]) - X = np.array(X) - Y = np.array(Y) - if X.shape[1] != Y.shape[1]: + x = np.array(x) + y = np.array(y) + if x.shape[1] != y.shape[1]: raise ValueError( - f"Number of columns in X and Y must be the same. X has shape {X.shape} " - f"and Y has shape {Y.shape}." + f"Number of columns in X and Y must be the same. X has shape {x.shape} " + f"and Y has shape {y.shape}." ) try: import simsimd as simd # type: ignore[import-not-found] - X = np.array(X, dtype=np.float32) - Y = np.array(Y, dtype=np.float32) - Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) - return Z + x = np.array(x, dtype=np.float32) + y = np.array(y, dtype=np.float32) + z = 1 - np.array(simd.cdist(x, y, metric="cosine")) + return z except ImportError: logger.debug( "Unable to import simsimd, defaulting to NumPy implementation. If you want " "to use simsimd please install with `pip install simsimd`." ) - X_norm = np.linalg.norm(X, axis=1) - Y_norm = np.linalg.norm(Y, axis=1) + x_norm = np.linalg.norm(x, axis=1) + y_norm = np.linalg.norm(y, axis=1) # Ignore divide by zero errors run time warnings as those are handled below. with np.errstate(divide="ignore", invalid="ignore"): - similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) + similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm) similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 return similarity diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 9647a19569a..bdece6d32d1 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -44,9 +44,17 @@ python = ">=3.12.4" target-version = "py39" [tool.ruff.lint] -select = ["B", "E", "F", "I", "T201", "UP"] +select = ["B", "E", "F", "I", "N", "T201", "UP"] ignore = ["UP007"] +[tool.ruff.lint.pep8-naming] +classmethod-decorators = [ + "classmethod", + "langchain_core.utils.pydantic.pre_init", + "pydantic.field_validator", + "pydantic.v1.root_validator", +] + [tool.coverage.run] omit = [ "tests/*",] diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 4e1e836f0b9..2e92cf2f18f 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -9,9 +9,9 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from tests.unit_tests.stubs import ( - _AnyIdAIMessage, - _AnyIdAIMessageChunk, - _AnyIdHumanMessage, + _any_id_ai_message, + _any_id_ai_message_chunk, + _any_id_human_message, ) @@ -20,11 +20,11 @@ def test_generic_fake_chat_model_invoke() -> None: infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) model = GenericFakeChatModel(messages=infinite_cycle) response = model.invoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") response = model.invoke("kitty") - assert response == _AnyIdAIMessage(content="goodbye") + assert response == _any_id_ai_message(content="goodbye") response = model.invoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") async def test_generic_fake_chat_model_ainvoke() -> None: @@ -32,11 +32,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None: infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) model = GenericFakeChatModel(messages=infinite_cycle) response = await model.ainvoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") response = await model.ainvoke("kitty") - assert response == _AnyIdAIMessage(content="goodbye") + assert response == _any_id_ai_message(content="goodbye") response = await model.ainvoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") async def test_generic_fake_chat_model_stream() -> None: @@ -49,17 +49,17 @@ async def test_generic_fake_chat_model_stream() -> None: model = GenericFakeChatModel(messages=infinite_cycle) chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert len({chunk.id for chunk in chunks}) == 1 chunks = [chunk for chunk in model.stream("meow")] assert chunks == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert len({chunk.id for chunk in chunks}) == 1 @@ -69,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None: model = GenericFakeChatModel(messages=cycle([message])) chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - _AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}), - _AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}), + _any_id_ai_message_chunk(content="", additional_kwargs={"foo": 42}), + _any_id_ai_message_chunk(content="", additional_kwargs={"bar": 24}), ] assert len({chunk.id for chunk in chunks}) == 1 @@ -88,19 +88,19 @@ async def test_generic_fake_chat_model_stream() -> None: chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - _AnyIdAIMessageChunk( + _any_id_ai_message_chunk( content="", additional_kwargs={"function_call": {"name": "move_file"}} ), - _AnyIdAIMessageChunk( + _any_id_ai_message_chunk( content="", additional_kwargs={ "function_call": {"arguments": '{\n "source_path": "foo"'}, }, ), - _AnyIdAIMessageChunk( + _any_id_ai_message_chunk( content="", additional_kwargs={"function_call": {"arguments": ","}} ), - _AnyIdAIMessageChunk( + _any_id_ai_message_chunk( content="", additional_kwargs={ "function_call": {"arguments": '\n "destination_path": "bar"\n}'}, @@ -138,9 +138,9 @@ async def test_generic_fake_chat_model_astream_log() -> None: ] final = log_patches[-1] assert final.state["streamed_output"] == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1 @@ -189,9 +189,9 @@ async def test_callback_handlers() -> None: # New model results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) assert results == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert tokens == ["hello", " ", "goodbye"] assert len({chunk.id for chunk in results}) == 1 @@ -200,6 +200,8 @@ async def test_callback_handlers() -> None: def test_chat_model_inputs() -> None: fake = ParrotFakeChatModel() - assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello") - assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah") - assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah") + assert fake.invoke("hello") == _any_id_human_message(content="hello") + assert fake.invoke([("ai", "blah")]) == _any_id_ai_message(content="blah") + assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message( + content="blah" + ) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 499711d0c9f..f8aab17f930 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -27,7 +27,7 @@ from tests.unit_tests.fake.callbacks import ( FakeAsyncCallbackHandler, FakeCallbackHandler, ) -from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk @pytest.fixture @@ -147,10 +147,10 @@ async def test_astream_fallback_to_ainvoke() -> None: model = ModelWithGenerate() chunks = [chunk for chunk in model.stream("anything")] - assert chunks == [_AnyIdAIMessage(content="hello")] + assert chunks == [_any_id_ai_message(content="hello")] chunks = [chunk async for chunk in model.astream("anything")] - assert chunks == [_AnyIdAIMessage(content="hello")] + assert chunks == [_any_id_ai_message(content="hello")] async def test_astream_implementation_fallback_to_stream() -> None: @@ -185,15 +185,15 @@ async def test_astream_implementation_fallback_to_stream() -> None: model = ModelWithSyncStream() chunks = [chunk for chunk in model.stream("anything")] assert chunks == [ - _AnyIdAIMessageChunk(content="a"), - _AnyIdAIMessageChunk(content="b"), + _any_id_ai_message_chunk(content="a"), + _any_id_ai_message_chunk(content="b"), ] assert len({chunk.id for chunk in chunks}) == 1 assert type(model)._astream == BaseChatModel._astream astream_chunks = [chunk async for chunk in model.astream("anything")] assert astream_chunks == [ - _AnyIdAIMessageChunk(content="a"), - _AnyIdAIMessageChunk(content="b"), + _any_id_ai_message_chunk(content="a"), + _any_id_ai_message_chunk(content="b"), ] assert len({chunk.id for chunk in astream_chunks}) == 1 @@ -230,8 +230,8 @@ async def test_astream_implementation_uses_astream() -> None: model = ModelWithAsyncStream() chunks = [chunk async for chunk in model.astream("anything")] assert chunks == [ - _AnyIdAIMessageChunk(content="a"), - _AnyIdAIMessageChunk(content="b"), + _any_id_ai_message_chunk(content="a"), + _any_id_ai_message_chunk(content="b"), ] assert len({chunk.id for chunk in chunks}) == 1 diff --git a/libs/core/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py index cd9423de5a7..2a98a1e95ce 100644 --- a/libs/core/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -25,7 +25,7 @@ def change_directory(dir: Path) -> Iterator: os.chdir(origin) -def test_loading_from_YAML() -> None: +def test_loading_from_yaml() -> None: """Test loading from yaml file.""" prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml") expected_prompt = PromptTemplate( @@ -36,7 +36,7 @@ def test_loading_from_YAML() -> None: assert prompt == expected_prompt -def test_loading_from_JSON() -> None: +def test_loading_from_json() -> None: """Test loading from json file.""" prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json") expected_prompt = PromptTemplate( @@ -46,14 +46,14 @@ def test_loading_from_JSON() -> None: assert prompt == expected_prompt -def test_loading_jinja_from_JSON() -> None: +def test_loading_jinja_from_json() -> None: """Test that loading jinja2 format prompts from JSON raises ValueError.""" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json" with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"): load_prompt(prompt_path) -def test_loading_jinja_from_YAML() -> None: +def test_loading_jinja_from_yaml() -> None: """Test that loading jinja2 format prompts from YAML raises ValueError.""" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml" with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"): diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 9f75682cdc3..2699e6cab48 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -2,6 +2,7 @@ from typing import Optional from pydantic import BaseModel from syrupy import SnapshotAssertion +from typing_extensions import override from langchain_core.language_models import FakeListLLM from langchain_core.output_parsers.list import CommaSeparatedListOutputParser @@ -353,9 +354,11 @@ def test_runnable_get_graph_with_invalid_input_type() -> None: class InvalidInputTypeRunnable(Runnable[int, int]): @property + @override def InputType(self) -> type: raise TypeError() + @override def invoke( self, input: int, @@ -375,9 +378,11 @@ def test_runnable_get_graph_with_invalid_output_type() -> None: class InvalidOutputTypeRunnable(Runnable[int, int]): @property + @override def OutputType(self) -> type: raise TypeError() + @override def invoke( self, input: int, diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index acc0a8d7792..710f66a1885 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -493,7 +493,7 @@ def test_get_output_schema() -> None: def test_get_input_schema_input_messages() -> None: from pydantic import RootModel - RunnableWithMessageHistoryInput = RootModel[Sequence[BaseMessage]] + runnable_with_message_history_input = RootModel[Sequence[BaseMessage]] runnable = RunnableLambda( lambda messages: { @@ -515,7 +515,7 @@ def test_get_input_schema_input_messages() -> None: with_history = RunnableWithMessageHistory( runnable, get_session_history, output_messages_key="output" ) - expected_schema = _schema(RunnableWithMessageHistoryInput) + expected_schema = _schema(runnable_with_message_history_input) expected_schema["title"] = "RunnableWithChatHistoryInput" assert _schema(with_history.get_input_schema()) == expected_schema diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index e4cdc124b01..69df6523d0d 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -87,7 +87,7 @@ from langchain_core.tracers import ( ) from langchain_core.tracers.context import collect_runs from tests.unit_tests.pydantic_utils import _normalize_schema, _schema -from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split("."))) @@ -1699,7 +1699,7 @@ def test_prompt_with_chat_model( tracer = FakeTracer() assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) - ) == _AnyIdAIMessage(content="foo") + ) == _any_id_ai_message(content="foo") assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( messages=[ @@ -1724,8 +1724,8 @@ def test_prompt_with_chat_model( ], dict(callbacks=[tracer]), ) == [ - _AnyIdAIMessage(content="foo"), - _AnyIdAIMessage(content="foo"), + _any_id_ai_message(content="foo"), + _any_id_ai_message(content="foo"), ] assert prompt_spy.call_args.args[1] == [ {"question": "What is your name?"}, @@ -1765,9 +1765,9 @@ def test_prompt_with_chat_model( assert [ *chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer])) ] == [ - _AnyIdAIMessageChunk(content="f"), - _AnyIdAIMessageChunk(content="o"), - _AnyIdAIMessageChunk(content="o"), + _any_id_ai_message_chunk(content="f"), + _any_id_ai_message_chunk(content="o"), + _any_id_ai_message_chunk(content="o"), ] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -1805,7 +1805,7 @@ async def test_prompt_with_chat_model_async( tracer = FakeTracer() assert await chain.ainvoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) - ) == _AnyIdAIMessage(content="foo") + ) == _any_id_ai_message(content="foo") assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( messages=[ @@ -1830,8 +1830,8 @@ async def test_prompt_with_chat_model_async( ], dict(callbacks=[tracer]), ) == [ - _AnyIdAIMessage(content="foo"), - _AnyIdAIMessage(content="foo"), + _any_id_ai_message(content="foo"), + _any_id_ai_message(content="foo"), ] assert prompt_spy.call_args.args[1] == [ {"question": "What is your name?"}, @@ -1874,9 +1874,9 @@ async def test_prompt_with_chat_model_async( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) ] == [ - _AnyIdAIMessageChunk(content="f"), - _AnyIdAIMessageChunk(content="o"), - _AnyIdAIMessageChunk(content="o"), + _any_id_ai_message_chunk(content="f"), + _any_id_ai_message_chunk(content="o"), + _any_id_ai_message_chunk(content="o"), ] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -2548,7 +2548,7 @@ def test_prompt_with_chat_model_and_parser( HumanMessage(content="What is your name?"), ] ) - assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar") + assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar") assert tracer.runs == snapshot @@ -2681,7 +2681,7 @@ Question: ), ] ) - assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar") + assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar") assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 parent_run = next(r for r in tracer.runs if r.parent_run_id is None) assert len(parent_run.child_runs) == 4 @@ -2727,7 +2727,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) == { - "chat": _AnyIdAIMessage(content="i'm a chatbot"), + "chat": _any_id_ai_message(content="i'm a chatbot"), "llm": "i'm a textbot", } assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} @@ -2936,7 +2936,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) == { - "chat": _AnyIdAIMessage(content="i'm a chatbot"), + "chat": _any_id_ai_message(content="i'm a chatbot"), "llm": "i'm a textbot", "passthrough": ChatPromptValue( messages=[ @@ -3000,7 +3000,7 @@ def test_map_stream() -> None: assert streamed_chunks[0] in [ {"passthrough": prompt.invoke({"question": "What is your name?"})}, {"llm": "i"}, - {"chat": _AnyIdAIMessageChunk(content="i")}, + {"chat": _any_id_ai_message_chunk(content="i")}, ] assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1 assert all(len(c.keys()) == 1 for c in streamed_chunks) @@ -3059,11 +3059,11 @@ def test_map_stream() -> None: assert streamed_chunks[0] in [ {"llm": "i"}, - {"chat": _AnyIdAIMessageChunk(content="i")}, + {"chat": _any_id_ai_message_chunk(content="i")}, ] if not ( # TODO(Rewrite properly) statement above streamed_chunks[0] == {"llm": "i"} - or {"chat": _AnyIdAIMessageChunk(content="i")} + or {"chat": _any_id_ai_message_chunk(content="i")} ): raise AssertionError(f"Got an unexpected chunk: {streamed_chunks[0]}") @@ -3108,7 +3108,7 @@ def test_map_stream_iterator_input() -> None: assert streamed_chunks[0] in [ {"passthrough": "i"}, {"llm": "i"}, - {"chat": _AnyIdAIMessageChunk(content="i")}, + {"chat": _any_id_ai_message_chunk(content="i")}, ] assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res) assert all(len(c.keys()) == 1 for c in streamed_chunks) @@ -3152,7 +3152,7 @@ async def test_map_astream() -> None: assert streamed_chunks[0] in [ {"passthrough": prompt.invoke({"question": "What is your name?"})}, {"llm": "i"}, - {"chat": _AnyIdAIMessageChunk(content="i")}, + {"chat": _any_id_ai_message_chunk(content="i")}, ] assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1 assert all(len(c.keys()) == 1 for c in streamed_chunks) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index d82cfd6d9dd..378d947e6bb 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -32,7 +32,7 @@ from langchain_core.runnables import ( from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import tool -from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]: @@ -503,7 +503,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -512,7 +512,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -521,7 +521,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -530,7 +530,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"output": _AnyIdAIMessageChunk(content="hello world!")}, + "data": {"output": _any_id_ai_message_chunk(content="hello world!")}, "event": "on_chat_model_end", "metadata": {"a": "b"}, "name": "my_model", @@ -575,7 +575,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -588,7 +588,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -601,7 +601,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -621,7 +621,9 @@ async def test_astream_events_from_model() -> None: [ { "generation_info": None, - "message": _AnyIdAIMessage(content="hello world!"), + "message": _any_id_ai_message( + content="hello world!" + ), "text": "hello world!", "type": "ChatGeneration", } @@ -644,7 +646,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "i_dont_stream", @@ -653,7 +655,7 @@ async def test_astream_events_from_model() -> None: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "i_dont_stream", @@ -698,7 +700,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -711,7 +713,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -724,7 +726,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -744,7 +746,9 @@ async def test_astream_events_from_model() -> None: [ { "generation_info": None, - "message": _AnyIdAIMessage(content="hello world!"), + "message": _any_id_ai_message( + content="hello world!" + ), "text": "hello world!", "type": "ChatGeneration", } @@ -767,7 +771,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "ai_dont_stream", @@ -776,7 +780,7 @@ async def test_astream_events_from_model() -> None: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "ai_dont_stream", diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 67383d28cde..4b4815217f2 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -47,7 +47,7 @@ from langchain_core.utils.aiter import aclosing from tests.unit_tests.runnables.test_runnable_events_v1 import ( _assert_events_equal_allow_superset_metadata, ) -from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]: @@ -533,7 +533,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -546,7 +546,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -559,7 +559,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -573,7 +573,7 @@ async def test_astream_events_from_model() -> None: }, { "data": { - "output": _AnyIdAIMessageChunk(content="hello world!"), + "output": _any_id_ai_message_chunk(content="hello world!"), }, "event": "on_chat_model_end", "metadata": { @@ -640,7 +640,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -653,7 +653,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -666,7 +666,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -681,7 +681,7 @@ async def test_astream_with_model_in_chain() -> None: { "data": { "input": {"messages": [[HumanMessage(content="hello")]]}, - "output": _AnyIdAIMessage(content="hello world!"), + "output": _any_id_ai_message(content="hello world!"), }, "event": "on_chat_model_end", "metadata": { @@ -695,7 +695,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "i_dont_stream", @@ -704,7 +704,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "i_dont_stream", @@ -749,7 +749,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -762,7 +762,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -775,7 +775,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -790,7 +790,7 @@ async def test_astream_with_model_in_chain() -> None: { "data": { "input": {"messages": [[HumanMessage(content="hello")]]}, - "output": _AnyIdAIMessage(content="hello world!"), + "output": _any_id_ai_message(content="hello world!"), }, "event": "on_chat_model_end", "metadata": { @@ -804,7 +804,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "ai_dont_stream", @@ -813,7 +813,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "ai_dont_stream", diff --git a/libs/core/tests/unit_tests/stubs.py b/libs/core/tests/unit_tests/stubs.py index b752364e3af..95f36b72b0f 100644 --- a/libs/core/tests/unit_tests/stubs.py +++ b/libs/core/tests/unit_tests/stubs.py @@ -16,28 +16,28 @@ class AnyStr(str): # subclassed strings. -def _AnyIdDocument(**kwargs: Any) -> Document: +def _any_id_document(**kwargs: Any) -> Document: """Create a document with an id field.""" message = Document(**kwargs) message.id = AnyStr() return message -def _AnyIdAIMessage(**kwargs: Any) -> AIMessage: +def _any_id_ai_message(**kwargs: Any) -> AIMessage: """Create ai message with an any id field.""" message = AIMessage(**kwargs) message.id = AnyStr() return message -def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk: +def _any_id_ai_message_chunk(**kwargs: Any) -> AIMessageChunk: """Create ai message with an any id field.""" message = AIMessageChunk(**kwargs) message.id = AnyStr() return message -def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: +def _any_id_human_message(**kwargs: Any) -> HumanMessage: """Create a human with an any id field.""" message = HumanMessage(**kwargs) message.id = AnyStr() diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index cc3e3551efb..90868cc18a6 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -781,7 +781,7 @@ def test_convert_to_messages() -> None: @pytest.mark.parametrize( - "MessageClass", + "message_class", [ AIMessage, AIMessageChunk, @@ -790,39 +790,39 @@ def test_convert_to_messages() -> None: SystemMessage, ], ) -def test_message_name(MessageClass: type) -> None: - msg = MessageClass(content="foo", name="bar") +def test_message_name(message_class: type) -> None: + msg = message_class(content="foo", name="bar") assert msg.name == "bar" - msg2 = MessageClass(content="foo", name=None) + msg2 = message_class(content="foo", name=None) assert msg2.name is None - msg3 = MessageClass(content="foo") + msg3 = message_class(content="foo") assert msg3.name is None @pytest.mark.parametrize( - "MessageClass", + "message_class", [FunctionMessage, FunctionMessageChunk], ) -def test_message_name_function(MessageClass: type) -> None: +def test_message_name_function(message_class: type) -> None: # functionmessage doesn't support name=None - msg = MessageClass(name="foo", content="bar") + msg = message_class(name="foo", content="bar") assert msg.name == "foo" @pytest.mark.parametrize( - "MessageClass", + "message_class", [ChatMessage, ChatMessageChunk], ) -def test_message_name_chat(MessageClass: type) -> None: - msg = MessageClass(content="foo", role="user", name="bar") +def test_message_name_chat(message_class: type) -> None: + msg = message_class(content="foo", role="user", name="bar") assert msg.name == "bar" - msg2 = MessageClass(content="foo", role="user", name=None) + msg2 = message_class(content="foo", role="user", name=None) assert msg2.name is None - msg3 = MessageClass(content="foo", role="user") + msg3 = message_class(content="foo", role="user") assert msg3.name is None diff --git a/libs/core/tests/unit_tests/test_pydantic_serde.py b/libs/core/tests/unit_tests/test_pydantic_serde.py index c19d14cb249..87af2fa5a61 100644 --- a/libs/core/tests/unit_tests/test_pydantic_serde.py +++ b/libs/core/tests/unit_tests/test_pydantic_serde.py @@ -51,14 +51,14 @@ def test_serde_any_message() -> None: ), ] - Model = RootModel[AnyMessage] + model = RootModel[AnyMessage] for lc_object in lc_objects: d = lc_object.model_dump() assert "type" in d, f"Missing key `type` for {type(lc_object)}" - obj1 = Model.model_validate(d) + obj1 = model.model_validate(d) assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}" with pytest.raises((TypeError, ValidationError)): # Make sure that specifically validation error is raised - Model.model_validate({}) + model.model_validate({}) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 35ff717dd2e..dc70ef9dbbe 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1421,7 +1421,7 @@ class InjectedTool(BaseTool): return y -class fooSchema(BaseModel): +class fooSchema(BaseModel): # noqa: N801 """foo.""" x: int = Field(..., description="abc") @@ -1568,14 +1568,14 @@ def test_tool_injected_arg() -> None: def test_tool_inherited_injected_arg() -> None: - class barSchema(BaseModel): + class BarSchema(BaseModel): """bar.""" y: Annotated[str, "foobar comment", InjectedToolArg()] = Field( ..., description="123" ) - class fooSchema(barSchema): + class FooSchema(BarSchema): """foo.""" x: int = Field(..., description="abc") @@ -1583,14 +1583,14 @@ def test_tool_inherited_injected_arg() -> None: class InheritedInjectedArgTool(BaseTool): name: str = "foo" description: str = "foo." - args_schema: type[BaseModel] = fooSchema + args_schema: type[BaseModel] = FooSchema def _run(self, x: int, y: str) -> Any: return y tool_ = InheritedInjectedArgTool() assert tool_.get_input_schema().model_json_schema() == { - "title": "fooSchema", # Matches the title from the provided schema + "title": "FooSchema", # Matches the title from the provided schema "description": "foo.", "type": "object", "properties": { @@ -1877,15 +1877,15 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None: A = TypeVar("A") if use_v1_namespace: - from pydantic.v1 import BaseModel as BM1 + from pydantic.v1 import BaseModel as BaseModel1 - class ModelA(BM1, Generic[A], extra="allow"): + class ModelA(BaseModel1, Generic[A], extra="allow"): a: A else: - from pydantic import BaseModel as BM2 + from pydantic import BaseModel as BaseModel2 from pydantic import ConfigDict - class ModelA(BM2, Generic[A]): # type: ignore[no-redef] + class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef] a: A model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") @@ -2081,13 +2081,13 @@ def test_structured_tool_direct_init() -> None: def foo(bar: str) -> str: return bar - async def asyncFoo(bar: str) -> str: + async def async_foo(bar: str) -> str: return bar - class fooSchema(BaseModel): + class FooSchema(BaseModel): bar: str = Field(..., description="The bar") - tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo) + tool = StructuredTool(name="foo", args_schema=FooSchema, coroutine=async_foo) with pytest.raises(NotImplementedError): assert tool.invoke("hello") == "hello" diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 64d017b5e74..6f09624d743 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -38,7 +38,7 @@ from langchain_core.utils.function_calling import ( @pytest.fixture() def pydantic() -> type[BaseModel]: - class dummy_function(BaseModel): + class dummy_function(BaseModel): # noqa: N801 """dummy function""" arg1: int = Field(..., description="foo") @@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]: @pytest.fixture() -def Annotated_function() -> Callable: +def annotated_function() -> Callable: def dummy_function( arg1: ExtensionsAnnotated[int, "foo"], arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"], @@ -118,7 +118,7 @@ def dummy_structured_tool() -> StructuredTool: @pytest.fixture() def dummy_pydantic() -> type[BaseModel]: - class dummy_function(BaseModel): + class dummy_function(BaseModel): # noqa: N801 """dummy function""" arg1: int = Field(..., description="foo") @@ -129,7 +129,7 @@ def dummy_pydantic() -> type[BaseModel]: @pytest.fixture() def dummy_pydantic_v2() -> type[BaseModelV2Maybe]: - class dummy_function(BaseModelV2Maybe): + class dummy_function(BaseModelV2Maybe): # noqa: N801 """dummy function""" arg1: int = FieldV2Maybe(..., description="foo") @@ -142,7 +142,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]: @pytest.fixture() def dummy_typing_typed_dict() -> type: - class dummy_function(TypingTypedDict): + class dummy_function(TypingTypedDict): # noqa: N801 """dummy function""" arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821 @@ -153,7 +153,7 @@ def dummy_typing_typed_dict() -> type: @pytest.fixture() def dummy_typing_typed_dict_docstring() -> type: - class dummy_function(TypingTypedDict): + class dummy_function(TypingTypedDict): # noqa: N801 """dummy function Args: @@ -169,7 +169,7 @@ def dummy_typing_typed_dict_docstring() -> type: @pytest.fixture() def dummy_extensions_typed_dict() -> type: - class dummy_function(ExtensionsTypedDict): + class dummy_function(ExtensionsTypedDict): # noqa: N801 """dummy function""" arg1: ExtensionsAnnotated[int, ..., "foo"] @@ -180,7 +180,7 @@ def dummy_extensions_typed_dict() -> type: @pytest.fixture() def dummy_extensions_typed_dict_docstring() -> type: - class dummy_function(ExtensionsTypedDict): + class dummy_function(ExtensionsTypedDict): # noqa: N801 """dummy function Args: @@ -241,7 +241,7 @@ def test_convert_to_openai_function( dummy_structured_tool: StructuredTool, dummy_tool: BaseTool, json_schema: dict, - Annotated_function: Callable, + annotated_function: Callable, dummy_pydantic: type[BaseModel], runnable: Runnable, dummy_typing_typed_dict: type, @@ -275,7 +275,7 @@ def test_convert_to_openai_function( expected, Dummy.dummy_function, DummyWithClassMethod.dummy_function, - Annotated_function, + annotated_function, dummy_pydantic, dummy_typing_typed_dict, dummy_typing_typed_dict_docstring, @@ -523,20 +523,20 @@ def test__convert_typed_dict_to_openai_function( use_extension_typed_dict: bool, use_extension_annotated: bool ) -> None: if use_extension_typed_dict: - TypedDict = ExtensionsTypedDict + typed_dict = ExtensionsTypedDict else: - TypedDict = TypingTypedDict + typed_dict = TypingTypedDict if use_extension_annotated: - Annotated = TypingAnnotated + annotated = TypingAnnotated else: - Annotated = TypingAnnotated + annotated = TypingAnnotated - class SubTool(TypedDict): + class SubTool(typed_dict): """Subtool docstring""" - args: Annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore + args: annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore - class Tool(TypedDict): + class Tool(typed_dict): """Docstring Args: @@ -546,20 +546,20 @@ def test__convert_typed_dict_to_openai_function( arg1: str arg2: Union[int, str, bool] arg3: Optional[list[SubTool]] - arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722 - arg5: Annotated[Optional[float], None] - arg6: Annotated[ + arg4: annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722 + arg5: annotated[Optional[float], None] + arg6: annotated[ Optional[Sequence[Mapping[str, tuple[Iterable[Any], SubTool]]]], [] ] - arg7: Annotated[list[SubTool], ...] - arg8: Annotated[tuple[SubTool], ...] - arg9: Annotated[Sequence[SubTool], ...] - arg10: Annotated[Iterable[SubTool], ...] - arg11: Annotated[set[SubTool], ...] - arg12: Annotated[dict[str, SubTool], ...] - arg13: Annotated[Mapping[str, SubTool], ...] - arg14: Annotated[MutableMapping[str, SubTool], ...] - arg15: Annotated[bool, False, "flag"] # noqa: F821 # type: ignore + arg7: annotated[list[SubTool], ...] + arg8: annotated[tuple[SubTool], ...] + arg9: annotated[Sequence[SubTool], ...] + arg10: annotated[Iterable[SubTool], ...] + arg11: annotated[set[SubTool], ...] + arg12: annotated[dict[str, SubTool], ...] + arg13: annotated[Mapping[str, SubTool], ...] + arg14: annotated[MutableMapping[str, SubTool], ...] + arg15: annotated[bool, False, "flag"] # noqa: F821 # type: ignore expected = { "name": "Tool", diff --git a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py index c0f9944d077..5373d022dbf 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py +++ b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py @@ -10,7 +10,7 @@ from langchain_standard_tests.integration_tests.vectorstores import ( from langchain_core.documents import Document from langchain_core.embeddings.fake import DeterministicFakeEmbedding from langchain_core.vectorstores import InMemoryVectorStore -from tests.unit_tests.stubs import _AnyIdDocument +from tests.unit_tests.stubs import _any_id_document class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite): @@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None: # Check sync version output = store.similarity_search("foo", k=1) - assert output == [_AnyIdDocument(page_content="foo")] + assert output == [_any_id_document(page_content="foo")] # Check async version output = await store.asimilarity_search("bar", k=2) assert output == [ - _AnyIdDocument(page_content="bar"), - _AnyIdDocument(page_content="baz"), + _any_id_document(page_content="bar"), + _any_id_document(page_content="baz"), ] @@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None: # make sure we can k > docstore size output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1) assert len(output) == len(texts) - assert output[0] == _AnyIdDocument(page_content="foo") - assert output[1] == _AnyIdDocument(page_content="foy") + assert output[0] == _any_id_document(page_content="foo") + assert output[1] == _any_id_document(page_content="foy") # Check async version output = await docsearch.amax_marginal_relevance_search( "foo", k=10, lambda_mult=0.1 ) assert len(output) == len(texts) - assert output[0] == _AnyIdDocument(page_content="foo") - assert output[1] == _AnyIdDocument(page_content="foy") + assert output[0] == _any_id_document(page_content="foo") + assert output[1] == _any_id_document(page_content="foy") async def test_inmemory_dump_load(tmp_path: Path) -> None: @@ -117,7 +117,7 @@ async def test_inmemory_filter() -> None: # Check sync version output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1) - assert output == [_AnyIdDocument(page_content="foo", metadata={"id": 1})] + assert output == [_any_id_document(page_content="foo", metadata={"id": 1})] # filter with not stored document id output = await store.asimilarity_search( diff --git a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py index a1797d777d1..29746bbe67d 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py +++ b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py @@ -50,6 +50,7 @@ class CustomAddTextsVectorstore(VectorStore): def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: return [self.store[id] for id in ids if id in self.store] + @classmethod def from_texts( # type: ignore cls, texts: list[str],