core: Add ruff rules PT (pytest) (#29381)

See https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
This commit is contained in:
Christophe Bornet 2025-04-01 19:31:07 +02:00 committed by GitHub
parent 6896c863e8
commit 8a33402016
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 379 additions and 227 deletions

View File

@ -103,7 +103,6 @@ ignore = [
"PLC",
"PLE",
"PLR",
"PT",
"PYI",
"RET",
"RUF",

View File

@ -9,7 +9,7 @@ from langchain_core._api.beta_decorator import beta, warn_beta
@pytest.mark.parametrize(
"kwargs, expected_message",
("kwargs", "expected_message"),
[
(
{

View File

@ -13,7 +13,7 @@ from langchain_core._api.deprecation import (
@pytest.mark.parametrize(
"kwargs, expected_message",
("kwargs", "expected_message"),
[
(
{
@ -404,7 +404,9 @@ def test_deprecated_method_pydantic() -> None:
def test_raise_error_for_bad_decorator() -> None:
"""Verify that errors raised on init rather than on use."""
# Should not specify both `alternative` and `alternative_import`
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="Cannot specify both alternative and alternative_import"
):
@deprecated(since="2.0.0", alternative="NewClass", alternative_import="hello")
def deprecated_function() -> str:

View File

@ -28,7 +28,7 @@ def test_initialization() -> None:
assert cache_with_maxsize._cache == {}
assert cache_with_maxsize._maxsize == 2
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="maxsize must be greater than 0"):
InMemoryCache(maxsize=0)

View File

@ -6,7 +6,6 @@ from uuid import UUID
import pytest
from blockbuster import BlockBuster, blockbuster_ctx
from pytest import Config, Function, Parser
from pytest_mock import MockerFixture
@ -36,7 +35,7 @@ def blockbuster() -> Iterator[BlockBuster]:
yield bb
def pytest_addoption(parser: Parser) -> None:
def pytest_addoption(parser: pytest.Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
@ -50,7 +49,9 @@ def pytest_addoption(parser: Parser) -> None:
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
def pytest_collection_modifyitems(
config: pytest.Config, items: Sequence[pytest.Function]
) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
@ -118,7 +119,7 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
)
@pytest.fixture()
@pytest.fixture
def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:
side_effect = (
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)

View File

@ -16,16 +16,16 @@ from langchain_core.indexing.in_memory import (
class TestDocumentIndexerTestSuite(DocumentIndexerTestSuite):
@pytest.fixture()
@pytest.fixture
def index(self) -> Generator[DocumentIndex, None, None]:
yield InMemoryDocumentIndex()
yield InMemoryDocumentIndex() # noqa: PT022
class TestAsyncDocumentIndexerTestSuite(AsyncDocumentIndexTestSuite):
# Something funky is going on with mypy and async pytest fixture
@pytest.fixture()
@pytest.fixture
async def index(self) -> AsyncGenerator[DocumentIndex, None]: # type: ignore
yield InMemoryDocumentIndex()
yield InMemoryDocumentIndex() # noqa: PT022
def test_sync_retriever() -> None:

View File

@ -7,7 +7,7 @@ import pytest_asyncio
from langchain_core.indexing import InMemoryRecordManager
@pytest.fixture()
@pytest.fixture
def manager() -> InMemoryRecordManager:
"""Initialize the test database and yield the TimestampedSet instance."""
# Initialize and yield the TimestampedSet instance

View File

@ -466,11 +466,19 @@ def test_incremental_fails_with_bad_source_ids(
]
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Source id key is required when cleanup mode is "
"incremental or scoped_full.",
):
# Should raise an error because no source id function was specified
index(loader, record_manager, vector_store, cleanup="incremental")
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Source ids are required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified
index(
loader,
@ -502,7 +510,11 @@ async def test_aincremental_fails_with_bad_source_ids(
]
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Source id key is required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified
await aindex(
loader,
@ -511,7 +523,11 @@ async def test_aincremental_fails_with_bad_source_ids(
cleanup="incremental",
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Source ids are required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified
await aindex(
loader,
@ -771,11 +787,19 @@ def test_scoped_full_fails_with_bad_source_ids(
]
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Source id key is required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified
index(loader, record_manager, vector_store, cleanup="scoped_full")
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Source ids are required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified
index(
loader,
@ -807,11 +831,19 @@ async def test_ascoped_full_fails_with_bad_source_ids(
]
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Source id key is required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified
await aindex(loader, arecord_manager, vector_store, cleanup="scoped_full")
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Source ids are required when cleanup mode "
"is incremental or scoped_full.",
):
# Should raise an error because no source id function was specified
await aindex(
loader,

View File

@ -99,6 +99,7 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None:
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
@pytest.mark.xfail(reason="This test is failing due to a bug in the testing code")
async def test_stream_error_callback() -> None:
message = "test"

View File

@ -1,5 +1,6 @@
import base64
import json
import re
import typing
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
@ -513,7 +514,14 @@ def test_trim_messages_bound_model_token_counter() -> None:
def test_trim_messages_bad_token_counter() -> None:
trimmer = trim_messages(max_tokens=10, token_counter={})
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape(
"'token_counter' expected to be a model that implements "
"'get_num_tokens_from_messages()' or a function. "
"Received object of type <class 'dict'>."
),
):
trimmer.invoke([HumanMessage("foobar")])
@ -852,7 +860,9 @@ def test_convert_to_messages_openai_refusal() -> None:
assert actual == expected
# Raises error if content is missing.
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="Message dict must contain 'role' and 'content' keys"
):
convert_to_messages([{"role": "assistant", "refusal": "9.1"}])

View File

@ -157,9 +157,10 @@ def test_pydantic_output_parser_fail() -> None:
pydantic_object=TestModel
)
with pytest.raises(OutputParserException) as e:
with pytest.raises(
OutputParserException, match="Failed to parse TestModel from completion"
):
pydantic_parser.parse(DEF_RESULT_FAIL)
assert "Failed to parse TestModel from completion" in str(e)
def test_pydantic_output_parser_type_inference() -> None:

View File

@ -1,3 +1,4 @@
import re
import warnings
from pathlib import Path
from typing import Any, Union, cast
@ -165,15 +166,14 @@ def test_create_system_message_prompt_list_template_partial_variables_not_null()
{variables}
"""
try:
graph_analyst_template = SystemMessagePromptTemplate.from_template(
with pytest.raises(
ValueError, match="Partial variables are not supported for list of templates."
):
_ = SystemMessagePromptTemplate.from_template(
template=[graph_creator_content1, graph_creator_content2],
input_variables=["variables"],
partial_variables={"variables": "foo"},
)
graph_analyst_template.format(variables="foo")
except ValueError as e:
assert str(e) == "Partial variables are not supported for list of templates."
def test_message_prompt_template_from_template_file() -> None:
@ -330,7 +330,7 @@ def test_chat_prompt_template_from_messages_jinja2() -> None:
@pytest.mark.requires("jinja2")
@pytest.mark.parametrize(
"template_format,image_type_placeholder,image_data_placeholder",
("template_format", "image_type_placeholder", "image_data_placeholder"),
[
("f-string", "{image_type}", "{image_data}"),
("mustache", "{{image_type}}", "{{image_data}}"),
@ -393,7 +393,12 @@ def test_chat_prompt_template_with_messages(
def test_chat_invalid_input_variables_extra() -> None:
messages = [HumanMessage(content="foo")]
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape(
"Got mismatched input_variables. Expected: set(). Got: ['foo']"
),
):
ChatPromptTemplate(
messages=messages, # type: ignore[arg-type]
input_variables=["foo"],
@ -407,7 +412,10 @@ def test_chat_invalid_input_variables_extra() -> None:
def test_chat_invalid_input_variables_missing() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape("Got mismatched input_variables. Expected: {'foo'}. Got: []"),
):
ChatPromptTemplate(
messages=messages, # type: ignore[arg-type]
input_variables=[],
@ -481,7 +489,7 @@ async def test_chat_from_role_strings() -> None:
@pytest.mark.parametrize(
"args,expected",
("args", "expected"),
[
(
("human", "{question}"),
@ -551,7 +559,7 @@ def test_chat_prompt_template_append_and_extend() -> None:
def test_convert_to_message_is_strict() -> None:
"""Verify that _convert_to_message is strict."""
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Unexpected message type: meow."):
# meow does not correspond to a valid message type.
# this test is here to ensure that functionality to interpret `meow`
# as a role is NOT added.
@ -762,14 +770,20 @@ async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None:
),
]
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Loading images from 'path' has been removed as of 0.3.15 for security reasons.",
):
template.format_messages(
name="R2D2",
in_mem=in_mem,
file_path="some/path",
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Loading images from 'path' has been removed as of 0.3.15 for security reasons.",
):
await template.aformat_messages(
name="R2D2",
in_mem=in_mem,
@ -869,10 +883,10 @@ def test_chat_prompt_message_dict() -> None:
HumanMessage(content="bar"),
]
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Invalid template: False"):
ChatPromptTemplate([{"role": "system", "content": False}])
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Unexpected message type: foo."):
ChatPromptTemplate([{"role": "foo", "content": "foo"}])

View File

@ -1,5 +1,6 @@
"""Test few shot prompt template."""
import re
from collections.abc import Sequence
from typing import Any
@ -24,7 +25,7 @@ EXAMPLE_PROMPT = PromptTemplate(
)
@pytest.fixture()
@pytest.fixture
@pytest.mark.requires("jinja2")
def example_jinja2_prompt() -> tuple[PromptTemplate, list[dict[str, str]]]:
example_template = "{{ word }}: {{ antonym }}"
@ -74,7 +75,10 @@ def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
# Test when missing in suffix
template = "This is a {foo} test."
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape("check for mismatched or missing input parameters from []"),
):
FewShotPromptTemplate(
input_variables=[],
suffix=template,
@ -91,7 +95,10 @@ def test_prompt_missing_input_variables() -> None:
# Test when missing in prefix
template = "This is a {foo} test."
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape("check for mismatched or missing input parameters from []"),
):
FewShotPromptTemplate(
input_variables=[],
suffix="foo",

View File

@ -1,5 +1,7 @@
"""Test few shot prompt template."""
import re
import pytest
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
@ -58,7 +60,10 @@ def test_prompttemplate_validation() -> None:
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape("Got input_variables=[], but based on prefix/suffix expected"),
):
FewShotPromptWithTemplates(
suffix=suffix,
prefix=prefix,

View File

@ -1,5 +1,6 @@
"""Test functionality related to prompts."""
import re
from typing import Any, Union
from unittest import mock
@ -264,7 +265,10 @@ def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
template = "This is a {foo} test."
input_variables: list = []
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape("check for mismatched or missing input parameters from []"),
):
PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
@ -275,7 +279,10 @@ def test_prompt_missing_input_variables() -> None:
def test_prompt_empty_input_variable() -> None:
"""Test error is raised when empty string input variable."""
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape("check for mismatched or missing input parameters from ['']"),
):
PromptTemplate(input_variables=[""], template="{}", validate_template=True)
@ -283,7 +290,13 @@ def test_prompt_wrong_input_variables() -> None:
"""Test error is raised when name of input variable is wrong."""
template = "This is a {foo} test."
input_variables = ["bar"]
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape(
"Invalid prompt schema; "
"check for mismatched or missing input parameters from ['bar']"
),
):
PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
@ -330,7 +343,7 @@ def test_prompt_invalid_template_format() -> None:
"""Test initializing a prompt with invalid template format."""
template = "This is a {foo} test."
input_variables = ["foo"]
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Unsupported template format: bar"):
PromptTemplate(
input_variables=input_variables,
template=template,
@ -580,7 +593,7 @@ async def test_prompt_ainvoke_with_metadata() -> None:
@pytest.mark.parametrize(
"value, expected",
("value", "expected"),
[
("0", "0"),
(0, "0"),

View File

@ -330,7 +330,7 @@ test_cases = [
]
@pytest.mark.parametrize("runnable, cases", test_cases)
@pytest.mark.parametrize(("runnable", "cases"), test_cases)
def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
@ -342,7 +342,7 @@ def test_context_runnables(
assert add(runnable.stream(cases[0].input)) == cases[0].output
@pytest.mark.parametrize("runnable, cases", test_cases)
@pytest.mark.parametrize(("runnable", "cases"), test_cases)
async def test_context_runnables_async(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
@ -357,14 +357,19 @@ async def test_context_runnables_async(
def test_runnable_context_seq_key_not_found() -> None:
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="Expected exactly one setter for context key foo"
):
seq.invoke("foo")
def test_runnable_context_seq_key_order() -> None:
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Context setter for key foo must be defined after all getters.",
):
seq.invoke("foo")
@ -374,7 +379,9 @@ def test_runnable_context_deadlock() -> None:
"foo": Context.setter("foo") | Context.getter("input"),
} | RunnablePassthrough()
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="Deadlock detected between context keys foo and input"
):
seq.invoke("foo")
@ -383,7 +390,9 @@ def test_runnable_context_seq_key_circular_ref() -> None:
"bar": Context.setter(input=Context.getter("input"))
} | Context.getter("foo")
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="Circular reference in context setter for key input"
):
seq.invoke("foo")

View File

@ -32,7 +32,7 @@ from langchain_core.runnables import (
from langchain_core.tools import BaseTool
@pytest.fixture()
@pytest.fixture
def llm() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
@ -40,7 +40,7 @@ def llm() -> RunnableWithFallbacks:
return error_llm.with_fallbacks([pass_llm])
@pytest.fixture()
@pytest.fixture
def llm_multi() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
@ -49,7 +49,7 @@ def llm_multi() -> RunnableWithFallbacks:
return error_llm.with_fallbacks([error_llm_2, pass_llm])
@pytest.fixture()
@pytest.fixture
def chain() -> Runnable:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
@ -70,7 +70,7 @@ def _dont_raise_error(inputs: dict) -> str:
raise ValueError
@pytest.fixture()
@pytest.fixture
def chain_pass_exceptions() -> Runnable:
fallback = RunnableLambda(_dont_raise_error)
return {"text": RunnablePassthrough()} | RunnableLambda(
@ -107,7 +107,8 @@ def _runnable(inputs: dict) -> str:
if inputs["text"] == "foo":
return "first"
if "exception" not in inputs:
raise ValueError
msg = "missing exception"
raise ValueError(msg)
if inputs["text"] == "bar":
return "second"
if isinstance(inputs["exception"], ValueError):
@ -128,7 +129,7 @@ def test_invoke_with_exception_key() -> None:
runnable_with_single = runnable.with_fallbacks(
[runnable], exception_key="exception"
)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="missing exception"):
runnable_with_single.invoke({"text": "baz"})
actual = runnable_with_single.invoke({"text": "bar"})
@ -149,7 +150,7 @@ async def test_ainvoke_with_exception_key() -> None:
runnable_with_single = runnable.with_fallbacks(
[runnable], exception_key="exception"
)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="missing exception"):
await runnable_with_single.ainvoke({"text": "baz"})
actual = await runnable_with_single.ainvoke({"text": "bar"})
@ -166,7 +167,7 @@ async def test_ainvoke_with_exception_key() -> None:
def test_batch() -> None:
runnable = RunnableLambda(_runnable)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="missing exception"):
runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
actual = runnable.batch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
@ -210,7 +211,7 @@ def test_batch() -> None:
async def test_abatch() -> None:
runnable = RunnableLambda(_runnable)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="missing exception"):
await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
actual = await runnable.abatch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
@ -263,13 +264,15 @@ def _generate(input: Iterator) -> Iterator[str]:
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
raise ValueError
msg = "immmediate error"
raise ValueError(msg)
yield ""
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
yield ""
raise ValueError
msg = "delayed error"
raise ValueError(msg)
def test_fallbacks_stream() -> None:
@ -278,10 +281,10 @@ def test_fallbacks_stream() -> None:
)
assert list(runnable.stream({})) == list("foo bar")
with pytest.raises(ValueError):
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(
[RunnableGenerator(_generate)]
)
with pytest.raises(ValueError, match="delayed error"):
list(runnable.stream({}))
@ -291,13 +294,15 @@ async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
raise ValueError
msg = "immmediate error"
raise ValueError(msg)
yield ""
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
yield ""
raise ValueError
msg = "delayed error"
raise ValueError(msg)
async def test_fallbacks_astream() -> None:
@ -308,12 +313,11 @@ async def test_fallbacks_astream() -> None:
async for c in runnable.astream({}):
assert c == next(expected)
with pytest.raises(ValueError):
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
[RunnableGenerator(_agenerate)]
)
async for _ in runnable.astream({}):
pass
with pytest.raises(ValueError, match="delayed error"):
_ = [_ async for _ in runnable.astream({})]
class FakeStructuredOutputModel(BaseChatModel):

View File

@ -1,3 +1,4 @@
import re
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
@ -864,22 +865,24 @@ def test_get_output_messages_with_value_error() -> None:
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
}
with pytest.raises(ValueError) as excinfo:
with_history.bound.invoke([HumanMessage(content="hello")], config)
excepted = (
with pytest.raises(
ValueError,
match=re.escape(
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
f" Got {illegal_bool_message}."
)
assert excepted in str(excinfo.value)
),
):
with_history.bound.invoke([HumanMessage(content="hello")], config)
illegal_int_message = 123
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
with pytest.raises(ValueError) as excinfo:
with_history.bound.invoke([HumanMessage(content="hello")], config)
excepted = (
with pytest.raises(
ValueError,
match=re.escape(
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
f" Got {illegal_int_message}."
)
assert excepted in str(excinfo.value)
),
):
with_history.bound.invoke([HumanMessage(content="hello")], config)

View File

@ -1,4 +1,5 @@
import asyncio
import re
import sys
import uuid
import warnings
@ -3827,13 +3828,13 @@ def test_retrying(mocker: MockerFixture) -> None:
_lambda_mock = mocker.Mock(side_effect=_lambda)
runnable = RunnableLambda(_lambda_mock)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="x is 1"):
runnable.invoke(1)
assert _lambda_mock.call_count == 1
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="x is 1"):
runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
@ -3852,7 +3853,7 @@ def test_retrying(mocker: MockerFixture) -> None:
assert _lambda_mock.call_count == 1 # did not retry
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="x is 1"):
runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
@ -3892,13 +3893,13 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
_lambda_mock = mocker.Mock(side_effect=_lambda)
runnable = RunnableLambda(_lambda_mock)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="x is 1"):
await runnable.ainvoke(1)
assert _lambda_mock.call_count == 1
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="x is 1"):
await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
@ -3918,7 +3919,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
assert _lambda_mock.call_count == 1 # did not retry
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="x is 1"):
await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
@ -3982,9 +3983,8 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
raise ValueError(msg)
# Check that the chain on error is invoked
with pytest.raises(ValueError):
for _ in RunnableLambda(raise_value_error).stream(1000, config=config):
pass
with pytest.raises(ValueError, match="x is too large"):
_ = list(RunnableLambda(raise_value_error).stream(1000, config=config))
assert len(tracer.runs) == 2
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
@ -4061,9 +4061,13 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
raise ValueError(msg)
# Check that the chain on error is invoked
with pytest.raises(ValueError):
async for _ in RunnableLambda(raise_value_error).astream(1000, config=config):
pass
with pytest.raises(ValueError, match="x is too large"):
_ = [
_
async for _ in RunnableLambda(raise_value_error).astream(
1000, config=config
)
]
assert len(tracer.runs) == 2
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
@ -4088,7 +4092,11 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
outputs: list[Any] = []
for input in inputs:
if input.startswith(self.fail_starts_with):
outputs.append(ValueError())
outputs.append(
ValueError(
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
)
)
else:
outputs.append(input + "a")
return outputs
@ -4119,7 +4127,9 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
assert isinstance(chain, RunnableSequence)
# Test batch
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match=re.escape("ControlledExceptionRunnable(bar) fail for bara")
):
chain.batch(["foo", "bar", "baz", "qux"])
spy = mocker.spy(ControlledExceptionRunnable, "batch")
@ -4155,32 +4165,44 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo"
assert repr(ValueError()) in str(parent_run_foo.error)
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
parent_run_foo.error
)
assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
None,
None,
None,
]
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
parent_run_foo.child_runs[-1].error
)
parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar"
assert repr(ValueError()) in str(parent_run_bar.error)
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
parent_run_bar.error
)
assert len(parent_run_bar.child_runs) == 2
assert parent_run_bar.child_runs[0].error is None
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
parent_run_bar.child_runs[1].error
)
parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz"
assert repr(ValueError()) in str(parent_run_baz.error)
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
parent_run_baz.error
)
assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
None,
None,
]
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
parent_run_baz.child_runs[-1].error
)
parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux"
@ -4209,7 +4231,11 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
outputs: list[Any] = []
for input in inputs:
if input.startswith(self.fail_starts_with):
outputs.append(ValueError())
outputs.append(
ValueError(
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
)
)
else:
outputs.append(input + "a")
return outputs
@ -4240,7 +4266,9 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
assert isinstance(chain, RunnableSequence)
# Test abatch
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match=re.escape("ControlledExceptionRunnable(bar) fail for bara")
):
await chain.abatch(["foo", "bar", "baz", "qux"])
spy = mocker.spy(ControlledExceptionRunnable, "abatch")
@ -4278,31 +4306,43 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo"
assert repr(ValueError()) in str(parent_run_foo.error)
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
parent_run_foo.error
)
assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
None,
None,
None,
]
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
parent_run_foo.child_runs[-1].error
)
parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar"
assert repr(ValueError()) in str(parent_run_bar.error)
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
parent_run_bar.error
)
assert len(parent_run_bar.child_runs) == 2
assert parent_run_bar.child_runs[0].error is None
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
parent_run_bar.child_runs[1].error
)
parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz"
assert repr(ValueError()) in str(parent_run_baz.error)
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
parent_run_baz.error
)
assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
None,
None,
]
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
parent_run_baz.child_runs[-1].error
)
parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux"
@ -4319,11 +4359,15 @@ def test_runnable_branch_init() -> None:
condition = RunnableLambda(lambda x: x > 0)
# Test failure with less than 2 branches
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="RunnableBranch requires at least two branches"
):
RunnableBranch((condition, add))
# Test failure with less than 2 branches
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="RunnableBranch requires at least two branches"
):
RunnableBranch(condition)
@ -4408,7 +4452,7 @@ def test_runnable_branch_invoke() -> None:
assert branch.invoke(10) == 100
assert branch.invoke(0) == -1
# Should raise an exception
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="x is too large"):
branch.invoke(1000)
@ -4472,7 +4516,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
assert tracer.runs[0].outputs == {"output": 0}
# Check that the chain on end is invoked
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="x is too large"):
branch.invoke(1000, config={"callbacks": [tracer]})
assert len(tracer.runs) == 2
@ -4500,7 +4544,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
assert tracer.runs[0].outputs == {"output": 0}
# Check that the chain on end is invoked
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="x is too large"):
await branch.ainvoke(1000, config={"callbacks": [tracer]})
assert len(tracer.runs) == 2
@ -4561,9 +4605,8 @@ def test_runnable_branch_stream_with_callbacks() -> None:
assert tracer.runs[0].outputs == {"output": llm_res}
# Verify that the chain on error is invoked
with pytest.raises(ValueError):
for _ in branch.stream("error", config=config):
pass
with pytest.raises(ValueError, match="x is error"):
_ = list(branch.stream("error", config=config))
assert len(tracer.runs) == 2
assert "ValueError('x is error')" in str(tracer.runs[1].error)
@ -4638,9 +4681,8 @@ async def test_runnable_branch_astream_with_callbacks() -> None:
assert tracer.runs[0].outputs == {"output": llm_res}
# Verify that the chain on error is invoked
with pytest.raises(ValueError):
async for _ in branch.astream("error", config=config):
pass
with pytest.raises(ValueError, match="x is error"):
_ = [_ async for _ in branch.astream("error", config=config)]
assert len(tracer.runs) == 2
assert "ValueError('x is error')" in str(tracer.runs[1].error)

View File

@ -1824,8 +1824,7 @@ async def test_runnable_each() -> None:
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
with pytest.raises(NotImplementedError):
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
pass
_ = [_ async for _ in add_one_map.astream_events([1, 2, 3], version="v1")]
async def test_events_astream_config() -> None:

View File

@ -1773,8 +1773,7 @@ async def test_runnable_each() -> None:
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
with pytest.raises(NotImplementedError):
async for _ in add_one_map.astream_events([1, 2, 3], version="v2"):
pass
_ = [_ async for _ in add_one_map.astream_events([1, 2, 3], version="v2")]
async def test_events_astream_config() -> None:

View File

@ -419,7 +419,7 @@ class TestRunnableSequenceParallelTraceNesting:
self._check_posts()
@pytest.mark.parametrize("parent_type", ("ls", "lc"))
@pytest.mark.parametrize("parent_type", ["ls", "lc"])
def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
mock_session = MagicMock()
mock_client_ = Client(

View File

@ -15,7 +15,7 @@ from langchain_core.runnables.utils import (
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
@pytest.mark.parametrize(
"func, expected_source",
("func", "expected_source"),
[
(lambda x: x * 2, "lambda x: x * 2"),
(lambda a, b: a + b, "lambda a, b: a + b"),
@ -29,7 +29,7 @@ def test_get_lambda_source(func: Callable, expected_source: str) -> None:
@pytest.mark.parametrize(
"text,prefix,expected_output",
("text", "prefix", "expected_output"),
[
("line 1\nline 2\nline 3", "1", "line 1\n line 2\n line 3"),
("line 1\nline 2\nline 3", "ax", "line 1\n line 2\n line 3"),

View File

@ -184,7 +184,9 @@ def test_chat_message_chunks() -> None:
"ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="Cannot concatenate ChatMessageChunks with different roles."
):
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
role="Assistant", content=" indeed."
)
@ -290,7 +292,10 @@ def test_function_message_chunks() -> None:
id="ai5", name="hello", content="I am indeed."
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Cannot concatenate FunctionMessageChunks with different names.",
):
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="bye", content=" indeed."
)
@ -303,7 +308,10 @@ def test_ai_message_chunks() -> None:
"AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Cannot concatenate AIMessageChunks with different example values.",
):
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=False, content=" indeed."
)
@ -320,30 +328,21 @@ class TestGetBufferString(unittest.TestCase):
self.tool_calls_msg = AIMessage(content="tool")
def test_empty_input(self) -> None:
self.assertEqual(get_buffer_string([]), "")
assert get_buffer_string([]) == ""
def test_valid_single_message(self) -> None:
expected_output = f"Human: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg]),
expected_output,
)
assert get_buffer_string([self.human_msg]) == expected_output
def test_custom_human_prefix(self) -> None:
prefix = "H"
expected_output = f"{prefix}: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg], human_prefix="H"),
expected_output,
)
assert get_buffer_string([self.human_msg], human_prefix="H") == expected_output
def test_custom_ai_prefix(self) -> None:
prefix = "A"
expected_output = f"{prefix}: {self.ai_msg.content}"
self.assertEqual(
get_buffer_string([self.ai_msg], ai_prefix="A"),
expected_output,
)
assert get_buffer_string([self.ai_msg], ai_prefix="A") == expected_output
def test_multiple_msg(self) -> None:
msgs = [
@ -366,10 +365,7 @@ class TestGetBufferString(unittest.TestCase):
"AI: tool",
]
)
self.assertEqual(
get_buffer_string(msgs),
expected_output,
)
assert get_buffer_string(msgs) == expected_output
def test_multiple_msg() -> None:
@ -991,7 +987,7 @@ def test_tool_message_str() -> None:
@pytest.mark.parametrize(
["first", "others", "expected"],
("first", "others", "expected"),
[
("", [""], ""),
("", [[]], [""]),

View File

@ -775,8 +775,8 @@ def test_exception_handling_callable() -> None:
def test_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError):
_tool = _FakeExceptionTool(exception=ValueError("some error"))
with pytest.raises(ValueError, match="some error"):
_tool.run({})
@ -806,8 +806,8 @@ async def test_async_exception_handling_callable() -> None:
async def test_async_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError):
_tool = _FakeExceptionTool(exception=ValueError("some error"))
with pytest.raises(ValueError, match="some error"):
await _tool.arun({})
@ -987,7 +987,7 @@ def test_optional_subset_model_rewrite() -> None:
@pytest.mark.parametrize(
"inputs, expected",
("inputs", "expected"),
[
# Check not required
({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}),
@ -1323,6 +1323,10 @@ def test_tool_invalid_docstrings() -> None:
"""
return bar
for func in {foo3, foo4}:
with pytest.raises(ValueError, match="Found invalid Google-Style docstring."):
_ = tool(func, parse_docstring=True)
def foo5(bar: str, baz: int) -> str:
"""The foo.
@ -1332,9 +1336,10 @@ def test_tool_invalid_docstrings() -> None:
"""
return bar
for func in [foo3, foo4, foo5]:
with pytest.raises(ValueError):
_ = tool(func, parse_docstring=True)
with pytest.raises(
ValueError, match="Arg banana in docstring not found in function signature."
):
_ = tool(foo5, parse_docstring=True)
def test_tool_annotated_descriptions() -> None:
@ -2004,9 +2009,9 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
@pytest.mark.parametrize(
("obj", "expected"),
[
["foo", True],
[valid_tool_result_blocks, True],
[invalid_tool_result_blocks, False],
("foo", True),
(valid_tool_result_blocks, True),
(invalid_tool_result_blocks, False),
],
)
def test__is_message_content_type(obj: Any, expected: bool) -> None:
@ -2268,7 +2273,8 @@ def test_imports() -> None:
"InjectedToolArg",
]
for module_name in expected_all:
assert hasattr(tools, module_name) and getattr(tools, module_name) is not None
assert hasattr(tools, module_name)
assert getattr(tools, module_name) is not None
def test_structured_tool_direct_init() -> None:
@ -2317,7 +2323,11 @@ def test_tool_injected_tool_call_id() -> None:
}
) == ToolMessage(0, tool_call_id="bar") # type: ignore
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="When tool includes an InjectedToolCallId argument, "
"tool must always be invoked with a full model ToolCall",
):
assert foo.invoke({"x": 0})
@tool
@ -2341,7 +2351,7 @@ def test_tool_uninjected_tool_call_id() -> None:
"""Foo."""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="1 validation error for foo"):
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
assert foo.invoke(

View File

@ -22,6 +22,5 @@ def test_public_api() -> None:
# Assert that the object is actually present in the schema module
for module_name in expected_all:
assert (
hasattr(schemas, module_name) and getattr(schemas, module_name) is not None
)
assert hasattr(schemas, module_name)
assert getattr(schemas, module_name) is not None

View File

@ -6,7 +6,7 @@ from langchain_core.utils.aiter import abatch_iterate
@pytest.mark.parametrize(
"input_size, input_iterable, expected_output",
("input_size", "input_iterable", "expected_output"),
[
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),

View File

@ -51,7 +51,12 @@ def test_get_from_dict_or_env() -> None:
# Not the most obvious behavior, but
# this is how it works right now
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Did not find not exists, "
"please add an environment variable `__SOME_KEY_IN_ENV` which contains it, "
"or pass `not exists` as a named parameter.",
):
assert (
get_from_dict_or_env(
{

View File

@ -37,7 +37,7 @@ from langchain_core.utils.function_calling import (
)
@pytest.fixture()
@pytest.fixture
def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801
"""Dummy function."""
@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]:
return dummy_function
@pytest.fixture()
@pytest.fixture
def annotated_function() -> Callable:
def dummy_function(
arg1: ExtensionsAnnotated[int, "foo"],
@ -59,7 +59,7 @@ def annotated_function() -> Callable:
return dummy_function
@pytest.fixture()
@pytest.fixture
def function() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""Dummy function.
@ -72,7 +72,7 @@ def function() -> Callable:
return dummy_function
@pytest.fixture()
@pytest.fixture
def function_docstring_annotations() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""Dummy function.
@ -85,7 +85,7 @@ def function_docstring_annotations() -> Callable:
return dummy_function
@pytest.fixture()
@pytest.fixture
def runnable() -> Runnable:
class Args(ExtensionsTypedDict):
arg1: ExtensionsAnnotated[int, "foo"]
@ -97,7 +97,7 @@ def runnable() -> Runnable:
return RunnableLambda(dummy_function)
@pytest.fixture()
@pytest.fixture
def dummy_tool() -> BaseTool:
class Schema(BaseModel):
arg1: int = Field(..., description="foo")
@ -114,7 +114,7 @@ def dummy_tool() -> BaseTool:
return DummyFunction()
@pytest.fixture()
@pytest.fixture
def dummy_structured_tool() -> StructuredTool:
class Schema(BaseModel):
arg1: int = Field(..., description="foo")
@ -128,7 +128,7 @@ def dummy_structured_tool() -> StructuredTool:
)
@pytest.fixture()
@pytest.fixture
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
args_schema = {
"type": "object",
@ -150,7 +150,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
)
@pytest.fixture()
@pytest.fixture
def dummy_pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801
"""Dummy function."""
@ -161,7 +161,7 @@ def dummy_pydantic() -> type[BaseModel]:
return dummy_function
@pytest.fixture()
@pytest.fixture
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
class dummy_function(BaseModelV2Maybe): # noqa: N801
"""Dummy function."""
@ -174,7 +174,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
return dummy_function
@pytest.fixture()
@pytest.fixture
def dummy_typing_typed_dict() -> type:
class dummy_function(TypingTypedDict): # noqa: N801
"""Dummy function."""
@ -185,7 +185,7 @@ def dummy_typing_typed_dict() -> type:
return dummy_function
@pytest.fixture()
@pytest.fixture
def dummy_typing_typed_dict_docstring() -> type:
class dummy_function(TypingTypedDict): # noqa: N801
"""Dummy function.
@ -201,7 +201,7 @@ def dummy_typing_typed_dict_docstring() -> type:
return dummy_function
@pytest.fixture()
@pytest.fixture
def dummy_extensions_typed_dict() -> type:
class dummy_function(ExtensionsTypedDict): # noqa: N801
"""Dummy function."""
@ -212,7 +212,7 @@ def dummy_extensions_typed_dict() -> type:
return dummy_function
@pytest.fixture()
@pytest.fixture
def dummy_extensions_typed_dict_docstring() -> type:
class dummy_function(ExtensionsTypedDict): # noqa: N801
"""Dummy function.
@ -228,7 +228,7 @@ def dummy_extensions_typed_dict_docstring() -> type:
return dummy_function
@pytest.fixture()
@pytest.fixture
def json_schema() -> dict:
return {
"title": "dummy_function",
@ -246,7 +246,7 @@ def json_schema() -> dict:
}
@pytest.fixture()
@pytest.fixture
def anthropic_tool() -> dict:
return {
"name": "dummy_function",
@ -266,7 +266,7 @@ def anthropic_tool() -> dict:
}
@pytest.fixture()
@pytest.fixture
def bedrock_converse_tool() -> dict:
return {
"toolSpec": {

View File

@ -4,7 +4,7 @@ from langchain_core.utils.iter import batch_iterate
@pytest.mark.parametrize(
"input_size, input_iterable, expected_output",
("input_size", "input_iterable", "expected_output"),
[
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),

View File

@ -147,7 +147,7 @@ def test_dereference_refs_remote_ref() -> None:
"first_name": {"$ref": "https://somewhere/else/name"},
},
}
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="ref paths are expected to be URI fragments"):
dereference_refs(schema)

View File

@ -220,7 +220,7 @@ output5 = {
@pytest.mark.parametrize(
"schema, output",
("schema", "output"),
[
(schema1, output1),
(schema2, output2),

View File

@ -29,12 +29,17 @@ def test_dict_int_op_nested() -> None:
def test_dict_int_op_max_depth_exceeded() -> None:
left = {"a": {"b": {"c": 1}}}
right = {"a": {"b": {"c": 2}}}
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="max_depth=2 exceeded, unable to combine dicts."
):
_dict_int_op(left, right, operator.add, max_depth=2)
def test_dict_int_op_invalid_types() -> None:
left = {"a": 1, "b": "string"}
right = {"a": 2, "b": 3}
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Only dict and int values are supported.",
):
_dict_int_op(left, right, operator.add)

View File

@ -46,7 +46,7 @@ def test_check_package_version(
@pytest.mark.parametrize(
("left", "right", "expected"),
(
[
# Merge `None` and `1`.
({"a": None}, {"a": 1}, {"a": 1}),
# Merge `1` and `None`.
@ -111,7 +111,7 @@ def test_check_package_version(
{"a": [{"idx": 0, "b": "f"}]},
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
),
),
],
)
def test_merge_dicts(
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
@ -130,7 +130,7 @@ def test_merge_dicts(
@pytest.mark.parametrize(
("left", "right", "expected"),
(
[
# 'type' special key handling
({"type": "foo"}, {"type": "foo"}, {"type": "foo"}),
(
@ -138,7 +138,7 @@ def test_merge_dicts(
{"type": "bar"},
pytest.raises(ValueError, match="Unable to merge."),
),
),
],
)
@pytest.mark.xfail(reason="Refactors to make in 0.3")
def test_merge_dicts_0_3(
@ -183,36 +183,32 @@ def test_guard_import(
@pytest.mark.parametrize(
("module_name", "pip_name", "package"),
("module_name", "pip_name", "package", "expected_pip_name"),
[
("langchain_core.utilsW", None, None),
("langchain_core.utilsW", "langchain-core-2", None),
("langchain_core.utilsW", None, "langchain-coreWX"),
("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"),
("langchain_coreW", None, None), # ModuleNotFoundError
("langchain_core.utilsW", None, None, "langchain-core"),
("langchain_core.utilsW", "langchain-core-2", None, "langchain-core-2"),
("langchain_core.utilsW", None, "langchain-coreWX", "langchain-core"),
(
"langchain_core.utilsW",
"langchain-core-2",
"langchain-coreWX",
"langchain-core-2",
),
("langchain_coreW", None, None, "langchain-coreW"), # ModuleNotFoundError
],
)
def test_guard_import_failure(
module_name: str, pip_name: Optional[str], package: Optional[str]
module_name: str,
pip_name: Optional[str],
package: Optional[str],
expected_pip_name: str,
) -> None:
with pytest.raises(ImportError) as exc_info:
if package is None and pip_name is None:
guard_import(module_name)
elif package is None and pip_name is not None:
guard_import(module_name, pip_name=pip_name)
elif package is not None and pip_name is None:
guard_import(module_name, package=package)
elif package is not None and pip_name is not None:
with pytest.raises(
ImportError,
match=f"Could not import {module_name} python package. "
f"Please install it with `pip install {expected_pip_name}`.",
):
guard_import(module_name, pip_name=pip_name, package=package)
else:
msg = "Invalid test case"
raise ValueError(msg)
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
err_msg = (
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name}`."
)
assert exc_info.value.msg == err_msg
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")