mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
core: Add ruff rules PT (pytest) (#29381)
See https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
This commit is contained in:
parent
6896c863e8
commit
8a33402016
@ -103,7 +103,6 @@ ignore = [
|
||||
"PLC",
|
||||
"PLE",
|
||||
"PLR",
|
||||
"PT",
|
||||
"PYI",
|
||||
"RET",
|
||||
"RUF",
|
||||
|
@ -9,7 +9,7 @@ from langchain_core._api.beta_decorator import beta, warn_beta
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kwargs, expected_message",
|
||||
("kwargs", "expected_message"),
|
||||
[
|
||||
(
|
||||
{
|
||||
|
@ -13,7 +13,7 @@ from langchain_core._api.deprecation import (
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kwargs, expected_message",
|
||||
("kwargs", "expected_message"),
|
||||
[
|
||||
(
|
||||
{
|
||||
@ -404,7 +404,9 @@ def test_deprecated_method_pydantic() -> None:
|
||||
def test_raise_error_for_bad_decorator() -> None:
|
||||
"""Verify that errors raised on init rather than on use."""
|
||||
# Should not specify both `alternative` and `alternative_import`
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot specify both alternative and alternative_import"
|
||||
):
|
||||
|
||||
@deprecated(since="2.0.0", alternative="NewClass", alternative_import="hello")
|
||||
def deprecated_function() -> str:
|
||||
|
@ -28,7 +28,7 @@ def test_initialization() -> None:
|
||||
assert cache_with_maxsize._cache == {}
|
||||
assert cache_with_maxsize._maxsize == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="maxsize must be greater than 0"):
|
||||
InMemoryCache(maxsize=0)
|
||||
|
||||
|
||||
|
@ -6,7 +6,6 @@ from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from blockbuster import BlockBuster, blockbuster_ctx
|
||||
from pytest import Config, Function, Parser
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
|
||||
@ -36,7 +35,7 @@ def blockbuster() -> Iterator[BlockBuster]:
|
||||
yield bb
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
@ -50,7 +49,9 @@ def pytest_addoption(parser: Parser) -> None:
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
def pytest_collection_modifyitems(
|
||||
config: pytest.Config, items: Sequence[pytest.Function]
|
||||
) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
@ -118,7 +119,7 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:
|
||||
side_effect = (
|
||||
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
|
||||
|
@ -16,16 +16,16 @@ from langchain_core.indexing.in_memory import (
|
||||
|
||||
|
||||
class TestDocumentIndexerTestSuite(DocumentIndexerTestSuite):
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def index(self) -> Generator[DocumentIndex, None, None]:
|
||||
yield InMemoryDocumentIndex()
|
||||
yield InMemoryDocumentIndex() # noqa: PT022
|
||||
|
||||
|
||||
class TestAsyncDocumentIndexerTestSuite(AsyncDocumentIndexTestSuite):
|
||||
# Something funky is going on with mypy and async pytest fixture
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
async def index(self) -> AsyncGenerator[DocumentIndex, None]: # type: ignore
|
||||
yield InMemoryDocumentIndex()
|
||||
yield InMemoryDocumentIndex() # noqa: PT022
|
||||
|
||||
|
||||
def test_sync_retriever() -> None:
|
||||
|
@ -7,7 +7,7 @@ import pytest_asyncio
|
||||
from langchain_core.indexing import InMemoryRecordManager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def manager() -> InMemoryRecordManager:
|
||||
"""Initialize the test database and yield the TimestampedSet instance."""
|
||||
# Initialize and yield the TimestampedSet instance
|
||||
|
@ -466,11 +466,19 @@ def test_incremental_fails_with_bad_source_ids(
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source id key is required when cleanup mode is "
|
||||
"incremental or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
index(loader, record_manager, vector_store, cleanup="incremental")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source ids are required when cleanup mode "
|
||||
"is incremental or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
index(
|
||||
loader,
|
||||
@ -502,7 +510,11 @@ async def test_aincremental_fails_with_bad_source_ids(
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source id key is required when cleanup mode "
|
||||
"is incremental or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
loader,
|
||||
@ -511,7 +523,11 @@ async def test_aincremental_fails_with_bad_source_ids(
|
||||
cleanup="incremental",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source ids are required when cleanup mode "
|
||||
"is incremental or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
loader,
|
||||
@ -771,11 +787,19 @@ def test_scoped_full_fails_with_bad_source_ids(
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source id key is required when cleanup mode "
|
||||
"is incremental or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
index(loader, record_manager, vector_store, cleanup="scoped_full")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source ids are required when cleanup mode "
|
||||
"is incremental or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
index(
|
||||
loader,
|
||||
@ -807,11 +831,19 @@ async def test_ascoped_full_fails_with_bad_source_ids(
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source id key is required when cleanup mode "
|
||||
"is incremental or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(loader, arecord_manager, vector_store, cleanup="scoped_full")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Source ids are required when cleanup mode "
|
||||
"is incremental or scoped_full.",
|
||||
):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
loader,
|
||||
|
@ -99,6 +99,7 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None:
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="This test is failing due to a bug in the testing code")
|
||||
async def test_stream_error_callback() -> None:
|
||||
message = "test"
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import typing
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
@ -513,7 +514,14 @@ def test_trim_messages_bound_model_token_counter() -> None:
|
||||
|
||||
def test_trim_messages_bad_token_counter() -> None:
|
||||
trimmer = trim_messages(max_tokens=10, token_counter={})
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"'token_counter' expected to be a model that implements "
|
||||
"'get_num_tokens_from_messages()' or a function. "
|
||||
"Received object of type <class 'dict'>."
|
||||
),
|
||||
):
|
||||
trimmer.invoke([HumanMessage("foobar")])
|
||||
|
||||
|
||||
@ -852,7 +860,9 @@ def test_convert_to_messages_openai_refusal() -> None:
|
||||
assert actual == expected
|
||||
|
||||
# Raises error if content is missing.
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Message dict must contain 'role' and 'content' keys"
|
||||
):
|
||||
convert_to_messages([{"role": "assistant", "refusal": "9.1"}])
|
||||
|
||||
|
||||
|
@ -157,9 +157,10 @@ def test_pydantic_output_parser_fail() -> None:
|
||||
pydantic_object=TestModel
|
||||
)
|
||||
|
||||
with pytest.raises(OutputParserException) as e:
|
||||
with pytest.raises(
|
||||
OutputParserException, match="Failed to parse TestModel from completion"
|
||||
):
|
||||
pydantic_parser.parse(DEF_RESULT_FAIL)
|
||||
assert "Failed to parse TestModel from completion" in str(e)
|
||||
|
||||
|
||||
def test_pydantic_output_parser_type_inference() -> None:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Union, cast
|
||||
@ -165,15 +166,14 @@ def test_create_system_message_prompt_list_template_partial_variables_not_null()
|
||||
{variables}
|
||||
"""
|
||||
|
||||
try:
|
||||
graph_analyst_template = SystemMessagePromptTemplate.from_template(
|
||||
with pytest.raises(
|
||||
ValueError, match="Partial variables are not supported for list of templates."
|
||||
):
|
||||
_ = SystemMessagePromptTemplate.from_template(
|
||||
template=[graph_creator_content1, graph_creator_content2],
|
||||
input_variables=["variables"],
|
||||
partial_variables={"variables": "foo"},
|
||||
)
|
||||
graph_analyst_template.format(variables="foo")
|
||||
except ValueError as e:
|
||||
assert str(e) == "Partial variables are not supported for list of templates."
|
||||
|
||||
|
||||
def test_message_prompt_template_from_template_file() -> None:
|
||||
@ -330,7 +330,7 @@ def test_chat_prompt_template_from_messages_jinja2() -> None:
|
||||
|
||||
@pytest.mark.requires("jinja2")
|
||||
@pytest.mark.parametrize(
|
||||
"template_format,image_type_placeholder,image_data_placeholder",
|
||||
("template_format", "image_type_placeholder", "image_data_placeholder"),
|
||||
[
|
||||
("f-string", "{image_type}", "{image_data}"),
|
||||
("mustache", "{{image_type}}", "{{image_data}}"),
|
||||
@ -393,7 +393,12 @@ def test_chat_prompt_template_with_messages(
|
||||
|
||||
def test_chat_invalid_input_variables_extra() -> None:
|
||||
messages = [HumanMessage(content="foo")]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Got mismatched input_variables. Expected: set(). Got: ['foo']"
|
||||
),
|
||||
):
|
||||
ChatPromptTemplate(
|
||||
messages=messages, # type: ignore[arg-type]
|
||||
input_variables=["foo"],
|
||||
@ -407,7 +412,10 @@ def test_chat_invalid_input_variables_extra() -> None:
|
||||
|
||||
def test_chat_invalid_input_variables_missing() -> None:
|
||||
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("Got mismatched input_variables. Expected: {'foo'}. Got: []"),
|
||||
):
|
||||
ChatPromptTemplate(
|
||||
messages=messages, # type: ignore[arg-type]
|
||||
input_variables=[],
|
||||
@ -481,7 +489,7 @@ async def test_chat_from_role_strings() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args,expected",
|
||||
("args", "expected"),
|
||||
[
|
||||
(
|
||||
("human", "{question}"),
|
||||
@ -551,7 +559,7 @@ def test_chat_prompt_template_append_and_extend() -> None:
|
||||
|
||||
def test_convert_to_message_is_strict() -> None:
|
||||
"""Verify that _convert_to_message is strict."""
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Unexpected message type: meow."):
|
||||
# meow does not correspond to a valid message type.
|
||||
# this test is here to ensure that functionality to interpret `meow`
|
||||
# as a role is NOT added.
|
||||
@ -762,14 +770,20 @@ async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None:
|
||||
),
|
||||
]
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Loading images from 'path' has been removed as of 0.3.15 for security reasons.",
|
||||
):
|
||||
template.format_messages(
|
||||
name="R2D2",
|
||||
in_mem=in_mem,
|
||||
file_path="some/path",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Loading images from 'path' has been removed as of 0.3.15 for security reasons.",
|
||||
):
|
||||
await template.aformat_messages(
|
||||
name="R2D2",
|
||||
in_mem=in_mem,
|
||||
@ -869,10 +883,10 @@ def test_chat_prompt_message_dict() -> None:
|
||||
HumanMessage(content="bar"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Invalid template: False"):
|
||||
ChatPromptTemplate([{"role": "system", "content": False}])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Unexpected message type: foo."):
|
||||
ChatPromptTemplate([{"role": "foo", "content": "foo"}])
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test few shot prompt template."""
|
||||
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
@ -24,7 +25,7 @@ EXAMPLE_PROMPT = PromptTemplate(
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
@pytest.mark.requires("jinja2")
|
||||
def example_jinja2_prompt() -> tuple[PromptTemplate, list[dict[str, str]]]:
|
||||
example_template = "{{ word }}: {{ antonym }}"
|
||||
@ -74,7 +75,10 @@ def test_prompt_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
# Test when missing in suffix
|
||||
template = "This is a {foo} test."
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("check for mismatched or missing input parameters from []"),
|
||||
):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=[],
|
||||
suffix=template,
|
||||
@ -91,7 +95,10 @@ def test_prompt_missing_input_variables() -> None:
|
||||
|
||||
# Test when missing in prefix
|
||||
template = "This is a {foo} test."
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("check for mismatched or missing input parameters from []"),
|
||||
):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=[],
|
||||
suffix="foo",
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Test few shot prompt template."""
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
||||
@ -58,7 +60,10 @@ def test_prompttemplate_validation() -> None:
|
||||
{"question": "foo", "answer": "bar"},
|
||||
{"question": "baz", "answer": "foo"},
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("Got input_variables=[], but based on prefix/suffix expected"),
|
||||
):
|
||||
FewShotPromptWithTemplates(
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test functionality related to prompts."""
|
||||
|
||||
import re
|
||||
from typing import Any, Union
|
||||
from unittest import mock
|
||||
|
||||
@ -264,7 +265,10 @@ def test_prompt_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables: list = []
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("check for mismatched or missing input parameters from []"),
|
||||
):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, validate_template=True
|
||||
)
|
||||
@ -275,7 +279,10 @@ def test_prompt_missing_input_variables() -> None:
|
||||
|
||||
def test_prompt_empty_input_variable() -> None:
|
||||
"""Test error is raised when empty string input variable."""
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("check for mismatched or missing input parameters from ['']"),
|
||||
):
|
||||
PromptTemplate(input_variables=[""], template="{}", validate_template=True)
|
||||
|
||||
|
||||
@ -283,7 +290,13 @@ def test_prompt_wrong_input_variables() -> None:
|
||||
"""Test error is raised when name of input variable is wrong."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["bar"]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Invalid prompt schema; "
|
||||
"check for mismatched or missing input parameters from ['bar']"
|
||||
),
|
||||
):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, validate_template=True
|
||||
)
|
||||
@ -330,7 +343,7 @@ def test_prompt_invalid_template_format() -> None:
|
||||
"""Test initializing a prompt with invalid template format."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo"]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Unsupported template format: bar"):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables,
|
||||
template=template,
|
||||
@ -580,7 +593,7 @@ async def test_prompt_ainvoke_with_metadata() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected",
|
||||
("value", "expected"),
|
||||
[
|
||||
("0", "0"),
|
||||
(0, "0"),
|
||||
|
@ -330,7 +330,7 @@ test_cases = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runnable, cases", test_cases)
|
||||
@pytest.mark.parametrize(("runnable", "cases"), test_cases)
|
||||
def test_context_runnables(
|
||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
|
||||
) -> None:
|
||||
@ -342,7 +342,7 @@ def test_context_runnables(
|
||||
assert add(runnable.stream(cases[0].input)) == cases[0].output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runnable, cases", test_cases)
|
||||
@pytest.mark.parametrize(("runnable", "cases"), test_cases)
|
||||
async def test_context_runnables_async(
|
||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
|
||||
) -> None:
|
||||
@ -357,14 +357,19 @@ async def test_context_runnables_async(
|
||||
def test_runnable_context_seq_key_not_found() -> None:
|
||||
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Expected exactly one setter for context key foo"
|
||||
):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
def test_runnable_context_seq_key_order() -> None:
|
||||
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Context setter for key foo must be defined after all getters.",
|
||||
):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
@ -374,7 +379,9 @@ def test_runnable_context_deadlock() -> None:
|
||||
"foo": Context.setter("foo") | Context.getter("input"),
|
||||
} | RunnablePassthrough()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Deadlock detected between context keys foo and input"
|
||||
):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
@ -383,7 +390,9 @@ def test_runnable_context_seq_key_circular_ref() -> None:
|
||||
"bar": Context.setter(input=Context.getter("input"))
|
||||
} | Context.getter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Circular reference in context setter for key input"
|
||||
):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ from langchain_core.runnables import (
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def llm() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
@ -40,7 +40,7 @@ def llm() -> RunnableWithFallbacks:
|
||||
return error_llm.with_fallbacks([pass_llm])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def llm_multi() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
||||
@ -49,7 +49,7 @@ def llm_multi() -> RunnableWithFallbacks:
|
||||
return error_llm.with_fallbacks([error_llm_2, pass_llm])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def chain() -> Runnable:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
@ -70,7 +70,7 @@ def _dont_raise_error(inputs: dict) -> str:
|
||||
raise ValueError
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def chain_pass_exceptions() -> Runnable:
|
||||
fallback = RunnableLambda(_dont_raise_error)
|
||||
return {"text": RunnablePassthrough()} | RunnableLambda(
|
||||
@ -107,7 +107,8 @@ def _runnable(inputs: dict) -> str:
|
||||
if inputs["text"] == "foo":
|
||||
return "first"
|
||||
if "exception" not in inputs:
|
||||
raise ValueError
|
||||
msg = "missing exception"
|
||||
raise ValueError(msg)
|
||||
if inputs["text"] == "bar":
|
||||
return "second"
|
||||
if isinstance(inputs["exception"], ValueError):
|
||||
@ -128,7 +129,7 @@ def test_invoke_with_exception_key() -> None:
|
||||
runnable_with_single = runnable.with_fallbacks(
|
||||
[runnable], exception_key="exception"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="missing exception"):
|
||||
runnable_with_single.invoke({"text": "baz"})
|
||||
|
||||
actual = runnable_with_single.invoke({"text": "bar"})
|
||||
@ -149,7 +150,7 @@ async def test_ainvoke_with_exception_key() -> None:
|
||||
runnable_with_single = runnable.with_fallbacks(
|
||||
[runnable], exception_key="exception"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="missing exception"):
|
||||
await runnable_with_single.ainvoke({"text": "baz"})
|
||||
|
||||
actual = await runnable_with_single.ainvoke({"text": "bar"})
|
||||
@ -166,7 +167,7 @@ async def test_ainvoke_with_exception_key() -> None:
|
||||
|
||||
def test_batch() -> None:
|
||||
runnable = RunnableLambda(_runnable)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="missing exception"):
|
||||
runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||
actual = runnable.batch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
@ -210,7 +211,7 @@ def test_batch() -> None:
|
||||
|
||||
async def test_abatch() -> None:
|
||||
runnable = RunnableLambda(_runnable)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="missing exception"):
|
||||
await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||
actual = await runnable.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
@ -263,13 +264,15 @@ def _generate(input: Iterator) -> Iterator[str]:
|
||||
|
||||
|
||||
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
||||
raise ValueError
|
||||
msg = "immmediate error"
|
||||
raise ValueError(msg)
|
||||
yield ""
|
||||
|
||||
|
||||
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
||||
yield ""
|
||||
raise ValueError
|
||||
msg = "delayed error"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def test_fallbacks_stream() -> None:
|
||||
@ -278,10 +281,10 @@ def test_fallbacks_stream() -> None:
|
||||
)
|
||||
assert list(runnable.stream({})) == list("foo bar")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(
|
||||
[RunnableGenerator(_generate)]
|
||||
)
|
||||
with pytest.raises(ValueError, match="delayed error"):
|
||||
list(runnable.stream({}))
|
||||
|
||||
|
||||
@ -291,13 +294,15 @@ async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
|
||||
|
||||
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
raise ValueError
|
||||
msg = "immmediate error"
|
||||
raise ValueError(msg)
|
||||
yield ""
|
||||
|
||||
|
||||
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
yield ""
|
||||
raise ValueError
|
||||
msg = "delayed error"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
async def test_fallbacks_astream() -> None:
|
||||
@ -308,12 +313,11 @@ async def test_fallbacks_astream() -> None:
|
||||
async for c in runnable.astream({}):
|
||||
assert c == next(expected)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
|
||||
[RunnableGenerator(_agenerate)]
|
||||
)
|
||||
async for _ in runnable.astream({}):
|
||||
pass
|
||||
with pytest.raises(ValueError, match="delayed error"):
|
||||
_ = [_ async for _ in runnable.astream({})]
|
||||
|
||||
|
||||
class FakeStructuredOutputModel(BaseChatModel):
|
||||
|
@ -1,3 +1,4 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
@ -864,22 +865,24 @@ def test_get_output_messages_with_value_error() -> None:
|
||||
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||
excepted = (
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
|
||||
f" Got {illegal_bool_message}."
|
||||
)
|
||||
assert excepted in str(excinfo.value)
|
||||
),
|
||||
):
|
||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||
|
||||
illegal_int_message = 123
|
||||
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||
excepted = (
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
|
||||
f" Got {illegal_int_message}."
|
||||
)
|
||||
assert excepted in str(excinfo.value)
|
||||
),
|
||||
):
|
||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
import warnings
|
||||
@ -3827,13 +3828,13 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
||||
runnable = RunnableLambda(_lambda_mock)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="x is 1"):
|
||||
runnable.invoke(1)
|
||||
|
||||
assert _lambda_mock.call_count == 1
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="x is 1"):
|
||||
runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
@ -3852,7 +3853,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
assert _lambda_mock.call_count == 1 # did not retry
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="x is 1"):
|
||||
runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
@ -3892,13 +3893,13 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
||||
runnable = RunnableLambda(_lambda_mock)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="x is 1"):
|
||||
await runnable.ainvoke(1)
|
||||
|
||||
assert _lambda_mock.call_count == 1
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="x is 1"):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
@ -3918,7 +3919,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
assert _lambda_mock.call_count == 1 # did not retry
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="x is 1"):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
@ -3982,9 +3983,8 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check that the chain on error is invoked
|
||||
with pytest.raises(ValueError):
|
||||
for _ in RunnableLambda(raise_value_error).stream(1000, config=config):
|
||||
pass
|
||||
with pytest.raises(ValueError, match="x is too large"):
|
||||
_ = list(RunnableLambda(raise_value_error).stream(1000, config=config))
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||
@ -4061,9 +4061,13 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check that the chain on error is invoked
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in RunnableLambda(raise_value_error).astream(1000, config=config):
|
||||
pass
|
||||
with pytest.raises(ValueError, match="x is too large"):
|
||||
_ = [
|
||||
_
|
||||
async for _ in RunnableLambda(raise_value_error).astream(
|
||||
1000, config=config
|
||||
)
|
||||
]
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||
@ -4088,7 +4092,11 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
outputs: list[Any] = []
|
||||
for input in inputs:
|
||||
if input.startswith(self.fail_starts_with):
|
||||
outputs.append(ValueError())
|
||||
outputs.append(
|
||||
ValueError(
|
||||
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
outputs.append(input + "a")
|
||||
return outputs
|
||||
@ -4119,7 +4127,9 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
|
||||
# Test batch
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match=re.escape("ControlledExceptionRunnable(bar) fail for bara")
|
||||
):
|
||||
chain.batch(["foo", "bar", "baz", "qux"])
|
||||
|
||||
spy = mocker.spy(ControlledExceptionRunnable, "batch")
|
||||
@ -4155,32 +4165,44 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
|
||||
parent_run_foo = parent_runs[0]
|
||||
assert parent_run_foo.inputs["input"] == "foo"
|
||||
assert repr(ValueError()) in str(parent_run_foo.error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
|
||||
parent_run_foo.error
|
||||
)
|
||||
assert len(parent_run_foo.child_runs) == 4
|
||||
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
|
||||
parent_run_foo.child_runs[-1].error
|
||||
)
|
||||
|
||||
parent_run_bar = parent_runs[1]
|
||||
assert parent_run_bar.inputs["input"] == "bar"
|
||||
assert repr(ValueError()) in str(parent_run_bar.error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
|
||||
parent_run_bar.error
|
||||
)
|
||||
assert len(parent_run_bar.child_runs) == 2
|
||||
assert parent_run_bar.child_runs[0].error is None
|
||||
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
|
||||
parent_run_bar.child_runs[1].error
|
||||
)
|
||||
|
||||
parent_run_baz = parent_runs[2]
|
||||
assert parent_run_baz.inputs["input"] == "baz"
|
||||
assert repr(ValueError()) in str(parent_run_baz.error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
|
||||
parent_run_baz.error
|
||||
)
|
||||
assert len(parent_run_baz.child_runs) == 3
|
||||
|
||||
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
|
||||
parent_run_baz.child_runs[-1].error
|
||||
)
|
||||
|
||||
parent_run_qux = parent_runs[3]
|
||||
assert parent_run_qux.inputs["input"] == "qux"
|
||||
@ -4209,7 +4231,11 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
outputs: list[Any] = []
|
||||
for input in inputs:
|
||||
if input.startswith(self.fail_starts_with):
|
||||
outputs.append(ValueError())
|
||||
outputs.append(
|
||||
ValueError(
|
||||
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
outputs.append(input + "a")
|
||||
return outputs
|
||||
@ -4240,7 +4266,9 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
|
||||
# Test abatch
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match=re.escape("ControlledExceptionRunnable(bar) fail for bara")
|
||||
):
|
||||
await chain.abatch(["foo", "bar", "baz", "qux"])
|
||||
|
||||
spy = mocker.spy(ControlledExceptionRunnable, "abatch")
|
||||
@ -4278,31 +4306,43 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
|
||||
parent_run_foo = parent_runs[0]
|
||||
assert parent_run_foo.inputs["input"] == "foo"
|
||||
assert repr(ValueError()) in str(parent_run_foo.error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
|
||||
parent_run_foo.error
|
||||
)
|
||||
assert len(parent_run_foo.child_runs) == 4
|
||||
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
|
||||
parent_run_foo.child_runs[-1].error
|
||||
)
|
||||
|
||||
parent_run_bar = parent_runs[1]
|
||||
assert parent_run_bar.inputs["input"] == "bar"
|
||||
assert repr(ValueError()) in str(parent_run_bar.error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
|
||||
parent_run_bar.error
|
||||
)
|
||||
assert len(parent_run_bar.child_runs) == 2
|
||||
assert parent_run_bar.child_runs[0].error is None
|
||||
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
|
||||
parent_run_bar.child_runs[1].error
|
||||
)
|
||||
|
||||
parent_run_baz = parent_runs[2]
|
||||
assert parent_run_baz.inputs["input"] == "baz"
|
||||
assert repr(ValueError()) in str(parent_run_baz.error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
|
||||
parent_run_baz.error
|
||||
)
|
||||
assert len(parent_run_baz.child_runs) == 3
|
||||
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
|
||||
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
|
||||
parent_run_baz.child_runs[-1].error
|
||||
)
|
||||
|
||||
parent_run_qux = parent_runs[3]
|
||||
assert parent_run_qux.inputs["input"] == "qux"
|
||||
@ -4319,11 +4359,15 @@ def test_runnable_branch_init() -> None:
|
||||
condition = RunnableLambda(lambda x: x > 0)
|
||||
|
||||
# Test failure with less than 2 branches
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="RunnableBranch requires at least two branches"
|
||||
):
|
||||
RunnableBranch((condition, add))
|
||||
|
||||
# Test failure with less than 2 branches
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="RunnableBranch requires at least two branches"
|
||||
):
|
||||
RunnableBranch(condition)
|
||||
|
||||
|
||||
@ -4408,7 +4452,7 @@ def test_runnable_branch_invoke() -> None:
|
||||
assert branch.invoke(10) == 100
|
||||
assert branch.invoke(0) == -1
|
||||
# Should raise an exception
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="x is too large"):
|
||||
branch.invoke(1000)
|
||||
|
||||
|
||||
@ -4472,7 +4516,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
|
||||
assert tracer.runs[0].outputs == {"output": 0}
|
||||
|
||||
# Check that the chain on end is invoked
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="x is too large"):
|
||||
branch.invoke(1000, config={"callbacks": [tracer]})
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
@ -4500,7 +4544,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
|
||||
assert tracer.runs[0].outputs == {"output": 0}
|
||||
|
||||
# Check that the chain on end is invoked
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="x is too large"):
|
||||
await branch.ainvoke(1000, config={"callbacks": [tracer]})
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
@ -4561,9 +4605,8 @@ def test_runnable_branch_stream_with_callbacks() -> None:
|
||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||
|
||||
# Verify that the chain on error is invoked
|
||||
with pytest.raises(ValueError):
|
||||
for _ in branch.stream("error", config=config):
|
||||
pass
|
||||
with pytest.raises(ValueError, match="x is error"):
|
||||
_ = list(branch.stream("error", config=config))
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert "ValueError('x is error')" in str(tracer.runs[1].error)
|
||||
@ -4638,9 +4681,8 @@ async def test_runnable_branch_astream_with_callbacks() -> None:
|
||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||
|
||||
# Verify that the chain on error is invoked
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in branch.astream("error", config=config):
|
||||
pass
|
||||
with pytest.raises(ValueError, match="x is error"):
|
||||
_ = [_ async for _ in branch.astream("error", config=config)]
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert "ValueError('x is error')" in str(tracer.runs[1].error)
|
||||
|
@ -1824,8 +1824,7 @@ async def test_runnable_each() -> None:
|
||||
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
|
||||
pass
|
||||
_ = [_ async for _ in add_one_map.astream_events([1, 2, 3], version="v1")]
|
||||
|
||||
|
||||
async def test_events_astream_config() -> None:
|
||||
|
@ -1773,8 +1773,7 @@ async def test_runnable_each() -> None:
|
||||
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
async for _ in add_one_map.astream_events([1, 2, 3], version="v2"):
|
||||
pass
|
||||
_ = [_ async for _ in add_one_map.astream_events([1, 2, 3], version="v2")]
|
||||
|
||||
|
||||
async def test_events_astream_config() -> None:
|
||||
|
@ -419,7 +419,7 @@ class TestRunnableSequenceParallelTraceNesting:
|
||||
self._check_posts()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("parent_type", ("ls", "lc"))
|
||||
@pytest.mark.parametrize("parent_type", ["ls", "lc"])
|
||||
def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_client_ = Client(
|
||||
|
@ -15,7 +15,7 @@ from langchain_core.runnables.utils import (
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"func, expected_source",
|
||||
("func", "expected_source"),
|
||||
[
|
||||
(lambda x: x * 2, "lambda x: x * 2"),
|
||||
(lambda a, b: a + b, "lambda a, b: a + b"),
|
||||
@ -29,7 +29,7 @@ def test_get_lambda_source(func: Callable, expected_source: str) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,prefix,expected_output",
|
||||
("text", "prefix", "expected_output"),
|
||||
[
|
||||
("line 1\nline 2\nline 3", "1", "line 1\n line 2\n line 3"),
|
||||
("line 1\nline 2\nline 3", "ax", "line 1\n line 2\n line 3"),
|
||||
|
@ -184,7 +184,9 @@ def test_chat_message_chunks() -> None:
|
||||
"ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot concatenate ChatMessageChunks with different roles."
|
||||
):
|
||||
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
||||
role="Assistant", content=" indeed."
|
||||
)
|
||||
@ -290,7 +292,10 @@ def test_function_message_chunks() -> None:
|
||||
id="ai5", name="hello", content="I am indeed."
|
||||
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Cannot concatenate FunctionMessageChunks with different names.",
|
||||
):
|
||||
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
||||
name="bye", content=" indeed."
|
||||
)
|
||||
@ -303,7 +308,10 @@ def test_ai_message_chunks() -> None:
|
||||
"AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Cannot concatenate AIMessageChunks with different example values.",
|
||||
):
|
||||
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
||||
example=False, content=" indeed."
|
||||
)
|
||||
@ -320,30 +328,21 @@ class TestGetBufferString(unittest.TestCase):
|
||||
self.tool_calls_msg = AIMessage(content="tool")
|
||||
|
||||
def test_empty_input(self) -> None:
|
||||
self.assertEqual(get_buffer_string([]), "")
|
||||
assert get_buffer_string([]) == ""
|
||||
|
||||
def test_valid_single_message(self) -> None:
|
||||
expected_output = f"Human: {self.human_msg.content}"
|
||||
self.assertEqual(
|
||||
get_buffer_string([self.human_msg]),
|
||||
expected_output,
|
||||
)
|
||||
assert get_buffer_string([self.human_msg]) == expected_output
|
||||
|
||||
def test_custom_human_prefix(self) -> None:
|
||||
prefix = "H"
|
||||
expected_output = f"{prefix}: {self.human_msg.content}"
|
||||
self.assertEqual(
|
||||
get_buffer_string([self.human_msg], human_prefix="H"),
|
||||
expected_output,
|
||||
)
|
||||
assert get_buffer_string([self.human_msg], human_prefix="H") == expected_output
|
||||
|
||||
def test_custom_ai_prefix(self) -> None:
|
||||
prefix = "A"
|
||||
expected_output = f"{prefix}: {self.ai_msg.content}"
|
||||
self.assertEqual(
|
||||
get_buffer_string([self.ai_msg], ai_prefix="A"),
|
||||
expected_output,
|
||||
)
|
||||
assert get_buffer_string([self.ai_msg], ai_prefix="A") == expected_output
|
||||
|
||||
def test_multiple_msg(self) -> None:
|
||||
msgs = [
|
||||
@ -366,10 +365,7 @@ class TestGetBufferString(unittest.TestCase):
|
||||
"AI: tool",
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
get_buffer_string(msgs),
|
||||
expected_output,
|
||||
)
|
||||
assert get_buffer_string(msgs) == expected_output
|
||||
|
||||
|
||||
def test_multiple_msg() -> None:
|
||||
@ -991,7 +987,7 @@ def test_tool_message_str() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["first", "others", "expected"],
|
||||
("first", "others", "expected"),
|
||||
[
|
||||
("", [""], ""),
|
||||
("", [[]], [""]),
|
||||
|
@ -775,8 +775,8 @@ def test_exception_handling_callable() -> None:
|
||||
|
||||
|
||||
def test_exception_handling_non_tool_exception() -> None:
|
||||
_tool = _FakeExceptionTool(exception=ValueError())
|
||||
with pytest.raises(ValueError):
|
||||
_tool = _FakeExceptionTool(exception=ValueError("some error"))
|
||||
with pytest.raises(ValueError, match="some error"):
|
||||
_tool.run({})
|
||||
|
||||
|
||||
@ -806,8 +806,8 @@ async def test_async_exception_handling_callable() -> None:
|
||||
|
||||
|
||||
async def test_async_exception_handling_non_tool_exception() -> None:
|
||||
_tool = _FakeExceptionTool(exception=ValueError())
|
||||
with pytest.raises(ValueError):
|
||||
_tool = _FakeExceptionTool(exception=ValueError("some error"))
|
||||
with pytest.raises(ValueError, match="some error"):
|
||||
await _tool.arun({})
|
||||
|
||||
|
||||
@ -987,7 +987,7 @@ def test_optional_subset_model_rewrite() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs, expected",
|
||||
("inputs", "expected"),
|
||||
[
|
||||
# Check not required
|
||||
({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}),
|
||||
@ -1323,6 +1323,10 @@ def test_tool_invalid_docstrings() -> None:
|
||||
"""
|
||||
return bar
|
||||
|
||||
for func in {foo3, foo4}:
|
||||
with pytest.raises(ValueError, match="Found invalid Google-Style docstring."):
|
||||
_ = tool(func, parse_docstring=True)
|
||||
|
||||
def foo5(bar: str, baz: int) -> str:
|
||||
"""The foo.
|
||||
|
||||
@ -1332,9 +1336,10 @@ def test_tool_invalid_docstrings() -> None:
|
||||
"""
|
||||
return bar
|
||||
|
||||
for func in [foo3, foo4, foo5]:
|
||||
with pytest.raises(ValueError):
|
||||
_ = tool(func, parse_docstring=True)
|
||||
with pytest.raises(
|
||||
ValueError, match="Arg banana in docstring not found in function signature."
|
||||
):
|
||||
_ = tool(foo5, parse_docstring=True)
|
||||
|
||||
|
||||
def test_tool_annotated_descriptions() -> None:
|
||||
@ -2004,9 +2009,9 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
|
||||
@pytest.mark.parametrize(
|
||||
("obj", "expected"),
|
||||
[
|
||||
["foo", True],
|
||||
[valid_tool_result_blocks, True],
|
||||
[invalid_tool_result_blocks, False],
|
||||
("foo", True),
|
||||
(valid_tool_result_blocks, True),
|
||||
(invalid_tool_result_blocks, False),
|
||||
],
|
||||
)
|
||||
def test__is_message_content_type(obj: Any, expected: bool) -> None:
|
||||
@ -2268,7 +2273,8 @@ def test_imports() -> None:
|
||||
"InjectedToolArg",
|
||||
]
|
||||
for module_name in expected_all:
|
||||
assert hasattr(tools, module_name) and getattr(tools, module_name) is not None
|
||||
assert hasattr(tools, module_name)
|
||||
assert getattr(tools, module_name) is not None
|
||||
|
||||
|
||||
def test_structured_tool_direct_init() -> None:
|
||||
@ -2317,7 +2323,11 @@ def test_tool_injected_tool_call_id() -> None:
|
||||
}
|
||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="When tool includes an InjectedToolCallId argument, "
|
||||
"tool must always be invoked with a full model ToolCall",
|
||||
):
|
||||
assert foo.invoke({"x": 0})
|
||||
|
||||
@tool
|
||||
@ -2341,7 +2351,7 @@ def test_tool_uninjected_tool_call_id() -> None:
|
||||
"""Foo."""
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="1 validation error for foo"):
|
||||
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
|
||||
|
||||
assert foo.invoke(
|
||||
|
@ -22,6 +22,5 @@ def test_public_api() -> None:
|
||||
|
||||
# Assert that the object is actually present in the schema module
|
||||
for module_name in expected_all:
|
||||
assert (
|
||||
hasattr(schemas, module_name) and getattr(schemas, module_name) is not None
|
||||
)
|
||||
assert hasattr(schemas, module_name)
|
||||
assert getattr(schemas, module_name) is not None
|
||||
|
@ -6,7 +6,7 @@ from langchain_core.utils.aiter import abatch_iterate
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_size, input_iterable, expected_output",
|
||||
("input_size", "input_iterable", "expected_output"),
|
||||
[
|
||||
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
||||
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
||||
|
@ -51,7 +51,12 @@ def test_get_from_dict_or_env() -> None:
|
||||
|
||||
# Not the most obvious behavior, but
|
||||
# this is how it works right now
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Did not find not exists, "
|
||||
"please add an environment variable `__SOME_KEY_IN_ENV` which contains it, "
|
||||
"or pass `not exists` as a named parameter.",
|
||||
):
|
||||
assert (
|
||||
get_from_dict_or_env(
|
||||
{
|
||||
|
@ -37,7 +37,7 @@ from langchain_core.utils.function_calling import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel): # noqa: N801
|
||||
"""Dummy function."""
|
||||
@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def annotated_function() -> Callable:
|
||||
def dummy_function(
|
||||
arg1: ExtensionsAnnotated[int, "foo"],
|
||||
@ -59,7 +59,7 @@ def annotated_function() -> Callable:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def function() -> Callable:
|
||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||
"""Dummy function.
|
||||
@ -72,7 +72,7 @@ def function() -> Callable:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def function_docstring_annotations() -> Callable:
|
||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||
"""Dummy function.
|
||||
@ -85,7 +85,7 @@ def function_docstring_annotations() -> Callable:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def runnable() -> Runnable:
|
||||
class Args(ExtensionsTypedDict):
|
||||
arg1: ExtensionsAnnotated[int, "foo"]
|
||||
@ -97,7 +97,7 @@ def runnable() -> Runnable:
|
||||
return RunnableLambda(dummy_function)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_tool() -> BaseTool:
|
||||
class Schema(BaseModel):
|
||||
arg1: int = Field(..., description="foo")
|
||||
@ -114,7 +114,7 @@ def dummy_tool() -> BaseTool:
|
||||
return DummyFunction()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_structured_tool() -> StructuredTool:
|
||||
class Schema(BaseModel):
|
||||
arg1: int = Field(..., description="foo")
|
||||
@ -128,7 +128,7 @@ def dummy_structured_tool() -> StructuredTool:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
||||
args_schema = {
|
||||
"type": "object",
|
||||
@ -150,7 +150,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel): # noqa: N801
|
||||
"""Dummy function."""
|
||||
@ -161,7 +161,7 @@ def dummy_pydantic() -> type[BaseModel]:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
||||
class dummy_function(BaseModelV2Maybe): # noqa: N801
|
||||
"""Dummy function."""
|
||||
@ -174,7 +174,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_typing_typed_dict() -> type:
|
||||
class dummy_function(TypingTypedDict): # noqa: N801
|
||||
"""Dummy function."""
|
||||
@ -185,7 +185,7 @@ def dummy_typing_typed_dict() -> type:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_typing_typed_dict_docstring() -> type:
|
||||
class dummy_function(TypingTypedDict): # noqa: N801
|
||||
"""Dummy function.
|
||||
@ -201,7 +201,7 @@ def dummy_typing_typed_dict_docstring() -> type:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_extensions_typed_dict() -> type:
|
||||
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
||||
"""Dummy function."""
|
||||
@ -212,7 +212,7 @@ def dummy_extensions_typed_dict() -> type:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def dummy_extensions_typed_dict_docstring() -> type:
|
||||
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
||||
"""Dummy function.
|
||||
@ -228,7 +228,7 @@ def dummy_extensions_typed_dict_docstring() -> type:
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def json_schema() -> dict:
|
||||
return {
|
||||
"title": "dummy_function",
|
||||
@ -246,7 +246,7 @@ def json_schema() -> dict:
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def anthropic_tool() -> dict:
|
||||
return {
|
||||
"name": "dummy_function",
|
||||
@ -266,7 +266,7 @@ def anthropic_tool() -> dict:
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def bedrock_converse_tool() -> dict:
|
||||
return {
|
||||
"toolSpec": {
|
||||
|
@ -4,7 +4,7 @@ from langchain_core.utils.iter import batch_iterate
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_size, input_iterable, expected_output",
|
||||
("input_size", "input_iterable", "expected_output"),
|
||||
[
|
||||
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
||||
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
||||
|
@ -147,7 +147,7 @@ def test_dereference_refs_remote_ref() -> None:
|
||||
"first_name": {"$ref": "https://somewhere/else/name"},
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="ref paths are expected to be URI fragments"):
|
||||
dereference_refs(schema)
|
||||
|
||||
|
||||
|
@ -220,7 +220,7 @@ output5 = {
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema, output",
|
||||
("schema", "output"),
|
||||
[
|
||||
(schema1, output1),
|
||||
(schema2, output2),
|
||||
|
@ -29,12 +29,17 @@ def test_dict_int_op_nested() -> None:
|
||||
def test_dict_int_op_max_depth_exceeded() -> None:
|
||||
left = {"a": {"b": {"c": 1}}}
|
||||
right = {"a": {"b": {"c": 2}}}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError, match="max_depth=2 exceeded, unable to combine dicts."
|
||||
):
|
||||
_dict_int_op(left, right, operator.add, max_depth=2)
|
||||
|
||||
|
||||
def test_dict_int_op_invalid_types() -> None:
|
||||
left = {"a": 1, "b": "string"}
|
||||
right = {"a": 2, "b": 3}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Only dict and int values are supported.",
|
||||
):
|
||||
_dict_int_op(left, right, operator.add)
|
||||
|
@ -46,7 +46,7 @@ def test_check_package_version(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("left", "right", "expected"),
|
||||
(
|
||||
[
|
||||
# Merge `None` and `1`.
|
||||
({"a": None}, {"a": 1}, {"a": 1}),
|
||||
# Merge `1` and `None`.
|
||||
@ -111,7 +111,7 @@ def test_check_package_version(
|
||||
{"a": [{"idx": 0, "b": "f"}]},
|
||||
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_merge_dicts(
|
||||
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
|
||||
@ -130,7 +130,7 @@ def test_merge_dicts(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("left", "right", "expected"),
|
||||
(
|
||||
[
|
||||
# 'type' special key handling
|
||||
({"type": "foo"}, {"type": "foo"}, {"type": "foo"}),
|
||||
(
|
||||
@ -138,7 +138,7 @@ def test_merge_dicts(
|
||||
{"type": "bar"},
|
||||
pytest.raises(ValueError, match="Unable to merge."),
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.xfail(reason="Refactors to make in 0.3")
|
||||
def test_merge_dicts_0_3(
|
||||
@ -183,36 +183,32 @@ def test_guard_import(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("module_name", "pip_name", "package"),
|
||||
("module_name", "pip_name", "package", "expected_pip_name"),
|
||||
[
|
||||
("langchain_core.utilsW", None, None),
|
||||
("langchain_core.utilsW", "langchain-core-2", None),
|
||||
("langchain_core.utilsW", None, "langchain-coreWX"),
|
||||
("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"),
|
||||
("langchain_coreW", None, None), # ModuleNotFoundError
|
||||
("langchain_core.utilsW", None, None, "langchain-core"),
|
||||
("langchain_core.utilsW", "langchain-core-2", None, "langchain-core-2"),
|
||||
("langchain_core.utilsW", None, "langchain-coreWX", "langchain-core"),
|
||||
(
|
||||
"langchain_core.utilsW",
|
||||
"langchain-core-2",
|
||||
"langchain-coreWX",
|
||||
"langchain-core-2",
|
||||
),
|
||||
("langchain_coreW", None, None, "langchain-coreW"), # ModuleNotFoundError
|
||||
],
|
||||
)
|
||||
def test_guard_import_failure(
|
||||
module_name: str, pip_name: Optional[str], package: Optional[str]
|
||||
module_name: str,
|
||||
pip_name: Optional[str],
|
||||
package: Optional[str],
|
||||
expected_pip_name: str,
|
||||
) -> None:
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
if package is None and pip_name is None:
|
||||
guard_import(module_name)
|
||||
elif package is None and pip_name is not None:
|
||||
guard_import(module_name, pip_name=pip_name)
|
||||
elif package is not None and pip_name is None:
|
||||
guard_import(module_name, package=package)
|
||||
elif package is not None and pip_name is not None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match=f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {expected_pip_name}`.",
|
||||
):
|
||||
guard_import(module_name, pip_name=pip_name, package=package)
|
||||
else:
|
||||
msg = "Invalid test case"
|
||||
raise ValueError(msg)
|
||||
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
||||
err_msg = (
|
||||
f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {pip_name}`."
|
||||
)
|
||||
assert exc_info.value.msg == err_msg
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
||||
|
Loading…
Reference in New Issue
Block a user