mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
core[patch]: fix repr and str for Serializable (#26786)
Fixes #26499 --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
2d58a8a08d
commit
20b56a0233
@ -10,6 +10,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
@ -77,10 +78,23 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
Raises:
|
||||
Exception: If the key is not in the model.
|
||||
"""
|
||||
field = model.model_fields[key]
|
||||
return _try_neq_default(value, field)
|
||||
|
||||
|
||||
def _try_neq_default(value: Any, field: FieldInfo) -> bool:
|
||||
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
|
||||
# Pandas DataFrames).
|
||||
try:
|
||||
return model.model_fields[key].get_default() != value
|
||||
except Exception:
|
||||
return True
|
||||
return bool(field.get_default() != value)
|
||||
except Exception as _:
|
||||
try:
|
||||
return all(field.get_default() != value)
|
||||
except Exception as _:
|
||||
try:
|
||||
return value is not field.default
|
||||
except Exception as _:
|
||||
return False
|
||||
|
||||
|
||||
class Serializable(BaseModel, ABC):
|
||||
@ -297,18 +311,7 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
if field.default_factory is list and isinstance(value, list):
|
||||
return False
|
||||
|
||||
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
|
||||
# Pandas DataFrames).
|
||||
try:
|
||||
value_neq_default = bool(field.get_default() != value)
|
||||
except Exception as _:
|
||||
try:
|
||||
value_neq_default = all(field.get_default() != value)
|
||||
except Exception as _:
|
||||
try:
|
||||
value_neq_default = value is not field.default
|
||||
except Exception as _:
|
||||
value_neq_default = False
|
||||
value_neq_default = _try_neq_default(value, field)
|
||||
|
||||
# If value is falsy and does not match the default
|
||||
return value_is_truthy or value_neq_default
|
||||
|
@ -4,6 +4,22 @@ from langchain_core.load import Serializable, dumpd, load
|
||||
from langchain_core.load.serializable import _is_field_useful
|
||||
|
||||
|
||||
class NonBoolObj:
|
||||
def __bool__(self) -> bool:
|
||||
msg = "Truthiness can't be determined"
|
||||
raise ValueError(msg)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
msg = "Equality can't be determined"
|
||||
raise ValueError(msg)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
def test_simple_serialization() -> None:
|
||||
class Foo(Serializable):
|
||||
bar: int
|
||||
@ -82,15 +98,6 @@ def test__is_field_useful() -> None:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return self # type: ignore[return-value]
|
||||
|
||||
class NonBoolObj:
|
||||
def __bool__(self) -> bool:
|
||||
msg = "Truthiness can't be determined"
|
||||
raise ValueError(msg)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
msg = "Equality can't be determined"
|
||||
raise ValueError(msg)
|
||||
|
||||
default_x = ArrayObj()
|
||||
default_y = NonBoolObj()
|
||||
|
||||
@ -169,3 +176,30 @@ def test_simple_deserialization_with_additional_imports() -> None:
|
||||
},
|
||||
)
|
||||
assert isinstance(new_foo, Foo2)
|
||||
|
||||
|
||||
class Foo3(Serializable):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
content: str
|
||||
non_bool: NonBoolObj
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def test_repr() -> None:
|
||||
foo = Foo3(
|
||||
content="repr",
|
||||
non_bool=NonBoolObj(),
|
||||
)
|
||||
assert repr(foo) == "Foo3(content='repr', non_bool=NonBoolObj)"
|
||||
|
||||
|
||||
def test_str() -> None:
|
||||
foo = Foo3(
|
||||
content="str",
|
||||
non_bool=NonBoolObj(),
|
||||
)
|
||||
assert str(foo) == "content='str' non_bool=NonBoolObj"
|
||||
|
Loading…
Reference in New Issue
Block a user