core[patch]: fix repr and str for Serializable (#26786)

Fixes #26499

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Tibor Reiss 2024-10-24 17:36:35 +02:00 committed by GitHub
parent 2d58a8a08d
commit 20b56a0233
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 24 deletions

View File

@ -10,6 +10,7 @@ from typing import (
)
from pydantic import BaseModel, ConfigDict
from pydantic.fields import FieldInfo
from typing_extensions import NotRequired
@ -77,10 +78,23 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
Raises:
Exception: If the key is not in the model.
"""
field = model.model_fields[key]
return _try_neq_default(value, field)
def _try_neq_default(value: Any, field: FieldInfo) -> bool:
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
# Pandas DataFrames).
try:
return model.model_fields[key].get_default() != value
except Exception:
return True
return bool(field.get_default() != value)
except Exception as _:
try:
return all(field.get_default() != value)
except Exception as _:
try:
return value is not field.default
except Exception as _:
return False
class Serializable(BaseModel, ABC):
@ -297,18 +311,7 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
if field.default_factory is list and isinstance(value, list):
return False
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
# Pandas DataFrames).
try:
value_neq_default = bool(field.get_default() != value)
except Exception as _:
try:
value_neq_default = all(field.get_default() != value)
except Exception as _:
try:
value_neq_default = value is not field.default
except Exception as _:
value_neq_default = False
value_neq_default = _try_neq_default(value, field)
# If value is falsy and does not match the default
return value_is_truthy or value_neq_default

View File

@ -4,6 +4,22 @@ from langchain_core.load import Serializable, dumpd, load
from langchain_core.load.serializable import _is_field_useful
class NonBoolObj:
def __bool__(self) -> bool:
msg = "Truthiness can't be determined"
raise ValueError(msg)
def __eq__(self, other: object) -> bool:
msg = "Equality can't be determined"
raise ValueError(msg)
def __str__(self) -> str:
return self.__class__.__name__
def __repr__(self) -> str:
return self.__class__.__name__
def test_simple_serialization() -> None:
class Foo(Serializable):
bar: int
@ -82,15 +98,6 @@ def test__is_field_useful() -> None:
def __eq__(self, other: object) -> bool:
return self # type: ignore[return-value]
class NonBoolObj:
def __bool__(self) -> bool:
msg = "Truthiness can't be determined"
raise ValueError(msg)
def __eq__(self, other: object) -> bool:
msg = "Equality can't be determined"
raise ValueError(msg)
default_x = ArrayObj()
default_y = NonBoolObj()
@ -169,3 +176,30 @@ def test_simple_deserialization_with_additional_imports() -> None:
},
)
assert isinstance(new_foo, Foo2)
class Foo3(Serializable):
model_config = ConfigDict(arbitrary_types_allowed=True)
content: str
non_bool: NonBoolObj
@classmethod
def is_lc_serializable(cls) -> bool:
return True
def test_repr() -> None:
foo = Foo3(
content="repr",
non_bool=NonBoolObj(),
)
assert repr(foo) == "Foo3(content='repr', non_bool=NonBoolObj)"
def test_str() -> None:
foo = Foo3(
content="str",
non_bool=NonBoolObj(),
)
assert str(foo) == "content='str' non_bool=NonBoolObj"