mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-29 22:46:27 +00:00
core[patch]: Add pydantic get_fields adapter (#25187)
Add adapter to get fields
This commit is contained in:
parent
c72e522e96
commit
2f209d84fa
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
|
||||||
|
|
||||||
import pydantic # pydantic: ignore
|
import pydantic # pydantic: ignore
|
||||||
|
|
||||||
@ -266,3 +266,50 @@ def _create_subset_model(
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if PYDANTIC_MAJOR_VERSION == 2:
|
||||||
|
from pydantic import BaseModel as BaseModelV2
|
||||||
|
from pydantic.fields import FieldInfo as FieldInfoV2
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
from pydantic.v1.fields import FieldInfo as FieldInfoV1
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ...
|
||||||
|
|
||||||
|
def get_fields(
|
||||||
|
model: Union[
|
||||||
|
BaseModelV2,
|
||||||
|
BaseModelV1,
|
||||||
|
Type[BaseModelV2],
|
||||||
|
Type[BaseModelV1],
|
||||||
|
],
|
||||||
|
) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]:
|
||||||
|
"""Get the field names of a Pydantic model."""
|
||||||
|
if hasattr(model, "model_fields"):
|
||||||
|
return model.model_fields # type: ignore
|
||||||
|
|
||||||
|
elif hasattr(model, "__fields__"):
|
||||||
|
return model.__fields__ # type: ignore
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Expected a Pydantic model. Got {type(model)}")
|
||||||
|
elif PYDANTIC_MAJOR_VERSION == 1:
|
||||||
|
from pydantic import BaseModel as BaseModelV1_
|
||||||
|
from pydantic.fields import FieldInfo as FieldInfoV1_
|
||||||
|
|
||||||
|
def get_fields( # type: ignore[no-redef]
|
||||||
|
model: Union[Type[BaseModelV1_], BaseModelV1_],
|
||||||
|
) -> Dict[str, FieldInfoV1_]:
|
||||||
|
"""Get the field names of a Pydantic model."""
|
||||||
|
return model.__fields__ # type: ignore
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||||
|
@ -4,10 +4,10 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
PYDANTIC_MAJOR_VERSION,
|
PYDANTIC_MAJOR_VERSION,
|
||||||
_create_subset_model_v2,
|
_create_subset_model_v2,
|
||||||
|
get_fields,
|
||||||
is_basemodel_instance,
|
is_basemodel_instance,
|
||||||
is_basemodel_subclass,
|
is_basemodel_subclass,
|
||||||
pre_init,
|
pre_init,
|
||||||
@ -15,6 +15,8 @@ from langchain_core.utils.pydantic import (
|
|||||||
|
|
||||||
|
|
||||||
def test_pre_init_decorator() -> None:
|
def test_pre_init_decorator() -> None:
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
x: int = 5
|
x: int = 5
|
||||||
y: int
|
y: int
|
||||||
@ -32,6 +34,8 @@ def test_pre_init_decorator() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_pre_init_decorator_with_more_defaults() -> None:
|
def test_pre_init_decorator_with_more_defaults() -> None:
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
a: int = 1
|
a: int = 1
|
||||||
b: Optional[int] = None
|
b: Optional[int] = None
|
||||||
@ -52,6 +56,8 @@ def test_pre_init_decorator_with_more_defaults() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_with_aliases() -> None:
|
def test_with_aliases() -> None:
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
x: int = Field(default=1, alias="y")
|
x: int = Field(default=1, alias="y")
|
||||||
z: int
|
z: int
|
||||||
@ -153,3 +159,36 @@ def test_with_field_metadata() -> None:
|
|||||||
"title": "Foo",
|
"title": "Foo",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Only tests Pydantic v1")
|
||||||
|
def test_fields_pydantic_v1() -> None:
|
||||||
|
from pydantic import BaseModel # pydantic: ignore
|
||||||
|
|
||||||
|
class Foo(BaseModel):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
fields = get_fields(Foo)
|
||||||
|
assert fields == {"x": Foo.__fields__["x"]} # type: ignore[index]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
|
||||||
|
def test_fields_pydantic_v2_proper() -> None:
|
||||||
|
from pydantic import BaseModel # pydantic: ignore
|
||||||
|
|
||||||
|
class Foo(BaseModel):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
fields = get_fields(Foo)
|
||||||
|
assert fields == {"x": Foo.model_fields["x"]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
|
||||||
|
def test_fields_pydantic_v1_from_2() -> None:
|
||||||
|
from pydantic.v1 import BaseModel # pydantic: ignore
|
||||||
|
|
||||||
|
class Foo(BaseModel):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
fields = get_fields(Foo)
|
||||||
|
assert fields == {"x": Foo.__fields__["x"]}
|
||||||
|
Loading…
Reference in New Issue
Block a user