diff --git a/libs/core/langchain_core/load/serializable.py b/libs/core/langchain_core/load/serializable.py index b5e7d8b9150..7655438be97 100644 --- a/libs/core/langchain_core/load/serializable.py +++ b/libs/core/langchain_core/load/serializable.py @@ -10,6 +10,7 @@ from typing import ( ) from pydantic import BaseModel, ConfigDict +from pydantic.fields import FieldInfo from typing_extensions import NotRequired @@ -77,10 +78,23 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool: Raises: Exception: If the key is not in the model. """ + field = model.model_fields[key] + return _try_neq_default(value, field) + + +def _try_neq_default(value: Any, field: FieldInfo) -> bool: + # Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two + # Pandas DataFrames). try: - return model.model_fields[key].get_default() != value - except Exception: - return True + return bool(field.get_default() != value) + except Exception as _: + try: + return all(field.get_default() != value) + except Exception as _: + try: + return value is not field.default + except Exception as _: + return False class Serializable(BaseModel, ABC): @@ -297,18 +311,7 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool: if field.default_factory is list and isinstance(value, list): return False - # Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two - # Pandas DataFrames). - try: - value_neq_default = bool(field.get_default() != value) - except Exception as _: - try: - value_neq_default = all(field.get_default() != value) - except Exception as _: - try: - value_neq_default = value is not field.default - except Exception as _: - value_neq_default = False + value_neq_default = _try_neq_default(value, field) # If value is falsy and does not match the default return value_is_truthy or value_neq_default diff --git a/libs/core/tests/unit_tests/load/test_serializable.py b/libs/core/tests/unit_tests/load/test_serializable.py index 65040a6841b..1c8b6772f09 100644 --- a/libs/core/tests/unit_tests/load/test_serializable.py +++ b/libs/core/tests/unit_tests/load/test_serializable.py @@ -4,6 +4,22 @@ from langchain_core.load import Serializable, dumpd, load from langchain_core.load.serializable import _is_field_useful +class NonBoolObj: + def __bool__(self) -> bool: + msg = "Truthiness can't be determined" + raise ValueError(msg) + + def __eq__(self, other: object) -> bool: + msg = "Equality can't be determined" + raise ValueError(msg) + + def __str__(self) -> str: + return self.__class__.__name__ + + def __repr__(self) -> str: + return self.__class__.__name__ + + def test_simple_serialization() -> None: class Foo(Serializable): bar: int @@ -82,15 +98,6 @@ def test__is_field_useful() -> None: def __eq__(self, other: object) -> bool: return self # type: ignore[return-value] - class NonBoolObj: - def __bool__(self) -> bool: - msg = "Truthiness can't be determined" - raise ValueError(msg) - - def __eq__(self, other: object) -> bool: - msg = "Equality can't be determined" - raise ValueError(msg) - default_x = ArrayObj() default_y = NonBoolObj() @@ -169,3 +176,30 @@ def test_simple_deserialization_with_additional_imports() -> None: }, ) assert isinstance(new_foo, Foo2) + + +class Foo3(Serializable): + model_config = ConfigDict(arbitrary_types_allowed=True) + + content: str + non_bool: NonBoolObj + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + +def test_repr() -> None: + foo = Foo3( + content="repr", + non_bool=NonBoolObj(), + ) + assert repr(foo) == "Foo3(content='repr', non_bool=NonBoolObj)" + + +def test_str() -> None: + foo = Foo3( + content="str", + non_bool=NonBoolObj(), + ) + assert str(foo) == "content='str' non_bool=NonBoolObj"