diff --git a/libs/core/langchain_core/load/__init__.py b/libs/core/langchain_core/load/__init__.py index baeb72a488a..e656fedaa19 100644 --- a/libs/core/langchain_core/load/__init__.py +++ b/libs/core/langchain_core/load/__init__.py @@ -6,7 +6,7 @@ from langchain_core._import_utils import import_attr if TYPE_CHECKING: from langchain_core.load.dump import dumpd, dumps - from langchain_core.load.load import loads + from langchain_core.load.load import InitValidator, loads from langchain_core.load.serializable import Serializable # Unfortunately, we have to eagerly import load from langchain_core/load/load.py @@ -15,11 +15,19 @@ if TYPE_CHECKING: # the `from langchain_core.load.load import load` absolute import should also work. from langchain_core.load.load import load -__all__ = ("Serializable", "dumpd", "dumps", "load", "loads") +__all__ = ( + "InitValidator", + "Serializable", + "dumpd", + "dumps", + "load", + "loads", +) _dynamic_imports = { "dumpd": "dump", "dumps": "dump", + "InitValidator": "load", "loads": "load", "Serializable": "serializable", } diff --git a/libs/core/langchain_core/load/_validation.py b/libs/core/langchain_core/load/_validation.py new file mode 100644 index 00000000000..a8adc36f040 --- /dev/null +++ b/libs/core/langchain_core/load/_validation.py @@ -0,0 +1,176 @@ +"""Validation utilities for LangChain serialization. + +Provides escape-based protection against injection attacks in serialized objects. The +approach uses an allowlist design: only dicts explicitly produced by +`Serializable.to_json()` are treated as LC objects during deserialization. + +## How escaping works + +During serialization, plain dicts (user data) that contain an `'lc'` key are wrapped: + +```python +{"lc": 1, ...} # user data that looks like LC object +# becomes: +{"__lc_escaped__": {"lc": 1, ...}} +``` + +During deserialization, escaped dicts are unwrapped and returned as plain dicts, +NOT instantiated as LC objects. +""" + +from typing import Any + +_LC_ESCAPED_KEY = "__lc_escaped__" +"""Sentinel key used to mark escaped user dicts during serialization. + +When a plain dict contains 'lc' key (which could be confused with LC objects), +we wrap it as {"__lc_escaped__": {...original...}}. +""" + + +def _needs_escaping(obj: dict[str, Any]) -> bool: + """Check if a dict needs escaping to prevent confusion with LC objects. + + A dict needs escaping if: + + 1. It has an `'lc'` key (could be confused with LC serialization format) + 2. It has only the escape key (would be mistaken for an escaped dict) + """ + return "lc" in obj or (len(obj) == 1 and _LC_ESCAPED_KEY in obj) + + +def _escape_dict(obj: dict[str, Any]) -> dict[str, Any]: + """Wrap a dict in the escape marker. + + Example: + ```python + {"key": "value"} # becomes {"__lc_escaped__": {"key": "value"}} + ``` + """ + return {_LC_ESCAPED_KEY: obj} + + +def _is_escaped_dict(obj: dict[str, Any]) -> bool: + """Check if a dict is an escaped user dict. + + Example: + ```python + {"__lc_escaped__": {...}} # is an escaped dict + ``` + """ + return len(obj) == 1 and _LC_ESCAPED_KEY in obj + + +def _serialize_value(obj: Any) -> Any: + """Serialize a value with escaping of user dicts. + + Called recursively on kwarg values to escape any plain dicts that could be confused + with LC objects. + + Args: + obj: The value to serialize. + + Returns: + The serialized value with user dicts escaped as needed. + """ + from langchain_core.load.serializable import ( # noqa: PLC0415 + Serializable, + to_json_not_implemented, + ) + + if isinstance(obj, Serializable): + # This is an LC object - serialize it properly (not escaped) + return _serialize_lc_object(obj) + if isinstance(obj, dict): + if not all(isinstance(k, (str, int, float, bool, type(None))) for k in obj): + # if keys are not json serializable + return to_json_not_implemented(obj) + # Check if dict needs escaping BEFORE recursing into values. + # If it needs escaping, wrap it as-is - the contents are user data that + # will be returned as-is during deserialization (no instantiation). + # This prevents re-escaping of already-escaped nested content. + if _needs_escaping(obj): + return _escape_dict(obj) + # Safe dict (no 'lc' key) - recurse into values + return {k: _serialize_value(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_serialize_value(item) for item in obj] + if isinstance(obj, (str, int, float, bool, type(None))): + return obj + + # Non-JSON-serializable object (datetime, custom objects, etc.) + return to_json_not_implemented(obj) + + +def _is_lc_secret(obj: Any) -> bool: + """Check if an object is a LangChain secret marker.""" + expected_num_keys = 3 + return ( + isinstance(obj, dict) + and obj.get("lc") == 1 + and obj.get("type") == "secret" + and "id" in obj + and len(obj) == expected_num_keys + ) + + +def _serialize_lc_object(obj: Any) -> dict[str, Any]: + """Serialize a `Serializable` object with escaping of user data in kwargs. + + Args: + obj: The `Serializable` object to serialize. + + Returns: + The serialized dict with user data in kwargs escaped as needed. + + Note: + Kwargs values are processed with `_serialize_value` to escape user data (like + metadata) that contains `'lc'` keys. Secret fields (from `lc_secrets`) are + skipped because `to_json()` replaces their values with secret markers. + """ + from langchain_core.load.serializable import Serializable # noqa: PLC0415 + + if not isinstance(obj, Serializable): + msg = f"Expected Serializable, got {type(obj)}" + raise TypeError(msg) + + serialized: dict[str, Any] = dict(obj.to_json()) + + # Process kwargs to escape user data that could be confused with LC objects + # Skip secret fields - to_json() already converted them to secret markers + if serialized.get("type") == "constructor" and "kwargs" in serialized: + serialized["kwargs"] = { + k: v if _is_lc_secret(v) else _serialize_value(v) + for k, v in serialized["kwargs"].items() + } + + return serialized + + +def _unescape_value(obj: Any) -> Any: + """Unescape a value, processing escape markers in dict values and lists. + + When an escaped dict is encountered (`{"__lc_escaped__": ...}`), it's + unwrapped and the contents are returned AS-IS (no further processing). + The contents represent user data that should not be modified. + + For regular dicts and lists, we recurse to find any nested escape markers. + + Args: + obj: The value to unescape. + + Returns: + The unescaped value. + """ + if isinstance(obj, dict): + if _is_escaped_dict(obj): + # Unwrap and return the user data as-is (no further unescaping). + # The contents are user data that may contain more escape keys, + # but those are part of the user's actual data. + return obj[_LC_ESCAPED_KEY] + + # Regular dict - recurse into values to find nested escape markers + return {k: _unescape_value(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_unescape_value(item) for item in obj] + return obj diff --git a/libs/core/langchain_core/load/dump.py b/libs/core/langchain_core/load/dump.py index 4cb9ca59892..07bc3099b6c 100644 --- a/libs/core/langchain_core/load/dump.py +++ b/libs/core/langchain_core/load/dump.py @@ -1,10 +1,26 @@ -"""Dump objects to json.""" +"""Serialize LangChain objects to JSON. + +Provides `dumps` (to JSON string) and `dumpd` (to dict) for serializing +`Serializable` objects. + +## Escaping + +During serialization, plain dicts (user data) that contain an `'lc'` key are escaped +by wrapping them: `{"__lc_escaped__": {...original...}}`. This prevents injection +attacks where malicious data could trick the deserializer into instantiating +arbitrary classes. The escape marker is removed during deserialization. + +This is an allowlist approach: only dicts explicitly produced by +`Serializable.to_json()` are treated as LC objects; everything else is escaped if it +could be confused with the LC format. +""" import json from typing import Any from pydantic import BaseModel +from langchain_core.load._validation import _serialize_value from langchain_core.load.serializable import Serializable, to_json_not_implemented from langchain_core.messages import AIMessage from langchain_core.outputs import ChatGeneration @@ -25,6 +41,20 @@ def default(obj: Any) -> Any: def _dump_pydantic_models(obj: Any) -> Any: + """Convert nested Pydantic models to dicts for JSON serialization. + + Handles the special case where a `ChatGeneration` contains an `AIMessage` + with a parsed Pydantic model in `additional_kwargs["parsed"]`. Since + Pydantic models aren't directly JSON serializable, this converts them to + dicts. + + Args: + obj: The object to process. + + Returns: + A copy of the object with nested Pydantic models converted to dicts, or + the original object unchanged if no conversion was needed. + """ if ( isinstance(obj, ChatGeneration) and isinstance(obj.message, AIMessage) @@ -40,10 +70,17 @@ def _dump_pydantic_models(obj: Any) -> Any: def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str: """Return a JSON string representation of an object. + Note: + Plain dicts containing an `'lc'` key are automatically escaped to prevent + confusion with LC serialization format. The escape marker is removed during + deserialization. + Args: obj: The object to dump. - pretty: Whether to pretty print the json. If `True`, the json will be - indented with 2 spaces (if no indent is provided as part of `kwargs`). + pretty: Whether to pretty print the json. + + If `True`, the json will be indented by either 2 spaces or the amount + provided in the `indent` kwarg. **kwargs: Additional arguments to pass to `json.dumps` Returns: @@ -55,28 +92,29 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str: if "default" in kwargs: msg = "`default` should not be passed to dumps" raise ValueError(msg) - try: - obj = _dump_pydantic_models(obj) - if pretty: - indent = kwargs.pop("indent", 2) - return json.dumps(obj, default=default, indent=indent, **kwargs) - return json.dumps(obj, default=default, **kwargs) - except TypeError: - if pretty: - indent = kwargs.pop("indent", 2) - return json.dumps(to_json_not_implemented(obj), indent=indent, **kwargs) - return json.dumps(to_json_not_implemented(obj), **kwargs) + + obj = _dump_pydantic_models(obj) + serialized = _serialize_value(obj) + + if pretty: + indent = kwargs.pop("indent", 2) + return json.dumps(serialized, indent=indent, **kwargs) + return json.dumps(serialized, **kwargs) def dumpd(obj: Any) -> Any: """Return a dict representation of an object. + Note: + Plain dicts containing an `'lc'` key are automatically escaped to prevent + confusion with LC serialization format. The escape marker is removed during + deserialization. + Args: obj: The object to dump. Returns: Dictionary that can be serialized to json using `json.dumps`. """ - # Unfortunately this function is not as efficient as it could be because it first - # dumps the object to a json string and then loads it back into a dictionary. - return json.loads(dumps(obj)) + obj = _dump_pydantic_models(obj) + return _serialize_value(obj) diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index 1c989b8595e..41849ed0e78 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -1,16 +1,83 @@ """Load LangChain objects from JSON strings or objects. -!!! warning - `load` and `loads` are vulnerable to remote code execution. Never use with untrusted - input. +## How it works + +Each `Serializable` LangChain object has a unique identifier (its "class path"), which +is a list of strings representing the module path and class name. For example: + +- `AIMessage` -> `["langchain_core", "messages", "ai", "AIMessage"]` +- `ChatPromptTemplate` -> `["langchain_core", "prompts", "chat", "ChatPromptTemplate"]` + +When deserializing, the class path from the JSON `'id'` field is checked against an +allowlist. If the class is not in the allowlist, deserialization raises a `ValueError`. + +## Security model + +The `allowed_objects` parameter controls which classes can be deserialized: + +- **`'core'` (default)**: Allow classes defined in the serialization mappings for + langchain_core. +- **`'all'`**: Allow classes defined in the serialization mappings. This + includes core LangChain types (messages, prompts, documents, etc.) and trusted + partner integrations. See `langchain_core.load.mapping` for the full list. +- **Explicit list of classes**: Only those specific classes are allowed. + +For simple data types like messages and documents, the default allowlist is safe to use. +These classes do not perform side effects during initialization. + +!!! note "Side effects in allowed classes" + + Deserialization calls `__init__` on allowed classes. If those classes perform side + effects during initialization (network calls, file operations, etc.), those side + effects will occur. The allowlist prevents instantiation of classes outside the + allowlist, but does not sandbox the allowed classes themselves. + +Import paths are also validated against trusted namespaces before any module is +imported. + +### Injection protection (escape-based) + +During serialization, plain dicts that contain an `'lc'` key are escaped by wrapping +them: `{"__lc_escaped__": {...}}`. During deserialization, escaped dicts are unwrapped +and returned as plain dicts, NOT instantiated as LC objects. + +This is an allowlist approach: only dicts explicitly produced by +`Serializable.to_json()` (which are NOT escaped) are treated as LC objects; +everything else is user data. + +Even if an attacker's payload includes `__lc_escaped__` wrappers, it will be unwrapped +to plain dicts and NOT instantiated as malicious objects. + +## Examples + +```python +from langchain_core.load import load +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.messages import AIMessage, HumanMessage + +# Use default allowlist (classes from mappings) - recommended +obj = load(data) + +# Allow only specific classes (most restrictive) +obj = load( + data, + allowed_objects=[ + ChatPromptTemplate, + AIMessage, + HumanMessage, + ], +) +``` """ import importlib import json import os -from typing import Any +from collections.abc import Callable, Iterable +from typing import Any, Literal, cast from langchain_core._api import beta +from langchain_core.load._validation import _is_escaped_dict, _unescape_value from langchain_core.load.mapping import ( _JS_SERIALIZABLE_MAPPING, _OG_SERIALIZABLE_MAPPING, @@ -49,34 +116,209 @@ ALL_SERIALIZABLE_MAPPINGS = { **_JS_SERIALIZABLE_MAPPING, } +# Cache for the default allowed class paths computed from mappings +# Maps mode ("all" or "core") to the cached set of paths +_default_class_paths_cache: dict[str, set[tuple[str, ...]]] = {} + + +def _get_default_allowed_class_paths( + allowed_object_mode: Literal["all", "core"], +) -> set[tuple[str, ...]]: + """Get the default allowed class paths from the serialization mappings. + + This uses the mappings as the source of truth for what classes are allowed + by default. Both the legacy paths (keys) and current paths (values) are included. + + Args: + allowed_object_mode: either `'all'` or `'core'`. + + Returns: + Set of class path tuples that are allowed by default. + """ + if allowed_object_mode in _default_class_paths_cache: + return _default_class_paths_cache[allowed_object_mode] + + allowed_paths: set[tuple[str, ...]] = set() + for key, value in ALL_SERIALIZABLE_MAPPINGS.items(): + if allowed_object_mode == "core" and value[0] != "langchain_core": + continue + allowed_paths.add(key) + allowed_paths.add(value) + + _default_class_paths_cache[allowed_object_mode] = allowed_paths + return _default_class_paths_cache[allowed_object_mode] + + +def _block_jinja2_templates( + class_path: tuple[str, ...], + kwargs: dict[str, Any], +) -> None: + """Block jinja2 templates during deserialization for security. + + Jinja2 templates can execute arbitrary code, so they are blocked by default when + deserializing objects with `template_format='jinja2'`. + + Note: + We intentionally do NOT check the `class_path` here to keep this simple and + future-proof. If any new class is added that accepts `template_format='jinja2'`, + it will be automatically blocked without needing to update this function. + + Args: + class_path: The class path tuple being deserialized (unused). + kwargs: The kwargs dict for the class constructor. + + Raises: + ValueError: If `template_format` is `'jinja2'`. + """ + _ = class_path # Unused - see docstring for rationale. Kept to satisfy signature. + if kwargs.get("template_format") == "jinja2": + msg = ( + "Jinja2 templates are not allowed during deserialization for security " + "reasons. Use 'f-string' template format instead, or explicitly allow " + "jinja2 by providing a custom init_validator." + ) + raise ValueError(msg) + + +def default_init_validator( + class_path: tuple[str, ...], + kwargs: dict[str, Any], +) -> None: + """Default init validator that blocks jinja2 templates. + + This is the default validator used by `load()` and `loads()` when no custom + validator is provided. + + Args: + class_path: The class path tuple being deserialized. + kwargs: The kwargs dict for the class constructor. + + Raises: + ValueError: If template_format is `'jinja2'`. + """ + _block_jinja2_templates(class_path, kwargs) + + +AllowedObject = type[Serializable] +"""Type alias for classes that can be included in the `allowed_objects` parameter. + +Must be a `Serializable` subclass (the class itself, not an instance). +""" + +InitValidator = Callable[[tuple[str, ...], dict[str, Any]], None] +"""Type alias for a callable that validates kwargs during deserialization. + +The callable receives: + +- `class_path`: A tuple of strings identifying the class being instantiated + (e.g., `('langchain', 'schema', 'messages', 'AIMessage')`). +- `kwargs`: The kwargs dict that will be passed to the constructor. + +The validator should raise an exception if the object should not be deserialized. +""" + + +def _compute_allowed_class_paths( + allowed_objects: Iterable[AllowedObject], + import_mappings: dict[tuple[str, ...], tuple[str, ...]], +) -> set[tuple[str, ...]]: + """Return allowed class paths from an explicit list of classes. + + A class path is a tuple of strings identifying a serializable class, derived from + `Serializable.lc_id()`. For example: `('langchain_core', 'messages', 'AIMessage')`. + + Args: + allowed_objects: Iterable of `Serializable` subclasses to allow. + import_mappings: Mapping of legacy class paths to current class paths. + + Returns: + Set of allowed class paths. + + Example: + ```python + # Allow a specific class + _compute_allowed_class_paths([MyPrompt], {}) -> + {("langchain_core", "prompts", "MyPrompt")} + + # Include legacy paths that map to the same class + import_mappings = {("old", "Prompt"): ("langchain_core", "prompts", "MyPrompt")} + _compute_allowed_class_paths([MyPrompt], import_mappings) -> + {("langchain_core", "prompts", "MyPrompt"), ("old", "Prompt")} + ``` + """ + allowed_objects_list = list(allowed_objects) + + allowed_class_paths: set[tuple[str, ...]] = set() + for allowed_obj in allowed_objects_list: + if not isinstance(allowed_obj, type) or not issubclass( + allowed_obj, Serializable + ): + msg = "allowed_objects must contain Serializable subclasses." + raise TypeError(msg) + + class_path = tuple(allowed_obj.lc_id()) + allowed_class_paths.add(class_path) + # Add legacy paths that map to the same class. + for mapping_key, mapping_value in import_mappings.items(): + if tuple(mapping_value) == class_path: + allowed_class_paths.add(mapping_key) + return allowed_class_paths + class Reviver: - """Reviver for JSON objects.""" + """Reviver for JSON objects. + + Used as the `object_hook` for `json.loads` to reconstruct LangChain objects from + their serialized JSON representation. + + Only classes in the allowlist can be instantiated. + """ def __init__( self, + allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core", secrets_map: dict[str, str] | None = None, valid_namespaces: list[str] | None = None, - secrets_from_env: bool = True, # noqa: FBT001,FBT002 + secrets_from_env: bool = False, # noqa: FBT001,FBT002 additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None, *, ignore_unserializable_fields: bool = False, + init_validator: InitValidator | None = default_init_validator, ) -> None: """Initialize the reviver. Args: - secrets_map: A map of secrets to load. + allowed_objects: Allowlist of classes that can be deserialized. + - `'core'` (default): Allow classes defined in the serialization + mappings for `langchain_core`. + - `'all'`: Allow classes defined in the serialization mappings. + This includes core LangChain types (messages, prompts, documents, + etc.) and trusted partner integrations. See + `langchain_core.load.mapping` for the full list. + - Explicit list of classes: Only those specific classes are allowed. + secrets_map: A map of secrets to load. If a secret is not found in the map, it will be loaded from the environment if `secrets_from_env` is `True`. - valid_namespaces: A list of additional namespaces (modules) - to allow to be deserialized. + valid_namespaces: Additional namespaces (modules) to allow during + deserialization, beyond the default trusted namespaces. secrets_from_env: Whether to load secrets from the environment. - additional_import_mappings: A dictionary of additional namespace mappings + additional_import_mappings: A dictionary of additional namespace mappings. You can use this to override default mappings or add new mappings. + + When `allowed_objects` is `None` (using defaults), paths from these + mappings are also added to the allowed class paths. ignore_unserializable_fields: Whether to ignore unserializable fields. + init_validator: Optional callable to validate kwargs before instantiation. + + If provided, this function is called with `(class_path, kwargs)` where + `class_path` is the class path tuple and `kwargs` is the kwargs dict. + The validator should raise an exception if the object should not be + deserialized, otherwise return `None`. + + Defaults to `default_init_validator` which blocks jinja2 templates. """ self.secrets_from_env = secrets_from_env self.secrets_map = secrets_map or {} @@ -95,7 +337,26 @@ class Reviver: if self.additional_import_mappings else ALL_SERIALIZABLE_MAPPINGS ) + # Compute allowed class paths: + # - "all" -> use default paths from mappings (+ additional_import_mappings) + # - Explicit list -> compute from those classes + if allowed_objects in ("all", "core"): + self.allowed_class_paths: set[tuple[str, ...]] | None = ( + _get_default_allowed_class_paths( + cast("Literal['all', 'core']", allowed_objects) + ).copy() + ) + # Add paths from additional_import_mappings to the defaults + if self.additional_import_mappings: + for key, value in self.additional_import_mappings.items(): + self.allowed_class_paths.add(key) + self.allowed_class_paths.add(value) + else: + self.allowed_class_paths = _compute_allowed_class_paths( + cast("Iterable[AllowedObject]", allowed_objects), self.import_mappings + ) self.ignore_unserializable_fields = ignore_unserializable_fields + self.init_validator = init_validator def __call__(self, value: dict[str, Any]) -> Any: """Revive the value. @@ -146,6 +407,20 @@ class Reviver: [*namespace, name] = value["id"] mapping_key = tuple(value["id"]) + if ( + self.allowed_class_paths is not None + and mapping_key not in self.allowed_class_paths + ): + msg = ( + f"Deserialization of {mapping_key!r} is not allowed. " + "The default (allowed_objects='core') only permits core " + "langchain-core classes. To allow trusted partner integrations, " + "use allowed_objects='all'. Alternatively, pass an explicit list " + "of allowed classes via allowed_objects=[...]. " + "See langchain_core.load.mapping for the full allowlist." + ) + raise ValueError(msg) + if ( namespace[0] not in self.valid_namespaces # The root namespace ["langchain"] is not a valid identifier. @@ -153,13 +428,11 @@ class Reviver: ): msg = f"Invalid namespace: {value}" raise ValueError(msg) - # Has explicit import path. + # Determine explicit import path if mapping_key in self.import_mappings: import_path = self.import_mappings[mapping_key] # Split into module and name import_dir, name = import_path[:-1], import_path[-1] - # Import module - mod = importlib.import_module(".".join(import_dir)) elif namespace[0] in DISALLOW_LOAD_FROM_PATH: msg = ( "Trying to deserialize something that cannot " @@ -167,9 +440,16 @@ class Reviver: f"{mapping_key}." ) raise ValueError(msg) - # Otherwise, treat namespace as path. else: - mod = importlib.import_module(".".join(namespace)) + # Otherwise, treat namespace as path. + import_dir = namespace + + # Validate import path is in trusted namespaces before importing + if import_dir[0] not in self.valid_namespaces: + msg = f"Invalid namespace: {value}" + raise ValueError(msg) + + mod = importlib.import_module(".".join(import_dir)) cls = getattr(mod, name) @@ -181,6 +461,10 @@ class Reviver: # We don't need to recurse on kwargs # as json.loads will do that for us. kwargs = value.get("kwargs", {}) + + if self.init_validator is not None: + self.init_validator(mapping_key, kwargs) + return cls(**kwargs) return value @@ -190,46 +474,74 @@ class Reviver: def loads( text: str, *, + allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core", secrets_map: dict[str, str] | None = None, valid_namespaces: list[str] | None = None, - secrets_from_env: bool = True, + secrets_from_env: bool = False, additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None, ignore_unserializable_fields: bool = False, + init_validator: InitValidator | None = default_init_validator, ) -> Any: """Revive a LangChain class from a JSON string. - !!! warning - This function is vulnerable to remote code execution. Never use with untrusted - input. - Equivalent to `load(json.loads(text))`. + Only classes in the allowlist can be instantiated. The default allowlist includes + core LangChain types (messages, prompts, documents, etc.). See + `langchain_core.load.mapping` for the full list. + Args: text: The string to load. + allowed_objects: Allowlist of classes that can be deserialized. + + - `'core'` (default): Allow classes defined in the serialization mappings + for langchain_core. + - `'all'`: Allow classes defined in the serialization mappings. + + This includes core LangChain types (messages, prompts, documents, etc.) + and trusted partner integrations. See `langchain_core.load.mapping` for + the full list. + - Explicit list of classes: Only those specific classes are allowed. + - `[]`: Disallow all deserialization (will raise on any object). secrets_map: A map of secrets to load. If a secret is not found in the map, it will be loaded from the environment if `secrets_from_env` is `True`. - valid_namespaces: A list of additional namespaces (modules) - to allow to be deserialized. + valid_namespaces: Additional namespaces (modules) to allow during + deserialization, beyond the default trusted namespaces. secrets_from_env: Whether to load secrets from the environment. - additional_import_mappings: A dictionary of additional namespace mappings + additional_import_mappings: A dictionary of additional namespace mappings. You can use this to override default mappings or add new mappings. + + When `allowed_objects` is `None` (using defaults), paths from these + mappings are also added to the allowed class paths. ignore_unserializable_fields: Whether to ignore unserializable fields. + init_validator: Optional callable to validate kwargs before instantiation. + + If provided, this function is called with `(class_path, kwargs)` where + `class_path` is the class path tuple and `kwargs` is the kwargs dict. + The validator should raise an exception if the object should not be + deserialized, otherwise return `None`. Defaults to + `default_init_validator` which blocks jinja2 templates. Returns: Revived LangChain objects. + + Raises: + ValueError: If an object's class path is not in the `allowed_objects` allowlist. """ - return json.loads( - text, - object_hook=Reviver( - secrets_map, - valid_namespaces, - secrets_from_env, - additional_import_mappings, - ignore_unserializable_fields=ignore_unserializable_fields, - ), + # Parse JSON and delegate to load() for proper escape handling + raw_obj = json.loads(text) + return load( + raw_obj, + allowed_objects=allowed_objects, + secrets_map=secrets_map, + valid_namespaces=valid_namespaces, + secrets_from_env=secrets_from_env, + additional_import_mappings=additional_import_mappings, + ignore_unserializable_fields=ignore_unserializable_fields, + init_validator=init_validator, ) @@ -237,49 +549,105 @@ def loads( def load( obj: Any, *, + allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core", secrets_map: dict[str, str] | None = None, valid_namespaces: list[str] | None = None, - secrets_from_env: bool = True, + secrets_from_env: bool = False, additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None, ignore_unserializable_fields: bool = False, + init_validator: InitValidator | None = default_init_validator, ) -> Any: """Revive a LangChain class from a JSON object. - !!! warning - This function is vulnerable to remote code execution. Never use with untrusted - input. + Use this if you already have a parsed JSON object, eg. from `json.load` or + `orjson.loads`. - Use this if you already have a parsed JSON object, - eg. from `json.load` or `orjson.loads`. + Only classes in the allowlist can be instantiated. The default allowlist includes + core LangChain types (messages, prompts, documents, etc.). See + `langchain_core.load.mapping` for the full list. Args: obj: The object to load. + allowed_objects: Allowlist of classes that can be deserialized. + + - `'core'` (default): Allow classes defined in the serialization mappings + for langchain_core. + - `'all'`: Allow classes defined in the serialization mappings. + + This includes core LangChain types (messages, prompts, documents, etc.) + and trusted partner integrations. See `langchain_core.load.mapping` for + the full list. + - Explicit list of classes: Only those specific classes are allowed. + - `[]`: Disallow all deserialization (will raise on any object). secrets_map: A map of secrets to load. If a secret is not found in the map, it will be loaded from the environment if `secrets_from_env` is `True`. - valid_namespaces: A list of additional namespaces (modules) - to allow to be deserialized. + valid_namespaces: Additional namespaces (modules) to allow during + deserialization, beyond the default trusted namespaces. secrets_from_env: Whether to load secrets from the environment. - additional_import_mappings: A dictionary of additional namespace mappings + additional_import_mappings: A dictionary of additional namespace mappings. You can use this to override default mappings or add new mappings. + + When `allowed_objects` is `None` (using defaults), paths from these + mappings are also added to the allowed class paths. ignore_unserializable_fields: Whether to ignore unserializable fields. + init_validator: Optional callable to validate kwargs before instantiation. + + If provided, this function is called with `(class_path, kwargs)` where + `class_path` is the class path tuple and `kwargs` is the kwargs dict. + The validator should raise an exception if the object should not be + deserialized, otherwise return `None`. Defaults to + `default_init_validator` which blocks jinja2 templates. Returns: Revived LangChain objects. + + Raises: + ValueError: If an object's class path is not in the `allowed_objects` allowlist. + + Example: + ```python + from langchain_core.load import load, dumpd + from langchain_core.messages import AIMessage + + msg = AIMessage(content="Hello") + data = dumpd(msg) + + # Deserialize using default allowlist + loaded = load(data) + + # Or with explicit allowlist + loaded = load(data, allowed_objects=[AIMessage]) + + # Or extend defaults with additional mappings + loaded = load( + data, + additional_import_mappings={ + ("my_pkg", "MyClass"): ("my_pkg", "module", "MyClass"), + }, + ) + ``` """ reviver = Reviver( + allowed_objects, secrets_map, valid_namespaces, secrets_from_env, additional_import_mappings, ignore_unserializable_fields=ignore_unserializable_fields, + init_validator=init_validator, ) def _load(obj: Any) -> Any: if isinstance(obj, dict): - # Need to revive leaf nodes before reviving this node + # Check for escaped dict FIRST (before recursing). + # Escaped dicts are user data that should NOT be processed as LC objects. + if _is_escaped_dict(obj): + return _unescape_value(obj) + + # Not escaped - recurse into children then apply reviver loaded_obj = {k: _load(v) for k, v in obj.items()} return reviver(loaded_obj) if isinstance(obj, list): diff --git a/libs/core/langchain_core/load/mapping.py b/libs/core/langchain_core/load/mapping.py index e2367ea1ebb..59c64a08e37 100644 --- a/libs/core/langchain_core/load/mapping.py +++ b/libs/core/langchain_core/load/mapping.py @@ -1,21 +1,19 @@ """Serialization mapping. -This file contains a mapping between the lc_namespace path for a given -subclass that implements from Serializable to the namespace +This file contains a mapping between the `lc_namespace` path for a given +subclass that implements from `Serializable` to the namespace where that class is actually located. This mapping helps maintain the ability to serialize and deserialize well-known LangChain objects even if they are moved around in the codebase across different LangChain versions. -For example, +For example, the code for the `AIMessage` class is located in +`langchain_core.messages.ai.AIMessage`. This message is associated with the +`lc_namespace` of `["langchain", "schema", "messages", "AIMessage"]`, +because this code was originally in `langchain.schema.messages.AIMessage`. -The code for AIMessage class is located in langchain_core.messages.ai.AIMessage, -This message is associated with the lc_namespace -["langchain", "schema", "messages", "AIMessage"], -because this code was originally in langchain.schema.messages.AIMessage. - -The mapping allows us to deserialize an AIMessage created with an older +The mapping allows us to deserialize an `AIMessage` created with an older version of LangChain where the code was in a different location. """ @@ -275,6 +273,11 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = { "chat_models", "ChatGroq", ), + ("langchain_xai", "chat_models", "ChatXAI"): ( + "langchain_xai", + "chat_models", + "ChatXAI", + ), ("langchain", "chat_models", "fireworks", "ChatFireworks"): ( "langchain_fireworks", "chat_models", @@ -529,16 +532,6 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = { "structured", "StructuredPrompt", ), - ("langchain_sambanova", "chat_models", "ChatSambaNovaCloud"): ( - "langchain_sambanova", - "chat_models", - "ChatSambaNovaCloud", - ), - ("langchain_sambanova", "chat_models", "ChatSambaStudio"): ( - "langchain_sambanova", - "chat_models", - "ChatSambaStudio", - ), ("langchain_core", "prompts", "message", "_DictMessagePromptTemplate"): ( "langchain_core", "prompts", diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 91628a8b1d9..a4abbd38575 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -539,7 +539,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] hist: BaseChatMessageHistory = config["configurable"]["message_history"] # Get the input messages - inputs = load(run.inputs) + inputs = load(run.inputs, allowed_objects="all") input_messages = self._get_input_messages(inputs) # If historic messages were prepended to the input messages, remove them to # avoid adding duplicate messages to history. @@ -548,7 +548,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] input_messages = input_messages[len(historic_messages) :] # Get the output messages - output_val = load(run.outputs) + output_val = load(run.outputs, allowed_objects="all") output_messages = self._get_output_messages(output_val) hist.add_messages(input_messages + output_messages) @@ -556,7 +556,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] hist: BaseChatMessageHistory = config["configurable"]["message_history"] # Get the input messages - inputs = load(run.inputs) + inputs = load(run.inputs, allowed_objects="all") input_messages = self._get_input_messages(inputs) # If historic messages were prepended to the input messages, remove them to # avoid adding duplicate messages to history. @@ -565,7 +565,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] input_messages = input_messages[len(historic_messages) :] # Get the output messages - output_val = load(run.outputs) + output_val = load(run.outputs, allowed_objects="all") output_messages = self._get_output_messages(output_val) await hist.aadd_messages(input_messages + output_messages) diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 345d5176735..1240d9eb851 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -563,7 +563,7 @@ def _get_standardized_inputs( ) raise NotImplementedError(msg) - inputs = load(run.inputs) + inputs = load(run.inputs, allowed_objects="all") if run.run_type in {"retriever", "llm", "chat_model"}: return inputs @@ -595,7 +595,7 @@ def _get_standardized_outputs( Returns: An output if returned, otherwise a None """ - outputs = load(run.outputs) + outputs = load(run.outputs, allowed_objects="all") if schema_format == "original": if run.run_type == "prompt" and "output" in outputs: # These were previously dumped before the tracer. diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index 919cd6a377a..ad893a02637 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -528,7 +528,7 @@ class InMemoryVectorStore(VectorStore): """ path_: Path = Path(path) with path_.open("r", encoding="utf-8") as f: - store = load(json.load(f)) + store = load(json.load(f), allowed_objects=[Document]) vectorstore = cls(embedding=embedding, **kwargs) vectorstore.store = store return vectorstore diff --git a/libs/core/tests/unit_tests/load/test_imports.py b/libs/core/tests/unit_tests/load/test_imports.py index a6e87afb5a9..9284773fa47 100644 --- a/libs/core/tests/unit_tests/load/test_imports.py +++ b/libs/core/tests/unit_tests/load/test_imports.py @@ -1,6 +1,13 @@ from langchain_core.load import __all__ -EXPECTED_ALL = ["dumpd", "dumps", "load", "loads", "Serializable"] +EXPECTED_ALL = [ + "InitValidator", + "Serializable", + "dumpd", + "dumps", + "load", + "loads", +] def test_all_imports() -> None: diff --git a/libs/core/tests/unit_tests/load/test_secret_injection.py b/libs/core/tests/unit_tests/load/test_secret_injection.py new file mode 100644 index 00000000000..7810638e62c --- /dev/null +++ b/libs/core/tests/unit_tests/load/test_secret_injection.py @@ -0,0 +1,431 @@ +"""Tests for secret injection prevention in serialization. + +Verify that user-provided data containing secret-like structures cannot be used to +extract environment variables during deserialization. +""" + +import json +import os +import re +from typing import Any +from unittest import mock + +import pytest +from pydantic import BaseModel + +from langchain_core.documents import Document +from langchain_core.load import dumpd, dumps, load +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.outputs import ChatGeneration + +SENTINEL_ENV_VAR = "TEST_SECRET_INJECTION_VAR" +"""Sentinel value that should NEVER appear in serialized output.""" + +SENTINEL_VALUE = "LEAKED_SECRET_MEOW_12345" +"""Sentinel value that should NEVER appear in serialized output.""" + +MALICIOUS_SECRET_DICT: dict[str, Any] = { + "lc": 1, + "type": "secret", + "id": [SENTINEL_ENV_VAR], +} +"""The malicious secret-like dict that tries to read the env var""" + + +@pytest.fixture(autouse=True) +def _set_sentinel_env_var() -> Any: + """Set the sentinel env var for all tests in this module.""" + with mock.patch.dict(os.environ, {SENTINEL_ENV_VAR: SENTINEL_VALUE}): + yield + + +def _assert_no_secret_leak(payload: Any) -> None: + """Assert that serializing/deserializing payload doesn't leak the secret.""" + # First serialize + serialized = dumps(payload) + + # Deserialize with secrets_from_env=True (the dangerous setting) + deserialized = load(serialized, secrets_from_env=True) + + # Re-serialize to string + reserialized = dumps(deserialized) + + assert SENTINEL_VALUE not in reserialized, ( + f"Secret was leaked! Found '{SENTINEL_VALUE}' in output.\n" + f"Original payload type: {type(payload)}\n" + f"Reserialized output: {reserialized[:500]}..." + ) + + assert SENTINEL_VALUE not in repr(deserialized), ( + f"Secret was leaked in deserialized object! Found '{SENTINEL_VALUE}'.\n" + f"Deserialized: {deserialized!r}" + ) + + +class TestSerializableTopLevel: + """Tests with `Serializable` objects at the top level.""" + + def test_human_message_with_secret_in_content(self) -> None: + """`HumanMessage` with secret-like dict in `content`.""" + msg = HumanMessage( + content=[ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": MALICIOUS_SECRET_DICT}, + ] + ) + _assert_no_secret_leak(msg) + + def test_human_message_with_secret_in_additional_kwargs(self) -> None: + """`HumanMessage` with secret-like dict in `additional_kwargs`.""" + msg = HumanMessage( + content="Hello", + additional_kwargs={"data": MALICIOUS_SECRET_DICT}, + ) + _assert_no_secret_leak(msg) + + def test_human_message_with_secret_in_nested_additional_kwargs(self) -> None: + """`HumanMessage` with secret-like dict nested in `additional_kwargs`.""" + msg = HumanMessage( + content="Hello", + additional_kwargs={"nested": {"deep": MALICIOUS_SECRET_DICT}}, + ) + _assert_no_secret_leak(msg) + + def test_human_message_with_secret_in_list_in_additional_kwargs(self) -> None: + """`HumanMessage` with secret-like dict in a list in `additional_kwargs`.""" + msg = HumanMessage( + content="Hello", + additional_kwargs={"items": [MALICIOUS_SECRET_DICT]}, + ) + _assert_no_secret_leak(msg) + + def test_ai_message_with_secret_in_response_metadata(self) -> None: + """`AIMessage` with secret-like dict in respo`nse_metadata.""" + msg = AIMessage( + content="Hello", + response_metadata={"data": MALICIOUS_SECRET_DICT}, + ) + _assert_no_secret_leak(msg) + + def test_document_with_secret_in_metadata(self) -> None: + """Document with secret-like dict in `metadata`.""" + doc = Document( + page_content="Hello", + metadata={"data": MALICIOUS_SECRET_DICT}, + ) + _assert_no_secret_leak(doc) + + def test_nested_serializable_with_secret(self) -> None: + """`AIMessage` containing `dumpd(HumanMessage)` with secret in kwargs.""" + inner = HumanMessage( + content="Hello", + additional_kwargs={"secret": MALICIOUS_SECRET_DICT}, + ) + outer = AIMessage( + content="Outer", + additional_kwargs={"nested": [dumpd(inner)]}, + ) + _assert_no_secret_leak(outer) + + +class TestDictTopLevel: + """Tests with plain dicts at the top level.""" + + def test_dict_with_serializable_containing_secret(self) -> None: + """Dict containing a `Serializable` with secret-like dict.""" + msg = HumanMessage( + content="Hello", + additional_kwargs={"data": MALICIOUS_SECRET_DICT}, + ) + payload = {"message": msg} + _assert_no_secret_leak(payload) + + def test_dict_with_secret_no_serializable(self) -> None: + """Dict with secret-like dict, no `Serializable` objects.""" + payload = {"data": MALICIOUS_SECRET_DICT} + _assert_no_secret_leak(payload) + + def test_dict_with_nested_secret_no_serializable(self) -> None: + """Dict with nested secret-like dict, no `Serializable` objects.""" + payload = {"outer": {"inner": MALICIOUS_SECRET_DICT}} + _assert_no_secret_leak(payload) + + def test_dict_with_secret_in_list(self) -> None: + """Dict with secret-like dict in a list.""" + payload = {"items": [MALICIOUS_SECRET_DICT]} + _assert_no_secret_leak(payload) + + def test_dict_mimicking_lc_constructor_with_secret(self) -> None: + """Dict that looks like an LC constructor containing a secret.""" + payload = { + "lc": 1, + "type": "constructor", + "id": ["langchain_core", "messages", "ai", "AIMessage"], + "kwargs": { + "content": "Hello", + "additional_kwargs": {"secret": MALICIOUS_SECRET_DICT}, + }, + } + _assert_no_secret_leak(payload) + + +class TestPydanticModelTopLevel: + """Tests with Pydantic models (non-`Serializable`) at the top level.""" + + def test_pydantic_model_with_serializable_containing_secret(self) -> None: + """Pydantic model containing a `Serializable` with secret-like dict.""" + + class MyModel(BaseModel): + message: Any + + msg = HumanMessage( + content="Hello", + additional_kwargs={"data": MALICIOUS_SECRET_DICT}, + ) + payload = MyModel(message=msg) + _assert_no_secret_leak(payload) + + def test_pydantic_model_with_secret_dict(self) -> None: + """Pydantic model containing a secret-like dict directly.""" + + class MyModel(BaseModel): + data: dict[str, Any] + + payload = MyModel(data=MALICIOUS_SECRET_DICT) + _assert_no_secret_leak(payload) + + # Test treatment of "parsed" in additional_kwargs + msg = AIMessage(content=[], additional_kwargs={"parsed": payload}) + gen = ChatGeneration(message=msg) + _assert_no_secret_leak(gen) + round_trip = load(dumpd(gen)) + assert MyModel(**(round_trip.message.additional_kwargs["parsed"])) == payload + + def test_pydantic_model_with_nested_secret(self) -> None: + """Pydantic model with nested secret-like dict.""" + + class MyModel(BaseModel): + nested: dict[str, Any] + + payload = MyModel(nested={"inner": MALICIOUS_SECRET_DICT}) + _assert_no_secret_leak(payload) + + +class TestNonSerializableClassTopLevel: + """Tests with classes at the top level.""" + + def test_custom_class_with_serializable_containing_secret(self) -> None: + """Custom class containing a `Serializable` with secret-like dict.""" + + class MyClass: + def __init__(self, message: Any) -> None: + self.message = message + + msg = HumanMessage( + content="Hello", + additional_kwargs={"data": MALICIOUS_SECRET_DICT}, + ) + payload = MyClass(message=msg) + # This will serialize as not_implemented, but let's verify no leak + _assert_no_secret_leak(payload) + + def test_custom_class_with_secret_dict(self) -> None: + """Custom class containing a secret-like dict directly.""" + + class MyClass: + def __init__(self, data: dict[str, Any]) -> None: + self.data = data + + payload = MyClass(data=MALICIOUS_SECRET_DICT) + _assert_no_secret_leak(payload) + + +class TestDumpdInKwargs: + """Tests for the specific pattern of `dumpd()` result stored in kwargs.""" + + def test_dumpd_human_message_in_ai_message_kwargs(self) -> None: + """`AIMessage` with `dumpd(HumanMessage)` in `additional_kwargs`.""" + h = HumanMessage("Hello") + a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]}) + _assert_no_secret_leak(a) + + def test_dumpd_human_message_with_secret_in_ai_message_kwargs(self) -> None: + """`AIMessage` with `dumpd(HumanMessage w/ secret)` in `additional_kwargs`.""" + h = HumanMessage( + "Hello", + additional_kwargs={"secret": MALICIOUS_SECRET_DICT}, + ) + a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]}) + _assert_no_secret_leak(a) + + def test_double_dumpd_nesting(self) -> None: + """Double nesting: `dumpd(AIMessage(dumpd(HumanMessage)))`.""" + h = HumanMessage( + "Hello", + additional_kwargs={"secret": MALICIOUS_SECRET_DICT}, + ) + a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]}) + outer = AIMessage("outer", additional_kwargs={"nested": [dumpd(a)]}) + _assert_no_secret_leak(outer) + + +class TestRoundTrip: + """Tests that verify round-trip serialization preserves data structure.""" + + def test_human_message_with_secret_round_trip(self) -> None: + """Verify secret-like dict is preserved as dict after round-trip.""" + msg = HumanMessage( + content="Hello", + additional_kwargs={"data": MALICIOUS_SECRET_DICT}, + ) + + serialized = dumpd(msg) + deserialized = load(serialized, secrets_from_env=True) + + # The secret-like dict should be preserved as a plain dict + assert deserialized.additional_kwargs["data"] == MALICIOUS_SECRET_DICT + assert isinstance(deserialized.additional_kwargs["data"], dict) + + def test_document_with_secret_round_trip(self) -> None: + """Verify secret-like dict in `Document` metadata is preserved.""" + doc = Document( + page_content="Hello", + metadata={"data": MALICIOUS_SECRET_DICT}, + ) + + serialized = dumpd(doc) + deserialized = load( + serialized, secrets_from_env=True, allowed_objects=[Document] + ) + + # The secret-like dict should be preserved as a plain dict + assert deserialized.metadata["data"] == MALICIOUS_SECRET_DICT + assert isinstance(deserialized.metadata["data"], dict) + + def test_plain_dict_with_secret_round_trip(self) -> None: + """Verify secret-like dict in plain dict is preserved.""" + payload = {"data": MALICIOUS_SECRET_DICT} + + serialized = dumpd(payload) + deserialized = load(serialized, secrets_from_env=True) + + # The secret-like dict should be preserved as a plain dict + assert deserialized["data"] == MALICIOUS_SECRET_DICT + assert isinstance(deserialized["data"], dict) + + +class TestEscapingEfficiency: + """Tests that escaping doesn't cause excessive nesting.""" + + def test_no_triple_escaping(self) -> None: + """Verify dumpd doesn't cause triple/multiple escaping.""" + h = HumanMessage( + "Hello", + additional_kwargs={"bar": [MALICIOUS_SECRET_DICT]}, + ) + a = AIMessage("foo", additional_kwargs={"bar": [dumpd(h)]}) + d = dumpd(a) + + serialized = json.dumps(d) + # Count nested escape markers - + # should be max 2 (one for HumanMessage, one for secret) + # Not 3+ which would indicate re-escaping of already-escaped content + escape_count = len(re.findall(r"__lc_escaped__", serialized)) + + # The HumanMessage dict gets escaped (1), the secret inside gets escaped (1) + # Total should be 2, not 4 (which would mean triple nesting) + assert escape_count <= 2, ( + f"Found {escape_count} escape markers, expected <= 2. " + f"This indicates unnecessary re-escaping.\n{serialized}" + ) + + def test_double_nesting_no_quadruple_escape(self) -> None: + """Verify double dumpd nesting doesn't explode escape markers.""" + h = HumanMessage( + "Hello", + additional_kwargs={"secret": MALICIOUS_SECRET_DICT}, + ) + a = AIMessage("middle", additional_kwargs={"nested": [dumpd(h)]}) + outer = AIMessage("outer", additional_kwargs={"deep": [dumpd(a)]}) + d = dumpd(outer) + + serialized = json.dumps(d) + escape_count = len(re.findall(r"__lc_escaped__", serialized)) + + # Should be: + # outer escapes middle (1), + # middle escapes h (1), + # h escapes secret (1) = 3 + # Not 6+ which would indicate re-escaping + assert escape_count <= 3, ( + f"Found {escape_count} escape markers, expected <= 3. " + f"This indicates unnecessary re-escaping." + ) + + +class TestConstructorInjection: + """Tests for constructor-type injection (not just secrets).""" + + def test_constructor_in_metadata_not_instantiated(self) -> None: + """Verify constructor-like dict in metadata is not instantiated.""" + malicious_constructor = { + "lc": 1, + "type": "constructor", + "id": ["langchain_core", "messages", "ai", "AIMessage"], + "kwargs": {"content": "injected"}, + } + + doc = Document( + page_content="Hello", + metadata={"data": malicious_constructor}, + ) + + serialized = dumpd(doc) + deserialized = load( + serialized, + secrets_from_env=True, + allowed_objects=[Document, AIMessage], + ) + + # The constructor-like dict should be a plain dict, NOT an AIMessage + assert isinstance(deserialized.metadata["data"], dict) + assert deserialized.metadata["data"] == malicious_constructor + + def test_constructor_in_content_not_instantiated(self) -> None: + """Verify constructor-like dict in message content is not instantiated.""" + malicious_constructor = { + "lc": 1, + "type": "constructor", + "id": ["langchain_core", "messages", "human", "HumanMessage"], + "kwargs": {"content": "injected"}, + } + + msg = AIMessage( + content="Hello", + additional_kwargs={"nested": malicious_constructor}, + ) + + serialized = dumpd(msg) + deserialized = load( + serialized, + secrets_from_env=True, + allowed_objects=[AIMessage, HumanMessage], + ) + + # The constructor-like dict should be a plain dict, NOT a HumanMessage + assert isinstance(deserialized.additional_kwargs["nested"], dict) + assert deserialized.additional_kwargs["nested"] == malicious_constructor + + +def test_allowed_objects() -> None: + # Core object + msg = AIMessage(content="foo") + serialized = dumpd(msg) + assert load(serialized) == msg + assert load(serialized, allowed_objects=[AIMessage]) == msg + assert load(serialized, allowed_objects="core") == msg + + with pytest.raises(ValueError, match="not allowed"): + load(serialized, allowed_objects=[]) + with pytest.raises(ValueError, match="not allowed"): + load(serialized, allowed_objects=[Document]) diff --git a/libs/core/tests/unit_tests/load/test_serializable.py b/libs/core/tests/unit_tests/load/test_serializable.py index 8630df1b47e..ee2618a61d2 100644 --- a/libs/core/tests/unit_tests/load/test_serializable.py +++ b/libs/core/tests/unit_tests/load/test_serializable.py @@ -1,12 +1,19 @@ import json +from typing import Any import pytest from pydantic import BaseModel, ConfigDict, Field, SecretStr -from langchain_core.load import Serializable, dumpd, dumps, load +from langchain_core.documents import Document +from langchain_core.load import InitValidator, Serializable, dumpd, dumps, load, loads from langchain_core.load.serializable import _is_field_useful from langchain_core.messages import AIMessage from langchain_core.outputs import ChatGeneration, Generation +from langchain_core.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, +) class NonBoolObj: @@ -145,10 +152,17 @@ def test_simple_deserialization() -> None: "lc": 1, "type": "constructor", } - new_foo = load(serialized_foo, valid_namespaces=["tests"]) + new_foo = load(serialized_foo, allowed_objects=[Foo], valid_namespaces=["tests"]) assert new_foo == foo +def test_disallowed_deserialization() -> None: + foo = Foo(bar=1, baz="hello") + serialized_foo = dumpd(foo) + with pytest.raises(ValueError, match="not allowed"): + load(serialized_foo, allowed_objects=[], valid_namespaces=["tests"]) + + class Foo2(Serializable): bar: int baz: str @@ -170,6 +184,7 @@ def test_simple_deserialization_with_additional_imports() -> None: } new_foo = load( serialized_foo, + allowed_objects=[Foo2], valid_namespaces=["tests"], additional_import_mappings={ ("tests", "unit_tests", "load", "test_serializable", "Foo"): ( @@ -223,7 +238,7 @@ def test_serialization_with_pydantic() -> None: ) ) ser = dumpd(llm_response) - deser = load(ser) + deser = load(ser, allowed_objects=[ChatGeneration, AIMessage]) assert isinstance(deser, ChatGeneration) assert deser.message.content assert deser.message.additional_kwargs["parsed"] == my_model.model_dump() @@ -260,8 +275,8 @@ def test_serialization_with_ignore_unserializable_fields() -> None: ] ] } - ser = dumpd(data) - deser = load(ser, ignore_unserializable_fields=True) + # Load directly (no dumpd - this is already serialized data) + deser = load(data, allowed_objects=[AIMessage], ignore_unserializable_fields=True) assert deser == { "messages": [ [ @@ -365,3 +380,514 @@ def test_dumps_mixed_data_structure() -> None: # Primitives should remain unchanged assert parsed["list"] == [1, 2, {"nested": "value"}] assert parsed["primitive"] == "string" + + +def test_document_normal_metadata_allowed() -> None: + """Test that `Document` metadata without `'lc'` key works fine.""" + doc = Document( + page_content="Hello world", + metadata={"source": "test.txt", "page": 1, "nested": {"key": "value"}}, + ) + serialized = dumpd(doc) + + loaded = load(serialized, allowed_objects=[Document]) + assert loaded.page_content == "Hello world" + + expected = {"source": "test.txt", "page": 1, "nested": {"key": "value"}} + assert loaded.metadata == expected + + +class TestEscaping: + """Tests that escape-based serialization prevents injection attacks. + + When user data contains an `'lc'` key, it's escaped during serialization + (wrapped in `{"__lc_escaped__": ...}`). During deserialization, escaped + dicts are unwrapped and returned as plain dicts - NOT instantiated as + LC objects. + """ + + def test_document_metadata_with_lc_key_escaped(self) -> None: + """Test that `Document` metadata with `'lc'` key round-trips as plain dict.""" + # User data that looks like an LC constructor - should be escaped, not executed + suspicious_metadata = {"lc": 1, "type": "constructor", "id": ["some", "module"]} + doc = Document(page_content="test", metadata=suspicious_metadata) + + # Serialize - should escape the metadata + serialized = dumpd(doc) + assert serialized["kwargs"]["metadata"] == { + "__lc_escaped__": suspicious_metadata + } + + # Deserialize - should restore original metadata as plain dict + loaded = load(serialized, allowed_objects=[Document]) + assert loaded.metadata == suspicious_metadata # Plain dict, not instantiated + + def test_document_metadata_with_nested_lc_key_escaped(self) -> None: + """Test that nested `'lc'` key in `Document` metadata is escaped.""" + suspicious_nested = {"lc": 1, "type": "constructor", "id": ["some", "module"]} + doc = Document(page_content="test", metadata={"nested": suspicious_nested}) + + serialized = dumpd(doc) + # The nested dict with 'lc' key should be escaped + assert serialized["kwargs"]["metadata"]["nested"] == { + "__lc_escaped__": suspicious_nested + } + + loaded = load(serialized, allowed_objects=[Document]) + assert loaded.metadata == {"nested": suspicious_nested} + + def test_document_metadata_with_lc_key_in_list_escaped(self) -> None: + """Test that `'lc'` key in list items within `Document` metadata is escaped.""" + suspicious_item = {"lc": 1, "type": "constructor", "id": ["some", "module"]} + doc = Document(page_content="test", metadata={"items": [suspicious_item]}) + + serialized = dumpd(doc) + assert serialized["kwargs"]["metadata"]["items"][0] == { + "__lc_escaped__": suspicious_item + } + + loaded = load(serialized, allowed_objects=[Document]) + assert loaded.metadata == {"items": [suspicious_item]} + + def test_malicious_payload_not_instantiated(self) -> None: + """Test that malicious LC-like structures in user data are NOT instantiated.""" + # An attacker might craft a payload with a valid AIMessage structure in metadata + malicious_data = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "document", "Document"], + "kwargs": { + "page_content": "test", + "metadata": { + # This looks like a valid LC object but is in escaped form + "__lc_escaped__": { + "lc": 1, + "type": "constructor", + "id": ["langchain_core", "messages", "ai", "AIMessage"], + "kwargs": {"content": "injected message"}, + } + }, + }, + } + + # Even though AIMessage is allowed, the metadata should remain as dict + loaded = load(malicious_data, allowed_objects=[Document, AIMessage]) + assert loaded.page_content == "test" + # The metadata is the original dict (unescaped), NOT an AIMessage instance + assert loaded.metadata == { + "lc": 1, + "type": "constructor", + "id": ["langchain_core", "messages", "ai", "AIMessage"], + "kwargs": {"content": "injected message"}, + } + assert not isinstance(loaded.metadata, AIMessage) + + def test_message_additional_kwargs_with_lc_key_escaped(self) -> None: + """Test that `AIMessage` `additional_kwargs` with `'lc'` is escaped.""" + suspicious_data = {"lc": 1, "type": "constructor", "id": ["x", "y"]} + msg = AIMessage( + content="Hello", + additional_kwargs={"data": suspicious_data}, + ) + + serialized = dumpd(msg) + assert serialized["kwargs"]["additional_kwargs"]["data"] == { + "__lc_escaped__": suspicious_data + } + + loaded = load(serialized, allowed_objects=[AIMessage]) + assert loaded.additional_kwargs == {"data": suspicious_data} + + def test_message_response_metadata_with_lc_key_escaped(self) -> None: + """Test that `AIMessage` `response_metadata` with `'lc'` is escaped.""" + suspicious_data = {"lc": 1, "type": "constructor", "id": ["x", "y"]} + msg = AIMessage(content="Hello", response_metadata=suspicious_data) + + serialized = dumpd(msg) + assert serialized["kwargs"]["response_metadata"] == { + "__lc_escaped__": suspicious_data + } + + loaded = load(serialized, allowed_objects=[AIMessage]) + assert loaded.response_metadata == suspicious_data + + def test_double_escape_handling(self) -> None: + """Test that data containing escape key itself is properly handled.""" + # User data that contains our escape key + data_with_escape_key = {"__lc_escaped__": "some_value"} + doc = Document(page_content="test", metadata=data_with_escape_key) + + serialized = dumpd(doc) + # Should be double-escaped since it looks like an escaped dict + assert serialized["kwargs"]["metadata"] == { + "__lc_escaped__": {"__lc_escaped__": "some_value"} + } + + loaded = load(serialized, allowed_objects=[Document]) + assert loaded.metadata == {"__lc_escaped__": "some_value"} + + +class TestDumpdEscapesLcKeyInPlainDicts: + """Tests that `dumpd()` escapes `'lc'` keys in plain dict kwargs.""" + + def test_normal_message_not_escaped(self) -> None: + """Test that normal `AIMessage` without `'lc'` key is not escaped.""" + msg = AIMessage( + content="Hello", + additional_kwargs={"tool_calls": []}, + response_metadata={"model": "gpt-4"}, + ) + serialized = dumpd(msg) + assert serialized["kwargs"]["content"] == "Hello" + # No escape wrappers for normal data + assert "__lc_escaped__" not in str(serialized) + + def test_document_metadata_with_lc_key_escaped(self) -> None: + """Test that `Document` with `'lc'` key in metadata is escaped.""" + doc = Document( + page_content="test", + metadata={"lc": 1, "type": "constructor"}, + ) + + serialized = dumpd(doc) + # Should be escaped, not blocked + assert serialized["kwargs"]["metadata"] == { + "__lc_escaped__": {"lc": 1, "type": "constructor"} + } + + def test_document_metadata_with_nested_lc_key_escaped(self) -> None: + """Test that `Document` with nested `'lc'` in metadata is escaped.""" + doc = Document( + page_content="test", + metadata={"nested": {"lc": 1}}, + ) + + serialized = dumpd(doc) + assert serialized["kwargs"]["metadata"]["nested"] == { + "__lc_escaped__": {"lc": 1} + } + + def test_message_additional_kwargs_with_lc_key_escaped(self) -> None: + """Test `AIMessage` with `'lc'` in `additional_kwargs` is escaped.""" + msg = AIMessage( + content="Hello", + additional_kwargs={"malicious": {"lc": 1}}, + ) + + serialized = dumpd(msg) + assert serialized["kwargs"]["additional_kwargs"]["malicious"] == { + "__lc_escaped__": {"lc": 1} + } + + def test_message_response_metadata_with_lc_key_escaped(self) -> None: + """Test `AIMessage` with `'lc'` in `response_metadata` is escaped.""" + msg = AIMessage( + content="Hello", + response_metadata={"lc": 1}, + ) + + serialized = dumpd(msg) + assert serialized["kwargs"]["response_metadata"] == { + "__lc_escaped__": {"lc": 1} + } + + +class TestInitValidator: + """Tests for `init_validator` on `load()` and `loads()`.""" + + def test_init_validator_allows_valid_kwargs(self) -> None: + """Test that `init_validator` returning None allows deserialization.""" + msg = AIMessage(content="Hello") + serialized = dumpd(msg) + + def allow_all(_class_path: tuple[str, ...], _kwargs: dict[str, Any]) -> None: + pass # Allow all by doing nothing + + loaded = load(serialized, allowed_objects=[AIMessage], init_validator=allow_all) + assert loaded == msg + + def test_init_validator_blocks_deserialization(self) -> None: + """Test that `init_validator` can block deserialization by raising.""" + doc = Document(page_content="test", metadata={"source": "test.txt"}) + serialized = dumpd(doc) + + def block_metadata( + _class_path: tuple[str, ...], kwargs: dict[str, Any] + ) -> None: + if "metadata" in kwargs: + msg = "Metadata not allowed" + raise ValueError(msg) + + with pytest.raises(ValueError, match="Metadata not allowed"): + load(serialized, allowed_objects=[Document], init_validator=block_metadata) + + def test_init_validator_receives_correct_class_path(self) -> None: + """Test that `init_validator` receives the correct class path.""" + msg = AIMessage(content="Hello") + serialized = dumpd(msg) + + received_class_paths: list[tuple[str, ...]] = [] + + def capture_class_path( + class_path: tuple[str, ...], _kwargs: dict[str, Any] + ) -> None: + received_class_paths.append(class_path) + + load(serialized, allowed_objects=[AIMessage], init_validator=capture_class_path) + + assert len(received_class_paths) == 1 + assert received_class_paths[0] == ( + "langchain", + "schema", + "messages", + "AIMessage", + ) + + def test_init_validator_receives_correct_kwargs(self) -> None: + """Test that `init_validator` receives the kwargs dict.""" + msg = AIMessage(content="Hello world", name="test_name") + serialized = dumpd(msg) + + received_kwargs: list[dict[str, Any]] = [] + + def capture_kwargs( + _class_path: tuple[str, ...], kwargs: dict[str, Any] + ) -> None: + received_kwargs.append(kwargs) + + load(serialized, allowed_objects=[AIMessage], init_validator=capture_kwargs) + + assert len(received_kwargs) == 1 + assert "content" in received_kwargs[0] + assert received_kwargs[0]["content"] == "Hello world" + assert "name" in received_kwargs[0] + assert received_kwargs[0]["name"] == "test_name" + + def test_init_validator_with_loads(self) -> None: + """Test that `init_validator` works with `loads()` function.""" + doc = Document(page_content="test", metadata={"key": "value"}) + json_str = dumps(doc) + + def block_metadata( + _class_path: tuple[str, ...], kwargs: dict[str, Any] + ) -> None: + if "metadata" in kwargs: + msg = "Metadata not allowed" + raise ValueError(msg) + + with pytest.raises(ValueError, match="Metadata not allowed"): + loads(json_str, allowed_objects=[Document], init_validator=block_metadata) + + def test_init_validator_none_allows_all(self) -> None: + """Test that `init_validator=None` (default) allows all kwargs.""" + msg = AIMessage(content="Hello") + serialized = dumpd(msg) + + # Should work without init_validator + loaded = load(serialized, allowed_objects=[AIMessage]) + assert loaded == msg + + def test_init_validator_type_alias_exists(self) -> None: + """Test that `InitValidator` type alias is exported and usable.""" + + def my_validator(_class_path: tuple[str, ...], _kwargs: dict[str, Any]) -> None: + pass + + validator_typed: InitValidator = my_validator + assert callable(validator_typed) + + def test_init_validator_blocks_specific_class(self) -> None: + """Test blocking deserialization for a specific class.""" + doc = Document(page_content="test", metadata={"source": "test.txt"}) + serialized = dumpd(doc) + + def block_documents( + class_path: tuple[str, ...], _kwargs: dict[str, Any] + ) -> None: + if class_path == ("langchain", "schema", "document", "Document"): + msg = "Documents not allowed" + raise ValueError(msg) + + with pytest.raises(ValueError, match="Documents not allowed"): + load(serialized, allowed_objects=[Document], init_validator=block_documents) + + +class TestJinja2SecurityBlocking: + """Tests blocking Jinja2 templates by default.""" + + def test_fstring_template_allowed(self) -> None: + """Test that f-string templates deserialize successfully.""" + # Serialized ChatPromptTemplate with f-string format + serialized = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "prompts", "chat", "ChatPromptTemplate"], + "kwargs": { + "input_variables": ["name"], + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate", + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate", + ], + "kwargs": { + "input_variables": ["name"], + "template": "Hello {name}", + "template_format": "f-string", + }, + } + }, + } + ], + }, + } + + # f-string should deserialize successfully + loaded = load( + serialized, + allowed_objects=[ + ChatPromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, + ], + ) + assert isinstance(loaded, ChatPromptTemplate) + assert loaded.input_variables == ["name"] + + def test_jinja2_template_blocked(self) -> None: + """Test that Jinja2 templates are blocked by default.""" + # Malicious serialized payload attempting to use jinja2 + malicious_serialized = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "prompts", "chat", "ChatPromptTemplate"], + "kwargs": { + "input_variables": ["name"], + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate", + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate", + ], + "kwargs": { + "input_variables": ["name"], + "template": "{{ name }}", + "template_format": "jinja2", + }, + } + }, + } + ], + }, + } + + # jinja2 should be blocked by default + with pytest.raises(ValueError, match="Jinja2 templates are not allowed"): + load( + malicious_serialized, + allowed_objects=[ + ChatPromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, + ], + ) + + def test_jinja2_blocked_standalone_prompt_template(self) -> None: + """Test blocking Jinja2 on standalone `PromptTemplate`.""" + serialized_jinja2 = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "prompts", "prompt", "PromptTemplate"], + "kwargs": { + "input_variables": ["name"], + "template": "{{ name }}", + "template_format": "jinja2", + }, + } + + serialized_fstring = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "prompts", "prompt", "PromptTemplate"], + "kwargs": { + "input_variables": ["name"], + "template": "{name}", + "template_format": "f-string", + }, + } + + # f-string should work + loaded = load( + serialized_fstring, + allowed_objects=[PromptTemplate], + ) + assert isinstance(loaded, PromptTemplate) + assert loaded.template == "{name}" + + # jinja2 should be blocked by default + with pytest.raises(ValueError, match="Jinja2 templates are not allowed"): + load( + serialized_jinja2, + allowed_objects=[PromptTemplate], + ) + + def test_jinja2_blocked_by_default(self) -> None: + """Test that Jinja2 templates are blocked by default.""" + serialized_jinja2 = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "prompts", "prompt", "PromptTemplate"], + "kwargs": { + "input_variables": ["name"], + "template": "{{ name }}", + "template_format": "jinja2", + }, + } + + serialized_fstring = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "prompts", "prompt", "PromptTemplate"], + "kwargs": { + "input_variables": ["name"], + "template": "{name}", + "template_format": "f-string", + }, + } + + # f-string should work + loaded = load(serialized_fstring, allowed_objects=[PromptTemplate]) + assert isinstance(loaded, PromptTemplate) + assert loaded.template == "{name}" + + # jinja2 should be blocked by default + with pytest.raises(ValueError, match="Jinja2 templates are not allowed"): + load(serialized_jinja2, allowed_objects=[PromptTemplate]) diff --git a/libs/core/tests/unit_tests/messages/test_ai.py b/libs/core/tests/unit_tests/messages/test_ai.py index 742f23b68be..31f8b3e1bb5 100644 --- a/libs/core/tests/unit_tests/messages/test_ai.py +++ b/libs/core/tests/unit_tests/messages/test_ai.py @@ -47,7 +47,7 @@ def test_serdes_message() -> None: } actual = dumpd(msg) assert actual == expected - assert load(actual) == msg + assert load(actual, allowed_objects=[AIMessage]) == msg def test_serdes_message_chunk() -> None: @@ -102,7 +102,7 @@ def test_serdes_message_chunk() -> None: } actual = dumpd(chunk) assert actual == expected - assert load(actual) == chunk + assert load(actual, allowed_objects=[AIMessageChunk]) == chunk def test_add_usage_both_none() -> None: diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 0e2954cf5c3..33fe267353b 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -1123,7 +1123,7 @@ def test_data_prompt_template_deserializable() -> None: ) ] ) - ) + ), ) diff --git a/libs/core/tests/unit_tests/prompts/test_dict.py b/libs/core/tests/unit_tests/prompts/test_dict.py index 581e418b6b5..13761223421 100644 --- a/libs/core/tests/unit_tests/prompts/test_dict.py +++ b/libs/core/tests/unit_tests/prompts/test_dict.py @@ -31,4 +31,4 @@ def test_deserialize_legacy() -> None: expected = DictPromptTemplate( template={"type": "audio", "audio": "{audio_data}"}, template_format="f-string" ) - assert load(ser) == expected + assert load(ser, allowed_objects=[DictPromptTemplate]) == expected diff --git a/libs/core/tests/unit_tests/prompts/test_image.py b/libs/core/tests/unit_tests/prompts/test_image.py index 746c099b803..11f0f578735 100644 --- a/libs/core/tests/unit_tests/prompts/test_image.py +++ b/libs/core/tests/unit_tests/prompts/test_image.py @@ -11,7 +11,7 @@ def test_image_prompt_template_deserializable() -> None: ChatPromptTemplate.from_messages( [("system", [{"type": "image", "image_url": "{img}"}])] ) - ) + ), ) @@ -105,5 +105,5 @@ def test_image_prompt_template_deserializable_old() -> None: "input_variables": ["img", "input"], }, } - ) + ), ) diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index 5df7631d9d4..5e6419d8edc 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -82,7 +82,6 @@ def test_structured_prompt_dict() -> None: assert loads(dumps(prompt)).model_dump() == prompt.model_dump() chain = loads(dumps(prompt)) | model - assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 92b00c84035..a2313b7122d 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -925,7 +925,7 @@ def test_tool_message_serdes() -> None: }, } assert dumpd(message) == ser_message - assert load(dumpd(message)) == message + assert load(dumpd(message), allowed_objects=[ToolMessage]) == message class BadObject: @@ -954,7 +954,7 @@ def test_tool_message_ser_non_serializable() -> None: } assert dumpd(message) == ser_message with pytest.raises(NotImplementedError): - load(dumpd(ser_message)) + load(dumpd(message), allowed_objects=[ToolMessage]) def test_tool_message_to_dict() -> None: diff --git a/libs/langchain/langchain_classic/hub.py b/libs/langchain/langchain_classic/hub.py index f23887a4546..74490cbff74 100644 --- a/libs/langchain/langchain_classic/hub.py +++ b/libs/langchain/langchain_classic/hub.py @@ -4,7 +4,7 @@ from __future__ import annotations import json from collections.abc import Sequence -from typing import Any +from typing import Any, Literal from langchain_core.load.dump import dumps from langchain_core.load.load import loads @@ -139,7 +139,8 @@ def pull( if hasattr(client, "pull_repo"): # >= 0.1.15 res_dict = client.pull_repo(owner_repo_commit) - obj = loads(json.dumps(res_dict["manifest"])) + allowed_objects: Literal["all", "core"] = "all" if include_model else "core" + obj = loads(json.dumps(res_dict["manifest"]), allowed_objects=allowed_objects) if isinstance(obj, BasePromptTemplate): if obj.metadata is None: obj.metadata = {} diff --git a/libs/partners/anthropic/tests/unit_tests/test_standard.py b/libs/partners/anthropic/tests/unit_tests/test_standard.py index 3b9071c5798..35844ab3e59 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_standard.py +++ b/libs/partners/anthropic/tests/unit_tests/test_standard.py @@ -7,6 +7,8 @@ from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-unt from langchain_anthropic import ChatAnthropic +_MODEL = "claude-3-haiku-20240307" + class TestAnthropicStandard(ChatModelUnitTests): """Use the standard chat model unit tests against the `ChatAnthropic` class.""" @@ -17,7 +19,15 @@ class TestAnthropicStandard(ChatModelUnitTests): @property def chat_model_params(self) -> dict: - return {"model": "claude-3-haiku-20240307"} + return {"model": _MODEL} + + @property + def init_from_env_params(self) -> tuple[dict, dict, dict]: + return ( + {"ANTHROPIC_API_KEY": "test"}, + {"model": _MODEL}, + {"anthropic_api_key": "test"}, + ) @pytest.mark.benchmark diff --git a/libs/partners/anthropic/uv.lock b/libs/partners/anthropic/uv.lock index ee29d8a3c63..9f7c2b1c6f7 100644 --- a/libs/partners/anthropic/uv.lock +++ b/libs/partners/anthropic/uv.lock @@ -495,7 +495,7 @@ wheels = [ [[package]] name = "langchain" -version = "1.1.3" +version = "1.2.0" source = { editable = "../../langchain_v1" } dependencies = [ { name = "langchain-core" }, @@ -642,7 +642,7 @@ typing = [ [[package]] name = "langchain-core" -version = "1.2.0" +version = "1.2.4" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, @@ -702,7 +702,7 @@ typing = [ [[package]] name = "langchain-tests" -version = "1.1.0" +version = "1.1.1" source = { editable = "../../standard-tests" } dependencies = [ { name = "httpx" }, diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index 0a9991c7598..c219f88a13e 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -274,6 +274,7 @@ def test_groq_serialization() -> None: dump, valid_namespaces=["langchain_groq"], secrets_map={"GROQ_API_KEY": api_key2}, + allowed_objects="all", ) assert type(llm2) is ChatGroq diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 0f7fad73c99..1bd36e774b5 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -1360,7 +1360,7 @@ def test_structured_outputs_parser() -> None: partial(_oai_structured_outputs_parser, schema=GenerateUsername) ) serialized = dumps(llm_output) - deserialized = loads(serialized) + deserialized = loads(serialized, allowed_objects=[ChatGeneration, AIMessage]) assert isinstance(deserialized, ChatGeneration) result = output_parser.invoke(cast(AIMessage, deserialized.message)) assert result == parsed_response diff --git a/libs/partners/openai/tests/unit_tests/test_load.py b/libs/partners/openai/tests/unit_tests/test_load.py index b1aabd07ccc..fa2c20ddbf8 100644 --- a/libs/partners/openai/tests/unit_tests/test_load.py +++ b/libs/partners/openai/tests/unit_tests/test_load.py @@ -1,4 +1,7 @@ from langchain_core.load import dumpd, dumps, load, loads +from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.runnables import RunnableSequence from langchain_openai import ChatOpenAI, OpenAI @@ -6,7 +9,11 @@ from langchain_openai import ChatOpenAI, OpenAI def test_loads_openai_llm() -> None: llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello", top_p=0.8) # type: ignore[call-arg] llm_string = dumps(llm) - llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"}) + llm2 = loads( + llm_string, + secrets_map={"OPENAI_API_KEY": "hello"}, + allowed_objects=[OpenAI], + ) assert llm2.dict() == llm.dict() llm_string_2 = dumps(llm2) @@ -17,7 +24,11 @@ def test_loads_openai_llm() -> None: def test_load_openai_llm() -> None: llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg] llm_obj = dumpd(llm) - llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"}) + llm2 = load( + llm_obj, + secrets_map={"OPENAI_API_KEY": "hello"}, + allowed_objects=[OpenAI], + ) assert llm2.dict() == llm.dict() assert dumpd(llm2) == llm_obj @@ -27,7 +38,11 @@ def test_load_openai_llm() -> None: def test_loads_openai_chat() -> None: llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg] llm_string = dumps(llm) - llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"}) + llm2 = loads( + llm_string, + secrets_map={"OPENAI_API_KEY": "hello"}, + allowed_objects=[ChatOpenAI], + ) assert llm2.dict() == llm.dict() llm_string_2 = dumps(llm2) @@ -38,8 +53,85 @@ def test_loads_openai_chat() -> None: def test_load_openai_chat() -> None: llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg] llm_obj = dumpd(llm) - llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"}) + llm2 = load( + llm_obj, + secrets_map={"OPENAI_API_KEY": "hello"}, + allowed_objects=[ChatOpenAI], + ) assert llm2.dict() == llm.dict() assert dumpd(llm2) == llm_obj assert isinstance(llm2, ChatOpenAI) + + +def test_loads_runnable_sequence_prompt_model() -> None: + """Test serialization/deserialization of a chain: + + `prompt | model (RunnableSequence)` + """ + prompt = ChatPromptTemplate.from_messages([("user", "Hello, {name}!")]) + model = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg] + chain = prompt | model + + # Verify the chain is a RunnableSequence + assert isinstance(chain, RunnableSequence) + + # Serialize + chain_string = dumps(chain) + + # Deserialize + # (ChatPromptTemplate contains HumanMessagePromptTemplate and PromptTemplate) + chain2 = loads( + chain_string, + secrets_map={"OPENAI_API_KEY": "hello"}, + allowed_objects=[ + RunnableSequence, + ChatPromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, + ChatOpenAI, + ], + ) + + # Verify structure + assert isinstance(chain2, RunnableSequence) + assert isinstance(chain2.first, ChatPromptTemplate) + assert isinstance(chain2.last, ChatOpenAI) + + # Verify round-trip serialization + assert dumps(chain2) == chain_string + + +def test_load_runnable_sequence_prompt_model() -> None: + """Test load() with a chain: + + `prompt | model (RunnableSequence)`. + """ + prompt = ChatPromptTemplate.from_messages([("user", "Tell me about {topic}")]) + model = ChatOpenAI(model="gpt-4o-mini", temperature=0.7, openai_api_key="hello") # type: ignore[call-arg] + chain = prompt | model + + # Serialize + chain_obj = dumpd(chain) + + # Deserialize + # (ChatPromptTemplate contains HumanMessagePromptTemplate and PromptTemplate) + chain2 = load( + chain_obj, + secrets_map={"OPENAI_API_KEY": "hello"}, + allowed_objects=[ + RunnableSequence, + ChatPromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, + ChatOpenAI, + ], + ) + + # Verify structure + assert isinstance(chain2, RunnableSequence) + assert isinstance(chain2.first, ChatPromptTemplate) + assert isinstance(chain2.last, ChatOpenAI) + + # Verify round-trip serialization + assert dumpd(chain2) == chain_obj diff --git a/libs/partners/xai/uv.lock b/libs/partners/xai/uv.lock index 8ba57d53f4f..b39c7233d20 100644 --- a/libs/partners/xai/uv.lock +++ b/libs/partners/xai/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", @@ -621,7 +621,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.2.3" +version = "1.2.4" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, diff --git a/libs/standard-tests/langchain_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_tests/unit_tests/chat_models.py index b3b0a24a7d4..9f6d1e4534b 100644 --- a/libs/standard-tests/langchain_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/unit_tests/chat_models.py @@ -1126,7 +1126,10 @@ class ChatModelUnitTests(ChatModelTests): assert ( model.dict() == load( - dumpd(model), valid_namespaces=model.get_lc_namespace()[:1] + dumpd(model), + valid_namespaces=model.get_lc_namespace()[:1], + allowed_objects="all", + secrets_from_env=True, ).dict() )