core[patch]: Add pydantic get_fields adapter (#25187)

Add adapter to get fields
This commit is contained in:
Eugene Yurtsev 2024-08-08 13:47:42 -04:00 committed by GitHub
parent c72e522e96
commit 2f209d84fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 2 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import inspect
import textwrap
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
import pydantic # pydantic: ignore
@ -266,3 +266,50 @@ def _create_subset_model(
raise NotImplementedError(
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
)
if PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.fields import FieldInfo as FieldInfoV2
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1.fields import FieldInfo as FieldInfoV1
@overload
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ...
@overload
def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ...
def get_fields(
model: Union[
BaseModelV2,
BaseModelV1,
Type[BaseModelV2],
Type[BaseModelV1],
],
) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields # type: ignore
elif hasattr(model, "__fields__"):
return model.__fields__ # type: ignore
else:
raise TypeError(f"Expected a Pydantic model. Got {type(model)}")
elif PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1_
from pydantic.fields import FieldInfo as FieldInfoV1_
def get_fields( # type: ignore[no-redef]
model: Union[Type[BaseModelV1_], BaseModelV1_],
) -> Dict[str, FieldInfoV1_]:
"""Get the field names of a Pydantic model."""
return model.__fields__ # type: ignore
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")

View File

@ -4,10 +4,10 @@ from typing import Any, Dict, List, Optional
import pytest
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
_create_subset_model_v2,
get_fields,
is_basemodel_instance,
is_basemodel_subclass,
pre_init,
@ -15,6 +15,8 @@ from langchain_core.utils.pydantic import (
def test_pre_init_decorator() -> None:
from langchain_core.pydantic_v1 import BaseModel
class Foo(BaseModel):
x: int = 5
y: int
@ -32,6 +34,8 @@ def test_pre_init_decorator() -> None:
def test_pre_init_decorator_with_more_defaults() -> None:
from langchain_core.pydantic_v1 import BaseModel, Field
class Foo(BaseModel):
a: int = 1
b: Optional[int] = None
@ -52,6 +56,8 @@ def test_pre_init_decorator_with_more_defaults() -> None:
def test_with_aliases() -> None:
from langchain_core.pydantic_v1 import BaseModel, Field
class Foo(BaseModel):
x: int = Field(default=1, alias="y")
z: int
@ -153,3 +159,36 @@ def test_with_field_metadata() -> None:
"title": "Foo",
"type": "object",
}
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Only tests Pydantic v1")
def test_fields_pydantic_v1() -> None:
from pydantic import BaseModel # pydantic: ignore
class Foo(BaseModel):
x: int
fields = get_fields(Foo)
assert fields == {"x": Foo.__fields__["x"]} # type: ignore[index]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
def test_fields_pydantic_v2_proper() -> None:
from pydantic import BaseModel # pydantic: ignore
class Foo(BaseModel):
x: int
fields = get_fields(Foo)
assert fields == {"x": Foo.model_fields["x"]}
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
def test_fields_pydantic_v1_from_2() -> None:
from pydantic.v1 import BaseModel # pydantic: ignore
class Foo(BaseModel):
x: int
fields = get_fields(Foo)
assert fields == {"x": Foo.__fields__["x"]}