mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
core[minor]: Add support for pydantic 2 to utility to get fields (#24899)
Add compatibility for pydantic 2 for a utility function. This will help push some small changes to master, so they don't have to be kept track of on a separate branch.
This commit is contained in:
parent
7d1694040d
commit
210623b409
@ -41,6 +41,18 @@ else:
|
||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||
|
||||
|
||||
def is_pydantic_v1_subclass(cls: Type) -> bool:
|
||||
"""Check if the installed Pydantic version is 1.x-like."""
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
return True
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if issubclass(cls, BaseModelV1):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_basemodel_subclass(cls: Type) -> bool:
|
||||
"""Check if the given class is a subclass of Pydantic BaseModel.
|
||||
|
||||
|
@ -12,6 +12,9 @@ from packaging.version import parse
|
||||
from requests import HTTPError, Response
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from langchain_core.utils.pydantic import (
|
||||
is_pydantic_v1_subclass,
|
||||
)
|
||||
|
||||
|
||||
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||
@ -192,10 +195,16 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
|
||||
Set[str]: Field names.
|
||||
"""
|
||||
all_required_field_names = set()
|
||||
for field in pydantic_cls.__fields__.values():
|
||||
all_required_field_names.add(field.name)
|
||||
if field.has_alias:
|
||||
all_required_field_names.add(field.alias)
|
||||
if is_pydantic_v1_subclass(pydantic_cls):
|
||||
for field in pydantic_cls.__fields__.values():
|
||||
all_required_field_names.add(field.name)
|
||||
if field.has_alias:
|
||||
all_required_field_names.add(field.alias)
|
||||
else: # Assuming pydantic 2 for now
|
||||
for name, field in pydantic_cls.model_fields.items():
|
||||
all_required_field_names.add(name)
|
||||
if field.alias:
|
||||
all_required_field_names.add(field.alias)
|
||||
return all_required_field_names
|
||||
|
||||
|
||||
|
@ -6,8 +6,13 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from langchain_core import utils
|
||||
from langchain_core.utils import check_package_version, guard_import
|
||||
from langchain_core.utils import (
|
||||
check_package_version,
|
||||
get_pydantic_field_names,
|
||||
guard_import,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -171,3 +176,46 @@ def test_guard_import_failure(
|
||||
f"Please install it with `pip install {pip_name}`."
|
||||
)
|
||||
assert exc_info.value.msg == err_msg
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
||||
def test_get_pydantic_field_names_v1_in_2() -> None:
|
||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel # pydantic: ignore
|
||||
from pydantic.v1 import Field # pydantic: ignore
|
||||
|
||||
class PydanticV1Model(PydanticV1BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
alias_field: int = Field(alias="aliased_field")
|
||||
|
||||
result = get_pydantic_field_names(PydanticV1Model)
|
||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
||||
def test_get_pydantic_field_names_v2_in_2() -> None:
|
||||
from pydantic import BaseModel, Field # pydantic: ignore
|
||||
|
||||
class PydanticModel(BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
alias_field: int = Field(alias="aliased_field")
|
||||
|
||||
result = get_pydantic_field_names(PydanticModel)
|
||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Requires pydantic 1")
|
||||
def test_get_pydantic_field_names_v1() -> None:
|
||||
from pydantic import BaseModel, Field # pydantic: ignore
|
||||
|
||||
class PydanticModel(BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
alias_field: int = Field(alias="aliased_field")
|
||||
|
||||
result = get_pydantic_field_names(PydanticModel)
|
||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||
assert result == expected
|
||||
|
Loading…
Reference in New Issue
Block a user