core: Add N(naming) ruff rules (#25362)

Public classes/functions are not renamed and rule is ignored for them.

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Christophe Bornet 2024-09-19 07:09:39 +02:00 committed by GitHub
parent 7835c0651f
commit fd21ffe293
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 289 additions and 222 deletions

View File

@ -155,7 +155,7 @@ def beta(
_name = _name or obj.fget.__qualname__ _name = _name or obj.fget.__qualname__
old_doc = obj.__doc__ old_doc = obj.__doc__
class _beta_property(property): class _BetaProperty(property):
"""A beta property.""" """A beta property."""
def __init__(self, fget=None, fset=None, fdel=None, doc=None): def __init__(self, fget=None, fset=None, fdel=None, doc=None):
@ -186,7 +186,7 @@ def beta(
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
"""Finalize the property.""" """Finalize the property."""
return _beta_property( return _BetaProperty(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
) )

View File

@ -264,7 +264,7 @@ def deprecated(
_name = _name or cast(Union[type, Callable], obj.fget).__qualname__ _name = _name or cast(Union[type, Callable], obj.fget).__qualname__
old_doc = obj.__doc__ old_doc = obj.__doc__
class _deprecated_property(property): class _DeprecatedProperty(property):
"""A deprecated property.""" """A deprecated property."""
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def] def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
@ -297,7 +297,7 @@ def deprecated(
"""Finalize the property.""" """Finalize the property."""
return cast( return cast(
T, T,
_deprecated_property( _DeprecatedProperty(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
), ),
) )

View File

@ -3,7 +3,7 @@
from typing import Any, Optional from typing import Any, Optional
class LangChainException(Exception): class LangChainException(Exception): # noqa: N818
"""General LangChain exception.""" """General LangChain exception."""
@ -11,7 +11,7 @@ class TracerException(LangChainException):
"""Base class for exceptions in tracers module.""" """Base class for exceptions in tracers module."""
class OutputParserException(ValueError, LangChainException): class OutputParserException(ValueError, LangChainException): # noqa: N818
"""Exception that output parsers should raise to signify a parsing error. """Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors This exists to differentiate parsing errors from other code or execution errors

View File

@ -14,7 +14,7 @@ from typing import (
) )
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypeAlias, TypedDict from typing_extensions import TypeAlias, TypedDict, override
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.messages import ( from langchain_core.messages import (
@ -143,6 +143,7 @@ class BaseLanguageModel(
return verbose return verbose
@property @property
@override
def InputType(self) -> TypeAlias: def InputType(self) -> TypeAlias:
"""Get the input type for this runnable.""" """Get the input type for this runnable."""
from langchain_core.prompt_values import ( from langchain_core.prompt_values import (

View File

@ -26,6 +26,7 @@ from pydantic import (
Field, Field,
model_validator, model_validator,
) )
from typing_extensions import override
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.caches import BaseCache from langchain_core.caches import BaseCache
@ -251,6 +252,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# --- Runnable methods --- # --- Runnable methods ---
@property @property
@override
def OutputType(self) -> Any: def OutputType(self) -> Any:
"""Get the output type for this runnable.""" """Get the output type for this runnable."""
return AnyMessage return AnyMessage

View File

@ -31,6 +31,7 @@ from tenacity import (
stop_after_attempt, stop_after_attempt,
wait_exponential, wait_exponential,
) )
from typing_extensions import override
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.caches import BaseCache from langchain_core.caches import BaseCache
@ -318,6 +319,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
# --- Runnable methods --- # --- Runnable methods ---
@property @property
@override
def OutputType(self) -> type[str]: def OutputType(self) -> type[str]:
"""Get the input type for this runnable.""" """Get the input type for this runnable."""
return str return str

View File

@ -10,6 +10,8 @@ from typing import (
Union, Union,
) )
from typing_extensions import override
from langchain_core.language_models import LanguageModelOutput from langchain_core.language_models import LanguageModelOutput
from langchain_core.messages import AnyMessage, BaseMessage from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
@ -63,11 +65,13 @@ class BaseGenerationOutputParser(
"""Base class to parse the output of an LLM call.""" """Base class to parse the output of an LLM call."""
@property @property
@override
def InputType(self) -> Any: def InputType(self) -> Any:
"""Return the input type for the parser.""" """Return the input type for the parser."""
return Union[str, AnyMessage] return Union[str, AnyMessage]
@property @property
@override
def OutputType(self) -> type[T]: def OutputType(self) -> type[T]:
"""Return the output type for the parser.""" """Return the output type for the parser."""
# even though mypy complains this isn't valid, # even though mypy complains this isn't valid,
@ -148,11 +152,13 @@ class BaseOutputParser(
""" # noqa: E501 """ # noqa: E501
@property @property
@override
def InputType(self) -> Any: def InputType(self) -> Any:
"""Return the input type for the parser.""" """Return the input type for the parser."""
return Union[str, AnyMessage] return Union[str, AnyMessage]
@property @property
@override
def OutputType(self) -> type[T]: def OutputType(self) -> type[T]:
"""Return the output type for the parser. """Return the output type for the parser.

View File

@ -3,6 +3,7 @@ from typing import Annotated, Generic, Optional
import pydantic import pydantic
from pydantic import SkipValidation from pydantic import SkipValidation
from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
@ -107,6 +108,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return "pydantic" return "pydantic"
@property @property
@override
def OutputType(self) -> type[TBaseModel]: def OutputType(self) -> type[TBaseModel]:
"""Return the pydantic model.""" """Return the pydantic model."""
return self.pydantic_object return self.pydantic_object

View File

@ -1,6 +1,6 @@
import re import re
import xml import xml
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET # noqa: N817
from collections.abc import AsyncIterator, Iterator from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from xml.etree.ElementTree import TreeBuilder from xml.etree.ElementTree import TreeBuilder
@ -46,14 +46,14 @@ class _StreamingParser:
""" """
if parser == "defusedxml": if parser == "defusedxml":
try: try:
from defusedxml import ElementTree as DET # type: ignore import defusedxml # type: ignore
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"defusedxml is not installed. " "defusedxml is not installed. "
"Please install it to use the defusedxml parser." "Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml` " "You can install it with `pip install defusedxml` "
) from e ) from e
_parser = DET.DefusedXMLParser(target=TreeBuilder()) _parser = defusedxml.ElementTree.DefusedXMLParser(target=TreeBuilder())
else: else:
_parser = None _parser = None
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser) self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
@ -189,7 +189,7 @@ class XMLOutputParser(BaseTransformOutputParser):
# likely if you're reading this you can move them to the top of the file # likely if you're reading this you can move them to the top of the file
if self.parser == "defusedxml": if self.parser == "defusedxml":
try: try:
from defusedxml import ElementTree as DET # type: ignore import defusedxml # type: ignore
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"defusedxml is not installed. " "defusedxml is not installed. "
@ -197,9 +197,9 @@ class XMLOutputParser(BaseTransformOutputParser):
"You can install it with `pip install defusedxml`" "You can install it with `pip install defusedxml`"
"See https://github.com/tiran/defusedxml for more details" "See https://github.com/tiran/defusedxml for more details"
) from e ) from e
_ET = DET # Use the defusedxml parser _et = defusedxml.ElementTree # Use the defusedxml parser
else: else:
_ET = ET # Use the standard library parser _et = ET # Use the standard library parser
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None: if match is not None:

View File

@ -18,7 +18,7 @@ from typing import (
import yaml import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self from typing_extensions import Self, override
from langchain_core.load import dumpd from langchain_core.load import dumpd
from langchain_core.output_parsers.base import BaseOutputParser from langchain_core.output_parsers.base import BaseOutputParser
@ -107,6 +107,7 @@ class BasePromptTemplate(
return dumpd(self) return dumpd(self)
@property @property
@override
def OutputType(self) -> Any: def OutputType(self) -> Any:
"""Return the output type of the prompt.""" """Return the output type of the prompt."""
return Union[StringPromptValue, ChatPromptValueConcrete] return Union[StringPromptValue, ChatPromptValueConcrete]

View File

@ -36,7 +36,7 @@ from typing import (
) )
from pydantic import BaseModel, ConfigDict, Field, RootModel from pydantic import BaseModel, ConfigDict, Field, RootModel
from typing_extensions import Literal, get_args from typing_extensions import Literal, get_args, override
from langchain_core._api import beta_decorator from langchain_core._api import beta_decorator
from langchain_core.load.serializable import ( from langchain_core.load.serializable import (
@ -272,7 +272,7 @@ class Runnable(Generic[Input, Output], ABC):
return name_ return name_
@property @property
def InputType(self) -> type[Input]: def InputType(self) -> type[Input]: # noqa: N802
"""The type of input this Runnable accepts specified as a type annotation.""" """The type of input this Runnable accepts specified as a type annotation."""
# First loop through all parent classes and if any of them is # First loop through all parent classes and if any of them is
# a pydantic model, we will pick up the generic parameterization # a pydantic model, we will pick up the generic parameterization
@ -297,7 +297,7 @@ class Runnable(Generic[Input, Output], ABC):
) )
@property @property
def OutputType(self) -> type[Output]: def OutputType(self) -> type[Output]: # noqa: N802
"""The type of output this Runnable produces specified as a type annotation.""" """The type of output this Runnable produces specified as a type annotation."""
# First loop through bases -- this will help generic # First loop through bases -- this will help generic
# any pydantic models. # any pydantic models.
@ -2811,11 +2811,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
) )
@property @property
@override
def InputType(self) -> type[Input]: def InputType(self) -> type[Input]:
"""The type of the input to the Runnable.""" """The type of the input to the Runnable."""
return self.first.InputType return self.first.InputType
@property @property
@override
def OutputType(self) -> type[Output]: def OutputType(self) -> type[Output]:
"""The type of the output of the Runnable.""" """The type of the output of the Runnable."""
return self.last.OutputType return self.last.OutputType
@ -3564,6 +3566,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
return super().get_name(suffix, name=name) return super().get_name(suffix, name=name)
@property @property
@override
def InputType(self) -> Any: def InputType(self) -> Any:
"""The type of the input to the Runnable.""" """The type of the input to the Runnable."""
for step in self.steps__.values(): for step in self.steps__.values():
@ -4057,6 +4060,7 @@ class RunnableGenerator(Runnable[Input, Output]):
self.name = "RunnableGenerator" self.name = "RunnableGenerator"
@property @property
@override
def InputType(self) -> Any: def InputType(self) -> Any:
func = getattr(self, "_transform", None) or self._atransform func = getattr(self, "_transform", None) or self._atransform
try: try:
@ -4097,6 +4101,7 @@ class RunnableGenerator(Runnable[Input, Output]):
) )
@property @property
@override
def OutputType(self) -> Any: def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or self._atransform func = getattr(self, "_transform", None) or self._atransform
try: try:
@ -4346,6 +4351,7 @@ class RunnableLambda(Runnable[Input, Output]):
pass pass
@property @property
@override
def InputType(self) -> Any: def InputType(self) -> Any:
"""The type of the input to this Runnable.""" """The type of the input to this Runnable."""
func = getattr(self, "func", None) or self.afunc func = getattr(self, "func", None) or self.afunc
@ -4405,6 +4411,7 @@ class RunnableLambda(Runnable[Input, Output]):
return super().get_input_schema(config) return super().get_input_schema(config)
@property @property
@override
def OutputType(self) -> Any: def OutputType(self) -> Any:
"""The type of the output of this Runnable as a type annotation. """The type of the output of this Runnable as a type annotation.
@ -4958,6 +4965,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
) )
@property @property
@override
def InputType(self) -> Any: def InputType(self) -> Any:
return list[self.bound.InputType] # type: ignore[name-defined] return list[self.bound.InputType] # type: ignore[name-defined]
@ -4981,6 +4989,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
) )
@property @property
@override
def OutputType(self) -> type[list[Output]]: def OutputType(self) -> type[list[Output]]:
return list[self.bound.OutputType] # type: ignore[name-defined] return list[self.bound.OutputType] # type: ignore[name-defined]
@ -5274,6 +5283,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
return self.bound.get_name(suffix, name=name) return self.bound.get_name(suffix, name=name)
@property @property
@override
def InputType(self) -> type[Input]: def InputType(self) -> type[Input]:
return ( return (
cast(type[Input], self.custom_input_type) cast(type[Input], self.custom_input_type)
@ -5282,6 +5292,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) )
@property @property
@override
def OutputType(self) -> type[Output]: def OutputType(self) -> type[Output]:
return ( return (
cast(type[Output], self.custom_output_type) cast(type[Output], self.custom_output_type)

View File

@ -16,6 +16,7 @@ from typing import (
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing_extensions import override
from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
@ -68,10 +69,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@property @property
@override
def InputType(self) -> type[Input]: def InputType(self) -> type[Input]:
return self.default.InputType return self.default.InputType
@property @property
@override
def OutputType(self) -> type[Output]: def OutputType(self) -> type[Output]:
return self.default.OutputType return self.default.OutputType

View File

@ -13,6 +13,7 @@ from typing import (
) )
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing_extensions import override
from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
@ -106,10 +107,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
) )
@property @property
@override
def InputType(self) -> type[Input]: def InputType(self) -> type[Input]:
return self.runnable.InputType return self.runnable.InputType
@property @property
@override
def OutputType(self) -> type[Output]: def OutputType(self) -> type[Output]:
return self.runnable.OutputType return self.runnable.OutputType

View File

@ -246,27 +246,27 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
# NOTE: coordinates might me negative, so we need to shift # NOTE: coordinates might me negative, so we need to shift
# everything to the positive plane before we actually draw it. # everything to the positive plane before we actually draw it.
Xs = [] xlist = []
Ys = [] ylist = []
sug = _build_sugiyama_layout(vertices, edges) sug = _build_sugiyama_layout(vertices, edges)
for vertex in sug.g.sV: for vertex in sug.g.sV:
# NOTE: moving boxes w/2 to the left # NOTE: moving boxes w/2 to the left
Xs.append(vertex.view.xy[0] - vertex.view.w / 2.0) xlist.append(vertex.view.xy[0] - vertex.view.w / 2.0)
Xs.append(vertex.view.xy[0] + vertex.view.w / 2.0) xlist.append(vertex.view.xy[0] + vertex.view.w / 2.0)
Ys.append(vertex.view.xy[1]) ylist.append(vertex.view.xy[1])
Ys.append(vertex.view.xy[1] + vertex.view.h) ylist.append(vertex.view.xy[1] + vertex.view.h)
for edge in sug.g.sE: for edge in sug.g.sE:
for x, y in edge.view._pts: for x, y in edge.view._pts:
Xs.append(x) xlist.append(x)
Ys.append(y) ylist.append(y)
minx = min(Xs) minx = min(xlist)
miny = min(Ys) miny = min(ylist)
maxx = max(Xs) maxx = max(xlist)
maxy = max(Ys) maxy = max(ylist)
canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1 canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1
canvas_lines = int(round(maxy - miny)) canvas_lines = int(round(maxy - miny))

View File

@ -12,6 +12,7 @@ from typing import (
) )
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override
from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load from langchain_core.load.load import load
@ -396,6 +397,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
) )
@property @property
@override
def OutputType(self) -> type[Output]: def OutputType(self) -> type[Output]:
output_type = self._history_chain.OutputType output_type = self._history_chain.OutputType
return output_type return output_type

View File

@ -16,6 +16,7 @@ from typing import (
) )
from pydantic import BaseModel, RootModel from pydantic import BaseModel, RootModel
from typing_extensions import override
from langchain_core.runnables.base import ( from langchain_core.runnables.base import (
Other, Other,
@ -193,10 +194,12 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@property @property
@override
def InputType(self) -> Any: def InputType(self) -> Any:
return self.input_type or Any return self.input_type or Any
@property @property
@override
def OutputType(self) -> Any: def OutputType(self) -> Any:
return self.input_type or Any return self.input_type or Any

View File

@ -28,7 +28,7 @@ from typing import (
Union, Union,
) )
from typing_extensions import TypeGuard from typing_extensions import TypeGuard, override
from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.schema import StreamEvent
@ -135,6 +135,7 @@ class IsLocalDict(ast.NodeVisitor):
self.name = name self.name = name
self.keys = keys self.keys = keys
@override
def visit_Subscript(self, node: ast.Subscript) -> Any: def visit_Subscript(self, node: ast.Subscript) -> Any:
"""Visit a subscript node. """Visit a subscript node.
@ -154,6 +155,7 @@ class IsLocalDict(ast.NodeVisitor):
# we've found a subscript access on the name we're looking for # we've found a subscript access on the name we're looking for
self.keys.add(node.slice.value) self.keys.add(node.slice.value)
@override
def visit_Call(self, node: ast.Call) -> Any: def visit_Call(self, node: ast.Call) -> Any:
"""Visit a call node. """Visit a call node.
@ -182,6 +184,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
def __init__(self) -> None: def __init__(self) -> None:
self.keys: set[str] = set() self.keys: set[str] = set()
@override
def visit_Lambda(self, node: ast.Lambda) -> Any: def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function. """Visit a lambda function.
@ -196,6 +199,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
input_arg_name = node.args.args[0].arg input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node.body) IsLocalDict(input_arg_name, self.keys).visit(node.body)
@override
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a function definition. """Visit a function definition.
@ -210,6 +214,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
input_arg_name = node.args.args[0].arg input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node) IsLocalDict(input_arg_name, self.keys).visit(node)
@override
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
"""Visit an async function definition. """Visit an async function definition.
@ -232,6 +237,7 @@ class NonLocals(ast.NodeVisitor):
self.loads: set[str] = set() self.loads: set[str] = set()
self.stores: set[str] = set() self.stores: set[str] = set()
@override
def visit_Name(self, node: ast.Name) -> Any: def visit_Name(self, node: ast.Name) -> Any:
"""Visit a name node. """Visit a name node.
@ -246,6 +252,7 @@ class NonLocals(ast.NodeVisitor):
elif isinstance(node.ctx, ast.Store): elif isinstance(node.ctx, ast.Store):
self.stores.add(node.id) self.stores.add(node.id)
@override
def visit_Attribute(self, node: ast.Attribute) -> Any: def visit_Attribute(self, node: ast.Attribute) -> Any:
"""Visit an attribute node. """Visit an attribute node.
@ -272,6 +279,7 @@ class FunctionNonLocals(ast.NodeVisitor):
def __init__(self) -> None: def __init__(self) -> None:
self.nonlocals: set[str] = set() self.nonlocals: set[str] = set()
@override
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a function definition. """Visit a function definition.
@ -285,6 +293,7 @@ class FunctionNonLocals(ast.NodeVisitor):
visitor.visit(node) visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores) self.nonlocals.update(visitor.loads - visitor.stores)
@override
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
"""Visit an async function definition. """Visit an async function definition.
@ -298,6 +307,7 @@ class FunctionNonLocals(ast.NodeVisitor):
visitor.visit(node) visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores) self.nonlocals.update(visitor.loads - visitor.stores)
@override
def visit_Lambda(self, node: ast.Lambda) -> Any: def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function. """Visit a lambda function.
@ -320,6 +330,7 @@ class GetLambdaSource(ast.NodeVisitor):
self.source: Optional[str] = None self.source: Optional[str] = None
self.count = 0 self.count = 0
@override
def visit_Lambda(self, node: ast.Lambda) -> Any: def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function. """Visit a lambda function.

View File

@ -311,7 +311,7 @@ def create_schema_from_function(
) )
class ToolException(Exception): class ToolException(Exception): # noqa: N818
"""Optional exception that tool throws when execution error occurs. """Optional exception that tool throws when execution error occurs.
When this exception is thrown, the agent will not stop working, When this exception is thrown, the agent will not stop working,

View File

@ -9,7 +9,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any:
) )
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802
"""Throw an error because this has been replaced by LangChainTracer.""" """Throw an error because this has been replaced by LangChainTracer."""
raise RuntimeError( raise RuntimeError(
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead." "LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."

View File

@ -18,7 +18,7 @@ from langchain_core._api import deprecated
@deprecated("0.1.0", alternative="Use string instead.", removal="1.0") @deprecated("0.1.0", alternative="Use string instead.", removal="1.0")
def RunTypeEnum() -> type[RunTypeEnumDep]: def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
"""RunTypeEnum.""" """RunTypeEnum."""
warnings.warn( warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead" "RunTypeEnum is deprecated. Please directly use a string instead"

View File

@ -235,7 +235,7 @@ class Tee(Generic[T]):
atee = Tee atee = Tee
class aclosing(AbstractAsyncContextManager): class aclosing(AbstractAsyncContextManager): # noqa: N801
"""Async context manager for safely finalizing an asynchronously cleaned-up """Async context manager for safely finalizing an asynchronously cleaned-up
resource such as an async generator, calling its ``aclose()`` method. resource such as an async generator, calling its ``aclose()`` method.

View File

@ -17,12 +17,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
"""Row-wise cosine similarity between two equal-width matrices. """Row-wise cosine similarity between two equal-width matrices.
Args: Args:
X: A matrix of shape (n, m). x: A matrix of shape (n, m).
Y: A matrix of shape (k, m). y: A matrix of shape (k, m).
Returns: Returns:
A matrix of shape (n, k) where each element (i, j) is the cosine similarity A matrix of shape (n, k) where each element (i, j) is the cosine similarity
@ -40,33 +40,33 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"Please install numpy with `pip install numpy`." "Please install numpy with `pip install numpy`."
) from e ) from e
if len(X) == 0 or len(Y) == 0: if len(x) == 0 or len(y) == 0:
return np.array([]) return np.array([])
X = np.array(X) x = np.array(x)
Y = np.array(Y) y = np.array(y)
if X.shape[1] != Y.shape[1]: if x.shape[1] != y.shape[1]:
raise ValueError( raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} " f"Number of columns in X and Y must be the same. X has shape {x.shape} "
f"and Y has shape {Y.shape}." f"and Y has shape {y.shape}."
) )
try: try:
import simsimd as simd # type: ignore[import-not-found] import simsimd as simd # type: ignore[import-not-found]
X = np.array(X, dtype=np.float32) x = np.array(x, dtype=np.float32)
Y = np.array(Y, dtype=np.float32) y = np.array(y, dtype=np.float32)
Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) z = 1 - np.array(simd.cdist(x, y, metric="cosine"))
return Z return z
except ImportError: except ImportError:
logger.debug( logger.debug(
"Unable to import simsimd, defaulting to NumPy implementation. If you want " "Unable to import simsimd, defaulting to NumPy implementation. If you want "
"to use simsimd please install with `pip install simsimd`." "to use simsimd please install with `pip install simsimd`."
) )
X_norm = np.linalg.norm(X, axis=1) x_norm = np.linalg.norm(x, axis=1)
Y_norm = np.linalg.norm(Y, axis=1) y_norm = np.linalg.norm(y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below. # Ignore divide by zero errors run time warnings as those are handled below.
with np.errstate(divide="ignore", invalid="ignore"): with np.errstate(divide="ignore", invalid="ignore"):
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity return similarity

View File

@ -44,9 +44,17 @@ python = ">=3.12.4"
target-version = "py39" target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["B", "E", "F", "I", "T201", "UP"] select = ["B", "E", "F", "I", "N", "T201", "UP"]
ignore = ["UP007"] ignore = ["UP007"]
[tool.ruff.lint.pep8-naming]
classmethod-decorators = [
"classmethod",
"langchain_core.utils.pydantic.pre_init",
"pydantic.field_validator",
"pydantic.v1.root_validator",
]
[tool.coverage.run] [tool.coverage.run]
omit = [ "tests/*",] omit = [ "tests/*",]

View File

@ -9,9 +9,9 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import ( from tests.unit_tests.stubs import (
_AnyIdAIMessage, _any_id_ai_message,
_AnyIdAIMessageChunk, _any_id_ai_message_chunk,
_AnyIdHumanMessage, _any_id_human_message,
) )
@ -20,11 +20,11 @@ def test_generic_fake_chat_model_invoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle) model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow") response = model.invoke("meow")
assert response == _AnyIdAIMessage(content="hello") assert response == _any_id_ai_message(content="hello")
response = model.invoke("kitty") response = model.invoke("kitty")
assert response == _AnyIdAIMessage(content="goodbye") assert response == _any_id_ai_message(content="goodbye")
response = model.invoke("meow") response = model.invoke("meow")
assert response == _AnyIdAIMessage(content="hello") assert response == _any_id_ai_message(content="hello")
async def test_generic_fake_chat_model_ainvoke() -> None: async def test_generic_fake_chat_model_ainvoke() -> None:
@ -32,11 +32,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle) model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow") response = await model.ainvoke("meow")
assert response == _AnyIdAIMessage(content="hello") assert response == _any_id_ai_message(content="hello")
response = await model.ainvoke("kitty") response = await model.ainvoke("kitty")
assert response == _AnyIdAIMessage(content="goodbye") assert response == _any_id_ai_message(content="goodbye")
response = await model.ainvoke("meow") response = await model.ainvoke("meow")
assert response == _AnyIdAIMessage(content="hello") assert response == _any_id_ai_message(content="hello")
async def test_generic_fake_chat_model_stream() -> None: async def test_generic_fake_chat_model_stream() -> None:
@ -49,17 +49,17 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle) model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")] chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [ assert chunks == [
_AnyIdAIMessageChunk(content="hello"), _any_id_ai_message_chunk(content="hello"),
_AnyIdAIMessageChunk(content=" "), _any_id_ai_message_chunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"), _any_id_ai_message_chunk(content="goodbye"),
] ]
assert len({chunk.id for chunk in chunks}) == 1 assert len({chunk.id for chunk in chunks}) == 1
chunks = [chunk for chunk in model.stream("meow")] chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [ assert chunks == [
_AnyIdAIMessageChunk(content="hello"), _any_id_ai_message_chunk(content="hello"),
_AnyIdAIMessageChunk(content=" "), _any_id_ai_message_chunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"), _any_id_ai_message_chunk(content="goodbye"),
] ]
assert len({chunk.id for chunk in chunks}) == 1 assert len({chunk.id for chunk in chunks}) == 1
@ -69,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=cycle([message])) model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")] chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [ assert chunks == [
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}), _any_id_ai_message_chunk(content="", additional_kwargs={"foo": 42}),
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}), _any_id_ai_message_chunk(content="", additional_kwargs={"bar": 24}),
] ]
assert len({chunk.id for chunk in chunks}) == 1 assert len({chunk.id for chunk in chunks}) == 1
@ -88,19 +88,19 @@ async def test_generic_fake_chat_model_stream() -> None:
chunks = [chunk async for chunk in model.astream("meow")] chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [ assert chunks == [
_AnyIdAIMessageChunk( _any_id_ai_message_chunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}} content="", additional_kwargs={"function_call": {"name": "move_file"}}
), ),
_AnyIdAIMessageChunk( _any_id_ai_message_chunk(
content="", content="",
additional_kwargs={ additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'}, "function_call": {"arguments": '{\n "source_path": "foo"'},
}, },
), ),
_AnyIdAIMessageChunk( _any_id_ai_message_chunk(
content="", additional_kwargs={"function_call": {"arguments": ","}} content="", additional_kwargs={"function_call": {"arguments": ","}}
), ),
_AnyIdAIMessageChunk( _any_id_ai_message_chunk(
content="", content="",
additional_kwargs={ additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}, "function_call": {"arguments": '\n "destination_path": "bar"\n}'},
@ -138,9 +138,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
] ]
final = log_patches[-1] final = log_patches[-1]
assert final.state["streamed_output"] == [ assert final.state["streamed_output"] == [
_AnyIdAIMessageChunk(content="hello"), _any_id_ai_message_chunk(content="hello"),
_AnyIdAIMessageChunk(content=" "), _any_id_ai_message_chunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"), _any_id_ai_message_chunk(content="goodbye"),
] ]
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1 assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
@ -189,9 +189,9 @@ async def test_callback_handlers() -> None:
# New model # New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [ assert results == [
_AnyIdAIMessageChunk(content="hello"), _any_id_ai_message_chunk(content="hello"),
_AnyIdAIMessageChunk(content=" "), _any_id_ai_message_chunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"), _any_id_ai_message_chunk(content="goodbye"),
] ]
assert tokens == ["hello", " ", "goodbye"] assert tokens == ["hello", " ", "goodbye"]
assert len({chunk.id for chunk in results}) == 1 assert len({chunk.id for chunk in results}) == 1
@ -200,6 +200,8 @@ async def test_callback_handlers() -> None:
def test_chat_model_inputs() -> None: def test_chat_model_inputs() -> None:
fake = ParrotFakeChatModel() fake = ParrotFakeChatModel()
assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello") assert fake.invoke("hello") == _any_id_human_message(content="hello")
assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah") assert fake.invoke([("ai", "blah")]) == _any_id_ai_message(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah") assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message(
content="blah"
)

View File

@ -27,7 +27,7 @@ from tests.unit_tests.fake.callbacks import (
FakeAsyncCallbackHandler, FakeAsyncCallbackHandler,
FakeCallbackHandler, FakeCallbackHandler,
) )
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
@pytest.fixture @pytest.fixture
@ -147,10 +147,10 @@ async def test_astream_fallback_to_ainvoke() -> None:
model = ModelWithGenerate() model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")] chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [_AnyIdAIMessage(content="hello")] assert chunks == [_any_id_ai_message(content="hello")]
chunks = [chunk async for chunk in model.astream("anything")] chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [_AnyIdAIMessage(content="hello")] assert chunks == [_any_id_ai_message(content="hello")]
async def test_astream_implementation_fallback_to_stream() -> None: async def test_astream_implementation_fallback_to_stream() -> None:
@ -185,15 +185,15 @@ async def test_astream_implementation_fallback_to_stream() -> None:
model = ModelWithSyncStream() model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")] chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [ assert chunks == [
_AnyIdAIMessageChunk(content="a"), _any_id_ai_message_chunk(content="a"),
_AnyIdAIMessageChunk(content="b"), _any_id_ai_message_chunk(content="b"),
] ]
assert len({chunk.id for chunk in chunks}) == 1 assert len({chunk.id for chunk in chunks}) == 1
assert type(model)._astream == BaseChatModel._astream assert type(model)._astream == BaseChatModel._astream
astream_chunks = [chunk async for chunk in model.astream("anything")] astream_chunks = [chunk async for chunk in model.astream("anything")]
assert astream_chunks == [ assert astream_chunks == [
_AnyIdAIMessageChunk(content="a"), _any_id_ai_message_chunk(content="a"),
_AnyIdAIMessageChunk(content="b"), _any_id_ai_message_chunk(content="b"),
] ]
assert len({chunk.id for chunk in astream_chunks}) == 1 assert len({chunk.id for chunk in astream_chunks}) == 1
@ -230,8 +230,8 @@ async def test_astream_implementation_uses_astream() -> None:
model = ModelWithAsyncStream() model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")] chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [ assert chunks == [
_AnyIdAIMessageChunk(content="a"), _any_id_ai_message_chunk(content="a"),
_AnyIdAIMessageChunk(content="b"), _any_id_ai_message_chunk(content="b"),
] ]
assert len({chunk.id for chunk in chunks}) == 1 assert len({chunk.id for chunk in chunks}) == 1

View File

@ -25,7 +25,7 @@ def change_directory(dir: Path) -> Iterator:
os.chdir(origin) os.chdir(origin)
def test_loading_from_YAML() -> None: def test_loading_from_yaml() -> None:
"""Test loading from yaml file.""" """Test loading from yaml file."""
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml") prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml")
expected_prompt = PromptTemplate( expected_prompt = PromptTemplate(
@ -36,7 +36,7 @@ def test_loading_from_YAML() -> None:
assert prompt == expected_prompt assert prompt == expected_prompt
def test_loading_from_JSON() -> None: def test_loading_from_json() -> None:
"""Test loading from json file.""" """Test loading from json file."""
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json") prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json")
expected_prompt = PromptTemplate( expected_prompt = PromptTemplate(
@ -46,14 +46,14 @@ def test_loading_from_JSON() -> None:
assert prompt == expected_prompt assert prompt == expected_prompt
def test_loading_jinja_from_JSON() -> None: def test_loading_jinja_from_json() -> None:
"""Test that loading jinja2 format prompts from JSON raises ValueError.""" """Test that loading jinja2 format prompts from JSON raises ValueError."""
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json"
with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"): with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"):
load_prompt(prompt_path) load_prompt(prompt_path)
def test_loading_jinja_from_YAML() -> None: def test_loading_jinja_from_yaml() -> None:
"""Test that loading jinja2 format prompts from YAML raises ValueError.""" """Test that loading jinja2 format prompts from YAML raises ValueError."""
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml"
with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"): with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"):

View File

@ -2,6 +2,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from syrupy import SnapshotAssertion from syrupy import SnapshotAssertion
from typing_extensions import override
from langchain_core.language_models import FakeListLLM from langchain_core.language_models import FakeListLLM
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
@ -353,9 +354,11 @@ def test_runnable_get_graph_with_invalid_input_type() -> None:
class InvalidInputTypeRunnable(Runnable[int, int]): class InvalidInputTypeRunnable(Runnable[int, int]):
@property @property
@override
def InputType(self) -> type: def InputType(self) -> type:
raise TypeError() raise TypeError()
@override
def invoke( def invoke(
self, self,
input: int, input: int,
@ -375,9 +378,11 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
class InvalidOutputTypeRunnable(Runnable[int, int]): class InvalidOutputTypeRunnable(Runnable[int, int]):
@property @property
@override
def OutputType(self) -> type: def OutputType(self) -> type:
raise TypeError() raise TypeError()
@override
def invoke( def invoke(
self, self,
input: int, input: int,

View File

@ -493,7 +493,7 @@ def test_get_output_schema() -> None:
def test_get_input_schema_input_messages() -> None: def test_get_input_schema_input_messages() -> None:
from pydantic import RootModel from pydantic import RootModel
RunnableWithMessageHistoryInput = RootModel[Sequence[BaseMessage]] runnable_with_message_history_input = RootModel[Sequence[BaseMessage]]
runnable = RunnableLambda( runnable = RunnableLambda(
lambda messages: { lambda messages: {
@ -515,7 +515,7 @@ def test_get_input_schema_input_messages() -> None:
with_history = RunnableWithMessageHistory( with_history = RunnableWithMessageHistory(
runnable, get_session_history, output_messages_key="output" runnable, get_session_history, output_messages_key="output"
) )
expected_schema = _schema(RunnableWithMessageHistoryInput) expected_schema = _schema(runnable_with_message_history_input)
expected_schema["title"] = "RunnableWithChatHistoryInput" expected_schema["title"] = "RunnableWithChatHistoryInput"
assert _schema(with_history.get_input_schema()) == expected_schema assert _schema(with_history.get_input_schema()) == expected_schema

View File

@ -87,7 +87,7 @@ from langchain_core.tracers import (
) )
from langchain_core.tracers.context import collect_runs from langchain_core.tracers.context import collect_runs
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split("."))) PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
@ -1699,7 +1699,7 @@ def test_prompt_with_chat_model(
tracer = FakeTracer() tracer = FakeTracer()
assert chain.invoke( assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, dict(callbacks=[tracer])
) == _AnyIdAIMessage(content="foo") ) == _any_id_ai_message(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue( assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[ messages=[
@ -1724,8 +1724,8 @@ def test_prompt_with_chat_model(
], ],
dict(callbacks=[tracer]), dict(callbacks=[tracer]),
) == [ ) == [
_AnyIdAIMessage(content="foo"), _any_id_ai_message(content="foo"),
_AnyIdAIMessage(content="foo"), _any_id_ai_message(content="foo"),
] ]
assert prompt_spy.call_args.args[1] == [ assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"}, {"question": "What is your name?"},
@ -1765,9 +1765,9 @@ def test_prompt_with_chat_model(
assert [ assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer])) *chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
] == [ ] == [
_AnyIdAIMessageChunk(content="f"), _any_id_ai_message_chunk(content="f"),
_AnyIdAIMessageChunk(content="o"), _any_id_ai_message_chunk(content="o"),
_AnyIdAIMessageChunk(content="o"), _any_id_ai_message_chunk(content="o"),
] ]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue( assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -1805,7 +1805,7 @@ async def test_prompt_with_chat_model_async(
tracer = FakeTracer() tracer = FakeTracer()
assert await chain.ainvoke( assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, dict(callbacks=[tracer])
) == _AnyIdAIMessage(content="foo") ) == _any_id_ai_message(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue( assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[ messages=[
@ -1830,8 +1830,8 @@ async def test_prompt_with_chat_model_async(
], ],
dict(callbacks=[tracer]), dict(callbacks=[tracer]),
) == [ ) == [
_AnyIdAIMessage(content="foo"), _any_id_ai_message(content="foo"),
_AnyIdAIMessage(content="foo"), _any_id_ai_message(content="foo"),
] ]
assert prompt_spy.call_args.args[1] == [ assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"}, {"question": "What is your name?"},
@ -1874,9 +1874,9 @@ async def test_prompt_with_chat_model_async(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, dict(callbacks=[tracer])
) )
] == [ ] == [
_AnyIdAIMessageChunk(content="f"), _any_id_ai_message_chunk(content="f"),
_AnyIdAIMessageChunk(content="o"), _any_id_ai_message_chunk(content="o"),
_AnyIdAIMessageChunk(content="o"), _any_id_ai_message_chunk(content="o"),
] ]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue( assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -2548,7 +2548,7 @@ def test_prompt_with_chat_model_and_parser(
HumanMessage(content="What is your name?"), HumanMessage(content="What is your name?"),
] ]
) )
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar") assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar")
assert tracer.runs == snapshot assert tracer.runs == snapshot
@ -2681,7 +2681,7 @@ Question:
), ),
] ]
) )
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar") assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar")
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None) parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 4 assert len(parent_run.child_runs) == 4
@ -2727,7 +2727,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
assert chain.invoke( assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, dict(callbacks=[tracer])
) == { ) == {
"chat": _AnyIdAIMessage(content="i'm a chatbot"), "chat": _any_id_ai_message(content="i'm a chatbot"),
"llm": "i'm a textbot", "llm": "i'm a textbot",
} }
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -2936,7 +2936,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
assert chain.invoke( assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, dict(callbacks=[tracer])
) == { ) == {
"chat": _AnyIdAIMessage(content="i'm a chatbot"), "chat": _any_id_ai_message(content="i'm a chatbot"),
"llm": "i'm a textbot", "llm": "i'm a textbot",
"passthrough": ChatPromptValue( "passthrough": ChatPromptValue(
messages=[ messages=[
@ -3000,7 +3000,7 @@ def test_map_stream() -> None:
assert streamed_chunks[0] in [ assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})}, {"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"}, {"llm": "i"},
{"chat": _AnyIdAIMessageChunk(content="i")}, {"chat": _any_id_ai_message_chunk(content="i")},
] ]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1 assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks) assert all(len(c.keys()) == 1 for c in streamed_chunks)
@ -3059,11 +3059,11 @@ def test_map_stream() -> None:
assert streamed_chunks[0] in [ assert streamed_chunks[0] in [
{"llm": "i"}, {"llm": "i"},
{"chat": _AnyIdAIMessageChunk(content="i")}, {"chat": _any_id_ai_message_chunk(content="i")},
] ]
if not ( # TODO(Rewrite properly) statement above if not ( # TODO(Rewrite properly) statement above
streamed_chunks[0] == {"llm": "i"} streamed_chunks[0] == {"llm": "i"}
or {"chat": _AnyIdAIMessageChunk(content="i")} or {"chat": _any_id_ai_message_chunk(content="i")}
): ):
raise AssertionError(f"Got an unexpected chunk: {streamed_chunks[0]}") raise AssertionError(f"Got an unexpected chunk: {streamed_chunks[0]}")
@ -3108,7 +3108,7 @@ def test_map_stream_iterator_input() -> None:
assert streamed_chunks[0] in [ assert streamed_chunks[0] in [
{"passthrough": "i"}, {"passthrough": "i"},
{"llm": "i"}, {"llm": "i"},
{"chat": _AnyIdAIMessageChunk(content="i")}, {"chat": _any_id_ai_message_chunk(content="i")},
] ]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res) assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
assert all(len(c.keys()) == 1 for c in streamed_chunks) assert all(len(c.keys()) == 1 for c in streamed_chunks)
@ -3152,7 +3152,7 @@ async def test_map_astream() -> None:
assert streamed_chunks[0] in [ assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})}, {"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"}, {"llm": "i"},
{"chat": _AnyIdAIMessageChunk(content="i")}, {"chat": _any_id_ai_message_chunk(content="i")},
] ]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1 assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks) assert all(len(c.keys()) == 1 for c in streamed_chunks)

View File

@ -32,7 +32,7 @@ from langchain_core.runnables import (
from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import tool from langchain_core.tools import tool
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]: def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
@ -503,7 +503,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": {"a": "b"}, "metadata": {"a": "b"},
"name": "my_model", "name": "my_model",
@ -512,7 +512,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": {"a": "b"}, "metadata": {"a": "b"},
"name": "my_model", "name": "my_model",
@ -521,7 +521,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": {"a": "b"}, "metadata": {"a": "b"},
"name": "my_model", "name": "my_model",
@ -530,7 +530,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"output": _AnyIdAIMessageChunk(content="hello world!")}, "data": {"output": _any_id_ai_message_chunk(content="hello world!")},
"event": "on_chat_model_end", "event": "on_chat_model_end",
"metadata": {"a": "b"}, "metadata": {"a": "b"},
"name": "my_model", "name": "my_model",
@ -575,7 +575,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -588,7 +588,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -601,7 +601,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -621,7 +621,9 @@ async def test_astream_events_from_model() -> None:
[ [
{ {
"generation_info": None, "generation_info": None,
"message": _AnyIdAIMessage(content="hello world!"), "message": _any_id_ai_message(
content="hello world!"
),
"text": "hello world!", "text": "hello world!",
"type": "ChatGeneration", "type": "ChatGeneration",
} }
@ -644,7 +646,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessage(content="hello world!")}, "data": {"chunk": _any_id_ai_message(content="hello world!")},
"event": "on_chain_stream", "event": "on_chain_stream",
"metadata": {}, "metadata": {},
"name": "i_dont_stream", "name": "i_dont_stream",
@ -653,7 +655,7 @@ async def test_astream_events_from_model() -> None:
"tags": [], "tags": [],
}, },
{ {
"data": {"output": _AnyIdAIMessage(content="hello world!")}, "data": {"output": _any_id_ai_message(content="hello world!")},
"event": "on_chain_end", "event": "on_chain_end",
"metadata": {}, "metadata": {},
"name": "i_dont_stream", "name": "i_dont_stream",
@ -698,7 +700,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -711,7 +713,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -724,7 +726,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -744,7 +746,9 @@ async def test_astream_events_from_model() -> None:
[ [
{ {
"generation_info": None, "generation_info": None,
"message": _AnyIdAIMessage(content="hello world!"), "message": _any_id_ai_message(
content="hello world!"
),
"text": "hello world!", "text": "hello world!",
"type": "ChatGeneration", "type": "ChatGeneration",
} }
@ -767,7 +771,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessage(content="hello world!")}, "data": {"chunk": _any_id_ai_message(content="hello world!")},
"event": "on_chain_stream", "event": "on_chain_stream",
"metadata": {}, "metadata": {},
"name": "ai_dont_stream", "name": "ai_dont_stream",
@ -776,7 +780,7 @@ async def test_astream_events_from_model() -> None:
"tags": [], "tags": [],
}, },
{ {
"data": {"output": _AnyIdAIMessage(content="hello world!")}, "data": {"output": _any_id_ai_message(content="hello world!")},
"event": "on_chain_end", "event": "on_chain_end",
"metadata": {}, "metadata": {},
"name": "ai_dont_stream", "name": "ai_dont_stream",

View File

@ -47,7 +47,7 @@ from langchain_core.utils.aiter import aclosing
from tests.unit_tests.runnables.test_runnable_events_v1 import ( from tests.unit_tests.runnables.test_runnable_events_v1 import (
_assert_events_equal_allow_superset_metadata, _assert_events_equal_allow_superset_metadata,
) )
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]: def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
@ -533,7 +533,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -546,7 +546,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -559,7 +559,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -573,7 +573,7 @@ async def test_astream_events_from_model() -> None:
}, },
{ {
"data": { "data": {
"output": _AnyIdAIMessageChunk(content="hello world!"), "output": _any_id_ai_message_chunk(content="hello world!"),
}, },
"event": "on_chat_model_end", "event": "on_chat_model_end",
"metadata": { "metadata": {
@ -640,7 +640,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -653,7 +653,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -666,7 +666,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -681,7 +681,7 @@ async def test_astream_with_model_in_chain() -> None:
{ {
"data": { "data": {
"input": {"messages": [[HumanMessage(content="hello")]]}, "input": {"messages": [[HumanMessage(content="hello")]]},
"output": _AnyIdAIMessage(content="hello world!"), "output": _any_id_ai_message(content="hello world!"),
}, },
"event": "on_chat_model_end", "event": "on_chat_model_end",
"metadata": { "metadata": {
@ -695,7 +695,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessage(content="hello world!")}, "data": {"chunk": _any_id_ai_message(content="hello world!")},
"event": "on_chain_stream", "event": "on_chain_stream",
"metadata": {}, "metadata": {},
"name": "i_dont_stream", "name": "i_dont_stream",
@ -704,7 +704,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": [], "tags": [],
}, },
{ {
"data": {"output": _AnyIdAIMessage(content="hello world!")}, "data": {"output": _any_id_ai_message(content="hello world!")},
"event": "on_chain_end", "event": "on_chain_end",
"metadata": {}, "metadata": {},
"name": "i_dont_stream", "name": "i_dont_stream",
@ -749,7 +749,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -762,7 +762,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -775,7 +775,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": { "metadata": {
"a": "b", "a": "b",
@ -790,7 +790,7 @@ async def test_astream_with_model_in_chain() -> None:
{ {
"data": { "data": {
"input": {"messages": [[HumanMessage(content="hello")]]}, "input": {"messages": [[HumanMessage(content="hello")]]},
"output": _AnyIdAIMessage(content="hello world!"), "output": _any_id_ai_message(content="hello world!"),
}, },
"event": "on_chat_model_end", "event": "on_chat_model_end",
"metadata": { "metadata": {
@ -804,7 +804,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"], "tags": ["my_model"],
}, },
{ {
"data": {"chunk": _AnyIdAIMessage(content="hello world!")}, "data": {"chunk": _any_id_ai_message(content="hello world!")},
"event": "on_chain_stream", "event": "on_chain_stream",
"metadata": {}, "metadata": {},
"name": "ai_dont_stream", "name": "ai_dont_stream",
@ -813,7 +813,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": [], "tags": [],
}, },
{ {
"data": {"output": _AnyIdAIMessage(content="hello world!")}, "data": {"output": _any_id_ai_message(content="hello world!")},
"event": "on_chain_end", "event": "on_chain_end",
"metadata": {}, "metadata": {},
"name": "ai_dont_stream", "name": "ai_dont_stream",

View File

@ -16,28 +16,28 @@ class AnyStr(str):
# subclassed strings. # subclassed strings.
def _AnyIdDocument(**kwargs: Any) -> Document: def _any_id_document(**kwargs: Any) -> Document:
"""Create a document with an id field.""" """Create a document with an id field."""
message = Document(**kwargs) message = Document(**kwargs)
message.id = AnyStr() message.id = AnyStr()
return message return message
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage: def _any_id_ai_message(**kwargs: Any) -> AIMessage:
"""Create ai message with an any id field.""" """Create ai message with an any id field."""
message = AIMessage(**kwargs) message = AIMessage(**kwargs)
message.id = AnyStr() message.id = AnyStr()
return message return message
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk: def _any_id_ai_message_chunk(**kwargs: Any) -> AIMessageChunk:
"""Create ai message with an any id field.""" """Create ai message with an any id field."""
message = AIMessageChunk(**kwargs) message = AIMessageChunk(**kwargs)
message.id = AnyStr() message.id = AnyStr()
return message return message
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: def _any_id_human_message(**kwargs: Any) -> HumanMessage:
"""Create a human with an any id field.""" """Create a human with an any id field."""
message = HumanMessage(**kwargs) message = HumanMessage(**kwargs)
message.id = AnyStr() message.id = AnyStr()

View File

@ -781,7 +781,7 @@ def test_convert_to_messages() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"MessageClass", "message_class",
[ [
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@ -790,39 +790,39 @@ def test_convert_to_messages() -> None:
SystemMessage, SystemMessage,
], ],
) )
def test_message_name(MessageClass: type) -> None: def test_message_name(message_class: type) -> None:
msg = MessageClass(content="foo", name="bar") msg = message_class(content="foo", name="bar")
assert msg.name == "bar" assert msg.name == "bar"
msg2 = MessageClass(content="foo", name=None) msg2 = message_class(content="foo", name=None)
assert msg2.name is None assert msg2.name is None
msg3 = MessageClass(content="foo") msg3 = message_class(content="foo")
assert msg3.name is None assert msg3.name is None
@pytest.mark.parametrize( @pytest.mark.parametrize(
"MessageClass", "message_class",
[FunctionMessage, FunctionMessageChunk], [FunctionMessage, FunctionMessageChunk],
) )
def test_message_name_function(MessageClass: type) -> None: def test_message_name_function(message_class: type) -> None:
# functionmessage doesn't support name=None # functionmessage doesn't support name=None
msg = MessageClass(name="foo", content="bar") msg = message_class(name="foo", content="bar")
assert msg.name == "foo" assert msg.name == "foo"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"MessageClass", "message_class",
[ChatMessage, ChatMessageChunk], [ChatMessage, ChatMessageChunk],
) )
def test_message_name_chat(MessageClass: type) -> None: def test_message_name_chat(message_class: type) -> None:
msg = MessageClass(content="foo", role="user", name="bar") msg = message_class(content="foo", role="user", name="bar")
assert msg.name == "bar" assert msg.name == "bar"
msg2 = MessageClass(content="foo", role="user", name=None) msg2 = message_class(content="foo", role="user", name=None)
assert msg2.name is None assert msg2.name is None
msg3 = MessageClass(content="foo", role="user") msg3 = message_class(content="foo", role="user")
assert msg3.name is None assert msg3.name is None

View File

@ -51,14 +51,14 @@ def test_serde_any_message() -> None:
), ),
] ]
Model = RootModel[AnyMessage] model = RootModel[AnyMessage]
for lc_object in lc_objects: for lc_object in lc_objects:
d = lc_object.model_dump() d = lc_object.model_dump()
assert "type" in d, f"Missing key `type` for {type(lc_object)}" assert "type" in d, f"Missing key `type` for {type(lc_object)}"
obj1 = Model.model_validate(d) obj1 = model.model_validate(d)
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}" assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
with pytest.raises((TypeError, ValidationError)): with pytest.raises((TypeError, ValidationError)):
# Make sure that specifically validation error is raised # Make sure that specifically validation error is raised
Model.model_validate({}) model.model_validate({})

View File

@ -1421,7 +1421,7 @@ class InjectedTool(BaseTool):
return y return y
class fooSchema(BaseModel): class fooSchema(BaseModel): # noqa: N801
"""foo.""" """foo."""
x: int = Field(..., description="abc") x: int = Field(..., description="abc")
@ -1568,14 +1568,14 @@ def test_tool_injected_arg() -> None:
def test_tool_inherited_injected_arg() -> None: def test_tool_inherited_injected_arg() -> None:
class barSchema(BaseModel): class BarSchema(BaseModel):
"""bar.""" """bar."""
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field( y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
..., description="123" ..., description="123"
) )
class fooSchema(barSchema): class FooSchema(BarSchema):
"""foo.""" """foo."""
x: int = Field(..., description="abc") x: int = Field(..., description="abc")
@ -1583,14 +1583,14 @@ def test_tool_inherited_injected_arg() -> None:
class InheritedInjectedArgTool(BaseTool): class InheritedInjectedArgTool(BaseTool):
name: str = "foo" name: str = "foo"
description: str = "foo." description: str = "foo."
args_schema: type[BaseModel] = fooSchema args_schema: type[BaseModel] = FooSchema
def _run(self, x: int, y: str) -> Any: def _run(self, x: int, y: str) -> Any:
return y return y
tool_ = InheritedInjectedArgTool() tool_ = InheritedInjectedArgTool()
assert tool_.get_input_schema().model_json_schema() == { assert tool_.get_input_schema().model_json_schema() == {
"title": "fooSchema", # Matches the title from the provided schema "title": "FooSchema", # Matches the title from the provided schema
"description": "foo.", "description": "foo.",
"type": "object", "type": "object",
"properties": { "properties": {
@ -1877,15 +1877,15 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
A = TypeVar("A") A = TypeVar("A")
if use_v1_namespace: if use_v1_namespace:
from pydantic.v1 import BaseModel as BM1 from pydantic.v1 import BaseModel as BaseModel1
class ModelA(BM1, Generic[A], extra="allow"): class ModelA(BaseModel1, Generic[A], extra="allow"):
a: A a: A
else: else:
from pydantic import BaseModel as BM2 from pydantic import BaseModel as BaseModel2
from pydantic import ConfigDict from pydantic import ConfigDict
class ModelA(BM2, Generic[A]): # type: ignore[no-redef] class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef]
a: A a: A
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
@ -2081,13 +2081,13 @@ def test_structured_tool_direct_init() -> None:
def foo(bar: str) -> str: def foo(bar: str) -> str:
return bar return bar
async def asyncFoo(bar: str) -> str: async def async_foo(bar: str) -> str:
return bar return bar
class fooSchema(BaseModel): class FooSchema(BaseModel):
bar: str = Field(..., description="The bar") bar: str = Field(..., description="The bar")
tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo) tool = StructuredTool(name="foo", args_schema=FooSchema, coroutine=async_foo)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
assert tool.invoke("hello") == "hello" assert tool.invoke("hello") == "hello"

View File

@ -38,7 +38,7 @@ from langchain_core.utils.function_calling import (
@pytest.fixture() @pytest.fixture()
def pydantic() -> type[BaseModel]: def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): class dummy_function(BaseModel): # noqa: N801
"""dummy function""" """dummy function"""
arg1: int = Field(..., description="foo") arg1: int = Field(..., description="foo")
@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]:
@pytest.fixture() @pytest.fixture()
def Annotated_function() -> Callable: def annotated_function() -> Callable:
def dummy_function( def dummy_function(
arg1: ExtensionsAnnotated[int, "foo"], arg1: ExtensionsAnnotated[int, "foo"],
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"], arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
@ -118,7 +118,7 @@ def dummy_structured_tool() -> StructuredTool:
@pytest.fixture() @pytest.fixture()
def dummy_pydantic() -> type[BaseModel]: def dummy_pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): class dummy_function(BaseModel): # noqa: N801
"""dummy function""" """dummy function"""
arg1: int = Field(..., description="foo") arg1: int = Field(..., description="foo")
@ -129,7 +129,7 @@ def dummy_pydantic() -> type[BaseModel]:
@pytest.fixture() @pytest.fixture()
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]: def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
class dummy_function(BaseModelV2Maybe): class dummy_function(BaseModelV2Maybe): # noqa: N801
"""dummy function""" """dummy function"""
arg1: int = FieldV2Maybe(..., description="foo") arg1: int = FieldV2Maybe(..., description="foo")
@ -142,7 +142,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
@pytest.fixture() @pytest.fixture()
def dummy_typing_typed_dict() -> type: def dummy_typing_typed_dict() -> type:
class dummy_function(TypingTypedDict): class dummy_function(TypingTypedDict): # noqa: N801
"""dummy function""" """dummy function"""
arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821 arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821
@ -153,7 +153,7 @@ def dummy_typing_typed_dict() -> type:
@pytest.fixture() @pytest.fixture()
def dummy_typing_typed_dict_docstring() -> type: def dummy_typing_typed_dict_docstring() -> type:
class dummy_function(TypingTypedDict): class dummy_function(TypingTypedDict): # noqa: N801
"""dummy function """dummy function
Args: Args:
@ -169,7 +169,7 @@ def dummy_typing_typed_dict_docstring() -> type:
@pytest.fixture() @pytest.fixture()
def dummy_extensions_typed_dict() -> type: def dummy_extensions_typed_dict() -> type:
class dummy_function(ExtensionsTypedDict): class dummy_function(ExtensionsTypedDict): # noqa: N801
"""dummy function""" """dummy function"""
arg1: ExtensionsAnnotated[int, ..., "foo"] arg1: ExtensionsAnnotated[int, ..., "foo"]
@ -180,7 +180,7 @@ def dummy_extensions_typed_dict() -> type:
@pytest.fixture() @pytest.fixture()
def dummy_extensions_typed_dict_docstring() -> type: def dummy_extensions_typed_dict_docstring() -> type:
class dummy_function(ExtensionsTypedDict): class dummy_function(ExtensionsTypedDict): # noqa: N801
"""dummy function """dummy function
Args: Args:
@ -241,7 +241,7 @@ def test_convert_to_openai_function(
dummy_structured_tool: StructuredTool, dummy_structured_tool: StructuredTool,
dummy_tool: BaseTool, dummy_tool: BaseTool,
json_schema: dict, json_schema: dict,
Annotated_function: Callable, annotated_function: Callable,
dummy_pydantic: type[BaseModel], dummy_pydantic: type[BaseModel],
runnable: Runnable, runnable: Runnable,
dummy_typing_typed_dict: type, dummy_typing_typed_dict: type,
@ -275,7 +275,7 @@ def test_convert_to_openai_function(
expected, expected,
Dummy.dummy_function, Dummy.dummy_function,
DummyWithClassMethod.dummy_function, DummyWithClassMethod.dummy_function,
Annotated_function, annotated_function,
dummy_pydantic, dummy_pydantic,
dummy_typing_typed_dict, dummy_typing_typed_dict,
dummy_typing_typed_dict_docstring, dummy_typing_typed_dict_docstring,
@ -523,20 +523,20 @@ def test__convert_typed_dict_to_openai_function(
use_extension_typed_dict: bool, use_extension_annotated: bool use_extension_typed_dict: bool, use_extension_annotated: bool
) -> None: ) -> None:
if use_extension_typed_dict: if use_extension_typed_dict:
TypedDict = ExtensionsTypedDict typed_dict = ExtensionsTypedDict
else: else:
TypedDict = TypingTypedDict typed_dict = TypingTypedDict
if use_extension_annotated: if use_extension_annotated:
Annotated = TypingAnnotated annotated = TypingAnnotated
else: else:
Annotated = TypingAnnotated annotated = TypingAnnotated
class SubTool(TypedDict): class SubTool(typed_dict):
"""Subtool docstring""" """Subtool docstring"""
args: Annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore args: annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
class Tool(TypedDict): class Tool(typed_dict):
"""Docstring """Docstring
Args: Args:
@ -546,20 +546,20 @@ def test__convert_typed_dict_to_openai_function(
arg1: str arg1: str
arg2: Union[int, str, bool] arg2: Union[int, str, bool]
arg3: Optional[list[SubTool]] arg3: Optional[list[SubTool]]
arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722 arg4: annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722
arg5: Annotated[Optional[float], None] arg5: annotated[Optional[float], None]
arg6: Annotated[ arg6: annotated[
Optional[Sequence[Mapping[str, tuple[Iterable[Any], SubTool]]]], [] Optional[Sequence[Mapping[str, tuple[Iterable[Any], SubTool]]]], []
] ]
arg7: Annotated[list[SubTool], ...] arg7: annotated[list[SubTool], ...]
arg8: Annotated[tuple[SubTool], ...] arg8: annotated[tuple[SubTool], ...]
arg9: Annotated[Sequence[SubTool], ...] arg9: annotated[Sequence[SubTool], ...]
arg10: Annotated[Iterable[SubTool], ...] arg10: annotated[Iterable[SubTool], ...]
arg11: Annotated[set[SubTool], ...] arg11: annotated[set[SubTool], ...]
arg12: Annotated[dict[str, SubTool], ...] arg12: annotated[dict[str, SubTool], ...]
arg13: Annotated[Mapping[str, SubTool], ...] arg13: annotated[Mapping[str, SubTool], ...]
arg14: Annotated[MutableMapping[str, SubTool], ...] arg14: annotated[MutableMapping[str, SubTool], ...]
arg15: Annotated[bool, False, "flag"] # noqa: F821 # type: ignore arg15: annotated[bool, False, "flag"] # noqa: F821 # type: ignore
expected = { expected = {
"name": "Tool", "name": "Tool",

View File

@ -10,7 +10,7 @@ from langchain_standard_tests.integration_tests.vectorstores import (
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings.fake import DeterministicFakeEmbedding from langchain_core.embeddings.fake import DeterministicFakeEmbedding
from langchain_core.vectorstores import InMemoryVectorStore from langchain_core.vectorstores import InMemoryVectorStore
from tests.unit_tests.stubs import _AnyIdDocument from tests.unit_tests.stubs import _any_id_document
class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite): class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite):
@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None:
# Check sync version # Check sync version
output = store.similarity_search("foo", k=1) output = store.similarity_search("foo", k=1)
assert output == [_AnyIdDocument(page_content="foo")] assert output == [_any_id_document(page_content="foo")]
# Check async version # Check async version
output = await store.asimilarity_search("bar", k=2) output = await store.asimilarity_search("bar", k=2)
assert output == [ assert output == [
_AnyIdDocument(page_content="bar"), _any_id_document(page_content="bar"),
_AnyIdDocument(page_content="baz"), _any_id_document(page_content="baz"),
] ]
@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None:
# make sure we can k > docstore size # make sure we can k > docstore size
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1) output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
assert len(output) == len(texts) assert len(output) == len(texts)
assert output[0] == _AnyIdDocument(page_content="foo") assert output[0] == _any_id_document(page_content="foo")
assert output[1] == _AnyIdDocument(page_content="foy") assert output[1] == _any_id_document(page_content="foy")
# Check async version # Check async version
output = await docsearch.amax_marginal_relevance_search( output = await docsearch.amax_marginal_relevance_search(
"foo", k=10, lambda_mult=0.1 "foo", k=10, lambda_mult=0.1
) )
assert len(output) == len(texts) assert len(output) == len(texts)
assert output[0] == _AnyIdDocument(page_content="foo") assert output[0] == _any_id_document(page_content="foo")
assert output[1] == _AnyIdDocument(page_content="foy") assert output[1] == _any_id_document(page_content="foy")
async def test_inmemory_dump_load(tmp_path: Path) -> None: async def test_inmemory_dump_load(tmp_path: Path) -> None:
@ -117,7 +117,7 @@ async def test_inmemory_filter() -> None:
# Check sync version # Check sync version
output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1) output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1)
assert output == [_AnyIdDocument(page_content="foo", metadata={"id": 1})] assert output == [_any_id_document(page_content="foo", metadata={"id": 1})]
# filter with not stored document id # filter with not stored document id
output = await store.asimilarity_search( output = await store.asimilarity_search(

View File

@ -50,6 +50,7 @@ class CustomAddTextsVectorstore(VectorStore):
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
return [self.store[id] for id in ids if id in self.store] return [self.store[id] for id in ids if id in self.store]
@classmethod
def from_texts( # type: ignore def from_texts( # type: ignore
cls, cls,
texts: list[str], texts: list[str],