mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 02:33:05 +00:00
core: Add N(naming) ruff rules (#25362)
Public classes/functions are not renamed and rule is ignored for them. Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
7835c0651f
commit
fd21ffe293
@ -155,7 +155,7 @@ def beta(
|
|||||||
_name = _name or obj.fget.__qualname__
|
_name = _name or obj.fget.__qualname__
|
||||||
old_doc = obj.__doc__
|
old_doc = obj.__doc__
|
||||||
|
|
||||||
class _beta_property(property):
|
class _BetaProperty(property):
|
||||||
"""A beta property."""
|
"""A beta property."""
|
||||||
|
|
||||||
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
|
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
|
||||||
@ -186,7 +186,7 @@ def beta(
|
|||||||
|
|
||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
|
||||||
"""Finalize the property."""
|
"""Finalize the property."""
|
||||||
return _beta_property(
|
return _BetaProperty(
|
||||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -264,7 +264,7 @@ def deprecated(
|
|||||||
_name = _name or cast(Union[type, Callable], obj.fget).__qualname__
|
_name = _name or cast(Union[type, Callable], obj.fget).__qualname__
|
||||||
old_doc = obj.__doc__
|
old_doc = obj.__doc__
|
||||||
|
|
||||||
class _deprecated_property(property):
|
class _DeprecatedProperty(property):
|
||||||
"""A deprecated property."""
|
"""A deprecated property."""
|
||||||
|
|
||||||
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
|
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
|
||||||
@ -297,7 +297,7 @@ def deprecated(
|
|||||||
"""Finalize the property."""
|
"""Finalize the property."""
|
||||||
return cast(
|
return cast(
|
||||||
T,
|
T,
|
||||||
_deprecated_property(
|
_DeprecatedProperty(
|
||||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class LangChainException(Exception):
|
class LangChainException(Exception): # noqa: N818
|
||||||
"""General LangChain exception."""
|
"""General LangChain exception."""
|
||||||
|
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ class TracerException(LangChainException):
|
|||||||
"""Base class for exceptions in tracers module."""
|
"""Base class for exceptions in tracers module."""
|
||||||
|
|
||||||
|
|
||||||
class OutputParserException(ValueError, LangChainException):
|
class OutputParserException(ValueError, LangChainException): # noqa: N818
|
||||||
"""Exception that output parsers should raise to signify a parsing error.
|
"""Exception that output parsers should raise to signify a parsing error.
|
||||||
|
|
||||||
This exists to differentiate parsing errors from other code or execution errors
|
This exists to differentiate parsing errors from other code or execution errors
|
||||||
|
@ -14,7 +14,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
from typing_extensions import TypeAlias, TypedDict
|
from typing_extensions import TypeAlias, TypedDict, override
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -143,6 +143,7 @@ class BaseLanguageModel(
|
|||||||
return verbose
|
return verbose
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> TypeAlias:
|
def InputType(self) -> TypeAlias:
|
||||||
"""Get the input type for this runnable."""
|
"""Get the input type for this runnable."""
|
||||||
from langchain_core.prompt_values import (
|
from langchain_core.prompt_values import (
|
||||||
|
@ -26,6 +26,7 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.caches import BaseCache
|
from langchain_core.caches import BaseCache
|
||||||
@ -251,6 +252,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
# --- Runnable methods ---
|
# --- Runnable methods ---
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> Any:
|
def OutputType(self) -> Any:
|
||||||
"""Get the output type for this runnable."""
|
"""Get the output type for this runnable."""
|
||||||
return AnyMessage
|
return AnyMessage
|
||||||
|
@ -31,6 +31,7 @@ from tenacity import (
|
|||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.caches import BaseCache
|
from langchain_core.caches import BaseCache
|
||||||
@ -318,6 +319,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
# --- Runnable methods ---
|
# --- Runnable methods ---
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type[str]:
|
def OutputType(self) -> type[str]:
|
||||||
"""Get the input type for this runnable."""
|
"""Get the input type for this runnable."""
|
||||||
return str
|
return str
|
||||||
|
@ -10,6 +10,8 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.language_models import LanguageModelOutput
|
from langchain_core.language_models import LanguageModelOutput
|
||||||
from langchain_core.messages import AnyMessage, BaseMessage
|
from langchain_core.messages import AnyMessage, BaseMessage
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
@ -63,11 +65,13 @@ class BaseGenerationOutputParser(
|
|||||||
"""Base class to parse the output of an LLM call."""
|
"""Base class to parse the output of an LLM call."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
"""Return the input type for the parser."""
|
"""Return the input type for the parser."""
|
||||||
return Union[str, AnyMessage]
|
return Union[str, AnyMessage]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type[T]:
|
def OutputType(self) -> type[T]:
|
||||||
"""Return the output type for the parser."""
|
"""Return the output type for the parser."""
|
||||||
# even though mypy complains this isn't valid,
|
# even though mypy complains this isn't valid,
|
||||||
@ -148,11 +152,13 @@ class BaseOutputParser(
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
"""Return the input type for the parser."""
|
"""Return the input type for the parser."""
|
||||||
return Union[str, AnyMessage]
|
return Union[str, AnyMessage]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type[T]:
|
def OutputType(self) -> type[T]:
|
||||||
"""Return the output type for the parser.
|
"""Return the output type for the parser.
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ from typing import Annotated, Generic, Optional
|
|||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import SkipValidation
|
from pydantic import SkipValidation
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
@ -107,6 +108,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
return "pydantic"
|
return "pydantic"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type[TBaseModel]:
|
def OutputType(self) -> type[TBaseModel]:
|
||||||
"""Return the pydantic model."""
|
"""Return the pydantic model."""
|
||||||
return self.pydantic_object
|
return self.pydantic_object
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
import xml
|
import xml
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET # noqa: N817
|
||||||
from collections.abc import AsyncIterator, Iterator
|
from collections.abc import AsyncIterator, Iterator
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
from xml.etree.ElementTree import TreeBuilder
|
from xml.etree.ElementTree import TreeBuilder
|
||||||
@ -46,14 +46,14 @@ class _StreamingParser:
|
|||||||
"""
|
"""
|
||||||
if parser == "defusedxml":
|
if parser == "defusedxml":
|
||||||
try:
|
try:
|
||||||
from defusedxml import ElementTree as DET # type: ignore
|
import defusedxml # type: ignore
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"defusedxml is not installed. "
|
"defusedxml is not installed. "
|
||||||
"Please install it to use the defusedxml parser."
|
"Please install it to use the defusedxml parser."
|
||||||
"You can install it with `pip install defusedxml` "
|
"You can install it with `pip install defusedxml` "
|
||||||
) from e
|
) from e
|
||||||
_parser = DET.DefusedXMLParser(target=TreeBuilder())
|
_parser = defusedxml.ElementTree.DefusedXMLParser(target=TreeBuilder())
|
||||||
else:
|
else:
|
||||||
_parser = None
|
_parser = None
|
||||||
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
|
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
|
||||||
@ -189,7 +189,7 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
# likely if you're reading this you can move them to the top of the file
|
# likely if you're reading this you can move them to the top of the file
|
||||||
if self.parser == "defusedxml":
|
if self.parser == "defusedxml":
|
||||||
try:
|
try:
|
||||||
from defusedxml import ElementTree as DET # type: ignore
|
import defusedxml # type: ignore
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"defusedxml is not installed. "
|
"defusedxml is not installed. "
|
||||||
@ -197,9 +197,9 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
"You can install it with `pip install defusedxml`"
|
"You can install it with `pip install defusedxml`"
|
||||||
"See https://github.com/tiran/defusedxml for more details"
|
"See https://github.com/tiran/defusedxml for more details"
|
||||||
) from e
|
) from e
|
||||||
_ET = DET # Use the defusedxml parser
|
_et = defusedxml.ElementTree # Use the defusedxml parser
|
||||||
else:
|
else:
|
||||||
_ET = ET # Use the standard library parser
|
_et = ET # Use the standard library parser
|
||||||
|
|
||||||
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
|
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
|
||||||
if match is not None:
|
if match is not None:
|
||||||
|
@ -18,7 +18,7 @@ from typing import (
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self, override
|
||||||
|
|
||||||
from langchain_core.load import dumpd
|
from langchain_core.load import dumpd
|
||||||
from langchain_core.output_parsers.base import BaseOutputParser
|
from langchain_core.output_parsers.base import BaseOutputParser
|
||||||
@ -107,6 +107,7 @@ class BasePromptTemplate(
|
|||||||
return dumpd(self)
|
return dumpd(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> Any:
|
def OutputType(self) -> Any:
|
||||||
"""Return the output type of the prompt."""
|
"""Return the output type of the prompt."""
|
||||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||||
|
@ -36,7 +36,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel
|
from pydantic import BaseModel, ConfigDict, Field, RootModel
|
||||||
from typing_extensions import Literal, get_args
|
from typing_extensions import Literal, get_args, override
|
||||||
|
|
||||||
from langchain_core._api import beta_decorator
|
from langchain_core._api import beta_decorator
|
||||||
from langchain_core.load.serializable import (
|
from langchain_core.load.serializable import (
|
||||||
@ -272,7 +272,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return name_
|
return name_
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> type[Input]:
|
def InputType(self) -> type[Input]: # noqa: N802
|
||||||
"""The type of input this Runnable accepts specified as a type annotation."""
|
"""The type of input this Runnable accepts specified as a type annotation."""
|
||||||
# First loop through all parent classes and if any of them is
|
# First loop through all parent classes and if any of them is
|
||||||
# a pydantic model, we will pick up the generic parameterization
|
# a pydantic model, we will pick up the generic parameterization
|
||||||
@ -297,7 +297,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> type[Output]:
|
def OutputType(self) -> type[Output]: # noqa: N802
|
||||||
"""The type of output this Runnable produces specified as a type annotation."""
|
"""The type of output this Runnable produces specified as a type annotation."""
|
||||||
# First loop through bases -- this will help generic
|
# First loop through bases -- this will help generic
|
||||||
# any pydantic models.
|
# any pydantic models.
|
||||||
@ -2811,11 +2811,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> type[Input]:
|
def InputType(self) -> type[Input]:
|
||||||
"""The type of the input to the Runnable."""
|
"""The type of the input to the Runnable."""
|
||||||
return self.first.InputType
|
return self.first.InputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type[Output]:
|
def OutputType(self) -> type[Output]:
|
||||||
"""The type of the output of the Runnable."""
|
"""The type of the output of the Runnable."""
|
||||||
return self.last.OutputType
|
return self.last.OutputType
|
||||||
@ -3564,6 +3566,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
return super().get_name(suffix, name=name)
|
return super().get_name(suffix, name=name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
"""The type of the input to the Runnable."""
|
"""The type of the input to the Runnable."""
|
||||||
for step in self.steps__.values():
|
for step in self.steps__.values():
|
||||||
@ -4057,6 +4060,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
self.name = "RunnableGenerator"
|
self.name = "RunnableGenerator"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
func = getattr(self, "_transform", None) or self._atransform
|
func = getattr(self, "_transform", None) or self._atransform
|
||||||
try:
|
try:
|
||||||
@ -4097,6 +4101,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> Any:
|
def OutputType(self) -> Any:
|
||||||
func = getattr(self, "_transform", None) or self._atransform
|
func = getattr(self, "_transform", None) or self._atransform
|
||||||
try:
|
try:
|
||||||
@ -4346,6 +4351,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
"""The type of the input to this Runnable."""
|
"""The type of the input to this Runnable."""
|
||||||
func = getattr(self, "func", None) or self.afunc
|
func = getattr(self, "func", None) or self.afunc
|
||||||
@ -4405,6 +4411,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return super().get_input_schema(config)
|
return super().get_input_schema(config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> Any:
|
def OutputType(self) -> Any:
|
||||||
"""The type of the output of this Runnable as a type annotation.
|
"""The type of the output of this Runnable as a type annotation.
|
||||||
|
|
||||||
@ -4958,6 +4965,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
return list[self.bound.InputType] # type: ignore[name-defined]
|
return list[self.bound.InputType] # type: ignore[name-defined]
|
||||||
|
|
||||||
@ -4981,6 +4989,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type[list[Output]]:
|
def OutputType(self) -> type[list[Output]]:
|
||||||
return list[self.bound.OutputType] # type: ignore[name-defined]
|
return list[self.bound.OutputType] # type: ignore[name-defined]
|
||||||
|
|
||||||
@ -5274,6 +5283,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
return self.bound.get_name(suffix, name=name)
|
return self.bound.get_name(suffix, name=name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> type[Input]:
|
def InputType(self) -> type[Input]:
|
||||||
return (
|
return (
|
||||||
cast(type[Input], self.custom_input_type)
|
cast(type[Input], self.custom_input_type)
|
||||||
@ -5282,6 +5292,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type[Output]:
|
def OutputType(self) -> type[Output]:
|
||||||
return (
|
return (
|
||||||
cast(type[Output], self.custom_output_type)
|
cast(type[Output], self.custom_output_type)
|
||||||
|
@ -16,6 +16,7 @@ from typing import (
|
|||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
@ -68,10 +69,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
return ["langchain", "schema", "runnable"]
|
return ["langchain", "schema", "runnable"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> type[Input]:
|
def InputType(self) -> type[Input]:
|
||||||
return self.default.InputType
|
return self.default.InputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type[Output]:
|
def OutputType(self) -> type[Output]:
|
||||||
return self.default.OutputType
|
return self.default.OutputType
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
@ -106,10 +107,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> type[Input]:
|
def InputType(self) -> type[Input]:
|
||||||
return self.runnable.InputType
|
return self.runnable.InputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type[Output]:
|
def OutputType(self) -> type[Output]:
|
||||||
return self.runnable.OutputType
|
return self.runnable.OutputType
|
||||||
|
|
||||||
|
@ -246,27 +246,27 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
|||||||
|
|
||||||
# NOTE: coordinates might me negative, so we need to shift
|
# NOTE: coordinates might me negative, so we need to shift
|
||||||
# everything to the positive plane before we actually draw it.
|
# everything to the positive plane before we actually draw it.
|
||||||
Xs = []
|
xlist = []
|
||||||
Ys = []
|
ylist = []
|
||||||
|
|
||||||
sug = _build_sugiyama_layout(vertices, edges)
|
sug = _build_sugiyama_layout(vertices, edges)
|
||||||
|
|
||||||
for vertex in sug.g.sV:
|
for vertex in sug.g.sV:
|
||||||
# NOTE: moving boxes w/2 to the left
|
# NOTE: moving boxes w/2 to the left
|
||||||
Xs.append(vertex.view.xy[0] - vertex.view.w / 2.0)
|
xlist.append(vertex.view.xy[0] - vertex.view.w / 2.0)
|
||||||
Xs.append(vertex.view.xy[0] + vertex.view.w / 2.0)
|
xlist.append(vertex.view.xy[0] + vertex.view.w / 2.0)
|
||||||
Ys.append(vertex.view.xy[1])
|
ylist.append(vertex.view.xy[1])
|
||||||
Ys.append(vertex.view.xy[1] + vertex.view.h)
|
ylist.append(vertex.view.xy[1] + vertex.view.h)
|
||||||
|
|
||||||
for edge in sug.g.sE:
|
for edge in sug.g.sE:
|
||||||
for x, y in edge.view._pts:
|
for x, y in edge.view._pts:
|
||||||
Xs.append(x)
|
xlist.append(x)
|
||||||
Ys.append(y)
|
ylist.append(y)
|
||||||
|
|
||||||
minx = min(Xs)
|
minx = min(xlist)
|
||||||
miny = min(Ys)
|
miny = min(ylist)
|
||||||
maxx = max(Xs)
|
maxx = max(xlist)
|
||||||
maxy = max(Ys)
|
maxy = max(ylist)
|
||||||
|
|
||||||
canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1
|
canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1
|
||||||
canvas_lines = int(round(maxy - miny))
|
canvas_lines = int(round(maxy - miny))
|
||||||
|
@ -12,6 +12,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
from langchain_core.load.load import load
|
from langchain_core.load.load import load
|
||||||
@ -396,6 +397,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type[Output]:
|
def OutputType(self) -> type[Output]:
|
||||||
output_type = self._history_chain.OutputType
|
output_type = self._history_chain.OutputType
|
||||||
return output_type
|
return output_type
|
||||||
|
@ -16,6 +16,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, RootModel
|
from pydantic import BaseModel, RootModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.runnables.base import (
|
from langchain_core.runnables.base import (
|
||||||
Other,
|
Other,
|
||||||
@ -193,10 +194,12 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
return ["langchain", "schema", "runnable"]
|
return ["langchain", "schema", "runnable"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
return self.input_type or Any
|
return self.input_type or Any
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> Any:
|
def OutputType(self) -> Any:
|
||||||
return self.input_type or Any
|
return self.input_type or Any
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import TypeGuard
|
from typing_extensions import TypeGuard, override
|
||||||
|
|
||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
|
|
||||||
@ -135,6 +135,7 @@ class IsLocalDict(ast.NodeVisitor):
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.keys = keys
|
self.keys = keys
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
||||||
"""Visit a subscript node.
|
"""Visit a subscript node.
|
||||||
|
|
||||||
@ -154,6 +155,7 @@ class IsLocalDict(ast.NodeVisitor):
|
|||||||
# we've found a subscript access on the name we're looking for
|
# we've found a subscript access on the name we're looking for
|
||||||
self.keys.add(node.slice.value)
|
self.keys.add(node.slice.value)
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_Call(self, node: ast.Call) -> Any:
|
def visit_Call(self, node: ast.Call) -> Any:
|
||||||
"""Visit a call node.
|
"""Visit a call node.
|
||||||
|
|
||||||
@ -182,6 +184,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.keys: set[str] = set()
|
self.keys: set[str] = set()
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||||
"""Visit a lambda function.
|
"""Visit a lambda function.
|
||||||
|
|
||||||
@ -196,6 +199,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
|||||||
input_arg_name = node.args.args[0].arg
|
input_arg_name = node.args.args[0].arg
|
||||||
IsLocalDict(input_arg_name, self.keys).visit(node.body)
|
IsLocalDict(input_arg_name, self.keys).visit(node.body)
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||||
"""Visit a function definition.
|
"""Visit a function definition.
|
||||||
|
|
||||||
@ -210,6 +214,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
|||||||
input_arg_name = node.args.args[0].arg
|
input_arg_name = node.args.args[0].arg
|
||||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||||
"""Visit an async function definition.
|
"""Visit an async function definition.
|
||||||
|
|
||||||
@ -232,6 +237,7 @@ class NonLocals(ast.NodeVisitor):
|
|||||||
self.loads: set[str] = set()
|
self.loads: set[str] = set()
|
||||||
self.stores: set[str] = set()
|
self.stores: set[str] = set()
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_Name(self, node: ast.Name) -> Any:
|
def visit_Name(self, node: ast.Name) -> Any:
|
||||||
"""Visit a name node.
|
"""Visit a name node.
|
||||||
|
|
||||||
@ -246,6 +252,7 @@ class NonLocals(ast.NodeVisitor):
|
|||||||
elif isinstance(node.ctx, ast.Store):
|
elif isinstance(node.ctx, ast.Store):
|
||||||
self.stores.add(node.id)
|
self.stores.add(node.id)
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||||
"""Visit an attribute node.
|
"""Visit an attribute node.
|
||||||
|
|
||||||
@ -272,6 +279,7 @@ class FunctionNonLocals(ast.NodeVisitor):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.nonlocals: set[str] = set()
|
self.nonlocals: set[str] = set()
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||||
"""Visit a function definition.
|
"""Visit a function definition.
|
||||||
|
|
||||||
@ -285,6 +293,7 @@ class FunctionNonLocals(ast.NodeVisitor):
|
|||||||
visitor.visit(node)
|
visitor.visit(node)
|
||||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||||
"""Visit an async function definition.
|
"""Visit an async function definition.
|
||||||
|
|
||||||
@ -298,6 +307,7 @@ class FunctionNonLocals(ast.NodeVisitor):
|
|||||||
visitor.visit(node)
|
visitor.visit(node)
|
||||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||||
"""Visit a lambda function.
|
"""Visit a lambda function.
|
||||||
|
|
||||||
@ -320,6 +330,7 @@ class GetLambdaSource(ast.NodeVisitor):
|
|||||||
self.source: Optional[str] = None
|
self.source: Optional[str] = None
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
|
@override
|
||||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||||
"""Visit a lambda function.
|
"""Visit a lambda function.
|
||||||
|
|
||||||
|
@ -311,7 +311,7 @@ def create_schema_from_function(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolException(Exception):
|
class ToolException(Exception): # noqa: N818
|
||||||
"""Optional exception that tool throws when execution error occurs.
|
"""Optional exception that tool throws when execution error occurs.
|
||||||
|
|
||||||
When this exception is thrown, the agent will not stop working,
|
When this exception is thrown, the agent will not stop working,
|
||||||
|
@ -9,7 +9,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any:
|
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802
|
||||||
"""Throw an error because this has been replaced by LangChainTracer."""
|
"""Throw an error because this has been replaced by LangChainTracer."""
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."
|
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."
|
||||||
|
@ -18,7 +18,7 @@ from langchain_core._api import deprecated
|
|||||||
|
|
||||||
|
|
||||||
@deprecated("0.1.0", alternative="Use string instead.", removal="1.0")
|
@deprecated("0.1.0", alternative="Use string instead.", removal="1.0")
|
||||||
def RunTypeEnum() -> type[RunTypeEnumDep]:
|
def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
|
||||||
"""RunTypeEnum."""
|
"""RunTypeEnum."""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"RunTypeEnum is deprecated. Please directly use a string instead"
|
"RunTypeEnum is deprecated. Please directly use a string instead"
|
||||||
|
@ -235,7 +235,7 @@ class Tee(Generic[T]):
|
|||||||
atee = Tee
|
atee = Tee
|
||||||
|
|
||||||
|
|
||||||
class aclosing(AbstractAsyncContextManager):
|
class aclosing(AbstractAsyncContextManager): # noqa: N801
|
||||||
"""Async context manager for safely finalizing an asynchronously cleaned-up
|
"""Async context manager for safely finalizing an asynchronously cleaned-up
|
||||||
resource such as an async generator, calling its ``aclose()`` method.
|
resource such as an async generator, calling its ``aclose()`` method.
|
||||||
|
|
||||||
|
@ -17,12 +17,12 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
|
||||||
"""Row-wise cosine similarity between two equal-width matrices.
|
"""Row-wise cosine similarity between two equal-width matrices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
X: A matrix of shape (n, m).
|
x: A matrix of shape (n, m).
|
||||||
Y: A matrix of shape (k, m).
|
y: A matrix of shape (k, m).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A matrix of shape (n, k) where each element (i, j) is the cosine similarity
|
A matrix of shape (n, k) where each element (i, j) is the cosine similarity
|
||||||
@ -40,33 +40,33 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|||||||
"Please install numpy with `pip install numpy`."
|
"Please install numpy with `pip install numpy`."
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
if len(X) == 0 or len(Y) == 0:
|
if len(x) == 0 or len(y) == 0:
|
||||||
return np.array([])
|
return np.array([])
|
||||||
|
|
||||||
X = np.array(X)
|
x = np.array(x)
|
||||||
Y = np.array(Y)
|
y = np.array(y)
|
||||||
if X.shape[1] != Y.shape[1]:
|
if x.shape[1] != y.shape[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
f"Number of columns in X and Y must be the same. X has shape {x.shape} "
|
||||||
f"and Y has shape {Y.shape}."
|
f"and Y has shape {y.shape}."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
import simsimd as simd # type: ignore[import-not-found]
|
import simsimd as simd # type: ignore[import-not-found]
|
||||||
|
|
||||||
X = np.array(X, dtype=np.float32)
|
x = np.array(x, dtype=np.float32)
|
||||||
Y = np.array(Y, dtype=np.float32)
|
y = np.array(y, dtype=np.float32)
|
||||||
Z = 1 - np.array(simd.cdist(X, Y, metric="cosine"))
|
z = 1 - np.array(simd.cdist(x, y, metric="cosine"))
|
||||||
return Z
|
return z
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
|
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
|
||||||
"to use simsimd please install with `pip install simsimd`."
|
"to use simsimd please install with `pip install simsimd`."
|
||||||
)
|
)
|
||||||
X_norm = np.linalg.norm(X, axis=1)
|
x_norm = np.linalg.norm(x, axis=1)
|
||||||
Y_norm = np.linalg.norm(Y, axis=1)
|
y_norm = np.linalg.norm(y, axis=1)
|
||||||
# Ignore divide by zero errors run time warnings as those are handled below.
|
# Ignore divide by zero errors run time warnings as those are handled below.
|
||||||
with np.errstate(divide="ignore", invalid="ignore"):
|
with np.errstate(divide="ignore", invalid="ignore"):
|
||||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm)
|
||||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||||
return similarity
|
return similarity
|
||||||
|
|
||||||
|
@ -44,9 +44,17 @@ python = ">=3.12.4"
|
|||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["B", "E", "F", "I", "T201", "UP"]
|
select = ["B", "E", "F", "I", "N", "T201", "UP"]
|
||||||
ignore = ["UP007"]
|
ignore = ["UP007"]
|
||||||
|
|
||||||
|
[tool.ruff.lint.pep8-naming]
|
||||||
|
classmethod-decorators = [
|
||||||
|
"classmethod",
|
||||||
|
"langchain_core.utils.pydantic.pre_init",
|
||||||
|
"pydantic.field_validator",
|
||||||
|
"pydantic.v1.root_validator",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
omit = [ "tests/*",]
|
omit = [ "tests/*",]
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
|
|||||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||||
from tests.unit_tests.stubs import (
|
from tests.unit_tests.stubs import (
|
||||||
_AnyIdAIMessage,
|
_any_id_ai_message,
|
||||||
_AnyIdAIMessageChunk,
|
_any_id_ai_message_chunk,
|
||||||
_AnyIdHumanMessage,
|
_any_id_human_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -20,11 +20,11 @@ def test_generic_fake_chat_model_invoke() -> None:
|
|||||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
response = model.invoke("meow")
|
response = model.invoke("meow")
|
||||||
assert response == _AnyIdAIMessage(content="hello")
|
assert response == _any_id_ai_message(content="hello")
|
||||||
response = model.invoke("kitty")
|
response = model.invoke("kitty")
|
||||||
assert response == _AnyIdAIMessage(content="goodbye")
|
assert response == _any_id_ai_message(content="goodbye")
|
||||||
response = model.invoke("meow")
|
response = model.invoke("meow")
|
||||||
assert response == _AnyIdAIMessage(content="hello")
|
assert response == _any_id_ai_message(content="hello")
|
||||||
|
|
||||||
|
|
||||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||||
@ -32,11 +32,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
|
|||||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
response = await model.ainvoke("meow")
|
response = await model.ainvoke("meow")
|
||||||
assert response == _AnyIdAIMessage(content="hello")
|
assert response == _any_id_ai_message(content="hello")
|
||||||
response = await model.ainvoke("kitty")
|
response = await model.ainvoke("kitty")
|
||||||
assert response == _AnyIdAIMessage(content="goodbye")
|
assert response == _any_id_ai_message(content="goodbye")
|
||||||
response = await model.ainvoke("meow")
|
response = await model.ainvoke("meow")
|
||||||
assert response == _AnyIdAIMessage(content="hello")
|
assert response == _any_id_ai_message(content="hello")
|
||||||
|
|
||||||
|
|
||||||
async def test_generic_fake_chat_model_stream() -> None:
|
async def test_generic_fake_chat_model_stream() -> None:
|
||||||
@ -49,17 +49,17 @@ async def test_generic_fake_chat_model_stream() -> None:
|
|||||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
chunks = [chunk async for chunk in model.astream("meow")]
|
chunks = [chunk async for chunk in model.astream("meow")]
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
_AnyIdAIMessageChunk(content="hello"),
|
_any_id_ai_message_chunk(content="hello"),
|
||||||
_AnyIdAIMessageChunk(content=" "),
|
_any_id_ai_message_chunk(content=" "),
|
||||||
_AnyIdAIMessageChunk(content="goodbye"),
|
_any_id_ai_message_chunk(content="goodbye"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
|
|
||||||
chunks = [chunk for chunk in model.stream("meow")]
|
chunks = [chunk for chunk in model.stream("meow")]
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
_AnyIdAIMessageChunk(content="hello"),
|
_any_id_ai_message_chunk(content="hello"),
|
||||||
_AnyIdAIMessageChunk(content=" "),
|
_any_id_ai_message_chunk(content=" "),
|
||||||
_AnyIdAIMessageChunk(content="goodbye"),
|
_any_id_ai_message_chunk(content="goodbye"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
|
|
||||||
@ -69,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None:
|
|||||||
model = GenericFakeChatModel(messages=cycle([message]))
|
model = GenericFakeChatModel(messages=cycle([message]))
|
||||||
chunks = [chunk async for chunk in model.astream("meow")]
|
chunks = [chunk async for chunk in model.astream("meow")]
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}),
|
_any_id_ai_message_chunk(content="", additional_kwargs={"foo": 42}),
|
||||||
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}),
|
_any_id_ai_message_chunk(content="", additional_kwargs={"bar": 24}),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
|
|
||||||
@ -88,19 +88,19 @@ async def test_generic_fake_chat_model_stream() -> None:
|
|||||||
chunks = [chunk async for chunk in model.astream("meow")]
|
chunks = [chunk async for chunk in model.astream("meow")]
|
||||||
|
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
_AnyIdAIMessageChunk(
|
_any_id_ai_message_chunk(
|
||||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||||
),
|
),
|
||||||
_AnyIdAIMessageChunk(
|
_any_id_ai_message_chunk(
|
||||||
content="",
|
content="",
|
||||||
additional_kwargs={
|
additional_kwargs={
|
||||||
"function_call": {"arguments": '{\n "source_path": "foo"'},
|
"function_call": {"arguments": '{\n "source_path": "foo"'},
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
_AnyIdAIMessageChunk(
|
_any_id_ai_message_chunk(
|
||||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||||
),
|
),
|
||||||
_AnyIdAIMessageChunk(
|
_any_id_ai_message_chunk(
|
||||||
content="",
|
content="",
|
||||||
additional_kwargs={
|
additional_kwargs={
|
||||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
|
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
|
||||||
@ -138,9 +138,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
|
|||||||
]
|
]
|
||||||
final = log_patches[-1]
|
final = log_patches[-1]
|
||||||
assert final.state["streamed_output"] == [
|
assert final.state["streamed_output"] == [
|
||||||
_AnyIdAIMessageChunk(content="hello"),
|
_any_id_ai_message_chunk(content="hello"),
|
||||||
_AnyIdAIMessageChunk(content=" "),
|
_any_id_ai_message_chunk(content=" "),
|
||||||
_AnyIdAIMessageChunk(content="goodbye"),
|
_any_id_ai_message_chunk(content="goodbye"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
|
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
|
||||||
|
|
||||||
@ -189,9 +189,9 @@ async def test_callback_handlers() -> None:
|
|||||||
# New model
|
# New model
|
||||||
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
||||||
assert results == [
|
assert results == [
|
||||||
_AnyIdAIMessageChunk(content="hello"),
|
_any_id_ai_message_chunk(content="hello"),
|
||||||
_AnyIdAIMessageChunk(content=" "),
|
_any_id_ai_message_chunk(content=" "),
|
||||||
_AnyIdAIMessageChunk(content="goodbye"),
|
_any_id_ai_message_chunk(content="goodbye"),
|
||||||
]
|
]
|
||||||
assert tokens == ["hello", " ", "goodbye"]
|
assert tokens == ["hello", " ", "goodbye"]
|
||||||
assert len({chunk.id for chunk in results}) == 1
|
assert len({chunk.id for chunk in results}) == 1
|
||||||
@ -200,6 +200,8 @@ async def test_callback_handlers() -> None:
|
|||||||
def test_chat_model_inputs() -> None:
|
def test_chat_model_inputs() -> None:
|
||||||
fake = ParrotFakeChatModel()
|
fake = ParrotFakeChatModel()
|
||||||
|
|
||||||
assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello")
|
assert fake.invoke("hello") == _any_id_human_message(content="hello")
|
||||||
assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah")
|
assert fake.invoke([("ai", "blah")]) == _any_id_ai_message(content="blah")
|
||||||
assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah")
|
assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message(
|
||||||
|
content="blah"
|
||||||
|
)
|
||||||
|
@ -27,7 +27,7 @@ from tests.unit_tests.fake.callbacks import (
|
|||||||
FakeAsyncCallbackHandler,
|
FakeAsyncCallbackHandler,
|
||||||
FakeCallbackHandler,
|
FakeCallbackHandler,
|
||||||
)
|
)
|
||||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -147,10 +147,10 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
|||||||
|
|
||||||
model = ModelWithGenerate()
|
model = ModelWithGenerate()
|
||||||
chunks = [chunk for chunk in model.stream("anything")]
|
chunks = [chunk for chunk in model.stream("anything")]
|
||||||
assert chunks == [_AnyIdAIMessage(content="hello")]
|
assert chunks == [_any_id_ai_message(content="hello")]
|
||||||
|
|
||||||
chunks = [chunk async for chunk in model.astream("anything")]
|
chunks = [chunk async for chunk in model.astream("anything")]
|
||||||
assert chunks == [_AnyIdAIMessage(content="hello")]
|
assert chunks == [_any_id_ai_message(content="hello")]
|
||||||
|
|
||||||
|
|
||||||
async def test_astream_implementation_fallback_to_stream() -> None:
|
async def test_astream_implementation_fallback_to_stream() -> None:
|
||||||
@ -185,15 +185,15 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
|||||||
model = ModelWithSyncStream()
|
model = ModelWithSyncStream()
|
||||||
chunks = [chunk for chunk in model.stream("anything")]
|
chunks = [chunk for chunk in model.stream("anything")]
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
_AnyIdAIMessageChunk(content="a"),
|
_any_id_ai_message_chunk(content="a"),
|
||||||
_AnyIdAIMessageChunk(content="b"),
|
_any_id_ai_message_chunk(content="b"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
assert type(model)._astream == BaseChatModel._astream
|
assert type(model)._astream == BaseChatModel._astream
|
||||||
astream_chunks = [chunk async for chunk in model.astream("anything")]
|
astream_chunks = [chunk async for chunk in model.astream("anything")]
|
||||||
assert astream_chunks == [
|
assert astream_chunks == [
|
||||||
_AnyIdAIMessageChunk(content="a"),
|
_any_id_ai_message_chunk(content="a"),
|
||||||
_AnyIdAIMessageChunk(content="b"),
|
_any_id_ai_message_chunk(content="b"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in astream_chunks}) == 1
|
assert len({chunk.id for chunk in astream_chunks}) == 1
|
||||||
|
|
||||||
@ -230,8 +230,8 @@ async def test_astream_implementation_uses_astream() -> None:
|
|||||||
model = ModelWithAsyncStream()
|
model = ModelWithAsyncStream()
|
||||||
chunks = [chunk async for chunk in model.astream("anything")]
|
chunks = [chunk async for chunk in model.astream("anything")]
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
_AnyIdAIMessageChunk(content="a"),
|
_any_id_ai_message_chunk(content="a"),
|
||||||
_AnyIdAIMessageChunk(content="b"),
|
_any_id_ai_message_chunk(content="b"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ def change_directory(dir: Path) -> Iterator:
|
|||||||
os.chdir(origin)
|
os.chdir(origin)
|
||||||
|
|
||||||
|
|
||||||
def test_loading_from_YAML() -> None:
|
def test_loading_from_yaml() -> None:
|
||||||
"""Test loading from yaml file."""
|
"""Test loading from yaml file."""
|
||||||
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml")
|
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml")
|
||||||
expected_prompt = PromptTemplate(
|
expected_prompt = PromptTemplate(
|
||||||
@ -36,7 +36,7 @@ def test_loading_from_YAML() -> None:
|
|||||||
assert prompt == expected_prompt
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
def test_loading_from_JSON() -> None:
|
def test_loading_from_json() -> None:
|
||||||
"""Test loading from json file."""
|
"""Test loading from json file."""
|
||||||
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json")
|
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json")
|
||||||
expected_prompt = PromptTemplate(
|
expected_prompt = PromptTemplate(
|
||||||
@ -46,14 +46,14 @@ def test_loading_from_JSON() -> None:
|
|||||||
assert prompt == expected_prompt
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
def test_loading_jinja_from_JSON() -> None:
|
def test_loading_jinja_from_json() -> None:
|
||||||
"""Test that loading jinja2 format prompts from JSON raises ValueError."""
|
"""Test that loading jinja2 format prompts from JSON raises ValueError."""
|
||||||
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json"
|
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json"
|
||||||
with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"):
|
with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"):
|
||||||
load_prompt(prompt_path)
|
load_prompt(prompt_path)
|
||||||
|
|
||||||
|
|
||||||
def test_loading_jinja_from_YAML() -> None:
|
def test_loading_jinja_from_yaml() -> None:
|
||||||
"""Test that loading jinja2 format prompts from YAML raises ValueError."""
|
"""Test that loading jinja2 format prompts from YAML raises ValueError."""
|
||||||
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml"
|
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml"
|
||||||
with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"):
|
with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"):
|
||||||
|
@ -2,6 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.language_models import FakeListLLM
|
from langchain_core.language_models import FakeListLLM
|
||||||
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||||
@ -353,9 +354,11 @@ def test_runnable_get_graph_with_invalid_input_type() -> None:
|
|||||||
|
|
||||||
class InvalidInputTypeRunnable(Runnable[int, int]):
|
class InvalidInputTypeRunnable(Runnable[int, int]):
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def InputType(self) -> type:
|
def InputType(self) -> type:
|
||||||
raise TypeError()
|
raise TypeError()
|
||||||
|
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: int,
|
input: int,
|
||||||
@ -375,9 +378,11 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
|
|||||||
|
|
||||||
class InvalidOutputTypeRunnable(Runnable[int, int]):
|
class InvalidOutputTypeRunnable(Runnable[int, int]):
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def OutputType(self) -> type:
|
def OutputType(self) -> type:
|
||||||
raise TypeError()
|
raise TypeError()
|
||||||
|
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: int,
|
input: int,
|
||||||
|
@ -493,7 +493,7 @@ def test_get_output_schema() -> None:
|
|||||||
def test_get_input_schema_input_messages() -> None:
|
def test_get_input_schema_input_messages() -> None:
|
||||||
from pydantic import RootModel
|
from pydantic import RootModel
|
||||||
|
|
||||||
RunnableWithMessageHistoryInput = RootModel[Sequence[BaseMessage]]
|
runnable_with_message_history_input = RootModel[Sequence[BaseMessage]]
|
||||||
|
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda messages: {
|
lambda messages: {
|
||||||
@ -515,7 +515,7 @@ def test_get_input_schema_input_messages() -> None:
|
|||||||
with_history = RunnableWithMessageHistory(
|
with_history = RunnableWithMessageHistory(
|
||||||
runnable, get_session_history, output_messages_key="output"
|
runnable, get_session_history, output_messages_key="output"
|
||||||
)
|
)
|
||||||
expected_schema = _schema(RunnableWithMessageHistoryInput)
|
expected_schema = _schema(runnable_with_message_history_input)
|
||||||
expected_schema["title"] = "RunnableWithChatHistoryInput"
|
expected_schema["title"] = "RunnableWithChatHistoryInput"
|
||||||
assert _schema(with_history.get_input_schema()) == expected_schema
|
assert _schema(with_history.get_input_schema()) == expected_schema
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ from langchain_core.tracers import (
|
|||||||
)
|
)
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
|
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
|
||||||
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
|
from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk
|
||||||
|
|
||||||
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
|
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
|
||||||
|
|
||||||
@ -1699,7 +1699,7 @@ def test_prompt_with_chat_model(
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
assert chain.invoke(
|
assert chain.invoke(
|
||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
) == _AnyIdAIMessage(content="foo")
|
) == _any_id_ai_message(content="foo")
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
messages=[
|
messages=[
|
||||||
@ -1724,8 +1724,8 @@ def test_prompt_with_chat_model(
|
|||||||
],
|
],
|
||||||
dict(callbacks=[tracer]),
|
dict(callbacks=[tracer]),
|
||||||
) == [
|
) == [
|
||||||
_AnyIdAIMessage(content="foo"),
|
_any_id_ai_message(content="foo"),
|
||||||
_AnyIdAIMessage(content="foo"),
|
_any_id_ai_message(content="foo"),
|
||||||
]
|
]
|
||||||
assert prompt_spy.call_args.args[1] == [
|
assert prompt_spy.call_args.args[1] == [
|
||||||
{"question": "What is your name?"},
|
{"question": "What is your name?"},
|
||||||
@ -1765,9 +1765,9 @@ def test_prompt_with_chat_model(
|
|||||||
assert [
|
assert [
|
||||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||||
] == [
|
] == [
|
||||||
_AnyIdAIMessageChunk(content="f"),
|
_any_id_ai_message_chunk(content="f"),
|
||||||
_AnyIdAIMessageChunk(content="o"),
|
_any_id_ai_message_chunk(content="o"),
|
||||||
_AnyIdAIMessageChunk(content="o"),
|
_any_id_ai_message_chunk(content="o"),
|
||||||
]
|
]
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
@ -1805,7 +1805,7 @@ async def test_prompt_with_chat_model_async(
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
assert await chain.ainvoke(
|
assert await chain.ainvoke(
|
||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
) == _AnyIdAIMessage(content="foo")
|
) == _any_id_ai_message(content="foo")
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
messages=[
|
messages=[
|
||||||
@ -1830,8 +1830,8 @@ async def test_prompt_with_chat_model_async(
|
|||||||
],
|
],
|
||||||
dict(callbacks=[tracer]),
|
dict(callbacks=[tracer]),
|
||||||
) == [
|
) == [
|
||||||
_AnyIdAIMessage(content="foo"),
|
_any_id_ai_message(content="foo"),
|
||||||
_AnyIdAIMessage(content="foo"),
|
_any_id_ai_message(content="foo"),
|
||||||
]
|
]
|
||||||
assert prompt_spy.call_args.args[1] == [
|
assert prompt_spy.call_args.args[1] == [
|
||||||
{"question": "What is your name?"},
|
{"question": "What is your name?"},
|
||||||
@ -1874,9 +1874,9 @@ async def test_prompt_with_chat_model_async(
|
|||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
)
|
)
|
||||||
] == [
|
] == [
|
||||||
_AnyIdAIMessageChunk(content="f"),
|
_any_id_ai_message_chunk(content="f"),
|
||||||
_AnyIdAIMessageChunk(content="o"),
|
_any_id_ai_message_chunk(content="o"),
|
||||||
_AnyIdAIMessageChunk(content="o"),
|
_any_id_ai_message_chunk(content="o"),
|
||||||
]
|
]
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
@ -2548,7 +2548,7 @@ def test_prompt_with_chat_model_and_parser(
|
|||||||
HumanMessage(content="What is your name?"),
|
HumanMessage(content="What is your name?"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
|
assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar")
|
||||||
|
|
||||||
assert tracer.runs == snapshot
|
assert tracer.runs == snapshot
|
||||||
|
|
||||||
@ -2681,7 +2681,7 @@ Question:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
|
assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar")
|
||||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
assert len(parent_run.child_runs) == 4
|
assert len(parent_run.child_runs) == 4
|
||||||
@ -2727,7 +2727,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
|||||||
assert chain.invoke(
|
assert chain.invoke(
|
||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
) == {
|
) == {
|
||||||
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
|
"chat": _any_id_ai_message(content="i'm a chatbot"),
|
||||||
"llm": "i'm a textbot",
|
"llm": "i'm a textbot",
|
||||||
}
|
}
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
@ -2936,7 +2936,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
|||||||
assert chain.invoke(
|
assert chain.invoke(
|
||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
) == {
|
) == {
|
||||||
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
|
"chat": _any_id_ai_message(content="i'm a chatbot"),
|
||||||
"llm": "i'm a textbot",
|
"llm": "i'm a textbot",
|
||||||
"passthrough": ChatPromptValue(
|
"passthrough": ChatPromptValue(
|
||||||
messages=[
|
messages=[
|
||||||
@ -3000,7 +3000,7 @@ def test_map_stream() -> None:
|
|||||||
assert streamed_chunks[0] in [
|
assert streamed_chunks[0] in [
|
||||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||||
{"llm": "i"},
|
{"llm": "i"},
|
||||||
{"chat": _AnyIdAIMessageChunk(content="i")},
|
{"chat": _any_id_ai_message_chunk(content="i")},
|
||||||
]
|
]
|
||||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
@ -3059,11 +3059,11 @@ def test_map_stream() -> None:
|
|||||||
|
|
||||||
assert streamed_chunks[0] in [
|
assert streamed_chunks[0] in [
|
||||||
{"llm": "i"},
|
{"llm": "i"},
|
||||||
{"chat": _AnyIdAIMessageChunk(content="i")},
|
{"chat": _any_id_ai_message_chunk(content="i")},
|
||||||
]
|
]
|
||||||
if not ( # TODO(Rewrite properly) statement above
|
if not ( # TODO(Rewrite properly) statement above
|
||||||
streamed_chunks[0] == {"llm": "i"}
|
streamed_chunks[0] == {"llm": "i"}
|
||||||
or {"chat": _AnyIdAIMessageChunk(content="i")}
|
or {"chat": _any_id_ai_message_chunk(content="i")}
|
||||||
):
|
):
|
||||||
raise AssertionError(f"Got an unexpected chunk: {streamed_chunks[0]}")
|
raise AssertionError(f"Got an unexpected chunk: {streamed_chunks[0]}")
|
||||||
|
|
||||||
@ -3108,7 +3108,7 @@ def test_map_stream_iterator_input() -> None:
|
|||||||
assert streamed_chunks[0] in [
|
assert streamed_chunks[0] in [
|
||||||
{"passthrough": "i"},
|
{"passthrough": "i"},
|
||||||
{"llm": "i"},
|
{"llm": "i"},
|
||||||
{"chat": _AnyIdAIMessageChunk(content="i")},
|
{"chat": _any_id_ai_message_chunk(content="i")},
|
||||||
]
|
]
|
||||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
||||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
@ -3152,7 +3152,7 @@ async def test_map_astream() -> None:
|
|||||||
assert streamed_chunks[0] in [
|
assert streamed_chunks[0] in [
|
||||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||||
{"llm": "i"},
|
{"llm": "i"},
|
||||||
{"chat": _AnyIdAIMessageChunk(content="i")},
|
{"chat": _any_id_ai_message_chunk(content="i")},
|
||||||
]
|
]
|
||||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
|
@ -32,7 +32,7 @@ from langchain_core.runnables import (
|
|||||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
|
||||||
|
|
||||||
|
|
||||||
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
|
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
|
||||||
@ -503,7 +503,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b"},
|
"metadata": {"a": "b"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -512,7 +512,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b"},
|
"metadata": {"a": "b"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -521,7 +521,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b"},
|
"metadata": {"a": "b"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -530,7 +530,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"output": _AnyIdAIMessageChunk(content="hello world!")},
|
"data": {"output": _any_id_ai_message_chunk(content="hello world!")},
|
||||||
"event": "on_chat_model_end",
|
"event": "on_chat_model_end",
|
||||||
"metadata": {"a": "b"},
|
"metadata": {"a": "b"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -575,7 +575,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -588,7 +588,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -601,7 +601,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -621,7 +621,9 @@ async def test_astream_events_from_model() -> None:
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"generation_info": None,
|
"generation_info": None,
|
||||||
"message": _AnyIdAIMessage(content="hello world!"),
|
"message": _any_id_ai_message(
|
||||||
|
content="hello world!"
|
||||||
|
),
|
||||||
"text": "hello world!",
|
"text": "hello world!",
|
||||||
"type": "ChatGeneration",
|
"type": "ChatGeneration",
|
||||||
}
|
}
|
||||||
@ -644,7 +646,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
"data": {"chunk": _any_id_ai_message(content="hello world!")},
|
||||||
"event": "on_chain_stream",
|
"event": "on_chain_stream",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "i_dont_stream",
|
"name": "i_dont_stream",
|
||||||
@ -653,7 +655,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": [],
|
"tags": [],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
"data": {"output": _any_id_ai_message(content="hello world!")},
|
||||||
"event": "on_chain_end",
|
"event": "on_chain_end",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "i_dont_stream",
|
"name": "i_dont_stream",
|
||||||
@ -698,7 +700,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -711,7 +713,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -724,7 +726,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -744,7 +746,9 @@ async def test_astream_events_from_model() -> None:
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"generation_info": None,
|
"generation_info": None,
|
||||||
"message": _AnyIdAIMessage(content="hello world!"),
|
"message": _any_id_ai_message(
|
||||||
|
content="hello world!"
|
||||||
|
),
|
||||||
"text": "hello world!",
|
"text": "hello world!",
|
||||||
"type": "ChatGeneration",
|
"type": "ChatGeneration",
|
||||||
}
|
}
|
||||||
@ -767,7 +771,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
"data": {"chunk": _any_id_ai_message(content="hello world!")},
|
||||||
"event": "on_chain_stream",
|
"event": "on_chain_stream",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "ai_dont_stream",
|
"name": "ai_dont_stream",
|
||||||
@ -776,7 +780,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": [],
|
"tags": [],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
"data": {"output": _any_id_ai_message(content="hello world!")},
|
||||||
"event": "on_chain_end",
|
"event": "on_chain_end",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "ai_dont_stream",
|
"name": "ai_dont_stream",
|
||||||
|
@ -47,7 +47,7 @@ from langchain_core.utils.aiter import aclosing
|
|||||||
from tests.unit_tests.runnables.test_runnable_events_v1 import (
|
from tests.unit_tests.runnables.test_runnable_events_v1 import (
|
||||||
_assert_events_equal_allow_superset_metadata,
|
_assert_events_equal_allow_superset_metadata,
|
||||||
)
|
)
|
||||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
|
||||||
|
|
||||||
|
|
||||||
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
|
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
|
||||||
@ -533,7 +533,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -546,7 +546,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -559,7 +559,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -573,7 +573,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"output": _AnyIdAIMessageChunk(content="hello world!"),
|
"output": _any_id_ai_message_chunk(content="hello world!"),
|
||||||
},
|
},
|
||||||
"event": "on_chat_model_end",
|
"event": "on_chat_model_end",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -640,7 +640,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -653,7 +653,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -666,7 +666,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -681,7 +681,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"input": {"messages": [[HumanMessage(content="hello")]]},
|
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||||
"output": _AnyIdAIMessage(content="hello world!"),
|
"output": _any_id_ai_message(content="hello world!"),
|
||||||
},
|
},
|
||||||
"event": "on_chat_model_end",
|
"event": "on_chat_model_end",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -695,7 +695,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
"data": {"chunk": _any_id_ai_message(content="hello world!")},
|
||||||
"event": "on_chain_stream",
|
"event": "on_chain_stream",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "i_dont_stream",
|
"name": "i_dont_stream",
|
||||||
@ -704,7 +704,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": [],
|
"tags": [],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
"data": {"output": _any_id_ai_message(content="hello world!")},
|
||||||
"event": "on_chain_end",
|
"event": "on_chain_end",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "i_dont_stream",
|
"name": "i_dont_stream",
|
||||||
@ -749,7 +749,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -762,7 +762,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -775,7 +775,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"a": "b",
|
"a": "b",
|
||||||
@ -790,7 +790,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"input": {"messages": [[HumanMessage(content="hello")]]},
|
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||||
"output": _AnyIdAIMessage(content="hello world!"),
|
"output": _any_id_ai_message(content="hello world!"),
|
||||||
},
|
},
|
||||||
"event": "on_chat_model_end",
|
"event": "on_chat_model_end",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -804,7 +804,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
"data": {"chunk": _any_id_ai_message(content="hello world!")},
|
||||||
"event": "on_chain_stream",
|
"event": "on_chain_stream",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "ai_dont_stream",
|
"name": "ai_dont_stream",
|
||||||
@ -813,7 +813,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": [],
|
"tags": [],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
"data": {"output": _any_id_ai_message(content="hello world!")},
|
||||||
"event": "on_chain_end",
|
"event": "on_chain_end",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "ai_dont_stream",
|
"name": "ai_dont_stream",
|
||||||
|
@ -16,28 +16,28 @@ class AnyStr(str):
|
|||||||
# subclassed strings.
|
# subclassed strings.
|
||||||
|
|
||||||
|
|
||||||
def _AnyIdDocument(**kwargs: Any) -> Document:
|
def _any_id_document(**kwargs: Any) -> Document:
|
||||||
"""Create a document with an id field."""
|
"""Create a document with an id field."""
|
||||||
message = Document(**kwargs)
|
message = Document(**kwargs)
|
||||||
message.id = AnyStr()
|
message.id = AnyStr()
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
|
def _any_id_ai_message(**kwargs: Any) -> AIMessage:
|
||||||
"""Create ai message with an any id field."""
|
"""Create ai message with an any id field."""
|
||||||
message = AIMessage(**kwargs)
|
message = AIMessage(**kwargs)
|
||||||
message.id = AnyStr()
|
message.id = AnyStr()
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
|
def _any_id_ai_message_chunk(**kwargs: Any) -> AIMessageChunk:
|
||||||
"""Create ai message with an any id field."""
|
"""Create ai message with an any id field."""
|
||||||
message = AIMessageChunk(**kwargs)
|
message = AIMessageChunk(**kwargs)
|
||||||
message.id = AnyStr()
|
message.id = AnyStr()
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
|
def _any_id_human_message(**kwargs: Any) -> HumanMessage:
|
||||||
"""Create a human with an any id field."""
|
"""Create a human with an any id field."""
|
||||||
message = HumanMessage(**kwargs)
|
message = HumanMessage(**kwargs)
|
||||||
message.id = AnyStr()
|
message.id = AnyStr()
|
||||||
|
@ -781,7 +781,7 @@ def test_convert_to_messages() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"MessageClass",
|
"message_class",
|
||||||
[
|
[
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@ -790,39 +790,39 @@ def test_convert_to_messages() -> None:
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_message_name(MessageClass: type) -> None:
|
def test_message_name(message_class: type) -> None:
|
||||||
msg = MessageClass(content="foo", name="bar")
|
msg = message_class(content="foo", name="bar")
|
||||||
assert msg.name == "bar"
|
assert msg.name == "bar"
|
||||||
|
|
||||||
msg2 = MessageClass(content="foo", name=None)
|
msg2 = message_class(content="foo", name=None)
|
||||||
assert msg2.name is None
|
assert msg2.name is None
|
||||||
|
|
||||||
msg3 = MessageClass(content="foo")
|
msg3 = message_class(content="foo")
|
||||||
assert msg3.name is None
|
assert msg3.name is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"MessageClass",
|
"message_class",
|
||||||
[FunctionMessage, FunctionMessageChunk],
|
[FunctionMessage, FunctionMessageChunk],
|
||||||
)
|
)
|
||||||
def test_message_name_function(MessageClass: type) -> None:
|
def test_message_name_function(message_class: type) -> None:
|
||||||
# functionmessage doesn't support name=None
|
# functionmessage doesn't support name=None
|
||||||
msg = MessageClass(name="foo", content="bar")
|
msg = message_class(name="foo", content="bar")
|
||||||
assert msg.name == "foo"
|
assert msg.name == "foo"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"MessageClass",
|
"message_class",
|
||||||
[ChatMessage, ChatMessageChunk],
|
[ChatMessage, ChatMessageChunk],
|
||||||
)
|
)
|
||||||
def test_message_name_chat(MessageClass: type) -> None:
|
def test_message_name_chat(message_class: type) -> None:
|
||||||
msg = MessageClass(content="foo", role="user", name="bar")
|
msg = message_class(content="foo", role="user", name="bar")
|
||||||
assert msg.name == "bar"
|
assert msg.name == "bar"
|
||||||
|
|
||||||
msg2 = MessageClass(content="foo", role="user", name=None)
|
msg2 = message_class(content="foo", role="user", name=None)
|
||||||
assert msg2.name is None
|
assert msg2.name is None
|
||||||
|
|
||||||
msg3 = MessageClass(content="foo", role="user")
|
msg3 = message_class(content="foo", role="user")
|
||||||
assert msg3.name is None
|
assert msg3.name is None
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,14 +51,14 @@ def test_serde_any_message() -> None:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
Model = RootModel[AnyMessage]
|
model = RootModel[AnyMessage]
|
||||||
|
|
||||||
for lc_object in lc_objects:
|
for lc_object in lc_objects:
|
||||||
d = lc_object.model_dump()
|
d = lc_object.model_dump()
|
||||||
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
|
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
|
||||||
obj1 = Model.model_validate(d)
|
obj1 = model.model_validate(d)
|
||||||
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
|
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
|
||||||
|
|
||||||
with pytest.raises((TypeError, ValidationError)):
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
# Make sure that specifically validation error is raised
|
# Make sure that specifically validation error is raised
|
||||||
Model.model_validate({})
|
model.model_validate({})
|
||||||
|
@ -1421,7 +1421,7 @@ class InjectedTool(BaseTool):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
class fooSchema(BaseModel):
|
class fooSchema(BaseModel): # noqa: N801
|
||||||
"""foo."""
|
"""foo."""
|
||||||
|
|
||||||
x: int = Field(..., description="abc")
|
x: int = Field(..., description="abc")
|
||||||
@ -1568,14 +1568,14 @@ def test_tool_injected_arg() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_tool_inherited_injected_arg() -> None:
|
def test_tool_inherited_injected_arg() -> None:
|
||||||
class barSchema(BaseModel):
|
class BarSchema(BaseModel):
|
||||||
"""bar."""
|
"""bar."""
|
||||||
|
|
||||||
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
|
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
|
||||||
..., description="123"
|
..., description="123"
|
||||||
)
|
)
|
||||||
|
|
||||||
class fooSchema(barSchema):
|
class FooSchema(BarSchema):
|
||||||
"""foo."""
|
"""foo."""
|
||||||
|
|
||||||
x: int = Field(..., description="abc")
|
x: int = Field(..., description="abc")
|
||||||
@ -1583,14 +1583,14 @@ def test_tool_inherited_injected_arg() -> None:
|
|||||||
class InheritedInjectedArgTool(BaseTool):
|
class InheritedInjectedArgTool(BaseTool):
|
||||||
name: str = "foo"
|
name: str = "foo"
|
||||||
description: str = "foo."
|
description: str = "foo."
|
||||||
args_schema: type[BaseModel] = fooSchema
|
args_schema: type[BaseModel] = FooSchema
|
||||||
|
|
||||||
def _run(self, x: int, y: str) -> Any:
|
def _run(self, x: int, y: str) -> Any:
|
||||||
return y
|
return y
|
||||||
|
|
||||||
tool_ = InheritedInjectedArgTool()
|
tool_ = InheritedInjectedArgTool()
|
||||||
assert tool_.get_input_schema().model_json_schema() == {
|
assert tool_.get_input_schema().model_json_schema() == {
|
||||||
"title": "fooSchema", # Matches the title from the provided schema
|
"title": "FooSchema", # Matches the title from the provided schema
|
||||||
"description": "foo.",
|
"description": "foo.",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@ -1877,15 +1877,15 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
|
|||||||
A = TypeVar("A")
|
A = TypeVar("A")
|
||||||
|
|
||||||
if use_v1_namespace:
|
if use_v1_namespace:
|
||||||
from pydantic.v1 import BaseModel as BM1
|
from pydantic.v1 import BaseModel as BaseModel1
|
||||||
|
|
||||||
class ModelA(BM1, Generic[A], extra="allow"):
|
class ModelA(BaseModel1, Generic[A], extra="allow"):
|
||||||
a: A
|
a: A
|
||||||
else:
|
else:
|
||||||
from pydantic import BaseModel as BM2
|
from pydantic import BaseModel as BaseModel2
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
class ModelA(BM2, Generic[A]): # type: ignore[no-redef]
|
class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef]
|
||||||
a: A
|
a: A
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||||
|
|
||||||
@ -2081,13 +2081,13 @@ def test_structured_tool_direct_init() -> None:
|
|||||||
def foo(bar: str) -> str:
|
def foo(bar: str) -> str:
|
||||||
return bar
|
return bar
|
||||||
|
|
||||||
async def asyncFoo(bar: str) -> str:
|
async def async_foo(bar: str) -> str:
|
||||||
return bar
|
return bar
|
||||||
|
|
||||||
class fooSchema(BaseModel):
|
class FooSchema(BaseModel):
|
||||||
bar: str = Field(..., description="The bar")
|
bar: str = Field(..., description="The bar")
|
||||||
|
|
||||||
tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo)
|
tool = StructuredTool(name="foo", args_schema=FooSchema, coroutine=async_foo)
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
assert tool.invoke("hello") == "hello"
|
assert tool.invoke("hello") == "hello"
|
||||||
|
@ -38,7 +38,7 @@ from langchain_core.utils.function_calling import (
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def pydantic() -> type[BaseModel]:
|
def pydantic() -> type[BaseModel]:
|
||||||
class dummy_function(BaseModel):
|
class dummy_function(BaseModel): # noqa: N801
|
||||||
"""dummy function"""
|
"""dummy function"""
|
||||||
|
|
||||||
arg1: int = Field(..., description="foo")
|
arg1: int = Field(..., description="foo")
|
||||||
@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def Annotated_function() -> Callable:
|
def annotated_function() -> Callable:
|
||||||
def dummy_function(
|
def dummy_function(
|
||||||
arg1: ExtensionsAnnotated[int, "foo"],
|
arg1: ExtensionsAnnotated[int, "foo"],
|
||||||
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
|
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
|
||||||
@ -118,7 +118,7 @@ def dummy_structured_tool() -> StructuredTool:
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def dummy_pydantic() -> type[BaseModel]:
|
def dummy_pydantic() -> type[BaseModel]:
|
||||||
class dummy_function(BaseModel):
|
class dummy_function(BaseModel): # noqa: N801
|
||||||
"""dummy function"""
|
"""dummy function"""
|
||||||
|
|
||||||
arg1: int = Field(..., description="foo")
|
arg1: int = Field(..., description="foo")
|
||||||
@ -129,7 +129,7 @@ def dummy_pydantic() -> type[BaseModel]:
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
||||||
class dummy_function(BaseModelV2Maybe):
|
class dummy_function(BaseModelV2Maybe): # noqa: N801
|
||||||
"""dummy function"""
|
"""dummy function"""
|
||||||
|
|
||||||
arg1: int = FieldV2Maybe(..., description="foo")
|
arg1: int = FieldV2Maybe(..., description="foo")
|
||||||
@ -142,7 +142,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def dummy_typing_typed_dict() -> type:
|
def dummy_typing_typed_dict() -> type:
|
||||||
class dummy_function(TypingTypedDict):
|
class dummy_function(TypingTypedDict): # noqa: N801
|
||||||
"""dummy function"""
|
"""dummy function"""
|
||||||
|
|
||||||
arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821
|
arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821
|
||||||
@ -153,7 +153,7 @@ def dummy_typing_typed_dict() -> type:
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def dummy_typing_typed_dict_docstring() -> type:
|
def dummy_typing_typed_dict_docstring() -> type:
|
||||||
class dummy_function(TypingTypedDict):
|
class dummy_function(TypingTypedDict): # noqa: N801
|
||||||
"""dummy function
|
"""dummy function
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -169,7 +169,7 @@ def dummy_typing_typed_dict_docstring() -> type:
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def dummy_extensions_typed_dict() -> type:
|
def dummy_extensions_typed_dict() -> type:
|
||||||
class dummy_function(ExtensionsTypedDict):
|
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
||||||
"""dummy function"""
|
"""dummy function"""
|
||||||
|
|
||||||
arg1: ExtensionsAnnotated[int, ..., "foo"]
|
arg1: ExtensionsAnnotated[int, ..., "foo"]
|
||||||
@ -180,7 +180,7 @@ def dummy_extensions_typed_dict() -> type:
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def dummy_extensions_typed_dict_docstring() -> type:
|
def dummy_extensions_typed_dict_docstring() -> type:
|
||||||
class dummy_function(ExtensionsTypedDict):
|
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
||||||
"""dummy function
|
"""dummy function
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -241,7 +241,7 @@ def test_convert_to_openai_function(
|
|||||||
dummy_structured_tool: StructuredTool,
|
dummy_structured_tool: StructuredTool,
|
||||||
dummy_tool: BaseTool,
|
dummy_tool: BaseTool,
|
||||||
json_schema: dict,
|
json_schema: dict,
|
||||||
Annotated_function: Callable,
|
annotated_function: Callable,
|
||||||
dummy_pydantic: type[BaseModel],
|
dummy_pydantic: type[BaseModel],
|
||||||
runnable: Runnable,
|
runnable: Runnable,
|
||||||
dummy_typing_typed_dict: type,
|
dummy_typing_typed_dict: type,
|
||||||
@ -275,7 +275,7 @@ def test_convert_to_openai_function(
|
|||||||
expected,
|
expected,
|
||||||
Dummy.dummy_function,
|
Dummy.dummy_function,
|
||||||
DummyWithClassMethod.dummy_function,
|
DummyWithClassMethod.dummy_function,
|
||||||
Annotated_function,
|
annotated_function,
|
||||||
dummy_pydantic,
|
dummy_pydantic,
|
||||||
dummy_typing_typed_dict,
|
dummy_typing_typed_dict,
|
||||||
dummy_typing_typed_dict_docstring,
|
dummy_typing_typed_dict_docstring,
|
||||||
@ -523,20 +523,20 @@ def test__convert_typed_dict_to_openai_function(
|
|||||||
use_extension_typed_dict: bool, use_extension_annotated: bool
|
use_extension_typed_dict: bool, use_extension_annotated: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
if use_extension_typed_dict:
|
if use_extension_typed_dict:
|
||||||
TypedDict = ExtensionsTypedDict
|
typed_dict = ExtensionsTypedDict
|
||||||
else:
|
else:
|
||||||
TypedDict = TypingTypedDict
|
typed_dict = TypingTypedDict
|
||||||
if use_extension_annotated:
|
if use_extension_annotated:
|
||||||
Annotated = TypingAnnotated
|
annotated = TypingAnnotated
|
||||||
else:
|
else:
|
||||||
Annotated = TypingAnnotated
|
annotated = TypingAnnotated
|
||||||
|
|
||||||
class SubTool(TypedDict):
|
class SubTool(typed_dict):
|
||||||
"""Subtool docstring"""
|
"""Subtool docstring"""
|
||||||
|
|
||||||
args: Annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
|
args: annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
|
||||||
|
|
||||||
class Tool(TypedDict):
|
class Tool(typed_dict):
|
||||||
"""Docstring
|
"""Docstring
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -546,20 +546,20 @@ def test__convert_typed_dict_to_openai_function(
|
|||||||
arg1: str
|
arg1: str
|
||||||
arg2: Union[int, str, bool]
|
arg2: Union[int, str, bool]
|
||||||
arg3: Optional[list[SubTool]]
|
arg3: Optional[list[SubTool]]
|
||||||
arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722
|
arg4: annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722
|
||||||
arg5: Annotated[Optional[float], None]
|
arg5: annotated[Optional[float], None]
|
||||||
arg6: Annotated[
|
arg6: annotated[
|
||||||
Optional[Sequence[Mapping[str, tuple[Iterable[Any], SubTool]]]], []
|
Optional[Sequence[Mapping[str, tuple[Iterable[Any], SubTool]]]], []
|
||||||
]
|
]
|
||||||
arg7: Annotated[list[SubTool], ...]
|
arg7: annotated[list[SubTool], ...]
|
||||||
arg8: Annotated[tuple[SubTool], ...]
|
arg8: annotated[tuple[SubTool], ...]
|
||||||
arg9: Annotated[Sequence[SubTool], ...]
|
arg9: annotated[Sequence[SubTool], ...]
|
||||||
arg10: Annotated[Iterable[SubTool], ...]
|
arg10: annotated[Iterable[SubTool], ...]
|
||||||
arg11: Annotated[set[SubTool], ...]
|
arg11: annotated[set[SubTool], ...]
|
||||||
arg12: Annotated[dict[str, SubTool], ...]
|
arg12: annotated[dict[str, SubTool], ...]
|
||||||
arg13: Annotated[Mapping[str, SubTool], ...]
|
arg13: annotated[Mapping[str, SubTool], ...]
|
||||||
arg14: Annotated[MutableMapping[str, SubTool], ...]
|
arg14: annotated[MutableMapping[str, SubTool], ...]
|
||||||
arg15: Annotated[bool, False, "flag"] # noqa: F821 # type: ignore
|
arg15: annotated[bool, False, "flag"] # noqa: F821 # type: ignore
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
"name": "Tool",
|
"name": "Tool",
|
||||||
|
@ -10,7 +10,7 @@ from langchain_standard_tests.integration_tests.vectorstores import (
|
|||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
|
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
|
||||||
from langchain_core.vectorstores import InMemoryVectorStore
|
from langchain_core.vectorstores import InMemoryVectorStore
|
||||||
from tests.unit_tests.stubs import _AnyIdDocument
|
from tests.unit_tests.stubs import _any_id_document
|
||||||
|
|
||||||
|
|
||||||
class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite):
|
class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite):
|
||||||
@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None:
|
|||||||
|
|
||||||
# Check sync version
|
# Check sync version
|
||||||
output = store.similarity_search("foo", k=1)
|
output = store.similarity_search("foo", k=1)
|
||||||
assert output == [_AnyIdDocument(page_content="foo")]
|
assert output == [_any_id_document(page_content="foo")]
|
||||||
|
|
||||||
# Check async version
|
# Check async version
|
||||||
output = await store.asimilarity_search("bar", k=2)
|
output = await store.asimilarity_search("bar", k=2)
|
||||||
assert output == [
|
assert output == [
|
||||||
_AnyIdDocument(page_content="bar"),
|
_any_id_document(page_content="bar"),
|
||||||
_AnyIdDocument(page_content="baz"),
|
_any_id_document(page_content="baz"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None:
|
|||||||
# make sure we can k > docstore size
|
# make sure we can k > docstore size
|
||||||
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
|
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
|
||||||
assert len(output) == len(texts)
|
assert len(output) == len(texts)
|
||||||
assert output[0] == _AnyIdDocument(page_content="foo")
|
assert output[0] == _any_id_document(page_content="foo")
|
||||||
assert output[1] == _AnyIdDocument(page_content="foy")
|
assert output[1] == _any_id_document(page_content="foy")
|
||||||
|
|
||||||
# Check async version
|
# Check async version
|
||||||
output = await docsearch.amax_marginal_relevance_search(
|
output = await docsearch.amax_marginal_relevance_search(
|
||||||
"foo", k=10, lambda_mult=0.1
|
"foo", k=10, lambda_mult=0.1
|
||||||
)
|
)
|
||||||
assert len(output) == len(texts)
|
assert len(output) == len(texts)
|
||||||
assert output[0] == _AnyIdDocument(page_content="foo")
|
assert output[0] == _any_id_document(page_content="foo")
|
||||||
assert output[1] == _AnyIdDocument(page_content="foy")
|
assert output[1] == _any_id_document(page_content="foy")
|
||||||
|
|
||||||
|
|
||||||
async def test_inmemory_dump_load(tmp_path: Path) -> None:
|
async def test_inmemory_dump_load(tmp_path: Path) -> None:
|
||||||
@ -117,7 +117,7 @@ async def test_inmemory_filter() -> None:
|
|||||||
|
|
||||||
# Check sync version
|
# Check sync version
|
||||||
output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1)
|
output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1)
|
||||||
assert output == [_AnyIdDocument(page_content="foo", metadata={"id": 1})]
|
assert output == [_any_id_document(page_content="foo", metadata={"id": 1})]
|
||||||
|
|
||||||
# filter with not stored document id
|
# filter with not stored document id
|
||||||
output = await store.asimilarity_search(
|
output = await store.asimilarity_search(
|
||||||
|
@ -50,6 +50,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
|||||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||||
return [self.store[id] for id in ids if id in self.store]
|
return [self.store[id] for id in ids if id in self.store]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def from_texts( # type: ignore
|
def from_texts( # type: ignore
|
||||||
cls,
|
cls,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
|
Loading…
Reference in New Issue
Block a user