mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-29 14:37:21 +00:00
core[patch]: Add pydantic get_fields adapter (#25187)
Add adapter to get fields
This commit is contained in:
parent
c72e522e96
commit
2f209d84fa
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import textwrap
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
@ -266,3 +266,50 @@ def _create_subset_model(
|
||||
raise NotImplementedError(
|
||||
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
||||
)
|
||||
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.fields import FieldInfo as FieldInfoV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1.fields import FieldInfo as FieldInfoV1
|
||||
|
||||
@overload
|
||||
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ...
|
||||
|
||||
def get_fields(
|
||||
model: Union[
|
||||
BaseModelV2,
|
||||
BaseModelV1,
|
||||
Type[BaseModelV2],
|
||||
Type[BaseModelV1],
|
||||
],
|
||||
) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
if hasattr(model, "model_fields"):
|
||||
return model.model_fields # type: ignore
|
||||
|
||||
elif hasattr(model, "__fields__"):
|
||||
return model.__fields__ # type: ignore
|
||||
else:
|
||||
raise TypeError(f"Expected a Pydantic model. Got {type(model)}")
|
||||
elif PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import BaseModel as BaseModelV1_
|
||||
from pydantic.fields import FieldInfo as FieldInfoV1_
|
||||
|
||||
def get_fields( # type: ignore[no-redef]
|
||||
model: Union[Type[BaseModelV1_], BaseModelV1_],
|
||||
) -> Dict[str, FieldInfoV1_]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
return model.__fields__ # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||
|
@ -4,10 +4,10 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils.pydantic import (
|
||||
PYDANTIC_MAJOR_VERSION,
|
||||
_create_subset_model_v2,
|
||||
get_fields,
|
||||
is_basemodel_instance,
|
||||
is_basemodel_subclass,
|
||||
pre_init,
|
||||
@ -15,6 +15,8 @@ from langchain_core.utils.pydantic import (
|
||||
|
||||
|
||||
def test_pre_init_decorator() -> None:
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
class Foo(BaseModel):
|
||||
x: int = 5
|
||||
y: int
|
||||
@ -32,6 +34,8 @@ def test_pre_init_decorator() -> None:
|
||||
|
||||
|
||||
def test_pre_init_decorator_with_more_defaults() -> None:
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
class Foo(BaseModel):
|
||||
a: int = 1
|
||||
b: Optional[int] = None
|
||||
@ -52,6 +56,8 @@ def test_pre_init_decorator_with_more_defaults() -> None:
|
||||
|
||||
|
||||
def test_with_aliases() -> None:
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
class Foo(BaseModel):
|
||||
x: int = Field(default=1, alias="y")
|
||||
z: int
|
||||
@ -153,3 +159,36 @@ def test_with_field_metadata() -> None:
|
||||
"title": "Foo",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Only tests Pydantic v1")
|
||||
def test_fields_pydantic_v1() -> None:
|
||||
from pydantic import BaseModel # pydantic: ignore
|
||||
|
||||
class Foo(BaseModel):
|
||||
x: int
|
||||
|
||||
fields = get_fields(Foo)
|
||||
assert fields == {"x": Foo.__fields__["x"]} # type: ignore[index]
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
|
||||
def test_fields_pydantic_v2_proper() -> None:
|
||||
from pydantic import BaseModel # pydantic: ignore
|
||||
|
||||
class Foo(BaseModel):
|
||||
x: int
|
||||
|
||||
fields = get_fields(Foo)
|
||||
assert fields == {"x": Foo.model_fields["x"]}
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
|
||||
def test_fields_pydantic_v1_from_2() -> None:
|
||||
from pydantic.v1 import BaseModel # pydantic: ignore
|
||||
|
||||
class Foo(BaseModel):
|
||||
x: int
|
||||
|
||||
fields = get_fields(Foo)
|
||||
assert fields == {"x": Foo.__fields__["x"]}
|
||||
|
Loading…
Reference in New Issue
Block a user