core: Rework pydantic version checks (#30653)

This pull request includes various changes to the `langchain_core`
library, focusing on improving compatibility with different versions of
Pydantic. The primary change involves replacing checks for Pydantic
major versions with boolean flags, which simplifies the code and
improves readability.
This also solves ruff rule checks for
[RUF048](https://docs.astral.sh/ruff/rules/map-int-version-parsing/) and
[PLR2004](https://docs.astral.sh/ruff/rules/magic-value-comparison/).

Key changes include:

### Compatibility Improvements:
*
[`libs/core/langchain_core/output_parsers/json.py`](diffhunk://#diff-5add0cf7134636ae4198a1e0df49ee332ae0c9123c3a2395101e02687c717646L22-R24):
Replaced `PYDANTIC_MAJOR_VERSION` with `IS_PYDANTIC_V1` to check for
Pydantic version 1.
*
[`libs/core/langchain_core/output_parsers/pydantic.py`](diffhunk://#diff-2364b5b4aee01c462aa5dbda5dc3a877dcd20f29df173ad540dc8adf8b192361L14-R14):
Updated version checks from `PYDANTIC_MAJOR_VERSION` to `IS_PYDANTIC_V2`
in the `PydanticOutputParser` class.
[[1]](diffhunk://#diff-2364b5b4aee01c462aa5dbda5dc3a877dcd20f29df173ad540dc8adf8b192361L14-R14)
[[2]](diffhunk://#diff-2364b5b4aee01c462aa5dbda5dc3a877dcd20f29df173ad540dc8adf8b192361L27-R27)

### Utility Enhancements:
*
[`libs/core/langchain_core/utils/pydantic.py`](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896R23):
Introduced `IS_PYDANTIC_V1` and `IS_PYDANTIC_V2` flags and deprecated
the `get_pydantic_major_version` function. Updated various functions to
use these flags instead of version numbers.
[[1]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896R23)
[[2]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896R42-R78)
[[3]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L90-R89)
[[4]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L104-R101)
[[5]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L120-R122)
[[6]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L135-R132)
[[7]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L149-R151)
[[8]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L164-R161)
[[9]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L248-R250)
[[10]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L330-R335)
[[11]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L356-R357)
[[12]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L393-R390)
[[13]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L403-R400)

### Test Updates:
*
[`libs/core/tests/unit_tests/output_parsers/test_openai_tools.py`](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L19-R22):
Updated tests to use `IS_PYDANTIC_V1` and `IS_PYDANTIC_V2` for version
checks.
[[1]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L19-R22)
[[2]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L532-R535)
[[3]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L567-R570)
[[4]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L602-R605)
*
[`libs/core/tests/unit_tests/prompts/test_chat.py`](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84R7):
Replaced version tuple checks with `PYDANTIC_VERSION` comparisons.
[[1]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84R7)
[[2]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84L35-R38)
[[3]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84L924-R927)
[[4]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84L935-R938)
*
[`libs/core/tests/unit_tests/runnables/test_graph.py`](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dR3):
Simplified version checks using `PYDANTIC_VERSION`.
[[1]](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dR3)
[[2]](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dL15-R18)
[[3]](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dL234-L239)
*
[`libs/core/tests/unit_tests/runnables/test_runnable.py`](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L18-R20):
Introduced `PYDANTIC_VERSION_AT_LEAST_29` and
`PYDANTIC_VERSION_AT_LEAST_210` for more readable version checks.
[[1]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L18-R20)
[[2]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L92-R99)
[[3]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L230-R233)
[[4]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L652-R655)
This commit is contained in:
Christophe Bornet 2025-04-04 19:42:30 +02:00 committed by GitHub
parent 43b5dc7191
commit 5e418c2666
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 115 additions and 96 deletions

View File

@ -19,9 +19,9 @@ from langchain_core.utils.json import (
parse_json_markdown,
parse_partial_json,
)
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
from langchain_core.utils.pydantic import IS_PYDANTIC_V1
if PYDANTIC_MAJOR_VERSION < 2:
if IS_PYDANTIC_V1:
PydanticBaseModel = pydantic.BaseModel
else:

View File

@ -11,7 +11,7 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
IS_PYDANTIC_V2,
PydanticBaseModel,
TBaseModel,
)
@ -24,7 +24,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""The pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel:
if PYDANTIC_MAJOR_VERSION == 2:
if IS_PYDANTIC_V2:
try:
if issubclass(self.pydantic_object, pydantic.BaseModel):
return self.pydantic_object.model_validate(obj)

View File

@ -20,6 +20,7 @@ from typing import (
)
import pydantic
from packaging import version
from pydantic import (
BaseModel,
ConfigDict,
@ -41,44 +42,46 @@ from pydantic.json_schema import (
if TYPE_CHECKING:
from pydantic_core import core_schema
try:
import pydantic
PYDANTIC_VERSION = version.parse(pydantic.__version__)
except ImportError:
PYDANTIC_VERSION = version.parse("0.0.0")
def get_pydantic_major_version() -> int:
"""Get the major version of Pydantic."""
try:
import pydantic
"""DEPRECATED - Get the major version of Pydantic.
return int(pydantic.__version__.split(".")[0])
except ImportError:
return 0
Use PYDANTIC_VERSION.major instead.
"""
warnings.warn(
"get_pydantic_major_version is deprecated. Use PYDANTIC_VERSION.major instead.",
DeprecationWarning,
stacklevel=2,
)
return PYDANTIC_VERSION.major
def _get_pydantic_minor_version() -> int:
"""Get the minor version of Pydantic."""
try:
import pydantic
PYDANTIC_MAJOR_VERSION = PYDANTIC_VERSION.major
PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
return int(pydantic.__version__.split(".")[1])
except ImportError:
return 0
IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1
IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
PYDANTIC_MINOR_VERSION = _get_pydantic_minor_version()
if PYDANTIC_MAJOR_VERSION == 1:
if IS_PYDANTIC_V1:
from pydantic.fields import FieldInfo as FieldInfoV1
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2:
elif IS_PYDANTIC_V2:
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc]
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
@ -87,9 +90,9 @@ TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
def is_pydantic_v1_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
if PYDANTIC_MAJOR_VERSION == 1:
if IS_PYDANTIC_V1:
return True
if PYDANTIC_MAJOR_VERSION == 2:
if IS_PYDANTIC_V2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV1):
@ -101,7 +104,7 @@ def is_pydantic_v2_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
from pydantic import BaseModel
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
return IS_PYDANTIC_V2 and issubclass(cls, BaseModel)
def is_basemodel_subclass(cls: type) -> bool:
@ -117,12 +120,12 @@ def is_basemodel_subclass(cls: type) -> bool:
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
return False
if PYDANTIC_MAJOR_VERSION == 1:
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
if issubclass(cls, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
@ -132,7 +135,7 @@ def is_basemodel_subclass(cls: type) -> bool:
if issubclass(cls, BaseModelV1):
return True
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
return False
@ -146,12 +149,12 @@ def is_basemodel_instance(obj: Any) -> bool:
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
if PYDANTIC_MAJOR_VERSION == 1:
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
if isinstance(obj, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
@ -161,7 +164,7 @@ def is_basemodel_instance(obj: Any) -> bool:
if isinstance(obj, BaseModelV1):
return True
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
return False
@ -245,12 +248,12 @@ def _create_subset_model_v1(
fn_description: Optional[str] = None,
) -> type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
if PYDANTIC_MAJOR_VERSION == 1:
if IS_PYDANTIC_V1:
from pydantic import create_model
elif PYDANTIC_MAJOR_VERSION == 2:
elif IS_PYDANTIC_V2:
from pydantic.v1 import create_model # type: ignore
else:
msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
raise NotImplementedError(msg)
fields = {}
@ -327,7 +330,7 @@ def _create_subset_model(
fn_description: Optional[str] = None,
) -> type[BaseModel]:
"""Create subset model using the same pydantic version as the input model."""
if PYDANTIC_MAJOR_VERSION == 1:
if IS_PYDANTIC_V1:
return _create_subset_model_v1(
name,
model,
@ -335,7 +338,7 @@ def _create_subset_model(
descriptions=descriptions,
fn_description=fn_description,
)
if PYDANTIC_MAJOR_VERSION == 2:
if IS_PYDANTIC_V2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(model, BaseModelV1):
@ -353,11 +356,11 @@ def _create_subset_model(
descriptions=descriptions,
fn_description=fn_description,
)
msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
raise NotImplementedError(msg)
if PYDANTIC_MAJOR_VERSION == 2:
if IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
@ -390,7 +393,7 @@ if PYDANTIC_MAJOR_VERSION == 2:
msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg)
elif PYDANTIC_MAJOR_VERSION == 1:
elif IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1_
def get_fields( # type: ignore[no-redef]
@ -400,7 +403,7 @@ elif PYDANTIC_MAJOR_VERSION == 1:
return model.__fields__ # type: ignore
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
_SchemaConfig = ConfigDict(

View File

@ -16,7 +16,10 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
)
STREAMED_MESSAGES: list = [
AIMessageChunk(content=""),
@ -529,7 +532,7 @@ async def test_partial_pydantic_output_parser_async() -> None:
assert actual == EXPECTED_STREAMED_PYDANTIC
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic
@ -564,7 +567,7 @@ def test_parse_with_different_pydantic_2_v1() -> None:
]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 2."""
import pydantic
@ -599,7 +602,7 @@ def test_parse_with_different_pydantic_2_proper() -> None:
]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1")
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="This test is for pydantic 1")
def test_parse_with_different_pydantic_1_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 1."""
import pydantic

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import Any, Union, cast
import pytest
from packaging import version
from pydantic import ValidationError
from syrupy import SnapshotAssertion
@ -32,7 +33,9 @@ from langchain_core.prompts.chat import (
_convert_to_message,
)
from langchain_core.prompts.string import PromptTemplateFormat
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION
from langchain_core.utils.pydantic import (
PYDANTIC_VERSION,
)
from tests.unit_tests.pydantic_utils import _normalize_schema
@ -921,7 +924,7 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
with pytest.raises(ValidationError):
prompt_all_required.input_schema(input="")
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
if version.parse("2.10") <= PYDANTIC_VERSION:
assert _normalize_schema(
prompt_all_required.get_input_jsonschema()
) == snapshot(name="required")
@ -932,7 +935,7 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
assert set(prompt_optional.input_variables) == {"input"}
prompt_optional.input_schema(input="") # won't raise error
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
if version.parse("2.10") <= PYDANTIC_VERSION:
assert _normalize_schema(prompt_optional.get_input_jsonschema()) == snapshot(
name="partial"
)

View File

@ -4,16 +4,17 @@ import re
from typing import Any, Union
from unittest import mock
import pydantic
import pytest
from packaging import version
from syrupy import SnapshotAssertion
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import PromptTemplateFormat
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
from langchain_core.utils.pydantic import PYDANTIC_VERSION
from tests.unit_tests.pydantic_utils import _normalize_schema
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
def test_prompt_valid() -> None:
@ -117,7 +118,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
"This foo is a bar test baz."
)
assert prompt.input_variables == ["foo", "obj"]
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
name="schema_0"
)
@ -144,7 +145,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
name="schema_2"
)
@ -168,7 +169,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
name="schema_3"
)
@ -206,7 +207,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
name="schema_4"
)
@ -224,7 +225,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
is a test.""" # noqa: W293
)
assert prompt.input_variables == ["foo"]
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
name="schema_5"
)

View File

@ -1,5 +1,6 @@
from typing import Any, Optional
from packaging import version
from pydantic import BaseModel
from syrupy import SnapshotAssertion
from typing_extensions import override
@ -12,7 +13,9 @@ from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables.base import Runnable, RunnableConfig
from langchain_core.runnables.graph import Edge, Graph, Node
from langchain_core.runnables.graph_mermaid import _escape_node_label
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION
from langchain_core.utils.pydantic import (
PYDANTIC_VERSION,
)
from tests.unit_tests.pydantic_utils import _normalize_schema
@ -231,12 +234,10 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
)
graph = sequence.get_graph()
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
if version.parse("2.10") <= PYDANTIC_VERSION:
assert _normalize_schema(graph.to_json(with_schemas=True)) == snapshot(
name="graph_with_schema"
)
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
assert _normalize_schema(graph.to_json()) == snapshot(name="graph_no_schemas")
assert graph.draw_ascii() == snapshot(name="ascii")

View File

@ -2,8 +2,8 @@ import re
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import pydantic
import pytest
from packaging import version
from pydantic import BaseModel
from langchain_core.callbacks import (
@ -19,10 +19,9 @@ from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
from langchain_core.tracers import Run
from langchain_core.utils.pydantic import PYDANTIC_VERSION
from tests.unit_tests.pydantic_utils import _schema
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
def test_interfaces() -> None:
history = InMemoryChatMessageHistory()
@ -492,7 +491,7 @@ def test_get_output_schema() -> None:
"title": "RunnableWithChatHistoryOutput",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 11):
if version.parse("2.11") <= PYDANTIC_VERSION:
expected_schema["additionalProperties"] = True
assert _schema(output_type) == expected_schema

View File

@ -15,9 +15,9 @@ from typing import (
)
from uuid import UUID
import pydantic
import pytest
from freezegun import freeze_time
from packaging import version
from pydantic import BaseModel, Field
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
@ -89,11 +89,14 @@ from langchain_core.tracers import (
RunLogPatch,
)
from langchain_core.tracers.context import collect_runs
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION
from langchain_core.utils.pydantic import (
PYDANTIC_VERSION,
)
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
PYDANTIC_VERSION_AT_LEAST_210 = version.parse("2.10") <= PYDANTIC_VERSION
class FakeTracer(BaseTracer):
@ -227,7 +230,7 @@ class FakeRetriever(BaseRetriever):
@pytest.mark.skipif(
(PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10),
PYDANTIC_VERSION_AT_LEAST_210,
reason=(
"Only test with most recent version of pydantic. "
"Pydantic introduced small fixes to generated JSONSchema on minor versions."
@ -649,7 +652,7 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
}
)
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(
RunnableLambda(aget_values_typed).get_output_jsonschema() # type: ignore
) == snapshot(name="schema8")
@ -764,7 +767,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None:
assert fake_llm_configurable.invoke("...") == "a"
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(
fake_llm_configurable.get_config_jsonschema()
) == snapshot(name="schema2")
@ -791,7 +794,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None:
text="Hello, John!"
)
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(
prompt_configurable.get_config_jsonschema()
) == snapshot(name="schema3")
@ -820,7 +823,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None:
assert chain_configurable.invoke({"name": "John"}) == "a"
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(
chain_configurable.get_config_jsonschema()
) == snapshot(name="schema4")
@ -865,7 +868,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None:
"llm3": "a",
}
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(
chain_with_map_configurable.get_config_jsonschema()
) == snapshot(name="schema5")
@ -938,7 +941,7 @@ def test_configurable_fields_prefix_keys(snapshot: SnapshotAssertion) -> None:
chain = prompt | fake_llm
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(_schema(chain.config_schema())) == snapshot(
name="schema6"
)
@ -990,7 +993,7 @@ def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None:
assert chain_configurable.invoke({"name": "John"}) == "a"
if PYDANTIC_VERSION >= (2, 9):
if PYDANTIC_VERSION_AT_LEAST_29:
assert _normalize_schema(
chain_configurable.get_config_jsonschema()
) == snapshot(name="schema7")
@ -3089,7 +3092,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
assert chain.middle == [RunnableLambda(passthrough)]
assert isinstance(chain.last, RunnableParallel)
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
if PYDANTIC_VERSION_AT_LEAST_210:
assert dumps(chain, pretty=True) == snapshot
# Test invoke

View File

@ -65,7 +65,8 @@ from langchain_core.utils.function_calling import (
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
_create_subset_model,
create_model_v2,
)
@ -2017,7 +2018,7 @@ def test__is_message_content_type(obj: Any, *, expected: bool) -> None:
assert _is_message_content_type(obj) is expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
@pytest.mark.parametrize("use_v1_namespace", [True, False])
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
A = TypeVar("A")
@ -2086,7 +2087,7 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
assert actual == expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.")
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Testing pydantic v1.")
def test__get_all_basemodel_annotations_v1() -> None:
A = TypeVar("A")
@ -2214,7 +2215,7 @@ def test_create_retriever_tool() -> None:
)
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
from pydantic import BaseModel as BaseModelV2
from pydantic import Field as FieldV2

View File

@ -7,7 +7,9 @@ import pytest
from pydantic import ConfigDict
from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
PYDANTIC_VERSION,
_create_subset_model_v2,
create_model_v2,
get_fields,
@ -95,11 +97,11 @@ def test_with_aliases() -> None:
def test_is_basemodel_subclass() -> None:
"""Test pydantic."""
if PYDANTIC_MAJOR_VERSION == 1:
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
assert is_basemodel_subclass(BaseModelV1Proper)
elif PYDANTIC_MAJOR_VERSION == 2:
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
@ -107,20 +109,20 @@ def test_is_basemodel_subclass() -> None:
assert is_basemodel_subclass(BaseModelV1)
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
def test_is_basemodel_instance() -> None:
"""Test pydantic."""
if PYDANTIC_MAJOR_VERSION == 1:
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
class FooV1(BaseModelV1Proper):
x: int
assert is_basemodel_instance(FooV1(x=5))
elif PYDANTIC_MAJOR_VERSION == 2:
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
@ -134,11 +136,11 @@ def test_is_basemodel_instance() -> None:
assert is_basemodel_instance(Bar(x=5))
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
def test_with_field_metadata() -> None:
"""Test pydantic with field metadata."""
from pydantic import BaseModel as BaseModelV2
@ -167,7 +169,7 @@ def test_with_field_metadata() -> None:
}
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Only tests Pydantic v1")
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Only tests Pydantic v1")
def test_fields_pydantic_v1() -> None:
from pydantic import BaseModel
@ -178,7 +180,7 @@ def test_fields_pydantic_v1() -> None:
assert fields == {"x": Foo.model_fields["x"]} # type: ignore[index]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
def test_fields_pydantic_v2_proper() -> None:
from pydantic import BaseModel
@ -189,7 +191,7 @@ def test_fields_pydantic_v2_proper() -> None:
assert fields == {"x": Foo.model_fields["x"]}
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
def test_fields_pydantic_v1_from_2() -> None:
from pydantic.v1 import BaseModel

View File

@ -16,7 +16,10 @@ from langchain_core.utils import (
guard_import,
)
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
)
from langchain_core.utils.utils import secret_from_env
@ -211,7 +214,7 @@ def test_guard_import_failure(
guard_import(module_name, pip_name=pip_name, package=package)
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
def test_get_pydantic_field_names_v1_in_2() -> None:
from pydantic.v1 import BaseModel as PydanticV1BaseModel
from pydantic.v1 import Field
@ -226,7 +229,7 @@ def test_get_pydantic_field_names_v1_in_2() -> None:
assert result == expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
def test_get_pydantic_field_names_v2_in_2() -> None:
from pydantic import BaseModel, Field
@ -240,7 +243,7 @@ def test_get_pydantic_field_names_v2_in_2() -> None:
assert result == expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Requires pydantic 1")
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Requires pydantic 1")
def test_get_pydantic_field_names_v1() -> None:
from pydantic import BaseModel, Field