mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
chore: Simpler Serializable
This commit is contained in:
@@ -13,8 +13,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import NotRequired, override
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -69,39 +68,6 @@ class SerializedNotImplemented(BaseSerialized):
|
||||
repr: Optional[str]
|
||||
|
||||
|
||||
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
"""Try to determine if a value is different from the default.
|
||||
|
||||
Args:
|
||||
value: The value.
|
||||
key: The key.
|
||||
model: The pydantic model.
|
||||
|
||||
Returns:
|
||||
Whether the value is different from the default.
|
||||
|
||||
Raises:
|
||||
Exception: If the key is not in the model.
|
||||
"""
|
||||
field = type(model).model_fields[key]
|
||||
return _try_neq_default(value, field)
|
||||
|
||||
|
||||
def _try_neq_default(value: Any, field: FieldInfo) -> bool:
|
||||
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
|
||||
# Pandas DataFrames).
|
||||
try:
|
||||
return bool(field.get_default() != value)
|
||||
except Exception as _:
|
||||
try:
|
||||
return all(field.get_default() != value)
|
||||
except Exception as _:
|
||||
try:
|
||||
return value is not field.default
|
||||
except Exception as _:
|
||||
return False
|
||||
|
||||
|
||||
class Serializable(BaseModel, ABC):
|
||||
"""Serializable base class.
|
||||
|
||||
@@ -192,14 +158,6 @@ class Serializable(BaseModel, ABC):
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
@override
|
||||
def __repr_args__(self) -> Any:
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if (k not in type(self).model_fields or try_neq_default(v, k, self))
|
||||
]
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
"""Serialize the object to JSON.
|
||||
|
||||
@@ -270,9 +228,9 @@ class Serializable(BaseModel, ABC):
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": self.lc_id(),
|
||||
"kwargs": lc_kwargs
|
||||
if not secrets
|
||||
else _replace_secrets(lc_kwargs, secrets),
|
||||
"kwargs": (
|
||||
lc_kwargs if not secrets else _replace_secrets(lc_kwargs, secrets)
|
||||
),
|
||||
}
|
||||
|
||||
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
||||
@@ -301,28 +259,13 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
if field.is_required():
|
||||
return True
|
||||
|
||||
# Handle edge case: a value cannot be converted to a boolean (e.g. a
|
||||
# Pandas DataFrame).
|
||||
try:
|
||||
value_is_truthy = bool(value)
|
||||
except Exception as _:
|
||||
value_is_truthy = False
|
||||
|
||||
if value_is_truthy:
|
||||
return True
|
||||
|
||||
# Value is still falsy here!
|
||||
if field.default_factory is dict and isinstance(value, dict):
|
||||
if value is field.default:
|
||||
return False
|
||||
|
||||
# Value is still falsy here!
|
||||
if field.default_factory is list and isinstance(value, list):
|
||||
if field.default_factory is dict and isinstance(value, dict) and not value:
|
||||
return False
|
||||
|
||||
value_neq_default = _try_neq_default(value, field)
|
||||
|
||||
# If value is falsy and does not match the default
|
||||
return value_is_truthy or value_neq_default
|
||||
return not (field.default_factory is list and isinstance(value, list) and not value)
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
|
||||
Reference in New Issue
Block a user