mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 23:26:34 +00:00
attempt fix for generation
This commit is contained in:
parent
16e5a12806
commit
07d32a23c8
@ -209,18 +209,15 @@ class Serializable(BaseModel, ABC):
|
||||
if not self.is_lc_serializable():
|
||||
return self.to_json_not_implemented()
|
||||
|
||||
model_fields = type(self).model_fields
|
||||
fields = {**self.__class__.model_fields, **self.__class__.model_computed_fields}
|
||||
secrets = {}
|
||||
# Get latest values for kwargs if there is an attribute with same name
|
||||
lc_kwargs = {}
|
||||
for k, v in self:
|
||||
if not _is_field_useful(self, k, v):
|
||||
continue
|
||||
# Do nothing if the field is excluded
|
||||
if k in model_fields and model_fields[k].exclude:
|
||||
for k, v in self.model_dump().items():
|
||||
if not _is_field_useful(self.__class__, k, v):
|
||||
continue
|
||||
|
||||
lc_kwargs[k] = getattr(self, k, v)
|
||||
lc_kwargs[k] = v
|
||||
|
||||
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||
for cls in [None, *self.__class__.mro()]:
|
||||
@ -253,9 +250,7 @@ class Serializable(BaseModel, ABC):
|
||||
# that are not present in the fields.
|
||||
for key in list(secrets):
|
||||
value = secrets[key]
|
||||
if (key in model_fields) and (
|
||||
alias := model_fields[key].alias
|
||||
) is not None:
|
||||
if (key in fields) and (alias := fields[key].alias) is not None:
|
||||
secrets[alias] = value
|
||||
lc_kwargs.update(this.lc_attributes)
|
||||
|
||||
@ -280,11 +275,11 @@ class Serializable(BaseModel, ABC):
|
||||
return to_json_not_implemented(self)
|
||||
|
||||
|
||||
def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
def _is_field_useful(class_: type[Serializable], key: str, value: Any) -> bool:
|
||||
"""Check if a field is useful as a constructor argument.
|
||||
|
||||
Args:
|
||||
inst: The instance.
|
||||
class_: The class.
|
||||
key: The key.
|
||||
value: The value.
|
||||
|
||||
@ -294,8 +289,10 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
If the field is not required and the value is None, it is useful if the
|
||||
default value is different from the value.
|
||||
"""
|
||||
field = type(inst).model_fields.get(key)
|
||||
if not field:
|
||||
if class_.model_computed_fields.get(key):
|
||||
return True
|
||||
|
||||
if not (field := class_.model_fields.get(key)):
|
||||
return False
|
||||
|
||||
if field.is_required():
|
||||
|
@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
||||
|
||||
from langchain_core.load import Serializable, dumpd, load
|
||||
from langchain_core.load.serializable import _is_field_useful
|
||||
@ -114,12 +114,26 @@ def test__is_field_useful() -> None:
|
||||
)
|
||||
|
||||
foo = Foo(x=ArrayObj(), y=NonBoolObj(), z=ArrayObj())
|
||||
assert _is_field_useful(foo, "x", foo.x)
|
||||
assert _is_field_useful(foo, "y", foo.y)
|
||||
assert _is_field_useful(Foo, "x", foo.x)
|
||||
assert _is_field_useful(Foo, "y", foo.y)
|
||||
|
||||
foo = Foo(x=default_x, y=default_y, z=ArrayObj())
|
||||
assert not _is_field_useful(foo, "x", foo.x)
|
||||
assert not _is_field_useful(foo, "y", foo.y)
|
||||
assert not _is_field_useful(Foo, "x", foo.x)
|
||||
assert not _is_field_useful(Foo, "y", foo.y)
|
||||
|
||||
class Bar(Serializable):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def z(self) -> int:
|
||||
return self.x + self.y
|
||||
|
||||
bar = Bar(x=1, y=2)
|
||||
assert _is_field_useful(Bar, "x", bar.x)
|
||||
assert _is_field_useful(Bar, "y", bar.y)
|
||||
assert _is_field_useful(Bar, "z", bar.z)
|
||||
|
||||
|
||||
class Foo(Serializable):
|
||||
|
Loading…
Reference in New Issue
Block a user