mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
core[minor]: Add support for pydantic 2 to utility to get fields (#24899)
Add compatibility for pydantic 2 for a utility function. This will help push some small changes to master, so they don't have to be kept track of on a separate branch.
This commit is contained in:
parent
7d1694040d
commit
210623b409
@ -41,6 +41,18 @@ else:
|
|||||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
def is_pydantic_v1_subclass(cls: Type) -> bool:
|
||||||
|
"""Check if the installed Pydantic version is 1.x-like."""
|
||||||
|
if PYDANTIC_MAJOR_VERSION == 1:
|
||||||
|
return True
|
||||||
|
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
|
if issubclass(cls, BaseModelV1):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_basemodel_subclass(cls: Type) -> bool:
|
def is_basemodel_subclass(cls: Type) -> bool:
|
||||||
"""Check if the given class is a subclass of Pydantic BaseModel.
|
"""Check if the given class is a subclass of Pydantic BaseModel.
|
||||||
|
|
||||||
|
@ -12,6 +12,9 @@ from packaging.version import parse
|
|||||||
from requests import HTTPError, Response
|
from requests import HTTPError, Response
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import SecretStr
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
|
from langchain_core.utils.pydantic import (
|
||||||
|
is_pydantic_v1_subclass,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||||
@ -192,10 +195,16 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
|
|||||||
Set[str]: Field names.
|
Set[str]: Field names.
|
||||||
"""
|
"""
|
||||||
all_required_field_names = set()
|
all_required_field_names = set()
|
||||||
for field in pydantic_cls.__fields__.values():
|
if is_pydantic_v1_subclass(pydantic_cls):
|
||||||
all_required_field_names.add(field.name)
|
for field in pydantic_cls.__fields__.values():
|
||||||
if field.has_alias:
|
all_required_field_names.add(field.name)
|
||||||
all_required_field_names.add(field.alias)
|
if field.has_alias:
|
||||||
|
all_required_field_names.add(field.alias)
|
||||||
|
else: # Assuming pydantic 2 for now
|
||||||
|
for name, field in pydantic_cls.model_fields.items():
|
||||||
|
all_required_field_names.add(name)
|
||||||
|
if field.alias:
|
||||||
|
all_required_field_names.add(field.alias)
|
||||||
return all_required_field_names
|
return all_required_field_names
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,8 +6,13 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core import utils
|
from langchain_core import utils
|
||||||
from langchain_core.utils import check_package_version, guard_import
|
from langchain_core.utils import (
|
||||||
|
check_package_version,
|
||||||
|
get_pydantic_field_names,
|
||||||
|
guard_import,
|
||||||
|
)
|
||||||
from langchain_core.utils._merge import merge_dicts
|
from langchain_core.utils._merge import merge_dicts
|
||||||
|
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -171,3 +176,46 @@ def test_guard_import_failure(
|
|||||||
f"Please install it with `pip install {pip_name}`."
|
f"Please install it with `pip install {pip_name}`."
|
||||||
)
|
)
|
||||||
assert exc_info.value.msg == err_msg
|
assert exc_info.value.msg == err_msg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
||||||
|
def test_get_pydantic_field_names_v1_in_2() -> None:
|
||||||
|
from pydantic.v1 import BaseModel as PydanticV1BaseModel # pydantic: ignore
|
||||||
|
from pydantic.v1 import Field # pydantic: ignore
|
||||||
|
|
||||||
|
class PydanticV1Model(PydanticV1BaseModel):
|
||||||
|
field1: str
|
||||||
|
field2: int
|
||||||
|
alias_field: int = Field(alias="aliased_field")
|
||||||
|
|
||||||
|
result = get_pydantic_field_names(PydanticV1Model)
|
||||||
|
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
||||||
|
def test_get_pydantic_field_names_v2_in_2() -> None:
|
||||||
|
from pydantic import BaseModel, Field # pydantic: ignore
|
||||||
|
|
||||||
|
class PydanticModel(BaseModel):
|
||||||
|
field1: str
|
||||||
|
field2: int
|
||||||
|
alias_field: int = Field(alias="aliased_field")
|
||||||
|
|
||||||
|
result = get_pydantic_field_names(PydanticModel)
|
||||||
|
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Requires pydantic 1")
|
||||||
|
def test_get_pydantic_field_names_v1() -> None:
|
||||||
|
from pydantic import BaseModel, Field # pydantic: ignore
|
||||||
|
|
||||||
|
class PydanticModel(BaseModel):
|
||||||
|
field1: str
|
||||||
|
field2: int
|
||||||
|
alias_field: int = Field(alias="aliased_field")
|
||||||
|
|
||||||
|
result = get_pydantic_field_names(PydanticModel)
|
||||||
|
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||||
|
assert result == expected
|
||||||
|
Loading…
Reference in New Issue
Block a user