mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
feat(core): add an option to make deserialization more permissive (#32054)
## Description Currently when deserializing objects that contain non-deserializable values, we throw an error. However, there are cases (e.g. proxies that return response fields containing extra fields like Python datetimes), where these values are not important and we just want to drop them. Twitter handle: @hacubu --------- Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
parent
3628dccbf3
commit
535ba43b0d
@ -56,6 +56,8 @@ class Reviver:
|
||||
additional_import_mappings: Optional[
|
||||
dict[tuple[str, ...], tuple[str, ...]]
|
||||
] = None,
|
||||
*,
|
||||
ignore_unserializable_fields: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the reviver.
|
||||
|
||||
@ -70,6 +72,8 @@ class Reviver:
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
You can use this to override default mappings or add new mappings.
|
||||
Defaults to None.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
Defaults to False.
|
||||
"""
|
||||
self.secrets_from_env = secrets_from_env
|
||||
self.secrets_map = secrets_map or {}
|
||||
@ -88,6 +92,7 @@ class Reviver:
|
||||
if self.additional_import_mappings
|
||||
else ALL_SERIALIZABLE_MAPPINGS
|
||||
)
|
||||
self.ignore_unserializable_fields = ignore_unserializable_fields
|
||||
|
||||
def __call__(self, value: dict[str, Any]) -> Any:
|
||||
"""Revive the value."""
|
||||
@ -108,6 +113,8 @@ class Reviver:
|
||||
and value.get("type") == "not_implemented"
|
||||
and value.get("id") is not None
|
||||
):
|
||||
if self.ignore_unserializable_fields:
|
||||
return None
|
||||
msg = (
|
||||
"Trying to load an object that doesn't implement "
|
||||
f"serialization: {value}"
|
||||
@ -170,6 +177,7 @@ def loads(
|
||||
valid_namespaces: Optional[list[str]] = None,
|
||||
secrets_from_env: bool = True,
|
||||
additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
|
||||
ignore_unserializable_fields: bool = False,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON string.
|
||||
|
||||
@ -187,6 +195,8 @@ def loads(
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
You can use this to override default mappings or add new mappings.
|
||||
Defaults to None.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
@ -194,7 +204,11 @@ def loads(
|
||||
return json.loads(
|
||||
text,
|
||||
object_hook=Reviver(
|
||||
secrets_map, valid_namespaces, secrets_from_env, additional_import_mappings
|
||||
secrets_map,
|
||||
valid_namespaces,
|
||||
secrets_from_env,
|
||||
additional_import_mappings,
|
||||
ignore_unserializable_fields=ignore_unserializable_fields,
|
||||
),
|
||||
)
|
||||
|
||||
@ -207,6 +221,7 @@ def load(
|
||||
valid_namespaces: Optional[list[str]] = None,
|
||||
secrets_from_env: bool = True,
|
||||
additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
|
||||
ignore_unserializable_fields: bool = False,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON object.
|
||||
|
||||
@ -225,12 +240,18 @@ def load(
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
You can use this to override default mappings or add new mappings.
|
||||
Defaults to None.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
"""
|
||||
reviver = Reviver(
|
||||
secrets_map, valid_namespaces, secrets_from_env, additional_import_mappings
|
||||
secrets_map,
|
||||
valid_namespaces,
|
||||
secrets_from_env,
|
||||
additional_import_mappings,
|
||||
ignore_unserializable_fields=ignore_unserializable_fields,
|
||||
)
|
||||
|
||||
def _load(obj: Any) -> Any:
|
||||
|
@ -232,3 +232,47 @@ def test_serialization_with_pydantic() -> None:
|
||||
def test_serialization_with_generation() -> None:
|
||||
generation = Generation(text="hello-world")
|
||||
assert dumpd(generation)["kwargs"] == {"text": "hello-world", "type": "Generation"}
|
||||
|
||||
|
||||
def test_serialization_with_ignore_unserializable_fields() -> None:
|
||||
data = {
|
||||
"messages": [
|
||||
[
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||
"kwargs": {
|
||||
"content": "Call tools to get entity details",
|
||||
"response_metadata": {
|
||||
"other_field": "foo",
|
||||
"create_date": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": ["datetime", "datetime"],
|
||||
"repr": "datetime.datetime(2025, 7, 15, 13, 14, 0, 000000, tzinfo=datetime.timezone.utc)", # noqa: E501
|
||||
},
|
||||
},
|
||||
"type": "ai",
|
||||
"id": "00000000-0000-0000-0000-000000000000",
|
||||
},
|
||||
},
|
||||
]
|
||||
]
|
||||
}
|
||||
ser = dumpd(data)
|
||||
deser = load(ser, ignore_unserializable_fields=True)
|
||||
assert deser == {
|
||||
"messages": [
|
||||
[
|
||||
AIMessage(
|
||||
id="00000000-0000-0000-0000-000000000000",
|
||||
content="Call tools to get entity details",
|
||||
response_metadata={
|
||||
"other_field": "foo",
|
||||
"create_date": None,
|
||||
},
|
||||
)
|
||||
]
|
||||
]
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user