From 5e418c2666a4cebbdedf79e9dc92655d741b0874 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 4 Apr 2025 19:42:30 +0200 Subject: [PATCH] core: Rework pydantic version checks (#30653) This pull request includes various changes to the `langchain_core` library, focusing on improving compatibility with different versions of Pydantic. The primary change involves replacing checks for Pydantic major versions with boolean flags, which simplifies the code and improves readability. This also solves ruff rule checks for [RUF048](https://docs.astral.sh/ruff/rules/map-int-version-parsing/) and [PLR2004](https://docs.astral.sh/ruff/rules/magic-value-comparison/). Key changes include: ### Compatibility Improvements: * [`libs/core/langchain_core/output_parsers/json.py`](diffhunk://#diff-5add0cf7134636ae4198a1e0df49ee332ae0c9123c3a2395101e02687c717646L22-R24): Replaced `PYDANTIC_MAJOR_VERSION` with `IS_PYDANTIC_V1` to check for Pydantic version 1. * [`libs/core/langchain_core/output_parsers/pydantic.py`](diffhunk://#diff-2364b5b4aee01c462aa5dbda5dc3a877dcd20f29df173ad540dc8adf8b192361L14-R14): Updated version checks from `PYDANTIC_MAJOR_VERSION` to `IS_PYDANTIC_V2` in the `PydanticOutputParser` class. [[1]](diffhunk://#diff-2364b5b4aee01c462aa5dbda5dc3a877dcd20f29df173ad540dc8adf8b192361L14-R14) [[2]](diffhunk://#diff-2364b5b4aee01c462aa5dbda5dc3a877dcd20f29df173ad540dc8adf8b192361L27-R27) ### Utility Enhancements: * [`libs/core/langchain_core/utils/pydantic.py`](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896R23): Introduced `IS_PYDANTIC_V1` and `IS_PYDANTIC_V2` flags and deprecated the `get_pydantic_major_version` function. Updated various functions to use these flags instead of version numbers. [[1]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896R23) [[2]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896R42-R78) [[3]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L90-R89) [[4]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L104-R101) [[5]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L120-R122) [[6]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L135-R132) [[7]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L149-R151) [[8]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L164-R161) [[9]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L248-R250) [[10]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L330-R335) [[11]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L356-R357) [[12]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L393-R390) [[13]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L403-R400) ### Test Updates: * [`libs/core/tests/unit_tests/output_parsers/test_openai_tools.py`](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L19-R22): Updated tests to use `IS_PYDANTIC_V1` and `IS_PYDANTIC_V2` for version checks. [[1]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L19-R22) [[2]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L532-R535) [[3]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L567-R570) [[4]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L602-R605) * [`libs/core/tests/unit_tests/prompts/test_chat.py`](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84R7): Replaced version tuple checks with `PYDANTIC_VERSION` comparisons. [[1]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84R7) [[2]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84L35-R38) [[3]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84L924-R927) [[4]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84L935-R938) * [`libs/core/tests/unit_tests/runnables/test_graph.py`](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dR3): Simplified version checks using `PYDANTIC_VERSION`. [[1]](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dR3) [[2]](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dL15-R18) [[3]](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dL234-L239) * [`libs/core/tests/unit_tests/runnables/test_runnable.py`](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L18-R20): Introduced `PYDANTIC_VERSION_AT_LEAST_29` and `PYDANTIC_VERSION_AT_LEAST_210` for more readable version checks. [[1]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L18-R20) [[2]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L92-R99) [[3]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L230-R233) [[4]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L652-R655) --- .../langchain_core/output_parsers/json.py | 4 +- .../langchain_core/output_parsers/pydantic.py | 4 +- libs/core/langchain_core/utils/pydantic.py | 81 ++++++++++--------- .../output_parsers/test_openai_tools.py | 11 ++- .../tests/unit_tests/prompts/test_chat.py | 9 ++- .../tests/unit_tests/prompts/test_prompt.py | 15 ++-- .../tests/unit_tests/runnables/test_graph.py | 9 ++- .../unit_tests/runnables/test_history.py | 7 +- .../unit_tests/runnables/test_runnable.py | 27 ++++--- libs/core/tests/unit_tests/test_tools.py | 9 ++- .../tests/unit_tests/utils/test_pydantic.py | 24 +++--- .../core/tests/unit_tests/utils/test_utils.py | 11 ++- 12 files changed, 115 insertions(+), 96 deletions(-) diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index 3dafaecc755..0b64c594a31 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -19,9 +19,9 @@ from langchain_core.utils.json import ( parse_json_markdown, parse_partial_json, ) -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION +from langchain_core.utils.pydantic import IS_PYDANTIC_V1 -if PYDANTIC_MAJOR_VERSION < 2: +if IS_PYDANTIC_V1: PydanticBaseModel = pydantic.BaseModel else: diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index 6550d4ce7d8..b1ee96e0fa0 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -11,7 +11,7 @@ from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import JsonOutputParser from langchain_core.outputs import Generation from langchain_core.utils.pydantic import ( - PYDANTIC_MAJOR_VERSION, + IS_PYDANTIC_V2, PydanticBaseModel, TBaseModel, ) @@ -24,7 +24,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): """The pydantic model to parse.""" def _parse_obj(self, obj: dict) -> TBaseModel: - if PYDANTIC_MAJOR_VERSION == 2: + if IS_PYDANTIC_V2: try: if issubclass(self.pydantic_object, pydantic.BaseModel): return self.pydantic_object.model_validate(obj) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index df698d9f13d..9a8f9fd832b 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -20,6 +20,7 @@ from typing import ( ) import pydantic +from packaging import version from pydantic import ( BaseModel, ConfigDict, @@ -41,44 +42,46 @@ from pydantic.json_schema import ( if TYPE_CHECKING: from pydantic_core import core_schema +try: + import pydantic + + PYDANTIC_VERSION = version.parse(pydantic.__version__) +except ImportError: + PYDANTIC_VERSION = version.parse("0.0.0") + def get_pydantic_major_version() -> int: - """Get the major version of Pydantic.""" - try: - import pydantic + """DEPRECATED - Get the major version of Pydantic. - return int(pydantic.__version__.split(".")[0]) - except ImportError: - return 0 + Use PYDANTIC_VERSION.major instead. + """ + warnings.warn( + "get_pydantic_major_version is deprecated. Use PYDANTIC_VERSION.major instead.", + DeprecationWarning, + stacklevel=2, + ) + return PYDANTIC_VERSION.major -def _get_pydantic_minor_version() -> int: - """Get the minor version of Pydantic.""" - try: - import pydantic +PYDANTIC_MAJOR_VERSION = PYDANTIC_VERSION.major +PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor - return int(pydantic.__version__.split(".")[1]) - except ImportError: - return 0 +IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1 +IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2 - -PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() -PYDANTIC_MINOR_VERSION = _get_pydantic_minor_version() - - -if PYDANTIC_MAJOR_VERSION == 1: +if IS_PYDANTIC_V1: from pydantic.fields import FieldInfo as FieldInfoV1 PydanticBaseModel = pydantic.BaseModel TypeBaseModel = type[BaseModel] -elif PYDANTIC_MAJOR_VERSION == 2: +elif IS_PYDANTIC_V2: from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment] # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc] TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc] else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) @@ -87,9 +90,9 @@ TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) def is_pydantic_v1_subclass(cls: type) -> bool: """Check if the installed Pydantic version is 1.x-like.""" - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: return True - if PYDANTIC_MAJOR_VERSION == 2: + if IS_PYDANTIC_V2: from pydantic.v1 import BaseModel as BaseModelV1 if issubclass(cls, BaseModelV1): @@ -101,7 +104,7 @@ def is_pydantic_v2_subclass(cls: type) -> bool: """Check if the installed Pydantic version is 1.x-like.""" from pydantic import BaseModel - return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel) + return IS_PYDANTIC_V2 and issubclass(cls, BaseModel) def is_basemodel_subclass(cls: type) -> bool: @@ -117,12 +120,12 @@ def is_basemodel_subclass(cls: type) -> bool: if not inspect.isclass(cls) or isinstance(cls, GenericAlias): return False - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: from pydantic import BaseModel as BaseModelV1Proper if issubclass(cls, BaseModelV1Proper): return True - elif PYDANTIC_MAJOR_VERSION == 2: + elif IS_PYDANTIC_V2: from pydantic import BaseModel as BaseModelV2 from pydantic.v1 import BaseModel as BaseModelV1 @@ -132,7 +135,7 @@ def is_basemodel_subclass(cls: type) -> bool: if issubclass(cls, BaseModelV1): return True else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) return False @@ -146,12 +149,12 @@ def is_basemodel_instance(obj: Any) -> bool: * pydantic.BaseModel in Pydantic 2.x * pydantic.v1.BaseModel in Pydantic 2.x """ - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: from pydantic import BaseModel as BaseModelV1Proper if isinstance(obj, BaseModelV1Proper): return True - elif PYDANTIC_MAJOR_VERSION == 2: + elif IS_PYDANTIC_V2: from pydantic import BaseModel as BaseModelV2 from pydantic.v1 import BaseModel as BaseModelV1 @@ -161,7 +164,7 @@ def is_basemodel_instance(obj: Any) -> bool: if isinstance(obj, BaseModelV1): return True else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) return False @@ -245,12 +248,12 @@ def _create_subset_model_v1( fn_description: Optional[str] = None, ) -> type[BaseModel]: """Create a pydantic model with only a subset of model's fields.""" - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: from pydantic import create_model - elif PYDANTIC_MAJOR_VERSION == 2: + elif IS_PYDANTIC_V2: from pydantic.v1 import create_model # type: ignore else: - msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}" raise NotImplementedError(msg) fields = {} @@ -327,7 +330,7 @@ def _create_subset_model( fn_description: Optional[str] = None, ) -> type[BaseModel]: """Create subset model using the same pydantic version as the input model.""" - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: return _create_subset_model_v1( name, model, @@ -335,7 +338,7 @@ def _create_subset_model( descriptions=descriptions, fn_description=fn_description, ) - if PYDANTIC_MAJOR_VERSION == 2: + if IS_PYDANTIC_V2: from pydantic.v1 import BaseModel as BaseModelV1 if issubclass(model, BaseModelV1): @@ -353,11 +356,11 @@ def _create_subset_model( descriptions=descriptions, fn_description=fn_description, ) - msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}" raise NotImplementedError(msg) -if PYDANTIC_MAJOR_VERSION == 2: +if IS_PYDANTIC_V2: from pydantic import BaseModel as BaseModelV2 from pydantic.v1 import BaseModel as BaseModelV1 @@ -390,7 +393,7 @@ if PYDANTIC_MAJOR_VERSION == 2: msg = f"Expected a Pydantic model. Got {type(model)}" raise TypeError(msg) -elif PYDANTIC_MAJOR_VERSION == 1: +elif IS_PYDANTIC_V1: from pydantic import BaseModel as BaseModelV1_ def get_fields( # type: ignore[no-redef] @@ -400,7 +403,7 @@ elif PYDANTIC_MAJOR_VERSION == 1: return model.__fields__ # type: ignore else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) _SchemaConfig = ConfigDict( diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index e5fe0f3076c..9fbbd6c446f 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -16,7 +16,10 @@ from langchain_core.output_parsers.openai_tools import ( PydanticToolsParser, ) from langchain_core.outputs import ChatGeneration -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION +from langchain_core.utils.pydantic import ( + IS_PYDANTIC_V1, + IS_PYDANTIC_V2, +) STREAMED_MESSAGES: list = [ AIMessageChunk(content=""), @@ -529,7 +532,7 @@ async def test_partial_pydantic_output_parser_async() -> None: assert actual == EXPECTED_STREAMED_PYDANTIC -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2") def test_parse_with_different_pydantic_2_v1() -> None: """Test with pydantic.v1.BaseModel from pydantic 2.""" import pydantic @@ -564,7 +567,7 @@ def test_parse_with_different_pydantic_2_v1() -> None: ] -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2") def test_parse_with_different_pydantic_2_proper() -> None: """Test with pydantic.BaseModel from pydantic 2.""" import pydantic @@ -599,7 +602,7 @@ def test_parse_with_different_pydantic_2_proper() -> None: ] -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1") +@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="This test is for pydantic 1") def test_parse_with_different_pydantic_1_proper() -> None: """Test with pydantic.BaseModel from pydantic 1.""" import pydantic diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index ba82ff8c760..84a83c6ae2d 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Union, cast import pytest +from packaging import version from pydantic import ValidationError from syrupy import SnapshotAssertion @@ -32,7 +33,9 @@ from langchain_core.prompts.chat import ( _convert_to_message, ) from langchain_core.prompts.string import PromptTemplateFormat -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION +from langchain_core.utils.pydantic import ( + PYDANTIC_VERSION, +) from tests.unit_tests.pydantic_utils import _normalize_schema @@ -921,7 +924,7 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None: with pytest.raises(ValidationError): prompt_all_required.input_schema(input="") - if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10): + if version.parse("2.10") <= PYDANTIC_VERSION: assert _normalize_schema( prompt_all_required.get_input_jsonschema() ) == snapshot(name="required") @@ -932,7 +935,7 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None: assert set(prompt_optional.input_variables) == {"input"} prompt_optional.input_schema(input="") # won't raise error - if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10): + if version.parse("2.10") <= PYDANTIC_VERSION: assert _normalize_schema(prompt_optional.get_input_jsonschema()) == snapshot( name="partial" ) diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index f05451bb0fd..03941ad6ff4 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -4,16 +4,17 @@ import re from typing import Any, Union from unittest import mock -import pydantic import pytest +from packaging import version from syrupy import SnapshotAssertion from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import PromptTemplateFormat from langchain_core.tracers.run_collector import RunCollectorCallbackHandler +from langchain_core.utils.pydantic import PYDANTIC_VERSION from tests.unit_tests.pydantic_utils import _normalize_schema -PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split("."))) +PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION def test_prompt_valid() -> None: @@ -117,7 +118,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: "This foo is a bar test baz." ) assert prompt.input_variables == ["foo", "obj"] - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot( name="schema_0" ) @@ -144,7 +145,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot( name="schema_2" ) @@ -168,7 +169,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot( name="schema_3" ) @@ -206,7 +207,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot( name="schema_4" ) @@ -224,7 +225,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: is a test.""" # noqa: W293 ) assert prompt.input_variables == ["foo"] - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot( name="schema_5" ) diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 171d8819186..545db59c8be 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -1,5 +1,6 @@ from typing import Any, Optional +from packaging import version from pydantic import BaseModel from syrupy import SnapshotAssertion from typing_extensions import override @@ -12,7 +13,9 @@ from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables.base import Runnable, RunnableConfig from langchain_core.runnables.graph import Edge, Graph, Node from langchain_core.runnables.graph_mermaid import _escape_node_label -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION +from langchain_core.utils.pydantic import ( + PYDANTIC_VERSION, +) from tests.unit_tests.pydantic_utils import _normalize_schema @@ -231,12 +234,10 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: ) graph = sequence.get_graph() - if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10): + if version.parse("2.10") <= PYDANTIC_VERSION: assert _normalize_schema(graph.to_json(with_schemas=True)) == snapshot( name="graph_with_schema" ) - - if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10): assert _normalize_schema(graph.to_json()) == snapshot(name="graph_no_schemas") assert graph.draw_ascii() == snapshot(name="ascii") diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index d5dc1072869..35f6c2b3801 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -2,8 +2,8 @@ import re from collections.abc import Sequence from typing import Any, Callable, Optional, Union -import pydantic import pytest +from packaging import version from pydantic import BaseModel from langchain_core.callbacks import ( @@ -19,10 +19,9 @@ from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output from langchain_core.tracers import Run +from langchain_core.utils.pydantic import PYDANTIC_VERSION from tests.unit_tests.pydantic_utils import _schema -PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split("."))) - def test_interfaces() -> None: history = InMemoryChatMessageHistory() @@ -492,7 +491,7 @@ def test_get_output_schema() -> None: "title": "RunnableWithChatHistoryOutput", "type": "object", } - if PYDANTIC_VERSION >= (2, 11): + if version.parse("2.11") <= PYDANTIC_VERSION: expected_schema["additionalProperties"] = True assert _schema(output_type) == expected_schema diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 71b006d6622..72b6c6461eb 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -15,9 +15,9 @@ from typing import ( ) from uuid import UUID -import pydantic import pytest from freezegun import freeze_time +from packaging import version from pydantic import BaseModel, Field from pytest_mock import MockerFixture from syrupy import SnapshotAssertion @@ -89,11 +89,14 @@ from langchain_core.tracers import ( RunLogPatch, ) from langchain_core.tracers.context import collect_runs -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION +from langchain_core.utils.pydantic import ( + PYDANTIC_VERSION, +) from tests.unit_tests.pydantic_utils import _normalize_schema, _schema from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk -PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split("."))) +PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION +PYDANTIC_VERSION_AT_LEAST_210 = version.parse("2.10") <= PYDANTIC_VERSION class FakeTracer(BaseTracer): @@ -227,7 +230,7 @@ class FakeRetriever(BaseRetriever): @pytest.mark.skipif( - (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10), + PYDANTIC_VERSION_AT_LEAST_210, reason=( "Only test with most recent version of pydantic. " "Pydantic introduced small fixes to generated JSONSchema on minor versions." @@ -649,7 +652,7 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None: } ) - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( RunnableLambda(aget_values_typed).get_output_jsonschema() # type: ignore ) == snapshot(name="schema8") @@ -764,7 +767,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None: assert fake_llm_configurable.invoke("...") == "a" - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( fake_llm_configurable.get_config_jsonschema() ) == snapshot(name="schema2") @@ -791,7 +794,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None: text="Hello, John!" ) - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( prompt_configurable.get_config_jsonschema() ) == snapshot(name="schema3") @@ -820,7 +823,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None: assert chain_configurable.invoke({"name": "John"}) == "a" - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( chain_configurable.get_config_jsonschema() ) == snapshot(name="schema4") @@ -865,7 +868,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None: "llm3": "a", } - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( chain_with_map_configurable.get_config_jsonschema() ) == snapshot(name="schema5") @@ -938,7 +941,7 @@ def test_configurable_fields_prefix_keys(snapshot: SnapshotAssertion) -> None: chain = prompt | fake_llm - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema(_schema(chain.config_schema())) == snapshot( name="schema6" ) @@ -990,7 +993,7 @@ def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None: assert chain_configurable.invoke({"name": "John"}) == "a" - if PYDANTIC_VERSION >= (2, 9): + if PYDANTIC_VERSION_AT_LEAST_29: assert _normalize_schema( chain_configurable.get_config_jsonschema() ) == snapshot(name="schema7") @@ -3089,7 +3092,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N assert chain.middle == [RunnableLambda(passthrough)] assert isinstance(chain.last, RunnableParallel) - if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10): + if PYDANTIC_VERSION_AT_LEAST_210: assert dumps(chain, pretty=True) == snapshot # Test invoke diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index e267388cbb3..bc5e849c157 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -65,7 +65,8 @@ from langchain_core.utils.function_calling import ( convert_to_openai_tool, ) from langchain_core.utils.pydantic import ( - PYDANTIC_MAJOR_VERSION, + IS_PYDANTIC_V1, + IS_PYDANTIC_V2, _create_subset_model, create_model_v2, ) @@ -2017,7 +2018,7 @@ def test__is_message_content_type(obj: Any, *, expected: bool) -> None: assert _is_message_content_type(obj) is expected -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.") @pytest.mark.parametrize("use_v1_namespace", [True, False]) def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None: A = TypeVar("A") @@ -2086,7 +2087,7 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None: assert actual == expected -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.") +@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Testing pydantic v1.") def test__get_all_basemodel_annotations_v1() -> None: A = TypeVar("A") @@ -2214,7 +2215,7 @@ def test_create_retriever_tool() -> None: ) -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.") def test_tool_args_schema_pydantic_v2_with_metadata() -> None: from pydantic import BaseModel as BaseModelV2 from pydantic import Field as FieldV2 diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index 6bbcdcc3a94..070c4e01e12 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -7,7 +7,9 @@ import pytest from pydantic import ConfigDict from langchain_core.utils.pydantic import ( - PYDANTIC_MAJOR_VERSION, + IS_PYDANTIC_V1, + IS_PYDANTIC_V2, + PYDANTIC_VERSION, _create_subset_model_v2, create_model_v2, get_fields, @@ -95,11 +97,11 @@ def test_with_aliases() -> None: def test_is_basemodel_subclass() -> None: """Test pydantic.""" - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: from pydantic import BaseModel as BaseModelV1Proper assert is_basemodel_subclass(BaseModelV1Proper) - elif PYDANTIC_MAJOR_VERSION == 2: + elif IS_PYDANTIC_V2: from pydantic import BaseModel as BaseModelV2 from pydantic.v1 import BaseModel as BaseModelV1 @@ -107,20 +109,20 @@ def test_is_basemodel_subclass() -> None: assert is_basemodel_subclass(BaseModelV1) else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) def test_is_basemodel_instance() -> None: """Test pydantic.""" - if PYDANTIC_MAJOR_VERSION == 1: + if IS_PYDANTIC_V1: from pydantic import BaseModel as BaseModelV1Proper class FooV1(BaseModelV1Proper): x: int assert is_basemodel_instance(FooV1(x=5)) - elif PYDANTIC_MAJOR_VERSION == 2: + elif IS_PYDANTIC_V2: from pydantic import BaseModel as BaseModelV2 from pydantic.v1 import BaseModel as BaseModelV1 @@ -134,11 +136,11 @@ def test_is_basemodel_instance() -> None: assert is_basemodel_instance(Bar(x=5)) else: - msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" + msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}" raise ValueError(msg) -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2") def test_with_field_metadata() -> None: """Test pydantic with field metadata.""" from pydantic import BaseModel as BaseModelV2 @@ -167,7 +169,7 @@ def test_with_field_metadata() -> None: } -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Only tests Pydantic v1") +@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Only tests Pydantic v1") def test_fields_pydantic_v1() -> None: from pydantic import BaseModel @@ -178,7 +180,7 @@ def test_fields_pydantic_v1() -> None: assert fields == {"x": Foo.model_fields["x"]} # type: ignore[index] -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2") def test_fields_pydantic_v2_proper() -> None: from pydantic import BaseModel @@ -189,7 +191,7 @@ def test_fields_pydantic_v2_proper() -> None: assert fields == {"x": Foo.model_fields["x"]} -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2") def test_fields_pydantic_v1_from_2() -> None: from pydantic.v1 import BaseModel diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 08dbb118c9b..e0d08913f9a 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -16,7 +16,10 @@ from langchain_core.utils import ( guard_import, ) from langchain_core.utils._merge import merge_dicts -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION +from langchain_core.utils.pydantic import ( + IS_PYDANTIC_V1, + IS_PYDANTIC_V2, +) from langchain_core.utils.utils import secret_from_env @@ -211,7 +214,7 @@ def test_guard_import_failure( guard_import(module_name, pip_name=pip_name, package=package) -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2") def test_get_pydantic_field_names_v1_in_2() -> None: from pydantic.v1 import BaseModel as PydanticV1BaseModel from pydantic.v1 import Field @@ -226,7 +229,7 @@ def test_get_pydantic_field_names_v1_in_2() -> None: assert result == expected -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2") +@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2") def test_get_pydantic_field_names_v2_in_2() -> None: from pydantic import BaseModel, Field @@ -240,7 +243,7 @@ def test_get_pydantic_field_names_v2_in_2() -> None: assert result == expected -@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Requires pydantic 1") +@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Requires pydantic 1") def test_get_pydantic_field_names_v1() -> None: from pydantic import BaseModel, Field