attempt fix for generation

This commit is contained in:
Sydney Runkle 2025-06-09 17:21:21 -04:00
parent 16e5a12806
commit 07d32a23c8
2 changed files with 30 additions and 19 deletions

View File

@ -209,18 +209,15 @@ class Serializable(BaseModel, ABC):
if not self.is_lc_serializable():
return self.to_json_not_implemented()
model_fields = type(self).model_fields
fields = {**self.__class__.model_fields, **self.__class__.model_computed_fields}
secrets = {}
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {}
for k, v in self:
if not _is_field_useful(self, k, v):
continue
# Do nothing if the field is excluded
if k in model_fields and model_fields[k].exclude:
for k, v in self.model_dump().items():
if not _is_field_useful(self.__class__, k, v):
continue
lc_kwargs[k] = getattr(self, k, v)
lc_kwargs[k] = v
# Merge the lc_secrets and lc_attributes from every class in the MRO
for cls in [None, *self.__class__.mro()]:
@ -253,9 +250,7 @@ class Serializable(BaseModel, ABC):
# that are not present in the fields.
for key in list(secrets):
value = secrets[key]
if (key in model_fields) and (
alias := model_fields[key].alias
) is not None:
if (key in fields) and (alias := fields[key].alias) is not None:
secrets[alias] = value
lc_kwargs.update(this.lc_attributes)
@ -280,11 +275,11 @@ class Serializable(BaseModel, ABC):
return to_json_not_implemented(self)
def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
def _is_field_useful(class_: type[Serializable], key: str, value: Any) -> bool:
"""Check if a field is useful as a constructor argument.
Args:
inst: The instance.
class_: The class.
key: The key.
value: The value.
@ -294,8 +289,10 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
If the field is not required and the value is None, it is useful if the
default value is different from the value.
"""
field = type(inst).model_fields.get(key)
if not field:
if class_.model_computed_fields.get(key):
return True
if not (field := class_.model_fields.get(key)):
return False
if field.is_required():

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, computed_field
from langchain_core.load import Serializable, dumpd, load
from langchain_core.load.serializable import _is_field_useful
@ -114,12 +114,26 @@ def test__is_field_useful() -> None:
)
foo = Foo(x=ArrayObj(), y=NonBoolObj(), z=ArrayObj())
assert _is_field_useful(foo, "x", foo.x)
assert _is_field_useful(foo, "y", foo.y)
assert _is_field_useful(Foo, "x", foo.x)
assert _is_field_useful(Foo, "y", foo.y)
foo = Foo(x=default_x, y=default_y, z=ArrayObj())
assert not _is_field_useful(foo, "x", foo.x)
assert not _is_field_useful(foo, "y", foo.y)
assert not _is_field_useful(Foo, "x", foo.x)
assert not _is_field_useful(Foo, "y", foo.y)
class Bar(Serializable):
x: int
y: int
@computed_field
@property
def z(self) -> int:
return self.x + self.y
bar = Bar(x=1, y=2)
assert _is_field_useful(Bar, "x", bar.x)
assert _is_field_useful(Bar, "y", bar.y)
assert _is_field_useful(Bar, "z", bar.z)
class Foo(Serializable):