core[patch]: handle serializable fields that cant be converted to bool (#25903)

This commit is contained in:
Bagatur 2024-09-01 16:44:33 -07:00 committed by GitHub
parent 7f857a02d5
commit d19e074374
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 1 deletions

View File

@ -262,7 +262,27 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
field = inst.__fields__.get(key)
if not field:
return False
return field.required is True or value or field.get_default() != value
# Handle edge case: a value cannot be converted to a boolean (e.g. a
# Pandas DataFrame).
try:
value_is_truthy = bool(value)
except Exception as _:
value_is_truthy = False
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
# Pandas DataFrames).
try:
value_neq_default = bool(field.get_default() != value)
except Exception as _:
try:
value_neq_default = all(field.get_default() != value)
except Exception as _:
try:
value_neq_default = value is not field.default
except Exception as _:
value_neq_default = False
return field.required is True or value_is_truthy or value_neq_default
def _replace_secrets(

View File

@ -1,6 +1,8 @@
from typing import Dict
from langchain_core.load import Serializable, dumpd
from langchain_core.load.serializable import _is_field_useful
from langchain_core.pydantic_v1 import Field
def test_simple_serialization() -> None:
@ -69,3 +71,39 @@ def test_simple_serialization_secret() -> None:
"lc": 1,
"type": "constructor",
}
def test__is_field_useful() -> None:
class ArrayObj:
def __bool__(self) -> bool:
raise ValueError("Truthiness can't be determined")
def __eq__(self, other: object) -> bool:
return self # type: ignore[return-value]
class NonBoolObj:
def __bool__(self) -> bool:
raise ValueError("Truthiness can't be determined")
def __eq__(self, other: object) -> bool:
raise ValueError("Equality can't be determined")
default_x = ArrayObj()
default_y = NonBoolObj()
class Foo(Serializable):
x: ArrayObj = Field(default=default_x)
y: NonBoolObj = Field(default=default_y)
# Make sure works for fields without default.
z: ArrayObj
class Config:
arbitrary_types_allowed = True
foo = Foo(x=ArrayObj(), y=NonBoolObj(), z=ArrayObj())
assert _is_field_useful(foo, "x", foo.x)
assert _is_field_useful(foo, "y", foo.y)
foo = Foo(x=default_x, y=default_y, z=ArrayObj())
assert not _is_field_useful(foo, "x", foo.x)
assert not _is_field_useful(foo, "y", foo.y)