core[patch]: fix deprecation pydantic bug (#25204)

#25004 is incompatible with pydantic < 1.10.17. Introduces fix for this.
This commit is contained in:
Bagatur 2024-08-08 16:39:38 -07:00 committed by GitHub
parent dc7423e88f
commit 7040013140
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 24 deletions

View File

@ -30,7 +30,8 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
# PUBLIC API # PUBLIC API
T = TypeVar("T", bound=Union[Type, Callable[..., Any]]) # Last Any should be FieldInfoV1 but this leads to circular imports
T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any])
def _validate_deprecation_params( def _validate_deprecation_params(
@ -133,7 +134,7 @@ def deprecated(
_package: str = package, _package: str = package,
) -> T: ) -> T:
"""Implementation of the decorator returned by `deprecated`.""" """Implementation of the decorator returned by `deprecated`."""
from pydantic.v1.fields import FieldInfo # pydantic: ignore from langchain_core.utils.pydantic import FieldInfoV1
def emit_warning() -> None: def emit_warning() -> None:
"""Emit the warning.""" """Emit the warning."""
@ -208,9 +209,7 @@ def deprecated(
) )
return cast(T, obj) return cast(T, obj)
elif isinstance(obj, FieldInfo): elif isinstance(obj, FieldInfoV1):
from langchain_core.pydantic_v1 import Field
wrapped = None wrapped = None
if not _obj_type: if not _obj_type:
_obj_type = "attribute" _obj_type = "attribute"
@ -219,58 +218,64 @@ def deprecated(
old_doc = obj.description old_doc = obj.description
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
return Field( return cast(
T,
FieldInfoV1(
default=obj.default, default=obj.default,
default_factory=obj.default_factory, default_factory=obj.default_factory,
description=new_doc, description=new_doc,
alias=obj.alias, alias=obj.alias,
exclude=obj.exclude, exclude=obj.exclude,
),
) )
elif isinstance(obj, property): elif isinstance(obj, property):
if not _obj_type: if not _obj_type:
_obj_type = "attribute" _obj_type = "attribute"
wrapped = None wrapped = None
_name = _name or obj.fget.__qualname__ _name = _name or cast(Union[Type, Callable], obj.fget).__qualname__
old_doc = obj.__doc__ old_doc = obj.__doc__
class _deprecated_property(property): class _deprecated_property(property):
"""A deprecated property.""" """A deprecated property."""
def __init__(self, fget=None, fset=None, fdel=None, doc=None): def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
super().__init__(fget, fset, fdel, doc) super().__init__(fget, fset, fdel, doc)
self.__orig_fget = fget self.__orig_fget = fget
self.__orig_fset = fset self.__orig_fset = fset
self.__orig_fdel = fdel self.__orig_fdel = fdel
def __get__(self, instance, owner=None): def __get__(self, instance, owner=None): # type: ignore[no-untyped-def]
if instance is not None or owner is not None: if instance is not None or owner is not None:
emit_warning() emit_warning()
return self.fget(instance) return self.fget(instance)
def __set__(self, instance, value): def __set__(self, instance, value): # type: ignore[no-untyped-def]
if instance is not None: if instance is not None:
emit_warning() emit_warning()
return self.fset(instance, value) return self.fset(instance, value)
def __delete__(self, instance): def __delete__(self, instance): # type: ignore[no-untyped-def]
if instance is not None: if instance is not None:
emit_warning() emit_warning()
return self.fdel(instance) return self.fdel(instance)
def __set_name__(self, owner, set_name): def __set_name__(self, owner, set_name): # type: ignore[no-untyped-def]
nonlocal _name nonlocal _name
if _name == "<lambda>": if _name == "<lambda>":
_name = set_name _name = set_name
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the property.""" """Finalize the property."""
return _deprecated_property( return cast(
T,
_deprecated_property(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
),
) )
else: else:
_name = _name or obj.__qualname__ _name = _name or cast(Union[Type, Callable], obj).__qualname__
if not _obj_type: if not _obj_type:
# edge case: when a function is within another function # edge case: when a function is within another function
# within a test, this will call it a "method" not a "function" # within a test, this will call it a "method" not a "function"

View File

@ -26,9 +26,13 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
if PYDANTIC_MAJOR_VERSION == 1: if PYDANTIC_MAJOR_VERSION == 1:
from pydantic.fields import FieldInfo as FieldInfoV1
PydanticBaseModel = pydantic.BaseModel PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = Type[BaseModel] TypeBaseModel = Type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2: elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy. # Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
@ -272,7 +276,6 @@ if PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2 from pydantic import BaseModel as BaseModelV2
from pydantic.fields import FieldInfo as FieldInfoV2 from pydantic.fields import FieldInfo as FieldInfoV2
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1.fields import FieldInfo as FieldInfoV1
@overload @overload
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ... def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ...
@ -304,11 +307,10 @@ if PYDANTIC_MAJOR_VERSION == 2:
raise TypeError(f"Expected a Pydantic model. Got {type(model)}") raise TypeError(f"Expected a Pydantic model. Got {type(model)}")
elif PYDANTIC_MAJOR_VERSION == 1: elif PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1_ from pydantic import BaseModel as BaseModelV1_
from pydantic.fields import FieldInfo as FieldInfoV1_
def get_fields( # type: ignore[no-redef] def get_fields( # type: ignore[no-redef]
model: Union[Type[BaseModelV1_], BaseModelV1_], model: Union[Type[BaseModelV1_], BaseModelV1_],
) -> Dict[str, FieldInfoV1_]: ) -> Dict[str, FieldInfoV1]:
"""Get the field names of a Pydantic model.""" """Get the field names of a Pydantic model."""
return model.__fields__ # type: ignore return model.__fields__ # type: ignore
else: else: