core[minor]: Add support for pydantic 2 to utility to get fields (#24899)

Add compatibility for pydantic 2 for a utility function.

This will help push some small changes to master, so they don't have to
be kept track of on a separate branch.
This commit is contained in:
Eugene Yurtsev 2024-07-31 15:11:07 -04:00 committed by GitHub
parent 7d1694040d
commit 210623b409
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 74 additions and 5 deletions

View File

@ -41,6 +41,18 @@ else:
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
def is_pydantic_v1_subclass(cls: Type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
if PYDANTIC_MAJOR_VERSION == 1:
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV1):
return True
return False
def is_basemodel_subclass(cls: Type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel.

View File

@ -12,6 +12,9 @@ from packaging.version import parse
from requests import HTTPError, Response
from langchain_core.pydantic_v1 import SecretStr
from langchain_core.utils.pydantic import (
is_pydantic_v1_subclass,
)
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
@ -192,10 +195,16 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
Set[str]: Field names.
"""
all_required_field_names = set()
for field in pydantic_cls.__fields__.values():
all_required_field_names.add(field.name)
if field.has_alias:
all_required_field_names.add(field.alias)
if is_pydantic_v1_subclass(pydantic_cls):
for field in pydantic_cls.__fields__.values():
all_required_field_names.add(field.name)
if field.has_alias:
all_required_field_names.add(field.alias)
else: # Assuming pydantic 2 for now
for name, field in pydantic_cls.model_fields.items():
all_required_field_names.add(name)
if field.alias:
all_required_field_names.add(field.alias)
return all_required_field_names

View File

@ -6,8 +6,13 @@ from unittest.mock import patch
import pytest
from langchain_core import utils
from langchain_core.utils import check_package_version, guard_import
from langchain_core.utils import (
check_package_version,
get_pydantic_field_names,
guard_import,
)
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
@pytest.mark.parametrize(
@ -171,3 +176,46 @@ def test_guard_import_failure(
f"Please install it with `pip install {pip_name}`."
)
assert exc_info.value.msg == err_msg
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
def test_get_pydantic_field_names_v1_in_2() -> None:
from pydantic.v1 import BaseModel as PydanticV1BaseModel # pydantic: ignore
from pydantic.v1 import Field # pydantic: ignore
class PydanticV1Model(PydanticV1BaseModel):
field1: str
field2: int
alias_field: int = Field(alias="aliased_field")
result = get_pydantic_field_names(PydanticV1Model)
expected = {"field1", "field2", "aliased_field", "alias_field"}
assert result == expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
def test_get_pydantic_field_names_v2_in_2() -> None:
from pydantic import BaseModel, Field # pydantic: ignore
class PydanticModel(BaseModel):
field1: str
field2: int
alias_field: int = Field(alias="aliased_field")
result = get_pydantic_field_names(PydanticModel)
expected = {"field1", "field2", "aliased_field", "alias_field"}
assert result == expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Requires pydantic 1")
def test_get_pydantic_field_names_v1() -> None:
from pydantic import BaseModel, Field # pydantic: ignore
class PydanticModel(BaseModel):
field1: str
field2: int
alias_field: int = Field(alias="aliased_field")
result = get_pydantic_field_names(PydanticModel)
expected = {"field1", "field2", "aliased_field", "alias_field"}
assert result == expected