mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 08:06:14 +00:00
attempt fix for generation
This commit is contained in:
parent
16e5a12806
commit
07d32a23c8
@ -209,18 +209,15 @@ class Serializable(BaseModel, ABC):
|
|||||||
if not self.is_lc_serializable():
|
if not self.is_lc_serializable():
|
||||||
return self.to_json_not_implemented()
|
return self.to_json_not_implemented()
|
||||||
|
|
||||||
model_fields = type(self).model_fields
|
fields = {**self.__class__.model_fields, **self.__class__.model_computed_fields}
|
||||||
secrets = {}
|
secrets = {}
|
||||||
# Get latest values for kwargs if there is an attribute with same name
|
# Get latest values for kwargs if there is an attribute with same name
|
||||||
lc_kwargs = {}
|
lc_kwargs = {}
|
||||||
for k, v in self:
|
for k, v in self.model_dump().items():
|
||||||
if not _is_field_useful(self, k, v):
|
if not _is_field_useful(self.__class__, k, v):
|
||||||
continue
|
|
||||||
# Do nothing if the field is excluded
|
|
||||||
if k in model_fields and model_fields[k].exclude:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lc_kwargs[k] = getattr(self, k, v)
|
lc_kwargs[k] = v
|
||||||
|
|
||||||
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||||
for cls in [None, *self.__class__.mro()]:
|
for cls in [None, *self.__class__.mro()]:
|
||||||
@ -253,9 +250,7 @@ class Serializable(BaseModel, ABC):
|
|||||||
# that are not present in the fields.
|
# that are not present in the fields.
|
||||||
for key in list(secrets):
|
for key in list(secrets):
|
||||||
value = secrets[key]
|
value = secrets[key]
|
||||||
if (key in model_fields) and (
|
if (key in fields) and (alias := fields[key].alias) is not None:
|
||||||
alias := model_fields[key].alias
|
|
||||||
) is not None:
|
|
||||||
secrets[alias] = value
|
secrets[alias] = value
|
||||||
lc_kwargs.update(this.lc_attributes)
|
lc_kwargs.update(this.lc_attributes)
|
||||||
|
|
||||||
@ -280,11 +275,11 @@ class Serializable(BaseModel, ABC):
|
|||||||
return to_json_not_implemented(self)
|
return to_json_not_implemented(self)
|
||||||
|
|
||||||
|
|
||||||
def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
def _is_field_useful(class_: type[Serializable], key: str, value: Any) -> bool:
|
||||||
"""Check if a field is useful as a constructor argument.
|
"""Check if a field is useful as a constructor argument.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inst: The instance.
|
class_: The class.
|
||||||
key: The key.
|
key: The key.
|
||||||
value: The value.
|
value: The value.
|
||||||
|
|
||||||
@ -294,8 +289,10 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
|||||||
If the field is not required and the value is None, it is useful if the
|
If the field is not required and the value is None, it is useful if the
|
||||||
default value is different from the value.
|
default value is different from the value.
|
||||||
"""
|
"""
|
||||||
field = type(inst).model_fields.get(key)
|
if class_.model_computed_fields.get(key):
|
||||||
if not field:
|
return True
|
||||||
|
|
||||||
|
if not (field := class_.model_fields.get(key)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if field.is_required():
|
if field.is_required():
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
||||||
|
|
||||||
from langchain_core.load import Serializable, dumpd, load
|
from langchain_core.load import Serializable, dumpd, load
|
||||||
from langchain_core.load.serializable import _is_field_useful
|
from langchain_core.load.serializable import _is_field_useful
|
||||||
@ -114,12 +114,26 @@ def test__is_field_useful() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
foo = Foo(x=ArrayObj(), y=NonBoolObj(), z=ArrayObj())
|
foo = Foo(x=ArrayObj(), y=NonBoolObj(), z=ArrayObj())
|
||||||
assert _is_field_useful(foo, "x", foo.x)
|
assert _is_field_useful(Foo, "x", foo.x)
|
||||||
assert _is_field_useful(foo, "y", foo.y)
|
assert _is_field_useful(Foo, "y", foo.y)
|
||||||
|
|
||||||
foo = Foo(x=default_x, y=default_y, z=ArrayObj())
|
foo = Foo(x=default_x, y=default_y, z=ArrayObj())
|
||||||
assert not _is_field_useful(foo, "x", foo.x)
|
assert not _is_field_useful(Foo, "x", foo.x)
|
||||||
assert not _is_field_useful(foo, "y", foo.y)
|
assert not _is_field_useful(Foo, "y", foo.y)
|
||||||
|
|
||||||
|
class Bar(Serializable):
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def z(self) -> int:
|
||||||
|
return self.x + self.y
|
||||||
|
|
||||||
|
bar = Bar(x=1, y=2)
|
||||||
|
assert _is_field_useful(Bar, "x", bar.x)
|
||||||
|
assert _is_field_useful(Bar, "y", bar.y)
|
||||||
|
assert _is_field_useful(Bar, "z", bar.z)
|
||||||
|
|
||||||
|
|
||||||
class Foo(Serializable):
|
class Foo(Serializable):
|
||||||
|
Loading…
Reference in New Issue
Block a user