From 8b8d90bea5a4b8698e02d4f55e27373adc08f3f7 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 21 Jul 2025 19:15:05 +0200 Subject: [PATCH] feat(langchain): add ruff rules PT (#32010) See https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt --- libs/langchain/pyproject.toml | 1 + .../integration_tests/embeddings/test_base.py | 2 +- .../tests/unit_tests/agents/test_types.py | 7 ++--- .../chains/query_constructor/test_parser.py | 24 +++++++------- .../test_map_rerank_prompt.py | 2 +- .../tests/unit_tests/chains/test_base.py | 27 ++++++++++------ .../chains/test_combine_documents.py | 19 ++++++++++-- .../unit_tests/chains/test_conversation.py | 20 ++++++++---- .../tests/unit_tests/chains/test_llm_math.py | 2 +- .../unit_tests/chains/test_qa_with_sources.py | 2 +- .../unit_tests/chains/test_sequential.py | 31 +++++++++++++++---- .../tests/unit_tests/chains/test_transform.py | 2 +- .../tests/unit_tests/chat_models/test_base.py | 4 +-- libs/langchain/tests/unit_tests/conftest.py | 7 +++-- .../tests/unit_tests/embeddings/test_base.py | 5 +-- .../evaluation/comparison/test_eval_chain.py | 4 ++- .../evaluation/criteria/test_eval_chain.py | 6 ++-- .../evaluation/qa/test_eval_chain.py | 2 +- .../evaluation/scoring/test_eval_chain.py | 10 ++++-- .../tests/unit_tests/indexes/test_indexing.py | 24 +++++++++++--- .../unit_tests/memory/test_combined_memory.py | 6 +++- .../output_parsers/test_boolean_parser.py | 14 +++++++-- .../output_parsers/test_datetime_parser.py | 22 ++++++------- .../unit_tests/output_parsers/test_fix.py | 4 +-- .../unit_tests/output_parsers/test_retry.py | 8 ++--- .../output_parsers/test_yaml_parser.py | 10 ++---- .../unit_tests/retrievers/test_multi_query.py | 4 +-- .../unit_tests/storage/test_filesystem.py | 2 +- .../tests/unit_tests/test_formatting.py | 6 +++- libs/langchain/tests/unit_tests/test_utils.py | 4 ++- .../tests/unit_tests/utils/test_iter.py | 2 +- 31 files changed, 181 insertions(+), 102 deletions(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 1b6c4f055b2..52e20a4fdff 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -167,6 +167,7 @@ select = [ "PGH", # pygrep-hooks "PIE", # flake8-pie "PERF", # flake8-perf + "PT", # flake8-pytest-style "PTH", # flake8-use-pathlib "PYI", # flake8-pyi "Q", # flake8-quotes diff --git a/libs/langchain/tests/integration_tests/embeddings/test_base.py b/libs/langchain/tests/integration_tests/embeddings/test_base.py index 204754642fd..4af0b24aa2c 100644 --- a/libs/langchain/tests/integration_tests/embeddings/test_base.py +++ b/libs/langchain/tests/integration_tests/embeddings/test_base.py @@ -9,7 +9,7 @@ from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings @pytest.mark.parametrize( - "provider, model", + ("provider", "model"), [ ("openai", "text-embedding-3-large"), ("google_vertexai", "text-embedding-gecko@003"), diff --git a/libs/langchain/tests/unit_tests/agents/test_types.py b/libs/langchain/tests/unit_tests/agents/test_types.py index 536d1f1d177..649c49717f0 100644 --- a/libs/langchain/tests/unit_tests/agents/test_types.py +++ b/libs/langchain/tests/unit_tests/agents/test_types.py @@ -1,9 +1,6 @@ -import unittest - from langchain.agents.agent_types import AgentType from langchain.agents.types import AGENT_TO_CLASS -class TestTypes(unittest.TestCase): - def test_confirm_full_coverage(self) -> None: - self.assertEqual(list(AgentType), list(AGENT_TO_CLASS.keys())) +def test_confirm_full_coverage() -> None: + assert list(AgentType) == list(AGENT_TO_CLASS.keys()) diff --git a/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py b/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py index 73846eb0929..8090c7922e0 100644 --- a/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py +++ b/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py @@ -16,7 +16,7 @@ from langchain.chains.query_constructor.parser import get_parser DEFAULT_PARSER = get_parser() -@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")')) +@pytest.mark.parametrize("x", ["", "foo", 'foo("bar", "baz")']) def test_parse_invalid_grammar(x: str) -> None: with pytest.raises((ValueError, lark.exceptions.UnexpectedToken)): DEFAULT_PARSER.parse(x) @@ -71,13 +71,13 @@ def test_parse_nested_operation() -> None: def test_parse_disallowed_comparator() -> None: parser = get_parser(allowed_comparators=[Comparator.EQ]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Received disallowed comparator gt."): parser.parse('gt("a", 2)') def test_parse_disallowed_operator() -> None: parser = get_parser(allowed_operators=[Operator.AND]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Received disallowed operator not."): parser.parse('not(gt("a", 2))') @@ -87,29 +87,29 @@ def _test_parse_value(x: Any) -> None: assert actual == x -@pytest.mark.parametrize("x", (-1, 0, 1_000_000)) +@pytest.mark.parametrize("x", [-1, 0, 1_000_000]) def test_parse_int_value(x: int) -> None: _test_parse_value(x) -@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210)) +@pytest.mark.parametrize("x", [-1.001, 0.00000002, 1_234_567.6543210]) def test_parse_float_value(x: float) -> None: _test_parse_value(x) -@pytest.mark.parametrize("x", ([], [1, "b", "true"])) +@pytest.mark.parametrize("x", [[], [1, "b", "true"]]) def test_parse_list_value(x: list) -> None: _test_parse_value(x) -@pytest.mark.parametrize("x", ('""', '" "', '"foo"', "'foo'")) +@pytest.mark.parametrize("x", ['""', '" "', '"foo"', "'foo'"]) def test_parse_string_value(x: str) -> None: parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) actual = parsed.value assert actual == x[1:-1] -@pytest.mark.parametrize("x", ("true", "True", "TRUE", "false", "False", "FALSE")) +@pytest.mark.parametrize("x", ["true", "True", "TRUE", "false", "False", "FALSE"]) def test_parse_bool_value(x: str) -> None: parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) actual = parsed.value @@ -117,15 +117,15 @@ def test_parse_bool_value(x: str) -> None: assert actual == expected -@pytest.mark.parametrize("op", ("and", "or")) -@pytest.mark.parametrize("arg", ('eq("foo", 2)', 'and(eq("foo", 2), lte("bar", 1.1))')) +@pytest.mark.parametrize("op", ["and", "or"]) +@pytest.mark.parametrize("arg", ['eq("foo", 2)', 'and(eq("foo", 2), lte("bar", 1.1))']) def test_parser_unpack_single_arg_operation(op: str, arg: str) -> None: expected = DEFAULT_PARSER.parse(arg) actual = DEFAULT_PARSER.parse(f"{op}({arg})") assert expected == actual -@pytest.mark.parametrize("x", ('"2022-10-20"', "'2022-10-20'", "2022-10-20")) +@pytest.mark.parametrize("x", ['"2022-10-20"', "'2022-10-20'", "2022-10-20"]) def test_parse_date_value(x: str) -> None: parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) actual = parsed.value["date"] @@ -133,7 +133,7 @@ def test_parse_date_value(x: str) -> None: @pytest.mark.parametrize( - "x, expected", + ("x", "expected"), [ ( '"2021-01-01T00:00:00"', diff --git a/libs/langchain/tests/unit_tests/chains/question_answering/test_map_rerank_prompt.py b/libs/langchain/tests/unit_tests/chains/question_answering/test_map_rerank_prompt.py index 8d27fa0635e..5628429226e 100644 --- a/libs/langchain/tests/unit_tests/chains/question_answering/test_map_rerank_prompt.py +++ b/libs/langchain/tests/unit_tests/chains/question_answering/test_map_rerank_prompt.py @@ -8,7 +8,7 @@ GOOD_SCORE = "foo bar answer.\nScore: 80" SCORE_WITH_EXPLANATION = "foo bar answer.\nScore: 80 (fully answers the question, but could provide more detail on the specific error message)" # noqa: E501 -@pytest.mark.parametrize("answer", (GOOD_SCORE, SCORE_WITH_EXPLANATION)) +@pytest.mark.parametrize("answer", [GOOD_SCORE, SCORE_WITH_EXPLANATION]) def test_parse_scores(answer: str) -> None: result = output_parser.parse(answer) diff --git a/libs/langchain/tests/unit_tests/chains/test_base.py b/libs/langchain/tests/unit_tests/chains/test_base.py index 819197e2ef2..0e47902bbee 100644 --- a/libs/langchain/tests/unit_tests/chains/test_base.py +++ b/libs/langchain/tests/unit_tests/chains/test_base.py @@ -65,14 +65,14 @@ class FakeChain(Chain): def test_bad_inputs() -> None: """Test errors are raised if input keys are not found.""" chain = FakeChain() - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Missing some input keys: {'foo'}"): chain({"foobar": "baz"}) def test_bad_outputs() -> None: """Test errors are raised if outputs keys are not found.""" chain = FakeChain(be_correct=False) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Missing some output keys: {'bar'}"): chain({"foo": "baz"}) @@ -102,7 +102,7 @@ def test_single_input_correct() -> None: def test_single_input_error() -> None: """Test passing single input errors as expected.""" chain = FakeChain(the_input_keys=["foo", "bar"]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Missing some input keys:"): chain("bar") @@ -116,7 +116,9 @@ def test_run_single_arg() -> None: def test_run_multiple_args_error() -> None: """Test run method with multiple args errors as expected.""" chain = FakeChain() - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="`run` supports only one positional argument." + ): chain.run("bar", "foo") @@ -130,21 +132,28 @@ def test_run_kwargs() -> None: def test_run_kwargs_error() -> None: """Test run method with kwargs errors as expected.""" chain = FakeChain(the_input_keys=["foo", "bar"]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Missing some input keys: {'bar'}"): chain.run(foo="bar", baz="foo") def test_run_args_and_kwargs_error() -> None: """Test run method with args and kwargs.""" chain = FakeChain(the_input_keys=["foo", "bar"]) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="`run` supported with either positional arguments " + "or keyword arguments but not both.", + ): chain.run("bar", foo="bar") def test_multiple_output_keys_error() -> None: """Test run with multiple output keys errors as expected.""" chain = FakeChain(the_output_keys=["foo", "bar"]) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="`run` not supported when there is not exactly one output key.", + ): chain.run("bar") @@ -175,7 +184,7 @@ def test_run_with_callback_and_input_error() -> None: callbacks=[handler], ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Missing some input keys: {'foo'}"): chain({"bar": "foo"}) assert handler.starts == 1 @@ -222,7 +231,7 @@ def test_run_with_callback_and_output_error() -> None: callbacks=[handler], ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Missing some output keys: {'foo'}"): chain("foo") assert handler.starts == 1 diff --git a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py index fd26b36569d..ff08295179e 100644 --- a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py +++ b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py @@ -1,5 +1,6 @@ """Test functionality related to combining documents.""" +import re from typing import Any import pytest @@ -30,7 +31,9 @@ def test_multiple_input_keys() -> None: def test__split_list_long_single_doc() -> None: """Test splitting of a long single doc.""" docs = [Document(page_content="foo" * 100)] - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="A single document was longer than the context length" + ): split_list_of_docs(docs, _fake_docs_len_func, 100) @@ -140,7 +143,17 @@ async def test_format_doc_missing_metadata() -> None: input_variables=["page_content", "bar"], template="{page_content}, {bar}", ) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=re.escape( + "Document prompt requires documents to have metadata variables: ['bar']." + ), + ): format_document(doc, prompt) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=re.escape( + "Document prompt requires documents to have metadata variables: ['bar']." + ), + ): await aformat_document(doc, prompt) diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation.py b/libs/langchain/tests/unit_tests/chains/test_conversation.py index 2dabee35831..ec494d85088 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation.py @@ -1,5 +1,6 @@ """Test conversation chain and memory.""" +import re from typing import Any, Optional import pytest @@ -76,7 +77,9 @@ def test_conversation_chain_errors_bad_prompt() -> None: """Test that conversation chain raise error with bad prompt.""" llm = FakeLLM() prompt = PromptTemplate(input_variables=[], template="nothing here") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Value error, Got unexpected prompt input variables." + ): ConversationChain(llm=llm, prompt=prompt) @@ -85,7 +88,12 @@ def test_conversation_chain_errors_bad_variable() -> None: llm = FakeLLM() prompt = PromptTemplate(input_variables=["foo"], template="{foo}") memory = ConversationBufferMemory(memory_key="foo") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=re.escape( + "Value error, The input key foo was also found in the memory keys (['foo'])" + ), + ): ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="foo") @@ -106,18 +114,18 @@ def test_conversation_memory(memory: BaseMemory) -> None: memory.save_context(good_inputs, good_outputs) # This is a bad input because there are two variables that aren't the same as baz. bad_inputs = {"foo": "bar", "foo1": "bar"} - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="One input key expected"): memory.save_context(bad_inputs, good_outputs) # This is a bad input because the only variable is the same as baz. bad_inputs = {"baz": "bar"} - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("One input key expected got []")): memory.save_context(bad_inputs, good_outputs) # This is a bad output because it is empty. - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Got multiple output keys"): memory.save_context(good_inputs, {}) # This is a bad output because there are two keys. bad_outputs = {"foo": "bar", "foo1": "bar"} - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Got multiple output keys"): memory.save_context(good_inputs, bad_outputs) diff --git a/libs/langchain/tests/unit_tests/chains/test_llm_math.py b/libs/langchain/tests/unit_tests/chains/test_llm_math.py index 26955939aea..42937287ec5 100644 --- a/libs/langchain/tests/unit_tests/chains/test_llm_math.py +++ b/libs/langchain/tests/unit_tests/chains/test_llm_math.py @@ -39,5 +39,5 @@ def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None: @pytest.mark.requires("numexpr") def test_error(fake_llm_math_chain: LLMMathChain) -> None: """Test question that raises error.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="unknown format from LLM: foo"): fake_llm_math_chain.run("foo") diff --git a/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py b/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py index 9788ac900a8..ae21f558534 100644 --- a/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py +++ b/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py @@ -5,7 +5,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM @pytest.mark.parametrize( - "text,answer,sources", + ("text", "answer", "sources"), [ ( "This Agreement is governed by English law.\nSOURCES: 28-pl", diff --git a/libs/langchain/tests/unit_tests/chains/test_sequential.py b/libs/langchain/tests/unit_tests/chains/test_sequential.py index fe5aa283b97..f3313e9139f 100644 --- a/libs/langchain/tests/unit_tests/chains/test_sequential.py +++ b/libs/langchain/tests/unit_tests/chains/test_sequential.py @@ -1,5 +1,6 @@ """Test pipeline functionality.""" +import re from typing import Optional import pytest @@ -94,7 +95,12 @@ def test_sequential_usage_memory() -> None: memory = SimpleMemory(memories={"zab": "rab", "foo": "rab"}) chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=re.escape( + "Value error, The input key(s) foo are found in the Memory keys" + ), + ): SequentialChain( # type: ignore[call-arg] memory=memory, chains=[chain_1, chain_2], @@ -136,7 +142,9 @@ def test_sequential_missing_inputs() -> None: """Test error is raised when input variables are missing.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"]) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Value error, Missing required input keys: {'test'}" + ): # Also needs "test" as an input SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) # type: ignore[call-arg] @@ -145,7 +153,10 @@ def test_sequential_bad_outputs() -> None: """Test error is raised when bad outputs are specified.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Value error, Expected output variables that were not found: {'test'}.", + ): # "test" is not present as an output variable. SequentialChain( chains=[chain_1, chain_2], @@ -172,7 +183,9 @@ def test_sequential_overlapping_inputs() -> None: """Test error is raised when input variables are overlapping.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Value error, Chain returned keys that already exist" + ): # "test" is specified as an input, but also is an output of one step SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"]) # type: ignore[call-arg] @@ -226,7 +239,10 @@ def test_multi_input_errors() -> None: """Test simple sequential errors if multiple input variables are expected.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Value error, Chains used in SimplePipeline should all have one input", + ): SimpleSequentialChain(chains=[chain_1, chain_2]) @@ -234,5 +250,8 @@ def test_multi_output_errors() -> None: """Test simple sequential errors if multiple output variables are expected.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Value error, Chains used in SimplePipeline should all have one output", + ): SimpleSequentialChain(chains=[chain_1, chain_2]) diff --git a/libs/langchain/tests/unit_tests/chains/test_transform.py b/libs/langchain/tests/unit_tests/chains/test_transform.py index 26e95ef263b..cf2ecaa1eb6 100644 --- a/libs/langchain/tests/unit_tests/chains/test_transform.py +++ b/libs/langchain/tests/unit_tests/chains/test_transform.py @@ -35,5 +35,5 @@ def test_transform_chain_bad_inputs() -> None: transform=dummy_transform, ) input_dict = {"name": "Leroy", "last_name": "Jenkins"} - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Missing some input keys: {'first_name'}"): _ = transform_chain(input_dict) diff --git a/libs/langchain/tests/unit_tests/chat_models/test_base.py b/libs/langchain/tests/unit_tests/chat_models/test_base.py index e86fb892584..18279c95e67 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_base.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_base.py @@ -30,7 +30,7 @@ def test_all_imports() -> None: "langchain_groq", ) @pytest.mark.parametrize( - ["model_name", "model_provider"], + ("model_name", "model_provider"), [ ("gpt-4o", "openai"), ("claude-3-opus-20240229", "anthropic"), @@ -57,7 +57,7 @@ def test_init_missing_dep() -> None: def test_init_unknown_provider() -> None: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Unsupported model_provider='bar'."): init_chat_model("foo", model_provider="bar") diff --git a/libs/langchain/tests/unit_tests/conftest.py b/libs/langchain/tests/unit_tests/conftest.py index 2b3977b8b47..fe0b62e766f 100644 --- a/libs/langchain/tests/unit_tests/conftest.py +++ b/libs/langchain/tests/unit_tests/conftest.py @@ -5,7 +5,6 @@ from importlib import util import pytest from blockbuster import blockbuster_ctx -from pytest import Config, Function, Parser @pytest.fixture(autouse=True) @@ -40,7 +39,7 @@ def blockbuster() -> Iterator[None]: yield -def pytest_addoption(parser: Parser) -> None: +def pytest_addoption(parser: pytest.Parser) -> None: """Add custom command line options to pytest.""" parser.addoption( "--only-extended", @@ -62,7 +61,9 @@ def pytest_addoption(parser: Parser) -> None: ) -def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: +def pytest_collection_modifyitems( + config: pytest.Config, items: Sequence[pytest.Function] +) -> None: """Add implementations for handling custom markers. At the moment, this adds support for a custom `requires` marker. diff --git a/libs/langchain/tests/unit_tests/embeddings/test_base.py b/libs/langchain/tests/unit_tests/embeddings/test_base.py index 5df628ad19b..30bfaeb6777 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_base.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_base.py @@ -88,12 +88,9 @@ def test_infer_model_and_provider_errors() -> None: _infer_model_and_provider("model", provider="") # Test invalid provider - with pytest.raises(ValueError, match="is not supported"): + with pytest.raises(ValueError, match="Provider 'invalid' is not supported.") as exc: _infer_model_and_provider("model", provider="invalid") - # Test provider list is in error - with pytest.raises(ValueError) as exc: - _infer_model_and_provider("model", provider="invalid") for provider in _SUPPORTED_PROVIDERS: assert provider in str(exc.value) diff --git a/libs/langchain/tests/unit_tests/evaluation/comparison/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/comparison/test_eval_chain.py index 583b63c1f92..cf1df2509b7 100644 --- a/libs/langchain/tests/unit_tests/evaluation/comparison/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/comparison/test_eval_chain.py @@ -112,7 +112,9 @@ def test_labeled_pairwise_string_comparison_chain_missing_ref() -> None: sequential_responses=True, ) chain = LabeledPairwiseStringEvalChain.from_llm(llm=llm) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="LabeledPairwiseStringEvalChain requires a reference string." + ): chain.evaluate_string_pairs( prediction="I like pie.", prediction_b="I love pie.", diff --git a/libs/langchain/tests/unit_tests/evaluation/criteria/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/criteria/test_eval_chain.py index fa6605b8eed..9a8dc153bc3 100644 --- a/libs/langchain/tests/unit_tests/evaluation/criteria/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/criteria/test_eval_chain.py @@ -23,7 +23,7 @@ def test_resolve_criteria_str() -> None: @pytest.mark.parametrize( - "text,want", + ("text", "want"), [ ("Y", {"reasoning": "", "value": "Y", "score": 1}), ( @@ -91,7 +91,9 @@ def test_criteria_eval_chain_missing_reference() -> None: ), criteria={"my criterion": "my criterion description"}, ) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="LabeledCriteriaEvalChain requires a reference string." + ): chain.evaluate_strings(prediction="my prediction", input="my input") diff --git a/libs/langchain/tests/unit_tests/evaluation/qa/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/qa/test_eval_chain.py index 5d9c5cda163..ad560a17f3d 100644 --- a/libs/langchain/tests/unit_tests/evaluation/qa/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/qa/test_eval_chain.py @@ -91,7 +91,7 @@ def test_returns_expected_results( @pytest.mark.parametrize( - "output,expected", + ("output", "expected"), [ ( """ GRADE: CORRECT diff --git a/libs/langchain/tests/unit_tests/evaluation/scoring/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/scoring/test_eval_chain.py index 5d50ad5d2e7..932b708c41c 100644 --- a/libs/langchain/tests/unit_tests/evaluation/scoring/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/scoring/test_eval_chain.py @@ -26,13 +26,15 @@ Rating: [[10]]""" text = """This answer is really good. Rating: 10""" - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Output must contain a double bracketed string" + ): output_parser.parse(text) text = """This answer is really good. Rating: [[0]]""" # Not in range [1, 10] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="with the verdict between 1 and 10"): output_parser.parse(text) @@ -69,7 +71,9 @@ def test_labeled_pairwise_string_comparison_chain_missing_ref() -> None: sequential_responses=True, ) chain = LabeledScoreStringEvalChain.from_llm(llm=llm) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="LabeledScoreStringEvalChain requires a reference string." + ): chain.evaluate_strings( prediction="I like pie.", input="What is your favorite food?", diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index 7e861ef63ab..e5954a5e374 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -432,11 +432,19 @@ def test_incremental_fails_with_bad_source_ids( ], ) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Source id key is required when cleanup mode is incremental " + "or scoped_full.", + ): # Should raise an error because no source id function was specified index(loader, record_manager, vector_store, cleanup="incremental") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Source ids are required when cleanup mode is incremental " + "or scoped_full.", + ): # Should raise an error because no source id function was specified index( loader, @@ -470,7 +478,11 @@ async def test_aincremental_fails_with_bad_source_ids( ], ) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Source id key is required when cleanup mode is incremental " + "or scoped_full.", + ): # Should raise an error because no source id function was specified await aindex( loader, @@ -479,7 +491,11 @@ async def test_aincremental_fails_with_bad_source_ids( cleanup="incremental", ) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Source ids are required when cleanup mode is incremental " + "or scoped_full.", + ): # Should raise an error because no source id function was specified await aindex( loader, diff --git a/libs/langchain/tests/unit_tests/memory/test_combined_memory.py b/libs/langchain/tests/unit_tests/memory/test_combined_memory.py index 0e5df20a9bf..a8853bd919d 100644 --- a/libs/langchain/tests/unit_tests/memory/test_combined_memory.py +++ b/libs/langchain/tests/unit_tests/memory/test_combined_memory.py @@ -34,5 +34,9 @@ def test_basic_functionality(example_memory: list[ConversationBufferMemory]) -> def test_repeated_memory_var(example_memory: list[ConversationBufferMemory]) -> None: """Test raising error when repeated memory variables found""" - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Value error, The same variables {'bar'} are found in " + "multiplememory object, which is not allowed by CombinedMemory.", + ): CombinedMemory(memories=[example_memory[1], example_memory[2]]) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_boolean_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_boolean_parser.py index 79eebb4ed08..3987764bf7e 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_boolean_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_boolean_parser.py @@ -31,13 +31,21 @@ def test_boolean_output_parser_parse() -> None: assert result is True # Test ambiguous input - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Ambiguous response. Both YES and NO in received: YES NO." + ): parser.parse("YES NO") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Ambiguous response. Both YES and NO in received: NO YES." + ): parser.parse("NO YES") # Bad input - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="BooleanOutputParser expected output value to include either YES or NO. " + "Received BOOM.", + ): parser.parse("BOOM") diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_datetime_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_datetime_parser.py index 4c541e5c23b..6ca691976c8 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_datetime_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_datetime_parser.py @@ -19,24 +19,20 @@ def test_datetime_output_parser_parse() -> None: parser.format = "%Y-%m-%dT%H:%M:%S" datestr = date.strftime(parser.format) result = parser.parse(datestr) - assert ( - result.year == date.year - and result.month == date.month - and result.day == date.day - and result.hour == date.hour - and result.minute == date.minute - and result.second == date.second - ) + assert result.year == date.year + assert result.month == date.month + assert result.day == date.day + assert result.hour == date.hour + assert result.minute == date.minute + assert result.second == date.second # Test valid input parser.format = "%H:%M:%S" datestr = date.strftime(parser.format) result = parser.parse(datestr) - assert ( - result.hour == date.hour - and result.minute == date.minute - and result.second == date.second - ) + assert result.hour == date.hour + assert result.minute == date.minute + assert result.second == date.second # Test invalid input with pytest.raises(OutputParserException): diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py index aa3eff95c57..9b914bce786 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py @@ -148,7 +148,7 @@ def test_output_fixing_parser_output_type( @pytest.mark.parametrize( - "completion,base_parser,retry_chain,expected", + ("completion", "base_parser", "retry_chain", "expected"), [ ( "2024/07/08", @@ -185,7 +185,7 @@ def test_output_fixing_parser_parse_with_retry_chain( @pytest.mark.parametrize( - "completion,base_parser,retry_chain,expected", + ("completion", "base_parser", "retry_chain", "expected"), [ ( "2024/07/08", diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py index 5df3247abff..4dc8cbac55f 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py @@ -205,7 +205,7 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None: @pytest.mark.parametrize( - "completion,prompt,base_parser,retry_chain,expected", + ("completion", "prompt", "base_parser", "retry_chain", "expected"), [ ( "2024/07/08", @@ -233,7 +233,7 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain( @pytest.mark.parametrize( - "completion,prompt,base_parser,retry_chain,expected", + ("completion", "prompt", "base_parser", "retry_chain", "expected"), [ ( "2024/07/08", @@ -262,7 +262,7 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain( @pytest.mark.parametrize( - "completion,prompt,base_parser,retry_chain,expected", + ("completion", "prompt", "base_parser", "retry_chain", "expected"), [ ( "2024/07/08", @@ -291,7 +291,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain( @pytest.mark.parametrize( - "completion,prompt,base_parser,retry_chain,expected", + ("completion", "prompt", "base_parser", "retry_chain", "expected"), [ ( "2024/07/08", diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py index 5d06f8e2c90..09742525290 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py @@ -87,14 +87,10 @@ def test_yaml_output_parser_fail() -> None: pydantic_object=TestModel, ) - try: + with pytest.raises(OutputParserException) as exc_info: yaml_parser.parse(DEF_RESULT_FAIL) - except OutputParserException as e: - print("parse_result:", e) # noqa: T201 - assert "Failed to parse TestModel from completion" in str(e) - else: - msg = "Expected OutputParserException" - raise AssertionError(msg) + + assert "Failed to parse TestModel from completion" in str(exc_info.value) def test_yaml_output_parser_output_type() -> None: diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py index 1bce2775b87..c9bdea85f28 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py @@ -5,7 +5,7 @@ from langchain.retrievers.multi_query import LineListOutputParser, _unique_docum @pytest.mark.parametrize( - "documents,expected", + ("documents", "expected"), [ ([], []), ([Document(page_content="foo")], [Document(page_content="foo")]), @@ -39,7 +39,7 @@ def test__unique_documents(documents: list[Document], expected: list[Document]) @pytest.mark.parametrize( - "text,expected", + ("text", "expected"), [ ("foo\nbar\nbaz", ["foo", "bar", "baz"]), ("foo\nbar\nbaz\n", ["foo", "bar", "baz"]), diff --git a/libs/langchain/tests/unit_tests/storage/test_filesystem.py b/libs/langchain/tests/unit_tests/storage/test_filesystem.py index c7cb14b823c..49c0e7a153d 100644 --- a/libs/langchain/tests/unit_tests/storage/test_filesystem.py +++ b/libs/langchain/tests/unit_tests/storage/test_filesystem.py @@ -30,7 +30,7 @@ def test_mset_and_mget(file_store: LocalFileStore) -> None: @pytest.mark.parametrize( - "chmod_dir_s, chmod_file_s", + ("chmod_dir_s", "chmod_file_s"), [("777", "666"), ("770", "660"), ("700", "600")], ) def test_mset_chmod(chmod_dir_s: str, chmod_file_s: str) -> None: diff --git a/libs/langchain/tests/unit_tests/test_formatting.py b/libs/langchain/tests/unit_tests/test_formatting.py index 5025c700d5b..ce1fe59999c 100644 --- a/libs/langchain/tests/unit_tests/test_formatting.py +++ b/libs/langchain/tests/unit_tests/test_formatting.py @@ -15,7 +15,11 @@ def test_valid_formatting() -> None: def test_does_not_allow_args() -> None: """Test formatting raises error when args are provided.""" template = "This is a {} test." - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="No arguments should be provided, " + "everything should be passed as keyword arguments.", + ): formatter.format(template, "good") diff --git a/libs/langchain/tests/unit_tests/test_utils.py b/libs/langchain/tests/unit_tests/test_utils.py index d94f2f276e7..a6aeff77fd5 100644 --- a/libs/langchain/tests/unit_tests/test_utils.py +++ b/libs/langchain/tests/unit_tests/test_utils.py @@ -7,5 +7,7 @@ def test_check_package_version_pass() -> None: def test_check_package_version_fail() -> None: - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Expected PyYAML version to be < 5.4.1. Received " + ): check_package_version("PyYAML", lt_version="5.4.1") diff --git a/libs/langchain/tests/unit_tests/utils/test_iter.py b/libs/langchain/tests/unit_tests/utils/test_iter.py index 85b577f10d1..c707ae8f314 100644 --- a/libs/langchain/tests/unit_tests/utils/test_iter.py +++ b/libs/langchain/tests/unit_tests/utils/test_iter.py @@ -3,7 +3,7 @@ from langchain_core.utils.iter import batch_iterate @pytest.mark.parametrize( - "input_size, input_iterable, expected_output", + ("input_size", "input_iterable", "expected_output"), [ (2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]), (3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),