mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-29 15:59:48 +00:00
core: Add N(naming) ruff rules (#25362)
Public classes/functions are not renamed and rule is ignored for them. Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
7835c0651f
commit
fd21ffe293
@ -155,7 +155,7 @@ def beta(
|
||||
_name = _name or obj.fget.__qualname__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
class _beta_property(property):
|
||||
class _BetaProperty(property):
|
||||
"""A beta property."""
|
||||
|
||||
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
|
||||
@ -186,7 +186,7 @@ def beta(
|
||||
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
|
||||
"""Finalize the property."""
|
||||
return _beta_property(
|
||||
return _BetaProperty(
|
||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||
)
|
||||
|
||||
|
@ -264,7 +264,7 @@ def deprecated(
|
||||
_name = _name or cast(Union[type, Callable], obj.fget).__qualname__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
class _deprecated_property(property):
|
||||
class _DeprecatedProperty(property):
|
||||
"""A deprecated property."""
|
||||
|
||||
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
|
||||
@ -297,7 +297,7 @@ def deprecated(
|
||||
"""Finalize the property."""
|
||||
return cast(
|
||||
T,
|
||||
_deprecated_property(
|
||||
_DeprecatedProperty(
|
||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||
),
|
||||
)
|
||||
|
@ -3,7 +3,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class LangChainException(Exception):
|
||||
class LangChainException(Exception): # noqa: N818
|
||||
"""General LangChain exception."""
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ class TracerException(LangChainException):
|
||||
"""Base class for exceptions in tracers module."""
|
||||
|
||||
|
||||
class OutputParserException(ValueError, LangChainException):
|
||||
class OutputParserException(ValueError, LangChainException): # noqa: N818
|
||||
"""Exception that output parsers should raise to signify a parsing error.
|
||||
|
||||
This exists to differentiate parsing errors from other code or execution errors
|
||||
|
@ -14,7 +14,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
from typing_extensions import TypeAlias, TypedDict, override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import (
|
||||
@ -143,6 +143,7 @@ class BaseLanguageModel(
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> TypeAlias:
|
||||
"""Get the input type for this runnable."""
|
||||
from langchain_core.prompt_values import (
|
||||
|
@ -26,6 +26,7 @@ from pydantic import (
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache
|
||||
@ -251,6 +252,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
# --- Runnable methods ---
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> Any:
|
||||
"""Get the output type for this runnable."""
|
||||
return AnyMessage
|
||||
|
@ -31,6 +31,7 @@ from tenacity import (
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache
|
||||
@ -318,6 +319,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
# --- Runnable methods ---
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[str]:
|
||||
"""Get the input type for this runnable."""
|
||||
return str
|
||||
|
@ -10,6 +10,8 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.language_models import LanguageModelOutput
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
@ -63,11 +65,13 @@ class BaseGenerationOutputParser(
|
||||
"""Base class to parse the output of an LLM call."""
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
"""Return the input type for the parser."""
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[T]:
|
||||
"""Return the output type for the parser."""
|
||||
# even though mypy complains this isn't valid,
|
||||
@ -148,11 +152,13 @@ class BaseOutputParser(
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
"""Return the input type for the parser."""
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[T]:
|
||||
"""Return the output type for the parser.
|
||||
|
||||
|
@ -3,6 +3,7 @@ from typing import Annotated, Generic, Optional
|
||||
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
@ -107,6 +108,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
return "pydantic"
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[TBaseModel]:
|
||||
"""Return the pydantic model."""
|
||||
return self.pydantic_object
|
||||
|
@ -1,6 +1,6 @@
|
||||
import re
|
||||
import xml
|
||||
import xml.etree.ElementTree as ET
|
||||
import xml.etree.ElementTree as ET # noqa: N817
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from xml.etree.ElementTree import TreeBuilder
|
||||
@ -46,14 +46,14 @@ class _StreamingParser:
|
||||
"""
|
||||
if parser == "defusedxml":
|
||||
try:
|
||||
from defusedxml import ElementTree as DET # type: ignore
|
||||
import defusedxml # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"defusedxml is not installed. "
|
||||
"Please install it to use the defusedxml parser."
|
||||
"You can install it with `pip install defusedxml` "
|
||||
) from e
|
||||
_parser = DET.DefusedXMLParser(target=TreeBuilder())
|
||||
_parser = defusedxml.ElementTree.DefusedXMLParser(target=TreeBuilder())
|
||||
else:
|
||||
_parser = None
|
||||
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
|
||||
@ -189,7 +189,7 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
# likely if you're reading this you can move them to the top of the file
|
||||
if self.parser == "defusedxml":
|
||||
try:
|
||||
from defusedxml import ElementTree as DET # type: ignore
|
||||
import defusedxml # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"defusedxml is not installed. "
|
||||
@ -197,9 +197,9 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
"You can install it with `pip install defusedxml`"
|
||||
"See https://github.com/tiran/defusedxml for more details"
|
||||
) from e
|
||||
_ET = DET # Use the defusedxml parser
|
||||
_et = defusedxml.ElementTree # Use the defusedxml parser
|
||||
else:
|
||||
_ET = ET # Use the standard library parser
|
||||
_et = ET # Use the standard library parser
|
||||
|
||||
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
|
||||
if match is not None:
|
||||
|
@ -18,7 +18,7 @@ from typing import (
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.output_parsers.base import BaseOutputParser
|
||||
@ -107,6 +107,7 @@ class BasePromptTemplate(
|
||||
return dumpd(self)
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> Any:
|
||||
"""Return the output type of the prompt."""
|
||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||
|
@ -36,7 +36,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel
|
||||
from typing_extensions import Literal, get_args
|
||||
from typing_extensions import Literal, get_args, override
|
||||
|
||||
from langchain_core._api import beta_decorator
|
||||
from langchain_core.load.serializable import (
|
||||
@ -272,7 +272,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return name_
|
||||
|
||||
@property
|
||||
def InputType(self) -> type[Input]:
|
||||
def InputType(self) -> type[Input]: # noqa: N802
|
||||
"""The type of input this Runnable accepts specified as a type annotation."""
|
||||
# First loop through all parent classes and if any of them is
|
||||
# a pydantic model, we will pick up the generic parameterization
|
||||
@ -297,7 +297,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[Output]:
|
||||
def OutputType(self) -> type[Output]: # noqa: N802
|
||||
"""The type of output this Runnable produces specified as a type annotation."""
|
||||
# First loop through bases -- this will help generic
|
||||
# any pydantic models.
|
||||
@ -2811,11 +2811,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> type[Input]:
|
||||
"""The type of the input to the Runnable."""
|
||||
return self.first.InputType
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[Output]:
|
||||
"""The type of the output of the Runnable."""
|
||||
return self.last.OutputType
|
||||
@ -3564,6 +3566,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
return super().get_name(suffix, name=name)
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
"""The type of the input to the Runnable."""
|
||||
for step in self.steps__.values():
|
||||
@ -4057,6 +4060,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
self.name = "RunnableGenerator"
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
try:
|
||||
@ -4097,6 +4101,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> Any:
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
try:
|
||||
@ -4346,6 +4351,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
pass
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
"""The type of the input to this Runnable."""
|
||||
func = getattr(self, "func", None) or self.afunc
|
||||
@ -4405,6 +4411,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> Any:
|
||||
"""The type of the output of this Runnable as a type annotation.
|
||||
|
||||
@ -4958,6 +4965,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
return list[self.bound.InputType] # type: ignore[name-defined]
|
||||
|
||||
@ -4981,6 +4989,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[list[Output]]:
|
||||
return list[self.bound.OutputType] # type: ignore[name-defined]
|
||||
|
||||
@ -5274,6 +5283,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
return self.bound.get_name(suffix, name=name)
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> type[Input]:
|
||||
return (
|
||||
cast(type[Input], self.custom_input_type)
|
||||
@ -5282,6 +5292,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[Output]:
|
||||
return (
|
||||
cast(type[Output], self.custom_output_type)
|
||||
|
@ -16,6 +16,7 @@ from typing import (
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
@ -68,10 +69,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> type[Input]:
|
||||
return self.default.InputType
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[Output]:
|
||||
return self.default.OutputType
|
||||
|
||||
|
@ -13,6 +13,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
@ -106,10 +107,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> type[Input]:
|
||||
return self.runnable.InputType
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[Output]:
|
||||
return self.runnable.OutputType
|
||||
|
||||
|
@ -246,27 +246,27 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
||||
|
||||
# NOTE: coordinates might me negative, so we need to shift
|
||||
# everything to the positive plane before we actually draw it.
|
||||
Xs = []
|
||||
Ys = []
|
||||
xlist = []
|
||||
ylist = []
|
||||
|
||||
sug = _build_sugiyama_layout(vertices, edges)
|
||||
|
||||
for vertex in sug.g.sV:
|
||||
# NOTE: moving boxes w/2 to the left
|
||||
Xs.append(vertex.view.xy[0] - vertex.view.w / 2.0)
|
||||
Xs.append(vertex.view.xy[0] + vertex.view.w / 2.0)
|
||||
Ys.append(vertex.view.xy[1])
|
||||
Ys.append(vertex.view.xy[1] + vertex.view.h)
|
||||
xlist.append(vertex.view.xy[0] - vertex.view.w / 2.0)
|
||||
xlist.append(vertex.view.xy[0] + vertex.view.w / 2.0)
|
||||
ylist.append(vertex.view.xy[1])
|
||||
ylist.append(vertex.view.xy[1] + vertex.view.h)
|
||||
|
||||
for edge in sug.g.sE:
|
||||
for x, y in edge.view._pts:
|
||||
Xs.append(x)
|
||||
Ys.append(y)
|
||||
xlist.append(x)
|
||||
ylist.append(y)
|
||||
|
||||
minx = min(Xs)
|
||||
miny = min(Ys)
|
||||
maxx = max(Xs)
|
||||
maxy = max(Ys)
|
||||
minx = min(xlist)
|
||||
miny = min(ylist)
|
||||
maxx = max(xlist)
|
||||
maxy = max(ylist)
|
||||
|
||||
canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1
|
||||
canvas_lines = int(round(maxy - miny))
|
||||
|
@ -12,6 +12,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.load.load import load
|
||||
@ -396,6 +397,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[Output]:
|
||||
output_type = self._history_chain.OutputType
|
||||
return output_type
|
||||
|
@ -16,6 +16,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, RootModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
Other,
|
||||
@ -193,10 +194,12 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
return self.input_type or Any
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> Any:
|
||||
return self.input_type or Any
|
||||
|
||||
|
@ -28,7 +28,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeGuard
|
||||
from typing_extensions import TypeGuard, override
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
|
||||
@ -135,6 +135,7 @@ class IsLocalDict(ast.NodeVisitor):
|
||||
self.name = name
|
||||
self.keys = keys
|
||||
|
||||
@override
|
||||
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
||||
"""Visit a subscript node.
|
||||
|
||||
@ -154,6 +155,7 @@ class IsLocalDict(ast.NodeVisitor):
|
||||
# we've found a subscript access on the name we're looking for
|
||||
self.keys.add(node.slice.value)
|
||||
|
||||
@override
|
||||
def visit_Call(self, node: ast.Call) -> Any:
|
||||
"""Visit a call node.
|
||||
|
||||
@ -182,6 +184,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
self.keys: set[str] = set()
|
||||
|
||||
@override
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function.
|
||||
|
||||
@ -196,6 +199,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node.body)
|
||||
|
||||
@override
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
"""Visit a function definition.
|
||||
|
||||
@ -210,6 +214,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||
|
||||
@override
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||
"""Visit an async function definition.
|
||||
|
||||
@ -232,6 +237,7 @@ class NonLocals(ast.NodeVisitor):
|
||||
self.loads: set[str] = set()
|
||||
self.stores: set[str] = set()
|
||||
|
||||
@override
|
||||
def visit_Name(self, node: ast.Name) -> Any:
|
||||
"""Visit a name node.
|
||||
|
||||
@ -246,6 +252,7 @@ class NonLocals(ast.NodeVisitor):
|
||||
elif isinstance(node.ctx, ast.Store):
|
||||
self.stores.add(node.id)
|
||||
|
||||
@override
|
||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||
"""Visit an attribute node.
|
||||
|
||||
@ -272,6 +279,7 @@ class FunctionNonLocals(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
self.nonlocals: set[str] = set()
|
||||
|
||||
@override
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
"""Visit a function definition.
|
||||
|
||||
@ -285,6 +293,7 @@ class FunctionNonLocals(ast.NodeVisitor):
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
@override
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||
"""Visit an async function definition.
|
||||
|
||||
@ -298,6 +307,7 @@ class FunctionNonLocals(ast.NodeVisitor):
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
@override
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function.
|
||||
|
||||
@ -320,6 +330,7 @@ class GetLambdaSource(ast.NodeVisitor):
|
||||
self.source: Optional[str] = None
|
||||
self.count = 0
|
||||
|
||||
@override
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function.
|
||||
|
||||
|
@ -311,7 +311,7 @@ def create_schema_from_function(
|
||||
)
|
||||
|
||||
|
||||
class ToolException(Exception):
|
||||
class ToolException(Exception): # noqa: N818
|
||||
"""Optional exception that tool throws when execution error occurs.
|
||||
|
||||
When this exception is thrown, the agent will not stop working,
|
||||
|
@ -9,7 +9,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any:
|
||||
)
|
||||
|
||||
|
||||
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any:
|
||||
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802
|
||||
"""Throw an error because this has been replaced by LangChainTracer."""
|
||||
raise RuntimeError(
|
||||
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."
|
||||
|
@ -18,7 +18,7 @@ from langchain_core._api import deprecated
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="Use string instead.", removal="1.0")
|
||||
def RunTypeEnum() -> type[RunTypeEnumDep]:
|
||||
def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
|
||||
"""RunTypeEnum."""
|
||||
warnings.warn(
|
||||
"RunTypeEnum is deprecated. Please directly use a string instead"
|
||||
|
@ -235,7 +235,7 @@ class Tee(Generic[T]):
|
||||
atee = Tee
|
||||
|
||||
|
||||
class aclosing(AbstractAsyncContextManager):
|
||||
class aclosing(AbstractAsyncContextManager): # noqa: N801
|
||||
"""Async context manager for safely finalizing an asynchronously cleaned-up
|
||||
resource such as an async generator, calling its ``aclose()`` method.
|
||||
|
||||
|
@ -17,12 +17,12 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices.
|
||||
|
||||
Args:
|
||||
X: A matrix of shape (n, m).
|
||||
Y: A matrix of shape (k, m).
|
||||
x: A matrix of shape (n, m).
|
||||
y: A matrix of shape (k, m).
|
||||
|
||||
Returns:
|
||||
A matrix of shape (n, k) where each element (i, j) is the cosine similarity
|
||||
@ -40,33 +40,33 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"Please install numpy with `pip install numpy`."
|
||||
) from e
|
||||
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
if len(x) == 0 or len(y) == 0:
|
||||
return np.array([])
|
||||
|
||||
X = np.array(X)
|
||||
Y = np.array(Y)
|
||||
if X.shape[1] != Y.shape[1]:
|
||||
x = np.array(x)
|
||||
y = np.array(y)
|
||||
if x.shape[1] != y.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
||||
f"and Y has shape {Y.shape}."
|
||||
f"Number of columns in X and Y must be the same. X has shape {x.shape} "
|
||||
f"and Y has shape {y.shape}."
|
||||
)
|
||||
try:
|
||||
import simsimd as simd # type: ignore[import-not-found]
|
||||
|
||||
X = np.array(X, dtype=np.float32)
|
||||
Y = np.array(Y, dtype=np.float32)
|
||||
Z = 1 - np.array(simd.cdist(X, Y, metric="cosine"))
|
||||
return Z
|
||||
x = np.array(x, dtype=np.float32)
|
||||
y = np.array(y, dtype=np.float32)
|
||||
z = 1 - np.array(simd.cdist(x, y, metric="cosine"))
|
||||
return z
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
|
||||
"to use simsimd please install with `pip install simsimd`."
|
||||
)
|
||||
X_norm = np.linalg.norm(X, axis=1)
|
||||
Y_norm = np.linalg.norm(Y, axis=1)
|
||||
x_norm = np.linalg.norm(x, axis=1)
|
||||
y_norm = np.linalg.norm(y, axis=1)
|
||||
# Ignore divide by zero errors run time warnings as those are handled below.
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||
similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm)
|
||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||
return similarity
|
||||
|
||||
|
@ -44,9 +44,17 @@ python = ">=3.12.4"
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["B", "E", "F", "I", "T201", "UP"]
|
||||
select = ["B", "E", "F", "I", "N", "T201", "UP"]
|
||||
ignore = ["UP007"]
|
||||
|
||||
[tool.ruff.lint.pep8-naming]
|
||||
classmethod-decorators = [
|
||||
"classmethod",
|
||||
"langchain_core.utils.pydantic.pre_init",
|
||||
"pydantic.field_validator",
|
||||
"pydantic.v1.root_validator",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [ "tests/*",]
|
||||
|
||||
|
@ -9,9 +9,9 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.stubs import (
|
||||
_AnyIdAIMessage,
|
||||
_AnyIdAIMessageChunk,
|
||||
_AnyIdHumanMessage,
|
||||
_any_id_ai_message,
|
||||
_any_id_ai_message_chunk,
|
||||
_any_id_human_message,
|
||||
)
|
||||
|
||||
|
||||
@ -20,11 +20,11 @@ def test_generic_fake_chat_model_invoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = model.invoke("meow")
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
assert response == _any_id_ai_message(content="hello")
|
||||
response = model.invoke("kitty")
|
||||
assert response == _AnyIdAIMessage(content="goodbye")
|
||||
assert response == _any_id_ai_message(content="goodbye")
|
||||
response = model.invoke("meow")
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
assert response == _any_id_ai_message(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
@ -32,11 +32,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
assert response == _any_id_ai_message(content="hello")
|
||||
response = await model.ainvoke("kitty")
|
||||
assert response == _AnyIdAIMessage(content="goodbye")
|
||||
assert response == _any_id_ai_message(content="goodbye")
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
assert response == _any_id_ai_message(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_stream() -> None:
|
||||
@ -49,17 +49,17 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
_any_id_ai_message_chunk(content="hello"),
|
||||
_any_id_ai_message_chunk(content=" "),
|
||||
_any_id_ai_message_chunk(content="goodbye"),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
chunks = [chunk for chunk in model.stream("meow")]
|
||||
assert chunks == [
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
_any_id_ai_message_chunk(content="hello"),
|
||||
_any_id_ai_message_chunk(content=" "),
|
||||
_any_id_ai_message_chunk(content="goodbye"),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
@ -69,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}),
|
||||
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}),
|
||||
_any_id_ai_message_chunk(content="", additional_kwargs={"foo": 42}),
|
||||
_any_id_ai_message_chunk(content="", additional_kwargs={"bar": 24}),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
@ -88,19 +88,19 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
|
||||
assert chunks == [
|
||||
_AnyIdAIMessageChunk(
|
||||
_any_id_ai_message_chunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||
),
|
||||
_AnyIdAIMessageChunk(
|
||||
_any_id_ai_message_chunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'},
|
||||
},
|
||||
),
|
||||
_AnyIdAIMessageChunk(
|
||||
_any_id_ai_message_chunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
),
|
||||
_AnyIdAIMessageChunk(
|
||||
_any_id_ai_message_chunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
|
||||
@ -138,9 +138,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
|
||||
]
|
||||
final = log_patches[-1]
|
||||
assert final.state["streamed_output"] == [
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
_any_id_ai_message_chunk(content="hello"),
|
||||
_any_id_ai_message_chunk(content=" "),
|
||||
_any_id_ai_message_chunk(content="goodbye"),
|
||||
]
|
||||
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
|
||||
|
||||
@ -189,9 +189,9 @@ async def test_callback_handlers() -> None:
|
||||
# New model
|
||||
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
||||
assert results == [
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
_any_id_ai_message_chunk(content="hello"),
|
||||
_any_id_ai_message_chunk(content=" "),
|
||||
_any_id_ai_message_chunk(content="goodbye"),
|
||||
]
|
||||
assert tokens == ["hello", " ", "goodbye"]
|
||||
assert len({chunk.id for chunk in results}) == 1
|
||||
@ -200,6 +200,8 @@ async def test_callback_handlers() -> None:
|
||||
def test_chat_model_inputs() -> None:
|
||||
fake = ParrotFakeChatModel()
|
||||
|
||||
assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello")
|
||||
assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah")
|
||||
assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah")
|
||||
assert fake.invoke("hello") == _any_id_human_message(content="hello")
|
||||
assert fake.invoke([("ai", "blah")]) == _any_id_ai_message(content="blah")
|
||||
assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message(
|
||||
content="blah"
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ from tests.unit_tests.fake.callbacks import (
|
||||
FakeAsyncCallbackHandler,
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -147,10 +147,10 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
||||
|
||||
model = ModelWithGenerate()
|
||||
chunks = [chunk for chunk in model.stream("anything")]
|
||||
assert chunks == [_AnyIdAIMessage(content="hello")]
|
||||
assert chunks == [_any_id_ai_message(content="hello")]
|
||||
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == [_AnyIdAIMessage(content="hello")]
|
||||
assert chunks == [_any_id_ai_message(content="hello")]
|
||||
|
||||
|
||||
async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
@ -185,15 +185,15 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
model = ModelWithSyncStream()
|
||||
chunks = [chunk for chunk in model.stream("anything")]
|
||||
assert chunks == [
|
||||
_AnyIdAIMessageChunk(content="a"),
|
||||
_AnyIdAIMessageChunk(content="b"),
|
||||
_any_id_ai_message_chunk(content="a"),
|
||||
_any_id_ai_message_chunk(content="b"),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
assert type(model)._astream == BaseChatModel._astream
|
||||
astream_chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert astream_chunks == [
|
||||
_AnyIdAIMessageChunk(content="a"),
|
||||
_AnyIdAIMessageChunk(content="b"),
|
||||
_any_id_ai_message_chunk(content="a"),
|
||||
_any_id_ai_message_chunk(content="b"),
|
||||
]
|
||||
assert len({chunk.id for chunk in astream_chunks}) == 1
|
||||
|
||||
@ -230,8 +230,8 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
model = ModelWithAsyncStream()
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == [
|
||||
_AnyIdAIMessageChunk(content="a"),
|
||||
_AnyIdAIMessageChunk(content="b"),
|
||||
_any_id_ai_message_chunk(content="a"),
|
||||
_any_id_ai_message_chunk(content="b"),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
|
@ -25,7 +25,7 @@ def change_directory(dir: Path) -> Iterator:
|
||||
os.chdir(origin)
|
||||
|
||||
|
||||
def test_loading_from_YAML() -> None:
|
||||
def test_loading_from_yaml() -> None:
|
||||
"""Test loading from yaml file."""
|
||||
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml")
|
||||
expected_prompt = PromptTemplate(
|
||||
@ -36,7 +36,7 @@ def test_loading_from_YAML() -> None:
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_loading_from_JSON() -> None:
|
||||
def test_loading_from_json() -> None:
|
||||
"""Test loading from json file."""
|
||||
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json")
|
||||
expected_prompt = PromptTemplate(
|
||||
@ -46,14 +46,14 @@ def test_loading_from_JSON() -> None:
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_loading_jinja_from_JSON() -> None:
|
||||
def test_loading_jinja_from_json() -> None:
|
||||
"""Test that loading jinja2 format prompts from JSON raises ValueError."""
|
||||
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json"
|
||||
with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"):
|
||||
load_prompt(prompt_path)
|
||||
|
||||
|
||||
def test_loading_jinja_from_YAML() -> None:
|
||||
def test_loading_jinja_from_yaml() -> None:
|
||||
"""Test that loading jinja2 format prompts from YAML raises ValueError."""
|
||||
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml"
|
||||
with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"):
|
||||
|
@ -2,6 +2,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from syrupy import SnapshotAssertion
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.language_models import FakeListLLM
|
||||
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||
@ -353,9 +354,11 @@ def test_runnable_get_graph_with_invalid_input_type() -> None:
|
||||
|
||||
class InvalidInputTypeRunnable(Runnable[int, int]):
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> type:
|
||||
raise TypeError()
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: int,
|
||||
@ -375,9 +378,11 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
|
||||
|
||||
class InvalidOutputTypeRunnable(Runnable[int, int]):
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type:
|
||||
raise TypeError()
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: int,
|
||||
|
@ -493,7 +493,7 @@ def test_get_output_schema() -> None:
|
||||
def test_get_input_schema_input_messages() -> None:
|
||||
from pydantic import RootModel
|
||||
|
||||
RunnableWithMessageHistoryInput = RootModel[Sequence[BaseMessage]]
|
||||
runnable_with_message_history_input = RootModel[Sequence[BaseMessage]]
|
||||
|
||||
runnable = RunnableLambda(
|
||||
lambda messages: {
|
||||
@ -515,7 +515,7 @@ def test_get_input_schema_input_messages() -> None:
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable, get_session_history, output_messages_key="output"
|
||||
)
|
||||
expected_schema = _schema(RunnableWithMessageHistoryInput)
|
||||
expected_schema = _schema(runnable_with_message_history_input)
|
||||
expected_schema["title"] = "RunnableWithChatHistoryInput"
|
||||
assert _schema(with_history.get_input_schema()) == expected_schema
|
||||
|
||||
|
@ -87,7 +87,7 @@ from langchain_core.tracers import (
|
||||
)
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
|
||||
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk
|
||||
|
||||
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
|
||||
|
||||
@ -1699,7 +1699,7 @@ def test_prompt_with_chat_model(
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == _AnyIdAIMessage(content="foo")
|
||||
) == _any_id_ai_message(content="foo")
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
@ -1724,8 +1724,8 @@ def test_prompt_with_chat_model(
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [
|
||||
_AnyIdAIMessage(content="foo"),
|
||||
_AnyIdAIMessage(content="foo"),
|
||||
_any_id_ai_message(content="foo"),
|
||||
_any_id_ai_message(content="foo"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
@ -1765,9 +1765,9 @@ def test_prompt_with_chat_model(
|
||||
assert [
|
||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||
] == [
|
||||
_AnyIdAIMessageChunk(content="f"),
|
||||
_AnyIdAIMessageChunk(content="o"),
|
||||
_AnyIdAIMessageChunk(content="o"),
|
||||
_any_id_ai_message_chunk(content="f"),
|
||||
_any_id_ai_message_chunk(content="o"),
|
||||
_any_id_ai_message_chunk(content="o"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
@ -1805,7 +1805,7 @@ async def test_prompt_with_chat_model_async(
|
||||
tracer = FakeTracer()
|
||||
assert await chain.ainvoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == _AnyIdAIMessage(content="foo")
|
||||
) == _any_id_ai_message(content="foo")
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
@ -1830,8 +1830,8 @@ async def test_prompt_with_chat_model_async(
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [
|
||||
_AnyIdAIMessage(content="foo"),
|
||||
_AnyIdAIMessage(content="foo"),
|
||||
_any_id_ai_message(content="foo"),
|
||||
_any_id_ai_message(content="foo"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
@ -1874,9 +1874,9 @@ async def test_prompt_with_chat_model_async(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
] == [
|
||||
_AnyIdAIMessageChunk(content="f"),
|
||||
_AnyIdAIMessageChunk(content="o"),
|
||||
_AnyIdAIMessageChunk(content="o"),
|
||||
_any_id_ai_message_chunk(content="f"),
|
||||
_any_id_ai_message_chunk(content="o"),
|
||||
_any_id_ai_message_chunk(content="o"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
@ -2548,7 +2548,7 @@ def test_prompt_with_chat_model_and_parser(
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
|
||||
assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar")
|
||||
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
@ -2681,7 +2681,7 @@ Question:
|
||||
),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
|
||||
assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar")
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 4
|
||||
@ -2727,7 +2727,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == {
|
||||
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
|
||||
"chat": _any_id_ai_message(content="i'm a chatbot"),
|
||||
"llm": "i'm a textbot",
|
||||
}
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
@ -2936,7 +2936,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == {
|
||||
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
|
||||
"chat": _any_id_ai_message(content="i'm a chatbot"),
|
||||
"llm": "i'm a textbot",
|
||||
"passthrough": ChatPromptValue(
|
||||
messages=[
|
||||
@ -3000,7 +3000,7 @@ def test_map_stream() -> None:
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||
{"llm": "i"},
|
||||
{"chat": _AnyIdAIMessageChunk(content="i")},
|
||||
{"chat": _any_id_ai_message_chunk(content="i")},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
@ -3059,11 +3059,11 @@ def test_map_stream() -> None:
|
||||
|
||||
assert streamed_chunks[0] in [
|
||||
{"llm": "i"},
|
||||
{"chat": _AnyIdAIMessageChunk(content="i")},
|
||||
{"chat": _any_id_ai_message_chunk(content="i")},
|
||||
]
|
||||
if not ( # TODO(Rewrite properly) statement above
|
||||
streamed_chunks[0] == {"llm": "i"}
|
||||
or {"chat": _AnyIdAIMessageChunk(content="i")}
|
||||
or {"chat": _any_id_ai_message_chunk(content="i")}
|
||||
):
|
||||
raise AssertionError(f"Got an unexpected chunk: {streamed_chunks[0]}")
|
||||
|
||||
@ -3108,7 +3108,7 @@ def test_map_stream_iterator_input() -> None:
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": "i"},
|
||||
{"llm": "i"},
|
||||
{"chat": _AnyIdAIMessageChunk(content="i")},
|
||||
{"chat": _any_id_ai_message_chunk(content="i")},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
@ -3152,7 +3152,7 @@ async def test_map_astream() -> None:
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||
{"llm": "i"},
|
||||
{"chat": _AnyIdAIMessageChunk(content="i")},
|
||||
{"chat": _any_id_ai_message_chunk(content="i")},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
|
@ -32,7 +32,7 @@ from langchain_core.runnables import (
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import tool
|
||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
|
||||
|
||||
|
||||
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
|
||||
@ -503,7 +503,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -512,7 +512,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -521,7 +521,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -530,7 +530,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"output": _AnyIdAIMessageChunk(content="hello world!")},
|
||||
"data": {"output": _any_id_ai_message_chunk(content="hello world!")},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -575,7 +575,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -588,7 +588,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -601,7 +601,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -621,7 +621,9 @@ async def test_astream_events_from_model() -> None:
|
||||
[
|
||||
{
|
||||
"generation_info": None,
|
||||
"message": _AnyIdAIMessage(content="hello world!"),
|
||||
"message": _any_id_ai_message(
|
||||
content="hello world!"
|
||||
),
|
||||
"text": "hello world!",
|
||||
"type": "ChatGeneration",
|
||||
}
|
||||
@ -644,7 +646,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||
"data": {"chunk": _any_id_ai_message(content="hello world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
@ -653,7 +655,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||
"data": {"output": _any_id_ai_message(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
@ -698,7 +700,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -711,7 +713,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -724,7 +726,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -744,7 +746,9 @@ async def test_astream_events_from_model() -> None:
|
||||
[
|
||||
{
|
||||
"generation_info": None,
|
||||
"message": _AnyIdAIMessage(content="hello world!"),
|
||||
"message": _any_id_ai_message(
|
||||
content="hello world!"
|
||||
),
|
||||
"text": "hello world!",
|
||||
"type": "ChatGeneration",
|
||||
}
|
||||
@ -767,7 +771,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||
"data": {"chunk": _any_id_ai_message(content="hello world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
@ -776,7 +780,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||
"data": {"output": _any_id_ai_message(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
|
@ -47,7 +47,7 @@ from langchain_core.utils.aiter import aclosing
|
||||
from tests.unit_tests.runnables.test_runnable_events_v1 import (
|
||||
_assert_events_equal_allow_superset_metadata,
|
||||
)
|
||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
|
||||
|
||||
|
||||
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
|
||||
@ -533,7 +533,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -546,7 +546,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -559,7 +559,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -573,7 +573,7 @@ async def test_astream_events_from_model() -> None:
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"output": _AnyIdAIMessageChunk(content="hello world!"),
|
||||
"output": _any_id_ai_message_chunk(content="hello world!"),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {
|
||||
@ -640,7 +640,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -653,7 +653,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -666,7 +666,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -681,7 +681,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {
|
||||
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||
"output": _AnyIdAIMessage(content="hello world!"),
|
||||
"output": _any_id_ai_message(content="hello world!"),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {
|
||||
@ -695,7 +695,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||
"data": {"chunk": _any_id_ai_message(content="hello world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
@ -704,7 +704,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||
"data": {"output": _any_id_ai_message(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
@ -749,7 +749,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -762,7 +762,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -775,7 +775,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
@ -790,7 +790,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {
|
||||
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||
"output": _AnyIdAIMessage(content="hello world!"),
|
||||
"output": _any_id_ai_message(content="hello world!"),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {
|
||||
@ -804,7 +804,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||
"data": {"chunk": _any_id_ai_message(content="hello world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
@ -813,7 +813,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||
"data": {"output": _any_id_ai_message(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
|
@ -16,28 +16,28 @@ class AnyStr(str):
|
||||
# subclassed strings.
|
||||
|
||||
|
||||
def _AnyIdDocument(**kwargs: Any) -> Document:
|
||||
def _any_id_document(**kwargs: Any) -> Document:
|
||||
"""Create a document with an id field."""
|
||||
message = Document(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
|
||||
def _any_id_ai_message(**kwargs: Any) -> AIMessage:
|
||||
"""Create ai message with an any id field."""
|
||||
message = AIMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
|
||||
def _any_id_ai_message_chunk(**kwargs: Any) -> AIMessageChunk:
|
||||
"""Create ai message with an any id field."""
|
||||
message = AIMessageChunk(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
|
||||
def _any_id_human_message(**kwargs: Any) -> HumanMessage:
|
||||
"""Create a human with an any id field."""
|
||||
message = HumanMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
|
@ -781,7 +781,7 @@ def test_convert_to_messages() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"MessageClass",
|
||||
"message_class",
|
||||
[
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -790,39 +790,39 @@ def test_convert_to_messages() -> None:
|
||||
SystemMessage,
|
||||
],
|
||||
)
|
||||
def test_message_name(MessageClass: type) -> None:
|
||||
msg = MessageClass(content="foo", name="bar")
|
||||
def test_message_name(message_class: type) -> None:
|
||||
msg = message_class(content="foo", name="bar")
|
||||
assert msg.name == "bar"
|
||||
|
||||
msg2 = MessageClass(content="foo", name=None)
|
||||
msg2 = message_class(content="foo", name=None)
|
||||
assert msg2.name is None
|
||||
|
||||
msg3 = MessageClass(content="foo")
|
||||
msg3 = message_class(content="foo")
|
||||
assert msg3.name is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"MessageClass",
|
||||
"message_class",
|
||||
[FunctionMessage, FunctionMessageChunk],
|
||||
)
|
||||
def test_message_name_function(MessageClass: type) -> None:
|
||||
def test_message_name_function(message_class: type) -> None:
|
||||
# functionmessage doesn't support name=None
|
||||
msg = MessageClass(name="foo", content="bar")
|
||||
msg = message_class(name="foo", content="bar")
|
||||
assert msg.name == "foo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"MessageClass",
|
||||
"message_class",
|
||||
[ChatMessage, ChatMessageChunk],
|
||||
)
|
||||
def test_message_name_chat(MessageClass: type) -> None:
|
||||
msg = MessageClass(content="foo", role="user", name="bar")
|
||||
def test_message_name_chat(message_class: type) -> None:
|
||||
msg = message_class(content="foo", role="user", name="bar")
|
||||
assert msg.name == "bar"
|
||||
|
||||
msg2 = MessageClass(content="foo", role="user", name=None)
|
||||
msg2 = message_class(content="foo", role="user", name=None)
|
||||
assert msg2.name is None
|
||||
|
||||
msg3 = MessageClass(content="foo", role="user")
|
||||
msg3 = message_class(content="foo", role="user")
|
||||
assert msg3.name is None
|
||||
|
||||
|
||||
|
@ -51,14 +51,14 @@ def test_serde_any_message() -> None:
|
||||
),
|
||||
]
|
||||
|
||||
Model = RootModel[AnyMessage]
|
||||
model = RootModel[AnyMessage]
|
||||
|
||||
for lc_object in lc_objects:
|
||||
d = lc_object.model_dump()
|
||||
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
|
||||
obj1 = Model.model_validate(d)
|
||||
obj1 = model.model_validate(d)
|
||||
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
|
||||
|
||||
with pytest.raises((TypeError, ValidationError)):
|
||||
# Make sure that specifically validation error is raised
|
||||
Model.model_validate({})
|
||||
model.model_validate({})
|
||||
|
@ -1421,7 +1421,7 @@ class InjectedTool(BaseTool):
|
||||
return y
|
||||
|
||||
|
||||
class fooSchema(BaseModel):
|
||||
class fooSchema(BaseModel): # noqa: N801
|
||||
"""foo."""
|
||||
|
||||
x: int = Field(..., description="abc")
|
||||
@ -1568,14 +1568,14 @@ def test_tool_injected_arg() -> None:
|
||||
|
||||
|
||||
def test_tool_inherited_injected_arg() -> None:
|
||||
class barSchema(BaseModel):
|
||||
class BarSchema(BaseModel):
|
||||
"""bar."""
|
||||
|
||||
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
|
||||
..., description="123"
|
||||
)
|
||||
|
||||
class fooSchema(barSchema):
|
||||
class FooSchema(BarSchema):
|
||||
"""foo."""
|
||||
|
||||
x: int = Field(..., description="abc")
|
||||
@ -1583,14 +1583,14 @@ def test_tool_inherited_injected_arg() -> None:
|
||||
class InheritedInjectedArgTool(BaseTool):
|
||||
name: str = "foo"
|
||||
description: str = "foo."
|
||||
args_schema: type[BaseModel] = fooSchema
|
||||
args_schema: type[BaseModel] = FooSchema
|
||||
|
||||
def _run(self, x: int, y: str) -> Any:
|
||||
return y
|
||||
|
||||
tool_ = InheritedInjectedArgTool()
|
||||
assert tool_.get_input_schema().model_json_schema() == {
|
||||
"title": "fooSchema", # Matches the title from the provided schema
|
||||
"title": "FooSchema", # Matches the title from the provided schema
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -1877,15 +1877,15 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
|
||||
A = TypeVar("A")
|
||||
|
||||
if use_v1_namespace:
|
||||
from pydantic.v1 import BaseModel as BM1
|
||||
from pydantic.v1 import BaseModel as BaseModel1
|
||||
|
||||
class ModelA(BM1, Generic[A], extra="allow"):
|
||||
class ModelA(BaseModel1, Generic[A], extra="allow"):
|
||||
a: A
|
||||
else:
|
||||
from pydantic import BaseModel as BM2
|
||||
from pydantic import BaseModel as BaseModel2
|
||||
from pydantic import ConfigDict
|
||||
|
||||
class ModelA(BM2, Generic[A]): # type: ignore[no-redef]
|
||||
class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef]
|
||||
a: A
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
@ -2081,13 +2081,13 @@ def test_structured_tool_direct_init() -> None:
|
||||
def foo(bar: str) -> str:
|
||||
return bar
|
||||
|
||||
async def asyncFoo(bar: str) -> str:
|
||||
async def async_foo(bar: str) -> str:
|
||||
return bar
|
||||
|
||||
class fooSchema(BaseModel):
|
||||
class FooSchema(BaseModel):
|
||||
bar: str = Field(..., description="The bar")
|
||||
|
||||
tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo)
|
||||
tool = StructuredTool(name="foo", args_schema=FooSchema, coroutine=async_foo)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
assert tool.invoke("hello") == "hello"
|
||||
|
@ -38,7 +38,7 @@ from langchain_core.utils.function_calling import (
|
||||
|
||||
@pytest.fixture()
|
||||
def pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel):
|
||||
class dummy_function(BaseModel): # noqa: N801
|
||||
"""dummy function"""
|
||||
|
||||
arg1: int = Field(..., description="foo")
|
||||
@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def Annotated_function() -> Callable:
|
||||
def annotated_function() -> Callable:
|
||||
def dummy_function(
|
||||
arg1: ExtensionsAnnotated[int, "foo"],
|
||||
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
|
||||
@ -118,7 +118,7 @@ def dummy_structured_tool() -> StructuredTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel):
|
||||
class dummy_function(BaseModel): # noqa: N801
|
||||
"""dummy function"""
|
||||
|
||||
arg1: int = Field(..., description="foo")
|
||||
@ -129,7 +129,7 @@ def dummy_pydantic() -> type[BaseModel]:
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
||||
class dummy_function(BaseModelV2Maybe):
|
||||
class dummy_function(BaseModelV2Maybe): # noqa: N801
|
||||
"""dummy function"""
|
||||
|
||||
arg1: int = FieldV2Maybe(..., description="foo")
|
||||
@ -142,7 +142,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_typing_typed_dict() -> type:
|
||||
class dummy_function(TypingTypedDict):
|
||||
class dummy_function(TypingTypedDict): # noqa: N801
|
||||
"""dummy function"""
|
||||
|
||||
arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821
|
||||
@ -153,7 +153,7 @@ def dummy_typing_typed_dict() -> type:
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_typing_typed_dict_docstring() -> type:
|
||||
class dummy_function(TypingTypedDict):
|
||||
class dummy_function(TypingTypedDict): # noqa: N801
|
||||
"""dummy function
|
||||
|
||||
Args:
|
||||
@ -169,7 +169,7 @@ def dummy_typing_typed_dict_docstring() -> type:
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_extensions_typed_dict() -> type:
|
||||
class dummy_function(ExtensionsTypedDict):
|
||||
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
||||
"""dummy function"""
|
||||
|
||||
arg1: ExtensionsAnnotated[int, ..., "foo"]
|
||||
@ -180,7 +180,7 @@ def dummy_extensions_typed_dict() -> type:
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_extensions_typed_dict_docstring() -> type:
|
||||
class dummy_function(ExtensionsTypedDict):
|
||||
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
||||
"""dummy function
|
||||
|
||||
Args:
|
||||
@ -241,7 +241,7 @@ def test_convert_to_openai_function(
|
||||
dummy_structured_tool: StructuredTool,
|
||||
dummy_tool: BaseTool,
|
||||
json_schema: dict,
|
||||
Annotated_function: Callable,
|
||||
annotated_function: Callable,
|
||||
dummy_pydantic: type[BaseModel],
|
||||
runnable: Runnable,
|
||||
dummy_typing_typed_dict: type,
|
||||
@ -275,7 +275,7 @@ def test_convert_to_openai_function(
|
||||
expected,
|
||||
Dummy.dummy_function,
|
||||
DummyWithClassMethod.dummy_function,
|
||||
Annotated_function,
|
||||
annotated_function,
|
||||
dummy_pydantic,
|
||||
dummy_typing_typed_dict,
|
||||
dummy_typing_typed_dict_docstring,
|
||||
@ -523,20 +523,20 @@ def test__convert_typed_dict_to_openai_function(
|
||||
use_extension_typed_dict: bool, use_extension_annotated: bool
|
||||
) -> None:
|
||||
if use_extension_typed_dict:
|
||||
TypedDict = ExtensionsTypedDict
|
||||
typed_dict = ExtensionsTypedDict
|
||||
else:
|
||||
TypedDict = TypingTypedDict
|
||||
typed_dict = TypingTypedDict
|
||||
if use_extension_annotated:
|
||||
Annotated = TypingAnnotated
|
||||
annotated = TypingAnnotated
|
||||
else:
|
||||
Annotated = TypingAnnotated
|
||||
annotated = TypingAnnotated
|
||||
|
||||
class SubTool(TypedDict):
|
||||
class SubTool(typed_dict):
|
||||
"""Subtool docstring"""
|
||||
|
||||
args: Annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
|
||||
args: annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
|
||||
|
||||
class Tool(TypedDict):
|
||||
class Tool(typed_dict):
|
||||
"""Docstring
|
||||
|
||||
Args:
|
||||
@ -546,20 +546,20 @@ def test__convert_typed_dict_to_openai_function(
|
||||
arg1: str
|
||||
arg2: Union[int, str, bool]
|
||||
arg3: Optional[list[SubTool]]
|
||||
arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722
|
||||
arg5: Annotated[Optional[float], None]
|
||||
arg6: Annotated[
|
||||
arg4: annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722
|
||||
arg5: annotated[Optional[float], None]
|
||||
arg6: annotated[
|
||||
Optional[Sequence[Mapping[str, tuple[Iterable[Any], SubTool]]]], []
|
||||
]
|
||||
arg7: Annotated[list[SubTool], ...]
|
||||
arg8: Annotated[tuple[SubTool], ...]
|
||||
arg9: Annotated[Sequence[SubTool], ...]
|
||||
arg10: Annotated[Iterable[SubTool], ...]
|
||||
arg11: Annotated[set[SubTool], ...]
|
||||
arg12: Annotated[dict[str, SubTool], ...]
|
||||
arg13: Annotated[Mapping[str, SubTool], ...]
|
||||
arg14: Annotated[MutableMapping[str, SubTool], ...]
|
||||
arg15: Annotated[bool, False, "flag"] # noqa: F821 # type: ignore
|
||||
arg7: annotated[list[SubTool], ...]
|
||||
arg8: annotated[tuple[SubTool], ...]
|
||||
arg9: annotated[Sequence[SubTool], ...]
|
||||
arg10: annotated[Iterable[SubTool], ...]
|
||||
arg11: annotated[set[SubTool], ...]
|
||||
arg12: annotated[dict[str, SubTool], ...]
|
||||
arg13: annotated[Mapping[str, SubTool], ...]
|
||||
arg14: annotated[MutableMapping[str, SubTool], ...]
|
||||
arg15: annotated[bool, False, "flag"] # noqa: F821 # type: ignore
|
||||
|
||||
expected = {
|
||||
"name": "Tool",
|
||||
|
@ -10,7 +10,7 @@ from langchain_standard_tests.integration_tests.vectorstores import (
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
from tests.unit_tests.stubs import _AnyIdDocument
|
||||
from tests.unit_tests.stubs import _any_id_document
|
||||
|
||||
|
||||
class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite):
|
||||
@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None:
|
||||
|
||||
# Check sync version
|
||||
output = store.similarity_search("foo", k=1)
|
||||
assert output == [_AnyIdDocument(page_content="foo")]
|
||||
assert output == [_any_id_document(page_content="foo")]
|
||||
|
||||
# Check async version
|
||||
output = await store.asimilarity_search("bar", k=2)
|
||||
assert output == [
|
||||
_AnyIdDocument(page_content="bar"),
|
||||
_AnyIdDocument(page_content="baz"),
|
||||
_any_id_document(page_content="bar"),
|
||||
_any_id_document(page_content="baz"),
|
||||
]
|
||||
|
||||
|
||||
@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None:
|
||||
# make sure we can k > docstore size
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0] == _AnyIdDocument(page_content="foo")
|
||||
assert output[1] == _AnyIdDocument(page_content="foy")
|
||||
assert output[0] == _any_id_document(page_content="foo")
|
||||
assert output[1] == _any_id_document(page_content="foy")
|
||||
|
||||
# Check async version
|
||||
output = await docsearch.amax_marginal_relevance_search(
|
||||
"foo", k=10, lambda_mult=0.1
|
||||
)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0] == _AnyIdDocument(page_content="foo")
|
||||
assert output[1] == _AnyIdDocument(page_content="foy")
|
||||
assert output[0] == _any_id_document(page_content="foo")
|
||||
assert output[1] == _any_id_document(page_content="foy")
|
||||
|
||||
|
||||
async def test_inmemory_dump_load(tmp_path: Path) -> None:
|
||||
@ -117,7 +117,7 @@ async def test_inmemory_filter() -> None:
|
||||
|
||||
# Check sync version
|
||||
output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1)
|
||||
assert output == [_AnyIdDocument(page_content="foo", metadata={"id": 1})]
|
||||
assert output == [_any_id_document(page_content="foo", metadata={"id": 1})]
|
||||
|
||||
# filter with not stored document id
|
||||
output = await store.asimilarity_search(
|
||||
|
@ -50,6 +50,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||
return [self.store[id] for id in ids if id in self.store]
|
||||
|
||||
@classmethod
|
||||
def from_texts( # type: ignore
|
||||
cls,
|
||||
texts: list[str],
|
||||
|
Loading…
Reference in New Issue
Block a user