diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index 3dafaecc755..0b64c594a31 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -19,9 +19,9 @@ from langchain_core.utils.json import ( parse_json_markdown, parse_partial_json, ) -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION +from langchain_core.utils.pydantic import IS_PYDANTIC_V1 -if PYDANTIC_MAJOR_VERSION < 2: +if IS_PYDANTIC_V1: PydanticBaseModel = pydantic.BaseModel else: diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index 6550d4ce7d8..b1ee96e0fa0 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -11,7 +11,7 @@ from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import JsonOutputParser from langchain_core.outputs import Generation from langchain_core.utils.pydantic import ( - PYDANTIC_MAJOR_VERSION, + IS_PYDANTIC_V2, PydanticBaseModel, TBaseModel, ) @@ -24,7 +24,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): """The pydantic model to parse.""" def _parse_obj(self, obj: dict) -> TBaseModel: - if PYDANTIC_MAJOR_VERSION == 2: + if IS_PYDANTIC_V2: try: if issubclass(self.pydantic_object, pydantic.BaseModel): return self.pydantic_object.model_validate(obj) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index df698d9f13d..9a8f9fd832b 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -20,6 +20,7 @@ from typing import ( ) import pydantic +from packaging import version from pydantic import ( BaseModel, ConfigDict, @@ -41,44 +42,46 @@ from pydantic.json_schema import ( if TYPE_CHECKING: from pydantic_core import core_schema +try: + import pydantic + + PYDANTIC_VERSION = version.parse(pydantic.__version__) +except ImportError: + PYDANTIC_VERSION = version.parse("0.0.0") + def get_pydantic_major_version() -> int: - """Get the major version of Pydantic.""" - try: - import pydantic + """DEPRECATED - Get the major version of Pydantic. - return int(pydantic.__version__.split(".")[0]) - except ImportError: - return 0 + Use PYDANTIC_VERSION.major instead. + """ + warnings.warn( + "get_pydantic_major_version is deprecated. Use PYDANTIC_VERSION.major instead.", + DeprecationWarning, + stacklevel=2, + ) + return PYDANTIC_VERSION.major -def _get_pydantic_minor_version() -> int: - """Get the minor version of Pydantic.""" - try: - import pydantic +PYDANTIC_MAJOR_VERSION = PYDANTIC_VERSION.major +PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor - return int(pydantic.__version__.split(".")[1]) - except ImportError: - return 0 +IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1 +IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2 - -PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() -PYDANTIC_MINOR_VERSION = _get_pydantic_minor_version() - - -if PYDANTIC_MAJOR_VERSION == 1: +if IS_PYDANTIC_V1: from pydantic.fields import FieldInfo as FieldInfoV1 PydanticBaseModel = pydantic.BaseModel TypeBaseModel = type[BaseModel] -elif PYDANTIC_MAJOR_VERSION == 2: +elif IS_PYDANTIC_V2: from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment] # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc] TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc] else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) @@ -87,9 +90,9 @@ TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) def is_pydantic_v1_subclass(cls: type) -> bool: """Check if the installed Pydantic version is 1.x-like.""" - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: return True - if PYDANTIC_MAJOR_VERSION == 2: + if IS_PYDANTIC_V2: from pydantic.v1 import BaseModel as BaseModelV1 if issubclass(cls, BaseModelV1): @@ -101,7 +104,7 @@ def is_pydantic_v2_subclass(cls: type) -> bool: """Check if the installed Pydantic version is 1.x-like.""" from pydantic import BaseModel - return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel) + return IS_PYDANTIC_V2 and issubclass(cls, BaseModel) def is_basemodel_subclass(cls: type) -> bool: @@ -117,12 +120,12 @@ def is_basemodel_subclass(cls: type) -> bool: if not inspect.isclass(cls) or isinstance(cls, GenericAlias): return False - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: from pydantic import BaseModel as BaseModelV1Proper if issubclass(cls, BaseModelV1Proper): return True - elif PYDANTIC_MAJOR_VERSION == 2: + elif IS_PYDANTIC_V2: from pydantic import BaseModel as BaseModelV2 from pydantic.v1 import BaseModel as BaseModelV1 @@ -132,7 +135,7 @@ def is_basemodel_subclass(cls: type) -> bool: if issubclass(cls, BaseModelV1): return True else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) return False @@ -146,12 +149,12 @@ def is_basemodel_instance(obj: Any) -> bool: * pydantic.BaseModel in Pydantic 2.x * pydantic.v1.BaseModel in Pydantic 2.x """ - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: from pydantic import BaseModel as BaseModelV1Proper if isinstance(obj, BaseModelV1Proper): return True - elif PYDANTIC_MAJOR_VERSION == 2: + elif IS_PYDANTIC_V2: from pydantic import BaseModel as BaseModelV2 from pydantic.v1 import BaseModel as BaseModelV1 @@ -161,7 +164,7 @@ def is_basemodel_instance(obj: Any) -> bool: if isinstance(obj, BaseModelV1): return True else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) return False @@ -245,12 +248,12 @@ def _create_subset_model_v1( fn_description: Optional[str] = None, ) -> type[BaseModel]: """Create a pydantic model with only a subset of model's fields.""" - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: from pydantic import create_model - elif PYDANTIC_MAJOR_VERSION == 2: + elif IS_PYDANTIC_V2: from pydantic.v1 import create_model # type: ignore else: - msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}" raise NotImplementedError(msg) fields = {} @@ -327,7 +330,7 @@ def _create_subset_model( fn_description: Optional[str] = None, ) -> type[BaseModel]: """Create subset model using the same pydantic version as the input model.""" - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: return _create_subset_model_v1( name, model, @@ -335,7 +338,7 @@ def _create_subset_model( descriptions=descriptions, fn_description=fn_description, ) - if PYDANTIC_MAJOR_VERSION == 2: + if IS_PYDANTIC_V2: from pydantic.v1 import BaseModel as BaseModelV1 if issubclass(model, BaseModelV1): @@ -353,11 +356,11 @@ def _create_subset_model( descriptions=descriptions, fn_description=fn_description, ) - msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}" raise NotImplementedError(msg) -if PYDANTIC_MAJOR_VERSION == 2: +if IS_PYDANTIC_V2: from pydantic import BaseModel as BaseModelV2 from pydantic.v1 import BaseModel as BaseModelV1 @@ -390,7 +393,7 @@ if PYDANTIC_MAJOR_VERSION == 2: msg = f"Expected a Pydantic model. Got {type(model)}" raise TypeError(msg) -elif PYDANTIC_MAJOR_VERSION == 1: +elif IS_PYDANTIC_V1: from pydantic import BaseModel as BaseModelV1_ def get_fields( # type: ignore[no-redef] @@ -400,7 +403,7 @@ elif PYDANTIC_MAJOR_VERSION == 1: return model.__fields__ # type: ignore else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) _SchemaConfig = ConfigDict( diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index e5fe0f3076c..9fbbd6c446f 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -16,7 +16,10 @@ from langchain_core.output_parsers.openai_tools import ( PydanticToolsParser, ) from langchain_core.outputs import ChatGeneration -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION +from langchain_core.utils.pydantic import ( + IS_PYDANTIC_V1, + IS_PYDANTIC_V2, +) STREAMED_MESSAGES: list = [ AIMessageChunk(content=""), @@ -529,7 +532,7 @@ async def test_partial_pydantic_output_parser_async() -> None: assert actual == EXPECTED_STREAMED_PYDANTIC -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2") def test_parse_with_different_pydantic_2_v1() -> None: """Test with pydantic.v1.BaseModel from pydantic 2.""" import pydantic @@ -564,7 +567,7 @@ def test_parse_with_different_pydantic_2_v1() -> None: ] -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2") def test_parse_with_different_pydantic_2_proper() -> None: """Test with pydantic.BaseModel from pydantic 2.""" import pydantic @@ -599,7 +602,7 @@ def test_parse_with_different_pydantic_2_proper() -> None: ] -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1") +@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="This test is for pydantic 1") def test_parse_with_different_pydantic_1_proper() -> None: """Test with pydantic.BaseModel from pydantic 1.""" import pydantic diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index ba82ff8c760..84a83c6ae2d 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Union, cast import pytest +from packaging import version from pydantic import ValidationError from syrupy import SnapshotAssertion @@ -32,7 +33,9 @@ from langchain_core.prompts.chat import ( _convert_to_message, ) from langchain_core.prompts.string import PromptTemplateFormat -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION +from langchain_core.utils.pydantic import ( + PYDANTIC_VERSION, +) from tests.unit_tests.pydantic_utils import _normalize_schema @@ -921,7 +924,7 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None: with pytest.raises(ValidationError): prompt_all_required.input_schema(input="") - if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10): + if version.parse("2.10") <= PYDANTIC_VERSION: assert _normalize_schema( prompt_all_required.get_input_jsonschema() ) == snapshot(name="required") @@ -932,7 +935,7 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None: assert set(prompt_optional.input_variables) == {"input"} prompt_optional.input_schema(input="") # won't raise error - if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10): + if version.parse("2.10") <= PYDANTIC_VERSION: assert _normalize_schema(prompt_optional.get_input_jsonschema()) == snapshot( name="partial" ) diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index f05451bb0fd..03941ad6ff4 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -4,16 +4,17 @@ import re from typing import Any, Union from unittest import mock -import pydantic import pytest +from packaging import version from syrupy import SnapshotAssertion from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import PromptTemplateFormat from langchain_core.tracers.run_collector import RunCollectorCallbackHandler +from langchain_core.utils.pydantic import PYDANTIC_VERSION from tests.unit_tests.pydantic_utils import _normalize_schema -PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split("."))) +PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION def test_prompt_valid() -> None: @@ -117,7 +118,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: "This foo is a bar test baz." ) assert prompt.input_variables == ["foo", "obj"] - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot( name="schema_0" ) @@ -144,7 +145,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot( name="schema_2" ) @@ -168,7 +169,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot( name="schema_3" ) @@ -206,7 +207,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot( name="schema_4" ) @@ -224,7 +225,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: is a test.""" # noqa: W293 ) assert prompt.input_variables == ["foo"] - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot( name="schema_5" ) diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 171d8819186..545db59c8be 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -1,5 +1,6 @@ from typing import Any, Optional +from packaging import version from pydantic import BaseModel from syrupy import SnapshotAssertion from typing_extensions import override @@ -12,7 +13,9 @@ from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables.base import Runnable, RunnableConfig from langchain_core.runnables.graph import Edge, Graph, Node from langchain_core.runnables.graph_mermaid import _escape_node_label -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION +from langchain_core.utils.pydantic import ( + PYDANTIC_VERSION, +) from tests.unit_tests.pydantic_utils import _normalize_schema @@ -231,12 +234,10 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: ) graph = sequence.get_graph() - if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10): + if version.parse("2.10") <= PYDANTIC_VERSION: assert _normalize_schema(graph.to_json(with_schemas=True)) == snapshot( name="graph_with_schema" ) - - if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10): assert _normalize_schema(graph.to_json()) == snapshot(name="graph_no_schemas") assert graph.draw_ascii() == snapshot(name="ascii") diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index d5dc1072869..35f6c2b3801 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -2,8 +2,8 @@ import re from collections.abc import Sequence from typing import Any, Callable, Optional, Union -import pydantic import pytest +from packaging import version from pydantic import BaseModel from langchain_core.callbacks import ( @@ -19,10 +19,9 @@ from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output from langchain_core.tracers import Run +from langchain_core.utils.pydantic import PYDANTIC_VERSION from tests.unit_tests.pydantic_utils import _schema -PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split("."))) - def test_interfaces() -> None: history = InMemoryChatMessageHistory() @@ -492,7 +491,7 @@ def test_get_output_schema() -> None: "title": "RunnableWithChatHistoryOutput", "type": "object", } - if PYDANTIC_VERSION >= (2, 11): + if version.parse("2.11") <= PYDANTIC_VERSION: expected_schema["additionalProperties"] = True assert _schema(output_type) == expected_schema diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 71b006d6622..72b6c6461eb 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -15,9 +15,9 @@ from typing import ( ) from uuid import UUID -import pydantic import pytest from freezegun import freeze_time +from packaging import version from pydantic import BaseModel, Field from pytest_mock import MockerFixture from syrupy import SnapshotAssertion @@ -89,11 +89,14 @@ from langchain_core.tracers import ( RunLogPatch, ) from langchain_core.tracers.context import collect_runs -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION +from langchain_core.utils.pydantic import ( + PYDANTIC_VERSION, +) from tests.unit_tests.pydantic_utils import _normalize_schema, _schema from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk -PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split("."))) +PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION +PYDANTIC_VERSION_AT_LEAST_210 = version.parse("2.10") <= PYDANTIC_VERSION class FakeTracer(BaseTracer): @@ -227,7 +230,7 @@ class FakeRetriever(BaseRetriever): @pytest.mark.skipif( - (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10), + PYDANTIC_VERSION_AT_LEAST_210, reason=( "Only test with most recent version of pydantic. " "Pydantic introduced small fixes to generated JSONSchema on minor versions." @@ -649,7 +652,7 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None: } ) - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( RunnableLambda(aget_values_typed).get_output_jsonschema() # type: ignore ) == snapshot(name="schema8") @@ -764,7 +767,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None: assert fake_llm_configurable.invoke("...") == "a" - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( fake_llm_configurable.get_config_jsonschema() ) == snapshot(name="schema2") @@ -791,7 +794,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None: text="Hello, John!" ) - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( prompt_configurable.get_config_jsonschema() ) == snapshot(name="schema3") @@ -820,7 +823,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None: assert chain_configurable.invoke({"name": "John"}) == "a" - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( chain_configurable.get_config_jsonschema() ) == snapshot(name="schema4") @@ -865,7 +868,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None: "llm3": "a", } - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( chain_with_map_configurable.get_config_jsonschema() ) == snapshot(name="schema5") @@ -938,7 +941,7 @@ def test_configurable_fields_prefix_keys(snapshot: SnapshotAssertion) -> None: chain = prompt | fake_llm - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(_schema(chain.config_schema())) == snapshot( name="schema6" ) @@ -990,7 +993,7 @@ def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None: assert chain_configurable.invoke({"name": "John"}) == "a" - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( chain_configurable.get_config_jsonschema() ) == snapshot(name="schema7") @@ -3089,7 +3092,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N assert chain.middle == [RunnableLambda(passthrough)] assert isinstance(chain.last, RunnableParallel) - if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10): + if PYDANTIC_VERSION_AT_LEAST_210: assert dumps(chain, pretty=True) == snapshot # Test invoke diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index e267388cbb3..bc5e849c157 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -65,7 +65,8 @@ from langchain_core.utils.function_calling import ( convert_to_openai_tool, ) from langchain_core.utils.pydantic import ( - PYDANTIC_MAJOR_VERSION, + IS_PYDANTIC_V1, + IS_PYDANTIC_V2, _create_subset_model, create_model_v2, ) @@ -2017,7 +2018,7 @@ def test__is_message_content_type(obj: Any, *, expected: bool) -> None: assert _is_message_content_type(obj) is expected -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.") @pytest.mark.parametrize("use_v1_namespace", [True, False]) def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None: A = TypeVar("A") @@ -2086,7 +2087,7 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None: assert actual == expected -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.") +@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Testing pydantic v1.") def test__get_all_basemodel_annotations_v1() -> None: A = TypeVar("A") @@ -2214,7 +2215,7 @@ def test_create_retriever_tool() -> None: ) -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.") def test_tool_args_schema_pydantic_v2_with_metadata() -> None: from pydantic import BaseModel as BaseModelV2 from pydantic import Field as FieldV2 diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index 6bbcdcc3a94..070c4e01e12 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -7,7 +7,9 @@ import pytest from pydantic import ConfigDict from langchain_core.utils.pydantic import ( - PYDANTIC_MAJOR_VERSION, + IS_PYDANTIC_V1, + IS_PYDANTIC_V2, + PYDANTIC_VERSION, _create_subset_model_v2, create_model_v2, get_fields, @@ -95,11 +97,11 @@ def test_with_aliases() -> None: def test_is_basemodel_subclass() -> None: """Test pydantic.""" - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: from pydantic import BaseModel as BaseModelV1Proper assert is_basemodel_subclass(BaseModelV1Proper) - elif PYDANTIC_MAJOR_VERSION == 2: + elif IS_PYDANTIC_V2: from pydantic import BaseModel as BaseModelV2 from pydantic.v1 import BaseModel as BaseModelV1 @@ -107,20 +109,20 @@ def test_is_basemodel_subclass() -> None: assert is_basemodel_subclass(BaseModelV1) else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) def test_is_basemodel_instance() -> None: """Test pydantic.""" - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: from pydantic import BaseModel as BaseModelV1Proper class FooV1(BaseModelV1Proper): x: int assert is_basemodel_instance(FooV1(x=5)) - elif PYDANTIC_MAJOR_VERSION == 2: + elif IS_PYDANTIC_V2: from pydantic import BaseModel as BaseModelV2 from pydantic.v1 import BaseModel as BaseModelV1 @@ -134,11 +136,11 @@ def test_is_basemodel_instance() -> None: assert is_basemodel_instance(Bar(x=5)) else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2") def test_with_field_metadata() -> None: """Test pydantic with field metadata.""" from pydantic import BaseModel as BaseModelV2 @@ -167,7 +169,7 @@ def test_with_field_metadata() -> None: } -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Only tests Pydantic v1") +@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Only tests Pydantic v1") def test_fields_pydantic_v1() -> None: from pydantic import BaseModel @@ -178,7 +180,7 @@ def test_fields_pydantic_v1() -> None: assert fields == {"x": Foo.model_fields["x"]} # type: ignore[index] -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2") def test_fields_pydantic_v2_proper() -> None: from pydantic import BaseModel @@ -189,7 +191,7 @@ def test_fields_pydantic_v2_proper() -> None: assert fields == {"x": Foo.model_fields["x"]} -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2") def test_fields_pydantic_v1_from_2() -> None: from pydantic.v1 import BaseModel diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 08dbb118c9b..e0d08913f9a 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -16,7 +16,10 @@ from langchain_core.utils import ( guard_import, ) from langchain_core.utils._merge import merge_dicts -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION +from langchain_core.utils.pydantic import ( + IS_PYDANTIC_V1, + IS_PYDANTIC_V2, +) from langchain_core.utils.utils import secret_from_env @@ -211,7 +214,7 @@ def test_guard_import_failure( guard_import(module_name, pip_name=pip_name, package=package) -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2") def test_get_pydantic_field_names_v1_in_2() -> None: from pydantic.v1 import BaseModel as PydanticV1BaseModel from pydantic.v1 import Field @@ -226,7 +229,7 @@ def test_get_pydantic_field_names_v1_in_2() -> None: assert result == expected -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2") def test_get_pydantic_field_names_v2_in_2() -> None: from pydantic import BaseModel, Field @@ -240,7 +243,7 @@ def test_get_pydantic_field_names_v2_in_2() -> None: assert result == expected -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Requires pydantic 1") +@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Requires pydantic 1") def test_get_pydantic_field_names_v1() -> None: from pydantic import BaseModel, Field