mirror of
https://github.com/hwchase17/langchain.git
synced 2025-12-24 08:24:14 +00:00
fix(core): serialization patch (#34455)
- `allowed_objects` kwarg in `load` - escape lc-ser formatted dicts on `dump` - fix for jinja2 --------- Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from langchain_core._import_utils import import_attr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.load.dump import dumpd, dumps
|
||||
from langchain_core.load.load import loads
|
||||
from langchain_core.load.load import InitValidator, loads
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
# Unfortunately, we have to eagerly import load from langchain_core/load/load.py
|
||||
@@ -15,11 +15,19 @@ if TYPE_CHECKING:
|
||||
# the `from langchain_core.load.load import load` absolute import should also work.
|
||||
from langchain_core.load.load import load
|
||||
|
||||
__all__ = ("Serializable", "dumpd", "dumps", "load", "loads")
|
||||
__all__ = (
|
||||
"InitValidator",
|
||||
"Serializable",
|
||||
"dumpd",
|
||||
"dumps",
|
||||
"load",
|
||||
"loads",
|
||||
)
|
||||
|
||||
_dynamic_imports = {
|
||||
"dumpd": "dump",
|
||||
"dumps": "dump",
|
||||
"InitValidator": "load",
|
||||
"loads": "load",
|
||||
"Serializable": "serializable",
|
||||
}
|
||||
|
||||
176
libs/core/langchain_core/load/_validation.py
Normal file
176
libs/core/langchain_core/load/_validation.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Validation utilities for LangChain serialization.
|
||||
|
||||
Provides escape-based protection against injection attacks in serialized objects. The
|
||||
approach uses an allowlist design: only dicts explicitly produced by
|
||||
`Serializable.to_json()` are treated as LC objects during deserialization.
|
||||
|
||||
## How escaping works
|
||||
|
||||
During serialization, plain dicts (user data) that contain an `'lc'` key are wrapped:
|
||||
|
||||
```python
|
||||
{"lc": 1, ...} # user data that looks like LC object
|
||||
# becomes:
|
||||
{"__lc_escaped__": {"lc": 1, ...}}
|
||||
```
|
||||
|
||||
During deserialization, escaped dicts are unwrapped and returned as plain dicts,
|
||||
NOT instantiated as LC objects.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
_LC_ESCAPED_KEY = "__lc_escaped__"
|
||||
"""Sentinel key used to mark escaped user dicts during serialization.
|
||||
|
||||
When a plain dict contains 'lc' key (which could be confused with LC objects),
|
||||
we wrap it as {"__lc_escaped__": {...original...}}.
|
||||
"""
|
||||
|
||||
|
||||
def _needs_escaping(obj: dict[str, Any]) -> bool:
|
||||
"""Check if a dict needs escaping to prevent confusion with LC objects.
|
||||
|
||||
A dict needs escaping if:
|
||||
|
||||
1. It has an `'lc'` key (could be confused with LC serialization format)
|
||||
2. It has only the escape key (would be mistaken for an escaped dict)
|
||||
"""
|
||||
return "lc" in obj or (len(obj) == 1 and _LC_ESCAPED_KEY in obj)
|
||||
|
||||
|
||||
def _escape_dict(obj: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Wrap a dict in the escape marker.
|
||||
|
||||
Example:
|
||||
```python
|
||||
{"key": "value"} # becomes {"__lc_escaped__": {"key": "value"}}
|
||||
```
|
||||
"""
|
||||
return {_LC_ESCAPED_KEY: obj}
|
||||
|
||||
|
||||
def _is_escaped_dict(obj: dict[str, Any]) -> bool:
|
||||
"""Check if a dict is an escaped user dict.
|
||||
|
||||
Example:
|
||||
```python
|
||||
{"__lc_escaped__": {...}} # is an escaped dict
|
||||
```
|
||||
"""
|
||||
return len(obj) == 1 and _LC_ESCAPED_KEY in obj
|
||||
|
||||
|
||||
def _serialize_value(obj: Any) -> Any:
|
||||
"""Serialize a value with escaping of user dicts.
|
||||
|
||||
Called recursively on kwarg values to escape any plain dicts that could be confused
|
||||
with LC objects.
|
||||
|
||||
Args:
|
||||
obj: The value to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized value with user dicts escaped as needed.
|
||||
"""
|
||||
from langchain_core.load.serializable import ( # noqa: PLC0415
|
||||
Serializable,
|
||||
to_json_not_implemented,
|
||||
)
|
||||
|
||||
if isinstance(obj, Serializable):
|
||||
# This is an LC object - serialize it properly (not escaped)
|
||||
return _serialize_lc_object(obj)
|
||||
if isinstance(obj, dict):
|
||||
if not all(isinstance(k, (str, int, float, bool, type(None))) for k in obj):
|
||||
# if keys are not json serializable
|
||||
return to_json_not_implemented(obj)
|
||||
# Check if dict needs escaping BEFORE recursing into values.
|
||||
# If it needs escaping, wrap it as-is - the contents are user data that
|
||||
# will be returned as-is during deserialization (no instantiation).
|
||||
# This prevents re-escaping of already-escaped nested content.
|
||||
if _needs_escaping(obj):
|
||||
return _escape_dict(obj)
|
||||
# Safe dict (no 'lc' key) - recurse into values
|
||||
return {k: _serialize_value(v) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [_serialize_value(item) for item in obj]
|
||||
if isinstance(obj, (str, int, float, bool, type(None))):
|
||||
return obj
|
||||
|
||||
# Non-JSON-serializable object (datetime, custom objects, etc.)
|
||||
return to_json_not_implemented(obj)
|
||||
|
||||
|
||||
def _is_lc_secret(obj: Any) -> bool:
|
||||
"""Check if an object is a LangChain secret marker."""
|
||||
expected_num_keys = 3
|
||||
return (
|
||||
isinstance(obj, dict)
|
||||
and obj.get("lc") == 1
|
||||
and obj.get("type") == "secret"
|
||||
and "id" in obj
|
||||
and len(obj) == expected_num_keys
|
||||
)
|
||||
|
||||
|
||||
def _serialize_lc_object(obj: Any) -> dict[str, Any]:
|
||||
"""Serialize a `Serializable` object with escaping of user data in kwargs.
|
||||
|
||||
Args:
|
||||
obj: The `Serializable` object to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized dict with user data in kwargs escaped as needed.
|
||||
|
||||
Note:
|
||||
Kwargs values are processed with `_serialize_value` to escape user data (like
|
||||
metadata) that contains `'lc'` keys. Secret fields (from `lc_secrets`) are
|
||||
skipped because `to_json()` replaces their values with secret markers.
|
||||
"""
|
||||
from langchain_core.load.serializable import Serializable # noqa: PLC0415
|
||||
|
||||
if not isinstance(obj, Serializable):
|
||||
msg = f"Expected Serializable, got {type(obj)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
serialized: dict[str, Any] = dict(obj.to_json())
|
||||
|
||||
# Process kwargs to escape user data that could be confused with LC objects
|
||||
# Skip secret fields - to_json() already converted them to secret markers
|
||||
if serialized.get("type") == "constructor" and "kwargs" in serialized:
|
||||
serialized["kwargs"] = {
|
||||
k: v if _is_lc_secret(v) else _serialize_value(v)
|
||||
for k, v in serialized["kwargs"].items()
|
||||
}
|
||||
|
||||
return serialized
|
||||
|
||||
|
||||
def _unescape_value(obj: Any) -> Any:
|
||||
"""Unescape a value, processing escape markers in dict values and lists.
|
||||
|
||||
When an escaped dict is encountered (`{"__lc_escaped__": ...}`), it's
|
||||
unwrapped and the contents are returned AS-IS (no further processing).
|
||||
The contents represent user data that should not be modified.
|
||||
|
||||
For regular dicts and lists, we recurse to find any nested escape markers.
|
||||
|
||||
Args:
|
||||
obj: The value to unescape.
|
||||
|
||||
Returns:
|
||||
The unescaped value.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
if _is_escaped_dict(obj):
|
||||
# Unwrap and return the user data as-is (no further unescaping).
|
||||
# The contents are user data that may contain more escape keys,
|
||||
# but those are part of the user's actual data.
|
||||
return obj[_LC_ESCAPED_KEY]
|
||||
|
||||
# Regular dict - recurse into values to find nested escape markers
|
||||
return {k: _unescape_value(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_unescape_value(item) for item in obj]
|
||||
return obj
|
||||
@@ -1,10 +1,26 @@
|
||||
"""Dump objects to json."""
|
||||
"""Serialize LangChain objects to JSON.
|
||||
|
||||
Provides `dumps` (to JSON string) and `dumpd` (to dict) for serializing
|
||||
`Serializable` objects.
|
||||
|
||||
## Escaping
|
||||
|
||||
During serialization, plain dicts (user data) that contain an `'lc'` key are escaped
|
||||
by wrapping them: `{"__lc_escaped__": {...original...}}`. This prevents injection
|
||||
attacks where malicious data could trick the deserializer into instantiating
|
||||
arbitrary classes. The escape marker is removed during deserialization.
|
||||
|
||||
This is an allowlist approach: only dicts explicitly produced by
|
||||
`Serializable.to_json()` are treated as LC objects; everything else is escaped if it
|
||||
could be confused with the LC format.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.load._validation import _serialize_value
|
||||
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
@@ -25,6 +41,20 @@ def default(obj: Any) -> Any:
|
||||
|
||||
|
||||
def _dump_pydantic_models(obj: Any) -> Any:
|
||||
"""Convert nested Pydantic models to dicts for JSON serialization.
|
||||
|
||||
Handles the special case where a `ChatGeneration` contains an `AIMessage`
|
||||
with a parsed Pydantic model in `additional_kwargs["parsed"]`. Since
|
||||
Pydantic models aren't directly JSON serializable, this converts them to
|
||||
dicts.
|
||||
|
||||
Args:
|
||||
obj: The object to process.
|
||||
|
||||
Returns:
|
||||
A copy of the object with nested Pydantic models converted to dicts, or
|
||||
the original object unchanged if no conversion was needed.
|
||||
"""
|
||||
if (
|
||||
isinstance(obj, ChatGeneration)
|
||||
and isinstance(obj.message, AIMessage)
|
||||
@@ -40,10 +70,17 @@ def _dump_pydantic_models(obj: Any) -> Any:
|
||||
def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
||||
"""Return a JSON string representation of an object.
|
||||
|
||||
Note:
|
||||
Plain dicts containing an `'lc'` key are automatically escaped to prevent
|
||||
confusion with LC serialization format. The escape marker is removed during
|
||||
deserialization.
|
||||
|
||||
Args:
|
||||
obj: The object to dump.
|
||||
pretty: Whether to pretty print the json. If `True`, the json will be
|
||||
indented with 2 spaces (if no indent is provided as part of `kwargs`).
|
||||
pretty: Whether to pretty print the json.
|
||||
|
||||
If `True`, the json will be indented by either 2 spaces or the amount
|
||||
provided in the `indent` kwarg.
|
||||
**kwargs: Additional arguments to pass to `json.dumps`
|
||||
|
||||
Returns:
|
||||
@@ -55,28 +92,29 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
||||
if "default" in kwargs:
|
||||
msg = "`default` should not be passed to dumps"
|
||||
raise ValueError(msg)
|
||||
try:
|
||||
obj = _dump_pydantic_models(obj)
|
||||
if pretty:
|
||||
indent = kwargs.pop("indent", 2)
|
||||
return json.dumps(obj, default=default, indent=indent, **kwargs)
|
||||
return json.dumps(obj, default=default, **kwargs)
|
||||
except TypeError:
|
||||
if pretty:
|
||||
indent = kwargs.pop("indent", 2)
|
||||
return json.dumps(to_json_not_implemented(obj), indent=indent, **kwargs)
|
||||
return json.dumps(to_json_not_implemented(obj), **kwargs)
|
||||
|
||||
obj = _dump_pydantic_models(obj)
|
||||
serialized = _serialize_value(obj)
|
||||
|
||||
if pretty:
|
||||
indent = kwargs.pop("indent", 2)
|
||||
return json.dumps(serialized, indent=indent, **kwargs)
|
||||
return json.dumps(serialized, **kwargs)
|
||||
|
||||
|
||||
def dumpd(obj: Any) -> Any:
|
||||
"""Return a dict representation of an object.
|
||||
|
||||
Note:
|
||||
Plain dicts containing an `'lc'` key are automatically escaped to prevent
|
||||
confusion with LC serialization format. The escape marker is removed during
|
||||
deserialization.
|
||||
|
||||
Args:
|
||||
obj: The object to dump.
|
||||
|
||||
Returns:
|
||||
Dictionary that can be serialized to json using `json.dumps`.
|
||||
"""
|
||||
# Unfortunately this function is not as efficient as it could be because it first
|
||||
# dumps the object to a json string and then loads it back into a dictionary.
|
||||
return json.loads(dumps(obj))
|
||||
obj = _dump_pydantic_models(obj)
|
||||
return _serialize_value(obj)
|
||||
|
||||
@@ -1,16 +1,83 @@
|
||||
"""Load LangChain objects from JSON strings or objects.
|
||||
|
||||
!!! warning
|
||||
`load` and `loads` are vulnerable to remote code execution. Never use with untrusted
|
||||
input.
|
||||
## How it works
|
||||
|
||||
Each `Serializable` LangChain object has a unique identifier (its "class path"), which
|
||||
is a list of strings representing the module path and class name. For example:
|
||||
|
||||
- `AIMessage` -> `["langchain_core", "messages", "ai", "AIMessage"]`
|
||||
- `ChatPromptTemplate` -> `["langchain_core", "prompts", "chat", "ChatPromptTemplate"]`
|
||||
|
||||
When deserializing, the class path from the JSON `'id'` field is checked against an
|
||||
allowlist. If the class is not in the allowlist, deserialization raises a `ValueError`.
|
||||
|
||||
## Security model
|
||||
|
||||
The `allowed_objects` parameter controls which classes can be deserialized:
|
||||
|
||||
- **`'core'` (default)**: Allow classes defined in the serialization mappings for
|
||||
langchain_core.
|
||||
- **`'all'`**: Allow classes defined in the serialization mappings. This
|
||||
includes core LangChain types (messages, prompts, documents, etc.) and trusted
|
||||
partner integrations. See `langchain_core.load.mapping` for the full list.
|
||||
- **Explicit list of classes**: Only those specific classes are allowed.
|
||||
|
||||
For simple data types like messages and documents, the default allowlist is safe to use.
|
||||
These classes do not perform side effects during initialization.
|
||||
|
||||
!!! note "Side effects in allowed classes"
|
||||
|
||||
Deserialization calls `__init__` on allowed classes. If those classes perform side
|
||||
effects during initialization (network calls, file operations, etc.), those side
|
||||
effects will occur. The allowlist prevents instantiation of classes outside the
|
||||
allowlist, but does not sandbox the allowed classes themselves.
|
||||
|
||||
Import paths are also validated against trusted namespaces before any module is
|
||||
imported.
|
||||
|
||||
### Injection protection (escape-based)
|
||||
|
||||
During serialization, plain dicts that contain an `'lc'` key are escaped by wrapping
|
||||
them: `{"__lc_escaped__": {...}}`. During deserialization, escaped dicts are unwrapped
|
||||
and returned as plain dicts, NOT instantiated as LC objects.
|
||||
|
||||
This is an allowlist approach: only dicts explicitly produced by
|
||||
`Serializable.to_json()` (which are NOT escaped) are treated as LC objects;
|
||||
everything else is user data.
|
||||
|
||||
Even if an attacker's payload includes `__lc_escaped__` wrappers, it will be unwrapped
|
||||
to plain dicts and NOT instantiated as malicious objects.
|
||||
|
||||
## Examples
|
||||
|
||||
```python
|
||||
from langchain_core.load import load
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
# Use default allowlist (classes from mappings) - recommended
|
||||
obj = load(data)
|
||||
|
||||
# Allow only specific classes (most restrictive)
|
||||
obj = load(
|
||||
data,
|
||||
allowed_objects=[
|
||||
ChatPromptTemplate,
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.load._validation import _is_escaped_dict, _unescape_value
|
||||
from langchain_core.load.mapping import (
|
||||
_JS_SERIALIZABLE_MAPPING,
|
||||
_OG_SERIALIZABLE_MAPPING,
|
||||
@@ -49,34 +116,209 @@ ALL_SERIALIZABLE_MAPPINGS = {
|
||||
**_JS_SERIALIZABLE_MAPPING,
|
||||
}
|
||||
|
||||
# Cache for the default allowed class paths computed from mappings
|
||||
# Maps mode ("all" or "core") to the cached set of paths
|
||||
_default_class_paths_cache: dict[str, set[tuple[str, ...]]] = {}
|
||||
|
||||
|
||||
def _get_default_allowed_class_paths(
|
||||
allowed_object_mode: Literal["all", "core"],
|
||||
) -> set[tuple[str, ...]]:
|
||||
"""Get the default allowed class paths from the serialization mappings.
|
||||
|
||||
This uses the mappings as the source of truth for what classes are allowed
|
||||
by default. Both the legacy paths (keys) and current paths (values) are included.
|
||||
|
||||
Args:
|
||||
allowed_object_mode: either `'all'` or `'core'`.
|
||||
|
||||
Returns:
|
||||
Set of class path tuples that are allowed by default.
|
||||
"""
|
||||
if allowed_object_mode in _default_class_paths_cache:
|
||||
return _default_class_paths_cache[allowed_object_mode]
|
||||
|
||||
allowed_paths: set[tuple[str, ...]] = set()
|
||||
for key, value in ALL_SERIALIZABLE_MAPPINGS.items():
|
||||
if allowed_object_mode == "core" and value[0] != "langchain_core":
|
||||
continue
|
||||
allowed_paths.add(key)
|
||||
allowed_paths.add(value)
|
||||
|
||||
_default_class_paths_cache[allowed_object_mode] = allowed_paths
|
||||
return _default_class_paths_cache[allowed_object_mode]
|
||||
|
||||
|
||||
def _block_jinja2_templates(
|
||||
class_path: tuple[str, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Block jinja2 templates during deserialization for security.
|
||||
|
||||
Jinja2 templates can execute arbitrary code, so they are blocked by default when
|
||||
deserializing objects with `template_format='jinja2'`.
|
||||
|
||||
Note:
|
||||
We intentionally do NOT check the `class_path` here to keep this simple and
|
||||
future-proof. If any new class is added that accepts `template_format='jinja2'`,
|
||||
it will be automatically blocked without needing to update this function.
|
||||
|
||||
Args:
|
||||
class_path: The class path tuple being deserialized (unused).
|
||||
kwargs: The kwargs dict for the class constructor.
|
||||
|
||||
Raises:
|
||||
ValueError: If `template_format` is `'jinja2'`.
|
||||
"""
|
||||
_ = class_path # Unused - see docstring for rationale. Kept to satisfy signature.
|
||||
if kwargs.get("template_format") == "jinja2":
|
||||
msg = (
|
||||
"Jinja2 templates are not allowed during deserialization for security "
|
||||
"reasons. Use 'f-string' template format instead, or explicitly allow "
|
||||
"jinja2 by providing a custom init_validator."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def default_init_validator(
|
||||
class_path: tuple[str, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Default init validator that blocks jinja2 templates.
|
||||
|
||||
This is the default validator used by `load()` and `loads()` when no custom
|
||||
validator is provided.
|
||||
|
||||
Args:
|
||||
class_path: The class path tuple being deserialized.
|
||||
kwargs: The kwargs dict for the class constructor.
|
||||
|
||||
Raises:
|
||||
ValueError: If template_format is `'jinja2'`.
|
||||
"""
|
||||
_block_jinja2_templates(class_path, kwargs)
|
||||
|
||||
|
||||
AllowedObject = type[Serializable]
|
||||
"""Type alias for classes that can be included in the `allowed_objects` parameter.
|
||||
|
||||
Must be a `Serializable` subclass (the class itself, not an instance).
|
||||
"""
|
||||
|
||||
InitValidator = Callable[[tuple[str, ...], dict[str, Any]], None]
|
||||
"""Type alias for a callable that validates kwargs during deserialization.
|
||||
|
||||
The callable receives:
|
||||
|
||||
- `class_path`: A tuple of strings identifying the class being instantiated
|
||||
(e.g., `('langchain', 'schema', 'messages', 'AIMessage')`).
|
||||
- `kwargs`: The kwargs dict that will be passed to the constructor.
|
||||
|
||||
The validator should raise an exception if the object should not be deserialized.
|
||||
"""
|
||||
|
||||
|
||||
def _compute_allowed_class_paths(
|
||||
allowed_objects: Iterable[AllowedObject],
|
||||
import_mappings: dict[tuple[str, ...], tuple[str, ...]],
|
||||
) -> set[tuple[str, ...]]:
|
||||
"""Return allowed class paths from an explicit list of classes.
|
||||
|
||||
A class path is a tuple of strings identifying a serializable class, derived from
|
||||
`Serializable.lc_id()`. For example: `('langchain_core', 'messages', 'AIMessage')`.
|
||||
|
||||
Args:
|
||||
allowed_objects: Iterable of `Serializable` subclasses to allow.
|
||||
import_mappings: Mapping of legacy class paths to current class paths.
|
||||
|
||||
Returns:
|
||||
Set of allowed class paths.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Allow a specific class
|
||||
_compute_allowed_class_paths([MyPrompt], {}) ->
|
||||
{("langchain_core", "prompts", "MyPrompt")}
|
||||
|
||||
# Include legacy paths that map to the same class
|
||||
import_mappings = {("old", "Prompt"): ("langchain_core", "prompts", "MyPrompt")}
|
||||
_compute_allowed_class_paths([MyPrompt], import_mappings) ->
|
||||
{("langchain_core", "prompts", "MyPrompt"), ("old", "Prompt")}
|
||||
```
|
||||
"""
|
||||
allowed_objects_list = list(allowed_objects)
|
||||
|
||||
allowed_class_paths: set[tuple[str, ...]] = set()
|
||||
for allowed_obj in allowed_objects_list:
|
||||
if not isinstance(allowed_obj, type) or not issubclass(
|
||||
allowed_obj, Serializable
|
||||
):
|
||||
msg = "allowed_objects must contain Serializable subclasses."
|
||||
raise TypeError(msg)
|
||||
|
||||
class_path = tuple(allowed_obj.lc_id())
|
||||
allowed_class_paths.add(class_path)
|
||||
# Add legacy paths that map to the same class.
|
||||
for mapping_key, mapping_value in import_mappings.items():
|
||||
if tuple(mapping_value) == class_path:
|
||||
allowed_class_paths.add(mapping_key)
|
||||
return allowed_class_paths
|
||||
|
||||
|
||||
class Reviver:
|
||||
"""Reviver for JSON objects."""
|
||||
"""Reviver for JSON objects.
|
||||
|
||||
Used as the `object_hook` for `json.loads` to reconstruct LangChain objects from
|
||||
their serialized JSON representation.
|
||||
|
||||
Only classes in the allowlist can be instantiated.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core",
|
||||
secrets_map: dict[str, str] | None = None,
|
||||
valid_namespaces: list[str] | None = None,
|
||||
secrets_from_env: bool = True, # noqa: FBT001,FBT002
|
||||
secrets_from_env: bool = False, # noqa: FBT001,FBT002
|
||||
additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]]
|
||||
| None = None,
|
||||
*,
|
||||
ignore_unserializable_fields: bool = False,
|
||||
init_validator: InitValidator | None = default_init_validator,
|
||||
) -> None:
|
||||
"""Initialize the reviver.
|
||||
|
||||
Args:
|
||||
secrets_map: A map of secrets to load.
|
||||
allowed_objects: Allowlist of classes that can be deserialized.
|
||||
- `'core'` (default): Allow classes defined in the serialization
|
||||
mappings for `langchain_core`.
|
||||
- `'all'`: Allow classes defined in the serialization mappings.
|
||||
|
||||
This includes core LangChain types (messages, prompts, documents,
|
||||
etc.) and trusted partner integrations. See
|
||||
`langchain_core.load.mapping` for the full list.
|
||||
- Explicit list of classes: Only those specific classes are allowed.
|
||||
secrets_map: A map of secrets to load.
|
||||
If a secret is not found in the map, it will be loaded from the
|
||||
environment if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
valid_namespaces: Additional namespaces (modules) to allow during
|
||||
deserialization, beyond the default trusted namespaces.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
additional_import_mappings: A dictionary of additional namespace mappings.
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
|
||||
When `allowed_objects` is `None` (using defaults), paths from these
|
||||
mappings are also added to the allowed class paths.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
init_validator: Optional callable to validate kwargs before instantiation.
|
||||
|
||||
If provided, this function is called with `(class_path, kwargs)` where
|
||||
`class_path` is the class path tuple and `kwargs` is the kwargs dict.
|
||||
The validator should raise an exception if the object should not be
|
||||
deserialized, otherwise return `None`.
|
||||
|
||||
Defaults to `default_init_validator` which blocks jinja2 templates.
|
||||
"""
|
||||
self.secrets_from_env = secrets_from_env
|
||||
self.secrets_map = secrets_map or {}
|
||||
@@ -95,7 +337,26 @@ class Reviver:
|
||||
if self.additional_import_mappings
|
||||
else ALL_SERIALIZABLE_MAPPINGS
|
||||
)
|
||||
# Compute allowed class paths:
|
||||
# - "all" -> use default paths from mappings (+ additional_import_mappings)
|
||||
# - Explicit list -> compute from those classes
|
||||
if allowed_objects in ("all", "core"):
|
||||
self.allowed_class_paths: set[tuple[str, ...]] | None = (
|
||||
_get_default_allowed_class_paths(
|
||||
cast("Literal['all', 'core']", allowed_objects)
|
||||
).copy()
|
||||
)
|
||||
# Add paths from additional_import_mappings to the defaults
|
||||
if self.additional_import_mappings:
|
||||
for key, value in self.additional_import_mappings.items():
|
||||
self.allowed_class_paths.add(key)
|
||||
self.allowed_class_paths.add(value)
|
||||
else:
|
||||
self.allowed_class_paths = _compute_allowed_class_paths(
|
||||
cast("Iterable[AllowedObject]", allowed_objects), self.import_mappings
|
||||
)
|
||||
self.ignore_unserializable_fields = ignore_unserializable_fields
|
||||
self.init_validator = init_validator
|
||||
|
||||
def __call__(self, value: dict[str, Any]) -> Any:
|
||||
"""Revive the value.
|
||||
@@ -146,6 +407,20 @@ class Reviver:
|
||||
[*namespace, name] = value["id"]
|
||||
mapping_key = tuple(value["id"])
|
||||
|
||||
if (
|
||||
self.allowed_class_paths is not None
|
||||
and mapping_key not in self.allowed_class_paths
|
||||
):
|
||||
msg = (
|
||||
f"Deserialization of {mapping_key!r} is not allowed. "
|
||||
"The default (allowed_objects='core') only permits core "
|
||||
"langchain-core classes. To allow trusted partner integrations, "
|
||||
"use allowed_objects='all'. Alternatively, pass an explicit list "
|
||||
"of allowed classes via allowed_objects=[...]. "
|
||||
"See langchain_core.load.mapping for the full allowlist."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if (
|
||||
namespace[0] not in self.valid_namespaces
|
||||
# The root namespace ["langchain"] is not a valid identifier.
|
||||
@@ -153,13 +428,11 @@ class Reviver:
|
||||
):
|
||||
msg = f"Invalid namespace: {value}"
|
||||
raise ValueError(msg)
|
||||
# Has explicit import path.
|
||||
# Determine explicit import path
|
||||
if mapping_key in self.import_mappings:
|
||||
import_path = self.import_mappings[mapping_key]
|
||||
# Split into module and name
|
||||
import_dir, name = import_path[:-1], import_path[-1]
|
||||
# Import module
|
||||
mod = importlib.import_module(".".join(import_dir))
|
||||
elif namespace[0] in DISALLOW_LOAD_FROM_PATH:
|
||||
msg = (
|
||||
"Trying to deserialize something that cannot "
|
||||
@@ -167,9 +440,16 @@ class Reviver:
|
||||
f"{mapping_key}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
# Otherwise, treat namespace as path.
|
||||
else:
|
||||
mod = importlib.import_module(".".join(namespace))
|
||||
# Otherwise, treat namespace as path.
|
||||
import_dir = namespace
|
||||
|
||||
# Validate import path is in trusted namespaces before importing
|
||||
if import_dir[0] not in self.valid_namespaces:
|
||||
msg = f"Invalid namespace: {value}"
|
||||
raise ValueError(msg)
|
||||
|
||||
mod = importlib.import_module(".".join(import_dir))
|
||||
|
||||
cls = getattr(mod, name)
|
||||
|
||||
@@ -181,6 +461,10 @@ class Reviver:
|
||||
# We don't need to recurse on kwargs
|
||||
# as json.loads will do that for us.
|
||||
kwargs = value.get("kwargs", {})
|
||||
|
||||
if self.init_validator is not None:
|
||||
self.init_validator(mapping_key, kwargs)
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
return value
|
||||
@@ -190,46 +474,74 @@ class Reviver:
|
||||
def loads(
|
||||
text: str,
|
||||
*,
|
||||
allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core",
|
||||
secrets_map: dict[str, str] | None = None,
|
||||
valid_namespaces: list[str] | None = None,
|
||||
secrets_from_env: bool = True,
|
||||
secrets_from_env: bool = False,
|
||||
additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None,
|
||||
ignore_unserializable_fields: bool = False,
|
||||
init_validator: InitValidator | None = default_init_validator,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON string.
|
||||
|
||||
!!! warning
|
||||
This function is vulnerable to remote code execution. Never use with untrusted
|
||||
input.
|
||||
|
||||
Equivalent to `load(json.loads(text))`.
|
||||
|
||||
Only classes in the allowlist can be instantiated. The default allowlist includes
|
||||
core LangChain types (messages, prompts, documents, etc.). See
|
||||
`langchain_core.load.mapping` for the full list.
|
||||
|
||||
Args:
|
||||
text: The string to load.
|
||||
allowed_objects: Allowlist of classes that can be deserialized.
|
||||
|
||||
- `'core'` (default): Allow classes defined in the serialization mappings
|
||||
for langchain_core.
|
||||
- `'all'`: Allow classes defined in the serialization mappings.
|
||||
|
||||
This includes core LangChain types (messages, prompts, documents, etc.)
|
||||
and trusted partner integrations. See `langchain_core.load.mapping` for
|
||||
the full list.
|
||||
- Explicit list of classes: Only those specific classes are allowed.
|
||||
- `[]`: Disallow all deserialization (will raise on any object).
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the environment
|
||||
if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
valid_namespaces: Additional namespaces (modules) to allow during
|
||||
deserialization, beyond the default trusted namespaces.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
additional_import_mappings: A dictionary of additional namespace mappings.
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
|
||||
When `allowed_objects` is `None` (using defaults), paths from these
|
||||
mappings are also added to the allowed class paths.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
init_validator: Optional callable to validate kwargs before instantiation.
|
||||
|
||||
If provided, this function is called with `(class_path, kwargs)` where
|
||||
`class_path` is the class path tuple and `kwargs` is the kwargs dict.
|
||||
The validator should raise an exception if the object should not be
|
||||
deserialized, otherwise return `None`. Defaults to
|
||||
`default_init_validator` which blocks jinja2 templates.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If an object's class path is not in the `allowed_objects` allowlist.
|
||||
"""
|
||||
return json.loads(
|
||||
text,
|
||||
object_hook=Reviver(
|
||||
secrets_map,
|
||||
valid_namespaces,
|
||||
secrets_from_env,
|
||||
additional_import_mappings,
|
||||
ignore_unserializable_fields=ignore_unserializable_fields,
|
||||
),
|
||||
# Parse JSON and delegate to load() for proper escape handling
|
||||
raw_obj = json.loads(text)
|
||||
return load(
|
||||
raw_obj,
|
||||
allowed_objects=allowed_objects,
|
||||
secrets_map=secrets_map,
|
||||
valid_namespaces=valid_namespaces,
|
||||
secrets_from_env=secrets_from_env,
|
||||
additional_import_mappings=additional_import_mappings,
|
||||
ignore_unserializable_fields=ignore_unserializable_fields,
|
||||
init_validator=init_validator,
|
||||
)
|
||||
|
||||
|
||||
@@ -237,49 +549,105 @@ def loads(
|
||||
def load(
|
||||
obj: Any,
|
||||
*,
|
||||
allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core",
|
||||
secrets_map: dict[str, str] | None = None,
|
||||
valid_namespaces: list[str] | None = None,
|
||||
secrets_from_env: bool = True,
|
||||
secrets_from_env: bool = False,
|
||||
additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None,
|
||||
ignore_unserializable_fields: bool = False,
|
||||
init_validator: InitValidator | None = default_init_validator,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON object.
|
||||
|
||||
!!! warning
|
||||
This function is vulnerable to remote code execution. Never use with untrusted
|
||||
input.
|
||||
Use this if you already have a parsed JSON object, eg. from `json.load` or
|
||||
`orjson.loads`.
|
||||
|
||||
Use this if you already have a parsed JSON object,
|
||||
eg. from `json.load` or `orjson.loads`.
|
||||
Only classes in the allowlist can be instantiated. The default allowlist includes
|
||||
core LangChain types (messages, prompts, documents, etc.). See
|
||||
`langchain_core.load.mapping` for the full list.
|
||||
|
||||
Args:
|
||||
obj: The object to load.
|
||||
allowed_objects: Allowlist of classes that can be deserialized.
|
||||
|
||||
- `'core'` (default): Allow classes defined in the serialization mappings
|
||||
for langchain_core.
|
||||
- `'all'`: Allow classes defined in the serialization mappings.
|
||||
|
||||
This includes core LangChain types (messages, prompts, documents, etc.)
|
||||
and trusted partner integrations. See `langchain_core.load.mapping` for
|
||||
the full list.
|
||||
- Explicit list of classes: Only those specific classes are allowed.
|
||||
- `[]`: Disallow all deserialization (will raise on any object).
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the environment
|
||||
if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
valid_namespaces: Additional namespaces (modules) to allow during
|
||||
deserialization, beyond the default trusted namespaces.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
additional_import_mappings: A dictionary of additional namespace mappings.
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
|
||||
When `allowed_objects` is `None` (using defaults), paths from these
|
||||
mappings are also added to the allowed class paths.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
init_validator: Optional callable to validate kwargs before instantiation.
|
||||
|
||||
If provided, this function is called with `(class_path, kwargs)` where
|
||||
`class_path` is the class path tuple and `kwargs` is the kwargs dict.
|
||||
The validator should raise an exception if the object should not be
|
||||
deserialized, otherwise return `None`. Defaults to
|
||||
`default_init_validator` which blocks jinja2 templates.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If an object's class path is not in the `allowed_objects` allowlist.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain_core.load import load, dumpd
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
msg = AIMessage(content="Hello")
|
||||
data = dumpd(msg)
|
||||
|
||||
# Deserialize using default allowlist
|
||||
loaded = load(data)
|
||||
|
||||
# Or with explicit allowlist
|
||||
loaded = load(data, allowed_objects=[AIMessage])
|
||||
|
||||
# Or extend defaults with additional mappings
|
||||
loaded = load(
|
||||
data,
|
||||
additional_import_mappings={
|
||||
("my_pkg", "MyClass"): ("my_pkg", "module", "MyClass"),
|
||||
},
|
||||
)
|
||||
```
|
||||
"""
|
||||
reviver = Reviver(
|
||||
allowed_objects,
|
||||
secrets_map,
|
||||
valid_namespaces,
|
||||
secrets_from_env,
|
||||
additional_import_mappings,
|
||||
ignore_unserializable_fields=ignore_unserializable_fields,
|
||||
init_validator=init_validator,
|
||||
)
|
||||
|
||||
def _load(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
# Need to revive leaf nodes before reviving this node
|
||||
# Check for escaped dict FIRST (before recursing).
|
||||
# Escaped dicts are user data that should NOT be processed as LC objects.
|
||||
if _is_escaped_dict(obj):
|
||||
return _unescape_value(obj)
|
||||
|
||||
# Not escaped - recurse into children then apply reviver
|
||||
loaded_obj = {k: _load(v) for k, v in obj.items()}
|
||||
return reviver(loaded_obj)
|
||||
if isinstance(obj, list):
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
"""Serialization mapping.
|
||||
|
||||
This file contains a mapping between the lc_namespace path for a given
|
||||
subclass that implements from Serializable to the namespace
|
||||
This file contains a mapping between the `lc_namespace` path for a given
|
||||
subclass that implements from `Serializable` to the namespace
|
||||
where that class is actually located.
|
||||
|
||||
This mapping helps maintain the ability to serialize and deserialize
|
||||
well-known LangChain objects even if they are moved around in the codebase
|
||||
across different LangChain versions.
|
||||
|
||||
For example,
|
||||
For example, the code for the `AIMessage` class is located in
|
||||
`langchain_core.messages.ai.AIMessage`. This message is associated with the
|
||||
`lc_namespace` of `["langchain", "schema", "messages", "AIMessage"]`,
|
||||
because this code was originally in `langchain.schema.messages.AIMessage`.
|
||||
|
||||
The code for AIMessage class is located in langchain_core.messages.ai.AIMessage,
|
||||
This message is associated with the lc_namespace
|
||||
["langchain", "schema", "messages", "AIMessage"],
|
||||
because this code was originally in langchain.schema.messages.AIMessage.
|
||||
|
||||
The mapping allows us to deserialize an AIMessage created with an older
|
||||
The mapping allows us to deserialize an `AIMessage` created with an older
|
||||
version of LangChain where the code was in a different location.
|
||||
"""
|
||||
|
||||
@@ -275,6 +273,11 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
"chat_models",
|
||||
"ChatGroq",
|
||||
),
|
||||
("langchain_xai", "chat_models", "ChatXAI"): (
|
||||
"langchain_xai",
|
||||
"chat_models",
|
||||
"ChatXAI",
|
||||
),
|
||||
("langchain", "chat_models", "fireworks", "ChatFireworks"): (
|
||||
"langchain_fireworks",
|
||||
"chat_models",
|
||||
@@ -529,16 +532,6 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
"structured",
|
||||
"StructuredPrompt",
|
||||
),
|
||||
("langchain_sambanova", "chat_models", "ChatSambaNovaCloud"): (
|
||||
"langchain_sambanova",
|
||||
"chat_models",
|
||||
"ChatSambaNovaCloud",
|
||||
),
|
||||
("langchain_sambanova", "chat_models", "ChatSambaStudio"): (
|
||||
"langchain_sambanova",
|
||||
"chat_models",
|
||||
"ChatSambaStudio",
|
||||
),
|
||||
("langchain_core", "prompts", "message", "_DictMessagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
|
||||
@@ -539,7 +539,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
|
||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||
|
||||
# Get the input messages
|
||||
inputs = load(run.inputs)
|
||||
inputs = load(run.inputs, allowed_objects="all")
|
||||
input_messages = self._get_input_messages(inputs)
|
||||
# If historic messages were prepended to the input messages, remove them to
|
||||
# avoid adding duplicate messages to history.
|
||||
@@ -548,7 +548,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
|
||||
input_messages = input_messages[len(historic_messages) :]
|
||||
|
||||
# Get the output messages
|
||||
output_val = load(run.outputs)
|
||||
output_val = load(run.outputs, allowed_objects="all")
|
||||
output_messages = self._get_output_messages(output_val)
|
||||
hist.add_messages(input_messages + output_messages)
|
||||
|
||||
@@ -556,7 +556,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
|
||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||
|
||||
# Get the input messages
|
||||
inputs = load(run.inputs)
|
||||
inputs = load(run.inputs, allowed_objects="all")
|
||||
input_messages = self._get_input_messages(inputs)
|
||||
# If historic messages were prepended to the input messages, remove them to
|
||||
# avoid adding duplicate messages to history.
|
||||
@@ -565,7 +565,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
|
||||
input_messages = input_messages[len(historic_messages) :]
|
||||
|
||||
# Get the output messages
|
||||
output_val = load(run.outputs)
|
||||
output_val = load(run.outputs, allowed_objects="all")
|
||||
output_messages = self._get_output_messages(output_val)
|
||||
await hist.aadd_messages(input_messages + output_messages)
|
||||
|
||||
|
||||
@@ -563,7 +563,7 @@ def _get_standardized_inputs(
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
inputs = load(run.inputs)
|
||||
inputs = load(run.inputs, allowed_objects="all")
|
||||
|
||||
if run.run_type in {"retriever", "llm", "chat_model"}:
|
||||
return inputs
|
||||
@@ -595,7 +595,7 @@ def _get_standardized_outputs(
|
||||
Returns:
|
||||
An output if returned, otherwise a None
|
||||
"""
|
||||
outputs = load(run.outputs)
|
||||
outputs = load(run.outputs, allowed_objects="all")
|
||||
if schema_format == "original":
|
||||
if run.run_type == "prompt" and "output" in outputs:
|
||||
# These were previously dumped before the tracer.
|
||||
|
||||
@@ -528,7 +528,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
"""
|
||||
path_: Path = Path(path)
|
||||
with path_.open("r", encoding="utf-8") as f:
|
||||
store = load(json.load(f))
|
||||
store = load(json.load(f), allowed_objects=[Document])
|
||||
vectorstore = cls(embedding=embedding, **kwargs)
|
||||
vectorstore.store = store
|
||||
return vectorstore
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
from langchain_core.load import __all__
|
||||
|
||||
EXPECTED_ALL = ["dumpd", "dumps", "load", "loads", "Serializable"]
|
||||
EXPECTED_ALL = [
|
||||
"InitValidator",
|
||||
"Serializable",
|
||||
"dumpd",
|
||||
"dumps",
|
||||
"load",
|
||||
"loads",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
|
||||
431
libs/core/tests/unit_tests/load/test_secret_injection.py
Normal file
431
libs/core/tests/unit_tests/load/test_secret_injection.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""Tests for secret injection prevention in serialization.
|
||||
|
||||
Verify that user-provided data containing secret-like structures cannot be used to
|
||||
extract environment variables during deserialization.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import dumpd, dumps, load
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
SENTINEL_ENV_VAR = "TEST_SECRET_INJECTION_VAR"
|
||||
"""Sentinel value that should NEVER appear in serialized output."""
|
||||
|
||||
SENTINEL_VALUE = "LEAKED_SECRET_MEOW_12345"
|
||||
"""Sentinel value that should NEVER appear in serialized output."""
|
||||
|
||||
MALICIOUS_SECRET_DICT: dict[str, Any] = {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [SENTINEL_ENV_VAR],
|
||||
}
|
||||
"""The malicious secret-like dict that tries to read the env var"""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _set_sentinel_env_var() -> Any:
|
||||
"""Set the sentinel env var for all tests in this module."""
|
||||
with mock.patch.dict(os.environ, {SENTINEL_ENV_VAR: SENTINEL_VALUE}):
|
||||
yield
|
||||
|
||||
|
||||
def _assert_no_secret_leak(payload: Any) -> None:
|
||||
"""Assert that serializing/deserializing payload doesn't leak the secret."""
|
||||
# First serialize
|
||||
serialized = dumps(payload)
|
||||
|
||||
# Deserialize with secrets_from_env=True (the dangerous setting)
|
||||
deserialized = load(serialized, secrets_from_env=True)
|
||||
|
||||
# Re-serialize to string
|
||||
reserialized = dumps(deserialized)
|
||||
|
||||
assert SENTINEL_VALUE not in reserialized, (
|
||||
f"Secret was leaked! Found '{SENTINEL_VALUE}' in output.\n"
|
||||
f"Original payload type: {type(payload)}\n"
|
||||
f"Reserialized output: {reserialized[:500]}..."
|
||||
)
|
||||
|
||||
assert SENTINEL_VALUE not in repr(deserialized), (
|
||||
f"Secret was leaked in deserialized object! Found '{SENTINEL_VALUE}'.\n"
|
||||
f"Deserialized: {deserialized!r}"
|
||||
)
|
||||
|
||||
|
||||
class TestSerializableTopLevel:
|
||||
"""Tests with `Serializable` objects at the top level."""
|
||||
|
||||
def test_human_message_with_secret_in_content(self) -> None:
|
||||
"""`HumanMessage` with secret-like dict in `content`."""
|
||||
msg = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": MALICIOUS_SECRET_DICT},
|
||||
]
|
||||
)
|
||||
_assert_no_secret_leak(msg)
|
||||
|
||||
def test_human_message_with_secret_in_additional_kwargs(self) -> None:
|
||||
"""`HumanMessage` with secret-like dict in `additional_kwargs`."""
|
||||
msg = HumanMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"data": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
_assert_no_secret_leak(msg)
|
||||
|
||||
def test_human_message_with_secret_in_nested_additional_kwargs(self) -> None:
|
||||
"""`HumanMessage` with secret-like dict nested in `additional_kwargs`."""
|
||||
msg = HumanMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"nested": {"deep": MALICIOUS_SECRET_DICT}},
|
||||
)
|
||||
_assert_no_secret_leak(msg)
|
||||
|
||||
def test_human_message_with_secret_in_list_in_additional_kwargs(self) -> None:
|
||||
"""`HumanMessage` with secret-like dict in a list in `additional_kwargs`."""
|
||||
msg = HumanMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"items": [MALICIOUS_SECRET_DICT]},
|
||||
)
|
||||
_assert_no_secret_leak(msg)
|
||||
|
||||
def test_ai_message_with_secret_in_response_metadata(self) -> None:
|
||||
"""`AIMessage` with secret-like dict in respo`nse_metadata."""
|
||||
msg = AIMessage(
|
||||
content="Hello",
|
||||
response_metadata={"data": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
_assert_no_secret_leak(msg)
|
||||
|
||||
def test_document_with_secret_in_metadata(self) -> None:
|
||||
"""Document with secret-like dict in `metadata`."""
|
||||
doc = Document(
|
||||
page_content="Hello",
|
||||
metadata={"data": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
_assert_no_secret_leak(doc)
|
||||
|
||||
def test_nested_serializable_with_secret(self) -> None:
|
||||
"""`AIMessage` containing `dumpd(HumanMessage)` with secret in kwargs."""
|
||||
inner = HumanMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"secret": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
outer = AIMessage(
|
||||
content="Outer",
|
||||
additional_kwargs={"nested": [dumpd(inner)]},
|
||||
)
|
||||
_assert_no_secret_leak(outer)
|
||||
|
||||
|
||||
class TestDictTopLevel:
|
||||
"""Tests with plain dicts at the top level."""
|
||||
|
||||
def test_dict_with_serializable_containing_secret(self) -> None:
|
||||
"""Dict containing a `Serializable` with secret-like dict."""
|
||||
msg = HumanMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"data": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
payload = {"message": msg}
|
||||
_assert_no_secret_leak(payload)
|
||||
|
||||
def test_dict_with_secret_no_serializable(self) -> None:
|
||||
"""Dict with secret-like dict, no `Serializable` objects."""
|
||||
payload = {"data": MALICIOUS_SECRET_DICT}
|
||||
_assert_no_secret_leak(payload)
|
||||
|
||||
def test_dict_with_nested_secret_no_serializable(self) -> None:
|
||||
"""Dict with nested secret-like dict, no `Serializable` objects."""
|
||||
payload = {"outer": {"inner": MALICIOUS_SECRET_DICT}}
|
||||
_assert_no_secret_leak(payload)
|
||||
|
||||
def test_dict_with_secret_in_list(self) -> None:
|
||||
"""Dict with secret-like dict in a list."""
|
||||
payload = {"items": [MALICIOUS_SECRET_DICT]}
|
||||
_assert_no_secret_leak(payload)
|
||||
|
||||
def test_dict_mimicking_lc_constructor_with_secret(self) -> None:
|
||||
"""Dict that looks like an LC constructor containing a secret."""
|
||||
payload = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain_core", "messages", "ai", "AIMessage"],
|
||||
"kwargs": {
|
||||
"content": "Hello",
|
||||
"additional_kwargs": {"secret": MALICIOUS_SECRET_DICT},
|
||||
},
|
||||
}
|
||||
_assert_no_secret_leak(payload)
|
||||
|
||||
|
||||
class TestPydanticModelTopLevel:
|
||||
"""Tests with Pydantic models (non-`Serializable`) at the top level."""
|
||||
|
||||
def test_pydantic_model_with_serializable_containing_secret(self) -> None:
|
||||
"""Pydantic model containing a `Serializable` with secret-like dict."""
|
||||
|
||||
class MyModel(BaseModel):
|
||||
message: Any
|
||||
|
||||
msg = HumanMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"data": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
payload = MyModel(message=msg)
|
||||
_assert_no_secret_leak(payload)
|
||||
|
||||
def test_pydantic_model_with_secret_dict(self) -> None:
|
||||
"""Pydantic model containing a secret-like dict directly."""
|
||||
|
||||
class MyModel(BaseModel):
|
||||
data: dict[str, Any]
|
||||
|
||||
payload = MyModel(data=MALICIOUS_SECRET_DICT)
|
||||
_assert_no_secret_leak(payload)
|
||||
|
||||
# Test treatment of "parsed" in additional_kwargs
|
||||
msg = AIMessage(content=[], additional_kwargs={"parsed": payload})
|
||||
gen = ChatGeneration(message=msg)
|
||||
_assert_no_secret_leak(gen)
|
||||
round_trip = load(dumpd(gen))
|
||||
assert MyModel(**(round_trip.message.additional_kwargs["parsed"])) == payload
|
||||
|
||||
def test_pydantic_model_with_nested_secret(self) -> None:
|
||||
"""Pydantic model with nested secret-like dict."""
|
||||
|
||||
class MyModel(BaseModel):
|
||||
nested: dict[str, Any]
|
||||
|
||||
payload = MyModel(nested={"inner": MALICIOUS_SECRET_DICT})
|
||||
_assert_no_secret_leak(payload)
|
||||
|
||||
|
||||
class TestNonSerializableClassTopLevel:
|
||||
"""Tests with classes at the top level."""
|
||||
|
||||
def test_custom_class_with_serializable_containing_secret(self) -> None:
|
||||
"""Custom class containing a `Serializable` with secret-like dict."""
|
||||
|
||||
class MyClass:
|
||||
def __init__(self, message: Any) -> None:
|
||||
self.message = message
|
||||
|
||||
msg = HumanMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"data": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
payload = MyClass(message=msg)
|
||||
# This will serialize as not_implemented, but let's verify no leak
|
||||
_assert_no_secret_leak(payload)
|
||||
|
||||
def test_custom_class_with_secret_dict(self) -> None:
|
||||
"""Custom class containing a secret-like dict directly."""
|
||||
|
||||
class MyClass:
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
self.data = data
|
||||
|
||||
payload = MyClass(data=MALICIOUS_SECRET_DICT)
|
||||
_assert_no_secret_leak(payload)
|
||||
|
||||
|
||||
class TestDumpdInKwargs:
|
||||
"""Tests for the specific pattern of `dumpd()` result stored in kwargs."""
|
||||
|
||||
def test_dumpd_human_message_in_ai_message_kwargs(self) -> None:
|
||||
"""`AIMessage` with `dumpd(HumanMessage)` in `additional_kwargs`."""
|
||||
h = HumanMessage("Hello")
|
||||
a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]})
|
||||
_assert_no_secret_leak(a)
|
||||
|
||||
def test_dumpd_human_message_with_secret_in_ai_message_kwargs(self) -> None:
|
||||
"""`AIMessage` with `dumpd(HumanMessage w/ secret)` in `additional_kwargs`."""
|
||||
h = HumanMessage(
|
||||
"Hello",
|
||||
additional_kwargs={"secret": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]})
|
||||
_assert_no_secret_leak(a)
|
||||
|
||||
def test_double_dumpd_nesting(self) -> None:
|
||||
"""Double nesting: `dumpd(AIMessage(dumpd(HumanMessage)))`."""
|
||||
h = HumanMessage(
|
||||
"Hello",
|
||||
additional_kwargs={"secret": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]})
|
||||
outer = AIMessage("outer", additional_kwargs={"nested": [dumpd(a)]})
|
||||
_assert_no_secret_leak(outer)
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
"""Tests that verify round-trip serialization preserves data structure."""
|
||||
|
||||
def test_human_message_with_secret_round_trip(self) -> None:
|
||||
"""Verify secret-like dict is preserved as dict after round-trip."""
|
||||
msg = HumanMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"data": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
|
||||
serialized = dumpd(msg)
|
||||
deserialized = load(serialized, secrets_from_env=True)
|
||||
|
||||
# The secret-like dict should be preserved as a plain dict
|
||||
assert deserialized.additional_kwargs["data"] == MALICIOUS_SECRET_DICT
|
||||
assert isinstance(deserialized.additional_kwargs["data"], dict)
|
||||
|
||||
def test_document_with_secret_round_trip(self) -> None:
|
||||
"""Verify secret-like dict in `Document` metadata is preserved."""
|
||||
doc = Document(
|
||||
page_content="Hello",
|
||||
metadata={"data": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
|
||||
serialized = dumpd(doc)
|
||||
deserialized = load(
|
||||
serialized, secrets_from_env=True, allowed_objects=[Document]
|
||||
)
|
||||
|
||||
# The secret-like dict should be preserved as a plain dict
|
||||
assert deserialized.metadata["data"] == MALICIOUS_SECRET_DICT
|
||||
assert isinstance(deserialized.metadata["data"], dict)
|
||||
|
||||
def test_plain_dict_with_secret_round_trip(self) -> None:
|
||||
"""Verify secret-like dict in plain dict is preserved."""
|
||||
payload = {"data": MALICIOUS_SECRET_DICT}
|
||||
|
||||
serialized = dumpd(payload)
|
||||
deserialized = load(serialized, secrets_from_env=True)
|
||||
|
||||
# The secret-like dict should be preserved as a plain dict
|
||||
assert deserialized["data"] == MALICIOUS_SECRET_DICT
|
||||
assert isinstance(deserialized["data"], dict)
|
||||
|
||||
|
||||
class TestEscapingEfficiency:
|
||||
"""Tests that escaping doesn't cause excessive nesting."""
|
||||
|
||||
def test_no_triple_escaping(self) -> None:
|
||||
"""Verify dumpd doesn't cause triple/multiple escaping."""
|
||||
h = HumanMessage(
|
||||
"Hello",
|
||||
additional_kwargs={"bar": [MALICIOUS_SECRET_DICT]},
|
||||
)
|
||||
a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]})
|
||||
d = dumpd(a)
|
||||
|
||||
serialized = json.dumps(d)
|
||||
# Count nested escape markers -
|
||||
# should be max 2 (one for HumanMessage, one for secret)
|
||||
# Not 3+ which would indicate re-escaping of already-escaped content
|
||||
escape_count = len(re.findall(r"__lc_escaped__", serialized))
|
||||
|
||||
# The HumanMessage dict gets escaped (1), the secret inside gets escaped (1)
|
||||
# Total should be 2, not 4 (which would mean triple nesting)
|
||||
assert escape_count <= 2, (
|
||||
f"Found {escape_count} escape markers, expected <= 2. "
|
||||
f"This indicates unnecessary re-escaping.\n{serialized}"
|
||||
)
|
||||
|
||||
def test_double_nesting_no_quadruple_escape(self) -> None:
|
||||
"""Verify double dumpd nesting doesn't explode escape markers."""
|
||||
h = HumanMessage(
|
||||
"Hello",
|
||||
additional_kwargs={"secret": MALICIOUS_SECRET_DICT},
|
||||
)
|
||||
a = AIMessage("middle", additional_kwargs={"nested": [dumpd(h)]})
|
||||
outer = AIMessage("outer", additional_kwargs={"deep": [dumpd(a)]})
|
||||
d = dumpd(outer)
|
||||
|
||||
serialized = json.dumps(d)
|
||||
escape_count = len(re.findall(r"__lc_escaped__", serialized))
|
||||
|
||||
# Should be:
|
||||
# outer escapes middle (1),
|
||||
# middle escapes h (1),
|
||||
# h escapes secret (1) = 3
|
||||
# Not 6+ which would indicate re-escaping
|
||||
assert escape_count <= 3, (
|
||||
f"Found {escape_count} escape markers, expected <= 3. "
|
||||
f"This indicates unnecessary re-escaping."
|
||||
)
|
||||
|
||||
|
||||
class TestConstructorInjection:
|
||||
"""Tests for constructor-type injection (not just secrets)."""
|
||||
|
||||
def test_constructor_in_metadata_not_instantiated(self) -> None:
|
||||
"""Verify constructor-like dict in metadata is not instantiated."""
|
||||
malicious_constructor = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain_core", "messages", "ai", "AIMessage"],
|
||||
"kwargs": {"content": "injected"},
|
||||
}
|
||||
|
||||
doc = Document(
|
||||
page_content="Hello",
|
||||
metadata={"data": malicious_constructor},
|
||||
)
|
||||
|
||||
serialized = dumpd(doc)
|
||||
deserialized = load(
|
||||
serialized,
|
||||
secrets_from_env=True,
|
||||
allowed_objects=[Document, AIMessage],
|
||||
)
|
||||
|
||||
# The constructor-like dict should be a plain dict, NOT an AIMessage
|
||||
assert isinstance(deserialized.metadata["data"], dict)
|
||||
assert deserialized.metadata["data"] == malicious_constructor
|
||||
|
||||
def test_constructor_in_content_not_instantiated(self) -> None:
|
||||
"""Verify constructor-like dict in message content is not instantiated."""
|
||||
malicious_constructor = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain_core", "messages", "human", "HumanMessage"],
|
||||
"kwargs": {"content": "injected"},
|
||||
}
|
||||
|
||||
msg = AIMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"nested": malicious_constructor},
|
||||
)
|
||||
|
||||
serialized = dumpd(msg)
|
||||
deserialized = load(
|
||||
serialized,
|
||||
secrets_from_env=True,
|
||||
allowed_objects=[AIMessage, HumanMessage],
|
||||
)
|
||||
|
||||
# The constructor-like dict should be a plain dict, NOT a HumanMessage
|
||||
assert isinstance(deserialized.additional_kwargs["nested"], dict)
|
||||
assert deserialized.additional_kwargs["nested"] == malicious_constructor
|
||||
|
||||
|
||||
def test_allowed_objects() -> None:
|
||||
# Core object
|
||||
msg = AIMessage(content="foo")
|
||||
serialized = dumpd(msg)
|
||||
assert load(serialized) == msg
|
||||
assert load(serialized, allowed_objects=[AIMessage]) == msg
|
||||
assert load(serialized, allowed_objects="core") == msg
|
||||
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
load(serialized, allowed_objects=[])
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
load(serialized, allowed_objects=[Document])
|
||||
@@ -1,12 +1,19 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
|
||||
from langchain_core.load import Serializable, dumpd, dumps, load
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import InitValidator, Serializable, dumpd, dumps, load, loads
|
||||
from langchain_core.load.serializable import _is_field_useful
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
)
|
||||
|
||||
|
||||
class NonBoolObj:
|
||||
@@ -145,10 +152,17 @@ def test_simple_deserialization() -> None:
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
}
|
||||
new_foo = load(serialized_foo, valid_namespaces=["tests"])
|
||||
new_foo = load(serialized_foo, allowed_objects=[Foo], valid_namespaces=["tests"])
|
||||
assert new_foo == foo
|
||||
|
||||
|
||||
def test_disallowed_deserialization() -> None:
|
||||
foo = Foo(bar=1, baz="hello")
|
||||
serialized_foo = dumpd(foo)
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
load(serialized_foo, allowed_objects=[], valid_namespaces=["tests"])
|
||||
|
||||
|
||||
class Foo2(Serializable):
|
||||
bar: int
|
||||
baz: str
|
||||
@@ -170,6 +184,7 @@ def test_simple_deserialization_with_additional_imports() -> None:
|
||||
}
|
||||
new_foo = load(
|
||||
serialized_foo,
|
||||
allowed_objects=[Foo2],
|
||||
valid_namespaces=["tests"],
|
||||
additional_import_mappings={
|
||||
("tests", "unit_tests", "load", "test_serializable", "Foo"): (
|
||||
@@ -223,7 +238,7 @@ def test_serialization_with_pydantic() -> None:
|
||||
)
|
||||
)
|
||||
ser = dumpd(llm_response)
|
||||
deser = load(ser)
|
||||
deser = load(ser, allowed_objects=[ChatGeneration, AIMessage])
|
||||
assert isinstance(deser, ChatGeneration)
|
||||
assert deser.message.content
|
||||
assert deser.message.additional_kwargs["parsed"] == my_model.model_dump()
|
||||
@@ -260,8 +275,8 @@ def test_serialization_with_ignore_unserializable_fields() -> None:
|
||||
]
|
||||
]
|
||||
}
|
||||
ser = dumpd(data)
|
||||
deser = load(ser, ignore_unserializable_fields=True)
|
||||
# Load directly (no dumpd - this is already serialized data)
|
||||
deser = load(data, allowed_objects=[AIMessage], ignore_unserializable_fields=True)
|
||||
assert deser == {
|
||||
"messages": [
|
||||
[
|
||||
@@ -365,3 +380,514 @@ def test_dumps_mixed_data_structure() -> None:
|
||||
# Primitives should remain unchanged
|
||||
assert parsed["list"] == [1, 2, {"nested": "value"}]
|
||||
assert parsed["primitive"] == "string"
|
||||
|
||||
|
||||
def test_document_normal_metadata_allowed() -> None:
|
||||
"""Test that `Document` metadata without `'lc'` key works fine."""
|
||||
doc = Document(
|
||||
page_content="Hello world",
|
||||
metadata={"source": "test.txt", "page": 1, "nested": {"key": "value"}},
|
||||
)
|
||||
serialized = dumpd(doc)
|
||||
|
||||
loaded = load(serialized, allowed_objects=[Document])
|
||||
assert loaded.page_content == "Hello world"
|
||||
|
||||
expected = {"source": "test.txt", "page": 1, "nested": {"key": "value"}}
|
||||
assert loaded.metadata == expected
|
||||
|
||||
|
||||
class TestEscaping:
|
||||
"""Tests that escape-based serialization prevents injection attacks.
|
||||
|
||||
When user data contains an `'lc'` key, it's escaped during serialization
|
||||
(wrapped in `{"__lc_escaped__": ...}`). During deserialization, escaped
|
||||
dicts are unwrapped and returned as plain dicts - NOT instantiated as
|
||||
LC objects.
|
||||
"""
|
||||
|
||||
def test_document_metadata_with_lc_key_escaped(self) -> None:
|
||||
"""Test that `Document` metadata with `'lc'` key round-trips as plain dict."""
|
||||
# User data that looks like an LC constructor - should be escaped, not executed
|
||||
suspicious_metadata = {"lc": 1, "type": "constructor", "id": ["some", "module"]}
|
||||
doc = Document(page_content="test", metadata=suspicious_metadata)
|
||||
|
||||
# Serialize - should escape the metadata
|
||||
serialized = dumpd(doc)
|
||||
assert serialized["kwargs"]["metadata"] == {
|
||||
"__lc_escaped__": suspicious_metadata
|
||||
}
|
||||
|
||||
# Deserialize - should restore original metadata as plain dict
|
||||
loaded = load(serialized, allowed_objects=[Document])
|
||||
assert loaded.metadata == suspicious_metadata # Plain dict, not instantiated
|
||||
|
||||
def test_document_metadata_with_nested_lc_key_escaped(self) -> None:
|
||||
"""Test that nested `'lc'` key in `Document` metadata is escaped."""
|
||||
suspicious_nested = {"lc": 1, "type": "constructor", "id": ["some", "module"]}
|
||||
doc = Document(page_content="test", metadata={"nested": suspicious_nested})
|
||||
|
||||
serialized = dumpd(doc)
|
||||
# The nested dict with 'lc' key should be escaped
|
||||
assert serialized["kwargs"]["metadata"]["nested"] == {
|
||||
"__lc_escaped__": suspicious_nested
|
||||
}
|
||||
|
||||
loaded = load(serialized, allowed_objects=[Document])
|
||||
assert loaded.metadata == {"nested": suspicious_nested}
|
||||
|
||||
def test_document_metadata_with_lc_key_in_list_escaped(self) -> None:
|
||||
"""Test that `'lc'` key in list items within `Document` metadata is escaped."""
|
||||
suspicious_item = {"lc": 1, "type": "constructor", "id": ["some", "module"]}
|
||||
doc = Document(page_content="test", metadata={"items": [suspicious_item]})
|
||||
|
||||
serialized = dumpd(doc)
|
||||
assert serialized["kwargs"]["metadata"]["items"][0] == {
|
||||
"__lc_escaped__": suspicious_item
|
||||
}
|
||||
|
||||
loaded = load(serialized, allowed_objects=[Document])
|
||||
assert loaded.metadata == {"items": [suspicious_item]}
|
||||
|
||||
def test_malicious_payload_not_instantiated(self) -> None:
|
||||
"""Test that malicious LC-like structures in user data are NOT instantiated."""
|
||||
# An attacker might craft a payload with a valid AIMessage structure in metadata
|
||||
malicious_data = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "document", "Document"],
|
||||
"kwargs": {
|
||||
"page_content": "test",
|
||||
"metadata": {
|
||||
# This looks like a valid LC object but is in escaped form
|
||||
"__lc_escaped__": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain_core", "messages", "ai", "AIMessage"],
|
||||
"kwargs": {"content": "injected message"},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Even though AIMessage is allowed, the metadata should remain as dict
|
||||
loaded = load(malicious_data, allowed_objects=[Document, AIMessage])
|
||||
assert loaded.page_content == "test"
|
||||
# The metadata is the original dict (unescaped), NOT an AIMessage instance
|
||||
assert loaded.metadata == {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain_core", "messages", "ai", "AIMessage"],
|
||||
"kwargs": {"content": "injected message"},
|
||||
}
|
||||
assert not isinstance(loaded.metadata, AIMessage)
|
||||
|
||||
def test_message_additional_kwargs_with_lc_key_escaped(self) -> None:
|
||||
"""Test that `AIMessage` `additional_kwargs` with `'lc'` is escaped."""
|
||||
suspicious_data = {"lc": 1, "type": "constructor", "id": ["x", "y"]}
|
||||
msg = AIMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"data": suspicious_data},
|
||||
)
|
||||
|
||||
serialized = dumpd(msg)
|
||||
assert serialized["kwargs"]["additional_kwargs"]["data"] == {
|
||||
"__lc_escaped__": suspicious_data
|
||||
}
|
||||
|
||||
loaded = load(serialized, allowed_objects=[AIMessage])
|
||||
assert loaded.additional_kwargs == {"data": suspicious_data}
|
||||
|
||||
def test_message_response_metadata_with_lc_key_escaped(self) -> None:
|
||||
"""Test that `AIMessage` `response_metadata` with `'lc'` is escaped."""
|
||||
suspicious_data = {"lc": 1, "type": "constructor", "id": ["x", "y"]}
|
||||
msg = AIMessage(content="Hello", response_metadata=suspicious_data)
|
||||
|
||||
serialized = dumpd(msg)
|
||||
assert serialized["kwargs"]["response_metadata"] == {
|
||||
"__lc_escaped__": suspicious_data
|
||||
}
|
||||
|
||||
loaded = load(serialized, allowed_objects=[AIMessage])
|
||||
assert loaded.response_metadata == suspicious_data
|
||||
|
||||
def test_double_escape_handling(self) -> None:
|
||||
"""Test that data containing escape key itself is properly handled."""
|
||||
# User data that contains our escape key
|
||||
data_with_escape_key = {"__lc_escaped__": "some_value"}
|
||||
doc = Document(page_content="test", metadata=data_with_escape_key)
|
||||
|
||||
serialized = dumpd(doc)
|
||||
# Should be double-escaped since it looks like an escaped dict
|
||||
assert serialized["kwargs"]["metadata"] == {
|
||||
"__lc_escaped__": {"__lc_escaped__": "some_value"}
|
||||
}
|
||||
|
||||
loaded = load(serialized, allowed_objects=[Document])
|
||||
assert loaded.metadata == {"__lc_escaped__": "some_value"}
|
||||
|
||||
|
||||
class TestDumpdEscapesLcKeyInPlainDicts:
|
||||
"""Tests that `dumpd()` escapes `'lc'` keys in plain dict kwargs."""
|
||||
|
||||
def test_normal_message_not_escaped(self) -> None:
|
||||
"""Test that normal `AIMessage` without `'lc'` key is not escaped."""
|
||||
msg = AIMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"tool_calls": []},
|
||||
response_metadata={"model": "gpt-4"},
|
||||
)
|
||||
serialized = dumpd(msg)
|
||||
assert serialized["kwargs"]["content"] == "Hello"
|
||||
# No escape wrappers for normal data
|
||||
assert "__lc_escaped__" not in str(serialized)
|
||||
|
||||
def test_document_metadata_with_lc_key_escaped(self) -> None:
|
||||
"""Test that `Document` with `'lc'` key in metadata is escaped."""
|
||||
doc = Document(
|
||||
page_content="test",
|
||||
metadata={"lc": 1, "type": "constructor"},
|
||||
)
|
||||
|
||||
serialized = dumpd(doc)
|
||||
# Should be escaped, not blocked
|
||||
assert serialized["kwargs"]["metadata"] == {
|
||||
"__lc_escaped__": {"lc": 1, "type": "constructor"}
|
||||
}
|
||||
|
||||
def test_document_metadata_with_nested_lc_key_escaped(self) -> None:
|
||||
"""Test that `Document` with nested `'lc'` in metadata is escaped."""
|
||||
doc = Document(
|
||||
page_content="test",
|
||||
metadata={"nested": {"lc": 1}},
|
||||
)
|
||||
|
||||
serialized = dumpd(doc)
|
||||
assert serialized["kwargs"]["metadata"]["nested"] == {
|
||||
"__lc_escaped__": {"lc": 1}
|
||||
}
|
||||
|
||||
def test_message_additional_kwargs_with_lc_key_escaped(self) -> None:
|
||||
"""Test `AIMessage` with `'lc'` in `additional_kwargs` is escaped."""
|
||||
msg = AIMessage(
|
||||
content="Hello",
|
||||
additional_kwargs={"malicious": {"lc": 1}},
|
||||
)
|
||||
|
||||
serialized = dumpd(msg)
|
||||
assert serialized["kwargs"]["additional_kwargs"]["malicious"] == {
|
||||
"__lc_escaped__": {"lc": 1}
|
||||
}
|
||||
|
||||
def test_message_response_metadata_with_lc_key_escaped(self) -> None:
|
||||
"""Test `AIMessage` with `'lc'` in `response_metadata` is escaped."""
|
||||
msg = AIMessage(
|
||||
content="Hello",
|
||||
response_metadata={"lc": 1},
|
||||
)
|
||||
|
||||
serialized = dumpd(msg)
|
||||
assert serialized["kwargs"]["response_metadata"] == {
|
||||
"__lc_escaped__": {"lc": 1}
|
||||
}
|
||||
|
||||
|
||||
class TestInitValidator:
|
||||
"""Tests for `init_validator` on `load()` and `loads()`."""
|
||||
|
||||
def test_init_validator_allows_valid_kwargs(self) -> None:
|
||||
"""Test that `init_validator` returning None allows deserialization."""
|
||||
msg = AIMessage(content="Hello")
|
||||
serialized = dumpd(msg)
|
||||
|
||||
def allow_all(_class_path: tuple[str, ...], _kwargs: dict[str, Any]) -> None:
|
||||
pass # Allow all by doing nothing
|
||||
|
||||
loaded = load(serialized, allowed_objects=[AIMessage], init_validator=allow_all)
|
||||
assert loaded == msg
|
||||
|
||||
def test_init_validator_blocks_deserialization(self) -> None:
|
||||
"""Test that `init_validator` can block deserialization by raising."""
|
||||
doc = Document(page_content="test", metadata={"source": "test.txt"})
|
||||
serialized = dumpd(doc)
|
||||
|
||||
def block_metadata(
|
||||
_class_path: tuple[str, ...], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
if "metadata" in kwargs:
|
||||
msg = "Metadata not allowed"
|
||||
raise ValueError(msg)
|
||||
|
||||
with pytest.raises(ValueError, match="Metadata not allowed"):
|
||||
load(serialized, allowed_objects=[Document], init_validator=block_metadata)
|
||||
|
||||
def test_init_validator_receives_correct_class_path(self) -> None:
|
||||
"""Test that `init_validator` receives the correct class path."""
|
||||
msg = AIMessage(content="Hello")
|
||||
serialized = dumpd(msg)
|
||||
|
||||
received_class_paths: list[tuple[str, ...]] = []
|
||||
|
||||
def capture_class_path(
|
||||
class_path: tuple[str, ...], _kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
received_class_paths.append(class_path)
|
||||
|
||||
load(serialized, allowed_objects=[AIMessage], init_validator=capture_class_path)
|
||||
|
||||
assert len(received_class_paths) == 1
|
||||
assert received_class_paths[0] == (
|
||||
"langchain",
|
||||
"schema",
|
||||
"messages",
|
||||
"AIMessage",
|
||||
)
|
||||
|
||||
def test_init_validator_receives_correct_kwargs(self) -> None:
|
||||
"""Test that `init_validator` receives the kwargs dict."""
|
||||
msg = AIMessage(content="Hello world", name="test_name")
|
||||
serialized = dumpd(msg)
|
||||
|
||||
received_kwargs: list[dict[str, Any]] = []
|
||||
|
||||
def capture_kwargs(
|
||||
_class_path: tuple[str, ...], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
received_kwargs.append(kwargs)
|
||||
|
||||
load(serialized, allowed_objects=[AIMessage], init_validator=capture_kwargs)
|
||||
|
||||
assert len(received_kwargs) == 1
|
||||
assert "content" in received_kwargs[0]
|
||||
assert received_kwargs[0]["content"] == "Hello world"
|
||||
assert "name" in received_kwargs[0]
|
||||
assert received_kwargs[0]["name"] == "test_name"
|
||||
|
||||
def test_init_validator_with_loads(self) -> None:
|
||||
"""Test that `init_validator` works with `loads()` function."""
|
||||
doc = Document(page_content="test", metadata={"key": "value"})
|
||||
json_str = dumps(doc)
|
||||
|
||||
def block_metadata(
|
||||
_class_path: tuple[str, ...], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
if "metadata" in kwargs:
|
||||
msg = "Metadata not allowed"
|
||||
raise ValueError(msg)
|
||||
|
||||
with pytest.raises(ValueError, match="Metadata not allowed"):
|
||||
loads(json_str, allowed_objects=[Document], init_validator=block_metadata)
|
||||
|
||||
def test_init_validator_none_allows_all(self) -> None:
|
||||
"""Test that `init_validator=None` (default) allows all kwargs."""
|
||||
msg = AIMessage(content="Hello")
|
||||
serialized = dumpd(msg)
|
||||
|
||||
# Should work without init_validator
|
||||
loaded = load(serialized, allowed_objects=[AIMessage])
|
||||
assert loaded == msg
|
||||
|
||||
def test_init_validator_type_alias_exists(self) -> None:
|
||||
"""Test that `InitValidator` type alias is exported and usable."""
|
||||
|
||||
def my_validator(_class_path: tuple[str, ...], _kwargs: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
validator_typed: InitValidator = my_validator
|
||||
assert callable(validator_typed)
|
||||
|
||||
def test_init_validator_blocks_specific_class(self) -> None:
|
||||
"""Test blocking deserialization for a specific class."""
|
||||
doc = Document(page_content="test", metadata={"source": "test.txt"})
|
||||
serialized = dumpd(doc)
|
||||
|
||||
def block_documents(
|
||||
class_path: tuple[str, ...], _kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
if class_path == ("langchain", "schema", "document", "Document"):
|
||||
msg = "Documents not allowed"
|
||||
raise ValueError(msg)
|
||||
|
||||
with pytest.raises(ValueError, match="Documents not allowed"):
|
||||
load(serialized, allowed_objects=[Document], init_validator=block_documents)
|
||||
|
||||
|
||||
class TestJinja2SecurityBlocking:
|
||||
"""Tests blocking Jinja2 templates by default."""
|
||||
|
||||
def test_fstring_template_allowed(self) -> None:
|
||||
"""Test that f-string templates deserialize successfully."""
|
||||
# Serialized ChatPromptTemplate with f-string format
|
||||
serialized = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "prompts", "chat", "ChatPromptTemplate"],
|
||||
"kwargs": {
|
||||
"input_variables": ["name"],
|
||||
"messages": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"HumanMessagePromptTemplate",
|
||||
],
|
||||
"kwargs": {
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate",
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": ["name"],
|
||||
"template": "Hello {name}",
|
||||
"template_format": "f-string",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# f-string should deserialize successfully
|
||||
loaded = load(
|
||||
serialized,
|
||||
allowed_objects=[
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
],
|
||||
)
|
||||
assert isinstance(loaded, ChatPromptTemplate)
|
||||
assert loaded.input_variables == ["name"]
|
||||
|
||||
def test_jinja2_template_blocked(self) -> None:
|
||||
"""Test that Jinja2 templates are blocked by default."""
|
||||
# Malicious serialized payload attempting to use jinja2
|
||||
malicious_serialized = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "prompts", "chat", "ChatPromptTemplate"],
|
||||
"kwargs": {
|
||||
"input_variables": ["name"],
|
||||
"messages": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"HumanMessagePromptTemplate",
|
||||
],
|
||||
"kwargs": {
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate",
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": ["name"],
|
||||
"template": "{{ name }}",
|
||||
"template_format": "jinja2",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# jinja2 should be blocked by default
|
||||
with pytest.raises(ValueError, match="Jinja2 templates are not allowed"):
|
||||
load(
|
||||
malicious_serialized,
|
||||
allowed_objects=[
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
],
|
||||
)
|
||||
|
||||
def test_jinja2_blocked_standalone_prompt_template(self) -> None:
|
||||
"""Test blocking Jinja2 on standalone `PromptTemplate`."""
|
||||
serialized_jinja2 = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
|
||||
"kwargs": {
|
||||
"input_variables": ["name"],
|
||||
"template": "{{ name }}",
|
||||
"template_format": "jinja2",
|
||||
},
|
||||
}
|
||||
|
||||
serialized_fstring = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
|
||||
"kwargs": {
|
||||
"input_variables": ["name"],
|
||||
"template": "{name}",
|
||||
"template_format": "f-string",
|
||||
},
|
||||
}
|
||||
|
||||
# f-string should work
|
||||
loaded = load(
|
||||
serialized_fstring,
|
||||
allowed_objects=[PromptTemplate],
|
||||
)
|
||||
assert isinstance(loaded, PromptTemplate)
|
||||
assert loaded.template == "{name}"
|
||||
|
||||
# jinja2 should be blocked by default
|
||||
with pytest.raises(ValueError, match="Jinja2 templates are not allowed"):
|
||||
load(
|
||||
serialized_jinja2,
|
||||
allowed_objects=[PromptTemplate],
|
||||
)
|
||||
|
||||
def test_jinja2_blocked_by_default(self) -> None:
|
||||
"""Test that Jinja2 templates are blocked by default."""
|
||||
serialized_jinja2 = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
|
||||
"kwargs": {
|
||||
"input_variables": ["name"],
|
||||
"template": "{{ name }}",
|
||||
"template_format": "jinja2",
|
||||
},
|
||||
}
|
||||
|
||||
serialized_fstring = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
|
||||
"kwargs": {
|
||||
"input_variables": ["name"],
|
||||
"template": "{name}",
|
||||
"template_format": "f-string",
|
||||
},
|
||||
}
|
||||
|
||||
# f-string should work
|
||||
loaded = load(serialized_fstring, allowed_objects=[PromptTemplate])
|
||||
assert isinstance(loaded, PromptTemplate)
|
||||
assert loaded.template == "{name}"
|
||||
|
||||
# jinja2 should be blocked by default
|
||||
with pytest.raises(ValueError, match="Jinja2 templates are not allowed"):
|
||||
load(serialized_jinja2, allowed_objects=[PromptTemplate])
|
||||
|
||||
@@ -47,7 +47,7 @@ def test_serdes_message() -> None:
|
||||
}
|
||||
actual = dumpd(msg)
|
||||
assert actual == expected
|
||||
assert load(actual) == msg
|
||||
assert load(actual, allowed_objects=[AIMessage]) == msg
|
||||
|
||||
|
||||
def test_serdes_message_chunk() -> None:
|
||||
@@ -102,7 +102,7 @@ def test_serdes_message_chunk() -> None:
|
||||
}
|
||||
actual = dumpd(chunk)
|
||||
assert actual == expected
|
||||
assert load(actual) == chunk
|
||||
assert load(actual, allowed_objects=[AIMessageChunk]) == chunk
|
||||
|
||||
|
||||
def test_add_usage_both_none() -> None:
|
||||
|
||||
@@ -1123,7 +1123,7 @@ def test_data_prompt_template_deserializable() -> None:
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,4 +31,4 @@ def test_deserialize_legacy() -> None:
|
||||
expected = DictPromptTemplate(
|
||||
template={"type": "audio", "audio": "{audio_data}"}, template_format="f-string"
|
||||
)
|
||||
assert load(ser) == expected
|
||||
assert load(ser, allowed_objects=[DictPromptTemplate]) == expected
|
||||
|
||||
@@ -11,7 +11,7 @@ def test_image_prompt_template_deserializable() -> None:
|
||||
ChatPromptTemplate.from_messages(
|
||||
[("system", [{"type": "image", "image_url": "{img}"}])]
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -105,5 +105,5 @@ def test_image_prompt_template_deserializable_old() -> None:
|
||||
"input_variables": ["img", "input"],
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -82,7 +82,6 @@ def test_structured_prompt_dict() -> None:
|
||||
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
|
||||
|
||||
chain = loads(dumps(prompt)) | model
|
||||
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
|
||||
|
||||
|
||||
|
||||
@@ -925,7 +925,7 @@ def test_tool_message_serdes() -> None:
|
||||
},
|
||||
}
|
||||
assert dumpd(message) == ser_message
|
||||
assert load(dumpd(message)) == message
|
||||
assert load(dumpd(message), allowed_objects=[ToolMessage]) == message
|
||||
|
||||
|
||||
class BadObject:
|
||||
@@ -954,7 +954,7 @@ def test_tool_message_ser_non_serializable() -> None:
|
||||
}
|
||||
assert dumpd(message) == ser_message
|
||||
with pytest.raises(NotImplementedError):
|
||||
load(dumpd(ser_message))
|
||||
load(dumpd(message), allowed_objects=[ToolMessage])
|
||||
|
||||
|
||||
def test_tool_message_to_dict() -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.load.dump import dumps
|
||||
from langchain_core.load.load import loads
|
||||
@@ -139,7 +139,8 @@ def pull(
|
||||
if hasattr(client, "pull_repo"):
|
||||
# >= 0.1.15
|
||||
res_dict = client.pull_repo(owner_repo_commit)
|
||||
obj = loads(json.dumps(res_dict["manifest"]))
|
||||
allowed_objects: Literal["all", "core"] = "all" if include_model else "core"
|
||||
obj = loads(json.dumps(res_dict["manifest"]), allowed_objects=allowed_objects)
|
||||
if isinstance(obj, BasePromptTemplate):
|
||||
if obj.metadata is None:
|
||||
obj.metadata = {}
|
||||
|
||||
@@ -7,6 +7,8 @@ from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-unt
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
_MODEL = "claude-3-haiku-20240307"
|
||||
|
||||
|
||||
class TestAnthropicStandard(ChatModelUnitTests):
|
||||
"""Use the standard chat model unit tests against the `ChatAnthropic` class."""
|
||||
@@ -17,7 +19,15 @@ class TestAnthropicStandard(ChatModelUnitTests):
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "claude-3-haiku-20240307"}
|
||||
return {"model": _MODEL}
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||
return (
|
||||
{"ANTHROPIC_API_KEY": "test"},
|
||||
{"model": _MODEL},
|
||||
{"anthropic_api_key": "test"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
|
||||
6
libs/partners/anthropic/uv.lock
generated
6
libs/partners/anthropic/uv.lock
generated
@@ -495,7 +495,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "1.1.3"
|
||||
version = "1.2.0"
|
||||
source = { editable = "../../langchain_v1" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
@@ -642,7 +642,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.2.0"
|
||||
version = "1.2.4"
|
||||
source = { editable = "../../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -702,7 +702,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-tests"
|
||||
version = "1.1.0"
|
||||
version = "1.1.1"
|
||||
source = { editable = "../../standard-tests" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
|
||||
@@ -274,6 +274,7 @@ def test_groq_serialization() -> None:
|
||||
dump,
|
||||
valid_namespaces=["langchain_groq"],
|
||||
secrets_map={"GROQ_API_KEY": api_key2},
|
||||
allowed_objects="all",
|
||||
)
|
||||
|
||||
assert type(llm2) is ChatGroq
|
||||
|
||||
@@ -1360,7 +1360,7 @@ def test_structured_outputs_parser() -> None:
|
||||
partial(_oai_structured_outputs_parser, schema=GenerateUsername)
|
||||
)
|
||||
serialized = dumps(llm_output)
|
||||
deserialized = loads(serialized)
|
||||
deserialized = loads(serialized, allowed_objects=[ChatGeneration, AIMessage])
|
||||
assert isinstance(deserialized, ChatGeneration)
|
||||
result = output_parser.invoke(cast(AIMessage, deserialized.message))
|
||||
assert result == parsed_response
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from langchain_core.load import dumpd, dumps, load, loads
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables import RunnableSequence
|
||||
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
|
||||
@@ -6,7 +9,11 @@ from langchain_openai import ChatOpenAI, OpenAI
|
||||
def test_loads_openai_llm() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello", top_p=0.8) # type: ignore[call-arg]
|
||||
llm_string = dumps(llm)
|
||||
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
llm2 = loads(
|
||||
llm_string,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[OpenAI],
|
||||
)
|
||||
|
||||
assert llm2.dict() == llm.dict()
|
||||
llm_string_2 = dumps(llm2)
|
||||
@@ -17,7 +24,11 @@ def test_loads_openai_llm() -> None:
|
||||
def test_load_openai_llm() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
llm_obj = dumpd(llm)
|
||||
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
llm2 = load(
|
||||
llm_obj,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[OpenAI],
|
||||
)
|
||||
|
||||
assert llm2.dict() == llm.dict()
|
||||
assert dumpd(llm2) == llm_obj
|
||||
@@ -27,7 +38,11 @@ def test_load_openai_llm() -> None:
|
||||
def test_loads_openai_chat() -> None:
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
llm_string = dumps(llm)
|
||||
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
llm2 = loads(
|
||||
llm_string,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[ChatOpenAI],
|
||||
)
|
||||
|
||||
assert llm2.dict() == llm.dict()
|
||||
llm_string_2 = dumps(llm2)
|
||||
@@ -38,8 +53,85 @@ def test_loads_openai_chat() -> None:
|
||||
def test_load_openai_chat() -> None:
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
llm_obj = dumpd(llm)
|
||||
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
llm2 = load(
|
||||
llm_obj,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[ChatOpenAI],
|
||||
)
|
||||
|
||||
assert llm2.dict() == llm.dict()
|
||||
assert dumpd(llm2) == llm_obj
|
||||
assert isinstance(llm2, ChatOpenAI)
|
||||
|
||||
|
||||
def test_loads_runnable_sequence_prompt_model() -> None:
|
||||
"""Test serialization/deserialization of a chain:
|
||||
|
||||
`prompt | model (RunnableSequence)`
|
||||
"""
|
||||
prompt = ChatPromptTemplate.from_messages([("user", "Hello, {name}!")])
|
||||
model = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
chain = prompt | model
|
||||
|
||||
# Verify the chain is a RunnableSequence
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
|
||||
# Serialize
|
||||
chain_string = dumps(chain)
|
||||
|
||||
# Deserialize
|
||||
# (ChatPromptTemplate contains HumanMessagePromptTemplate and PromptTemplate)
|
||||
chain2 = loads(
|
||||
chain_string,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[
|
||||
RunnableSequence,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
ChatOpenAI,
|
||||
],
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(chain2, RunnableSequence)
|
||||
assert isinstance(chain2.first, ChatPromptTemplate)
|
||||
assert isinstance(chain2.last, ChatOpenAI)
|
||||
|
||||
# Verify round-trip serialization
|
||||
assert dumps(chain2) == chain_string
|
||||
|
||||
|
||||
def test_load_runnable_sequence_prompt_model() -> None:
|
||||
"""Test load() with a chain:
|
||||
|
||||
`prompt | model (RunnableSequence)`.
|
||||
"""
|
||||
prompt = ChatPromptTemplate.from_messages([("user", "Tell me about {topic}")])
|
||||
model = ChatOpenAI(model="gpt-4o-mini", temperature=0.7, openai_api_key="hello") # type: ignore[call-arg]
|
||||
chain = prompt | model
|
||||
|
||||
# Serialize
|
||||
chain_obj = dumpd(chain)
|
||||
|
||||
# Deserialize
|
||||
# (ChatPromptTemplate contains HumanMessagePromptTemplate and PromptTemplate)
|
||||
chain2 = load(
|
||||
chain_obj,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[
|
||||
RunnableSequence,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
ChatOpenAI,
|
||||
],
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(chain2, RunnableSequence)
|
||||
assert isinstance(chain2.first, ChatPromptTemplate)
|
||||
assert isinstance(chain2.last, ChatOpenAI)
|
||||
|
||||
# Verify round-trip serialization
|
||||
assert dumpd(chain2) == chain_obj
|
||||
|
||||
4
libs/partners/xai/uv.lock
generated
4
libs/partners/xai/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.10.0, <4.0.0"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.13' and platform_python_implementation == 'PyPy'",
|
||||
@@ -621,7 +621,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.2.3"
|
||||
version = "1.2.4"
|
||||
source = { editable = "../../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
|
||||
@@ -1126,7 +1126,10 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
assert (
|
||||
model.dict()
|
||||
== load(
|
||||
dumpd(model), valid_namespaces=model.get_lc_namespace()[:1]
|
||||
dumpd(model),
|
||||
valid_namespaces=model.get_lc_namespace()[:1],
|
||||
allowed_objects="all",
|
||||
secrets_from_env=True,
|
||||
).dict()
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user