diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index b49a649a4a4..00c17676775 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -41,6 +41,18 @@ else: TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) +def is_pydantic_v1_subclass(cls: Type) -> bool: + """Check if the installed Pydantic version is 1.x-like.""" + if PYDANTIC_MAJOR_VERSION == 1: + return True + elif PYDANTIC_MAJOR_VERSION == 2: + from pydantic.v1 import BaseModel as BaseModelV1 + + if issubclass(cls, BaseModelV1): + return True + return False + + def is_basemodel_subclass(cls: Type) -> bool: """Check if the given class is a subclass of Pydantic BaseModel. diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 4ee4ee6de67..8711b399cd3 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -12,6 +12,9 @@ from packaging.version import parse from requests import HTTPError, Response from langchain_core.pydantic_v1 import SecretStr +from langchain_core.utils.pydantic import ( + is_pydantic_v1_subclass, +) def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: @@ -192,10 +195,16 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: Set[str]: Field names. """ all_required_field_names = set() - for field in pydantic_cls.__fields__.values(): - all_required_field_names.add(field.name) - if field.has_alias: - all_required_field_names.add(field.alias) + if is_pydantic_v1_subclass(pydantic_cls): + for field in pydantic_cls.__fields__.values(): + all_required_field_names.add(field.name) + if field.has_alias: + all_required_field_names.add(field.alias) + else: # Assuming pydantic 2 for now + for name, field in pydantic_cls.model_fields.items(): + all_required_field_names.add(name) + if field.alias: + all_required_field_names.add(field.alias) return all_required_field_names diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 2524aec5f9c..0aed89e1a64 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -6,8 +6,13 @@ from unittest.mock import patch import pytest from langchain_core import utils -from langchain_core.utils import check_package_version, guard_import +from langchain_core.utils import ( + check_package_version, + get_pydantic_field_names, + guard_import, +) from langchain_core.utils._merge import merge_dicts +from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION @pytest.mark.parametrize( @@ -171,3 +176,46 @@ def test_guard_import_failure( f"Please install it with `pip install {pip_name}`." ) assert exc_info.value.msg == err_msg + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2") +def test_get_pydantic_field_names_v1_in_2() -> None: + from pydantic.v1 import BaseModel as PydanticV1BaseModel # pydantic: ignore + from pydantic.v1 import Field # pydantic: ignore + + class PydanticV1Model(PydanticV1BaseModel): + field1: str + field2: int + alias_field: int = Field(alias="aliased_field") + + result = get_pydantic_field_names(PydanticV1Model) + expected = {"field1", "field2", "aliased_field", "alias_field"} + assert result == expected + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2") +def test_get_pydantic_field_names_v2_in_2() -> None: + from pydantic import BaseModel, Field # pydantic: ignore + + class PydanticModel(BaseModel): + field1: str + field2: int + alias_field: int = Field(alias="aliased_field") + + result = get_pydantic_field_names(PydanticModel) + expected = {"field1", "field2", "aliased_field", "alias_field"} + assert result == expected + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Requires pydantic 1") +def test_get_pydantic_field_names_v1() -> None: + from pydantic import BaseModel, Field # pydantic: ignore + + class PydanticModel(BaseModel): + field1: str + field2: int + alias_field: int = Field(alias="aliased_field") + + result = get_pydantic_field_names(PydanticModel) + expected = {"field1", "field2", "aliased_field", "alias_field"} + assert result == expected