core: Add ruff rules PT (pytest) (#29381)

See https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
This commit is contained in:
Christophe Bornet 2025-04-01 19:31:07 +02:00 committed by GitHub
parent 6896c863e8
commit 8a33402016
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 379 additions and 227 deletions

View File

@ -103,7 +103,6 @@ ignore = [
"PLC", "PLC",
"PLE", "PLE",
"PLR", "PLR",
"PT",
"PYI", "PYI",
"RET", "RET",
"RUF", "RUF",

View File

@ -9,7 +9,7 @@ from langchain_core._api.beta_decorator import beta, warn_beta
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kwargs, expected_message", ("kwargs", "expected_message"),
[ [
( (
{ {

View File

@ -13,7 +13,7 @@ from langchain_core._api.deprecation import (
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kwargs, expected_message", ("kwargs", "expected_message"),
[ [
( (
{ {
@ -404,7 +404,9 @@ def test_deprecated_method_pydantic() -> None:
def test_raise_error_for_bad_decorator() -> None: def test_raise_error_for_bad_decorator() -> None:
"""Verify that errors raised on init rather than on use.""" """Verify that errors raised on init rather than on use."""
# Should not specify both `alternative` and `alternative_import` # Should not specify both `alternative` and `alternative_import`
with pytest.raises(ValueError): with pytest.raises(
ValueError, match="Cannot specify both alternative and alternative_import"
):
@deprecated(since="2.0.0", alternative="NewClass", alternative_import="hello") @deprecated(since="2.0.0", alternative="NewClass", alternative_import="hello")
def deprecated_function() -> str: def deprecated_function() -> str:

View File

@ -28,7 +28,7 @@ def test_initialization() -> None:
assert cache_with_maxsize._cache == {} assert cache_with_maxsize._cache == {}
assert cache_with_maxsize._maxsize == 2 assert cache_with_maxsize._maxsize == 2
with pytest.raises(ValueError): with pytest.raises(ValueError, match="maxsize must be greater than 0"):
InMemoryCache(maxsize=0) InMemoryCache(maxsize=0)

View File

@ -6,7 +6,6 @@ from uuid import UUID
import pytest import pytest
from blockbuster import BlockBuster, blockbuster_ctx from blockbuster import BlockBuster, blockbuster_ctx
from pytest import Config, Function, Parser
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
@ -36,7 +35,7 @@ def blockbuster() -> Iterator[BlockBuster]:
yield bb yield bb
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: pytest.Parser) -> None:
"""Add custom command line options to pytest.""" """Add custom command line options to pytest."""
parser.addoption( parser.addoption(
"--only-extended", "--only-extended",
@ -50,7 +49,9 @@ def pytest_addoption(parser: Parser) -> None:
) )
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: def pytest_collection_modifyitems(
config: pytest.Config, items: Sequence[pytest.Function]
) -> None:
"""Add implementations for handling custom markers. """Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker. At the moment, this adds support for a custom `requires` marker.
@ -118,7 +119,7 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
) )
@pytest.fixture() @pytest.fixture
def deterministic_uuids(mocker: MockerFixture) -> MockerFixture: def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:
side_effect = ( side_effect = (
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000) UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)

View File

@ -16,16 +16,16 @@ from langchain_core.indexing.in_memory import (
class TestDocumentIndexerTestSuite(DocumentIndexerTestSuite): class TestDocumentIndexerTestSuite(DocumentIndexerTestSuite):
@pytest.fixture() @pytest.fixture
def index(self) -> Generator[DocumentIndex, None, None]: def index(self) -> Generator[DocumentIndex, None, None]:
yield InMemoryDocumentIndex() yield InMemoryDocumentIndex() # noqa: PT022
class TestAsyncDocumentIndexerTestSuite(AsyncDocumentIndexTestSuite): class TestAsyncDocumentIndexerTestSuite(AsyncDocumentIndexTestSuite):
# Something funky is going on with mypy and async pytest fixture # Something funky is going on with mypy and async pytest fixture
@pytest.fixture() @pytest.fixture
async def index(self) -> AsyncGenerator[DocumentIndex, None]: # type: ignore async def index(self) -> AsyncGenerator[DocumentIndex, None]: # type: ignore
yield InMemoryDocumentIndex() yield InMemoryDocumentIndex() # noqa: PT022
def test_sync_retriever() -> None: def test_sync_retriever() -> None:

View File

@ -7,7 +7,7 @@ import pytest_asyncio
from langchain_core.indexing import InMemoryRecordManager from langchain_core.indexing import InMemoryRecordManager
@pytest.fixture() @pytest.fixture
def manager() -> InMemoryRecordManager: def manager() -> InMemoryRecordManager:
"""Initialize the test database and yield the TimestampedSet instance.""" """Initialize the test database and yield the TimestampedSet instance."""
# Initialize and yield the TimestampedSet instance # Initialize and yield the TimestampedSet instance

View File

@ -466,11 +466,19 @@ def test_incremental_fails_with_bad_source_ids(
] ]
) )
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Source id key is required when cleanup mode is "
"incremental or scoped_full.",
):
# Should raise an error because no source id function was specified # Should raise an error because no source id function was specified
index(loader, record_manager, vector_store, cleanup="incremental") index(loader, record_manager, vector_store, cleanup="incremental")
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Source ids are required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified # Should raise an error because no source id function was specified
index( index(
loader, loader,
@ -502,7 +510,11 @@ async def test_aincremental_fails_with_bad_source_ids(
] ]
) )
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Source id key is required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified # Should raise an error because no source id function was specified
await aindex( await aindex(
loader, loader,
@ -511,7 +523,11 @@ async def test_aincremental_fails_with_bad_source_ids(
cleanup="incremental", cleanup="incremental",
) )
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Source ids are required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified # Should raise an error because no source id function was specified
await aindex( await aindex(
loader, loader,
@ -771,11 +787,19 @@ def test_scoped_full_fails_with_bad_source_ids(
] ]
) )
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Source id key is required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified # Should raise an error because no source id function was specified
index(loader, record_manager, vector_store, cleanup="scoped_full") index(loader, record_manager, vector_store, cleanup="scoped_full")
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Source ids are required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified # Should raise an error because no source id function was specified
index( index(
loader, loader,
@ -807,11 +831,19 @@ async def test_ascoped_full_fails_with_bad_source_ids(
] ]
) )
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Source id key is required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified # Should raise an error because no source id function was specified
await aindex(loader, arecord_manager, vector_store, cleanup="scoped_full") await aindex(loader, arecord_manager, vector_store, cleanup="scoped_full")
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Source ids are required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified # Should raise an error because no source id function was specified
await aindex( await aindex(
loader, loader,

View File

@ -99,6 +99,7 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None:
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1 assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
@pytest.mark.xfail(reason="This test is failing due to a bug in the testing code")
async def test_stream_error_callback() -> None: async def test_stream_error_callback() -> None:
message = "test" message = "test"

View File

@ -1,5 +1,6 @@
import base64 import base64
import json import json
import re
import typing import typing
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
@ -513,7 +514,14 @@ def test_trim_messages_bound_model_token_counter() -> None:
def test_trim_messages_bad_token_counter() -> None: def test_trim_messages_bad_token_counter() -> None:
trimmer = trim_messages(max_tokens=10, token_counter={}) trimmer = trim_messages(max_tokens=10, token_counter={})
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match=re.escape(
"'token_counter' expected to be a model that implements "
"'get_num_tokens_from_messages()' or a function. "
"Received object of type <class 'dict'>."
),
):
trimmer.invoke([HumanMessage("foobar")]) trimmer.invoke([HumanMessage("foobar")])
@ -852,7 +860,9 @@ def test_convert_to_messages_openai_refusal() -> None:
assert actual == expected assert actual == expected
# Raises error if content is missing. # Raises error if content is missing.
with pytest.raises(ValueError): with pytest.raises(
ValueError, match="Message dict must contain 'role' and 'content' keys"
):
convert_to_messages([{"role": "assistant", "refusal": "9.1"}]) convert_to_messages([{"role": "assistant", "refusal": "9.1"}])

View File

@ -157,9 +157,10 @@ def test_pydantic_output_parser_fail() -> None:
pydantic_object=TestModel pydantic_object=TestModel
) )
with pytest.raises(OutputParserException) as e: with pytest.raises(
OutputParserException, match="Failed to parse TestModel from completion"
):
pydantic_parser.parse(DEF_RESULT_FAIL) pydantic_parser.parse(DEF_RESULT_FAIL)
assert "Failed to parse TestModel from completion" in str(e)
def test_pydantic_output_parser_type_inference() -> None: def test_pydantic_output_parser_type_inference() -> None:

View File

@ -1,3 +1,4 @@
import re
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Union, cast from typing import Any, Union, cast
@ -165,15 +166,14 @@ def test_create_system_message_prompt_list_template_partial_variables_not_null()
{variables} {variables}
""" """
try: with pytest.raises(
graph_analyst_template = SystemMessagePromptTemplate.from_template( ValueError, match="Partial variables are not supported for list of templates."
):
_ = SystemMessagePromptTemplate.from_template(
template=[graph_creator_content1, graph_creator_content2], template=[graph_creator_content1, graph_creator_content2],
input_variables=["variables"], input_variables=["variables"],
partial_variables={"variables": "foo"}, partial_variables={"variables": "foo"},
) )
graph_analyst_template.format(variables="foo")
except ValueError as e:
assert str(e) == "Partial variables are not supported for list of templates."
def test_message_prompt_template_from_template_file() -> None: def test_message_prompt_template_from_template_file() -> None:
@ -330,7 +330,7 @@ def test_chat_prompt_template_from_messages_jinja2() -> None:
@pytest.mark.requires("jinja2") @pytest.mark.requires("jinja2")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"template_format,image_type_placeholder,image_data_placeholder", ("template_format", "image_type_placeholder", "image_data_placeholder"),
[ [
("f-string", "{image_type}", "{image_data}"), ("f-string", "{image_type}", "{image_data}"),
("mustache", "{{image_type}}", "{{image_data}}"), ("mustache", "{{image_type}}", "{{image_data}}"),
@ -393,7 +393,12 @@ def test_chat_prompt_template_with_messages(
def test_chat_invalid_input_variables_extra() -> None: def test_chat_invalid_input_variables_extra() -> None:
messages = [HumanMessage(content="foo")] messages = [HumanMessage(content="foo")]
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match=re.escape(
"Got mismatched input_variables. Expected: set(). Got: ['foo']"
),
):
ChatPromptTemplate( ChatPromptTemplate(
messages=messages, # type: ignore[arg-type] messages=messages, # type: ignore[arg-type]
input_variables=["foo"], input_variables=["foo"],
@ -407,7 +412,10 @@ def test_chat_invalid_input_variables_extra() -> None:
def test_chat_invalid_input_variables_missing() -> None: def test_chat_invalid_input_variables_missing() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")] messages = [HumanMessagePromptTemplate.from_template("{foo}")]
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match=re.escape("Got mismatched input_variables. Expected: {'foo'}. Got: []"),
):
ChatPromptTemplate( ChatPromptTemplate(
messages=messages, # type: ignore[arg-type] messages=messages, # type: ignore[arg-type]
input_variables=[], input_variables=[],
@ -481,7 +489,7 @@ async def test_chat_from_role_strings() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"args,expected", ("args", "expected"),
[ [
( (
("human", "{question}"), ("human", "{question}"),
@ -551,7 +559,7 @@ def test_chat_prompt_template_append_and_extend() -> None:
def test_convert_to_message_is_strict() -> None: def test_convert_to_message_is_strict() -> None:
"""Verify that _convert_to_message is strict.""" """Verify that _convert_to_message is strict."""
with pytest.raises(ValueError): with pytest.raises(ValueError, match="Unexpected message type: meow."):
# meow does not correspond to a valid message type. # meow does not correspond to a valid message type.
# this test is here to ensure that functionality to interpret `meow` # this test is here to ensure that functionality to interpret `meow`
# as a role is NOT added. # as a role is NOT added.
@ -762,14 +770,20 @@ async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None:
), ),
] ]
) )
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Loading images from 'path' has been removed as of 0.3.15 for security reasons.",
):
template.format_messages( template.format_messages(
name="R2D2", name="R2D2",
in_mem=in_mem, in_mem=in_mem,
file_path="some/path", file_path="some/path",
) )
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Loading images from 'path' has been removed as of 0.3.15 for security reasons.",
):
await template.aformat_messages( await template.aformat_messages(
name="R2D2", name="R2D2",
in_mem=in_mem, in_mem=in_mem,
@ -869,10 +883,10 @@ def test_chat_prompt_message_dict() -> None:
HumanMessage(content="bar"), HumanMessage(content="bar"),
] ]
with pytest.raises(ValueError): with pytest.raises(ValueError, match="Invalid template: False"):
ChatPromptTemplate([{"role": "system", "content": False}]) ChatPromptTemplate([{"role": "system", "content": False}])
with pytest.raises(ValueError): with pytest.raises(ValueError, match="Unexpected message type: foo."):
ChatPromptTemplate([{"role": "foo", "content": "foo"}]) ChatPromptTemplate([{"role": "foo", "content": "foo"}])

View File

@ -1,5 +1,6 @@
"""Test few shot prompt template.""" """Test few shot prompt template."""
import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import Any
@ -24,7 +25,7 @@ EXAMPLE_PROMPT = PromptTemplate(
) )
@pytest.fixture() @pytest.fixture
@pytest.mark.requires("jinja2") @pytest.mark.requires("jinja2")
def example_jinja2_prompt() -> tuple[PromptTemplate, list[dict[str, str]]]: def example_jinja2_prompt() -> tuple[PromptTemplate, list[dict[str, str]]]:
example_template = "{{ word }}: {{ antonym }}" example_template = "{{ word }}: {{ antonym }}"
@ -74,7 +75,10 @@ def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided.""" """Test error is raised when input variables are not provided."""
# Test when missing in suffix # Test when missing in suffix
template = "This is a {foo} test." template = "This is a {foo} test."
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match=re.escape("check for mismatched or missing input parameters from []"),
):
FewShotPromptTemplate( FewShotPromptTemplate(
input_variables=[], input_variables=[],
suffix=template, suffix=template,
@ -91,7 +95,10 @@ def test_prompt_missing_input_variables() -> None:
# Test when missing in prefix # Test when missing in prefix
template = "This is a {foo} test." template = "This is a {foo} test."
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match=re.escape("check for mismatched or missing input parameters from []"),
):
FewShotPromptTemplate( FewShotPromptTemplate(
input_variables=[], input_variables=[],
suffix="foo", suffix="foo",

View File

@ -1,5 +1,7 @@
"""Test few shot prompt template.""" """Test few shot prompt template."""
import re
import pytest import pytest
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
@ -58,7 +60,10 @@ def test_prompttemplate_validation() -> None:
{"question": "foo", "answer": "bar"}, {"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"}, {"question": "baz", "answer": "foo"},
] ]
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match=re.escape("Got input_variables=[], but based on prefix/suffix expected"),
):
FewShotPromptWithTemplates( FewShotPromptWithTemplates(
suffix=suffix, suffix=suffix,
prefix=prefix, prefix=prefix,

View File

@ -1,5 +1,6 @@
"""Test functionality related to prompts.""" """Test functionality related to prompts."""
import re
from typing import Any, Union from typing import Any, Union
from unittest import mock from unittest import mock
@ -264,7 +265,10 @@ def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided.""" """Test error is raised when input variables are not provided."""
template = "This is a {foo} test." template = "This is a {foo} test."
input_variables: list = [] input_variables: list = []
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match=re.escape("check for mismatched or missing input parameters from []"),
):
PromptTemplate( PromptTemplate(
input_variables=input_variables, template=template, validate_template=True input_variables=input_variables, template=template, validate_template=True
) )
@ -275,7 +279,10 @@ def test_prompt_missing_input_variables() -> None:
def test_prompt_empty_input_variable() -> None: def test_prompt_empty_input_variable() -> None:
"""Test error is raised when empty string input variable.""" """Test error is raised when empty string input variable."""
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match=re.escape("check for mismatched or missing input parameters from ['']"),
):
PromptTemplate(input_variables=[""], template="{}", validate_template=True) PromptTemplate(input_variables=[""], template="{}", validate_template=True)
@ -283,7 +290,13 @@ def test_prompt_wrong_input_variables() -> None:
"""Test error is raised when name of input variable is wrong.""" """Test error is raised when name of input variable is wrong."""
template = "This is a {foo} test." template = "This is a {foo} test."
input_variables = ["bar"] input_variables = ["bar"]
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match=re.escape(
"Invalid prompt schema; "
"check for mismatched or missing input parameters from ['bar']"
),
):
PromptTemplate( PromptTemplate(
input_variables=input_variables, template=template, validate_template=True input_variables=input_variables, template=template, validate_template=True
) )
@ -330,7 +343,7 @@ def test_prompt_invalid_template_format() -> None:
"""Test initializing a prompt with invalid template format.""" """Test initializing a prompt with invalid template format."""
template = "This is a {foo} test." template = "This is a {foo} test."
input_variables = ["foo"] input_variables = ["foo"]
with pytest.raises(ValueError): with pytest.raises(ValueError, match="Unsupported template format: bar"):
PromptTemplate( PromptTemplate(
input_variables=input_variables, input_variables=input_variables,
template=template, template=template,
@ -580,7 +593,7 @@ async def test_prompt_ainvoke_with_metadata() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"value, expected", ("value", "expected"),
[ [
("0", "0"), ("0", "0"),
(0, "0"), (0, "0"),

View File

@ -330,7 +330,7 @@ test_cases = [
] ]
@pytest.mark.parametrize("runnable, cases", test_cases) @pytest.mark.parametrize(("runnable", "cases"), test_cases)
def test_context_runnables( def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None: ) -> None:
@ -342,7 +342,7 @@ def test_context_runnables(
assert add(runnable.stream(cases[0].input)) == cases[0].output assert add(runnable.stream(cases[0].input)) == cases[0].output
@pytest.mark.parametrize("runnable, cases", test_cases) @pytest.mark.parametrize(("runnable", "cases"), test_cases)
async def test_context_runnables_async( async def test_context_runnables_async(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None: ) -> None:
@ -357,14 +357,19 @@ async def test_context_runnables_async(
def test_runnable_context_seq_key_not_found() -> None: def test_runnable_context_seq_key_not_found() -> None:
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo") seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
with pytest.raises(ValueError): with pytest.raises(
ValueError, match="Expected exactly one setter for context key foo"
):
seq.invoke("foo") seq.invoke("foo")
def test_runnable_context_seq_key_order() -> None: def test_runnable_context_seq_key_order() -> None:
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo") seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Context setter for key foo must be defined after all getters.",
):
seq.invoke("foo") seq.invoke("foo")
@ -374,7 +379,9 @@ def test_runnable_context_deadlock() -> None:
"foo": Context.setter("foo") | Context.getter("input"), "foo": Context.setter("foo") | Context.getter("input"),
} | RunnablePassthrough() } | RunnablePassthrough()
with pytest.raises(ValueError): with pytest.raises(
ValueError, match="Deadlock detected between context keys foo and input"
):
seq.invoke("foo") seq.invoke("foo")
@ -383,7 +390,9 @@ def test_runnable_context_seq_key_circular_ref() -> None:
"bar": Context.setter(input=Context.getter("input")) "bar": Context.setter(input=Context.getter("input"))
} | Context.getter("foo") } | Context.getter("foo")
with pytest.raises(ValueError): with pytest.raises(
ValueError, match="Circular reference in context setter for key input"
):
seq.invoke("foo") seq.invoke("foo")

View File

@ -32,7 +32,7 @@ from langchain_core.runnables import (
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@pytest.fixture() @pytest.fixture
def llm() -> RunnableWithFallbacks: def llm() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1) error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"]) pass_llm = FakeListLLM(responses=["bar"])
@ -40,7 +40,7 @@ def llm() -> RunnableWithFallbacks:
return error_llm.with_fallbacks([pass_llm]) return error_llm.with_fallbacks([pass_llm])
@pytest.fixture() @pytest.fixture
def llm_multi() -> RunnableWithFallbacks: def llm_multi() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1) error_llm = FakeListLLM(responses=["foo"], i=1)
error_llm_2 = FakeListLLM(responses=["baz"], i=1) error_llm_2 = FakeListLLM(responses=["baz"], i=1)
@ -49,7 +49,7 @@ def llm_multi() -> RunnableWithFallbacks:
return error_llm.with_fallbacks([error_llm_2, pass_llm]) return error_llm.with_fallbacks([error_llm_2, pass_llm])
@pytest.fixture() @pytest.fixture
def chain() -> Runnable: def chain() -> Runnable:
error_llm = FakeListLLM(responses=["foo"], i=1) error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"]) pass_llm = FakeListLLM(responses=["bar"])
@ -70,7 +70,7 @@ def _dont_raise_error(inputs: dict) -> str:
raise ValueError raise ValueError
@pytest.fixture() @pytest.fixture
def chain_pass_exceptions() -> Runnable: def chain_pass_exceptions() -> Runnable:
fallback = RunnableLambda(_dont_raise_error) fallback = RunnableLambda(_dont_raise_error)
return {"text": RunnablePassthrough()} | RunnableLambda( return {"text": RunnablePassthrough()} | RunnableLambda(
@ -107,7 +107,8 @@ def _runnable(inputs: dict) -> str:
if inputs["text"] == "foo": if inputs["text"] == "foo":
return "first" return "first"
if "exception" not in inputs: if "exception" not in inputs:
raise ValueError msg = "missing exception"
raise ValueError(msg)
if inputs["text"] == "bar": if inputs["text"] == "bar":
return "second" return "second"
if isinstance(inputs["exception"], ValueError): if isinstance(inputs["exception"], ValueError):
@ -128,7 +129,7 @@ def test_invoke_with_exception_key() -> None:
runnable_with_single = runnable.with_fallbacks( runnable_with_single = runnable.with_fallbacks(
[runnable], exception_key="exception" [runnable], exception_key="exception"
) )
with pytest.raises(ValueError): with pytest.raises(ValueError, match="missing exception"):
runnable_with_single.invoke({"text": "baz"}) runnable_with_single.invoke({"text": "baz"})
actual = runnable_with_single.invoke({"text": "bar"}) actual = runnable_with_single.invoke({"text": "bar"})
@ -149,7 +150,7 @@ async def test_ainvoke_with_exception_key() -> None:
runnable_with_single = runnable.with_fallbacks( runnable_with_single = runnable.with_fallbacks(
[runnable], exception_key="exception" [runnable], exception_key="exception"
) )
with pytest.raises(ValueError): with pytest.raises(ValueError, match="missing exception"):
await runnable_with_single.ainvoke({"text": "baz"}) await runnable_with_single.ainvoke({"text": "baz"})
actual = await runnable_with_single.ainvoke({"text": "bar"}) actual = await runnable_with_single.ainvoke({"text": "bar"})
@ -166,7 +167,7 @@ async def test_ainvoke_with_exception_key() -> None:
def test_batch() -> None: def test_batch() -> None:
runnable = RunnableLambda(_runnable) runnable = RunnableLambda(_runnable)
with pytest.raises(ValueError): with pytest.raises(ValueError, match="missing exception"):
runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}]) runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
actual = runnable.batch( actual = runnable.batch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
@ -210,7 +211,7 @@ def test_batch() -> None:
async def test_abatch() -> None: async def test_abatch() -> None:
runnable = RunnableLambda(_runnable) runnable = RunnableLambda(_runnable)
with pytest.raises(ValueError): with pytest.raises(ValueError, match="missing exception"):
await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}]) await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
actual = await runnable.abatch( actual = await runnable.abatch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
@ -263,13 +264,15 @@ def _generate(input: Iterator) -> Iterator[str]:
def _generate_immediate_error(input: Iterator) -> Iterator[str]: def _generate_immediate_error(input: Iterator) -> Iterator[str]:
raise ValueError msg = "immmediate error"
raise ValueError(msg)
yield "" yield ""
def _generate_delayed_error(input: Iterator) -> Iterator[str]: def _generate_delayed_error(input: Iterator) -> Iterator[str]:
yield "" yield ""
raise ValueError msg = "delayed error"
raise ValueError(msg)
def test_fallbacks_stream() -> None: def test_fallbacks_stream() -> None:
@ -278,10 +281,10 @@ def test_fallbacks_stream() -> None:
) )
assert list(runnable.stream({})) == list("foo bar") assert list(runnable.stream({})) == list("foo bar")
with pytest.raises(ValueError): runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks( [RunnableGenerator(_generate)]
[RunnableGenerator(_generate)] )
) with pytest.raises(ValueError, match="delayed error"):
list(runnable.stream({})) list(runnable.stream({}))
@ -291,13 +294,15 @@ async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]: async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
raise ValueError msg = "immmediate error"
raise ValueError(msg)
yield "" yield ""
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]: async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
yield "" yield ""
raise ValueError msg = "delayed error"
raise ValueError(msg)
async def test_fallbacks_astream() -> None: async def test_fallbacks_astream() -> None:
@ -308,12 +313,11 @@ async def test_fallbacks_astream() -> None:
async for c in runnable.astream({}): async for c in runnable.astream({}):
assert c == next(expected) assert c == next(expected)
with pytest.raises(ValueError): runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks( [RunnableGenerator(_agenerate)]
[RunnableGenerator(_agenerate)] )
) with pytest.raises(ValueError, match="delayed error"):
async for _ in runnable.astream({}): _ = [_ async for _ in runnable.astream({})]
pass
class FakeStructuredOutputModel(BaseChatModel): class FakeStructuredOutputModel(BaseChatModel):

View File

@ -1,3 +1,4 @@
import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
@ -864,22 +865,24 @@ def test_get_output_messages_with_value_error() -> None:
"configurable": {"session_id": "1", "message_history": get_session_history("1")} "configurable": {"session_id": "1", "message_history": get_session_history("1")}
} }
with pytest.raises(ValueError) as excinfo: with pytest.raises(
ValueError,
match=re.escape(
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
f" Got {illegal_bool_message}."
),
):
with_history.bound.invoke([HumanMessage(content="hello")], config) with_history.bound.invoke([HumanMessage(content="hello")], config)
excepted = (
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
f" Got {illegal_bool_message}."
)
assert excepted in str(excinfo.value)
illegal_int_message = 123 illegal_int_message = 123
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message) runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
with_history = RunnableWithMessageHistory(runnable, get_session_history) with_history = RunnableWithMessageHistory(runnable, get_session_history)
with pytest.raises(ValueError) as excinfo: with pytest.raises(
ValueError,
match=re.escape(
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
f" Got {illegal_int_message}."
),
):
with_history.bound.invoke([HumanMessage(content="hello")], config) with_history.bound.invoke([HumanMessage(content="hello")], config)
excepted = (
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
f" Got {illegal_int_message}."
)
assert excepted in str(excinfo.value)

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import re
import sys import sys
import uuid import uuid
import warnings import warnings
@ -3827,13 +3828,13 @@ def test_retrying(mocker: MockerFixture) -> None:
_lambda_mock = mocker.Mock(side_effect=_lambda) _lambda_mock = mocker.Mock(side_effect=_lambda)
runnable = RunnableLambda(_lambda_mock) runnable = RunnableLambda(_lambda_mock)
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is 1"):
runnable.invoke(1) runnable.invoke(1)
assert _lambda_mock.call_count == 1 assert _lambda_mock.call_count == 1
_lambda_mock.reset_mock() _lambda_mock.reset_mock()
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is 1"):
runnable.with_retry( runnable.with_retry(
stop_after_attempt=2, stop_after_attempt=2,
retry_if_exception_type=(ValueError,), retry_if_exception_type=(ValueError,),
@ -3852,7 +3853,7 @@ def test_retrying(mocker: MockerFixture) -> None:
assert _lambda_mock.call_count == 1 # did not retry assert _lambda_mock.call_count == 1 # did not retry
_lambda_mock.reset_mock() _lambda_mock.reset_mock()
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is 1"):
runnable.with_retry( runnable.with_retry(
stop_after_attempt=2, stop_after_attempt=2,
wait_exponential_jitter=False, wait_exponential_jitter=False,
@ -3892,13 +3893,13 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
_lambda_mock = mocker.Mock(side_effect=_lambda) _lambda_mock = mocker.Mock(side_effect=_lambda)
runnable = RunnableLambda(_lambda_mock) runnable = RunnableLambda(_lambda_mock)
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is 1"):
await runnable.ainvoke(1) await runnable.ainvoke(1)
assert _lambda_mock.call_count == 1 assert _lambda_mock.call_count == 1
_lambda_mock.reset_mock() _lambda_mock.reset_mock()
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is 1"):
await runnable.with_retry( await runnable.with_retry(
stop_after_attempt=2, stop_after_attempt=2,
wait_exponential_jitter=False, wait_exponential_jitter=False,
@ -3918,7 +3919,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
assert _lambda_mock.call_count == 1 # did not retry assert _lambda_mock.call_count == 1 # did not retry
_lambda_mock.reset_mock() _lambda_mock.reset_mock()
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is 1"):
await runnable.with_retry( await runnable.with_retry(
stop_after_attempt=2, stop_after_attempt=2,
wait_exponential_jitter=False, wait_exponential_jitter=False,
@ -3982,9 +3983,8 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
raise ValueError(msg) raise ValueError(msg)
# Check that the chain on error is invoked # Check that the chain on error is invoked
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is too large"):
for _ in RunnableLambda(raise_value_error).stream(1000, config=config): _ = list(RunnableLambda(raise_value_error).stream(1000, config=config))
pass
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2
assert "ValueError('x is too large')" in str(tracer.runs[1].error) assert "ValueError('x is too large')" in str(tracer.runs[1].error)
@ -4061,9 +4061,13 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
raise ValueError(msg) raise ValueError(msg)
# Check that the chain on error is invoked # Check that the chain on error is invoked
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is too large"):
async for _ in RunnableLambda(raise_value_error).astream(1000, config=config): _ = [
pass _
async for _ in RunnableLambda(raise_value_error).astream(
1000, config=config
)
]
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2
assert "ValueError('x is too large')" in str(tracer.runs[1].error) assert "ValueError('x is too large')" in str(tracer.runs[1].error)
@ -4088,7 +4092,11 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
outputs: list[Any] = [] outputs: list[Any] = []
for input in inputs: for input in inputs:
if input.startswith(self.fail_starts_with): if input.startswith(self.fail_starts_with):
outputs.append(ValueError()) outputs.append(
ValueError(
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
)
)
else: else:
outputs.append(input + "a") outputs.append(input + "a")
return outputs return outputs
@ -4119,7 +4127,9 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
assert isinstance(chain, RunnableSequence) assert isinstance(chain, RunnableSequence)
# Test batch # Test batch
with pytest.raises(ValueError): with pytest.raises(
ValueError, match=re.escape("ControlledExceptionRunnable(bar) fail for bara")
):
chain.batch(["foo", "bar", "baz", "qux"]) chain.batch(["foo", "bar", "baz", "qux"])
spy = mocker.spy(ControlledExceptionRunnable, "batch") spy = mocker.spy(ControlledExceptionRunnable, "batch")
@ -4155,32 +4165,44 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
parent_run_foo = parent_runs[0] parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo" assert parent_run_foo.inputs["input"] == "foo"
assert repr(ValueError()) in str(parent_run_foo.error) assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
parent_run_foo.error
)
assert len(parent_run_foo.child_runs) == 4 assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [ assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
None, None,
None, None,
None, None,
] ]
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error) assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
parent_run_foo.child_runs[-1].error
)
parent_run_bar = parent_runs[1] parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar" assert parent_run_bar.inputs["input"] == "bar"
assert repr(ValueError()) in str(parent_run_bar.error) assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
parent_run_bar.error
)
assert len(parent_run_bar.child_runs) == 2 assert len(parent_run_bar.child_runs) == 2
assert parent_run_bar.child_runs[0].error is None assert parent_run_bar.child_runs[0].error is None
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error) assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
parent_run_bar.child_runs[1].error
)
parent_run_baz = parent_runs[2] parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz" assert parent_run_baz.inputs["input"] == "baz"
assert repr(ValueError()) in str(parent_run_baz.error) assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
parent_run_baz.error
)
assert len(parent_run_baz.child_runs) == 3 assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [ assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
None, None,
None, None,
] ]
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error) assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
parent_run_baz.child_runs[-1].error
)
parent_run_qux = parent_runs[3] parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux" assert parent_run_qux.inputs["input"] == "qux"
@ -4209,7 +4231,11 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
outputs: list[Any] = [] outputs: list[Any] = []
for input in inputs: for input in inputs:
if input.startswith(self.fail_starts_with): if input.startswith(self.fail_starts_with):
outputs.append(ValueError()) outputs.append(
ValueError(
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
)
)
else: else:
outputs.append(input + "a") outputs.append(input + "a")
return outputs return outputs
@ -4240,7 +4266,9 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
assert isinstance(chain, RunnableSequence) assert isinstance(chain, RunnableSequence)
# Test abatch # Test abatch
with pytest.raises(ValueError): with pytest.raises(
ValueError, match=re.escape("ControlledExceptionRunnable(bar) fail for bara")
):
await chain.abatch(["foo", "bar", "baz", "qux"]) await chain.abatch(["foo", "bar", "baz", "qux"])
spy = mocker.spy(ControlledExceptionRunnable, "abatch") spy = mocker.spy(ControlledExceptionRunnable, "abatch")
@ -4278,31 +4306,43 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
parent_run_foo = parent_runs[0] parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo" assert parent_run_foo.inputs["input"] == "foo"
assert repr(ValueError()) in str(parent_run_foo.error) assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
parent_run_foo.error
)
assert len(parent_run_foo.child_runs) == 4 assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [ assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
None, None,
None, None,
None, None,
] ]
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error) assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
parent_run_foo.child_runs[-1].error
)
parent_run_bar = parent_runs[1] parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar" assert parent_run_bar.inputs["input"] == "bar"
assert repr(ValueError()) in str(parent_run_bar.error) assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
parent_run_bar.error
)
assert len(parent_run_bar.child_runs) == 2 assert len(parent_run_bar.child_runs) == 2
assert parent_run_bar.child_runs[0].error is None assert parent_run_bar.child_runs[0].error is None
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error) assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
parent_run_bar.child_runs[1].error
)
parent_run_baz = parent_runs[2] parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz" assert parent_run_baz.inputs["input"] == "baz"
assert repr(ValueError()) in str(parent_run_baz.error) assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
parent_run_baz.error
)
assert len(parent_run_baz.child_runs) == 3 assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [ assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
None, None,
None, None,
] ]
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error) assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
parent_run_baz.child_runs[-1].error
)
parent_run_qux = parent_runs[3] parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux" assert parent_run_qux.inputs["input"] == "qux"
@ -4319,11 +4359,15 @@ def test_runnable_branch_init() -> None:
condition = RunnableLambda(lambda x: x > 0) condition = RunnableLambda(lambda x: x > 0)
# Test failure with less than 2 branches # Test failure with less than 2 branches
with pytest.raises(ValueError): with pytest.raises(
ValueError, match="RunnableBranch requires at least two branches"
):
RunnableBranch((condition, add)) RunnableBranch((condition, add))
# Test failure with less than 2 branches # Test failure with less than 2 branches
with pytest.raises(ValueError): with pytest.raises(
ValueError, match="RunnableBranch requires at least two branches"
):
RunnableBranch(condition) RunnableBranch(condition)
@ -4408,7 +4452,7 @@ def test_runnable_branch_invoke() -> None:
assert branch.invoke(10) == 100 assert branch.invoke(10) == 100
assert branch.invoke(0) == -1 assert branch.invoke(0) == -1
# Should raise an exception # Should raise an exception
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is too large"):
branch.invoke(1000) branch.invoke(1000)
@ -4472,7 +4516,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
assert tracer.runs[0].outputs == {"output": 0} assert tracer.runs[0].outputs == {"output": 0}
# Check that the chain on end is invoked # Check that the chain on end is invoked
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is too large"):
branch.invoke(1000, config={"callbacks": [tracer]}) branch.invoke(1000, config={"callbacks": [tracer]})
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2
@ -4500,7 +4544,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
assert tracer.runs[0].outputs == {"output": 0} assert tracer.runs[0].outputs == {"output": 0}
# Check that the chain on end is invoked # Check that the chain on end is invoked
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is too large"):
await branch.ainvoke(1000, config={"callbacks": [tracer]}) await branch.ainvoke(1000, config={"callbacks": [tracer]})
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2
@ -4561,9 +4605,8 @@ def test_runnable_branch_stream_with_callbacks() -> None:
assert tracer.runs[0].outputs == {"output": llm_res} assert tracer.runs[0].outputs == {"output": llm_res}
# Verify that the chain on error is invoked # Verify that the chain on error is invoked
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is error"):
for _ in branch.stream("error", config=config): _ = list(branch.stream("error", config=config))
pass
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2
assert "ValueError('x is error')" in str(tracer.runs[1].error) assert "ValueError('x is error')" in str(tracer.runs[1].error)
@ -4638,9 +4681,8 @@ async def test_runnable_branch_astream_with_callbacks() -> None:
assert tracer.runs[0].outputs == {"output": llm_res} assert tracer.runs[0].outputs == {"output": llm_res}
# Verify that the chain on error is invoked # Verify that the chain on error is invoked
with pytest.raises(ValueError): with pytest.raises(ValueError, match="x is error"):
async for _ in branch.astream("error", config=config): _ = [_ async for _ in branch.astream("error", config=config)]
pass
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2
assert "ValueError('x is error')" in str(tracer.runs[1].error) assert "ValueError('x is error')" in str(tracer.runs[1].error)

View File

@ -1824,8 +1824,7 @@ async def test_runnable_each() -> None:
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4] assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"): _ = [_ async for _ in add_one_map.astream_events([1, 2, 3], version="v1")]
pass
async def test_events_astream_config() -> None: async def test_events_astream_config() -> None:

View File

@ -1773,8 +1773,7 @@ async def test_runnable_each() -> None:
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4] assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
async for _ in add_one_map.astream_events([1, 2, 3], version="v2"): _ = [_ async for _ in add_one_map.astream_events([1, 2, 3], version="v2")]
pass
async def test_events_astream_config() -> None: async def test_events_astream_config() -> None:

View File

@ -419,7 +419,7 @@ class TestRunnableSequenceParallelTraceNesting:
self._check_posts() self._check_posts()
@pytest.mark.parametrize("parent_type", ("ls", "lc")) @pytest.mark.parametrize("parent_type", ["ls", "lc"])
def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None: def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
mock_session = MagicMock() mock_session = MagicMock()
mock_client_ = Client( mock_client_ = Client(

View File

@ -15,7 +15,7 @@ from langchain_core.runnables.utils import (
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"func, expected_source", ("func", "expected_source"),
[ [
(lambda x: x * 2, "lambda x: x * 2"), (lambda x: x * 2, "lambda x: x * 2"),
(lambda a, b: a + b, "lambda a, b: a + b"), (lambda a, b: a + b, "lambda a, b: a + b"),
@ -29,7 +29,7 @@ def test_get_lambda_source(func: Callable, expected_source: str) -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"text,prefix,expected_output", ("text", "prefix", "expected_output"),
[ [
("line 1\nline 2\nline 3", "1", "line 1\n line 2\n line 3"), ("line 1\nline 2\nline 3", "1", "line 1\n line 2\n line 3"),
("line 1\nline 2\nline 3", "ax", "line 1\n line 2\n line 3"), ("line 1\nline 2\nline 3", "ax", "line 1\n line 2\n line 3"),

View File

@ -184,7 +184,9 @@ def test_chat_message_chunks() -> None:
"ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk" "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
) )
with pytest.raises(ValueError): with pytest.raises(
ValueError, match="Cannot concatenate ChatMessageChunks with different roles."
):
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk( ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
role="Assistant", content=" indeed." role="Assistant", content=" indeed."
) )
@ -290,7 +292,10 @@ def test_function_message_chunks() -> None:
id="ai5", name="hello", content="I am indeed." id="ai5", name="hello", content="I am indeed."
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk" ), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Cannot concatenate FunctionMessageChunks with different names.",
):
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk( FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="bye", content=" indeed." name="bye", content=" indeed."
) )
@ -303,7 +308,10 @@ def test_ai_message_chunks() -> None:
"AIMessageChunk + AIMessageChunk should be a AIMessageChunk" "AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
) )
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Cannot concatenate AIMessageChunks with different example values.",
):
AIMessageChunk(example=True, content="I am") + AIMessageChunk( AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=False, content=" indeed." example=False, content=" indeed."
) )
@ -320,30 +328,21 @@ class TestGetBufferString(unittest.TestCase):
self.tool_calls_msg = AIMessage(content="tool") self.tool_calls_msg = AIMessage(content="tool")
def test_empty_input(self) -> None: def test_empty_input(self) -> None:
self.assertEqual(get_buffer_string([]), "") assert get_buffer_string([]) == ""
def test_valid_single_message(self) -> None: def test_valid_single_message(self) -> None:
expected_output = f"Human: {self.human_msg.content}" expected_output = f"Human: {self.human_msg.content}"
self.assertEqual( assert get_buffer_string([self.human_msg]) == expected_output
get_buffer_string([self.human_msg]),
expected_output,
)
def test_custom_human_prefix(self) -> None: def test_custom_human_prefix(self) -> None:
prefix = "H" prefix = "H"
expected_output = f"{prefix}: {self.human_msg.content}" expected_output = f"{prefix}: {self.human_msg.content}"
self.assertEqual( assert get_buffer_string([self.human_msg], human_prefix="H") == expected_output
get_buffer_string([self.human_msg], human_prefix="H"),
expected_output,
)
def test_custom_ai_prefix(self) -> None: def test_custom_ai_prefix(self) -> None:
prefix = "A" prefix = "A"
expected_output = f"{prefix}: {self.ai_msg.content}" expected_output = f"{prefix}: {self.ai_msg.content}"
self.assertEqual( assert get_buffer_string([self.ai_msg], ai_prefix="A") == expected_output
get_buffer_string([self.ai_msg], ai_prefix="A"),
expected_output,
)
def test_multiple_msg(self) -> None: def test_multiple_msg(self) -> None:
msgs = [ msgs = [
@ -366,10 +365,7 @@ class TestGetBufferString(unittest.TestCase):
"AI: tool", "AI: tool",
] ]
) )
self.assertEqual( assert get_buffer_string(msgs) == expected_output
get_buffer_string(msgs),
expected_output,
)
def test_multiple_msg() -> None: def test_multiple_msg() -> None:
@ -991,7 +987,7 @@ def test_tool_message_str() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
["first", "others", "expected"], ("first", "others", "expected"),
[ [
("", [""], ""), ("", [""], ""),
("", [[]], [""]), ("", [[]], [""]),

View File

@ -775,8 +775,8 @@ def test_exception_handling_callable() -> None:
def test_exception_handling_non_tool_exception() -> None: def test_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError()) _tool = _FakeExceptionTool(exception=ValueError("some error"))
with pytest.raises(ValueError): with pytest.raises(ValueError, match="some error"):
_tool.run({}) _tool.run({})
@ -806,8 +806,8 @@ async def test_async_exception_handling_callable() -> None:
async def test_async_exception_handling_non_tool_exception() -> None: async def test_async_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError()) _tool = _FakeExceptionTool(exception=ValueError("some error"))
with pytest.raises(ValueError): with pytest.raises(ValueError, match="some error"):
await _tool.arun({}) await _tool.arun({})
@ -987,7 +987,7 @@ def test_optional_subset_model_rewrite() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inputs, expected", ("inputs", "expected"),
[ [
# Check not required # Check not required
({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}), ({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}),
@ -1323,6 +1323,10 @@ def test_tool_invalid_docstrings() -> None:
""" """
return bar return bar
for func in {foo3, foo4}:
with pytest.raises(ValueError, match="Found invalid Google-Style docstring."):
_ = tool(func, parse_docstring=True)
def foo5(bar: str, baz: int) -> str: def foo5(bar: str, baz: int) -> str:
"""The foo. """The foo.
@ -1332,9 +1336,10 @@ def test_tool_invalid_docstrings() -> None:
""" """
return bar return bar
for func in [foo3, foo4, foo5]: with pytest.raises(
with pytest.raises(ValueError): ValueError, match="Arg banana in docstring not found in function signature."
_ = tool(func, parse_docstring=True) ):
_ = tool(foo5, parse_docstring=True)
def test_tool_annotated_descriptions() -> None: def test_tool_annotated_descriptions() -> None:
@ -2004,9 +2009,9 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("obj", "expected"), ("obj", "expected"),
[ [
["foo", True], ("foo", True),
[valid_tool_result_blocks, True], (valid_tool_result_blocks, True),
[invalid_tool_result_blocks, False], (invalid_tool_result_blocks, False),
], ],
) )
def test__is_message_content_type(obj: Any, expected: bool) -> None: def test__is_message_content_type(obj: Any, expected: bool) -> None:
@ -2268,7 +2273,8 @@ def test_imports() -> None:
"InjectedToolArg", "InjectedToolArg",
] ]
for module_name in expected_all: for module_name in expected_all:
assert hasattr(tools, module_name) and getattr(tools, module_name) is not None assert hasattr(tools, module_name)
assert getattr(tools, module_name) is not None
def test_structured_tool_direct_init() -> None: def test_structured_tool_direct_init() -> None:
@ -2317,7 +2323,11 @@ def test_tool_injected_tool_call_id() -> None:
} }
) == ToolMessage(0, tool_call_id="bar") # type: ignore ) == ToolMessage(0, tool_call_id="bar") # type: ignore
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="When tool includes an InjectedToolCallId argument, "
"tool must always be invoked with a full model ToolCall",
):
assert foo.invoke({"x": 0}) assert foo.invoke({"x": 0})
@tool @tool
@ -2341,7 +2351,7 @@ def test_tool_uninjected_tool_call_id() -> None:
"""Foo.""" """Foo."""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
with pytest.raises(ValueError): with pytest.raises(ValueError, match="1 validation error for foo"):
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}) foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
assert foo.invoke( assert foo.invoke(

View File

@ -22,6 +22,5 @@ def test_public_api() -> None:
# Assert that the object is actually present in the schema module # Assert that the object is actually present in the schema module
for module_name in expected_all: for module_name in expected_all:
assert ( assert hasattr(schemas, module_name)
hasattr(schemas, module_name) and getattr(schemas, module_name) is not None assert getattr(schemas, module_name) is not None
)

View File

@ -6,7 +6,7 @@ from langchain_core.utils.aiter import abatch_iterate
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_size, input_iterable, expected_output", ("input_size", "input_iterable", "expected_output"),
[ [
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]), (2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]), (3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),

View File

@ -51,7 +51,12 @@ def test_get_from_dict_or_env() -> None:
# Not the most obvious behavior, but # Not the most obvious behavior, but
# this is how it works right now # this is how it works right now
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Did not find not exists, "
"please add an environment variable `__SOME_KEY_IN_ENV` which contains it, "
"or pass `not exists` as a named parameter.",
):
assert ( assert (
get_from_dict_or_env( get_from_dict_or_env(
{ {

View File

@ -37,7 +37,7 @@ from langchain_core.utils.function_calling import (
) )
@pytest.fixture() @pytest.fixture
def pydantic() -> type[BaseModel]: def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801 class dummy_function(BaseModel): # noqa: N801
"""Dummy function.""" """Dummy function."""
@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]:
return dummy_function return dummy_function
@pytest.fixture() @pytest.fixture
def annotated_function() -> Callable: def annotated_function() -> Callable:
def dummy_function( def dummy_function(
arg1: ExtensionsAnnotated[int, "foo"], arg1: ExtensionsAnnotated[int, "foo"],
@ -59,7 +59,7 @@ def annotated_function() -> Callable:
return dummy_function return dummy_function
@pytest.fixture() @pytest.fixture
def function() -> Callable: def function() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None: def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""Dummy function. """Dummy function.
@ -72,7 +72,7 @@ def function() -> Callable:
return dummy_function return dummy_function
@pytest.fixture() @pytest.fixture
def function_docstring_annotations() -> Callable: def function_docstring_annotations() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None: def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""Dummy function. """Dummy function.
@ -85,7 +85,7 @@ def function_docstring_annotations() -> Callable:
return dummy_function return dummy_function
@pytest.fixture() @pytest.fixture
def runnable() -> Runnable: def runnable() -> Runnable:
class Args(ExtensionsTypedDict): class Args(ExtensionsTypedDict):
arg1: ExtensionsAnnotated[int, "foo"] arg1: ExtensionsAnnotated[int, "foo"]
@ -97,7 +97,7 @@ def runnable() -> Runnable:
return RunnableLambda(dummy_function) return RunnableLambda(dummy_function)
@pytest.fixture() @pytest.fixture
def dummy_tool() -> BaseTool: def dummy_tool() -> BaseTool:
class Schema(BaseModel): class Schema(BaseModel):
arg1: int = Field(..., description="foo") arg1: int = Field(..., description="foo")
@ -114,7 +114,7 @@ def dummy_tool() -> BaseTool:
return DummyFunction() return DummyFunction()
@pytest.fixture() @pytest.fixture
def dummy_structured_tool() -> StructuredTool: def dummy_structured_tool() -> StructuredTool:
class Schema(BaseModel): class Schema(BaseModel):
arg1: int = Field(..., description="foo") arg1: int = Field(..., description="foo")
@ -128,7 +128,7 @@ def dummy_structured_tool() -> StructuredTool:
) )
@pytest.fixture() @pytest.fixture
def dummy_structured_tool_args_schema_dict() -> StructuredTool: def dummy_structured_tool_args_schema_dict() -> StructuredTool:
args_schema = { args_schema = {
"type": "object", "type": "object",
@ -150,7 +150,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
) )
@pytest.fixture() @pytest.fixture
def dummy_pydantic() -> type[BaseModel]: def dummy_pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801 class dummy_function(BaseModel): # noqa: N801
"""Dummy function.""" """Dummy function."""
@ -161,7 +161,7 @@ def dummy_pydantic() -> type[BaseModel]:
return dummy_function return dummy_function
@pytest.fixture() @pytest.fixture
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]: def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
class dummy_function(BaseModelV2Maybe): # noqa: N801 class dummy_function(BaseModelV2Maybe): # noqa: N801
"""Dummy function.""" """Dummy function."""
@ -174,7 +174,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
return dummy_function return dummy_function
@pytest.fixture() @pytest.fixture
def dummy_typing_typed_dict() -> type: def dummy_typing_typed_dict() -> type:
class dummy_function(TypingTypedDict): # noqa: N801 class dummy_function(TypingTypedDict): # noqa: N801
"""Dummy function.""" """Dummy function."""
@ -185,7 +185,7 @@ def dummy_typing_typed_dict() -> type:
return dummy_function return dummy_function
@pytest.fixture() @pytest.fixture
def dummy_typing_typed_dict_docstring() -> type: def dummy_typing_typed_dict_docstring() -> type:
class dummy_function(TypingTypedDict): # noqa: N801 class dummy_function(TypingTypedDict): # noqa: N801
"""Dummy function. """Dummy function.
@ -201,7 +201,7 @@ def dummy_typing_typed_dict_docstring() -> type:
return dummy_function return dummy_function
@pytest.fixture() @pytest.fixture
def dummy_extensions_typed_dict() -> type: def dummy_extensions_typed_dict() -> type:
class dummy_function(ExtensionsTypedDict): # noqa: N801 class dummy_function(ExtensionsTypedDict): # noqa: N801
"""Dummy function.""" """Dummy function."""
@ -212,7 +212,7 @@ def dummy_extensions_typed_dict() -> type:
return dummy_function return dummy_function
@pytest.fixture() @pytest.fixture
def dummy_extensions_typed_dict_docstring() -> type: def dummy_extensions_typed_dict_docstring() -> type:
class dummy_function(ExtensionsTypedDict): # noqa: N801 class dummy_function(ExtensionsTypedDict): # noqa: N801
"""Dummy function. """Dummy function.
@ -228,7 +228,7 @@ def dummy_extensions_typed_dict_docstring() -> type:
return dummy_function return dummy_function
@pytest.fixture() @pytest.fixture
def json_schema() -> dict: def json_schema() -> dict:
return { return {
"title": "dummy_function", "title": "dummy_function",
@ -246,7 +246,7 @@ def json_schema() -> dict:
} }
@pytest.fixture() @pytest.fixture
def anthropic_tool() -> dict: def anthropic_tool() -> dict:
return { return {
"name": "dummy_function", "name": "dummy_function",
@ -266,7 +266,7 @@ def anthropic_tool() -> dict:
} }
@pytest.fixture() @pytest.fixture
def bedrock_converse_tool() -> dict: def bedrock_converse_tool() -> dict:
return { return {
"toolSpec": { "toolSpec": {

View File

@ -4,7 +4,7 @@ from langchain_core.utils.iter import batch_iterate
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_size, input_iterable, expected_output", ("input_size", "input_iterable", "expected_output"),
[ [
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]), (2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]), (3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),

View File

@ -147,7 +147,7 @@ def test_dereference_refs_remote_ref() -> None:
"first_name": {"$ref": "https://somewhere/else/name"}, "first_name": {"$ref": "https://somewhere/else/name"},
}, },
} }
with pytest.raises(ValueError): with pytest.raises(ValueError, match="ref paths are expected to be URI fragments"):
dereference_refs(schema) dereference_refs(schema)

View File

@ -220,7 +220,7 @@ output5 = {
@pytest.mark.parametrize( @pytest.mark.parametrize(
"schema, output", ("schema", "output"),
[ [
(schema1, output1), (schema1, output1),
(schema2, output2), (schema2, output2),

View File

@ -29,12 +29,17 @@ def test_dict_int_op_nested() -> None:
def test_dict_int_op_max_depth_exceeded() -> None: def test_dict_int_op_max_depth_exceeded() -> None:
left = {"a": {"b": {"c": 1}}} left = {"a": {"b": {"c": 1}}}
right = {"a": {"b": {"c": 2}}} right = {"a": {"b": {"c": 2}}}
with pytest.raises(ValueError): with pytest.raises(
ValueError, match="max_depth=2 exceeded, unable to combine dicts."
):
_dict_int_op(left, right, operator.add, max_depth=2) _dict_int_op(left, right, operator.add, max_depth=2)
def test_dict_int_op_invalid_types() -> None: def test_dict_int_op_invalid_types() -> None:
left = {"a": 1, "b": "string"} left = {"a": 1, "b": "string"}
right = {"a": 2, "b": 3} right = {"a": 2, "b": 3}
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Only dict and int values are supported.",
):
_dict_int_op(left, right, operator.add) _dict_int_op(left, right, operator.add)

View File

@ -46,7 +46,7 @@ def test_check_package_version(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("left", "right", "expected"), ("left", "right", "expected"),
( [
# Merge `None` and `1`. # Merge `None` and `1`.
({"a": None}, {"a": 1}, {"a": 1}), ({"a": None}, {"a": 1}, {"a": 1}),
# Merge `1` and `None`. # Merge `1` and `None`.
@ -111,7 +111,7 @@ def test_check_package_version(
{"a": [{"idx": 0, "b": "f"}]}, {"a": [{"idx": 0, "b": "f"}]},
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]}, {"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
), ),
), ],
) )
def test_merge_dicts( def test_merge_dicts(
left: dict, right: dict, expected: Union[dict, AbstractContextManager] left: dict, right: dict, expected: Union[dict, AbstractContextManager]
@ -130,7 +130,7 @@ def test_merge_dicts(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("left", "right", "expected"), ("left", "right", "expected"),
( [
# 'type' special key handling # 'type' special key handling
({"type": "foo"}, {"type": "foo"}, {"type": "foo"}), ({"type": "foo"}, {"type": "foo"}, {"type": "foo"}),
( (
@ -138,7 +138,7 @@ def test_merge_dicts(
{"type": "bar"}, {"type": "bar"},
pytest.raises(ValueError, match="Unable to merge."), pytest.raises(ValueError, match="Unable to merge."),
), ),
), ],
) )
@pytest.mark.xfail(reason="Refactors to make in 0.3") @pytest.mark.xfail(reason="Refactors to make in 0.3")
def test_merge_dicts_0_3( def test_merge_dicts_0_3(
@ -183,36 +183,32 @@ def test_guard_import(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("module_name", "pip_name", "package"), ("module_name", "pip_name", "package", "expected_pip_name"),
[ [
("langchain_core.utilsW", None, None), ("langchain_core.utilsW", None, None, "langchain-core"),
("langchain_core.utilsW", "langchain-core-2", None), ("langchain_core.utilsW", "langchain-core-2", None, "langchain-core-2"),
("langchain_core.utilsW", None, "langchain-coreWX"), ("langchain_core.utilsW", None, "langchain-coreWX", "langchain-core"),
("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"), (
("langchain_coreW", None, None), # ModuleNotFoundError "langchain_core.utilsW",
"langchain-core-2",
"langchain-coreWX",
"langchain-core-2",
),
("langchain_coreW", None, None, "langchain-coreW"), # ModuleNotFoundError
], ],
) )
def test_guard_import_failure( def test_guard_import_failure(
module_name: str, pip_name: Optional[str], package: Optional[str] module_name: str,
pip_name: Optional[str],
package: Optional[str],
expected_pip_name: str,
) -> None: ) -> None:
with pytest.raises(ImportError) as exc_info: with pytest.raises(
if package is None and pip_name is None: ImportError,
guard_import(module_name) match=f"Could not import {module_name} python package. "
elif package is None and pip_name is not None: f"Please install it with `pip install {expected_pip_name}`.",
guard_import(module_name, pip_name=pip_name) ):
elif package is not None and pip_name is None: guard_import(module_name, pip_name=pip_name, package=package)
guard_import(module_name, package=package)
elif package is not None and pip_name is not None:
guard_import(module_name, pip_name=pip_name, package=package)
else:
msg = "Invalid test case"
raise ValueError(msg)
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
err_msg = (
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name}`."
)
assert exc_info.value.msg == err_msg
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2") @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")