fix(core): serialization patch (#34455)

- `allowed_objects` kwarg in `load`
- escape lc-ser formatted dicts on `dump`
- fix for jinja2

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
ccurme
2025-12-22 18:33:31 -05:00
committed by GitHub
parent 6a416c6186
commit 5ec0fa69de
25 changed files with 1769 additions and 116 deletions

View File

@@ -6,7 +6,7 @@ from langchain_core._import_utils import import_attr
if TYPE_CHECKING:
from langchain_core.load.dump import dumpd, dumps
from langchain_core.load.load import loads
from langchain_core.load.load import InitValidator, loads
from langchain_core.load.serializable import Serializable
# Unfortunately, we have to eagerly import load from langchain_core/load/load.py
@@ -15,11 +15,19 @@ if TYPE_CHECKING:
# the `from langchain_core.load.load import load` absolute import should also work.
from langchain_core.load.load import load
__all__ = ("Serializable", "dumpd", "dumps", "load", "loads")
__all__ = (
"InitValidator",
"Serializable",
"dumpd",
"dumps",
"load",
"loads",
)
_dynamic_imports = {
"dumpd": "dump",
"dumps": "dump",
"InitValidator": "load",
"loads": "load",
"Serializable": "serializable",
}

View File

@@ -0,0 +1,176 @@
"""Validation utilities for LangChain serialization.
Provides escape-based protection against injection attacks in serialized objects. The
approach uses an allowlist design: only dicts explicitly produced by
`Serializable.to_json()` are treated as LC objects during deserialization.
## How escaping works
During serialization, plain dicts (user data) that contain an `'lc'` key are wrapped:
```python
{"lc": 1, ...} # user data that looks like LC object
# becomes:
{"__lc_escaped__": {"lc": 1, ...}}
```
During deserialization, escaped dicts are unwrapped and returned as plain dicts,
NOT instantiated as LC objects.
"""
from typing import Any
_LC_ESCAPED_KEY = "__lc_escaped__"
"""Sentinel key used to mark escaped user dicts during serialization.
When a plain dict contains 'lc' key (which could be confused with LC objects),
we wrap it as {"__lc_escaped__": {...original...}}.
"""
def _needs_escaping(obj: dict[str, Any]) -> bool:
"""Check if a dict needs escaping to prevent confusion with LC objects.
A dict needs escaping if:
1. It has an `'lc'` key (could be confused with LC serialization format)
2. It has only the escape key (would be mistaken for an escaped dict)
"""
return "lc" in obj or (len(obj) == 1 and _LC_ESCAPED_KEY in obj)
def _escape_dict(obj: dict[str, Any]) -> dict[str, Any]:
"""Wrap a dict in the escape marker.
Example:
```python
{"key": "value"} # becomes {"__lc_escaped__": {"key": "value"}}
```
"""
return {_LC_ESCAPED_KEY: obj}
def _is_escaped_dict(obj: dict[str, Any]) -> bool:
"""Check if a dict is an escaped user dict.
Example:
```python
{"__lc_escaped__": {...}} # is an escaped dict
```
"""
return len(obj) == 1 and _LC_ESCAPED_KEY in obj
def _serialize_value(obj: Any) -> Any:
"""Serialize a value with escaping of user dicts.
Called recursively on kwarg values to escape any plain dicts that could be confused
with LC objects.
Args:
obj: The value to serialize.
Returns:
The serialized value with user dicts escaped as needed.
"""
from langchain_core.load.serializable import ( # noqa: PLC0415
Serializable,
to_json_not_implemented,
)
if isinstance(obj, Serializable):
# This is an LC object - serialize it properly (not escaped)
return _serialize_lc_object(obj)
if isinstance(obj, dict):
if not all(isinstance(k, (str, int, float, bool, type(None))) for k in obj):
# if keys are not json serializable
return to_json_not_implemented(obj)
# Check if dict needs escaping BEFORE recursing into values.
# If it needs escaping, wrap it as-is - the contents are user data that
# will be returned as-is during deserialization (no instantiation).
# This prevents re-escaping of already-escaped nested content.
if _needs_escaping(obj):
return _escape_dict(obj)
# Safe dict (no 'lc' key) - recurse into values
return {k: _serialize_value(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_serialize_value(item) for item in obj]
if isinstance(obj, (str, int, float, bool, type(None))):
return obj
# Non-JSON-serializable object (datetime, custom objects, etc.)
return to_json_not_implemented(obj)
def _is_lc_secret(obj: Any) -> bool:
"""Check if an object is a LangChain secret marker."""
expected_num_keys = 3
return (
isinstance(obj, dict)
and obj.get("lc") == 1
and obj.get("type") == "secret"
and "id" in obj
and len(obj) == expected_num_keys
)
def _serialize_lc_object(obj: Any) -> dict[str, Any]:
"""Serialize a `Serializable` object with escaping of user data in kwargs.
Args:
obj: The `Serializable` object to serialize.
Returns:
The serialized dict with user data in kwargs escaped as needed.
Note:
Kwargs values are processed with `_serialize_value` to escape user data (like
metadata) that contains `'lc'` keys. Secret fields (from `lc_secrets`) are
skipped because `to_json()` replaces their values with secret markers.
"""
from langchain_core.load.serializable import Serializable # noqa: PLC0415
if not isinstance(obj, Serializable):
msg = f"Expected Serializable, got {type(obj)}"
raise TypeError(msg)
serialized: dict[str, Any] = dict(obj.to_json())
# Process kwargs to escape user data that could be confused with LC objects
# Skip secret fields - to_json() already converted them to secret markers
if serialized.get("type") == "constructor" and "kwargs" in serialized:
serialized["kwargs"] = {
k: v if _is_lc_secret(v) else _serialize_value(v)
for k, v in serialized["kwargs"].items()
}
return serialized
def _unescape_value(obj: Any) -> Any:
"""Unescape a value, processing escape markers in dict values and lists.
When an escaped dict is encountered (`{"__lc_escaped__": ...}`), it's
unwrapped and the contents are returned AS-IS (no further processing).
The contents represent user data that should not be modified.
For regular dicts and lists, we recurse to find any nested escape markers.
Args:
obj: The value to unescape.
Returns:
The unescaped value.
"""
if isinstance(obj, dict):
if _is_escaped_dict(obj):
# Unwrap and return the user data as-is (no further unescaping).
# The contents are user data that may contain more escape keys,
# but those are part of the user's actual data.
return obj[_LC_ESCAPED_KEY]
# Regular dict - recurse into values to find nested escape markers
return {k: _unescape_value(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_unescape_value(item) for item in obj]
return obj

View File

@@ -1,10 +1,26 @@
"""Dump objects to json."""
"""Serialize LangChain objects to JSON.
Provides `dumps` (to JSON string) and `dumpd` (to dict) for serializing
`Serializable` objects.
## Escaping
During serialization, plain dicts (user data) that contain an `'lc'` key are escaped
by wrapping them: `{"__lc_escaped__": {...original...}}`. This prevents injection
attacks where malicious data could trick the deserializer into instantiating
arbitrary classes. The escape marker is removed during deserialization.
This is an allowlist approach: only dicts explicitly produced by
`Serializable.to_json()` are treated as LC objects; everything else is escaped if it
could be confused with the LC format.
"""
import json
from typing import Any
from pydantic import BaseModel
from langchain_core.load._validation import _serialize_value
from langchain_core.load.serializable import Serializable, to_json_not_implemented
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
@@ -25,6 +41,20 @@ def default(obj: Any) -> Any:
def _dump_pydantic_models(obj: Any) -> Any:
"""Convert nested Pydantic models to dicts for JSON serialization.
Handles the special case where a `ChatGeneration` contains an `AIMessage`
with a parsed Pydantic model in `additional_kwargs["parsed"]`. Since
Pydantic models aren't directly JSON serializable, this converts them to
dicts.
Args:
obj: The object to process.
Returns:
A copy of the object with nested Pydantic models converted to dicts, or
the original object unchanged if no conversion was needed.
"""
if (
isinstance(obj, ChatGeneration)
and isinstance(obj.message, AIMessage)
@@ -40,10 +70,17 @@ def _dump_pydantic_models(obj: Any) -> Any:
def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
"""Return a JSON string representation of an object.
Note:
Plain dicts containing an `'lc'` key are automatically escaped to prevent
confusion with LC serialization format. The escape marker is removed during
deserialization.
Args:
obj: The object to dump.
pretty: Whether to pretty print the json. If `True`, the json will be
indented with 2 spaces (if no indent is provided as part of `kwargs`).
pretty: Whether to pretty print the json.
If `True`, the json will be indented by either 2 spaces or the amount
provided in the `indent` kwarg.
**kwargs: Additional arguments to pass to `json.dumps`
Returns:
@@ -55,28 +92,29 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
if "default" in kwargs:
msg = "`default` should not be passed to dumps"
raise ValueError(msg)
try:
obj = _dump_pydantic_models(obj)
if pretty:
indent = kwargs.pop("indent", 2)
return json.dumps(obj, default=default, indent=indent, **kwargs)
return json.dumps(obj, default=default, **kwargs)
except TypeError:
if pretty:
indent = kwargs.pop("indent", 2)
return json.dumps(to_json_not_implemented(obj), indent=indent, **kwargs)
return json.dumps(to_json_not_implemented(obj), **kwargs)
obj = _dump_pydantic_models(obj)
serialized = _serialize_value(obj)
if pretty:
indent = kwargs.pop("indent", 2)
return json.dumps(serialized, indent=indent, **kwargs)
return json.dumps(serialized, **kwargs)
def dumpd(obj: Any) -> Any:
"""Return a dict representation of an object.
Note:
Plain dicts containing an `'lc'` key are automatically escaped to prevent
confusion with LC serialization format. The escape marker is removed during
deserialization.
Args:
obj: The object to dump.
Returns:
Dictionary that can be serialized to json using `json.dumps`.
"""
# Unfortunately this function is not as efficient as it could be because it first
# dumps the object to a json string and then loads it back into a dictionary.
return json.loads(dumps(obj))
obj = _dump_pydantic_models(obj)
return _serialize_value(obj)

View File

@@ -1,16 +1,83 @@
"""Load LangChain objects from JSON strings or objects.
!!! warning
`load` and `loads` are vulnerable to remote code execution. Never use with untrusted
input.
## How it works
Each `Serializable` LangChain object has a unique identifier (its "class path"), which
is a list of strings representing the module path and class name. For example:
- `AIMessage` -> `["langchain_core", "messages", "ai", "AIMessage"]`
- `ChatPromptTemplate` -> `["langchain_core", "prompts", "chat", "ChatPromptTemplate"]`
When deserializing, the class path from the JSON `'id'` field is checked against an
allowlist. If the class is not in the allowlist, deserialization raises a `ValueError`.
## Security model
The `allowed_objects` parameter controls which classes can be deserialized:
- **`'core'` (default)**: Allow classes defined in the serialization mappings for
langchain_core.
- **`'all'`**: Allow classes defined in the serialization mappings. This
includes core LangChain types (messages, prompts, documents, etc.) and trusted
partner integrations. See `langchain_core.load.mapping` for the full list.
- **Explicit list of classes**: Only those specific classes are allowed.
For simple data types like messages and documents, the default allowlist is safe to use.
These classes do not perform side effects during initialization.
!!! note "Side effects in allowed classes"
Deserialization calls `__init__` on allowed classes. If those classes perform side
effects during initialization (network calls, file operations, etc.), those side
effects will occur. The allowlist prevents instantiation of classes outside the
allowlist, but does not sandbox the allowed classes themselves.
Import paths are also validated against trusted namespaces before any module is
imported.
### Injection protection (escape-based)
During serialization, plain dicts that contain an `'lc'` key are escaped by wrapping
them: `{"__lc_escaped__": {...}}`. During deserialization, escaped dicts are unwrapped
and returned as plain dicts, NOT instantiated as LC objects.
This is an allowlist approach: only dicts explicitly produced by
`Serializable.to_json()` (which are NOT escaped) are treated as LC objects;
everything else is user data.
Even if an attacker's payload includes `__lc_escaped__` wrappers, it will be unwrapped
to plain dicts and NOT instantiated as malicious objects.
## Examples
```python
from langchain_core.load import load
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import AIMessage, HumanMessage
# Use default allowlist (classes from mappings) - recommended
obj = load(data)
# Allow only specific classes (most restrictive)
obj = load(
data,
allowed_objects=[
ChatPromptTemplate,
AIMessage,
HumanMessage,
],
)
```
"""
import importlib
import json
import os
from typing import Any
from collections.abc import Callable, Iterable
from typing import Any, Literal, cast
from langchain_core._api import beta
from langchain_core.load._validation import _is_escaped_dict, _unescape_value
from langchain_core.load.mapping import (
_JS_SERIALIZABLE_MAPPING,
_OG_SERIALIZABLE_MAPPING,
@@ -49,34 +116,209 @@ ALL_SERIALIZABLE_MAPPINGS = {
**_JS_SERIALIZABLE_MAPPING,
}
# Cache for the default allowed class paths computed from mappings
# Maps mode ("all" or "core") to the cached set of paths
_default_class_paths_cache: dict[str, set[tuple[str, ...]]] = {}
def _get_default_allowed_class_paths(
allowed_object_mode: Literal["all", "core"],
) -> set[tuple[str, ...]]:
"""Get the default allowed class paths from the serialization mappings.
This uses the mappings as the source of truth for what classes are allowed
by default. Both the legacy paths (keys) and current paths (values) are included.
Args:
allowed_object_mode: either `'all'` or `'core'`.
Returns:
Set of class path tuples that are allowed by default.
"""
if allowed_object_mode in _default_class_paths_cache:
return _default_class_paths_cache[allowed_object_mode]
allowed_paths: set[tuple[str, ...]] = set()
for key, value in ALL_SERIALIZABLE_MAPPINGS.items():
if allowed_object_mode == "core" and value[0] != "langchain_core":
continue
allowed_paths.add(key)
allowed_paths.add(value)
_default_class_paths_cache[allowed_object_mode] = allowed_paths
return _default_class_paths_cache[allowed_object_mode]
def _block_jinja2_templates(
class_path: tuple[str, ...],
kwargs: dict[str, Any],
) -> None:
"""Block jinja2 templates during deserialization for security.
Jinja2 templates can execute arbitrary code, so they are blocked by default when
deserializing objects with `template_format='jinja2'`.
Note:
We intentionally do NOT check the `class_path` here to keep this simple and
future-proof. If any new class is added that accepts `template_format='jinja2'`,
it will be automatically blocked without needing to update this function.
Args:
class_path: The class path tuple being deserialized (unused).
kwargs: The kwargs dict for the class constructor.
Raises:
ValueError: If `template_format` is `'jinja2'`.
"""
_ = class_path # Unused - see docstring for rationale. Kept to satisfy signature.
if kwargs.get("template_format") == "jinja2":
msg = (
"Jinja2 templates are not allowed during deserialization for security "
"reasons. Use 'f-string' template format instead, or explicitly allow "
"jinja2 by providing a custom init_validator."
)
raise ValueError(msg)
def default_init_validator(
class_path: tuple[str, ...],
kwargs: dict[str, Any],
) -> None:
"""Default init validator that blocks jinja2 templates.
This is the default validator used by `load()` and `loads()` when no custom
validator is provided.
Args:
class_path: The class path tuple being deserialized.
kwargs: The kwargs dict for the class constructor.
Raises:
ValueError: If template_format is `'jinja2'`.
"""
_block_jinja2_templates(class_path, kwargs)
AllowedObject = type[Serializable]
"""Type alias for classes that can be included in the `allowed_objects` parameter.
Must be a `Serializable` subclass (the class itself, not an instance).
"""
InitValidator = Callable[[tuple[str, ...], dict[str, Any]], None]
"""Type alias for a callable that validates kwargs during deserialization.
The callable receives:
- `class_path`: A tuple of strings identifying the class being instantiated
(e.g., `('langchain', 'schema', 'messages', 'AIMessage')`).
- `kwargs`: The kwargs dict that will be passed to the constructor.
The validator should raise an exception if the object should not be deserialized.
"""
def _compute_allowed_class_paths(
allowed_objects: Iterable[AllowedObject],
import_mappings: dict[tuple[str, ...], tuple[str, ...]],
) -> set[tuple[str, ...]]:
"""Return allowed class paths from an explicit list of classes.
A class path is a tuple of strings identifying a serializable class, derived from
`Serializable.lc_id()`. For example: `('langchain_core', 'messages', 'AIMessage')`.
Args:
allowed_objects: Iterable of `Serializable` subclasses to allow.
import_mappings: Mapping of legacy class paths to current class paths.
Returns:
Set of allowed class paths.
Example:
```python
# Allow a specific class
_compute_allowed_class_paths([MyPrompt], {}) ->
{("langchain_core", "prompts", "MyPrompt")}
# Include legacy paths that map to the same class
import_mappings = {("old", "Prompt"): ("langchain_core", "prompts", "MyPrompt")}
_compute_allowed_class_paths([MyPrompt], import_mappings) ->
{("langchain_core", "prompts", "MyPrompt"), ("old", "Prompt")}
```
"""
allowed_objects_list = list(allowed_objects)
allowed_class_paths: set[tuple[str, ...]] = set()
for allowed_obj in allowed_objects_list:
if not isinstance(allowed_obj, type) or not issubclass(
allowed_obj, Serializable
):
msg = "allowed_objects must contain Serializable subclasses."
raise TypeError(msg)
class_path = tuple(allowed_obj.lc_id())
allowed_class_paths.add(class_path)
# Add legacy paths that map to the same class.
for mapping_key, mapping_value in import_mappings.items():
if tuple(mapping_value) == class_path:
allowed_class_paths.add(mapping_key)
return allowed_class_paths
class Reviver:
"""Reviver for JSON objects."""
"""Reviver for JSON objects.
Used as the `object_hook` for `json.loads` to reconstruct LangChain objects from
their serialized JSON representation.
Only classes in the allowlist can be instantiated.
"""
def __init__(
self,
allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core",
secrets_map: dict[str, str] | None = None,
valid_namespaces: list[str] | None = None,
secrets_from_env: bool = True, # noqa: FBT001,FBT002
secrets_from_env: bool = False, # noqa: FBT001,FBT002
additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]]
| None = None,
*,
ignore_unserializable_fields: bool = False,
init_validator: InitValidator | None = default_init_validator,
) -> None:
"""Initialize the reviver.
Args:
secrets_map: A map of secrets to load.
allowed_objects: Allowlist of classes that can be deserialized.
- `'core'` (default): Allow classes defined in the serialization
mappings for `langchain_core`.
- `'all'`: Allow classes defined in the serialization mappings.
This includes core LangChain types (messages, prompts, documents,
etc.) and trusted partner integrations. See
`langchain_core.load.mapping` for the full list.
- Explicit list of classes: Only those specific classes are allowed.
secrets_map: A map of secrets to load.
If a secret is not found in the map, it will be loaded from the
environment if `secrets_from_env` is `True`.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
valid_namespaces: Additional namespaces (modules) to allow during
deserialization, beyond the default trusted namespaces.
secrets_from_env: Whether to load secrets from the environment.
additional_import_mappings: A dictionary of additional namespace mappings
additional_import_mappings: A dictionary of additional namespace mappings.
You can use this to override default mappings or add new mappings.
When `allowed_objects` is `None` (using defaults), paths from these
mappings are also added to the allowed class paths.
ignore_unserializable_fields: Whether to ignore unserializable fields.
init_validator: Optional callable to validate kwargs before instantiation.
If provided, this function is called with `(class_path, kwargs)` where
`class_path` is the class path tuple and `kwargs` is the kwargs dict.
The validator should raise an exception if the object should not be
deserialized, otherwise return `None`.
Defaults to `default_init_validator` which blocks jinja2 templates.
"""
self.secrets_from_env = secrets_from_env
self.secrets_map = secrets_map or {}
@@ -95,7 +337,26 @@ class Reviver:
if self.additional_import_mappings
else ALL_SERIALIZABLE_MAPPINGS
)
# Compute allowed class paths:
# - "all" -> use default paths from mappings (+ additional_import_mappings)
# - Explicit list -> compute from those classes
if allowed_objects in ("all", "core"):
self.allowed_class_paths: set[tuple[str, ...]] | None = (
_get_default_allowed_class_paths(
cast("Literal['all', 'core']", allowed_objects)
).copy()
)
# Add paths from additional_import_mappings to the defaults
if self.additional_import_mappings:
for key, value in self.additional_import_mappings.items():
self.allowed_class_paths.add(key)
self.allowed_class_paths.add(value)
else:
self.allowed_class_paths = _compute_allowed_class_paths(
cast("Iterable[AllowedObject]", allowed_objects), self.import_mappings
)
self.ignore_unserializable_fields = ignore_unserializable_fields
self.init_validator = init_validator
def __call__(self, value: dict[str, Any]) -> Any:
"""Revive the value.
@@ -146,6 +407,20 @@ class Reviver:
[*namespace, name] = value["id"]
mapping_key = tuple(value["id"])
if (
self.allowed_class_paths is not None
and mapping_key not in self.allowed_class_paths
):
msg = (
f"Deserialization of {mapping_key!r} is not allowed. "
"The default (allowed_objects='core') only permits core "
"langchain-core classes. To allow trusted partner integrations, "
"use allowed_objects='all'. Alternatively, pass an explicit list "
"of allowed classes via allowed_objects=[...]. "
"See langchain_core.load.mapping for the full allowlist."
)
raise ValueError(msg)
if (
namespace[0] not in self.valid_namespaces
# The root namespace ["langchain"] is not a valid identifier.
@@ -153,13 +428,11 @@ class Reviver:
):
msg = f"Invalid namespace: {value}"
raise ValueError(msg)
# Has explicit import path.
# Determine explicit import path
if mapping_key in self.import_mappings:
import_path = self.import_mappings[mapping_key]
# Split into module and name
import_dir, name = import_path[:-1], import_path[-1]
# Import module
mod = importlib.import_module(".".join(import_dir))
elif namespace[0] in DISALLOW_LOAD_FROM_PATH:
msg = (
"Trying to deserialize something that cannot "
@@ -167,9 +440,16 @@ class Reviver:
f"{mapping_key}."
)
raise ValueError(msg)
# Otherwise, treat namespace as path.
else:
mod = importlib.import_module(".".join(namespace))
# Otherwise, treat namespace as path.
import_dir = namespace
# Validate import path is in trusted namespaces before importing
if import_dir[0] not in self.valid_namespaces:
msg = f"Invalid namespace: {value}"
raise ValueError(msg)
mod = importlib.import_module(".".join(import_dir))
cls = getattr(mod, name)
@@ -181,6 +461,10 @@ class Reviver:
# We don't need to recurse on kwargs
# as json.loads will do that for us.
kwargs = value.get("kwargs", {})
if self.init_validator is not None:
self.init_validator(mapping_key, kwargs)
return cls(**kwargs)
return value
@@ -190,46 +474,74 @@ class Reviver:
def loads(
text: str,
*,
allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core",
secrets_map: dict[str, str] | None = None,
valid_namespaces: list[str] | None = None,
secrets_from_env: bool = True,
secrets_from_env: bool = False,
additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None,
ignore_unserializable_fields: bool = False,
init_validator: InitValidator | None = default_init_validator,
) -> Any:
"""Revive a LangChain class from a JSON string.
!!! warning
This function is vulnerable to remote code execution. Never use with untrusted
input.
Equivalent to `load(json.loads(text))`.
Only classes in the allowlist can be instantiated. The default allowlist includes
core LangChain types (messages, prompts, documents, etc.). See
`langchain_core.load.mapping` for the full list.
Args:
text: The string to load.
allowed_objects: Allowlist of classes that can be deserialized.
- `'core'` (default): Allow classes defined in the serialization mappings
for langchain_core.
- `'all'`: Allow classes defined in the serialization mappings.
This includes core LangChain types (messages, prompts, documents, etc.)
and trusted partner integrations. See `langchain_core.load.mapping` for
the full list.
- Explicit list of classes: Only those specific classes are allowed.
- `[]`: Disallow all deserialization (will raise on any object).
secrets_map: A map of secrets to load.
If a secret is not found in the map, it will be loaded from the environment
if `secrets_from_env` is `True`.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
valid_namespaces: Additional namespaces (modules) to allow during
deserialization, beyond the default trusted namespaces.
secrets_from_env: Whether to load secrets from the environment.
additional_import_mappings: A dictionary of additional namespace mappings
additional_import_mappings: A dictionary of additional namespace mappings.
You can use this to override default mappings or add new mappings.
When `allowed_objects` is `None` (using defaults), paths from these
mappings are also added to the allowed class paths.
ignore_unserializable_fields: Whether to ignore unserializable fields.
init_validator: Optional callable to validate kwargs before instantiation.
If provided, this function is called with `(class_path, kwargs)` where
`class_path` is the class path tuple and `kwargs` is the kwargs dict.
The validator should raise an exception if the object should not be
deserialized, otherwise return `None`. Defaults to
`default_init_validator` which blocks jinja2 templates.
Returns:
Revived LangChain objects.
Raises:
ValueError: If an object's class path is not in the `allowed_objects` allowlist.
"""
return json.loads(
text,
object_hook=Reviver(
secrets_map,
valid_namespaces,
secrets_from_env,
additional_import_mappings,
ignore_unserializable_fields=ignore_unserializable_fields,
),
# Parse JSON and delegate to load() for proper escape handling
raw_obj = json.loads(text)
return load(
raw_obj,
allowed_objects=allowed_objects,
secrets_map=secrets_map,
valid_namespaces=valid_namespaces,
secrets_from_env=secrets_from_env,
additional_import_mappings=additional_import_mappings,
ignore_unserializable_fields=ignore_unserializable_fields,
init_validator=init_validator,
)
@@ -237,49 +549,105 @@ def loads(
def load(
obj: Any,
*,
allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core",
secrets_map: dict[str, str] | None = None,
valid_namespaces: list[str] | None = None,
secrets_from_env: bool = True,
secrets_from_env: bool = False,
additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None,
ignore_unserializable_fields: bool = False,
init_validator: InitValidator | None = default_init_validator,
) -> Any:
"""Revive a LangChain class from a JSON object.
!!! warning
This function is vulnerable to remote code execution. Never use with untrusted
input.
Use this if you already have a parsed JSON object, eg. from `json.load` or
`orjson.loads`.
Use this if you already have a parsed JSON object,
eg. from `json.load` or `orjson.loads`.
Only classes in the allowlist can be instantiated. The default allowlist includes
core LangChain types (messages, prompts, documents, etc.). See
`langchain_core.load.mapping` for the full list.
Args:
obj: The object to load.
allowed_objects: Allowlist of classes that can be deserialized.
- `'core'` (default): Allow classes defined in the serialization mappings
for langchain_core.
- `'all'`: Allow classes defined in the serialization mappings.
This includes core LangChain types (messages, prompts, documents, etc.)
and trusted partner integrations. See `langchain_core.load.mapping` for
the full list.
- Explicit list of classes: Only those specific classes are allowed.
- `[]`: Disallow all deserialization (will raise on any object).
secrets_map: A map of secrets to load.
If a secret is not found in the map, it will be loaded from the environment
if `secrets_from_env` is `True`.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
valid_namespaces: Additional namespaces (modules) to allow during
deserialization, beyond the default trusted namespaces.
secrets_from_env: Whether to load secrets from the environment.
additional_import_mappings: A dictionary of additional namespace mappings
additional_import_mappings: A dictionary of additional namespace mappings.
You can use this to override default mappings or add new mappings.
When `allowed_objects` is `None` (using defaults), paths from these
mappings are also added to the allowed class paths.
ignore_unserializable_fields: Whether to ignore unserializable fields.
init_validator: Optional callable to validate kwargs before instantiation.
If provided, this function is called with `(class_path, kwargs)` where
`class_path` is the class path tuple and `kwargs` is the kwargs dict.
The validator should raise an exception if the object should not be
deserialized, otherwise return `None`. Defaults to
`default_init_validator` which blocks jinja2 templates.
Returns:
Revived LangChain objects.
Raises:
ValueError: If an object's class path is not in the `allowed_objects` allowlist.
Example:
```python
from langchain_core.load import load, dumpd
from langchain_core.messages import AIMessage
msg = AIMessage(content="Hello")
data = dumpd(msg)
# Deserialize using default allowlist
loaded = load(data)
# Or with explicit allowlist
loaded = load(data, allowed_objects=[AIMessage])
# Or extend defaults with additional mappings
loaded = load(
data,
additional_import_mappings={
("my_pkg", "MyClass"): ("my_pkg", "module", "MyClass"),
},
)
```
"""
reviver = Reviver(
allowed_objects,
secrets_map,
valid_namespaces,
secrets_from_env,
additional_import_mappings,
ignore_unserializable_fields=ignore_unserializable_fields,
init_validator=init_validator,
)
def _load(obj: Any) -> Any:
if isinstance(obj, dict):
# Need to revive leaf nodes before reviving this node
# Check for escaped dict FIRST (before recursing).
# Escaped dicts are user data that should NOT be processed as LC objects.
if _is_escaped_dict(obj):
return _unescape_value(obj)
# Not escaped - recurse into children then apply reviver
loaded_obj = {k: _load(v) for k, v in obj.items()}
return reviver(loaded_obj)
if isinstance(obj, list):

View File

@@ -1,21 +1,19 @@
"""Serialization mapping.
This file contains a mapping between the lc_namespace path for a given
subclass that implements from Serializable to the namespace
This file contains a mapping between the `lc_namespace` path for a given
subclass that implements from `Serializable` to the namespace
where that class is actually located.
This mapping helps maintain the ability to serialize and deserialize
well-known LangChain objects even if they are moved around in the codebase
across different LangChain versions.
For example,
For example, the code for the `AIMessage` class is located in
`langchain_core.messages.ai.AIMessage`. This message is associated with the
`lc_namespace` of `["langchain", "schema", "messages", "AIMessage"]`,
because this code was originally in `langchain.schema.messages.AIMessage`.
The code for AIMessage class is located in langchain_core.messages.ai.AIMessage,
This message is associated with the lc_namespace
["langchain", "schema", "messages", "AIMessage"],
because this code was originally in langchain.schema.messages.AIMessage.
The mapping allows us to deserialize an AIMessage created with an older
The mapping allows us to deserialize an `AIMessage` created with an older
version of LangChain where the code was in a different location.
"""
@@ -275,6 +273,11 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
"chat_models",
"ChatGroq",
),
("langchain_xai", "chat_models", "ChatXAI"): (
"langchain_xai",
"chat_models",
"ChatXAI",
),
("langchain", "chat_models", "fireworks", "ChatFireworks"): (
"langchain_fireworks",
"chat_models",
@@ -529,16 +532,6 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
"structured",
"StructuredPrompt",
),
("langchain_sambanova", "chat_models", "ChatSambaNovaCloud"): (
"langchain_sambanova",
"chat_models",
"ChatSambaNovaCloud",
),
("langchain_sambanova", "chat_models", "ChatSambaStudio"): (
"langchain_sambanova",
"chat_models",
"ChatSambaStudio",
),
("langchain_core", "prompts", "message", "_DictMessagePromptTemplate"): (
"langchain_core",
"prompts",

View File

@@ -539,7 +539,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
# Get the input messages
inputs = load(run.inputs)
inputs = load(run.inputs, allowed_objects="all")
input_messages = self._get_input_messages(inputs)
# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
@@ -548,7 +548,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
input_messages = input_messages[len(historic_messages) :]
# Get the output messages
output_val = load(run.outputs)
output_val = load(run.outputs, allowed_objects="all")
output_messages = self._get_output_messages(output_val)
hist.add_messages(input_messages + output_messages)
@@ -556,7 +556,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
# Get the input messages
inputs = load(run.inputs)
inputs = load(run.inputs, allowed_objects="all")
input_messages = self._get_input_messages(inputs)
# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
@@ -565,7 +565,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
input_messages = input_messages[len(historic_messages) :]
# Get the output messages
output_val = load(run.outputs)
output_val = load(run.outputs, allowed_objects="all")
output_messages = self._get_output_messages(output_val)
await hist.aadd_messages(input_messages + output_messages)

View File

@@ -563,7 +563,7 @@ def _get_standardized_inputs(
)
raise NotImplementedError(msg)
inputs = load(run.inputs)
inputs = load(run.inputs, allowed_objects="all")
if run.run_type in {"retriever", "llm", "chat_model"}:
return inputs
@@ -595,7 +595,7 @@ def _get_standardized_outputs(
Returns:
An output if returned, otherwise a None
"""
outputs = load(run.outputs)
outputs = load(run.outputs, allowed_objects="all")
if schema_format == "original":
if run.run_type == "prompt" and "output" in outputs:
# These were previously dumped before the tracer.

View File

@@ -528,7 +528,7 @@ class InMemoryVectorStore(VectorStore):
"""
path_: Path = Path(path)
with path_.open("r", encoding="utf-8") as f:
store = load(json.load(f))
store = load(json.load(f), allowed_objects=[Document])
vectorstore = cls(embedding=embedding, **kwargs)
vectorstore.store = store
return vectorstore

View File

@@ -1,6 +1,13 @@
from langchain_core.load import __all__
EXPECTED_ALL = ["dumpd", "dumps", "load", "loads", "Serializable"]
EXPECTED_ALL = [
"InitValidator",
"Serializable",
"dumpd",
"dumps",
"load",
"loads",
]
def test_all_imports() -> None:

View File

@@ -0,0 +1,431 @@
"""Tests for secret injection prevention in serialization.
Verify that user-provided data containing secret-like structures cannot be used to
extract environment variables during deserialization.
"""
import json
import os
import re
from typing import Any
from unittest import mock
import pytest
from pydantic import BaseModel
from langchain_core.documents import Document
from langchain_core.load import dumpd, dumps, load
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration
SENTINEL_ENV_VAR = "TEST_SECRET_INJECTION_VAR"
"""Sentinel value that should NEVER appear in serialized output."""
SENTINEL_VALUE = "LEAKED_SECRET_MEOW_12345"
"""Sentinel value that should NEVER appear in serialized output."""
MALICIOUS_SECRET_DICT: dict[str, Any] = {
"lc": 1,
"type": "secret",
"id": [SENTINEL_ENV_VAR],
}
"""The malicious secret-like dict that tries to read the env var"""
@pytest.fixture(autouse=True)
def _set_sentinel_env_var() -> Any:
"""Set the sentinel env var for all tests in this module."""
with mock.patch.dict(os.environ, {SENTINEL_ENV_VAR: SENTINEL_VALUE}):
yield
def _assert_no_secret_leak(payload: Any) -> None:
"""Assert that serializing/deserializing payload doesn't leak the secret."""
# First serialize
serialized = dumps(payload)
# Deserialize with secrets_from_env=True (the dangerous setting)
deserialized = load(serialized, secrets_from_env=True)
# Re-serialize to string
reserialized = dumps(deserialized)
assert SENTINEL_VALUE not in reserialized, (
f"Secret was leaked! Found '{SENTINEL_VALUE}' in output.\n"
f"Original payload type: {type(payload)}\n"
f"Reserialized output: {reserialized[:500]}..."
)
assert SENTINEL_VALUE not in repr(deserialized), (
f"Secret was leaked in deserialized object! Found '{SENTINEL_VALUE}'.\n"
f"Deserialized: {deserialized!r}"
)
class TestSerializableTopLevel:
"""Tests with `Serializable` objects at the top level."""
def test_human_message_with_secret_in_content(self) -> None:
"""`HumanMessage` with secret-like dict in `content`."""
msg = HumanMessage(
content=[
{"type": "text", "text": "Hello"},
{"type": "text", "text": MALICIOUS_SECRET_DICT},
]
)
_assert_no_secret_leak(msg)
def test_human_message_with_secret_in_additional_kwargs(self) -> None:
"""`HumanMessage` with secret-like dict in `additional_kwargs`."""
msg = HumanMessage(
content="Hello",
additional_kwargs={"data": MALICIOUS_SECRET_DICT},
)
_assert_no_secret_leak(msg)
def test_human_message_with_secret_in_nested_additional_kwargs(self) -> None:
"""`HumanMessage` with secret-like dict nested in `additional_kwargs`."""
msg = HumanMessage(
content="Hello",
additional_kwargs={"nested": {"deep": MALICIOUS_SECRET_DICT}},
)
_assert_no_secret_leak(msg)
def test_human_message_with_secret_in_list_in_additional_kwargs(self) -> None:
"""`HumanMessage` with secret-like dict in a list in `additional_kwargs`."""
msg = HumanMessage(
content="Hello",
additional_kwargs={"items": [MALICIOUS_SECRET_DICT]},
)
_assert_no_secret_leak(msg)
def test_ai_message_with_secret_in_response_metadata(self) -> None:
"""`AIMessage` with secret-like dict in respo`nse_metadata."""
msg = AIMessage(
content="Hello",
response_metadata={"data": MALICIOUS_SECRET_DICT},
)
_assert_no_secret_leak(msg)
def test_document_with_secret_in_metadata(self) -> None:
"""Document with secret-like dict in `metadata`."""
doc = Document(
page_content="Hello",
metadata={"data": MALICIOUS_SECRET_DICT},
)
_assert_no_secret_leak(doc)
def test_nested_serializable_with_secret(self) -> None:
"""`AIMessage` containing `dumpd(HumanMessage)` with secret in kwargs."""
inner = HumanMessage(
content="Hello",
additional_kwargs={"secret": MALICIOUS_SECRET_DICT},
)
outer = AIMessage(
content="Outer",
additional_kwargs={"nested": [dumpd(inner)]},
)
_assert_no_secret_leak(outer)
class TestDictTopLevel:
"""Tests with plain dicts at the top level."""
def test_dict_with_serializable_containing_secret(self) -> None:
"""Dict containing a `Serializable` with secret-like dict."""
msg = HumanMessage(
content="Hello",
additional_kwargs={"data": MALICIOUS_SECRET_DICT},
)
payload = {"message": msg}
_assert_no_secret_leak(payload)
def test_dict_with_secret_no_serializable(self) -> None:
"""Dict with secret-like dict, no `Serializable` objects."""
payload = {"data": MALICIOUS_SECRET_DICT}
_assert_no_secret_leak(payload)
def test_dict_with_nested_secret_no_serializable(self) -> None:
"""Dict with nested secret-like dict, no `Serializable` objects."""
payload = {"outer": {"inner": MALICIOUS_SECRET_DICT}}
_assert_no_secret_leak(payload)
def test_dict_with_secret_in_list(self) -> None:
"""Dict with secret-like dict in a list."""
payload = {"items": [MALICIOUS_SECRET_DICT]}
_assert_no_secret_leak(payload)
def test_dict_mimicking_lc_constructor_with_secret(self) -> None:
"""Dict that looks like an LC constructor containing a secret."""
payload = {
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "messages", "ai", "AIMessage"],
"kwargs": {
"content": "Hello",
"additional_kwargs": {"secret": MALICIOUS_SECRET_DICT},
},
}
_assert_no_secret_leak(payload)
class TestPydanticModelTopLevel:
"""Tests with Pydantic models (non-`Serializable`) at the top level."""
def test_pydantic_model_with_serializable_containing_secret(self) -> None:
"""Pydantic model containing a `Serializable` with secret-like dict."""
class MyModel(BaseModel):
message: Any
msg = HumanMessage(
content="Hello",
additional_kwargs={"data": MALICIOUS_SECRET_DICT},
)
payload = MyModel(message=msg)
_assert_no_secret_leak(payload)
def test_pydantic_model_with_secret_dict(self) -> None:
"""Pydantic model containing a secret-like dict directly."""
class MyModel(BaseModel):
data: dict[str, Any]
payload = MyModel(data=MALICIOUS_SECRET_DICT)
_assert_no_secret_leak(payload)
# Test treatment of "parsed" in additional_kwargs
msg = AIMessage(content=[], additional_kwargs={"parsed": payload})
gen = ChatGeneration(message=msg)
_assert_no_secret_leak(gen)
round_trip = load(dumpd(gen))
assert MyModel(**(round_trip.message.additional_kwargs["parsed"])) == payload
def test_pydantic_model_with_nested_secret(self) -> None:
"""Pydantic model with nested secret-like dict."""
class MyModel(BaseModel):
nested: dict[str, Any]
payload = MyModel(nested={"inner": MALICIOUS_SECRET_DICT})
_assert_no_secret_leak(payload)
class TestNonSerializableClassTopLevel:
"""Tests with classes at the top level."""
def test_custom_class_with_serializable_containing_secret(self) -> None:
"""Custom class containing a `Serializable` with secret-like dict."""
class MyClass:
def __init__(self, message: Any) -> None:
self.message = message
msg = HumanMessage(
content="Hello",
additional_kwargs={"data": MALICIOUS_SECRET_DICT},
)
payload = MyClass(message=msg)
# This will serialize as not_implemented, but let's verify no leak
_assert_no_secret_leak(payload)
def test_custom_class_with_secret_dict(self) -> None:
"""Custom class containing a secret-like dict directly."""
class MyClass:
def __init__(self, data: dict[str, Any]) -> None:
self.data = data
payload = MyClass(data=MALICIOUS_SECRET_DICT)
_assert_no_secret_leak(payload)
class TestDumpdInKwargs:
"""Tests for the specific pattern of `dumpd()` result stored in kwargs."""
def test_dumpd_human_message_in_ai_message_kwargs(self) -> None:
"""`AIMessage` with `dumpd(HumanMessage)` in `additional_kwargs`."""
h = HumanMessage("Hello")
a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]})
_assert_no_secret_leak(a)
def test_dumpd_human_message_with_secret_in_ai_message_kwargs(self) -> None:
"""`AIMessage` with `dumpd(HumanMessage w/ secret)` in `additional_kwargs`."""
h = HumanMessage(
"Hello",
additional_kwargs={"secret": MALICIOUS_SECRET_DICT},
)
a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]})
_assert_no_secret_leak(a)
def test_double_dumpd_nesting(self) -> None:
"""Double nesting: `dumpd(AIMessage(dumpd(HumanMessage)))`."""
h = HumanMessage(
"Hello",
additional_kwargs={"secret": MALICIOUS_SECRET_DICT},
)
a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]})
outer = AIMessage("outer", additional_kwargs={"nested": [dumpd(a)]})
_assert_no_secret_leak(outer)
class TestRoundTrip:
"""Tests that verify round-trip serialization preserves data structure."""
def test_human_message_with_secret_round_trip(self) -> None:
"""Verify secret-like dict is preserved as dict after round-trip."""
msg = HumanMessage(
content="Hello",
additional_kwargs={"data": MALICIOUS_SECRET_DICT},
)
serialized = dumpd(msg)
deserialized = load(serialized, secrets_from_env=True)
# The secret-like dict should be preserved as a plain dict
assert deserialized.additional_kwargs["data"] == MALICIOUS_SECRET_DICT
assert isinstance(deserialized.additional_kwargs["data"], dict)
def test_document_with_secret_round_trip(self) -> None:
"""Verify secret-like dict in `Document` metadata is preserved."""
doc = Document(
page_content="Hello",
metadata={"data": MALICIOUS_SECRET_DICT},
)
serialized = dumpd(doc)
deserialized = load(
serialized, secrets_from_env=True, allowed_objects=[Document]
)
# The secret-like dict should be preserved as a plain dict
assert deserialized.metadata["data"] == MALICIOUS_SECRET_DICT
assert isinstance(deserialized.metadata["data"], dict)
def test_plain_dict_with_secret_round_trip(self) -> None:
"""Verify secret-like dict in plain dict is preserved."""
payload = {"data": MALICIOUS_SECRET_DICT}
serialized = dumpd(payload)
deserialized = load(serialized, secrets_from_env=True)
# The secret-like dict should be preserved as a plain dict
assert deserialized["data"] == MALICIOUS_SECRET_DICT
assert isinstance(deserialized["data"], dict)
class TestEscapingEfficiency:
"""Tests that escaping doesn't cause excessive nesting."""
def test_no_triple_escaping(self) -> None:
"""Verify dumpd doesn't cause triple/multiple escaping."""
h = HumanMessage(
"Hello",
additional_kwargs={"bar": [MALICIOUS_SECRET_DICT]},
)
a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]})
d = dumpd(a)
serialized = json.dumps(d)
# Count nested escape markers -
# should be max 2 (one for HumanMessage, one for secret)
# Not 3+ which would indicate re-escaping of already-escaped content
escape_count = len(re.findall(r"__lc_escaped__", serialized))
# The HumanMessage dict gets escaped (1), the secret inside gets escaped (1)
# Total should be 2, not 4 (which would mean triple nesting)
assert escape_count <= 2, (
f"Found {escape_count} escape markers, expected <= 2. "
f"This indicates unnecessary re-escaping.\n{serialized}"
)
def test_double_nesting_no_quadruple_escape(self) -> None:
"""Verify double dumpd nesting doesn't explode escape markers."""
h = HumanMessage(
"Hello",
additional_kwargs={"secret": MALICIOUS_SECRET_DICT},
)
a = AIMessage("middle", additional_kwargs={"nested": [dumpd(h)]})
outer = AIMessage("outer", additional_kwargs={"deep": [dumpd(a)]})
d = dumpd(outer)
serialized = json.dumps(d)
escape_count = len(re.findall(r"__lc_escaped__", serialized))
# Should be:
# outer escapes middle (1),
# middle escapes h (1),
# h escapes secret (1) = 3
# Not 6+ which would indicate re-escaping
assert escape_count <= 3, (
f"Found {escape_count} escape markers, expected <= 3. "
f"This indicates unnecessary re-escaping."
)
class TestConstructorInjection:
"""Tests for constructor-type injection (not just secrets)."""
def test_constructor_in_metadata_not_instantiated(self) -> None:
"""Verify constructor-like dict in metadata is not instantiated."""
malicious_constructor = {
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "messages", "ai", "AIMessage"],
"kwargs": {"content": "injected"},
}
doc = Document(
page_content="Hello",
metadata={"data": malicious_constructor},
)
serialized = dumpd(doc)
deserialized = load(
serialized,
secrets_from_env=True,
allowed_objects=[Document, AIMessage],
)
# The constructor-like dict should be a plain dict, NOT an AIMessage
assert isinstance(deserialized.metadata["data"], dict)
assert deserialized.metadata["data"] == malicious_constructor
def test_constructor_in_content_not_instantiated(self) -> None:
"""Verify constructor-like dict in message content is not instantiated."""
malicious_constructor = {
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "messages", "human", "HumanMessage"],
"kwargs": {"content": "injected"},
}
msg = AIMessage(
content="Hello",
additional_kwargs={"nested": malicious_constructor},
)
serialized = dumpd(msg)
deserialized = load(
serialized,
secrets_from_env=True,
allowed_objects=[AIMessage, HumanMessage],
)
# The constructor-like dict should be a plain dict, NOT a HumanMessage
assert isinstance(deserialized.additional_kwargs["nested"], dict)
assert deserialized.additional_kwargs["nested"] == malicious_constructor
def test_allowed_objects() -> None:
# Core object
msg = AIMessage(content="foo")
serialized = dumpd(msg)
assert load(serialized) == msg
assert load(serialized, allowed_objects=[AIMessage]) == msg
assert load(serialized, allowed_objects="core") == msg
with pytest.raises(ValueError, match="not allowed"):
load(serialized, allowed_objects=[])
with pytest.raises(ValueError, match="not allowed"):
load(serialized, allowed_objects=[Document])

View File

@@ -1,12 +1,19 @@
import json
from typing import Any
import pytest
from pydantic import BaseModel, ConfigDict, Field, SecretStr
from langchain_core.load import Serializable, dumpd, dumps, load
from langchain_core.documents import Document
from langchain_core.load import InitValidator, Serializable, dumpd, dumps, load, loads
from langchain_core.load.serializable import _is_field_useful
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
)
class NonBoolObj:
@@ -145,10 +152,17 @@ def test_simple_deserialization() -> None:
"lc": 1,
"type": "constructor",
}
new_foo = load(serialized_foo, valid_namespaces=["tests"])
new_foo = load(serialized_foo, allowed_objects=[Foo], valid_namespaces=["tests"])
assert new_foo == foo
def test_disallowed_deserialization() -> None:
foo = Foo(bar=1, baz="hello")
serialized_foo = dumpd(foo)
with pytest.raises(ValueError, match="not allowed"):
load(serialized_foo, allowed_objects=[], valid_namespaces=["tests"])
class Foo2(Serializable):
bar: int
baz: str
@@ -170,6 +184,7 @@ def test_simple_deserialization_with_additional_imports() -> None:
}
new_foo = load(
serialized_foo,
allowed_objects=[Foo2],
valid_namespaces=["tests"],
additional_import_mappings={
("tests", "unit_tests", "load", "test_serializable", "Foo"): (
@@ -223,7 +238,7 @@ def test_serialization_with_pydantic() -> None:
)
)
ser = dumpd(llm_response)
deser = load(ser)
deser = load(ser, allowed_objects=[ChatGeneration, AIMessage])
assert isinstance(deser, ChatGeneration)
assert deser.message.content
assert deser.message.additional_kwargs["parsed"] == my_model.model_dump()
@@ -260,8 +275,8 @@ def test_serialization_with_ignore_unserializable_fields() -> None:
]
]
}
ser = dumpd(data)
deser = load(ser, ignore_unserializable_fields=True)
# Load directly (no dumpd - this is already serialized data)
deser = load(data, allowed_objects=[AIMessage], ignore_unserializable_fields=True)
assert deser == {
"messages": [
[
@@ -365,3 +380,514 @@ def test_dumps_mixed_data_structure() -> None:
# Primitives should remain unchanged
assert parsed["list"] == [1, 2, {"nested": "value"}]
assert parsed["primitive"] == "string"
def test_document_normal_metadata_allowed() -> None:
"""Test that `Document` metadata without `'lc'` key works fine."""
doc = Document(
page_content="Hello world",
metadata={"source": "test.txt", "page": 1, "nested": {"key": "value"}},
)
serialized = dumpd(doc)
loaded = load(serialized, allowed_objects=[Document])
assert loaded.page_content == "Hello world"
expected = {"source": "test.txt", "page": 1, "nested": {"key": "value"}}
assert loaded.metadata == expected
class TestEscaping:
"""Tests that escape-based serialization prevents injection attacks.
When user data contains an `'lc'` key, it's escaped during serialization
(wrapped in `{"__lc_escaped__": ...}`). During deserialization, escaped
dicts are unwrapped and returned as plain dicts - NOT instantiated as
LC objects.
"""
def test_document_metadata_with_lc_key_escaped(self) -> None:
"""Test that `Document` metadata with `'lc'` key round-trips as plain dict."""
# User data that looks like an LC constructor - should be escaped, not executed
suspicious_metadata = {"lc": 1, "type": "constructor", "id": ["some", "module"]}
doc = Document(page_content="test", metadata=suspicious_metadata)
# Serialize - should escape the metadata
serialized = dumpd(doc)
assert serialized["kwargs"]["metadata"] == {
"__lc_escaped__": suspicious_metadata
}
# Deserialize - should restore original metadata as plain dict
loaded = load(serialized, allowed_objects=[Document])
assert loaded.metadata == suspicious_metadata # Plain dict, not instantiated
def test_document_metadata_with_nested_lc_key_escaped(self) -> None:
"""Test that nested `'lc'` key in `Document` metadata is escaped."""
suspicious_nested = {"lc": 1, "type": "constructor", "id": ["some", "module"]}
doc = Document(page_content="test", metadata={"nested": suspicious_nested})
serialized = dumpd(doc)
# The nested dict with 'lc' key should be escaped
assert serialized["kwargs"]["metadata"]["nested"] == {
"__lc_escaped__": suspicious_nested
}
loaded = load(serialized, allowed_objects=[Document])
assert loaded.metadata == {"nested": suspicious_nested}
def test_document_metadata_with_lc_key_in_list_escaped(self) -> None:
"""Test that `'lc'` key in list items within `Document` metadata is escaped."""
suspicious_item = {"lc": 1, "type": "constructor", "id": ["some", "module"]}
doc = Document(page_content="test", metadata={"items": [suspicious_item]})
serialized = dumpd(doc)
assert serialized["kwargs"]["metadata"]["items"][0] == {
"__lc_escaped__": suspicious_item
}
loaded = load(serialized, allowed_objects=[Document])
assert loaded.metadata == {"items": [suspicious_item]}
def test_malicious_payload_not_instantiated(self) -> None:
"""Test that malicious LC-like structures in user data are NOT instantiated."""
# An attacker might craft a payload with a valid AIMessage structure in metadata
malicious_data = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "document", "Document"],
"kwargs": {
"page_content": "test",
"metadata": {
# This looks like a valid LC object but is in escaped form
"__lc_escaped__": {
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "messages", "ai", "AIMessage"],
"kwargs": {"content": "injected message"},
}
},
},
}
# Even though AIMessage is allowed, the metadata should remain as dict
loaded = load(malicious_data, allowed_objects=[Document, AIMessage])
assert loaded.page_content == "test"
# The metadata is the original dict (unescaped), NOT an AIMessage instance
assert loaded.metadata == {
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "messages", "ai", "AIMessage"],
"kwargs": {"content": "injected message"},
}
assert not isinstance(loaded.metadata, AIMessage)
def test_message_additional_kwargs_with_lc_key_escaped(self) -> None:
"""Test that `AIMessage` `additional_kwargs` with `'lc'` is escaped."""
suspicious_data = {"lc": 1, "type": "constructor", "id": ["x", "y"]}
msg = AIMessage(
content="Hello",
additional_kwargs={"data": suspicious_data},
)
serialized = dumpd(msg)
assert serialized["kwargs"]["additional_kwargs"]["data"] == {
"__lc_escaped__": suspicious_data
}
loaded = load(serialized, allowed_objects=[AIMessage])
assert loaded.additional_kwargs == {"data": suspicious_data}
def test_message_response_metadata_with_lc_key_escaped(self) -> None:
"""Test that `AIMessage` `response_metadata` with `'lc'` is escaped."""
suspicious_data = {"lc": 1, "type": "constructor", "id": ["x", "y"]}
msg = AIMessage(content="Hello", response_metadata=suspicious_data)
serialized = dumpd(msg)
assert serialized["kwargs"]["response_metadata"] == {
"__lc_escaped__": suspicious_data
}
loaded = load(serialized, allowed_objects=[AIMessage])
assert loaded.response_metadata == suspicious_data
def test_double_escape_handling(self) -> None:
"""Test that data containing escape key itself is properly handled."""
# User data that contains our escape key
data_with_escape_key = {"__lc_escaped__": "some_value"}
doc = Document(page_content="test", metadata=data_with_escape_key)
serialized = dumpd(doc)
# Should be double-escaped since it looks like an escaped dict
assert serialized["kwargs"]["metadata"] == {
"__lc_escaped__": {"__lc_escaped__": "some_value"}
}
loaded = load(serialized, allowed_objects=[Document])
assert loaded.metadata == {"__lc_escaped__": "some_value"}
class TestDumpdEscapesLcKeyInPlainDicts:
"""Tests that `dumpd()` escapes `'lc'` keys in plain dict kwargs."""
def test_normal_message_not_escaped(self) -> None:
"""Test that normal `AIMessage` without `'lc'` key is not escaped."""
msg = AIMessage(
content="Hello",
additional_kwargs={"tool_calls": []},
response_metadata={"model": "gpt-4"},
)
serialized = dumpd(msg)
assert serialized["kwargs"]["content"] == "Hello"
# No escape wrappers for normal data
assert "__lc_escaped__" not in str(serialized)
def test_document_metadata_with_lc_key_escaped(self) -> None:
"""Test that `Document` with `'lc'` key in metadata is escaped."""
doc = Document(
page_content="test",
metadata={"lc": 1, "type": "constructor"},
)
serialized = dumpd(doc)
# Should be escaped, not blocked
assert serialized["kwargs"]["metadata"] == {
"__lc_escaped__": {"lc": 1, "type": "constructor"}
}
def test_document_metadata_with_nested_lc_key_escaped(self) -> None:
"""Test that `Document` with nested `'lc'` in metadata is escaped."""
doc = Document(
page_content="test",
metadata={"nested": {"lc": 1}},
)
serialized = dumpd(doc)
assert serialized["kwargs"]["metadata"]["nested"] == {
"__lc_escaped__": {"lc": 1}
}
def test_message_additional_kwargs_with_lc_key_escaped(self) -> None:
"""Test `AIMessage` with `'lc'` in `additional_kwargs` is escaped."""
msg = AIMessage(
content="Hello",
additional_kwargs={"malicious": {"lc": 1}},
)
serialized = dumpd(msg)
assert serialized["kwargs"]["additional_kwargs"]["malicious"] == {
"__lc_escaped__": {"lc": 1}
}
def test_message_response_metadata_with_lc_key_escaped(self) -> None:
"""Test `AIMessage` with `'lc'` in `response_metadata` is escaped."""
msg = AIMessage(
content="Hello",
response_metadata={"lc": 1},
)
serialized = dumpd(msg)
assert serialized["kwargs"]["response_metadata"] == {
"__lc_escaped__": {"lc": 1}
}
class TestInitValidator:
"""Tests for `init_validator` on `load()` and `loads()`."""
def test_init_validator_allows_valid_kwargs(self) -> None:
"""Test that `init_validator` returning None allows deserialization."""
msg = AIMessage(content="Hello")
serialized = dumpd(msg)
def allow_all(_class_path: tuple[str, ...], _kwargs: dict[str, Any]) -> None:
pass # Allow all by doing nothing
loaded = load(serialized, allowed_objects=[AIMessage], init_validator=allow_all)
assert loaded == msg
def test_init_validator_blocks_deserialization(self) -> None:
"""Test that `init_validator` can block deserialization by raising."""
doc = Document(page_content="test", metadata={"source": "test.txt"})
serialized = dumpd(doc)
def block_metadata(
_class_path: tuple[str, ...], kwargs: dict[str, Any]
) -> None:
if "metadata" in kwargs:
msg = "Metadata not allowed"
raise ValueError(msg)
with pytest.raises(ValueError, match="Metadata not allowed"):
load(serialized, allowed_objects=[Document], init_validator=block_metadata)
def test_init_validator_receives_correct_class_path(self) -> None:
"""Test that `init_validator` receives the correct class path."""
msg = AIMessage(content="Hello")
serialized = dumpd(msg)
received_class_paths: list[tuple[str, ...]] = []
def capture_class_path(
class_path: tuple[str, ...], _kwargs: dict[str, Any]
) -> None:
received_class_paths.append(class_path)
load(serialized, allowed_objects=[AIMessage], init_validator=capture_class_path)
assert len(received_class_paths) == 1
assert received_class_paths[0] == (
"langchain",
"schema",
"messages",
"AIMessage",
)
def test_init_validator_receives_correct_kwargs(self) -> None:
"""Test that `init_validator` receives the kwargs dict."""
msg = AIMessage(content="Hello world", name="test_name")
serialized = dumpd(msg)
received_kwargs: list[dict[str, Any]] = []
def capture_kwargs(
_class_path: tuple[str, ...], kwargs: dict[str, Any]
) -> None:
received_kwargs.append(kwargs)
load(serialized, allowed_objects=[AIMessage], init_validator=capture_kwargs)
assert len(received_kwargs) == 1
assert "content" in received_kwargs[0]
assert received_kwargs[0]["content"] == "Hello world"
assert "name" in received_kwargs[0]
assert received_kwargs[0]["name"] == "test_name"
def test_init_validator_with_loads(self) -> None:
"""Test that `init_validator` works with `loads()` function."""
doc = Document(page_content="test", metadata={"key": "value"})
json_str = dumps(doc)
def block_metadata(
_class_path: tuple[str, ...], kwargs: dict[str, Any]
) -> None:
if "metadata" in kwargs:
msg = "Metadata not allowed"
raise ValueError(msg)
with pytest.raises(ValueError, match="Metadata not allowed"):
loads(json_str, allowed_objects=[Document], init_validator=block_metadata)
def test_init_validator_none_allows_all(self) -> None:
"""Test that `init_validator=None` (default) allows all kwargs."""
msg = AIMessage(content="Hello")
serialized = dumpd(msg)
# Should work without init_validator
loaded = load(serialized, allowed_objects=[AIMessage])
assert loaded == msg
def test_init_validator_type_alias_exists(self) -> None:
"""Test that `InitValidator` type alias is exported and usable."""
def my_validator(_class_path: tuple[str, ...], _kwargs: dict[str, Any]) -> None:
pass
validator_typed: InitValidator = my_validator
assert callable(validator_typed)
def test_init_validator_blocks_specific_class(self) -> None:
"""Test blocking deserialization for a specific class."""
doc = Document(page_content="test", metadata={"source": "test.txt"})
serialized = dumpd(doc)
def block_documents(
class_path: tuple[str, ...], _kwargs: dict[str, Any]
) -> None:
if class_path == ("langchain", "schema", "document", "Document"):
msg = "Documents not allowed"
raise ValueError(msg)
with pytest.raises(ValueError, match="Documents not allowed"):
load(serialized, allowed_objects=[Document], init_validator=block_documents)
class TestJinja2SecurityBlocking:
"""Tests blocking Jinja2 templates by default."""
def test_fstring_template_allowed(self) -> None:
"""Test that f-string templates deserialize successfully."""
# Serialized ChatPromptTemplate with f-string format
serialized = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "prompts", "chat", "ChatPromptTemplate"],
"kwargs": {
"input_variables": ["name"],
"messages": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"HumanMessagePromptTemplate",
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate",
],
"kwargs": {
"input_variables": ["name"],
"template": "Hello {name}",
"template_format": "f-string",
},
}
},
}
],
},
}
# f-string should deserialize successfully
loaded = load(
serialized,
allowed_objects=[
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
],
)
assert isinstance(loaded, ChatPromptTemplate)
assert loaded.input_variables == ["name"]
def test_jinja2_template_blocked(self) -> None:
"""Test that Jinja2 templates are blocked by default."""
# Malicious serialized payload attempting to use jinja2
malicious_serialized = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "prompts", "chat", "ChatPromptTemplate"],
"kwargs": {
"input_variables": ["name"],
"messages": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"HumanMessagePromptTemplate",
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate",
],
"kwargs": {
"input_variables": ["name"],
"template": "{{ name }}",
"template_format": "jinja2",
},
}
},
}
],
},
}
# jinja2 should be blocked by default
with pytest.raises(ValueError, match="Jinja2 templates are not allowed"):
load(
malicious_serialized,
allowed_objects=[
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
],
)
def test_jinja2_blocked_standalone_prompt_template(self) -> None:
"""Test blocking Jinja2 on standalone `PromptTemplate`."""
serialized_jinja2 = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
"kwargs": {
"input_variables": ["name"],
"template": "{{ name }}",
"template_format": "jinja2",
},
}
serialized_fstring = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
"kwargs": {
"input_variables": ["name"],
"template": "{name}",
"template_format": "f-string",
},
}
# f-string should work
loaded = load(
serialized_fstring,
allowed_objects=[PromptTemplate],
)
assert isinstance(loaded, PromptTemplate)
assert loaded.template == "{name}"
# jinja2 should be blocked by default
with pytest.raises(ValueError, match="Jinja2 templates are not allowed"):
load(
serialized_jinja2,
allowed_objects=[PromptTemplate],
)
def test_jinja2_blocked_by_default(self) -> None:
"""Test that Jinja2 templates are blocked by default."""
serialized_jinja2 = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
"kwargs": {
"input_variables": ["name"],
"template": "{{ name }}",
"template_format": "jinja2",
},
}
serialized_fstring = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
"kwargs": {
"input_variables": ["name"],
"template": "{name}",
"template_format": "f-string",
},
}
# f-string should work
loaded = load(serialized_fstring, allowed_objects=[PromptTemplate])
assert isinstance(loaded, PromptTemplate)
assert loaded.template == "{name}"
# jinja2 should be blocked by default
with pytest.raises(ValueError, match="Jinja2 templates are not allowed"):
load(serialized_jinja2, allowed_objects=[PromptTemplate])

View File

@@ -47,7 +47,7 @@ def test_serdes_message() -> None:
}
actual = dumpd(msg)
assert actual == expected
assert load(actual) == msg
assert load(actual, allowed_objects=[AIMessage]) == msg
def test_serdes_message_chunk() -> None:
@@ -102,7 +102,7 @@ def test_serdes_message_chunk() -> None:
}
actual = dumpd(chunk)
assert actual == expected
assert load(actual) == chunk
assert load(actual, allowed_objects=[AIMessageChunk]) == chunk
def test_add_usage_both_none() -> None:

View File

@@ -1123,7 +1123,7 @@ def test_data_prompt_template_deserializable() -> None:
)
]
)
)
),
)

View File

@@ -31,4 +31,4 @@ def test_deserialize_legacy() -> None:
expected = DictPromptTemplate(
template={"type": "audio", "audio": "{audio_data}"}, template_format="f-string"
)
assert load(ser) == expected
assert load(ser, allowed_objects=[DictPromptTemplate]) == expected

View File

@@ -11,7 +11,7 @@ def test_image_prompt_template_deserializable() -> None:
ChatPromptTemplate.from_messages(
[("system", [{"type": "image", "image_url": "{img}"}])]
)
)
),
)
@@ -105,5 +105,5 @@ def test_image_prompt_template_deserializable_old() -> None:
"input_variables": ["img", "input"],
},
}
)
),
)

View File

@@ -82,7 +82,6 @@ def test_structured_prompt_dict() -> None:
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
chain = loads(dumps(prompt)) | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}

View File

@@ -925,7 +925,7 @@ def test_tool_message_serdes() -> None:
},
}
assert dumpd(message) == ser_message
assert load(dumpd(message)) == message
assert load(dumpd(message), allowed_objects=[ToolMessage]) == message
class BadObject:
@@ -954,7 +954,7 @@ def test_tool_message_ser_non_serializable() -> None:
}
assert dumpd(message) == ser_message
with pytest.raises(NotImplementedError):
load(dumpd(ser_message))
load(dumpd(message), allowed_objects=[ToolMessage])
def test_tool_message_to_dict() -> None:

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any
from typing import Any, Literal
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
@@ -139,7 +139,8 @@ def pull(
if hasattr(client, "pull_repo"):
# >= 0.1.15
res_dict = client.pull_repo(owner_repo_commit)
obj = loads(json.dumps(res_dict["manifest"]))
allowed_objects: Literal["all", "core"] = "all" if include_model else "core"
obj = loads(json.dumps(res_dict["manifest"]), allowed_objects=allowed_objects)
if isinstance(obj, BasePromptTemplate):
if obj.metadata is None:
obj.metadata = {}

View File

@@ -7,6 +7,8 @@ from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-unt
from langchain_anthropic import ChatAnthropic
_MODEL = "claude-3-haiku-20240307"
class TestAnthropicStandard(ChatModelUnitTests):
"""Use the standard chat model unit tests against the `ChatAnthropic` class."""
@@ -17,7 +19,15 @@ class TestAnthropicStandard(ChatModelUnitTests):
@property
def chat_model_params(self) -> dict:
return {"model": "claude-3-haiku-20240307"}
return {"model": _MODEL}
@property
def init_from_env_params(self) -> tuple[dict, dict, dict]:
return (
{"ANTHROPIC_API_KEY": "test"},
{"model": _MODEL},
{"anthropic_api_key": "test"},
)
@pytest.mark.benchmark

View File

@@ -495,7 +495,7 @@ wheels = [
[[package]]
name = "langchain"
version = "1.1.3"
version = "1.2.0"
source = { editable = "../../langchain_v1" }
dependencies = [
{ name = "langchain-core" },
@@ -642,7 +642,7 @@ typing = [
[[package]]
name = "langchain-core"
version = "1.2.0"
version = "1.2.4"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },
@@ -702,7 +702,7 @@ typing = [
[[package]]
name = "langchain-tests"
version = "1.1.0"
version = "1.1.1"
source = { editable = "../../standard-tests" }
dependencies = [
{ name = "httpx" },

View File

@@ -274,6 +274,7 @@ def test_groq_serialization() -> None:
dump,
valid_namespaces=["langchain_groq"],
secrets_map={"GROQ_API_KEY": api_key2},
allowed_objects="all",
)
assert type(llm2) is ChatGroq

View File

@@ -1360,7 +1360,7 @@ def test_structured_outputs_parser() -> None:
partial(_oai_structured_outputs_parser, schema=GenerateUsername)
)
serialized = dumps(llm_output)
deserialized = loads(serialized)
deserialized = loads(serialized, allowed_objects=[ChatGeneration, AIMessage])
assert isinstance(deserialized, ChatGeneration)
result = output_parser.invoke(cast(AIMessage, deserialized.message))
assert result == parsed_response

View File

@@ -1,4 +1,7 @@
from langchain_core.load import dumpd, dumps, load, loads
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables import RunnableSequence
from langchain_openai import ChatOpenAI, OpenAI
@@ -6,7 +9,11 @@ from langchain_openai import ChatOpenAI, OpenAI
def test_loads_openai_llm() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello", top_p=0.8) # type: ignore[call-arg]
llm_string = dumps(llm)
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
llm2 = loads(
llm_string,
secrets_map={"OPENAI_API_KEY": "hello"},
allowed_objects=[OpenAI],
)
assert llm2.dict() == llm.dict()
llm_string_2 = dumps(llm2)
@@ -17,7 +24,11 @@ def test_loads_openai_llm() -> None:
def test_load_openai_llm() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
llm_obj = dumpd(llm)
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
llm2 = load(
llm_obj,
secrets_map={"OPENAI_API_KEY": "hello"},
allowed_objects=[OpenAI],
)
assert llm2.dict() == llm.dict()
assert dumpd(llm2) == llm_obj
@@ -27,7 +38,11 @@ def test_load_openai_llm() -> None:
def test_loads_openai_chat() -> None:
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
llm_string = dumps(llm)
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
llm2 = loads(
llm_string,
secrets_map={"OPENAI_API_KEY": "hello"},
allowed_objects=[ChatOpenAI],
)
assert llm2.dict() == llm.dict()
llm_string_2 = dumps(llm2)
@@ -38,8 +53,85 @@ def test_loads_openai_chat() -> None:
def test_load_openai_chat() -> None:
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
llm_obj = dumpd(llm)
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
llm2 = load(
llm_obj,
secrets_map={"OPENAI_API_KEY": "hello"},
allowed_objects=[ChatOpenAI],
)
assert llm2.dict() == llm.dict()
assert dumpd(llm2) == llm_obj
assert isinstance(llm2, ChatOpenAI)
def test_loads_runnable_sequence_prompt_model() -> None:
"""Test serialization/deserialization of a chain:
`prompt | model (RunnableSequence)`
"""
prompt = ChatPromptTemplate.from_messages([("user", "Hello, {name}!")])
model = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
chain = prompt | model
# Verify the chain is a RunnableSequence
assert isinstance(chain, RunnableSequence)
# Serialize
chain_string = dumps(chain)
# Deserialize
# (ChatPromptTemplate contains HumanMessagePromptTemplate and PromptTemplate)
chain2 = loads(
chain_string,
secrets_map={"OPENAI_API_KEY": "hello"},
allowed_objects=[
RunnableSequence,
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
ChatOpenAI,
],
)
# Verify structure
assert isinstance(chain2, RunnableSequence)
assert isinstance(chain2.first, ChatPromptTemplate)
assert isinstance(chain2.last, ChatOpenAI)
# Verify round-trip serialization
assert dumps(chain2) == chain_string
def test_load_runnable_sequence_prompt_model() -> None:
"""Test load() with a chain:
`prompt | model (RunnableSequence)`.
"""
prompt = ChatPromptTemplate.from_messages([("user", "Tell me about {topic}")])
model = ChatOpenAI(model="gpt-4o-mini", temperature=0.7, openai_api_key="hello") # type: ignore[call-arg]
chain = prompt | model
# Serialize
chain_obj = dumpd(chain)
# Deserialize
# (ChatPromptTemplate contains HumanMessagePromptTemplate and PromptTemplate)
chain2 = load(
chain_obj,
secrets_map={"OPENAI_API_KEY": "hello"},
allowed_objects=[
RunnableSequence,
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
ChatOpenAI,
],
)
# Verify structure
assert isinstance(chain2, RunnableSequence)
assert isinstance(chain2.first, ChatPromptTemplate)
assert isinstance(chain2.last, ChatOpenAI)
# Verify round-trip serialization
assert dumpd(chain2) == chain_obj

View File

@@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.10.0, <4.0.0"
resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation == 'PyPy'",
@@ -621,7 +621,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "1.2.3"
version = "1.2.4"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },

View File

@@ -1126,7 +1126,10 @@ class ChatModelUnitTests(ChatModelTests):
assert (
model.dict()
== load(
dumpd(model), valid_namespaces=model.get_lc_namespace()[:1]
dumpd(model),
valid_namespaces=model.get_lc_namespace()[:1],
allowed_objects="all",
secrets_from_env=True,
).dict()
)