mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
core: Add ruff rules PT (pytest) (#29381)
See https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
This commit is contained in:
parent
6896c863e8
commit
8a33402016
@ -103,7 +103,6 @@ ignore = [
|
|||||||
"PLC",
|
"PLC",
|
||||||
"PLE",
|
"PLE",
|
||||||
"PLR",
|
"PLR",
|
||||||
"PT",
|
|
||||||
"PYI",
|
"PYI",
|
||||||
"RET",
|
"RET",
|
||||||
"RUF",
|
"RUF",
|
||||||
|
@ -9,7 +9,7 @@ from langchain_core._api.beta_decorator import beta, warn_beta
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"kwargs, expected_message",
|
("kwargs", "expected_message"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
|
@ -13,7 +13,7 @@ from langchain_core._api.deprecation import (
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"kwargs, expected_message",
|
("kwargs", "expected_message"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
@ -404,7 +404,9 @@ def test_deprecated_method_pydantic() -> None:
|
|||||||
def test_raise_error_for_bad_decorator() -> None:
|
def test_raise_error_for_bad_decorator() -> None:
|
||||||
"""Verify that errors raised on init rather than on use."""
|
"""Verify that errors raised on init rather than on use."""
|
||||||
# Should not specify both `alternative` and `alternative_import`
|
# Should not specify both `alternative` and `alternative_import`
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match="Cannot specify both alternative and alternative_import"
|
||||||
|
):
|
||||||
|
|
||||||
@deprecated(since="2.0.0", alternative="NewClass", alternative_import="hello")
|
@deprecated(since="2.0.0", alternative="NewClass", alternative_import="hello")
|
||||||
def deprecated_function() -> str:
|
def deprecated_function() -> str:
|
||||||
|
@ -28,7 +28,7 @@ def test_initialization() -> None:
|
|||||||
assert cache_with_maxsize._cache == {}
|
assert cache_with_maxsize._cache == {}
|
||||||
assert cache_with_maxsize._maxsize == 2
|
assert cache_with_maxsize._maxsize == 2
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="maxsize must be greater than 0"):
|
||||||
InMemoryCache(maxsize=0)
|
InMemoryCache(maxsize=0)
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ from uuid import UUID
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from blockbuster import BlockBuster, blockbuster_ctx
|
from blockbuster import BlockBuster, blockbuster_ctx
|
||||||
from pytest import Config, Function, Parser
|
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +35,7 @@ def blockbuster() -> Iterator[BlockBuster]:
|
|||||||
yield bb
|
yield bb
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser: Parser) -> None:
|
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||||
"""Add custom command line options to pytest."""
|
"""Add custom command line options to pytest."""
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--only-extended",
|
"--only-extended",
|
||||||
@ -50,7 +49,9 @@ def pytest_addoption(parser: Parser) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
def pytest_collection_modifyitems(
|
||||||
|
config: pytest.Config, items: Sequence[pytest.Function]
|
||||||
|
) -> None:
|
||||||
"""Add implementations for handling custom markers.
|
"""Add implementations for handling custom markers.
|
||||||
|
|
||||||
At the moment, this adds support for a custom `requires` marker.
|
At the moment, this adds support for a custom `requires` marker.
|
||||||
@ -118,7 +119,7 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:
|
def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:
|
||||||
side_effect = (
|
side_effect = (
|
||||||
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
|
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
|
||||||
|
@ -16,16 +16,16 @@ from langchain_core.indexing.in_memory import (
|
|||||||
|
|
||||||
|
|
||||||
class TestDocumentIndexerTestSuite(DocumentIndexerTestSuite):
|
class TestDocumentIndexerTestSuite(DocumentIndexerTestSuite):
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def index(self) -> Generator[DocumentIndex, None, None]:
|
def index(self) -> Generator[DocumentIndex, None, None]:
|
||||||
yield InMemoryDocumentIndex()
|
yield InMemoryDocumentIndex() # noqa: PT022
|
||||||
|
|
||||||
|
|
||||||
class TestAsyncDocumentIndexerTestSuite(AsyncDocumentIndexTestSuite):
|
class TestAsyncDocumentIndexerTestSuite(AsyncDocumentIndexTestSuite):
|
||||||
# Something funky is going on with mypy and async pytest fixture
|
# Something funky is going on with mypy and async pytest fixture
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
async def index(self) -> AsyncGenerator[DocumentIndex, None]: # type: ignore
|
async def index(self) -> AsyncGenerator[DocumentIndex, None]: # type: ignore
|
||||||
yield InMemoryDocumentIndex()
|
yield InMemoryDocumentIndex() # noqa: PT022
|
||||||
|
|
||||||
|
|
||||||
def test_sync_retriever() -> None:
|
def test_sync_retriever() -> None:
|
||||||
|
@ -7,7 +7,7 @@ import pytest_asyncio
|
|||||||
from langchain_core.indexing import InMemoryRecordManager
|
from langchain_core.indexing import InMemoryRecordManager
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def manager() -> InMemoryRecordManager:
|
def manager() -> InMemoryRecordManager:
|
||||||
"""Initialize the test database and yield the TimestampedSet instance."""
|
"""Initialize the test database and yield the TimestampedSet instance."""
|
||||||
# Initialize and yield the TimestampedSet instance
|
# Initialize and yield the TimestampedSet instance
|
||||||
|
@ -466,11 +466,19 @@ def test_incremental_fails_with_bad_source_ids(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Source id key is required when cleanup mode is "
|
||||||
|
"incremental or scoped_full.",
|
||||||
|
):
|
||||||
# Should raise an error because no source id function was specified
|
# Should raise an error because no source id function was specified
|
||||||
index(loader, record_manager, vector_store, cleanup="incremental")
|
index(loader, record_manager, vector_store, cleanup="incremental")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Source ids are required when cleanup mode "
|
||||||
|
"is incremental or scoped_full.",
|
||||||
|
):
|
||||||
# Should raise an error because no source id function was specified
|
# Should raise an error because no source id function was specified
|
||||||
index(
|
index(
|
||||||
loader,
|
loader,
|
||||||
@ -502,7 +510,11 @@ async def test_aincremental_fails_with_bad_source_ids(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Source id key is required when cleanup mode "
|
||||||
|
"is incremental or scoped_full.",
|
||||||
|
):
|
||||||
# Should raise an error because no source id function was specified
|
# Should raise an error because no source id function was specified
|
||||||
await aindex(
|
await aindex(
|
||||||
loader,
|
loader,
|
||||||
@ -511,7 +523,11 @@ async def test_aincremental_fails_with_bad_source_ids(
|
|||||||
cleanup="incremental",
|
cleanup="incremental",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Source ids are required when cleanup mode "
|
||||||
|
"is incremental or scoped_full.",
|
||||||
|
):
|
||||||
# Should raise an error because no source id function was specified
|
# Should raise an error because no source id function was specified
|
||||||
await aindex(
|
await aindex(
|
||||||
loader,
|
loader,
|
||||||
@ -771,11 +787,19 @@ def test_scoped_full_fails_with_bad_source_ids(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Source id key is required when cleanup mode "
|
||||||
|
"is incremental or scoped_full.",
|
||||||
|
):
|
||||||
# Should raise an error because no source id function was specified
|
# Should raise an error because no source id function was specified
|
||||||
index(loader, record_manager, vector_store, cleanup="scoped_full")
|
index(loader, record_manager, vector_store, cleanup="scoped_full")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Source ids are required when cleanup mode "
|
||||||
|
"is incremental or scoped_full.",
|
||||||
|
):
|
||||||
# Should raise an error because no source id function was specified
|
# Should raise an error because no source id function was specified
|
||||||
index(
|
index(
|
||||||
loader,
|
loader,
|
||||||
@ -807,11 +831,19 @@ async def test_ascoped_full_fails_with_bad_source_ids(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Source id key is required when cleanup mode "
|
||||||
|
"is incremental or scoped_full.",
|
||||||
|
):
|
||||||
# Should raise an error because no source id function was specified
|
# Should raise an error because no source id function was specified
|
||||||
await aindex(loader, arecord_manager, vector_store, cleanup="scoped_full")
|
await aindex(loader, arecord_manager, vector_store, cleanup="scoped_full")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Source ids are required when cleanup mode "
|
||||||
|
"is incremental or scoped_full.",
|
||||||
|
):
|
||||||
# Should raise an error because no source id function was specified
|
# Should raise an error because no source id function was specified
|
||||||
await aindex(
|
await aindex(
|
||||||
loader,
|
loader,
|
||||||
|
@ -99,6 +99,7 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None:
|
|||||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="This test is failing due to a bug in the testing code")
|
||||||
async def test_stream_error_callback() -> None:
|
async def test_stream_error_callback() -> None:
|
||||||
message = "test"
|
message = "test"
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import typing
|
import typing
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
@ -513,7 +514,14 @@ def test_trim_messages_bound_model_token_counter() -> None:
|
|||||||
|
|
||||||
def test_trim_messages_bad_token_counter() -> None:
|
def test_trim_messages_bad_token_counter() -> None:
|
||||||
trimmer = trim_messages(max_tokens=10, token_counter={})
|
trimmer = trim_messages(max_tokens=10, token_counter={})
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape(
|
||||||
|
"'token_counter' expected to be a model that implements "
|
||||||
|
"'get_num_tokens_from_messages()' or a function. "
|
||||||
|
"Received object of type <class 'dict'>."
|
||||||
|
),
|
||||||
|
):
|
||||||
trimmer.invoke([HumanMessage("foobar")])
|
trimmer.invoke([HumanMessage("foobar")])
|
||||||
|
|
||||||
|
|
||||||
@ -852,7 +860,9 @@ def test_convert_to_messages_openai_refusal() -> None:
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
# Raises error if content is missing.
|
# Raises error if content is missing.
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match="Message dict must contain 'role' and 'content' keys"
|
||||||
|
):
|
||||||
convert_to_messages([{"role": "assistant", "refusal": "9.1"}])
|
convert_to_messages([{"role": "assistant", "refusal": "9.1"}])
|
||||||
|
|
||||||
|
|
||||||
|
@ -157,9 +157,10 @@ def test_pydantic_output_parser_fail() -> None:
|
|||||||
pydantic_object=TestModel
|
pydantic_object=TestModel
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(OutputParserException) as e:
|
with pytest.raises(
|
||||||
|
OutputParserException, match="Failed to parse TestModel from completion"
|
||||||
|
):
|
||||||
pydantic_parser.parse(DEF_RESULT_FAIL)
|
pydantic_parser.parse(DEF_RESULT_FAIL)
|
||||||
assert "Failed to parse TestModel from completion" in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
def test_pydantic_output_parser_type_inference() -> None:
|
def test_pydantic_output_parser_type_inference() -> None:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union, cast
|
from typing import Any, Union, cast
|
||||||
@ -165,15 +166,14 @@ def test_create_system_message_prompt_list_template_partial_variables_not_null()
|
|||||||
{variables}
|
{variables}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
with pytest.raises(
|
||||||
graph_analyst_template = SystemMessagePromptTemplate.from_template(
|
ValueError, match="Partial variables are not supported for list of templates."
|
||||||
|
):
|
||||||
|
_ = SystemMessagePromptTemplate.from_template(
|
||||||
template=[graph_creator_content1, graph_creator_content2],
|
template=[graph_creator_content1, graph_creator_content2],
|
||||||
input_variables=["variables"],
|
input_variables=["variables"],
|
||||||
partial_variables={"variables": "foo"},
|
partial_variables={"variables": "foo"},
|
||||||
)
|
)
|
||||||
graph_analyst_template.format(variables="foo")
|
|
||||||
except ValueError as e:
|
|
||||||
assert str(e) == "Partial variables are not supported for list of templates."
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_prompt_template_from_template_file() -> None:
|
def test_message_prompt_template_from_template_file() -> None:
|
||||||
@ -330,7 +330,7 @@ def test_chat_prompt_template_from_messages_jinja2() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("jinja2")
|
@pytest.mark.requires("jinja2")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"template_format,image_type_placeholder,image_data_placeholder",
|
("template_format", "image_type_placeholder", "image_data_placeholder"),
|
||||||
[
|
[
|
||||||
("f-string", "{image_type}", "{image_data}"),
|
("f-string", "{image_type}", "{image_data}"),
|
||||||
("mustache", "{{image_type}}", "{{image_data}}"),
|
("mustache", "{{image_type}}", "{{image_data}}"),
|
||||||
@ -393,7 +393,12 @@ def test_chat_prompt_template_with_messages(
|
|||||||
|
|
||||||
def test_chat_invalid_input_variables_extra() -> None:
|
def test_chat_invalid_input_variables_extra() -> None:
|
||||||
messages = [HumanMessage(content="foo")]
|
messages = [HumanMessage(content="foo")]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape(
|
||||||
|
"Got mismatched input_variables. Expected: set(). Got: ['foo']"
|
||||||
|
),
|
||||||
|
):
|
||||||
ChatPromptTemplate(
|
ChatPromptTemplate(
|
||||||
messages=messages, # type: ignore[arg-type]
|
messages=messages, # type: ignore[arg-type]
|
||||||
input_variables=["foo"],
|
input_variables=["foo"],
|
||||||
@ -407,7 +412,10 @@ def test_chat_invalid_input_variables_extra() -> None:
|
|||||||
|
|
||||||
def test_chat_invalid_input_variables_missing() -> None:
|
def test_chat_invalid_input_variables_missing() -> None:
|
||||||
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
|
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape("Got mismatched input_variables. Expected: {'foo'}. Got: []"),
|
||||||
|
):
|
||||||
ChatPromptTemplate(
|
ChatPromptTemplate(
|
||||||
messages=messages, # type: ignore[arg-type]
|
messages=messages, # type: ignore[arg-type]
|
||||||
input_variables=[],
|
input_variables=[],
|
||||||
@ -481,7 +489,7 @@ async def test_chat_from_role_strings() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"args,expected",
|
("args", "expected"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
("human", "{question}"),
|
("human", "{question}"),
|
||||||
@ -551,7 +559,7 @@ def test_chat_prompt_template_append_and_extend() -> None:
|
|||||||
|
|
||||||
def test_convert_to_message_is_strict() -> None:
|
def test_convert_to_message_is_strict() -> None:
|
||||||
"""Verify that _convert_to_message is strict."""
|
"""Verify that _convert_to_message is strict."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="Unexpected message type: meow."):
|
||||||
# meow does not correspond to a valid message type.
|
# meow does not correspond to a valid message type.
|
||||||
# this test is here to ensure that functionality to interpret `meow`
|
# this test is here to ensure that functionality to interpret `meow`
|
||||||
# as a role is NOT added.
|
# as a role is NOT added.
|
||||||
@ -762,14 +770,20 @@ async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Loading images from 'path' has been removed as of 0.3.15 for security reasons.",
|
||||||
|
):
|
||||||
template.format_messages(
|
template.format_messages(
|
||||||
name="R2D2",
|
name="R2D2",
|
||||||
in_mem=in_mem,
|
in_mem=in_mem,
|
||||||
file_path="some/path",
|
file_path="some/path",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Loading images from 'path' has been removed as of 0.3.15 for security reasons.",
|
||||||
|
):
|
||||||
await template.aformat_messages(
|
await template.aformat_messages(
|
||||||
name="R2D2",
|
name="R2D2",
|
||||||
in_mem=in_mem,
|
in_mem=in_mem,
|
||||||
@ -869,10 +883,10 @@ def test_chat_prompt_message_dict() -> None:
|
|||||||
HumanMessage(content="bar"),
|
HumanMessage(content="bar"),
|
||||||
]
|
]
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="Invalid template: False"):
|
||||||
ChatPromptTemplate([{"role": "system", "content": False}])
|
ChatPromptTemplate([{"role": "system", "content": False}])
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="Unexpected message type: foo."):
|
||||||
ChatPromptTemplate([{"role": "foo", "content": "foo"}])
|
ChatPromptTemplate([{"role": "foo", "content": "foo"}])
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Test few shot prompt template."""
|
"""Test few shot prompt template."""
|
||||||
|
|
||||||
|
import re
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ EXAMPLE_PROMPT = PromptTemplate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
@pytest.mark.requires("jinja2")
|
@pytest.mark.requires("jinja2")
|
||||||
def example_jinja2_prompt() -> tuple[PromptTemplate, list[dict[str, str]]]:
|
def example_jinja2_prompt() -> tuple[PromptTemplate, list[dict[str, str]]]:
|
||||||
example_template = "{{ word }}: {{ antonym }}"
|
example_template = "{{ word }}: {{ antonym }}"
|
||||||
@ -74,7 +75,10 @@ def test_prompt_missing_input_variables() -> None:
|
|||||||
"""Test error is raised when input variables are not provided."""
|
"""Test error is raised when input variables are not provided."""
|
||||||
# Test when missing in suffix
|
# Test when missing in suffix
|
||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape("check for mismatched or missing input parameters from []"),
|
||||||
|
):
|
||||||
FewShotPromptTemplate(
|
FewShotPromptTemplate(
|
||||||
input_variables=[],
|
input_variables=[],
|
||||||
suffix=template,
|
suffix=template,
|
||||||
@ -91,7 +95,10 @@ def test_prompt_missing_input_variables() -> None:
|
|||||||
|
|
||||||
# Test when missing in prefix
|
# Test when missing in prefix
|
||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape("check for mismatched or missing input parameters from []"),
|
||||||
|
):
|
||||||
FewShotPromptTemplate(
|
FewShotPromptTemplate(
|
||||||
input_variables=[],
|
input_variables=[],
|
||||||
suffix="foo",
|
suffix="foo",
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Test few shot prompt template."""
|
"""Test few shot prompt template."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
||||||
@ -58,7 +60,10 @@ def test_prompttemplate_validation() -> None:
|
|||||||
{"question": "foo", "answer": "bar"},
|
{"question": "foo", "answer": "bar"},
|
||||||
{"question": "baz", "answer": "foo"},
|
{"question": "baz", "answer": "foo"},
|
||||||
]
|
]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape("Got input_variables=[], but based on prefix/suffix expected"),
|
||||||
|
):
|
||||||
FewShotPromptWithTemplates(
|
FewShotPromptWithTemplates(
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Test functionality related to prompts."""
|
"""Test functionality related to prompts."""
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
@ -264,7 +265,10 @@ def test_prompt_missing_input_variables() -> None:
|
|||||||
"""Test error is raised when input variables are not provided."""
|
"""Test error is raised when input variables are not provided."""
|
||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
input_variables: list = []
|
input_variables: list = []
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape("check for mismatched or missing input parameters from []"),
|
||||||
|
):
|
||||||
PromptTemplate(
|
PromptTemplate(
|
||||||
input_variables=input_variables, template=template, validate_template=True
|
input_variables=input_variables, template=template, validate_template=True
|
||||||
)
|
)
|
||||||
@ -275,7 +279,10 @@ def test_prompt_missing_input_variables() -> None:
|
|||||||
|
|
||||||
def test_prompt_empty_input_variable() -> None:
|
def test_prompt_empty_input_variable() -> None:
|
||||||
"""Test error is raised when empty string input variable."""
|
"""Test error is raised when empty string input variable."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape("check for mismatched or missing input parameters from ['']"),
|
||||||
|
):
|
||||||
PromptTemplate(input_variables=[""], template="{}", validate_template=True)
|
PromptTemplate(input_variables=[""], template="{}", validate_template=True)
|
||||||
|
|
||||||
|
|
||||||
@ -283,7 +290,13 @@ def test_prompt_wrong_input_variables() -> None:
|
|||||||
"""Test error is raised when name of input variable is wrong."""
|
"""Test error is raised when name of input variable is wrong."""
|
||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
input_variables = ["bar"]
|
input_variables = ["bar"]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape(
|
||||||
|
"Invalid prompt schema; "
|
||||||
|
"check for mismatched or missing input parameters from ['bar']"
|
||||||
|
),
|
||||||
|
):
|
||||||
PromptTemplate(
|
PromptTemplate(
|
||||||
input_variables=input_variables, template=template, validate_template=True
|
input_variables=input_variables, template=template, validate_template=True
|
||||||
)
|
)
|
||||||
@ -330,7 +343,7 @@ def test_prompt_invalid_template_format() -> None:
|
|||||||
"""Test initializing a prompt with invalid template format."""
|
"""Test initializing a prompt with invalid template format."""
|
||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
input_variables = ["foo"]
|
input_variables = ["foo"]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="Unsupported template format: bar"):
|
||||||
PromptTemplate(
|
PromptTemplate(
|
||||||
input_variables=input_variables,
|
input_variables=input_variables,
|
||||||
template=template,
|
template=template,
|
||||||
@ -580,7 +593,7 @@ async def test_prompt_ainvoke_with_metadata() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"value, expected",
|
("value", "expected"),
|
||||||
[
|
[
|
||||||
("0", "0"),
|
("0", "0"),
|
||||||
(0, "0"),
|
(0, "0"),
|
||||||
|
@ -330,7 +330,7 @@ test_cases = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("runnable, cases", test_cases)
|
@pytest.mark.parametrize(("runnable", "cases"), test_cases)
|
||||||
def test_context_runnables(
|
def test_context_runnables(
|
||||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
|
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -342,7 +342,7 @@ def test_context_runnables(
|
|||||||
assert add(runnable.stream(cases[0].input)) == cases[0].output
|
assert add(runnable.stream(cases[0].input)) == cases[0].output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("runnable, cases", test_cases)
|
@pytest.mark.parametrize(("runnable", "cases"), test_cases)
|
||||||
async def test_context_runnables_async(
|
async def test_context_runnables_async(
|
||||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
|
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -357,14 +357,19 @@ async def test_context_runnables_async(
|
|||||||
def test_runnable_context_seq_key_not_found() -> None:
|
def test_runnable_context_seq_key_not_found() -> None:
|
||||||
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
|
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match="Expected exactly one setter for context key foo"
|
||||||
|
):
|
||||||
seq.invoke("foo")
|
seq.invoke("foo")
|
||||||
|
|
||||||
|
|
||||||
def test_runnable_context_seq_key_order() -> None:
|
def test_runnable_context_seq_key_order() -> None:
|
||||||
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
|
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Context setter for key foo must be defined after all getters.",
|
||||||
|
):
|
||||||
seq.invoke("foo")
|
seq.invoke("foo")
|
||||||
|
|
||||||
|
|
||||||
@ -374,7 +379,9 @@ def test_runnable_context_deadlock() -> None:
|
|||||||
"foo": Context.setter("foo") | Context.getter("input"),
|
"foo": Context.setter("foo") | Context.getter("input"),
|
||||||
} | RunnablePassthrough()
|
} | RunnablePassthrough()
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match="Deadlock detected between context keys foo and input"
|
||||||
|
):
|
||||||
seq.invoke("foo")
|
seq.invoke("foo")
|
||||||
|
|
||||||
|
|
||||||
@ -383,7 +390,9 @@ def test_runnable_context_seq_key_circular_ref() -> None:
|
|||||||
"bar": Context.setter(input=Context.getter("input"))
|
"bar": Context.setter(input=Context.getter("input"))
|
||||||
} | Context.getter("foo")
|
} | Context.getter("foo")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match="Circular reference in context setter for key input"
|
||||||
|
):
|
||||||
seq.invoke("foo")
|
seq.invoke("foo")
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ from langchain_core.runnables import (
|
|||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def llm() -> RunnableWithFallbacks:
|
def llm() -> RunnableWithFallbacks:
|
||||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
pass_llm = FakeListLLM(responses=["bar"])
|
pass_llm = FakeListLLM(responses=["bar"])
|
||||||
@ -40,7 +40,7 @@ def llm() -> RunnableWithFallbacks:
|
|||||||
return error_llm.with_fallbacks([pass_llm])
|
return error_llm.with_fallbacks([pass_llm])
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def llm_multi() -> RunnableWithFallbacks:
|
def llm_multi() -> RunnableWithFallbacks:
|
||||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
||||||
@ -49,7 +49,7 @@ def llm_multi() -> RunnableWithFallbacks:
|
|||||||
return error_llm.with_fallbacks([error_llm_2, pass_llm])
|
return error_llm.with_fallbacks([error_llm_2, pass_llm])
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def chain() -> Runnable:
|
def chain() -> Runnable:
|
||||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
pass_llm = FakeListLLM(responses=["bar"])
|
pass_llm = FakeListLLM(responses=["bar"])
|
||||||
@ -70,7 +70,7 @@ def _dont_raise_error(inputs: dict) -> str:
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def chain_pass_exceptions() -> Runnable:
|
def chain_pass_exceptions() -> Runnable:
|
||||||
fallback = RunnableLambda(_dont_raise_error)
|
fallback = RunnableLambda(_dont_raise_error)
|
||||||
return {"text": RunnablePassthrough()} | RunnableLambda(
|
return {"text": RunnablePassthrough()} | RunnableLambda(
|
||||||
@ -107,7 +107,8 @@ def _runnable(inputs: dict) -> str:
|
|||||||
if inputs["text"] == "foo":
|
if inputs["text"] == "foo":
|
||||||
return "first"
|
return "first"
|
||||||
if "exception" not in inputs:
|
if "exception" not in inputs:
|
||||||
raise ValueError
|
msg = "missing exception"
|
||||||
|
raise ValueError(msg)
|
||||||
if inputs["text"] == "bar":
|
if inputs["text"] == "bar":
|
||||||
return "second"
|
return "second"
|
||||||
if isinstance(inputs["exception"], ValueError):
|
if isinstance(inputs["exception"], ValueError):
|
||||||
@ -128,7 +129,7 @@ def test_invoke_with_exception_key() -> None:
|
|||||||
runnable_with_single = runnable.with_fallbacks(
|
runnable_with_single = runnable.with_fallbacks(
|
||||||
[runnable], exception_key="exception"
|
[runnable], exception_key="exception"
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="missing exception"):
|
||||||
runnable_with_single.invoke({"text": "baz"})
|
runnable_with_single.invoke({"text": "baz"})
|
||||||
|
|
||||||
actual = runnable_with_single.invoke({"text": "bar"})
|
actual = runnable_with_single.invoke({"text": "bar"})
|
||||||
@ -149,7 +150,7 @@ async def test_ainvoke_with_exception_key() -> None:
|
|||||||
runnable_with_single = runnable.with_fallbacks(
|
runnable_with_single = runnable.with_fallbacks(
|
||||||
[runnable], exception_key="exception"
|
[runnable], exception_key="exception"
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="missing exception"):
|
||||||
await runnable_with_single.ainvoke({"text": "baz"})
|
await runnable_with_single.ainvoke({"text": "baz"})
|
||||||
|
|
||||||
actual = await runnable_with_single.ainvoke({"text": "bar"})
|
actual = await runnable_with_single.ainvoke({"text": "bar"})
|
||||||
@ -166,7 +167,7 @@ async def test_ainvoke_with_exception_key() -> None:
|
|||||||
|
|
||||||
def test_batch() -> None:
|
def test_batch() -> None:
|
||||||
runnable = RunnableLambda(_runnable)
|
runnable = RunnableLambda(_runnable)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="missing exception"):
|
||||||
runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||||
actual = runnable.batch(
|
actual = runnable.batch(
|
||||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||||
@ -210,7 +211,7 @@ def test_batch() -> None:
|
|||||||
|
|
||||||
async def test_abatch() -> None:
|
async def test_abatch() -> None:
|
||||||
runnable = RunnableLambda(_runnable)
|
runnable = RunnableLambda(_runnable)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="missing exception"):
|
||||||
await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||||
actual = await runnable.abatch(
|
actual = await runnable.abatch(
|
||||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||||
@ -263,13 +264,15 @@ def _generate(input: Iterator) -> Iterator[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
||||||
raise ValueError
|
msg = "immmediate error"
|
||||||
|
raise ValueError(msg)
|
||||||
yield ""
|
yield ""
|
||||||
|
|
||||||
|
|
||||||
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
||||||
yield ""
|
yield ""
|
||||||
raise ValueError
|
msg = "delayed error"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
def test_fallbacks_stream() -> None:
|
def test_fallbacks_stream() -> None:
|
||||||
@ -278,10 +281,10 @@ def test_fallbacks_stream() -> None:
|
|||||||
)
|
)
|
||||||
assert list(runnable.stream({})) == list("foo bar")
|
assert list(runnable.stream({})) == list("foo bar")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(
|
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(
|
||||||
[RunnableGenerator(_generate)]
|
[RunnableGenerator(_generate)]
|
||||||
)
|
)
|
||||||
|
with pytest.raises(ValueError, match="delayed error"):
|
||||||
list(runnable.stream({}))
|
list(runnable.stream({}))
|
||||||
|
|
||||||
|
|
||||||
@ -291,13 +294,15 @@ async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
|
|||||||
|
|
||||||
|
|
||||||
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||||
raise ValueError
|
msg = "immmediate error"
|
||||||
|
raise ValueError(msg)
|
||||||
yield ""
|
yield ""
|
||||||
|
|
||||||
|
|
||||||
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||||
yield ""
|
yield ""
|
||||||
raise ValueError
|
msg = "delayed error"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
async def test_fallbacks_astream() -> None:
|
async def test_fallbacks_astream() -> None:
|
||||||
@ -308,12 +313,11 @@ async def test_fallbacks_astream() -> None:
|
|||||||
async for c in runnable.astream({}):
|
async for c in runnable.astream({}):
|
||||||
assert c == next(expected)
|
assert c == next(expected)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
|
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
|
||||||
[RunnableGenerator(_agenerate)]
|
[RunnableGenerator(_agenerate)]
|
||||||
)
|
)
|
||||||
async for _ in runnable.astream({}):
|
with pytest.raises(ValueError, match="delayed error"):
|
||||||
pass
|
_ = [_ async for _ in runnable.astream({})]
|
||||||
|
|
||||||
|
|
||||||
class FakeStructuredOutputModel(BaseChatModel):
|
class FakeStructuredOutputModel(BaseChatModel):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import re
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
@ -864,22 +865,24 @@ def test_get_output_messages_with_value_error() -> None:
|
|||||||
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
|
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(
|
||||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
ValueError,
|
||||||
excepted = (
|
match=re.escape(
|
||||||
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
|
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
|
||||||
f" Got {illegal_bool_message}."
|
f" Got {illegal_bool_message}."
|
||||||
)
|
),
|
||||||
assert excepted in str(excinfo.value)
|
):
|
||||||
|
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||||
|
|
||||||
illegal_int_message = 123
|
illegal_int_message = 123
|
||||||
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
|
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
|
||||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||||
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(
|
||||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
ValueError,
|
||||||
excepted = (
|
match=re.escape(
|
||||||
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
|
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
|
||||||
f" Got {illegal_int_message}."
|
f" Got {illegal_int_message}."
|
||||||
)
|
),
|
||||||
assert excepted in str(excinfo.value)
|
):
|
||||||
|
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
@ -3827,13 +3828,13 @@ def test_retrying(mocker: MockerFixture) -> None:
|
|||||||
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
||||||
runnable = RunnableLambda(_lambda_mock)
|
runnable = RunnableLambda(_lambda_mock)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is 1"):
|
||||||
runnable.invoke(1)
|
runnable.invoke(1)
|
||||||
|
|
||||||
assert _lambda_mock.call_count == 1
|
assert _lambda_mock.call_count == 1
|
||||||
_lambda_mock.reset_mock()
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is 1"):
|
||||||
runnable.with_retry(
|
runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
retry_if_exception_type=(ValueError,),
|
retry_if_exception_type=(ValueError,),
|
||||||
@ -3852,7 +3853,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
|||||||
assert _lambda_mock.call_count == 1 # did not retry
|
assert _lambda_mock.call_count == 1 # did not retry
|
||||||
_lambda_mock.reset_mock()
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is 1"):
|
||||||
runnable.with_retry(
|
runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
wait_exponential_jitter=False,
|
wait_exponential_jitter=False,
|
||||||
@ -3892,13 +3893,13 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
|||||||
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
||||||
runnable = RunnableLambda(_lambda_mock)
|
runnable = RunnableLambda(_lambda_mock)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is 1"):
|
||||||
await runnable.ainvoke(1)
|
await runnable.ainvoke(1)
|
||||||
|
|
||||||
assert _lambda_mock.call_count == 1
|
assert _lambda_mock.call_count == 1
|
||||||
_lambda_mock.reset_mock()
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is 1"):
|
||||||
await runnable.with_retry(
|
await runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
wait_exponential_jitter=False,
|
wait_exponential_jitter=False,
|
||||||
@ -3918,7 +3919,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
|||||||
assert _lambda_mock.call_count == 1 # did not retry
|
assert _lambda_mock.call_count == 1 # did not retry
|
||||||
_lambda_mock.reset_mock()
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is 1"):
|
||||||
await runnable.with_retry(
|
await runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
wait_exponential_jitter=False,
|
wait_exponential_jitter=False,
|
||||||
@ -3982,9 +3983,8 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Check that the chain on error is invoked
|
# Check that the chain on error is invoked
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is too large"):
|
||||||
for _ in RunnableLambda(raise_value_error).stream(1000, config=config):
|
_ = list(RunnableLambda(raise_value_error).stream(1000, config=config))
|
||||||
pass
|
|
||||||
|
|
||||||
assert len(tracer.runs) == 2
|
assert len(tracer.runs) == 2
|
||||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||||
@ -4061,9 +4061,13 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Check that the chain on error is invoked
|
# Check that the chain on error is invoked
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is too large"):
|
||||||
async for _ in RunnableLambda(raise_value_error).astream(1000, config=config):
|
_ = [
|
||||||
pass
|
_
|
||||||
|
async for _ in RunnableLambda(raise_value_error).astream(
|
||||||
|
1000, config=config
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
assert len(tracer.runs) == 2
|
assert len(tracer.runs) == 2
|
||||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||||
@ -4088,7 +4092,11 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
outputs: list[Any] = []
|
outputs: list[Any] = []
|
||||||
for input in inputs:
|
for input in inputs:
|
||||||
if input.startswith(self.fail_starts_with):
|
if input.startswith(self.fail_starts_with):
|
||||||
outputs.append(ValueError())
|
outputs.append(
|
||||||
|
ValueError(
|
||||||
|
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
outputs.append(input + "a")
|
outputs.append(input + "a")
|
||||||
return outputs
|
return outputs
|
||||||
@ -4119,7 +4127,9 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
|
||||||
# Test batch
|
# Test batch
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match=re.escape("ControlledExceptionRunnable(bar) fail for bara")
|
||||||
|
):
|
||||||
chain.batch(["foo", "bar", "baz", "qux"])
|
chain.batch(["foo", "bar", "baz", "qux"])
|
||||||
|
|
||||||
spy = mocker.spy(ControlledExceptionRunnable, "batch")
|
spy = mocker.spy(ControlledExceptionRunnable, "batch")
|
||||||
@ -4155,32 +4165,44 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
|
|
||||||
parent_run_foo = parent_runs[0]
|
parent_run_foo = parent_runs[0]
|
||||||
assert parent_run_foo.inputs["input"] == "foo"
|
assert parent_run_foo.inputs["input"] == "foo"
|
||||||
assert repr(ValueError()) in str(parent_run_foo.error)
|
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
|
||||||
|
parent_run_foo.error
|
||||||
|
)
|
||||||
assert len(parent_run_foo.child_runs) == 4
|
assert len(parent_run_foo.child_runs) == 4
|
||||||
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
|
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
]
|
]
|
||||||
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
|
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
|
||||||
|
parent_run_foo.child_runs[-1].error
|
||||||
|
)
|
||||||
|
|
||||||
parent_run_bar = parent_runs[1]
|
parent_run_bar = parent_runs[1]
|
||||||
assert parent_run_bar.inputs["input"] == "bar"
|
assert parent_run_bar.inputs["input"] == "bar"
|
||||||
assert repr(ValueError()) in str(parent_run_bar.error)
|
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
|
||||||
|
parent_run_bar.error
|
||||||
|
)
|
||||||
assert len(parent_run_bar.child_runs) == 2
|
assert len(parent_run_bar.child_runs) == 2
|
||||||
assert parent_run_bar.child_runs[0].error is None
|
assert parent_run_bar.child_runs[0].error is None
|
||||||
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
|
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
|
||||||
|
parent_run_bar.child_runs[1].error
|
||||||
|
)
|
||||||
|
|
||||||
parent_run_baz = parent_runs[2]
|
parent_run_baz = parent_runs[2]
|
||||||
assert parent_run_baz.inputs["input"] == "baz"
|
assert parent_run_baz.inputs["input"] == "baz"
|
||||||
assert repr(ValueError()) in str(parent_run_baz.error)
|
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
|
||||||
|
parent_run_baz.error
|
||||||
|
)
|
||||||
assert len(parent_run_baz.child_runs) == 3
|
assert len(parent_run_baz.child_runs) == 3
|
||||||
|
|
||||||
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
|
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
]
|
]
|
||||||
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
|
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
|
||||||
|
parent_run_baz.child_runs[-1].error
|
||||||
|
)
|
||||||
|
|
||||||
parent_run_qux = parent_runs[3]
|
parent_run_qux = parent_runs[3]
|
||||||
assert parent_run_qux.inputs["input"] == "qux"
|
assert parent_run_qux.inputs["input"] == "qux"
|
||||||
@ -4209,7 +4231,11 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
outputs: list[Any] = []
|
outputs: list[Any] = []
|
||||||
for input in inputs:
|
for input in inputs:
|
||||||
if input.startswith(self.fail_starts_with):
|
if input.startswith(self.fail_starts_with):
|
||||||
outputs.append(ValueError())
|
outputs.append(
|
||||||
|
ValueError(
|
||||||
|
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
outputs.append(input + "a")
|
outputs.append(input + "a")
|
||||||
return outputs
|
return outputs
|
||||||
@ -4240,7 +4266,9 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
|
||||||
# Test abatch
|
# Test abatch
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match=re.escape("ControlledExceptionRunnable(bar) fail for bara")
|
||||||
|
):
|
||||||
await chain.abatch(["foo", "bar", "baz", "qux"])
|
await chain.abatch(["foo", "bar", "baz", "qux"])
|
||||||
|
|
||||||
spy = mocker.spy(ControlledExceptionRunnable, "abatch")
|
spy = mocker.spy(ControlledExceptionRunnable, "abatch")
|
||||||
@ -4278,31 +4306,43 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
|
|
||||||
parent_run_foo = parent_runs[0]
|
parent_run_foo = parent_runs[0]
|
||||||
assert parent_run_foo.inputs["input"] == "foo"
|
assert parent_run_foo.inputs["input"] == "foo"
|
||||||
assert repr(ValueError()) in str(parent_run_foo.error)
|
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
|
||||||
|
parent_run_foo.error
|
||||||
|
)
|
||||||
assert len(parent_run_foo.child_runs) == 4
|
assert len(parent_run_foo.child_runs) == 4
|
||||||
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
|
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
]
|
]
|
||||||
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
|
assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
|
||||||
|
parent_run_foo.child_runs[-1].error
|
||||||
|
)
|
||||||
|
|
||||||
parent_run_bar = parent_runs[1]
|
parent_run_bar = parent_runs[1]
|
||||||
assert parent_run_bar.inputs["input"] == "bar"
|
assert parent_run_bar.inputs["input"] == "bar"
|
||||||
assert repr(ValueError()) in str(parent_run_bar.error)
|
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
|
||||||
|
parent_run_bar.error
|
||||||
|
)
|
||||||
assert len(parent_run_bar.child_runs) == 2
|
assert len(parent_run_bar.child_runs) == 2
|
||||||
assert parent_run_bar.child_runs[0].error is None
|
assert parent_run_bar.child_runs[0].error is None
|
||||||
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
|
assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
|
||||||
|
parent_run_bar.child_runs[1].error
|
||||||
|
)
|
||||||
|
|
||||||
parent_run_baz = parent_runs[2]
|
parent_run_baz = parent_runs[2]
|
||||||
assert parent_run_baz.inputs["input"] == "baz"
|
assert parent_run_baz.inputs["input"] == "baz"
|
||||||
assert repr(ValueError()) in str(parent_run_baz.error)
|
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
|
||||||
|
parent_run_baz.error
|
||||||
|
)
|
||||||
assert len(parent_run_baz.child_runs) == 3
|
assert len(parent_run_baz.child_runs) == 3
|
||||||
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
|
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
]
|
]
|
||||||
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
|
assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
|
||||||
|
parent_run_baz.child_runs[-1].error
|
||||||
|
)
|
||||||
|
|
||||||
parent_run_qux = parent_runs[3]
|
parent_run_qux = parent_runs[3]
|
||||||
assert parent_run_qux.inputs["input"] == "qux"
|
assert parent_run_qux.inputs["input"] == "qux"
|
||||||
@ -4319,11 +4359,15 @@ def test_runnable_branch_init() -> None:
|
|||||||
condition = RunnableLambda(lambda x: x > 0)
|
condition = RunnableLambda(lambda x: x > 0)
|
||||||
|
|
||||||
# Test failure with less than 2 branches
|
# Test failure with less than 2 branches
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match="RunnableBranch requires at least two branches"
|
||||||
|
):
|
||||||
RunnableBranch((condition, add))
|
RunnableBranch((condition, add))
|
||||||
|
|
||||||
# Test failure with less than 2 branches
|
# Test failure with less than 2 branches
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match="RunnableBranch requires at least two branches"
|
||||||
|
):
|
||||||
RunnableBranch(condition)
|
RunnableBranch(condition)
|
||||||
|
|
||||||
|
|
||||||
@ -4408,7 +4452,7 @@ def test_runnable_branch_invoke() -> None:
|
|||||||
assert branch.invoke(10) == 100
|
assert branch.invoke(10) == 100
|
||||||
assert branch.invoke(0) == -1
|
assert branch.invoke(0) == -1
|
||||||
# Should raise an exception
|
# Should raise an exception
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is too large"):
|
||||||
branch.invoke(1000)
|
branch.invoke(1000)
|
||||||
|
|
||||||
|
|
||||||
@ -4472,7 +4516,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
|
|||||||
assert tracer.runs[0].outputs == {"output": 0}
|
assert tracer.runs[0].outputs == {"output": 0}
|
||||||
|
|
||||||
# Check that the chain on end is invoked
|
# Check that the chain on end is invoked
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is too large"):
|
||||||
branch.invoke(1000, config={"callbacks": [tracer]})
|
branch.invoke(1000, config={"callbacks": [tracer]})
|
||||||
|
|
||||||
assert len(tracer.runs) == 2
|
assert len(tracer.runs) == 2
|
||||||
@ -4500,7 +4544,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
|
|||||||
assert tracer.runs[0].outputs == {"output": 0}
|
assert tracer.runs[0].outputs == {"output": 0}
|
||||||
|
|
||||||
# Check that the chain on end is invoked
|
# Check that the chain on end is invoked
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is too large"):
|
||||||
await branch.ainvoke(1000, config={"callbacks": [tracer]})
|
await branch.ainvoke(1000, config={"callbacks": [tracer]})
|
||||||
|
|
||||||
assert len(tracer.runs) == 2
|
assert len(tracer.runs) == 2
|
||||||
@ -4561,9 +4605,8 @@ def test_runnable_branch_stream_with_callbacks() -> None:
|
|||||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||||
|
|
||||||
# Verify that the chain on error is invoked
|
# Verify that the chain on error is invoked
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is error"):
|
||||||
for _ in branch.stream("error", config=config):
|
_ = list(branch.stream("error", config=config))
|
||||||
pass
|
|
||||||
|
|
||||||
assert len(tracer.runs) == 2
|
assert len(tracer.runs) == 2
|
||||||
assert "ValueError('x is error')" in str(tracer.runs[1].error)
|
assert "ValueError('x is error')" in str(tracer.runs[1].error)
|
||||||
@ -4638,9 +4681,8 @@ async def test_runnable_branch_astream_with_callbacks() -> None:
|
|||||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||||
|
|
||||||
# Verify that the chain on error is invoked
|
# Verify that the chain on error is invoked
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="x is error"):
|
||||||
async for _ in branch.astream("error", config=config):
|
_ = [_ async for _ in branch.astream("error", config=config)]
|
||||||
pass
|
|
||||||
|
|
||||||
assert len(tracer.runs) == 2
|
assert len(tracer.runs) == 2
|
||||||
assert "ValueError('x is error')" in str(tracer.runs[1].error)
|
assert "ValueError('x is error')" in str(tracer.runs[1].error)
|
||||||
|
@ -1824,8 +1824,7 @@ async def test_runnable_each() -> None:
|
|||||||
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
|
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
|
_ = [_ async for _ in add_one_map.astream_events([1, 2, 3], version="v1")]
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def test_events_astream_config() -> None:
|
async def test_events_astream_config() -> None:
|
||||||
|
@ -1773,8 +1773,7 @@ async def test_runnable_each() -> None:
|
|||||||
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
|
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
async for _ in add_one_map.astream_events([1, 2, 3], version="v2"):
|
_ = [_ async for _ in add_one_map.astream_events([1, 2, 3], version="v2")]
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def test_events_astream_config() -> None:
|
async def test_events_astream_config() -> None:
|
||||||
|
@ -419,7 +419,7 @@ class TestRunnableSequenceParallelTraceNesting:
|
|||||||
self._check_posts()
|
self._check_posts()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("parent_type", ("ls", "lc"))
|
@pytest.mark.parametrize("parent_type", ["ls", "lc"])
|
||||||
def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
|
def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_client_ = Client(
|
mock_client_ = Client(
|
||||||
|
@ -15,7 +15,7 @@ from langchain_core.runnables.utils import (
|
|||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"func, expected_source",
|
("func", "expected_source"),
|
||||||
[
|
[
|
||||||
(lambda x: x * 2, "lambda x: x * 2"),
|
(lambda x: x * 2, "lambda x: x * 2"),
|
||||||
(lambda a, b: a + b, "lambda a, b: a + b"),
|
(lambda a, b: a + b, "lambda a, b: a + b"),
|
||||||
@ -29,7 +29,7 @@ def test_get_lambda_source(func: Callable, expected_source: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"text,prefix,expected_output",
|
("text", "prefix", "expected_output"),
|
||||||
[
|
[
|
||||||
("line 1\nline 2\nline 3", "1", "line 1\n line 2\n line 3"),
|
("line 1\nline 2\nline 3", "1", "line 1\n line 2\n line 3"),
|
||||||
("line 1\nline 2\nline 3", "ax", "line 1\n line 2\n line 3"),
|
("line 1\nline 2\nline 3", "ax", "line 1\n line 2\n line 3"),
|
||||||
|
@ -184,7 +184,9 @@ def test_chat_message_chunks() -> None:
|
|||||||
"ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
|
"ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match="Cannot concatenate ChatMessageChunks with different roles."
|
||||||
|
):
|
||||||
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
||||||
role="Assistant", content=" indeed."
|
role="Assistant", content=" indeed."
|
||||||
)
|
)
|
||||||
@ -290,7 +292,10 @@ def test_function_message_chunks() -> None:
|
|||||||
id="ai5", name="hello", content="I am indeed."
|
id="ai5", name="hello", content="I am indeed."
|
||||||
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
|
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Cannot concatenate FunctionMessageChunks with different names.",
|
||||||
|
):
|
||||||
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
||||||
name="bye", content=" indeed."
|
name="bye", content=" indeed."
|
||||||
)
|
)
|
||||||
@ -303,7 +308,10 @@ def test_ai_message_chunks() -> None:
|
|||||||
"AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
|
"AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Cannot concatenate AIMessageChunks with different example values.",
|
||||||
|
):
|
||||||
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
||||||
example=False, content=" indeed."
|
example=False, content=" indeed."
|
||||||
)
|
)
|
||||||
@ -320,30 +328,21 @@ class TestGetBufferString(unittest.TestCase):
|
|||||||
self.tool_calls_msg = AIMessage(content="tool")
|
self.tool_calls_msg = AIMessage(content="tool")
|
||||||
|
|
||||||
def test_empty_input(self) -> None:
|
def test_empty_input(self) -> None:
|
||||||
self.assertEqual(get_buffer_string([]), "")
|
assert get_buffer_string([]) == ""
|
||||||
|
|
||||||
def test_valid_single_message(self) -> None:
|
def test_valid_single_message(self) -> None:
|
||||||
expected_output = f"Human: {self.human_msg.content}"
|
expected_output = f"Human: {self.human_msg.content}"
|
||||||
self.assertEqual(
|
assert get_buffer_string([self.human_msg]) == expected_output
|
||||||
get_buffer_string([self.human_msg]),
|
|
||||||
expected_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_custom_human_prefix(self) -> None:
|
def test_custom_human_prefix(self) -> None:
|
||||||
prefix = "H"
|
prefix = "H"
|
||||||
expected_output = f"{prefix}: {self.human_msg.content}"
|
expected_output = f"{prefix}: {self.human_msg.content}"
|
||||||
self.assertEqual(
|
assert get_buffer_string([self.human_msg], human_prefix="H") == expected_output
|
||||||
get_buffer_string([self.human_msg], human_prefix="H"),
|
|
||||||
expected_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_custom_ai_prefix(self) -> None:
|
def test_custom_ai_prefix(self) -> None:
|
||||||
prefix = "A"
|
prefix = "A"
|
||||||
expected_output = f"{prefix}: {self.ai_msg.content}"
|
expected_output = f"{prefix}: {self.ai_msg.content}"
|
||||||
self.assertEqual(
|
assert get_buffer_string([self.ai_msg], ai_prefix="A") == expected_output
|
||||||
get_buffer_string([self.ai_msg], ai_prefix="A"),
|
|
||||||
expected_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_multiple_msg(self) -> None:
|
def test_multiple_msg(self) -> None:
|
||||||
msgs = [
|
msgs = [
|
||||||
@ -366,10 +365,7 @@ class TestGetBufferString(unittest.TestCase):
|
|||||||
"AI: tool",
|
"AI: tool",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
assert get_buffer_string(msgs) == expected_output
|
||||||
get_buffer_string(msgs),
|
|
||||||
expected_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_msg() -> None:
|
def test_multiple_msg() -> None:
|
||||||
@ -991,7 +987,7 @@ def test_tool_message_str() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["first", "others", "expected"],
|
("first", "others", "expected"),
|
||||||
[
|
[
|
||||||
("", [""], ""),
|
("", [""], ""),
|
||||||
("", [[]], [""]),
|
("", [[]], [""]),
|
||||||
|
@ -775,8 +775,8 @@ def test_exception_handling_callable() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_exception_handling_non_tool_exception() -> None:
|
def test_exception_handling_non_tool_exception() -> None:
|
||||||
_tool = _FakeExceptionTool(exception=ValueError())
|
_tool = _FakeExceptionTool(exception=ValueError("some error"))
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="some error"):
|
||||||
_tool.run({})
|
_tool.run({})
|
||||||
|
|
||||||
|
|
||||||
@ -806,8 +806,8 @@ async def test_async_exception_handling_callable() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_async_exception_handling_non_tool_exception() -> None:
|
async def test_async_exception_handling_non_tool_exception() -> None:
|
||||||
_tool = _FakeExceptionTool(exception=ValueError())
|
_tool = _FakeExceptionTool(exception=ValueError("some error"))
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="some error"):
|
||||||
await _tool.arun({})
|
await _tool.arun({})
|
||||||
|
|
||||||
|
|
||||||
@ -987,7 +987,7 @@ def test_optional_subset_model_rewrite() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"inputs, expected",
|
("inputs", "expected"),
|
||||||
[
|
[
|
||||||
# Check not required
|
# Check not required
|
||||||
({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}),
|
({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}),
|
||||||
@ -1323,6 +1323,10 @@ def test_tool_invalid_docstrings() -> None:
|
|||||||
"""
|
"""
|
||||||
return bar
|
return bar
|
||||||
|
|
||||||
|
for func in {foo3, foo4}:
|
||||||
|
with pytest.raises(ValueError, match="Found invalid Google-Style docstring."):
|
||||||
|
_ = tool(func, parse_docstring=True)
|
||||||
|
|
||||||
def foo5(bar: str, baz: int) -> str:
|
def foo5(bar: str, baz: int) -> str:
|
||||||
"""The foo.
|
"""The foo.
|
||||||
|
|
||||||
@ -1332,9 +1336,10 @@ def test_tool_invalid_docstrings() -> None:
|
|||||||
"""
|
"""
|
||||||
return bar
|
return bar
|
||||||
|
|
||||||
for func in [foo3, foo4, foo5]:
|
with pytest.raises(
|
||||||
with pytest.raises(ValueError):
|
ValueError, match="Arg banana in docstring not found in function signature."
|
||||||
_ = tool(func, parse_docstring=True)
|
):
|
||||||
|
_ = tool(foo5, parse_docstring=True)
|
||||||
|
|
||||||
|
|
||||||
def test_tool_annotated_descriptions() -> None:
|
def test_tool_annotated_descriptions() -> None:
|
||||||
@ -2004,9 +2009,9 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("obj", "expected"),
|
("obj", "expected"),
|
||||||
[
|
[
|
||||||
["foo", True],
|
("foo", True),
|
||||||
[valid_tool_result_blocks, True],
|
(valid_tool_result_blocks, True),
|
||||||
[invalid_tool_result_blocks, False],
|
(invalid_tool_result_blocks, False),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test__is_message_content_type(obj: Any, expected: bool) -> None:
|
def test__is_message_content_type(obj: Any, expected: bool) -> None:
|
||||||
@ -2268,7 +2273,8 @@ def test_imports() -> None:
|
|||||||
"InjectedToolArg",
|
"InjectedToolArg",
|
||||||
]
|
]
|
||||||
for module_name in expected_all:
|
for module_name in expected_all:
|
||||||
assert hasattr(tools, module_name) and getattr(tools, module_name) is not None
|
assert hasattr(tools, module_name)
|
||||||
|
assert getattr(tools, module_name) is not None
|
||||||
|
|
||||||
|
|
||||||
def test_structured_tool_direct_init() -> None:
|
def test_structured_tool_direct_init() -> None:
|
||||||
@ -2317,7 +2323,11 @@ def test_tool_injected_tool_call_id() -> None:
|
|||||||
}
|
}
|
||||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore
|
) == ToolMessage(0, tool_call_id="bar") # type: ignore
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="When tool includes an InjectedToolCallId argument, "
|
||||||
|
"tool must always be invoked with a full model ToolCall",
|
||||||
|
):
|
||||||
assert foo.invoke({"x": 0})
|
assert foo.invoke({"x": 0})
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@ -2341,7 +2351,7 @@ def test_tool_uninjected_tool_call_id() -> None:
|
|||||||
"""Foo."""
|
"""Foo."""
|
||||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="1 validation error for foo"):
|
||||||
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
|
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
|
||||||
|
|
||||||
assert foo.invoke(
|
assert foo.invoke(
|
||||||
|
@ -22,6 +22,5 @@ def test_public_api() -> None:
|
|||||||
|
|
||||||
# Assert that the object is actually present in the schema module
|
# Assert that the object is actually present in the schema module
|
||||||
for module_name in expected_all:
|
for module_name in expected_all:
|
||||||
assert (
|
assert hasattr(schemas, module_name)
|
||||||
hasattr(schemas, module_name) and getattr(schemas, module_name) is not None
|
assert getattr(schemas, module_name) is not None
|
||||||
)
|
|
||||||
|
@ -6,7 +6,7 @@ from langchain_core.utils.aiter import abatch_iterate
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input_size, input_iterable, expected_output",
|
("input_size", "input_iterable", "expected_output"),
|
||||||
[
|
[
|
||||||
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
||||||
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
||||||
|
@ -51,7 +51,12 @@ def test_get_from_dict_or_env() -> None:
|
|||||||
|
|
||||||
# Not the most obvious behavior, but
|
# Not the most obvious behavior, but
|
||||||
# this is how it works right now
|
# this is how it works right now
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Did not find not exists, "
|
||||||
|
"please add an environment variable `__SOME_KEY_IN_ENV` which contains it, "
|
||||||
|
"or pass `not exists` as a named parameter.",
|
||||||
|
):
|
||||||
assert (
|
assert (
|
||||||
get_from_dict_or_env(
|
get_from_dict_or_env(
|
||||||
{
|
{
|
||||||
|
@ -37,7 +37,7 @@ from langchain_core.utils.function_calling import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def pydantic() -> type[BaseModel]:
|
def pydantic() -> type[BaseModel]:
|
||||||
class dummy_function(BaseModel): # noqa: N801
|
class dummy_function(BaseModel): # noqa: N801
|
||||||
"""Dummy function."""
|
"""Dummy function."""
|
||||||
@ -48,7 +48,7 @@ def pydantic() -> type[BaseModel]:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def annotated_function() -> Callable:
|
def annotated_function() -> Callable:
|
||||||
def dummy_function(
|
def dummy_function(
|
||||||
arg1: ExtensionsAnnotated[int, "foo"],
|
arg1: ExtensionsAnnotated[int, "foo"],
|
||||||
@ -59,7 +59,7 @@ def annotated_function() -> Callable:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def function() -> Callable:
|
def function() -> Callable:
|
||||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||||
"""Dummy function.
|
"""Dummy function.
|
||||||
@ -72,7 +72,7 @@ def function() -> Callable:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def function_docstring_annotations() -> Callable:
|
def function_docstring_annotations() -> Callable:
|
||||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||||
"""Dummy function.
|
"""Dummy function.
|
||||||
@ -85,7 +85,7 @@ def function_docstring_annotations() -> Callable:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def runnable() -> Runnable:
|
def runnable() -> Runnable:
|
||||||
class Args(ExtensionsTypedDict):
|
class Args(ExtensionsTypedDict):
|
||||||
arg1: ExtensionsAnnotated[int, "foo"]
|
arg1: ExtensionsAnnotated[int, "foo"]
|
||||||
@ -97,7 +97,7 @@ def runnable() -> Runnable:
|
|||||||
return RunnableLambda(dummy_function)
|
return RunnableLambda(dummy_function)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def dummy_tool() -> BaseTool:
|
def dummy_tool() -> BaseTool:
|
||||||
class Schema(BaseModel):
|
class Schema(BaseModel):
|
||||||
arg1: int = Field(..., description="foo")
|
arg1: int = Field(..., description="foo")
|
||||||
@ -114,7 +114,7 @@ def dummy_tool() -> BaseTool:
|
|||||||
return DummyFunction()
|
return DummyFunction()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def dummy_structured_tool() -> StructuredTool:
|
def dummy_structured_tool() -> StructuredTool:
|
||||||
class Schema(BaseModel):
|
class Schema(BaseModel):
|
||||||
arg1: int = Field(..., description="foo")
|
arg1: int = Field(..., description="foo")
|
||||||
@ -128,7 +128,7 @@ def dummy_structured_tool() -> StructuredTool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
||||||
args_schema = {
|
args_schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -150,7 +150,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def dummy_pydantic() -> type[BaseModel]:
|
def dummy_pydantic() -> type[BaseModel]:
|
||||||
class dummy_function(BaseModel): # noqa: N801
|
class dummy_function(BaseModel): # noqa: N801
|
||||||
"""Dummy function."""
|
"""Dummy function."""
|
||||||
@ -161,7 +161,7 @@ def dummy_pydantic() -> type[BaseModel]:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
||||||
class dummy_function(BaseModelV2Maybe): # noqa: N801
|
class dummy_function(BaseModelV2Maybe): # noqa: N801
|
||||||
"""Dummy function."""
|
"""Dummy function."""
|
||||||
@ -174,7 +174,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def dummy_typing_typed_dict() -> type:
|
def dummy_typing_typed_dict() -> type:
|
||||||
class dummy_function(TypingTypedDict): # noqa: N801
|
class dummy_function(TypingTypedDict): # noqa: N801
|
||||||
"""Dummy function."""
|
"""Dummy function."""
|
||||||
@ -185,7 +185,7 @@ def dummy_typing_typed_dict() -> type:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def dummy_typing_typed_dict_docstring() -> type:
|
def dummy_typing_typed_dict_docstring() -> type:
|
||||||
class dummy_function(TypingTypedDict): # noqa: N801
|
class dummy_function(TypingTypedDict): # noqa: N801
|
||||||
"""Dummy function.
|
"""Dummy function.
|
||||||
@ -201,7 +201,7 @@ def dummy_typing_typed_dict_docstring() -> type:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def dummy_extensions_typed_dict() -> type:
|
def dummy_extensions_typed_dict() -> type:
|
||||||
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
||||||
"""Dummy function."""
|
"""Dummy function."""
|
||||||
@ -212,7 +212,7 @@ def dummy_extensions_typed_dict() -> type:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def dummy_extensions_typed_dict_docstring() -> type:
|
def dummy_extensions_typed_dict_docstring() -> type:
|
||||||
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
class dummy_function(ExtensionsTypedDict): # noqa: N801
|
||||||
"""Dummy function.
|
"""Dummy function.
|
||||||
@ -228,7 +228,7 @@ def dummy_extensions_typed_dict_docstring() -> type:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def json_schema() -> dict:
|
def json_schema() -> dict:
|
||||||
return {
|
return {
|
||||||
"title": "dummy_function",
|
"title": "dummy_function",
|
||||||
@ -246,7 +246,7 @@ def json_schema() -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def anthropic_tool() -> dict:
|
def anthropic_tool() -> dict:
|
||||||
return {
|
return {
|
||||||
"name": "dummy_function",
|
"name": "dummy_function",
|
||||||
@ -266,7 +266,7 @@ def anthropic_tool() -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def bedrock_converse_tool() -> dict:
|
def bedrock_converse_tool() -> dict:
|
||||||
return {
|
return {
|
||||||
"toolSpec": {
|
"toolSpec": {
|
||||||
|
@ -4,7 +4,7 @@ from langchain_core.utils.iter import batch_iterate
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input_size, input_iterable, expected_output",
|
("input_size", "input_iterable", "expected_output"),
|
||||||
[
|
[
|
||||||
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
||||||
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
||||||
|
@ -147,7 +147,7 @@ def test_dereference_refs_remote_ref() -> None:
|
|||||||
"first_name": {"$ref": "https://somewhere/else/name"},
|
"first_name": {"$ref": "https://somewhere/else/name"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="ref paths are expected to be URI fragments"):
|
||||||
dereference_refs(schema)
|
dereference_refs(schema)
|
||||||
|
|
||||||
|
|
||||||
|
@ -220,7 +220,7 @@ output5 = {
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"schema, output",
|
("schema", "output"),
|
||||||
[
|
[
|
||||||
(schema1, output1),
|
(schema1, output1),
|
||||||
(schema2, output2),
|
(schema2, output2),
|
||||||
|
@ -29,12 +29,17 @@ def test_dict_int_op_nested() -> None:
|
|||||||
def test_dict_int_op_max_depth_exceeded() -> None:
|
def test_dict_int_op_max_depth_exceeded() -> None:
|
||||||
left = {"a": {"b": {"c": 1}}}
|
left = {"a": {"b": {"c": 1}}}
|
||||||
right = {"a": {"b": {"c": 2}}}
|
right = {"a": {"b": {"c": 2}}}
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError, match="max_depth=2 exceeded, unable to combine dicts."
|
||||||
|
):
|
||||||
_dict_int_op(left, right, operator.add, max_depth=2)
|
_dict_int_op(left, right, operator.add, max_depth=2)
|
||||||
|
|
||||||
|
|
||||||
def test_dict_int_op_invalid_types() -> None:
|
def test_dict_int_op_invalid_types() -> None:
|
||||||
left = {"a": 1, "b": "string"}
|
left = {"a": 1, "b": "string"}
|
||||||
right = {"a": 2, "b": 3}
|
right = {"a": 2, "b": 3}
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Only dict and int values are supported.",
|
||||||
|
):
|
||||||
_dict_int_op(left, right, operator.add)
|
_dict_int_op(left, right, operator.add)
|
||||||
|
@ -46,7 +46,7 @@ def test_check_package_version(
|
|||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("left", "right", "expected"),
|
("left", "right", "expected"),
|
||||||
(
|
[
|
||||||
# Merge `None` and `1`.
|
# Merge `None` and `1`.
|
||||||
({"a": None}, {"a": 1}, {"a": 1}),
|
({"a": None}, {"a": 1}, {"a": 1}),
|
||||||
# Merge `1` and `None`.
|
# Merge `1` and `None`.
|
||||||
@ -111,7 +111,7 @@ def test_check_package_version(
|
|||||||
{"a": [{"idx": 0, "b": "f"}]},
|
{"a": [{"idx": 0, "b": "f"}]},
|
||||||
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
|
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
|
||||||
),
|
),
|
||||||
),
|
],
|
||||||
)
|
)
|
||||||
def test_merge_dicts(
|
def test_merge_dicts(
|
||||||
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
|
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
|
||||||
@ -130,7 +130,7 @@ def test_merge_dicts(
|
|||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("left", "right", "expected"),
|
("left", "right", "expected"),
|
||||||
(
|
[
|
||||||
# 'type' special key handling
|
# 'type' special key handling
|
||||||
({"type": "foo"}, {"type": "foo"}, {"type": "foo"}),
|
({"type": "foo"}, {"type": "foo"}, {"type": "foo"}),
|
||||||
(
|
(
|
||||||
@ -138,7 +138,7 @@ def test_merge_dicts(
|
|||||||
{"type": "bar"},
|
{"type": "bar"},
|
||||||
pytest.raises(ValueError, match="Unable to merge."),
|
pytest.raises(ValueError, match="Unable to merge."),
|
||||||
),
|
),
|
||||||
),
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.xfail(reason="Refactors to make in 0.3")
|
@pytest.mark.xfail(reason="Refactors to make in 0.3")
|
||||||
def test_merge_dicts_0_3(
|
def test_merge_dicts_0_3(
|
||||||
@ -183,36 +183,32 @@ def test_guard_import(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("module_name", "pip_name", "package"),
|
("module_name", "pip_name", "package", "expected_pip_name"),
|
||||||
[
|
[
|
||||||
("langchain_core.utilsW", None, None),
|
("langchain_core.utilsW", None, None, "langchain-core"),
|
||||||
("langchain_core.utilsW", "langchain-core-2", None),
|
("langchain_core.utilsW", "langchain-core-2", None, "langchain-core-2"),
|
||||||
("langchain_core.utilsW", None, "langchain-coreWX"),
|
("langchain_core.utilsW", None, "langchain-coreWX", "langchain-core"),
|
||||||
("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"),
|
(
|
||||||
("langchain_coreW", None, None), # ModuleNotFoundError
|
"langchain_core.utilsW",
|
||||||
|
"langchain-core-2",
|
||||||
|
"langchain-coreWX",
|
||||||
|
"langchain-core-2",
|
||||||
|
),
|
||||||
|
("langchain_coreW", None, None, "langchain-coreW"), # ModuleNotFoundError
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_guard_import_failure(
|
def test_guard_import_failure(
|
||||||
module_name: str, pip_name: Optional[str], package: Optional[str]
|
module_name: str,
|
||||||
|
pip_name: Optional[str],
|
||||||
|
package: Optional[str],
|
||||||
|
expected_pip_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
with pytest.raises(ImportError) as exc_info:
|
with pytest.raises(
|
||||||
if package is None and pip_name is None:
|
ImportError,
|
||||||
guard_import(module_name)
|
match=f"Could not import {module_name} python package. "
|
||||||
elif package is None and pip_name is not None:
|
f"Please install it with `pip install {expected_pip_name}`.",
|
||||||
guard_import(module_name, pip_name=pip_name)
|
):
|
||||||
elif package is not None and pip_name is None:
|
|
||||||
guard_import(module_name, package=package)
|
|
||||||
elif package is not None and pip_name is not None:
|
|
||||||
guard_import(module_name, pip_name=pip_name, package=package)
|
guard_import(module_name, pip_name=pip_name, package=package)
|
||||||
else:
|
|
||||||
msg = "Invalid test case"
|
|
||||||
raise ValueError(msg)
|
|
||||||
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
|
||||||
err_msg = (
|
|
||||||
f"Could not import {module_name} python package. "
|
|
||||||
f"Please install it with `pip install {pip_name}`."
|
|
||||||
)
|
|
||||||
assert exc_info.value.msg == err_msg
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
||||||
|
Loading…
Reference in New Issue
Block a user