mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
core: Add ruff rules PT (pytest) (#29381)
See https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
This commit is contained in:
committed by
GitHub
parent
6896c863e8
commit
8a33402016
@@ -6,7 +6,7 @@ from langchain_core.utils.aiter import abatch_iterate
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_size, input_iterable, expected_output",
|
||||
("input_size", "input_iterable", "expected_output"),
|
||||
[
|
||||
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
||||
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
||||
|
@@ -51,7 +51,12 @@ def test_get_from_dict_or_env() -> None:
|
||||
|
||||
# Not the most obvious behavior, but
|
||||
# this is how it works right now
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Did not find not exists, "
|
||||
"please add an environment variable `__SOME_KEY_IN_ENV` which contains it, "
|
||||
"or pass `not exists` as a named parameter.",
|
||||
):
|
||||
assert (
|
||||
get_from_dict_or_env(
|
||||
{
|
||||
|
@@ -37,7 +37,7 @@ from langchain_core.utils.function_calling import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel): # noqa: N801
|
||||
"""Dummy function."""
|
||||
@@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def annotated_function() -> Callable:
|
||||
def dummy_function(
|
||||
arg1: ExtensionsAnnotated[int, "foo"],
|
||||
@@ -59,7 +59,7 @@ def annotated_function() -> Callable:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def function() -> Callable:
|
||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||
"""Dummy function.
|
||||
@@ -72,7 +72,7 @@ def function() -> Callable:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def function_docstring_annotations() -> Callable:
|
||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||
"""Dummy function.
|
||||
@@ -85,7 +85,7 @@ def function_docstring_annotations() -> Callable:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def runnable() -> Runnable:
|
||||
class Args(ExtensionsTypedDict):
|
||||
arg1: ExtensionsAnnotated[int, "foo"]
|
||||
@@ -97,7 +97,7 @@ def runnable() -> Runnable:
|
||||
return RunnableLambda(dummy_function)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_tool() -> BaseTool:
|
||||
class Schema(BaseModel):
|
||||
arg1: int = Field(..., description="foo")
|
||||
@@ -114,7 +114,7 @@ def dummy_tool() -> BaseTool:
|
||||
return DummyFunction()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_structured_tool() -> StructuredTool:
|
||||
class Schema(BaseModel):
|
||||
arg1: int = Field(..., description="foo")
|
||||
@@ -128,7 +128,7 @@ def dummy_structured_tool() -> StructuredTool:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
||||
args_schema = {
|
||||
"type": "object",
|
||||
@@ -150,7 +150,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel): # noqa: N801
|
||||
"""Dummy function."""
|
||||
@@ -161,7 +161,7 @@ def dummy_pydantic() -> type[BaseModel]:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
||||
class dummy_function(BaseModelV2Maybe): # noqa: N801
|
||||
"""Dummy function."""
|
||||
@@ -174,7 +174,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_typing_typed_dict() -> type:
|
||||
class dummy_function(TypingTypedDict): # noqa: N801
|
||||
"""Dummy function."""
|
||||
@@ -185,7 +185,7 @@ def dummy_typing_typed_dict() -> type:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_typing_typed_dict_docstring() -> type:
|
||||
class dummy_function(TypingTypedDict): # noqa: N801
|
||||
"""Dummy function.
|
||||
@@ -201,7 +201,7 @@ def dummy_typing_typed_dict_docstring() -> type:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_extensions_typed_dict() -> type:
|
||||
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
||||
"""Dummy function."""
|
||||
@@ -212,7 +212,7 @@ def dummy_extensions_typed_dict() -> type:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_extensions_typed_dict_docstring() -> type:
|
||||
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
||||
"""Dummy function.
|
||||
@@ -228,7 +228,7 @@ def dummy_extensions_typed_dict_docstring() -> type:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def json_schema() -> dict:
|
||||
return {
|
||||
"title": "dummy_function",
|
||||
@@ -246,7 +246,7 @@ def json_schema() -> dict:
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def anthropic_tool() -> dict:
|
||||
return {
|
||||
"name": "dummy_function",
|
||||
@@ -266,7 +266,7 @@ def anthropic_tool() -> dict:
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def bedrock_converse_tool() -> dict:
|
||||
return {
|
||||
"toolSpec": {
|
||||
|
@@ -4,7 +4,7 @@ from langchain_core.utils.iter import batch_iterate
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_size, input_iterable, expected_output",
|
||||
("input_size", "input_iterable", "expected_output"),
|
||||
[
|
||||
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
||||
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
||||
|
@@ -147,7 +147,7 @@ def test_dereference_refs_remote_ref() -> None:
|
||||
"first_name": {"$ref": "https://somewhere/else/name"},
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="ref paths are expected to be URI fragments"):
|
||||
dereference_refs(schema)
|
||||
|
||||
|
||||
|
@@ -220,7 +220,7 @@ output5 = {
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema, output",
|
||||
("schema", "output"),
|
||||
[
|
||||
(schema1, output1),
|
||||
(schema2, output2),
|
||||
|
@@ -29,12 +29,17 @@ def test_dict_int_op_nested() -> None:
|
||||
def test_dict_int_op_max_depth_exceeded() -> None:
|
||||
left = {"a": {"b": {"c": 1}}}
|
||||
right = {"a": {"b": {"c": 2}}}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="max_depth=2 exceeded, unable to combine dicts."
|
||||
):
|
||||
_dict_int_op(left, right, operator.add, max_depth=2)
|
||||
|
||||
|
||||
def test_dict_int_op_invalid_types() -> None:
|
||||
left = {"a": 1, "b": "string"}
|
||||
right = {"a": 2, "b": 3}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Only dict and int values are supported.",
|
||||
):
|
||||
_dict_int_op(left, right, operator.add)
|
||||
|
@@ -46,7 +46,7 @@ def test_check_package_version(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("left", "right", "expected"),
|
||||
(
|
||||
[
|
||||
# Merge `None` and `1`.
|
||||
({"a": None}, {"a": 1}, {"a": 1}),
|
||||
# Merge `1` and `None`.
|
||||
@@ -111,7 +111,7 @@ def test_check_package_version(
|
||||
{"a": [{"idx": 0, "b": "f"}]},
|
||||
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_merge_dicts(
|
||||
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
|
||||
@@ -130,7 +130,7 @@ def test_merge_dicts(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("left", "right", "expected"),
|
||||
(
|
||||
[
|
||||
# 'type' special key handling
|
||||
({"type": "foo"}, {"type": "foo"}, {"type": "foo"}),
|
||||
(
|
||||
@@ -138,7 +138,7 @@ def test_merge_dicts(
|
||||
{"type": "bar"},
|
||||
pytest.raises(ValueError, match="Unable to merge."),
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.xfail(reason="Refactors to make in 0.3")
|
||||
def test_merge_dicts_0_3(
|
||||
@@ -183,36 +183,32 @@ def test_guard_import(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("module_name", "pip_name", "package"),
|
||||
("module_name", "pip_name", "package", "expected_pip_name"),
|
||||
[
|
||||
("langchain_core.utilsW", None, None),
|
||||
("langchain_core.utilsW", "langchain-core-2", None),
|
||||
("langchain_core.utilsW", None, "langchain-coreWX"),
|
||||
("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"),
|
||||
("langchain_coreW", None, None), # ModuleNotFoundError
|
||||
("langchain_core.utilsW", None, None, "langchain-core"),
|
||||
("langchain_core.utilsW", "langchain-core-2", None, "langchain-core-2"),
|
||||
("langchain_core.utilsW", None, "langchain-coreWX", "langchain-core"),
|
||||
(
|
||||
"langchain_core.utilsW",
|
||||
"langchain-core-2",
|
||||
"langchain-coreWX",
|
||||
"langchain-core-2",
|
||||
),
|
||||
("langchain_coreW", None, None, "langchain-coreW"), # ModuleNotFoundError
|
||||
],
|
||||
)
|
||||
def test_guard_import_failure(
|
||||
module_name: str, pip_name: Optional[str], package: Optional[str]
|
||||
module_name: str,
|
||||
pip_name: Optional[str],
|
||||
package: Optional[str],
|
||||
expected_pip_name: str,
|
||||
) -> None:
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
if package is None and pip_name is None:
|
||||
guard_import(module_name)
|
||||
elif package is None and pip_name is not None:
|
||||
guard_import(module_name, pip_name=pip_name)
|
||||
elif package is not None and pip_name is None:
|
||||
guard_import(module_name, package=package)
|
||||
elif package is not None and pip_name is not None:
|
||||
guard_import(module_name, pip_name=pip_name, package=package)
|
||||
else:
|
||||
msg = "Invalid test case"
|
||||
raise ValueError(msg)
|
||||
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
||||
err_msg = (
|
||||
f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {pip_name}`."
|
||||
)
|
||||
assert exc_info.value.msg == err_msg
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match=f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {expected_pip_name}`.",
|
||||
):
|
||||
guard_import(module_name, pip_name=pip_name, package=package)
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
||||
|
Reference in New Issue
Block a user