mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-02 19:34:04 +00:00
core[patch]: handle serializable fields that cant be converted to bool (#25903)
This commit is contained in:
parent
7f857a02d5
commit
d19e074374
@ -262,7 +262,27 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
field = inst.__fields__.get(key)
|
||||
if not field:
|
||||
return False
|
||||
return field.required is True or value or field.get_default() != value
|
||||
# Handle edge case: a value cannot be converted to a boolean (e.g. a
|
||||
# Pandas DataFrame).
|
||||
try:
|
||||
value_is_truthy = bool(value)
|
||||
except Exception as _:
|
||||
value_is_truthy = False
|
||||
|
||||
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
|
||||
# Pandas DataFrames).
|
||||
try:
|
||||
value_neq_default = bool(field.get_default() != value)
|
||||
except Exception as _:
|
||||
try:
|
||||
value_neq_default = all(field.get_default() != value)
|
||||
except Exception as _:
|
||||
try:
|
||||
value_neq_default = value is not field.default
|
||||
except Exception as _:
|
||||
value_neq_default = False
|
||||
|
||||
return field.required is True or value_is_truthy or value_neq_default
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
|
@ -1,6 +1,8 @@
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.load import Serializable, dumpd
|
||||
from langchain_core.load.serializable import _is_field_useful
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
|
||||
def test_simple_serialization() -> None:
|
||||
@ -69,3 +71,39 @@ def test_simple_serialization_secret() -> None:
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
}
|
||||
|
||||
|
||||
def test__is_field_useful() -> None:
|
||||
class ArrayObj:
|
||||
def __bool__(self) -> bool:
|
||||
raise ValueError("Truthiness can't be determined")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return self # type: ignore[return-value]
|
||||
|
||||
class NonBoolObj:
|
||||
def __bool__(self) -> bool:
|
||||
raise ValueError("Truthiness can't be determined")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
raise ValueError("Equality can't be determined")
|
||||
|
||||
default_x = ArrayObj()
|
||||
default_y = NonBoolObj()
|
||||
|
||||
class Foo(Serializable):
|
||||
x: ArrayObj = Field(default=default_x)
|
||||
y: NonBoolObj = Field(default=default_y)
|
||||
# Make sure works for fields without default.
|
||||
z: ArrayObj
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
foo = Foo(x=ArrayObj(), y=NonBoolObj(), z=ArrayObj())
|
||||
assert _is_field_useful(foo, "x", foo.x)
|
||||
assert _is_field_useful(foo, "y", foo.y)
|
||||
|
||||
foo = Foo(x=default_x, y=default_y, z=ArrayObj())
|
||||
assert not _is_field_useful(foo, "x", foo.x)
|
||||
assert not _is_field_useful(foo, "y", foo.y)
|
||||
|
Loading…
Reference in New Issue
Block a user