From 7040013140d9e4c34d43808d21affb55324ba810 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 8 Aug 2024 16:39:38 -0700 Subject: [PATCH] core[patch]: fix deprecation pydantic bug (#25204) #25004 is incompatible with pydantic < 1.10.17. Introduces fix for this. --- libs/core/langchain_core/_api/deprecation.py | 47 +++++++++++--------- libs/core/langchain_core/utils/pydantic.py | 8 ++-- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 5a003a80a18..b48215d0d6c 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -30,7 +30,8 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning): # PUBLIC API -T = TypeVar("T", bound=Union[Type, Callable[..., Any]]) +# Last Any should be FieldInfoV1 but this leads to circular imports +T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any]) def _validate_deprecation_params( @@ -133,7 +134,7 @@ def deprecated( _package: str = package, ) -> T: """Implementation of the decorator returned by `deprecated`.""" - from pydantic.v1.fields import FieldInfo # pydantic: ignore + from langchain_core.utils.pydantic import FieldInfoV1 def emit_warning() -> None: """Emit the warning.""" @@ -208,9 +209,7 @@ def deprecated( ) return cast(T, obj) - elif isinstance(obj, FieldInfo): - from langchain_core.pydantic_v1 import Field - + elif isinstance(obj, FieldInfoV1): wrapped = None if not _obj_type: _obj_type = "attribute" @@ -219,58 +218,64 @@ def deprecated( old_doc = obj.description def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: - return Field( - default=obj.default, - default_factory=obj.default_factory, - description=new_doc, - alias=obj.alias, - exclude=obj.exclude, + return cast( + T, + FieldInfoV1( + default=obj.default, + default_factory=obj.default_factory, + description=new_doc, + alias=obj.alias, + exclude=obj.exclude, + ), ) elif isinstance(obj, property): if not _obj_type: _obj_type = "attribute" wrapped = None - _name = _name or obj.fget.__qualname__ + _name = _name or cast(Union[Type, Callable], obj.fget).__qualname__ old_doc = obj.__doc__ class _deprecated_property(property): """A deprecated property.""" - def __init__(self, fget=None, fset=None, fdel=None, doc=None): + def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def] super().__init__(fget, fset, fdel, doc) self.__orig_fget = fget self.__orig_fset = fset self.__orig_fdel = fdel - def __get__(self, instance, owner=None): + def __get__(self, instance, owner=None): # type: ignore[no-untyped-def] if instance is not None or owner is not None: emit_warning() return self.fget(instance) - def __set__(self, instance, value): + def __set__(self, instance, value): # type: ignore[no-untyped-def] if instance is not None: emit_warning() return self.fset(instance, value) - def __delete__(self, instance): + def __delete__(self, instance): # type: ignore[no-untyped-def] if instance is not None: emit_warning() return self.fdel(instance) - def __set_name__(self, owner, set_name): + def __set_name__(self, owner, set_name): # type: ignore[no-untyped-def] nonlocal _name if _name == "": _name = set_name - def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: + def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: """Finalize the property.""" - return _deprecated_property( - fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc + return cast( + T, + _deprecated_property( + fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc + ), ) else: - _name = _name or obj.__qualname__ + _name = _name or cast(Union[Type, Callable], obj).__qualname__ if not _obj_type: # edge case: when a function is within another function # within a test, this will call it a "method" not a "function" diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index cf99d3947ad..edb48d9ceff 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -26,9 +26,13 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() if PYDANTIC_MAJOR_VERSION == 1: + from pydantic.fields import FieldInfo as FieldInfoV1 + PydanticBaseModel = pydantic.BaseModel TypeBaseModel = Type[BaseModel] elif PYDANTIC_MAJOR_VERSION == 2: + from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment] + # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore @@ -272,7 +276,6 @@ if PYDANTIC_MAJOR_VERSION == 2: from pydantic import BaseModel as BaseModelV2 from pydantic.fields import FieldInfo as FieldInfoV2 from pydantic.v1 import BaseModel as BaseModelV1 - from pydantic.v1.fields import FieldInfo as FieldInfoV1 @overload def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ... @@ -304,11 +307,10 @@ if PYDANTIC_MAJOR_VERSION == 2: raise TypeError(f"Expected a Pydantic model. Got {type(model)}") elif PYDANTIC_MAJOR_VERSION == 1: from pydantic import BaseModel as BaseModelV1_ - from pydantic.fields import FieldInfo as FieldInfoV1_ def get_fields( # type: ignore[no-redef] model: Union[Type[BaseModelV1_], BaseModelV1_], - ) -> Dict[str, FieldInfoV1_]: + ) -> Dict[str, FieldInfoV1]: """Get the field names of a Pydantic model.""" return model.__fields__ # type: ignore else: