mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-22 19:08:40 +00:00
feat(langchain): add ruff rules PT (#32010)
See https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
This commit is contained in:
parent
095f4a7c28
commit
8b8d90bea5
@ -167,6 +167,7 @@ select = [
|
||||
"PGH", # pygrep-hooks
|
||||
"PIE", # flake8-pie
|
||||
"PERF", # flake8-perf
|
||||
"PT", # flake8-pytest-style
|
||||
"PTH", # flake8-use-pathlib
|
||||
"PYI", # flake8-pyi
|
||||
"Q", # flake8-quotes
|
||||
|
@ -9,7 +9,7 @@ from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider, model",
|
||||
("provider", "model"),
|
||||
[
|
||||
("openai", "text-embedding-3-large"),
|
||||
("google_vertexai", "text-embedding-gecko@003"),
|
||||
|
@ -1,9 +1,6 @@
|
||||
import unittest
|
||||
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.types import AGENT_TO_CLASS
|
||||
|
||||
|
||||
class TestTypes(unittest.TestCase):
|
||||
def test_confirm_full_coverage(self) -> None:
|
||||
self.assertEqual(list(AgentType), list(AGENT_TO_CLASS.keys()))
|
||||
def test_confirm_full_coverage() -> None:
|
||||
assert list(AgentType) == list(AGENT_TO_CLASS.keys())
|
||||
|
@ -16,7 +16,7 @@ from langchain.chains.query_constructor.parser import get_parser
|
||||
DEFAULT_PARSER = get_parser()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")'))
|
||||
@pytest.mark.parametrize("x", ["", "foo", 'foo("bar", "baz")'])
|
||||
def test_parse_invalid_grammar(x: str) -> None:
|
||||
with pytest.raises((ValueError, lark.exceptions.UnexpectedToken)):
|
||||
DEFAULT_PARSER.parse(x)
|
||||
@ -71,13 +71,13 @@ def test_parse_nested_operation() -> None:
|
||||
|
||||
def test_parse_disallowed_comparator() -> None:
|
||||
parser = get_parser(allowed_comparators=[Comparator.EQ])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Received disallowed comparator gt."):
|
||||
parser.parse('gt("a", 2)')
|
||||
|
||||
|
||||
def test_parse_disallowed_operator() -> None:
|
||||
parser = get_parser(allowed_operators=[Operator.AND])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Received disallowed operator not."):
|
||||
parser.parse('not(gt("a", 2))')
|
||||
|
||||
|
||||
@ -87,29 +87,29 @@ def _test_parse_value(x: Any) -> None:
|
||||
assert actual == x
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", (-1, 0, 1_000_000))
|
||||
@pytest.mark.parametrize("x", [-1, 0, 1_000_000])
|
||||
def test_parse_int_value(x: int) -> None:
|
||||
_test_parse_value(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210))
|
||||
@pytest.mark.parametrize("x", [-1.001, 0.00000002, 1_234_567.6543210])
|
||||
def test_parse_float_value(x: float) -> None:
|
||||
_test_parse_value(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ([], [1, "b", "true"]))
|
||||
@pytest.mark.parametrize("x", [[], [1, "b", "true"]])
|
||||
def test_parse_list_value(x: list) -> None:
|
||||
_test_parse_value(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ('""', '" "', '"foo"', "'foo'"))
|
||||
@pytest.mark.parametrize("x", ['""', '" "', '"foo"', "'foo'"])
|
||||
def test_parse_string_value(x: str) -> None:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
actual = parsed.value
|
||||
assert actual == x[1:-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ("true", "True", "TRUE", "false", "False", "FALSE"))
|
||||
@pytest.mark.parametrize("x", ["true", "True", "TRUE", "false", "False", "FALSE"])
|
||||
def test_parse_bool_value(x: str) -> None:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
actual = parsed.value
|
||||
@ -117,15 +117,15 @@ def test_parse_bool_value(x: str) -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op", ("and", "or"))
|
||||
@pytest.mark.parametrize("arg", ('eq("foo", 2)', 'and(eq("foo", 2), lte("bar", 1.1))'))
|
||||
@pytest.mark.parametrize("op", ["and", "or"])
|
||||
@pytest.mark.parametrize("arg", ['eq("foo", 2)', 'and(eq("foo", 2), lte("bar", 1.1))'])
|
||||
def test_parser_unpack_single_arg_operation(op: str, arg: str) -> None:
|
||||
expected = DEFAULT_PARSER.parse(arg)
|
||||
actual = DEFAULT_PARSER.parse(f"{op}({arg})")
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ('"2022-10-20"', "'2022-10-20'", "2022-10-20"))
|
||||
@pytest.mark.parametrize("x", ['"2022-10-20"', "'2022-10-20'", "2022-10-20"])
|
||||
def test_parse_date_value(x: str) -> None:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
actual = parsed.value["date"]
|
||||
@ -133,7 +133,7 @@ def test_parse_date_value(x: str) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"x, expected",
|
||||
("x", "expected"),
|
||||
[
|
||||
(
|
||||
'"2021-01-01T00:00:00"',
|
||||
|
@ -8,7 +8,7 @@ GOOD_SCORE = "foo bar answer.\nScore: 80"
|
||||
SCORE_WITH_EXPLANATION = "foo bar answer.\nScore: 80 (fully answers the question, but could provide more detail on the specific error message)" # noqa: E501
|
||||
|
||||
|
||||
@pytest.mark.parametrize("answer", (GOOD_SCORE, SCORE_WITH_EXPLANATION))
|
||||
@pytest.mark.parametrize("answer", [GOOD_SCORE, SCORE_WITH_EXPLANATION])
|
||||
def test_parse_scores(answer: str) -> None:
|
||||
result = output_parser.parse(answer)
|
||||
|
||||
|
@ -65,14 +65,14 @@ class FakeChain(Chain):
|
||||
def test_bad_inputs() -> None:
|
||||
"""Test errors are raised if input keys are not found."""
|
||||
chain = FakeChain()
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Missing some input keys: {'foo'}"):
|
||||
chain({"foobar": "baz"})
|
||||
|
||||
|
||||
def test_bad_outputs() -> None:
|
||||
"""Test errors are raised if outputs keys are not found."""
|
||||
chain = FakeChain(be_correct=False)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Missing some output keys: {'bar'}"):
|
||||
chain({"foo": "baz"})
|
||||
|
||||
|
||||
@ -102,7 +102,7 @@ def test_single_input_correct() -> None:
|
||||
def test_single_input_error() -> None:
|
||||
"""Test passing single input errors as expected."""
|
||||
chain = FakeChain(the_input_keys=["foo", "bar"])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Missing some input keys:"):
|
||||
chain("bar")
|
||||
|
||||
|
||||
@ -116,7 +116,9 @@ def test_run_single_arg() -> None:
|
||||
def test_run_multiple_args_error() -> None:
|
||||
"""Test run method with multiple args errors as expected."""
|
||||
chain = FakeChain()
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="`run` supports only one positional argument."
|
||||
):
|
||||
chain.run("bar", "foo")
|
||||
|
||||
|
||||
@ -130,21 +132,28 @@ def test_run_kwargs() -> None:
|
||||
def test_run_kwargs_error() -> None:
|
||||
"""Test run method with kwargs errors as expected."""
|
||||
chain = FakeChain(the_input_keys=["foo", "bar"])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Missing some input keys: {'bar'}"):
|
||||
chain.run(foo="bar", baz="foo")
|
||||
|
||||
|
||||
def test_run_args_and_kwargs_error() -> None:
|
||||
"""Test run method with args and kwargs."""
|
||||
chain = FakeChain(the_input_keys=["foo", "bar"])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="`run` supported with either positional arguments "
|
||||
"or keyword arguments but not both.",
|
||||
):
|
||||
chain.run("bar", foo="bar")
|
||||
|
||||
|
||||
def test_multiple_output_keys_error() -> None:
|
||||
"""Test run with multiple output keys errors as expected."""
|
||||
chain = FakeChain(the_output_keys=["foo", "bar"])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="`run` not supported when there is not exactly one output key.",
|
||||
):
|
||||
chain.run("bar")
|
||||
|
||||
|
||||
@ -175,7 +184,7 @@ def test_run_with_callback_and_input_error() -> None:
|
||||
callbacks=[handler],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Missing some input keys: {'foo'}"):
|
||||
chain({"bar": "foo"})
|
||||
|
||||
assert handler.starts == 1
|
||||
@ -222,7 +231,7 @@ def test_run_with_callback_and_output_error() -> None:
|
||||
callbacks=[handler],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Missing some output keys: {'foo'}"):
|
||||
chain("foo")
|
||||
|
||||
assert handler.starts == 1
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test functionality related to combining documents."""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@ -30,7 +31,9 @@ def test_multiple_input_keys() -> None:
|
||||
def test__split_list_long_single_doc() -> None:
|
||||
"""Test splitting of a long single doc."""
|
||||
docs = [Document(page_content="foo" * 100)]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="A single document was longer than the context length"
|
||||
):
|
||||
split_list_of_docs(docs, _fake_docs_len_func, 100)
|
||||
|
||||
|
||||
@ -140,7 +143,17 @@ async def test_format_doc_missing_metadata() -> None:
|
||||
input_variables=["page_content", "bar"],
|
||||
template="{page_content}, {bar}",
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Document prompt requires documents to have metadata variables: ['bar']."
|
||||
),
|
||||
):
|
||||
format_document(doc, prompt)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Document prompt requires documents to have metadata variables: ['bar']."
|
||||
),
|
||||
):
|
||||
await aformat_document(doc, prompt)
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test conversation chain and memory."""
|
||||
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
@ -76,7 +77,9 @@ def test_conversation_chain_errors_bad_prompt() -> None:
|
||||
"""Test that conversation chain raise error with bad prompt."""
|
||||
llm = FakeLLM()
|
||||
prompt = PromptTemplate(input_variables=[], template="nothing here")
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Value error, Got unexpected prompt input variables."
|
||||
):
|
||||
ConversationChain(llm=llm, prompt=prompt)
|
||||
|
||||
|
||||
@ -85,7 +88,12 @@ def test_conversation_chain_errors_bad_variable() -> None:
|
||||
llm = FakeLLM()
|
||||
prompt = PromptTemplate(input_variables=["foo"], template="{foo}")
|
||||
memory = ConversationBufferMemory(memory_key="foo")
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Value error, The input key foo was also found in the memory keys (['foo'])"
|
||||
),
|
||||
):
|
||||
ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="foo")
|
||||
|
||||
|
||||
@ -106,18 +114,18 @@ def test_conversation_memory(memory: BaseMemory) -> None:
|
||||
memory.save_context(good_inputs, good_outputs)
|
||||
# This is a bad input because there are two variables that aren't the same as baz.
|
||||
bad_inputs = {"foo": "bar", "foo1": "bar"}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="One input key expected"):
|
||||
memory.save_context(bad_inputs, good_outputs)
|
||||
# This is a bad input because the only variable is the same as baz.
|
||||
bad_inputs = {"baz": "bar"}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match=re.escape("One input key expected got []")):
|
||||
memory.save_context(bad_inputs, good_outputs)
|
||||
# This is a bad output because it is empty.
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Got multiple output keys"):
|
||||
memory.save_context(good_inputs, {})
|
||||
# This is a bad output because there are two keys.
|
||||
bad_outputs = {"foo": "bar", "foo1": "bar"}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Got multiple output keys"):
|
||||
memory.save_context(good_inputs, bad_outputs)
|
||||
|
||||
|
||||
|
@ -39,5 +39,5 @@ def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None:
|
||||
@pytest.mark.requires("numexpr")
|
||||
def test_error(fake_llm_math_chain: LLMMathChain) -> None:
|
||||
"""Test question that raises error."""
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="unknown format from LLM: foo"):
|
||||
fake_llm_math_chain.run("foo")
|
||||
|
@ -5,7 +5,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,answer,sources",
|
||||
("text", "answer", "sources"),
|
||||
[
|
||||
(
|
||||
"This Agreement is governed by English law.\nSOURCES: 28-pl",
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test pipeline functionality."""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -94,7 +95,12 @@ def test_sequential_usage_memory() -> None:
|
||||
memory = SimpleMemory(memories={"zab": "rab", "foo": "rab"})
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Value error, The input key(s) foo are found in the Memory keys"
|
||||
),
|
||||
):
|
||||
SequentialChain( # type: ignore[call-arg]
|
||||
memory=memory,
|
||||
chains=[chain_1, chain_2],
|
||||
@ -136,7 +142,9 @@ def test_sequential_missing_inputs() -> None:
|
||||
"""Test error is raised when input variables are missing."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Value error, Missing required input keys: {'test'}"
|
||||
):
|
||||
# Also needs "test" as an input
|
||||
SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) # type: ignore[call-arg]
|
||||
|
||||
@ -145,7 +153,10 @@ def test_sequential_bad_outputs() -> None:
|
||||
"""Test error is raised when bad outputs are specified."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Value error, Expected output variables that were not found: {'test'}.",
|
||||
):
|
||||
# "test" is not present as an output variable.
|
||||
SequentialChain(
|
||||
chains=[chain_1, chain_2],
|
||||
@ -172,7 +183,9 @@ def test_sequential_overlapping_inputs() -> None:
|
||||
"""Test error is raised when input variables are overlapping."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Value error, Chain returned keys that already exist"
|
||||
):
|
||||
# "test" is specified as an input, but also is an output of one step
|
||||
SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"]) # type: ignore[call-arg]
|
||||
|
||||
@ -226,7 +239,10 @@ def test_multi_input_errors() -> None:
|
||||
"""Test simple sequential errors if multiple input variables are expected."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Value error, Chains used in SimplePipeline should all have one input",
|
||||
):
|
||||
SimpleSequentialChain(chains=[chain_1, chain_2])
|
||||
|
||||
|
||||
@ -234,5 +250,8 @@ def test_multi_output_errors() -> None:
|
||||
"""Test simple sequential errors if multiple output variables are expected."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Value error, Chains used in SimplePipeline should all have one output",
|
||||
):
|
||||
SimpleSequentialChain(chains=[chain_1, chain_2])
|
||||
|
@ -35,5 +35,5 @@ def test_transform_chain_bad_inputs() -> None:
|
||||
transform=dummy_transform,
|
||||
)
|
||||
input_dict = {"name": "Leroy", "last_name": "Jenkins"}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Missing some input keys: {'first_name'}"):
|
||||
_ = transform_chain(input_dict)
|
||||
|
@ -30,7 +30,7 @@ def test_all_imports() -> None:
|
||||
"langchain_groq",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
["model_name", "model_provider"],
|
||||
("model_name", "model_provider"),
|
||||
[
|
||||
("gpt-4o", "openai"),
|
||||
("claude-3-opus-20240229", "anthropic"),
|
||||
@ -57,7 +57,7 @@ def test_init_missing_dep() -> None:
|
||||
|
||||
|
||||
def test_init_unknown_provider() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Unsupported model_provider='bar'."):
|
||||
init_chat_model("foo", model_provider="bar")
|
||||
|
||||
|
||||
|
@ -5,7 +5,6 @@ from importlib import util
|
||||
|
||||
import pytest
|
||||
from blockbuster import blockbuster_ctx
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -40,7 +39,7 @@ def blockbuster() -> Iterator[None]:
|
||||
yield
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
@ -62,7 +61,9 @@ def pytest_addoption(parser: Parser) -> None:
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
def pytest_collection_modifyitems(
|
||||
config: pytest.Config, items: Sequence[pytest.Function]
|
||||
) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
@ -88,12 +88,9 @@ def test_infer_model_and_provider_errors() -> None:
|
||||
_infer_model_and_provider("model", provider="")
|
||||
|
||||
# Test invalid provider
|
||||
with pytest.raises(ValueError, match="is not supported"):
|
||||
with pytest.raises(ValueError, match="Provider 'invalid' is not supported.") as exc:
|
||||
_infer_model_and_provider("model", provider="invalid")
|
||||
|
||||
# Test provider list is in error
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_infer_model_and_provider("model", provider="invalid")
|
||||
for provider in _SUPPORTED_PROVIDERS:
|
||||
assert provider in str(exc.value)
|
||||
|
||||
|
@ -112,7 +112,9 @@ def test_labeled_pairwise_string_comparison_chain_missing_ref() -> None:
|
||||
sequential_responses=True,
|
||||
)
|
||||
chain = LabeledPairwiseStringEvalChain.from_llm(llm=llm)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="LabeledPairwiseStringEvalChain requires a reference string."
|
||||
):
|
||||
chain.evaluate_string_pairs(
|
||||
prediction="I like pie.",
|
||||
prediction_b="I love pie.",
|
||||
|
@ -23,7 +23,7 @@ def test_resolve_criteria_str() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,want",
|
||||
("text", "want"),
|
||||
[
|
||||
("Y", {"reasoning": "", "value": "Y", "score": 1}),
|
||||
(
|
||||
@ -91,7 +91,9 @@ def test_criteria_eval_chain_missing_reference() -> None:
|
||||
),
|
||||
criteria={"my criterion": "my criterion description"},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="LabeledCriteriaEvalChain requires a reference string."
|
||||
):
|
||||
chain.evaluate_strings(prediction="my prediction", input="my input")
|
||||
|
||||
|
||||
|
@ -91,7 +91,7 @@ def test_returns_expected_results(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output,expected",
|
||||
("output", "expected"),
|
||||
[
|
||||
(
|
||||
""" GRADE: CORRECT
|
||||
|
@ -26,13 +26,15 @@ Rating: [[10]]"""
|
||||
|
||||
text = """This answer is really good.
|
||||
Rating: 10"""
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Output must contain a double bracketed string"
|
||||
):
|
||||
output_parser.parse(text)
|
||||
|
||||
text = """This answer is really good.
|
||||
Rating: [[0]]"""
|
||||
# Not in range [1, 10]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="with the verdict between 1 and 10"):
|
||||
output_parser.parse(text)
|
||||
|
||||
|
||||
@ -69,7 +71,9 @@ def test_labeled_pairwise_string_comparison_chain_missing_ref() -> None:
|
||||
sequential_responses=True,
|
||||
)
|
||||
chain = LabeledScoreStringEvalChain.from_llm(llm=llm)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="LabeledScoreStringEvalChain requires a reference string."
|
||||
):
|
||||
chain.evaluate_strings(
|
||||
prediction="I like pie.",
|
||||
input="What is your favorite food?",
|
||||
|
@ -432,11 +432,19 @@ def test_incremental_fails_with_bad_source_ids(
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source id key is required when cleanup mode is incremental "
|
||||
"or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
index(loader, record_manager, vector_store, cleanup="incremental")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source ids are required when cleanup mode is incremental "
|
||||
"or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
index(
|
||||
loader,
|
||||
@ -470,7 +478,11 @@ async def test_aincremental_fails_with_bad_source_ids(
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source id key is required when cleanup mode is incremental "
|
||||
"or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
loader,
|
||||
@ -479,7 +491,11 @@ async def test_aincremental_fails_with_bad_source_ids(
|
||||
cleanup="incremental",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source ids are required when cleanup mode is incremental "
|
||||
"or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
loader,
|
||||
|
@ -34,5 +34,9 @@ def test_basic_functionality(example_memory: list[ConversationBufferMemory]) ->
|
||||
|
||||
def test_repeated_memory_var(example_memory: list[ConversationBufferMemory]) -> None:
|
||||
"""Test raising error when repeated memory variables found"""
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Value error, The same variables {'bar'} are found in "
|
||||
"multiplememory object, which is not allowed by CombinedMemory.",
|
||||
):
|
||||
CombinedMemory(memories=[example_memory[1], example_memory[2]])
|
||||
|
@ -31,13 +31,21 @@ def test_boolean_output_parser_parse() -> None:
|
||||
assert result is True
|
||||
|
||||
# Test ambiguous input
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Ambiguous response. Both YES and NO in received: YES NO."
|
||||
):
|
||||
parser.parse("YES NO")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Ambiguous response. Both YES and NO in received: NO YES."
|
||||
):
|
||||
parser.parse("NO YES")
|
||||
# Bad input
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="BooleanOutputParser expected output value to include either YES or NO. "
|
||||
"Received BOOM.",
|
||||
):
|
||||
parser.parse("BOOM")
|
||||
|
||||
|
||||
|
@ -19,24 +19,20 @@ def test_datetime_output_parser_parse() -> None:
|
||||
parser.format = "%Y-%m-%dT%H:%M:%S"
|
||||
datestr = date.strftime(parser.format)
|
||||
result = parser.parse(datestr)
|
||||
assert (
|
||||
result.year == date.year
|
||||
and result.month == date.month
|
||||
and result.day == date.day
|
||||
and result.hour == date.hour
|
||||
and result.minute == date.minute
|
||||
and result.second == date.second
|
||||
)
|
||||
assert result.year == date.year
|
||||
assert result.month == date.month
|
||||
assert result.day == date.day
|
||||
assert result.hour == date.hour
|
||||
assert result.minute == date.minute
|
||||
assert result.second == date.second
|
||||
|
||||
# Test valid input
|
||||
parser.format = "%H:%M:%S"
|
||||
datestr = date.strftime(parser.format)
|
||||
result = parser.parse(datestr)
|
||||
assert (
|
||||
result.hour == date.hour
|
||||
and result.minute == date.minute
|
||||
and result.second == date.second
|
||||
)
|
||||
assert result.hour == date.hour
|
||||
assert result.minute == date.minute
|
||||
assert result.second == date.second
|
||||
|
||||
# Test invalid input
|
||||
with pytest.raises(OutputParserException):
|
||||
|
@ -148,7 +148,7 @@ def test_output_fixing_parser_output_type(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"completion,base_parser,retry_chain,expected",
|
||||
("completion", "base_parser", "retry_chain", "expected"),
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
@ -185,7 +185,7 @@ def test_output_fixing_parser_parse_with_retry_chain(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"completion,base_parser,retry_chain,expected",
|
||||
("completion", "base_parser", "retry_chain", "expected"),
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
|
@ -205,7 +205,7 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"completion,prompt,base_parser,retry_chain,expected",
|
||||
("completion", "prompt", "base_parser", "retry_chain", "expected"),
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
@ -233,7 +233,7 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"completion,prompt,base_parser,retry_chain,expected",
|
||||
("completion", "prompt", "base_parser", "retry_chain", "expected"),
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
@ -262,7 +262,7 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"completion,prompt,base_parser,retry_chain,expected",
|
||||
("completion", "prompt", "base_parser", "retry_chain", "expected"),
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
@ -291,7 +291,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"completion,prompt,base_parser,retry_chain,expected",
|
||||
("completion", "prompt", "base_parser", "retry_chain", "expected"),
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
|
@ -87,14 +87,10 @@ def test_yaml_output_parser_fail() -> None:
|
||||
pydantic_object=TestModel,
|
||||
)
|
||||
|
||||
try:
|
||||
with pytest.raises(OutputParserException) as exc_info:
|
||||
yaml_parser.parse(DEF_RESULT_FAIL)
|
||||
except OutputParserException as e:
|
||||
print("parse_result:", e) # noqa: T201
|
||||
assert "Failed to parse TestModel from completion" in str(e)
|
||||
else:
|
||||
msg = "Expected OutputParserException"
|
||||
raise AssertionError(msg)
|
||||
|
||||
assert "Failed to parse TestModel from completion" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_yaml_output_parser_output_type() -> None:
|
||||
|
@ -5,7 +5,7 @@ from langchain.retrievers.multi_query import LineListOutputParser, _unique_docum
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"documents,expected",
|
||||
("documents", "expected"),
|
||||
[
|
||||
([], []),
|
||||
([Document(page_content="foo")], [Document(page_content="foo")]),
|
||||
@ -39,7 +39,7 @@ def test__unique_documents(documents: list[Document], expected: list[Document])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,expected",
|
||||
("text", "expected"),
|
||||
[
|
||||
("foo\nbar\nbaz", ["foo", "bar", "baz"]),
|
||||
("foo\nbar\nbaz\n", ["foo", "bar", "baz"]),
|
||||
|
@ -30,7 +30,7 @@ def test_mset_and_mget(file_store: LocalFileStore) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chmod_dir_s, chmod_file_s",
|
||||
("chmod_dir_s", "chmod_file_s"),
|
||||
[("777", "666"), ("770", "660"), ("700", "600")],
|
||||
)
|
||||
def test_mset_chmod(chmod_dir_s: str, chmod_file_s: str) -> None:
|
||||
|
@ -15,7 +15,11 @@ def test_valid_formatting() -> None:
|
||||
def test_does_not_allow_args() -> None:
|
||||
"""Test formatting raises error when args are provided."""
|
||||
template = "This is a {} test."
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="No arguments should be provided, "
|
||||
"everything should be passed as keyword arguments.",
|
||||
):
|
||||
formatter.format(template, "good")
|
||||
|
||||
|
||||
|
@ -7,5 +7,7 @@ def test_check_package_version_pass() -> None:
|
||||
|
||||
|
||||
def test_check_package_version_fail() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Expected PyYAML version to be < 5.4.1. Received "
|
||||
):
|
||||
check_package_version("PyYAML", lt_version="5.4.1")
|
||||
|
@ -3,7 +3,7 @@ from langchain_core.utils.iter import batch_iterate
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_size, input_iterable, expected_output",
|
||||
("input_size", "input_iterable", "expected_output"),
|
||||
[
|
||||
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
||||
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
||||
|
Loading…
Reference in New Issue
Block a user