core: Add N(naming) ruff rules (#25362)

Public classes/functions are not renamed and rule is ignored for them.

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Christophe Bornet 2024-09-19 07:09:39 +02:00 committed by GitHub
parent 7835c0651f
commit fd21ffe293
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 289 additions and 222 deletions

View File

@ -155,7 +155,7 @@ def beta(
_name = _name or obj.fget.__qualname__
old_doc = obj.__doc__
class _beta_property(property):
class _BetaProperty(property):
"""A beta property."""
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
@ -186,7 +186,7 @@ def beta(
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
"""Finalize the property."""
return _beta_property(
return _BetaProperty(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
)

View File

@ -264,7 +264,7 @@ def deprecated(
_name = _name or cast(Union[type, Callable], obj.fget).__qualname__
old_doc = obj.__doc__
class _deprecated_property(property):
class _DeprecatedProperty(property):
"""A deprecated property."""
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
@ -297,7 +297,7 @@ def deprecated(
"""Finalize the property."""
return cast(
T,
_deprecated_property(
_DeprecatedProperty(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
),
)

View File

@ -3,7 +3,7 @@
from typing import Any, Optional
class LangChainException(Exception):
class LangChainException(Exception): # noqa: N818
"""General LangChain exception."""
@ -11,7 +11,7 @@ class TracerException(LangChainException):
"""Base class for exceptions in tracers module."""
class OutputParserException(ValueError, LangChainException):
class OutputParserException(ValueError, LangChainException): # noqa: N818
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors

View File

@ -14,7 +14,7 @@ from typing import (
)
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypeAlias, TypedDict
from typing_extensions import TypeAlias, TypedDict, override
from langchain_core._api import deprecated
from langchain_core.messages import (
@ -143,6 +143,7 @@ class BaseLanguageModel(
return verbose
@property
@override
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain_core.prompt_values import (

View File

@ -26,6 +26,7 @@ from pydantic import (
Field,
model_validator,
)
from typing_extensions import override
from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
@ -251,6 +252,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# --- Runnable methods ---
@property
@override
def OutputType(self) -> Any:
"""Get the output type for this runnable."""
return AnyMessage

View File

@ -31,6 +31,7 @@ from tenacity import (
stop_after_attempt,
wait_exponential,
)
from typing_extensions import override
from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
@ -318,6 +319,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
# --- Runnable methods ---
@property
@override
def OutputType(self) -> type[str]:
"""Get the input type for this runnable."""
return str

View File

@ -10,6 +10,8 @@ from typing import (
Union,
)
from typing_extensions import override
from langchain_core.language_models import LanguageModelOutput
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
@ -63,11 +65,13 @@ class BaseGenerationOutputParser(
"""Base class to parse the output of an LLM call."""
@property
@override
def InputType(self) -> Any:
"""Return the input type for the parser."""
return Union[str, AnyMessage]
@property
@override
def OutputType(self) -> type[T]:
"""Return the output type for the parser."""
# even though mypy complains this isn't valid,
@ -148,11 +152,13 @@ class BaseOutputParser(
""" # noqa: E501
@property
@override
def InputType(self) -> Any:
"""Return the input type for the parser."""
return Union[str, AnyMessage]
@property
@override
def OutputType(self) -> type[T]:
"""Return the output type for the parser.

View File

@ -3,6 +3,7 @@ from typing import Annotated, Generic, Optional
import pydantic
from pydantic import SkipValidation
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
@ -107,6 +108,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return "pydantic"
@property
@override
def OutputType(self) -> type[TBaseModel]:
"""Return the pydantic model."""
return self.pydantic_object

View File

@ -1,6 +1,6 @@
import re
import xml
import xml.etree.ElementTree as ET
import xml.etree.ElementTree as ET # noqa: N817
from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal, Optional, Union
from xml.etree.ElementTree import TreeBuilder
@ -46,14 +46,14 @@ class _StreamingParser:
"""
if parser == "defusedxml":
try:
from defusedxml import ElementTree as DET # type: ignore
import defusedxml # type: ignore
except ImportError as e:
raise ImportError(
"defusedxml is not installed. "
"Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml` "
) from e
_parser = DET.DefusedXMLParser(target=TreeBuilder())
_parser = defusedxml.ElementTree.DefusedXMLParser(target=TreeBuilder())
else:
_parser = None
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
@ -189,7 +189,7 @@ class XMLOutputParser(BaseTransformOutputParser):
# likely if you're reading this you can move them to the top of the file
if self.parser == "defusedxml":
try:
from defusedxml import ElementTree as DET # type: ignore
import defusedxml # type: ignore
except ImportError as e:
raise ImportError(
"defusedxml is not installed. "
@ -197,9 +197,9 @@ class XMLOutputParser(BaseTransformOutputParser):
"You can install it with `pip install defusedxml`"
"See https://github.com/tiran/defusedxml for more details"
) from e
_ET = DET # Use the defusedxml parser
_et = defusedxml.ElementTree # Use the defusedxml parser
else:
_ET = ET # Use the standard library parser
_et = ET # Use the standard library parser
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:

View File

@ -18,7 +18,7 @@ from typing import (
import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self
from typing_extensions import Self, override
from langchain_core.load import dumpd
from langchain_core.output_parsers.base import BaseOutputParser
@ -107,6 +107,7 @@ class BasePromptTemplate(
return dumpd(self)
@property
@override
def OutputType(self) -> Any:
"""Return the output type of the prompt."""
return Union[StringPromptValue, ChatPromptValueConcrete]

View File

@ -36,7 +36,7 @@ from typing import (
)
from pydantic import BaseModel, ConfigDict, Field, RootModel
from typing_extensions import Literal, get_args
from typing_extensions import Literal, get_args, override
from langchain_core._api import beta_decorator
from langchain_core.load.serializable import (
@ -272,7 +272,7 @@ class Runnable(Generic[Input, Output], ABC):
return name_
@property
def InputType(self) -> type[Input]:
def InputType(self) -> type[Input]: # noqa: N802
"""The type of input this Runnable accepts specified as a type annotation."""
# First loop through all parent classes and if any of them is
# a pydantic model, we will pick up the generic parameterization
@ -297,7 +297,7 @@ class Runnable(Generic[Input, Output], ABC):
)
@property
def OutputType(self) -> type[Output]:
def OutputType(self) -> type[Output]: # noqa: N802
"""The type of output this Runnable produces specified as a type annotation."""
# First loop through bases -- this will help generic
# any pydantic models.
@ -2811,11 +2811,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
)
@property
@override
def InputType(self) -> type[Input]:
"""The type of the input to the Runnable."""
return self.first.InputType
@property
@override
def OutputType(self) -> type[Output]:
"""The type of the output of the Runnable."""
return self.last.OutputType
@ -3564,6 +3566,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
return super().get_name(suffix, name=name)
@property
@override
def InputType(self) -> Any:
"""The type of the input to the Runnable."""
for step in self.steps__.values():
@ -4057,6 +4060,7 @@ class RunnableGenerator(Runnable[Input, Output]):
self.name = "RunnableGenerator"
@property
@override
def InputType(self) -> Any:
func = getattr(self, "_transform", None) or self._atransform
try:
@ -4097,6 +4101,7 @@ class RunnableGenerator(Runnable[Input, Output]):
)
@property
@override
def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or self._atransform
try:
@ -4346,6 +4351,7 @@ class RunnableLambda(Runnable[Input, Output]):
pass
@property
@override
def InputType(self) -> Any:
"""The type of the input to this Runnable."""
func = getattr(self, "func", None) or self.afunc
@ -4405,6 +4411,7 @@ class RunnableLambda(Runnable[Input, Output]):
return super().get_input_schema(config)
@property
@override
def OutputType(self) -> Any:
"""The type of the output of this Runnable as a type annotation.
@ -4958,6 +4965,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
)
@property
@override
def InputType(self) -> Any:
return list[self.bound.InputType] # type: ignore[name-defined]
@ -4981,6 +4989,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
)
@property
@override
def OutputType(self) -> type[list[Output]]:
return list[self.bound.OutputType] # type: ignore[name-defined]
@ -5274,6 +5283,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
return self.bound.get_name(suffix, name=name)
@property
@override
def InputType(self) -> type[Input]:
return (
cast(type[Input], self.custom_input_type)
@ -5282,6 +5292,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
)
@property
@override
def OutputType(self) -> type[Output]:
return (
cast(type[Output], self.custom_output_type)

View File

@ -16,6 +16,7 @@ from typing import (
from weakref import WeakValueDictionary
from pydantic import BaseModel, ConfigDict
from typing_extensions import override
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
@ -68,10 +69,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
return ["langchain", "schema", "runnable"]
@property
@override
def InputType(self) -> type[Input]:
return self.default.InputType
@property
@override
def OutputType(self) -> type[Output]:
return self.default.OutputType

View File

@ -13,6 +13,7 @@ from typing import (
)
from pydantic import BaseModel, ConfigDict
from typing_extensions import override
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
@ -106,10 +107,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
)
@property
@override
def InputType(self) -> type[Input]:
return self.runnable.InputType
@property
@override
def OutputType(self) -> type[Output]:
return self.runnable.OutputType

View File

@ -246,27 +246,27 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
# NOTE: coordinates might me negative, so we need to shift
# everything to the positive plane before we actually draw it.
Xs = []
Ys = []
xlist = []
ylist = []
sug = _build_sugiyama_layout(vertices, edges)
for vertex in sug.g.sV:
# NOTE: moving boxes w/2 to the left
Xs.append(vertex.view.xy[0] - vertex.view.w / 2.0)
Xs.append(vertex.view.xy[0] + vertex.view.w / 2.0)
Ys.append(vertex.view.xy[1])
Ys.append(vertex.view.xy[1] + vertex.view.h)
xlist.append(vertex.view.xy[0] - vertex.view.w / 2.0)
xlist.append(vertex.view.xy[0] + vertex.view.w / 2.0)
ylist.append(vertex.view.xy[1])
ylist.append(vertex.view.xy[1] + vertex.view.h)
for edge in sug.g.sE:
for x, y in edge.view._pts:
Xs.append(x)
Ys.append(y)
xlist.append(x)
ylist.append(y)
minx = min(Xs)
miny = min(Ys)
maxx = max(Xs)
maxy = max(Ys)
minx = min(xlist)
miny = min(ylist)
maxx = max(xlist)
maxy = max(ylist)
canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1
canvas_lines = int(round(maxy - miny))

View File

@ -12,6 +12,7 @@ from typing import (
)
from pydantic import BaseModel
from typing_extensions import override
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load
@ -396,6 +397,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
)
@property
@override
def OutputType(self) -> type[Output]:
output_type = self._history_chain.OutputType
return output_type

View File

@ -16,6 +16,7 @@ from typing import (
)
from pydantic import BaseModel, RootModel
from typing_extensions import override
from langchain_core.runnables.base import (
Other,
@ -193,10 +194,12 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
return ["langchain", "schema", "runnable"]
@property
@override
def InputType(self) -> Any:
return self.input_type or Any
@property
@override
def OutputType(self) -> Any:
return self.input_type or Any

View File

@ -28,7 +28,7 @@ from typing import (
Union,
)
from typing_extensions import TypeGuard
from typing_extensions import TypeGuard, override
from langchain_core.runnables.schema import StreamEvent
@ -135,6 +135,7 @@ class IsLocalDict(ast.NodeVisitor):
self.name = name
self.keys = keys
@override
def visit_Subscript(self, node: ast.Subscript) -> Any:
"""Visit a subscript node.
@ -154,6 +155,7 @@ class IsLocalDict(ast.NodeVisitor):
# we've found a subscript access on the name we're looking for
self.keys.add(node.slice.value)
@override
def visit_Call(self, node: ast.Call) -> Any:
"""Visit a call node.
@ -182,6 +184,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
def __init__(self) -> None:
self.keys: set[str] = set()
@override
def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function.
@ -196,6 +199,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node.body)
@override
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a function definition.
@ -210,6 +214,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node)
@override
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
"""Visit an async function definition.
@ -232,6 +237,7 @@ class NonLocals(ast.NodeVisitor):
self.loads: set[str] = set()
self.stores: set[str] = set()
@override
def visit_Name(self, node: ast.Name) -> Any:
"""Visit a name node.
@ -246,6 +252,7 @@ class NonLocals(ast.NodeVisitor):
elif isinstance(node.ctx, ast.Store):
self.stores.add(node.id)
@override
def visit_Attribute(self, node: ast.Attribute) -> Any:
"""Visit an attribute node.
@ -272,6 +279,7 @@ class FunctionNonLocals(ast.NodeVisitor):
def __init__(self) -> None:
self.nonlocals: set[str] = set()
@override
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a function definition.
@ -285,6 +293,7 @@ class FunctionNonLocals(ast.NodeVisitor):
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
@override
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
"""Visit an async function definition.
@ -298,6 +307,7 @@ class FunctionNonLocals(ast.NodeVisitor):
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
@override
def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function.
@ -320,6 +330,7 @@ class GetLambdaSource(ast.NodeVisitor):
self.source: Optional[str] = None
self.count = 0
@override
def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function.

View File

@ -311,7 +311,7 @@ def create_schema_from_function(
)
class ToolException(Exception):
class ToolException(Exception): # noqa: N818
"""Optional exception that tool throws when execution error occurs.
When this exception is thrown, the agent will not stop working,

View File

@ -9,7 +9,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any:
)
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any:
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802
"""Throw an error because this has been replaced by LangChainTracer."""
raise RuntimeError(
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."

View File

@ -18,7 +18,7 @@ from langchain_core._api import deprecated
@deprecated("0.1.0", alternative="Use string instead.", removal="1.0")
def RunTypeEnum() -> type[RunTypeEnumDep]:
def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
"""RunTypeEnum."""
warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead"

View File

@ -235,7 +235,7 @@ class Tee(Generic[T]):
atee = Tee
class aclosing(AbstractAsyncContextManager):
class aclosing(AbstractAsyncContextManager): # noqa: N801
"""Async context manager for safely finalizing an asynchronously cleaned-up
resource such as an async generator, calling its ``aclose()`` method.

View File

@ -17,12 +17,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
"""Row-wise cosine similarity between two equal-width matrices.
Args:
X: A matrix of shape (n, m).
Y: A matrix of shape (k, m).
x: A matrix of shape (n, m).
y: A matrix of shape (k, m).
Returns:
A matrix of shape (n, k) where each element (i, j) is the cosine similarity
@ -40,33 +40,33 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"Please install numpy with `pip install numpy`."
) from e
if len(X) == 0 or len(Y) == 0:
if len(x) == 0 or len(y) == 0:
return np.array([])
X = np.array(X)
Y = np.array(Y)
if X.shape[1] != Y.shape[1]:
x = np.array(x)
y = np.array(y)
if x.shape[1] != y.shape[1]:
raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
f"and Y has shape {Y.shape}."
f"Number of columns in X and Y must be the same. X has shape {x.shape} "
f"and Y has shape {y.shape}."
)
try:
import simsimd as simd # type: ignore[import-not-found]
X = np.array(X, dtype=np.float32)
Y = np.array(Y, dtype=np.float32)
Z = 1 - np.array(simd.cdist(X, Y, metric="cosine"))
return Z
x = np.array(x, dtype=np.float32)
y = np.array(y, dtype=np.float32)
z = 1 - np.array(simd.cdist(x, y, metric="cosine"))
return z
except ImportError:
logger.debug(
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
"to use simsimd please install with `pip install simsimd`."
)
X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1)
x_norm = np.linalg.norm(x, axis=1)
y_norm = np.linalg.norm(y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below.
with np.errstate(divide="ignore", invalid="ignore"):
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity

View File

@ -44,9 +44,17 @@ python = ">=3.12.4"
target-version = "py39"
[tool.ruff.lint]
select = ["B", "E", "F", "I", "T201", "UP"]
select = ["B", "E", "F", "I", "N", "T201", "UP"]
ignore = ["UP007"]
[tool.ruff.lint.pep8-naming]
classmethod-decorators = [
"classmethod",
"langchain_core.utils.pydantic.pre_init",
"pydantic.field_validator",
"pydantic.v1.root_validator",
]
[tool.coverage.run]
omit = [ "tests/*",]

View File

@ -9,9 +9,9 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import (
_AnyIdAIMessage,
_AnyIdAIMessageChunk,
_AnyIdHumanMessage,
_any_id_ai_message,
_any_id_ai_message_chunk,
_any_id_human_message,
)
@ -20,11 +20,11 @@ def test_generic_fake_chat_model_invoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == _AnyIdAIMessage(content="hello")
assert response == _any_id_ai_message(content="hello")
response = model.invoke("kitty")
assert response == _AnyIdAIMessage(content="goodbye")
assert response == _any_id_ai_message(content="goodbye")
response = model.invoke("meow")
assert response == _AnyIdAIMessage(content="hello")
assert response == _any_id_ai_message(content="hello")
async def test_generic_fake_chat_model_ainvoke() -> None:
@ -32,11 +32,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == _AnyIdAIMessage(content="hello")
assert response == _any_id_ai_message(content="hello")
response = await model.ainvoke("kitty")
assert response == _AnyIdAIMessage(content="goodbye")
assert response == _any_id_ai_message(content="goodbye")
response = await model.ainvoke("meow")
assert response == _AnyIdAIMessage(content="hello")
assert response == _any_id_ai_message(content="hello")
async def test_generic_fake_chat_model_stream() -> None:
@ -49,17 +49,17 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
_any_id_ai_message_chunk(content="hello"),
_any_id_ai_message_chunk(content=" "),
_any_id_ai_message_chunk(content="goodbye"),
]
assert len({chunk.id for chunk in chunks}) == 1
chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
_any_id_ai_message_chunk(content="hello"),
_any_id_ai_message_chunk(content=" "),
_any_id_ai_message_chunk(content="goodbye"),
]
assert len({chunk.id for chunk in chunks}) == 1
@ -69,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}),
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}),
_any_id_ai_message_chunk(content="", additional_kwargs={"foo": 42}),
_any_id_ai_message_chunk(content="", additional_kwargs={"bar": 24}),
]
assert len({chunk.id for chunk in chunks}) == 1
@ -88,19 +88,19 @@ async def test_generic_fake_chat_model_stream() -> None:
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
_AnyIdAIMessageChunk(
_any_id_ai_message_chunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
),
_AnyIdAIMessageChunk(
_any_id_ai_message_chunk(
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'},
},
),
_AnyIdAIMessageChunk(
_any_id_ai_message_chunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
),
_AnyIdAIMessageChunk(
_any_id_ai_message_chunk(
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
@ -138,9 +138,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
_any_id_ai_message_chunk(content="hello"),
_any_id_ai_message_chunk(content=" "),
_any_id_ai_message_chunk(content="goodbye"),
]
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
@ -189,9 +189,9 @@ async def test_callback_handlers() -> None:
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
_any_id_ai_message_chunk(content="hello"),
_any_id_ai_message_chunk(content=" "),
_any_id_ai_message_chunk(content="goodbye"),
]
assert tokens == ["hello", " ", "goodbye"]
assert len({chunk.id for chunk in results}) == 1
@ -200,6 +200,8 @@ async def test_callback_handlers() -> None:
def test_chat_model_inputs() -> None:
fake = ParrotFakeChatModel()
assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello")
assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah")
assert fake.invoke("hello") == _any_id_human_message(content="hello")
assert fake.invoke([("ai", "blah")]) == _any_id_ai_message(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message(
content="blah"
)

View File

@ -27,7 +27,7 @@ from tests.unit_tests.fake.callbacks import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
@pytest.fixture
@ -147,10 +147,10 @@ async def test_astream_fallback_to_ainvoke() -> None:
model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [_AnyIdAIMessage(content="hello")]
assert chunks == [_any_id_ai_message(content="hello")]
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [_AnyIdAIMessage(content="hello")]
assert chunks == [_any_id_ai_message(content="hello")]
async def test_astream_implementation_fallback_to_stream() -> None:
@ -185,15 +185,15 @@ async def test_astream_implementation_fallback_to_stream() -> None:
model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [
_AnyIdAIMessageChunk(content="a"),
_AnyIdAIMessageChunk(content="b"),
_any_id_ai_message_chunk(content="a"),
_any_id_ai_message_chunk(content="b"),
]
assert len({chunk.id for chunk in chunks}) == 1
assert type(model)._astream == BaseChatModel._astream
astream_chunks = [chunk async for chunk in model.astream("anything")]
assert astream_chunks == [
_AnyIdAIMessageChunk(content="a"),
_AnyIdAIMessageChunk(content="b"),
_any_id_ai_message_chunk(content="a"),
_any_id_ai_message_chunk(content="b"),
]
assert len({chunk.id for chunk in astream_chunks}) == 1
@ -230,8 +230,8 @@ async def test_astream_implementation_uses_astream() -> None:
model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [
_AnyIdAIMessageChunk(content="a"),
_AnyIdAIMessageChunk(content="b"),
_any_id_ai_message_chunk(content="a"),
_any_id_ai_message_chunk(content="b"),
]
assert len({chunk.id for chunk in chunks}) == 1

View File

@ -25,7 +25,7 @@ def change_directory(dir: Path) -> Iterator:
os.chdir(origin)
def test_loading_from_YAML() -> None:
def test_loading_from_yaml() -> None:
"""Test loading from yaml file."""
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml")
expected_prompt = PromptTemplate(
@ -36,7 +36,7 @@ def test_loading_from_YAML() -> None:
assert prompt == expected_prompt
def test_loading_from_JSON() -> None:
def test_loading_from_json() -> None:
"""Test loading from json file."""
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json")
expected_prompt = PromptTemplate(
@ -46,14 +46,14 @@ def test_loading_from_JSON() -> None:
assert prompt == expected_prompt
def test_loading_jinja_from_JSON() -> None:
def test_loading_jinja_from_json() -> None:
"""Test that loading jinja2 format prompts from JSON raises ValueError."""
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json"
with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"):
load_prompt(prompt_path)
def test_loading_jinja_from_YAML() -> None:
def test_loading_jinja_from_yaml() -> None:
"""Test that loading jinja2 format prompts from YAML raises ValueError."""
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml"
with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"):

View File

@ -2,6 +2,7 @@ from typing import Optional
from pydantic import BaseModel
from syrupy import SnapshotAssertion
from typing_extensions import override
from langchain_core.language_models import FakeListLLM
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
@ -353,9 +354,11 @@ def test_runnable_get_graph_with_invalid_input_type() -> None:
class InvalidInputTypeRunnable(Runnable[int, int]):
@property
@override
def InputType(self) -> type:
raise TypeError()
@override
def invoke(
self,
input: int,
@ -375,9 +378,11 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
class InvalidOutputTypeRunnable(Runnable[int, int]):
@property
@override
def OutputType(self) -> type:
raise TypeError()
@override
def invoke(
self,
input: int,

View File

@ -493,7 +493,7 @@ def test_get_output_schema() -> None:
def test_get_input_schema_input_messages() -> None:
from pydantic import RootModel
RunnableWithMessageHistoryInput = RootModel[Sequence[BaseMessage]]
runnable_with_message_history_input = RootModel[Sequence[BaseMessage]]
runnable = RunnableLambda(
lambda messages: {
@ -515,7 +515,7 @@ def test_get_input_schema_input_messages() -> None:
with_history = RunnableWithMessageHistory(
runnable, get_session_history, output_messages_key="output"
)
expected_schema = _schema(RunnableWithMessageHistoryInput)
expected_schema = _schema(runnable_with_message_history_input)
expected_schema["title"] = "RunnableWithChatHistoryInput"
assert _schema(with_history.get_input_schema()) == expected_schema

View File

@ -87,7 +87,7 @@ from langchain_core.tracers import (
)
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
@ -1699,7 +1699,7 @@ def test_prompt_with_chat_model(
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == _AnyIdAIMessage(content="foo")
) == _any_id_ai_message(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
@ -1724,8 +1724,8 @@ def test_prompt_with_chat_model(
],
dict(callbacks=[tracer]),
) == [
_AnyIdAIMessage(content="foo"),
_AnyIdAIMessage(content="foo"),
_any_id_ai_message(content="foo"),
_any_id_ai_message(content="foo"),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
@ -1765,9 +1765,9 @@ def test_prompt_with_chat_model(
assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
] == [
_AnyIdAIMessageChunk(content="f"),
_AnyIdAIMessageChunk(content="o"),
_AnyIdAIMessageChunk(content="o"),
_any_id_ai_message_chunk(content="f"),
_any_id_ai_message_chunk(content="o"),
_any_id_ai_message_chunk(content="o"),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -1805,7 +1805,7 @@ async def test_prompt_with_chat_model_async(
tracer = FakeTracer()
assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == _AnyIdAIMessage(content="foo")
) == _any_id_ai_message(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
@ -1830,8 +1830,8 @@ async def test_prompt_with_chat_model_async(
],
dict(callbacks=[tracer]),
) == [
_AnyIdAIMessage(content="foo"),
_AnyIdAIMessage(content="foo"),
_any_id_ai_message(content="foo"),
_any_id_ai_message(content="foo"),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
@ -1874,9 +1874,9 @@ async def test_prompt_with_chat_model_async(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
] == [
_AnyIdAIMessageChunk(content="f"),
_AnyIdAIMessageChunk(content="o"),
_AnyIdAIMessageChunk(content="o"),
_any_id_ai_message_chunk(content="f"),
_any_id_ai_message_chunk(content="o"),
_any_id_ai_message_chunk(content="o"),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -2548,7 +2548,7 @@ def test_prompt_with_chat_model_and_parser(
HumanMessage(content="What is your name?"),
]
)
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar")
assert tracer.runs == snapshot
@ -2681,7 +2681,7 @@ Question:
),
]
)
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar")
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 4
@ -2727,7 +2727,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
"chat": _any_id_ai_message(content="i'm a chatbot"),
"llm": "i'm a textbot",
}
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -2936,7 +2936,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
"chat": _any_id_ai_message(content="i'm a chatbot"),
"llm": "i'm a textbot",
"passthrough": ChatPromptValue(
messages=[
@ -3000,7 +3000,7 @@ def test_map_stream() -> None:
assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"},
{"chat": _AnyIdAIMessageChunk(content="i")},
{"chat": _any_id_ai_message_chunk(content="i")},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks)
@ -3059,11 +3059,11 @@ def test_map_stream() -> None:
assert streamed_chunks[0] in [
{"llm": "i"},
{"chat": _AnyIdAIMessageChunk(content="i")},
{"chat": _any_id_ai_message_chunk(content="i")},
]
if not ( # TODO(Rewrite properly) statement above
streamed_chunks[0] == {"llm": "i"}
or {"chat": _AnyIdAIMessageChunk(content="i")}
or {"chat": _any_id_ai_message_chunk(content="i")}
):
raise AssertionError(f"Got an unexpected chunk: {streamed_chunks[0]}")
@ -3108,7 +3108,7 @@ def test_map_stream_iterator_input() -> None:
assert streamed_chunks[0] in [
{"passthrough": "i"},
{"llm": "i"},
{"chat": _AnyIdAIMessageChunk(content="i")},
{"chat": _any_id_ai_message_chunk(content="i")},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
assert all(len(c.keys()) == 1 for c in streamed_chunks)
@ -3152,7 +3152,7 @@ async def test_map_astream() -> None:
assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"},
{"chat": _AnyIdAIMessageChunk(content="i")},
{"chat": _any_id_ai_message_chunk(content="i")},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks)

View File

@ -32,7 +32,7 @@ from langchain_core.runnables import (
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import tool
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
@ -503,7 +503,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -512,7 +512,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -521,7 +521,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -530,7 +530,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"output": _AnyIdAIMessageChunk(content="hello world!")},
"data": {"output": _any_id_ai_message_chunk(content="hello world!")},
"event": "on_chat_model_end",
"metadata": {"a": "b"},
"name": "my_model",
@ -575,7 +575,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -588,7 +588,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -601,7 +601,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -621,7 +621,9 @@ async def test_astream_events_from_model() -> None:
[
{
"generation_info": None,
"message": _AnyIdAIMessage(content="hello world!"),
"message": _any_id_ai_message(
content="hello world!"
),
"text": "hello world!",
"type": "ChatGeneration",
}
@ -644,7 +646,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"data": {"chunk": _any_id_ai_message(content="hello world!")},
"event": "on_chain_stream",
"metadata": {},
"name": "i_dont_stream",
@ -653,7 +655,7 @@ async def test_astream_events_from_model() -> None:
"tags": [],
},
{
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"data": {"output": _any_id_ai_message(content="hello world!")},
"event": "on_chain_end",
"metadata": {},
"name": "i_dont_stream",
@ -698,7 +700,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -711,7 +713,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -724,7 +726,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -744,7 +746,9 @@ async def test_astream_events_from_model() -> None:
[
{
"generation_info": None,
"message": _AnyIdAIMessage(content="hello world!"),
"message": _any_id_ai_message(
content="hello world!"
),
"text": "hello world!",
"type": "ChatGeneration",
}
@ -767,7 +771,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"data": {"chunk": _any_id_ai_message(content="hello world!")},
"event": "on_chain_stream",
"metadata": {},
"name": "ai_dont_stream",
@ -776,7 +780,7 @@ async def test_astream_events_from_model() -> None:
"tags": [],
},
{
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"data": {"output": _any_id_ai_message(content="hello world!")},
"event": "on_chain_end",
"metadata": {},
"name": "ai_dont_stream",

View File

@ -47,7 +47,7 @@ from langchain_core.utils.aiter import aclosing
from tests.unit_tests.runnables.test_runnable_events_v1 import (
_assert_events_equal_allow_superset_metadata,
)
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
@ -533,7 +533,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -546,7 +546,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -559,7 +559,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -573,7 +573,7 @@ async def test_astream_events_from_model() -> None:
},
{
"data": {
"output": _AnyIdAIMessageChunk(content="hello world!"),
"output": _any_id_ai_message_chunk(content="hello world!"),
},
"event": "on_chat_model_end",
"metadata": {
@ -640,7 +640,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -653,7 +653,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -666,7 +666,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -681,7 +681,7 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {
"input": {"messages": [[HumanMessage(content="hello")]]},
"output": _AnyIdAIMessage(content="hello world!"),
"output": _any_id_ai_message(content="hello world!"),
},
"event": "on_chat_model_end",
"metadata": {
@ -695,7 +695,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"data": {"chunk": _any_id_ai_message(content="hello world!")},
"event": "on_chain_stream",
"metadata": {},
"name": "i_dont_stream",
@ -704,7 +704,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": [],
},
{
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"data": {"output": _any_id_ai_message(content="hello world!")},
"event": "on_chain_end",
"metadata": {},
"name": "i_dont_stream",
@ -749,7 +749,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": _any_id_ai_message_chunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -762,7 +762,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": _any_id_ai_message_chunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -775,7 +775,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": _any_id_ai_message_chunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {
"a": "b",
@ -790,7 +790,7 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {
"input": {"messages": [[HumanMessage(content="hello")]]},
"output": _AnyIdAIMessage(content="hello world!"),
"output": _any_id_ai_message(content="hello world!"),
},
"event": "on_chat_model_end",
"metadata": {
@ -804,7 +804,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"data": {"chunk": _any_id_ai_message(content="hello world!")},
"event": "on_chain_stream",
"metadata": {},
"name": "ai_dont_stream",
@ -813,7 +813,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": [],
},
{
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"data": {"output": _any_id_ai_message(content="hello world!")},
"event": "on_chain_end",
"metadata": {},
"name": "ai_dont_stream",

View File

@ -16,28 +16,28 @@ class AnyStr(str):
# subclassed strings.
def _AnyIdDocument(**kwargs: Any) -> Document:
def _any_id_document(**kwargs: Any) -> Document:
"""Create a document with an id field."""
message = Document(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
def _any_id_ai_message(**kwargs: Any) -> AIMessage:
"""Create ai message with an any id field."""
message = AIMessage(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
def _any_id_ai_message_chunk(**kwargs: Any) -> AIMessageChunk:
"""Create ai message with an any id field."""
message = AIMessageChunk(**kwargs)
message.id = AnyStr()
return message
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
def _any_id_human_message(**kwargs: Any) -> HumanMessage:
"""Create a human with an any id field."""
message = HumanMessage(**kwargs)
message.id = AnyStr()

View File

@ -781,7 +781,7 @@ def test_convert_to_messages() -> None:
@pytest.mark.parametrize(
"MessageClass",
"message_class",
[
AIMessage,
AIMessageChunk,
@ -790,39 +790,39 @@ def test_convert_to_messages() -> None:
SystemMessage,
],
)
def test_message_name(MessageClass: type) -> None:
msg = MessageClass(content="foo", name="bar")
def test_message_name(message_class: type) -> None:
msg = message_class(content="foo", name="bar")
assert msg.name == "bar"
msg2 = MessageClass(content="foo", name=None)
msg2 = message_class(content="foo", name=None)
assert msg2.name is None
msg3 = MessageClass(content="foo")
msg3 = message_class(content="foo")
assert msg3.name is None
@pytest.mark.parametrize(
"MessageClass",
"message_class",
[FunctionMessage, FunctionMessageChunk],
)
def test_message_name_function(MessageClass: type) -> None:
def test_message_name_function(message_class: type) -> None:
# functionmessage doesn't support name=None
msg = MessageClass(name="foo", content="bar")
msg = message_class(name="foo", content="bar")
assert msg.name == "foo"
@pytest.mark.parametrize(
"MessageClass",
"message_class",
[ChatMessage, ChatMessageChunk],
)
def test_message_name_chat(MessageClass: type) -> None:
msg = MessageClass(content="foo", role="user", name="bar")
def test_message_name_chat(message_class: type) -> None:
msg = message_class(content="foo", role="user", name="bar")
assert msg.name == "bar"
msg2 = MessageClass(content="foo", role="user", name=None)
msg2 = message_class(content="foo", role="user", name=None)
assert msg2.name is None
msg3 = MessageClass(content="foo", role="user")
msg3 = message_class(content="foo", role="user")
assert msg3.name is None

View File

@ -51,14 +51,14 @@ def test_serde_any_message() -> None:
),
]
Model = RootModel[AnyMessage]
model = RootModel[AnyMessage]
for lc_object in lc_objects:
d = lc_object.model_dump()
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
obj1 = Model.model_validate(d)
obj1 = model.model_validate(d)
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
with pytest.raises((TypeError, ValidationError)):
# Make sure that specifically validation error is raised
Model.model_validate({})
model.model_validate({})

View File

@ -1421,7 +1421,7 @@ class InjectedTool(BaseTool):
return y
class fooSchema(BaseModel):
class fooSchema(BaseModel): # noqa: N801
"""foo."""
x: int = Field(..., description="abc")
@ -1568,14 +1568,14 @@ def test_tool_injected_arg() -> None:
def test_tool_inherited_injected_arg() -> None:
class barSchema(BaseModel):
class BarSchema(BaseModel):
"""bar."""
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
..., description="123"
)
class fooSchema(barSchema):
class FooSchema(BarSchema):
"""foo."""
x: int = Field(..., description="abc")
@ -1583,14 +1583,14 @@ def test_tool_inherited_injected_arg() -> None:
class InheritedInjectedArgTool(BaseTool):
name: str = "foo"
description: str = "foo."
args_schema: type[BaseModel] = fooSchema
args_schema: type[BaseModel] = FooSchema
def _run(self, x: int, y: str) -> Any:
return y
tool_ = InheritedInjectedArgTool()
assert tool_.get_input_schema().model_json_schema() == {
"title": "fooSchema", # Matches the title from the provided schema
"title": "FooSchema", # Matches the title from the provided schema
"description": "foo.",
"type": "object",
"properties": {
@ -1877,15 +1877,15 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
A = TypeVar("A")
if use_v1_namespace:
from pydantic.v1 import BaseModel as BM1
from pydantic.v1 import BaseModel as BaseModel1
class ModelA(BM1, Generic[A], extra="allow"):
class ModelA(BaseModel1, Generic[A], extra="allow"):
a: A
else:
from pydantic import BaseModel as BM2
from pydantic import BaseModel as BaseModel2
from pydantic import ConfigDict
class ModelA(BM2, Generic[A]): # type: ignore[no-redef]
class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef]
a: A
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
@ -2081,13 +2081,13 @@ def test_structured_tool_direct_init() -> None:
def foo(bar: str) -> str:
return bar
async def asyncFoo(bar: str) -> str:
async def async_foo(bar: str) -> str:
return bar
class fooSchema(BaseModel):
class FooSchema(BaseModel):
bar: str = Field(..., description="The bar")
tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo)
tool = StructuredTool(name="foo", args_schema=FooSchema, coroutine=async_foo)
with pytest.raises(NotImplementedError):
assert tool.invoke("hello") == "hello"

View File

@ -38,7 +38,7 @@ from langchain_core.utils.function_calling import (
@pytest.fixture()
def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel):
class dummy_function(BaseModel): # noqa: N801
"""dummy function"""
arg1: int = Field(..., description="foo")
@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]:
@pytest.fixture()
def Annotated_function() -> Callable:
def annotated_function() -> Callable:
def dummy_function(
arg1: ExtensionsAnnotated[int, "foo"],
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
@ -118,7 +118,7 @@ def dummy_structured_tool() -> StructuredTool:
@pytest.fixture()
def dummy_pydantic() -> type[BaseModel]:
class dummy_function(BaseModel):
class dummy_function(BaseModel): # noqa: N801
"""dummy function"""
arg1: int = Field(..., description="foo")
@ -129,7 +129,7 @@ def dummy_pydantic() -> type[BaseModel]:
@pytest.fixture()
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
class dummy_function(BaseModelV2Maybe):
class dummy_function(BaseModelV2Maybe): # noqa: N801
"""dummy function"""
arg1: int = FieldV2Maybe(..., description="foo")
@ -142,7 +142,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
@pytest.fixture()
def dummy_typing_typed_dict() -> type:
class dummy_function(TypingTypedDict):
class dummy_function(TypingTypedDict): # noqa: N801
"""dummy function"""
arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821
@ -153,7 +153,7 @@ def dummy_typing_typed_dict() -> type:
@pytest.fixture()
def dummy_typing_typed_dict_docstring() -> type:
class dummy_function(TypingTypedDict):
class dummy_function(TypingTypedDict): # noqa: N801
"""dummy function
Args:
@ -169,7 +169,7 @@ def dummy_typing_typed_dict_docstring() -> type:
@pytest.fixture()
def dummy_extensions_typed_dict() -> type:
class dummy_function(ExtensionsTypedDict):
class dummy_function(ExtensionsTypedDict): # noqa: N801
"""dummy function"""
arg1: ExtensionsAnnotated[int, ..., "foo"]
@ -180,7 +180,7 @@ def dummy_extensions_typed_dict() -> type:
@pytest.fixture()
def dummy_extensions_typed_dict_docstring() -> type:
class dummy_function(ExtensionsTypedDict):
class dummy_function(ExtensionsTypedDict): # noqa: N801
"""dummy function
Args:
@ -241,7 +241,7 @@ def test_convert_to_openai_function(
dummy_structured_tool: StructuredTool,
dummy_tool: BaseTool,
json_schema: dict,
Annotated_function: Callable,
annotated_function: Callable,
dummy_pydantic: type[BaseModel],
runnable: Runnable,
dummy_typing_typed_dict: type,
@ -275,7 +275,7 @@ def test_convert_to_openai_function(
expected,
Dummy.dummy_function,
DummyWithClassMethod.dummy_function,
Annotated_function,
annotated_function,
dummy_pydantic,
dummy_typing_typed_dict,
dummy_typing_typed_dict_docstring,
@ -523,20 +523,20 @@ def test__convert_typed_dict_to_openai_function(
use_extension_typed_dict: bool, use_extension_annotated: bool
) -> None:
if use_extension_typed_dict:
TypedDict = ExtensionsTypedDict
typed_dict = ExtensionsTypedDict
else:
TypedDict = TypingTypedDict
typed_dict = TypingTypedDict
if use_extension_annotated:
Annotated = TypingAnnotated
annotated = TypingAnnotated
else:
Annotated = TypingAnnotated
annotated = TypingAnnotated
class SubTool(TypedDict):
class SubTool(typed_dict):
"""Subtool docstring"""
args: Annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
args: annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
class Tool(TypedDict):
class Tool(typed_dict):
"""Docstring
Args:
@ -546,20 +546,20 @@ def test__convert_typed_dict_to_openai_function(
arg1: str
arg2: Union[int, str, bool]
arg3: Optional[list[SubTool]]
arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722
arg5: Annotated[Optional[float], None]
arg6: Annotated[
arg4: annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722
arg5: annotated[Optional[float], None]
arg6: annotated[
Optional[Sequence[Mapping[str, tuple[Iterable[Any], SubTool]]]], []
]
arg7: Annotated[list[SubTool], ...]
arg8: Annotated[tuple[SubTool], ...]
arg9: Annotated[Sequence[SubTool], ...]
arg10: Annotated[Iterable[SubTool], ...]
arg11: Annotated[set[SubTool], ...]
arg12: Annotated[dict[str, SubTool], ...]
arg13: Annotated[Mapping[str, SubTool], ...]
arg14: Annotated[MutableMapping[str, SubTool], ...]
arg15: Annotated[bool, False, "flag"] # noqa: F821 # type: ignore
arg7: annotated[list[SubTool], ...]
arg8: annotated[tuple[SubTool], ...]
arg9: annotated[Sequence[SubTool], ...]
arg10: annotated[Iterable[SubTool], ...]
arg11: annotated[set[SubTool], ...]
arg12: annotated[dict[str, SubTool], ...]
arg13: annotated[Mapping[str, SubTool], ...]
arg14: annotated[MutableMapping[str, SubTool], ...]
arg15: annotated[bool, False, "flag"] # noqa: F821 # type: ignore
expected = {
"name": "Tool",

View File

@ -10,7 +10,7 @@ from langchain_standard_tests.integration_tests.vectorstores import (
from langchain_core.documents import Document
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
from langchain_core.vectorstores import InMemoryVectorStore
from tests.unit_tests.stubs import _AnyIdDocument
from tests.unit_tests.stubs import _any_id_document
class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite):
@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None:
# Check sync version
output = store.similarity_search("foo", k=1)
assert output == [_AnyIdDocument(page_content="foo")]
assert output == [_any_id_document(page_content="foo")]
# Check async version
output = await store.asimilarity_search("bar", k=2)
assert output == [
_AnyIdDocument(page_content="bar"),
_AnyIdDocument(page_content="baz"),
_any_id_document(page_content="bar"),
_any_id_document(page_content="baz"),
]
@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None:
# make sure we can k > docstore size
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
assert len(output) == len(texts)
assert output[0] == _AnyIdDocument(page_content="foo")
assert output[1] == _AnyIdDocument(page_content="foy")
assert output[0] == _any_id_document(page_content="foo")
assert output[1] == _any_id_document(page_content="foy")
# Check async version
output = await docsearch.amax_marginal_relevance_search(
"foo", k=10, lambda_mult=0.1
)
assert len(output) == len(texts)
assert output[0] == _AnyIdDocument(page_content="foo")
assert output[1] == _AnyIdDocument(page_content="foy")
assert output[0] == _any_id_document(page_content="foo")
assert output[1] == _any_id_document(page_content="foy")
async def test_inmemory_dump_load(tmp_path: Path) -> None:
@ -117,7 +117,7 @@ async def test_inmemory_filter() -> None:
# Check sync version
output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1)
assert output == [_AnyIdDocument(page_content="foo", metadata={"id": 1})]
assert output == [_any_id_document(page_content="foo", metadata={"id": 1})]
# filter with not stored document id
output = await store.asimilarity_search(

View File

@ -50,6 +50,7 @@ class CustomAddTextsVectorstore(VectorStore):
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
return [self.store[id] for id in ids if id in self.store]
@classmethod
def from_texts( # type: ignore
cls,
texts: list[str],