From 2f209d84fab4a859883c04ee1248a27309864e28 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 8 Aug 2024 13:47:42 -0400 Subject: [PATCH] core[patch]: Add pydantic get_fields adapter (#25187) Add adapter to get fields --- libs/core/langchain_core/utils/pydantic.py | 49 ++++++++++++++++++- .../tests/unit_tests/utils/test_pydantic.py | 41 +++++++++++++++- 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index fb2f3488e0c..cf99d3947ad 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -5,7 +5,7 @@ from __future__ import annotations import inspect import textwrap from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload import pydantic # pydantic: ignore @@ -266,3 +266,50 @@ def _create_subset_model( raise NotImplementedError( f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}" ) + + +if PYDANTIC_MAJOR_VERSION == 2: + from pydantic import BaseModel as BaseModelV2 + from pydantic.fields import FieldInfo as FieldInfoV2 + from pydantic.v1 import BaseModel as BaseModelV1 + from pydantic.v1.fields import FieldInfo as FieldInfoV1 + + @overload + def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ... + + @overload + def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ... + + @overload + def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ... + + @overload + def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ... + + def get_fields( + model: Union[ + BaseModelV2, + BaseModelV1, + Type[BaseModelV2], + Type[BaseModelV1], + ], + ) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]: + """Get the field names of a Pydantic model.""" + if hasattr(model, "model_fields"): + return model.model_fields # type: ignore + + elif hasattr(model, "__fields__"): + return model.__fields__ # type: ignore + else: + raise TypeError(f"Expected a Pydantic model. Got {type(model)}") +elif PYDANTIC_MAJOR_VERSION == 1: + from pydantic import BaseModel as BaseModelV1_ + from pydantic.fields import FieldInfo as FieldInfoV1_ + + def get_fields( # type: ignore[no-redef] + model: Union[Type[BaseModelV1_], BaseModelV1_], + ) -> Dict[str, FieldInfoV1_]: + """Get the field names of a Pydantic model.""" + return model.__fields__ # type: ignore +else: + raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}") diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index 3f0df33a89b..3517d5c1640 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -4,10 +4,10 @@ from typing import Any, Dict, List, Optional import pytest -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.utils.pydantic import ( PYDANTIC_MAJOR_VERSION, _create_subset_model_v2, + get_fields, is_basemodel_instance, is_basemodel_subclass, pre_init, @@ -15,6 +15,8 @@ from langchain_core.utils.pydantic import ( def test_pre_init_decorator() -> None: + from langchain_core.pydantic_v1 import BaseModel + class Foo(BaseModel): x: int = 5 y: int @@ -32,6 +34,8 @@ def test_pre_init_decorator() -> None: def test_pre_init_decorator_with_more_defaults() -> None: + from langchain_core.pydantic_v1 import BaseModel, Field + class Foo(BaseModel): a: int = 1 b: Optional[int] = None @@ -52,6 +56,8 @@ def test_pre_init_decorator_with_more_defaults() -> None: def test_with_aliases() -> None: + from langchain_core.pydantic_v1 import BaseModel, Field + class Foo(BaseModel): x: int = Field(default=1, alias="y") z: int @@ -153,3 +159,36 @@ def test_with_field_metadata() -> None: "title": "Foo", "type": "object", } + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Only tests Pydantic v1") +def test_fields_pydantic_v1() -> None: + from pydantic import BaseModel # pydantic: ignore + + class Foo(BaseModel): + x: int + + fields = get_fields(Foo) + assert fields == {"x": Foo.__fields__["x"]} # type: ignore[index] + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2") +def test_fields_pydantic_v2_proper() -> None: + from pydantic import BaseModel # pydantic: ignore + + class Foo(BaseModel): + x: int + + fields = get_fields(Foo) + assert fields == {"x": Foo.model_fields["x"]} + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2") +def test_fields_pydantic_v1_from_2() -> None: + from pydantic.v1 import BaseModel # pydantic: ignore + + class Foo(BaseModel): + x: int + + fields = get_fields(Foo) + assert fields == {"x": Foo.__fields__["x"]}