core: Add ruff rules PT (pytest) (#29381)

See https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
This commit is contained in:
Christophe Bornet
2025-04-01 19:31:07 +02:00
committed by GitHub
parent 6896c863e8
commit 8a33402016
34 changed files with 379 additions and 227 deletions

View File

@@ -6,7 +6,7 @@ from langchain_core.utils.aiter import abatch_iterate
@pytest.mark.parametrize(
"input_size, input_iterable, expected_output",
("input_size", "input_iterable", "expected_output"),
[
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),

View File

@@ -51,7 +51,12 @@ def test_get_from_dict_or_env() -> None:
# Not the most obvious behavior, but
# this is how it works right now
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Did not find not exists, "
"please add an environment variable `__SOME_KEY_IN_ENV` which contains it, "
"or pass `not exists` as a named parameter.",
):
assert (
get_from_dict_or_env(
{

View File

@@ -37,7 +37,7 @@ from langchain_core.utils.function_calling import (
)
@pytest.fixture()
@pytest.fixture
def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801
"""Dummy function."""
@@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]:
return dummy_function
@pytest.fixture()
@pytest.fixture
def annotated_function() -> Callable:
def dummy_function(
arg1: ExtensionsAnnotated[int, "foo"],
@@ -59,7 +59,7 @@ def annotated_function() -> Callable:
return dummy_function
@pytest.fixture()
@pytest.fixture
def function() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""Dummy function.
@@ -72,7 +72,7 @@ def function() -> Callable:
return dummy_function
@pytest.fixture()
@pytest.fixture
def function_docstring_annotations() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""Dummy function.
@@ -85,7 +85,7 @@ def function_docstring_annotations() -> Callable:
return dummy_function
@pytest.fixture()
@pytest.fixture
def runnable() -> Runnable:
class Args(ExtensionsTypedDict):
arg1: ExtensionsAnnotated[int, "foo"]
@@ -97,7 +97,7 @@ def runnable() -> Runnable:
return RunnableLambda(dummy_function)
@pytest.fixture()
@pytest.fixture
def dummy_tool() -> BaseTool:
class Schema(BaseModel):
arg1: int = Field(..., description="foo")
@@ -114,7 +114,7 @@ def dummy_tool() -> BaseTool:
return DummyFunction()
@pytest.fixture()
@pytest.fixture
def dummy_structured_tool() -> StructuredTool:
class Schema(BaseModel):
arg1: int = Field(..., description="foo")
@@ -128,7 +128,7 @@ def dummy_structured_tool() -> StructuredTool:
)
@pytest.fixture()
@pytest.fixture
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
args_schema = {
"type": "object",
@@ -150,7 +150,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
)
@pytest.fixture()
@pytest.fixture
def dummy_pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801
"""Dummy function."""
@@ -161,7 +161,7 @@ def dummy_pydantic() -> type[BaseModel]:
return dummy_function
@pytest.fixture()
@pytest.fixture
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
class dummy_function(BaseModelV2Maybe): # noqa: N801
"""Dummy function."""
@@ -174,7 +174,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
return dummy_function
@pytest.fixture()
@pytest.fixture
def dummy_typing_typed_dict() -> type:
class dummy_function(TypingTypedDict): # noqa: N801
"""Dummy function."""
@@ -185,7 +185,7 @@ def dummy_typing_typed_dict() -> type:
return dummy_function
@pytest.fixture()
@pytest.fixture
def dummy_typing_typed_dict_docstring() -> type:
class dummy_function(TypingTypedDict): # noqa: N801
"""Dummy function.
@@ -201,7 +201,7 @@ def dummy_typing_typed_dict_docstring() -> type:
return dummy_function
@pytest.fixture()
@pytest.fixture
def dummy_extensions_typed_dict() -> type:
class dummy_function(ExtensionsTypedDict): # noqa: N801
"""Dummy function."""
@@ -212,7 +212,7 @@ def dummy_extensions_typed_dict() -> type:
return dummy_function
@pytest.fixture()
@pytest.fixture
def dummy_extensions_typed_dict_docstring() -> type:
class dummy_function(ExtensionsTypedDict): # noqa: N801
"""Dummy function.
@@ -228,7 +228,7 @@ def dummy_extensions_typed_dict_docstring() -> type:
return dummy_function
@pytest.fixture()
@pytest.fixture
def json_schema() -> dict:
return {
"title": "dummy_function",
@@ -246,7 +246,7 @@ def json_schema() -> dict:
}
@pytest.fixture()
@pytest.fixture
def anthropic_tool() -> dict:
return {
"name": "dummy_function",
@@ -266,7 +266,7 @@ def anthropic_tool() -> dict:
}
@pytest.fixture()
@pytest.fixture
def bedrock_converse_tool() -> dict:
return {
"toolSpec": {

View File

@@ -4,7 +4,7 @@ from langchain_core.utils.iter import batch_iterate
@pytest.mark.parametrize(
"input_size, input_iterable, expected_output",
("input_size", "input_iterable", "expected_output"),
[
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),

View File

@@ -147,7 +147,7 @@ def test_dereference_refs_remote_ref() -> None:
"first_name": {"$ref": "https://somewhere/else/name"},
},
}
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="ref paths are expected to be URI fragments"):
dereference_refs(schema)

View File

@@ -220,7 +220,7 @@ output5 = {
@pytest.mark.parametrize(
"schema, output",
("schema", "output"),
[
(schema1, output1),
(schema2, output2),

View File

@@ -29,12 +29,17 @@ def test_dict_int_op_nested() -> None:
def test_dict_int_op_max_depth_exceeded() -> None:
left = {"a": {"b": {"c": 1}}}
right = {"a": {"b": {"c": 2}}}
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="max_depth=2 exceeded, unable to combine dicts."
):
_dict_int_op(left, right, operator.add, max_depth=2)
def test_dict_int_op_invalid_types() -> None:
left = {"a": 1, "b": "string"}
right = {"a": 2, "b": 3}
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Only dict and int values are supported.",
):
_dict_int_op(left, right, operator.add)

View File

@@ -46,7 +46,7 @@ def test_check_package_version(
@pytest.mark.parametrize(
("left", "right", "expected"),
(
[
# Merge `None` and `1`.
({"a": None}, {"a": 1}, {"a": 1}),
# Merge `1` and `None`.
@@ -111,7 +111,7 @@ def test_check_package_version(
{"a": [{"idx": 0, "b": "f"}]},
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
),
),
],
)
def test_merge_dicts(
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
@@ -130,7 +130,7 @@ def test_merge_dicts(
@pytest.mark.parametrize(
("left", "right", "expected"),
(
[
# 'type' special key handling
({"type": "foo"}, {"type": "foo"}, {"type": "foo"}),
(
@@ -138,7 +138,7 @@ def test_merge_dicts(
{"type": "bar"},
pytest.raises(ValueError, match="Unable to merge."),
),
),
],
)
@pytest.mark.xfail(reason="Refactors to make in 0.3")
def test_merge_dicts_0_3(
@@ -183,36 +183,32 @@ def test_guard_import(
@pytest.mark.parametrize(
("module_name", "pip_name", "package"),
("module_name", "pip_name", "package", "expected_pip_name"),
[
("langchain_core.utilsW", None, None),
("langchain_core.utilsW", "langchain-core-2", None),
("langchain_core.utilsW", None, "langchain-coreWX"),
("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"),
("langchain_coreW", None, None), # ModuleNotFoundError
("langchain_core.utilsW", None, None, "langchain-core"),
("langchain_core.utilsW", "langchain-core-2", None, "langchain-core-2"),
("langchain_core.utilsW", None, "langchain-coreWX", "langchain-core"),
(
"langchain_core.utilsW",
"langchain-core-2",
"langchain-coreWX",
"langchain-core-2",
),
("langchain_coreW", None, None, "langchain-coreW"), # ModuleNotFoundError
],
)
def test_guard_import_failure(
module_name: str, pip_name: Optional[str], package: Optional[str]
module_name: str,
pip_name: Optional[str],
package: Optional[str],
expected_pip_name: str,
) -> None:
with pytest.raises(ImportError) as exc_info:
if package is None and pip_name is None:
guard_import(module_name)
elif package is None and pip_name is not None:
guard_import(module_name, pip_name=pip_name)
elif package is not None and pip_name is None:
guard_import(module_name, package=package)
elif package is not None and pip_name is not None:
guard_import(module_name, pip_name=pip_name, package=package)
else:
msg = "Invalid test case"
raise ValueError(msg)
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
err_msg = (
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name}`."
)
assert exc_info.value.msg == err_msg
with pytest.raises(
ImportError,
match=f"Could not import {module_name} python package. "
f"Please install it with `pip install {expected_pip_name}`.",
):
guard_import(module_name, pip_name=pip_name, package=package)
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")