From c979c6187b6d82f4bef35b10b84b39fa44806b22 Mon Sep 17 00:00:00 2001 From: Nick Hollon Date: Tue, 5 May 2026 14:36:58 -0400 Subject: [PATCH] fix(core, langchain): harden `load()` against untrusted manifests (#37197) --- libs/core/langchain_core/load/_validation.py | 51 ++-- libs/core/langchain_core/load/load.py | 237 +++++++++++----- libs/core/langchain_core/runnables/base.py | 14 + libs/core/langchain_core/runnables/history.py | 17 +- .../core/langchain_core/tracers/log_stream.py | 4 +- .../unit_tests/load/test_serializable.py | 259 +++++++++++++++++- libs/langchain/langchain_classic/hub.py | 38 ++- libs/langchain/tests/unit_tests/test_hub.py | 28 ++ 8 files changed, 550 insertions(+), 98 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/test_hub.py diff --git a/libs/core/langchain_core/load/_validation.py b/libs/core/langchain_core/load/_validation.py index a8502d34b29..8bf9f76a486 100644 --- a/libs/core/langchain_core/load/_validation.py +++ b/libs/core/langchain_core/load/_validation.py @@ -18,7 +18,7 @@ During deserialization, escaped dicts are unwrapped and returned as plain dicts, NOT instantiated as LC objects. """ -from typing import Any +from typing import Any, cast from langchain_core.load.serializable import ( Serializable, @@ -102,16 +102,25 @@ def _serialize_value(obj: Any) -> Any: return to_json_not_implemented(obj) -def _is_lc_secret(obj: Any) -> bool: - """Check if an object is a LangChain secret marker.""" - expected_num_keys = 3 - return ( - isinstance(obj, dict) - and obj.get("lc") == 1 - and obj.get("type") == "secret" - and "id" in obj - and len(obj) == expected_num_keys - ) +def _get_secret_keys(obj: Serializable) -> set[str]: + """Return the merged set of constructor kwarg names declared as secrets. + + Mirrors the MRO walk in `Serializable.to_json` so the keys returned here + match the keys whose values `_replace_secrets` rewrites into secret + markers. Used by `_serialize_lc_object` to decide which kwargs to skip + when escaping user data. + """ + secrets: dict[str, str] = {} + model_fields = type(obj).model_fields + for cls in [None, *obj.__class__.mro()]: + if cls is Serializable: + break + this = cast("Serializable", obj if cls is None else super(cls, obj)) + secrets.update(this.lc_secrets) + for key in list(secrets): + if (key in model_fields) and (alias := model_fields[key].alias) is not None: + secrets[alias] = secrets[key] + return set(secrets) def _serialize_lc_object(obj: Any) -> dict[str, Any]: @@ -124,9 +133,15 @@ def _serialize_lc_object(obj: Any) -> dict[str, Any]: The serialized dict with user data in kwargs escaped as needed. Note: - Kwargs values are processed with `_serialize_value` to escape user data (like - metadata) that contains `'lc'` keys. Secret fields (from `lc_secrets`) are - skipped because `to_json()` replaces their values with secret markers. + Kwargs values are processed with `_serialize_value` to escape user data + (like metadata) that contains `'lc'` keys. Secret fields are identified + by the class's declared `lc_secrets` and skipped because `to_json()` + already converted their values to secret markers. + + The check is key-based rather than shape-based. A shape-based check + ("this dict looks like a secret marker") can be forged by user data, + letting attacker-controlled free-form dicts bypass escaping and reach + the Reviver. """ if not isinstance(obj, Serializable): msg = f"Expected Serializable, got {type(obj)}" @@ -134,11 +149,13 @@ def _serialize_lc_object(obj: Any) -> dict[str, Any]: serialized: dict[str, Any] = dict(obj.to_json()) - # Process kwargs to escape user data that could be confused with LC objects - # Skip secret fields - to_json() already converted them to secret markers + # Process kwargs to escape user data that could be confused with LC objects. + # Skip kwargs declared as secrets - `to_json()` already replaced their + # values with secret markers via `_replace_secrets`. if serialized.get("type") == "constructor" and "kwargs" in serialized: + secret_keys = _get_secret_keys(obj) serialized["kwargs"] = { - k: v if _is_lc_secret(v) else _serialize_value(v) + k: v if k in secret_keys else _serialize_value(v) for k, v in serialized["kwargs"].items() } diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index 1a6f1b23b94..0c992e8c7aa 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -11,48 +11,58 @@ is a list of strings representing the module path and class name. For example: When deserializing, the class path from the JSON `'id'` field is checked against an allowlist. If the class is not in the allowlist, deserialization raises a `ValueError`. -## Security model +## Threat model -!!! warning "Exercise caution with untrusted input" +A serialized LangChain payload crosses a trust boundary because the manifest +may contain serialized objects and configuration that affect runtime behavior. +For example, a payload can configure a chat model with a custom `base_url`, +custom headers, a different model name, or other constructor arguments. These +are supported features, but they also mean the payload contents should be +treated as executable configuration rather than plain text. - These functions deserialize by instantiating Python objects, which means - constructors (`__init__`) and validators may run and can trigger side effects. - With the default settings, deserialization is restricted to a core allowlist - of `langchain_core` types (for example: messages, documents, and prompts) - defined in `langchain_core.load.mapping`. +Concretely, deserialization instantiates Python objects, so any constructor +(`__init__`) or validator on an allowed class can run during `load()`. A +crafted payload that is allowed to reach an unintended class — or an intended +class with attacker-controlled kwargs — could cause network calls, file +operations, or environment-variable access while the object is being built. - If you broaden `allowed_objects` (for example, by using `'all'` or adding - additional classes), treat the serialized payload as a manifest and only - deserialize data that comes from a trusted source. A crafted payload that - is allowed to instantiate unintended classes could cause network calls, - file operations, or environment variable access during `__init__`. +!!! warning "Do not use with untrusted input" + + If the source is untrusted, avoid calling `load()` / `loads()` on it. If + you must, restrict `allowed_objects` to types that do not execute logic + during init — `allowed_objects='messages'` (or an explicit list of + message classes) is the safe choice. Keep `secrets_from_env=False`. The `allowed_objects` parameter controls which classes can be deserialized: -- **`'core'` (default)**: Allow classes defined in the serialization mappings for - langchain_core. -- **`'all'`**: Allow classes defined in the serialization mappings. This - includes core LangChain types (messages, prompts, documents, etc.) and trusted - partner integrations. See `langchain_core.load.mapping` for the full list. -- **Explicit list of classes**: Only those specific classes are allowed. - -For simple data types like messages and documents, the default allowlist is safe to use. -These classes do not perform side effects during initialization. +- **Explicit list of classes** (recommended for untrusted input): only those + specific classes are allowed. +- **`'messages'`**: chat-message classes only (e.g. `AIMessage`, + `HumanMessage`). Safe for untrusted input. +- **`'core'` (current default)** — *unsafe with untrusted manifests.* + Classes defined in the serialization mappings under `langchain_core` + (messages, documents, prompts, etc.). +- **`'all'`** — *unsafe with untrusted manifests.* Every class in the + serialization mappings, including partner chat models and LLMs and their + constructor kwargs (endpoint URLs, headers, model names, etc.). !!! note "Side effects in allowed classes" - Deserialization calls `__init__` on allowed classes. If those classes perform side - effects during initialization (network calls, file operations, etc.), those side - effects will occur. The allowlist prevents instantiation of classes outside the - allowlist, but does not sandbox the allowed classes themselves. + Deserialization calls `__init__` on allowed classes. If those classes perform + side effects during initialization (network calls, file operations, etc.), + those side effects will occur. The allowlist prevents instantiation of + classes outside the allowlist, but does not sandbox the allowed classes + themselves or constrain their constructor kwargs. Import paths are also validated against trusted namespaces before any module is imported. ### Best practices -- Use the most restrictive `allowed_objects` possible. Prefer an explicit list - of classes over `'core'` or `'all'`. +- Use the most restrictive `allowed_objects` possible. For untrusted input, + pass an explicit list of classes or `'messages'`. `'core'` and `'all'` + are unsafe with untrusted manifests — only use them when the source + serves the entire payload, including its configuration. - Keep `secrets_from_env` set to `False` (the default). If you must use it, ensure the serialized data comes from a fully trusted source, as a crafted payload can read arbitrary environment variables. @@ -101,6 +111,7 @@ from collections.abc import Callable, Iterable from typing import Any, Literal, cast from langchain_core._api import beta +from langchain_core._api.deprecation import warn_deprecated from langchain_core.load._validation import _is_escaped_dict, _unescape_value from langchain_core.load.mapping import ( _JS_SERIALIZABLE_MAPPING, @@ -141,13 +152,31 @@ ALL_SERIALIZABLE_MAPPINGS = { **_JS_SERIALIZABLE_MAPPING, } +# Modern message classes admitted by `allowed_objects='messages'`. Legacy types +# (BaseMessage / BaseMessageChunk, ChatMessage / ChatMessageChunk, FunctionMessage / +# FunctionMessageChunk) are intentionally excluded — `BaseMessage` is abstract and +# the chat/function variants are superseded by `ToolMessage` and tool calling. +_MESSAGES_ALLOWED_CLASS_NAMES = frozenset( + { + "AIMessage", + "AIMessageChunk", + "HumanMessage", + "HumanMessageChunk", + "SystemMessage", + "SystemMessageChunk", + "ToolMessage", + "ToolMessageChunk", + "RemoveMessage", + } +) + # Cache for the default allowed class paths computed from mappings -# Maps mode ("all" or "core") to the cached set of paths +# Maps mode ("all", "core", or "messages") to the cached set of paths _default_class_paths_cache: dict[str, set[tuple[str, ...]]] = {} def _get_default_allowed_class_paths( - allowed_object_mode: Literal["all", "core"], + allowed_object_mode: Literal["all", "core", "messages"], ) -> set[tuple[str, ...]]: """Get the default allowed class paths from the serialization mappings. @@ -155,7 +184,7 @@ def _get_default_allowed_class_paths( by default. Both the legacy paths (keys) and current paths (values) are included. Args: - allowed_object_mode: either `'all'` or `'core'`. + allowed_object_mode: either `'all'`, `'core'`, or `'messages'`. Returns: Set of class path tuples that are allowed by default. @@ -167,6 +196,11 @@ def _get_default_allowed_class_paths( for key, value in ALL_SERIALIZABLE_MAPPINGS.items(): if allowed_object_mode == "core" and value[0] != "langchain_core": continue + if allowed_object_mode == "messages" and ( + value[0] != "langchain_core" + or value[-1] not in _MESSAGES_ALLOWED_CLASS_NAMES + ): + continue allowed_paths.add(key) allowed_paths.add(value) @@ -301,7 +335,9 @@ class Reviver: def __init__( self, - allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core", + allowed_objects: Iterable[AllowedObject] + | Literal["all", "core", "messages"] + | None = None, secrets_map: dict[str, str] | None = None, valid_namespaces: list[str] | None = None, secrets_from_env: bool = False, # noqa: FBT001,FBT002 @@ -313,16 +349,24 @@ class Reviver: ) -> None: """Initialize the reviver. + See the module docstring for the threat model around `load()`/`loads()`: + a serialized payload may carry constructor configuration that affects + runtime behavior (custom `base_url`, headers, model name, etc.). Do not + use `'core'` or `'all'` with untrusted manifests. + Args: allowed_objects: Allowlist of classes that can be deserialized. - - `'core'` (default): Allow classes defined in the serialization - mappings for `langchain_core`. - - `'all'`: Allow classes defined in the serialization mappings. - - This includes core LangChain types (messages, prompts, documents, - etc.) and trusted partner integrations. See + - Explicit list of classes (recommended for untrusted input): + only those specific classes are allowed. + - `'messages'`: chat-message classes only (e.g. `AIMessage`, + `HumanMessage`). Safe for untrusted input. + - `'core'` (current default): unsafe with untrusted manifests. + Classes defined in the serialization mappings under + `langchain_core`. + - `'all'`: unsafe with untrusted manifests. Every class in the + serialization mappings, including partner chat models and + LLMs and their constructor kwargs. See `langchain_core.load.mapping` for the full list. - - Explicit list of classes: Only those specific classes are allowed. secrets_map: A map of secrets to load. Only include the specific secrets the serialized object @@ -352,6 +396,19 @@ class Reviver: Defaults to `default_init_validator` which blocks jinja2 templates. """ + if allowed_objects is None: + warn_deprecated( + since="1.4.0", + message=( + "The default value of `allowed_objects` will change in a future " + "version. Pass an explicit value (e.g., " + "allowed_objects='messages' or allowed_objects='core') to suppress " + "this warning." + ), + pending=True, + ) + allowed_objects = "core" + self.secrets_from_env = secrets_from_env self.secrets_map = secrets_map or {} # By default, only support langchain, but user can pass in additional namespaces @@ -372,10 +429,10 @@ class Reviver: # Compute allowed class paths: # - "all" -> use default paths from mappings (+ additional_import_mappings) # - Explicit list -> compute from those classes - if allowed_objects in ("all", "core"): + if allowed_objects in ("all", "core", "messages"): self.allowed_class_paths: set[tuple[str, ...]] | None = ( _get_default_allowed_class_paths( - cast("Literal['all', 'core']", allowed_objects) + cast("Literal['all', 'core', 'messages']", allowed_objects) ).copy() ) # Add paths from additional_import_mappings to the defaults @@ -512,7 +569,9 @@ class Reviver: def loads( text: str, *, - allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core", + allowed_objects: Iterable[AllowedObject] + | Literal["all", "core", "messages"] + | None = None, secrets_map: dict[str, str] | None = None, valid_namespaces: list[str] | None = None, secrets_from_env: bool = False, @@ -524,30 +583,33 @@ def loads( Equivalent to `load(json.loads(text))`. - Only classes in the allowlist can be instantiated. The default allowlist includes - core LangChain types (messages, prompts, documents, etc.). See + Only classes in the allowlist can be instantiated. The default allowlist + includes core LangChain types (messages, prompts, documents, etc.). See `langchain_core.load.mapping` for the full list. !!! warning "Do not use with untrusted input" - This function instantiates Python objects and can trigger side effects - during deserialization. **Never call `loads()` on data from an untrusted - or unauthenticated source.** See the module-level security model - documentation for details and best practices. + A serialized payload may carry constructor kwargs that affect runtime + behavior (custom `base_url`, headers, model name, etc.), so it should be + treated as executable configuration rather than plain text. If the + source is untrusted, avoid calling `loads()` on it; if you must, pass + `allowed_objects='messages'` or an explicit list of message classes. + See the module-level threat model for details. Args: text: The string to load. allowed_objects: Allowlist of classes that can be deserialized. - - `'core'` (default): Allow classes defined in the serialization mappings - for `langchain_core`. - - `'all'`: Allow classes defined in the serialization mappings. - - This includes core LangChain types (messages, prompts, documents, etc.) - and trusted partner integrations. See `langchain_core.load.mapping` for - the full list. - - - Explicit list of classes: Only those specific classes are allowed. + - Explicit list of classes (recommended for untrusted input): only + those specific classes are allowed. + - `'messages'`: chat-message classes only. Safe for untrusted input. + - `'core'` (current default): unsafe with untrusted manifests. + Classes defined in the serialization mappings under + `langchain_core`. + - `'all'`: unsafe with untrusted manifests. Every class in the + serialization mappings, including partner chat models and LLMs + and their constructor kwargs. See `langchain_core.load.mapping` + for the full list. - `[]`: Disallow all deserialization (will raise on any object). secrets_map: A map of secrets to load. @@ -584,6 +646,19 @@ def loads( Raises: ValueError: If an object's class path is not in the `allowed_objects` allowlist. """ + if allowed_objects is None: + warn_deprecated( + since="1.4.0", + message=( + "The default value of `allowed_objects` will change in a future " + "version. Pass an explicit list of allowed classes (or " + "'messages' for untrusted input that contains only chat " + "messages) to suppress this warning." + ), + pending=True, + ) + allowed_objects = "core" + # Parse JSON and delegate to load() for proper escape handling raw_obj = json.loads(text) return load( @@ -602,7 +677,9 @@ def loads( def load( obj: Any, *, - allowed_objects: Iterable[AllowedObject] | Literal["all", "core"] = "core", + allowed_objects: Iterable[AllowedObject] + | Literal["all", "core", "messages"] + | None = None, secrets_map: dict[str, str] | None = None, valid_namespaces: list[str] | None = None, secrets_from_env: bool = False, @@ -615,30 +692,33 @@ def load( Use this if you already have a parsed JSON object, eg. from `json.load` or `orjson.loads`. - Only classes in the allowlist can be instantiated. The default allowlist includes - core LangChain types (messages, prompts, documents, etc.). See + Only classes in the allowlist can be instantiated. The default allowlist + includes core LangChain types (messages, prompts, documents, etc.). See `langchain_core.load.mapping` for the full list. !!! warning "Do not use with untrusted input" - This function instantiates Python objects and can trigger side effects - during deserialization. **Never call `load()` on data from an untrusted - or unauthenticated source.** See the module-level security model - documentation for details and best practices. + A serialized payload may carry constructor kwargs that affect runtime + behavior (custom `base_url`, headers, model name, etc.), so it should be + treated as executable configuration rather than plain text. If the + source is untrusted, avoid calling `load()` on it; if you must, pass + `allowed_objects='messages'` or an explicit list of message classes. + See the module-level threat model for details. Args: obj: The object to load. allowed_objects: Allowlist of classes that can be deserialized. - - `'core'` (default): Allow classes defined in the serialization mappings - for `langchain_core`. - - `'all'`: Allow classes defined in the serialization mappings. - - This includes core LangChain types (messages, prompts, documents, etc.) - and trusted partner integrations. See `langchain_core.load.mapping` for - the full list. - - - Explicit list of classes: Only those specific classes are allowed. + - Explicit list of classes (recommended for untrusted input): only + those specific classes are allowed. + - `'messages'`: chat-message classes only. Safe for untrusted input. + - `'core'` (current default): unsafe with untrusted manifests. + Classes defined in the serialization mappings under + `langchain_core`. + - `'all'`: unsafe with untrusted manifests. Every class in the + serialization mappings, including partner chat models and LLMs + and their constructor kwargs. See `langchain_core.load.mapping` + for the full list. - `[]`: Disallow all deserialization (will raise on any object). secrets_map: A map of secrets to load. @@ -699,6 +779,19 @@ def load( ) ``` """ + if allowed_objects is None: + warn_deprecated( + since="1.4.0", + message=( + "The default value of `allowed_objects` will change in a future " + "version. Pass an explicit list of allowed classes (or " + "'messages' for untrusted input that contains only chat " + "messages) to suppress this warning." + ), + pending=True, + ) + allowed_objects = "core" + reviver = Reviver( allowed_objects, secrets_map, diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index c63b5c6fce0..ac6374faf9f 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -41,6 +41,7 @@ from pydantic import BaseModel, ConfigDict, Field, RootModel from typing_extensions import override from langchain_core._api import beta_decorator +from langchain_core._api.deprecation import warn_deprecated from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain_core.load.serializable import ( Serializable, @@ -1289,6 +1290,11 @@ class Runnable(ABC, Generic[Input, Output]): A `RunLogPatch` or `RunLog` object. """ + warn_deprecated( + since="1.4.0", + message=("astream_log is deprecated. Use astream instead."), + removal="2.0.0", + ) stream = LogStreamCallbackHandler( auto_close=False, include_names=include_names, @@ -1538,6 +1544,14 @@ class Runnable(ABC, Generic[Input, Output]): **kwargs, ) elif version == "v1": + warn_deprecated( + since="1.4.0", + message=( + "astream_events version='v1' is deprecated. " + "Use version='v2' or astream instead." + ), + removal="2.0.0", + ) # First implementation, built on top of astream_log API # This implementation will be deprecated as of 0.2.0 event_stream = _astream_events_implementation_v1( diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 2521cfae753..8e197c0379b 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -13,6 +13,7 @@ from typing import ( from pydantic import BaseModel from typing_extensions import override +from langchain_core._api.deprecation import warn_deprecated from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.load.load import load from langchain_core.messages import AIMessage, BaseMessage, HumanMessage @@ -320,6 +321,14 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] `RunnableBindingBase` init. """ + warn_deprecated( + since="1.4.0", + message=( + "RunnableWithMessageHistory is deprecated. " + "Use LangGraph's built-in persistence instead." + ), + removal="2.0.0", + ) history_chain: Runnable[Any, Any] = RunnableLambda( self._enter_history, self._aenter_history ).with_config(run_name="load_history") @@ -539,7 +548,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] hist: BaseChatMessageHistory = config["configurable"]["message_history"] # Get the input messages - inputs = load(run.inputs, allowed_objects="all") + inputs = load(run.inputs, allowed_objects="messages") input_messages = self._get_input_messages(inputs) # If historic messages were prepended to the input messages, remove them to # avoid adding duplicate messages to history. @@ -548,7 +557,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] input_messages = input_messages[len(historic_messages) :] # Get the output messages - output_val = load(run.outputs, allowed_objects="all") + output_val = load(run.outputs, allowed_objects="messages") output_messages = self._get_output_messages(output_val) hist.add_messages(input_messages + output_messages) @@ -556,7 +565,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] hist: BaseChatMessageHistory = config["configurable"]["message_history"] # Get the input messages - inputs = load(run.inputs, allowed_objects="all") + inputs = load(run.inputs, allowed_objects="messages") input_messages = self._get_input_messages(inputs) # If historic messages were prepended to the input messages, remove them to # avoid adding duplicate messages to history. @@ -565,7 +574,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] input_messages = input_messages[len(historic_messages) :] # Get the output messages - output_val = load(run.outputs, allowed_objects="all") + output_val = load(run.outputs, allowed_objects="messages") output_messages = self._get_output_messages(output_val) await hist.aadd_messages(input_messages + output_messages) diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 312d416799d..5131815ebdf 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -585,7 +585,7 @@ def _get_standardized_inputs( ) raise NotImplementedError(msg) - inputs = load(run.inputs, allowed_objects="all") + inputs = load(run.inputs, allowed_objects="messages") if run.run_type in {"retriever", "llm", "chat_model"}: return inputs @@ -617,7 +617,7 @@ def _get_standardized_outputs( Returns: An output if returned, otherwise `None`. """ - outputs = load(run.outputs, allowed_objects="all") + outputs = load(run.outputs, allowed_objects="messages") if schema_format == "original": if run.run_type == "prompt" and "output" in outputs: # These were previously dumped before the tracer. diff --git a/libs/core/tests/unit_tests/load/test_serializable.py b/libs/core/tests/unit_tests/load/test_serializable.py index 857587bd32f..b99395a43e0 100644 --- a/libs/core/tests/unit_tests/load/test_serializable.py +++ b/libs/core/tests/unit_tests/load/test_serializable.py @@ -1,13 +1,20 @@ import contextlib +import inspect import json +import warnings from typing import Any import pytest from pydantic import BaseModel, ConfigDict, Field, SecretStr +from langchain_core._api import LangChainDeprecationWarning +from langchain_core._api.deprecation import LangChainPendingDeprecationWarning from langchain_core.documents import Document from langchain_core.load import InitValidator, Serializable, dumpd, dumps, load, loads -from langchain_core.load.load import ALL_SERIALIZABLE_MAPPINGS +from langchain_core.load.load import ( + ALL_SERIALIZABLE_MAPPINGS, + _get_default_allowed_class_paths, +) from langchain_core.load.serializable import _is_field_useful from langchain_core.load.validators import CLASS_INIT_VALIDATORS, _bedrock_validator from langchain_core.messages import AIMessage @@ -17,6 +24,8 @@ from langchain_core.prompts import ( HumanMessagePromptTemplate, PromptTemplate, ) +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.tracers import log_stream class NonBoolObj: @@ -594,6 +603,40 @@ class TestDumpdEscapesLcKeyInPlainDicts: "__lc_escaped__": {"lc": 1} } + def test_fake_secret_marker_in_metadata_is_escaped(self) -> None: + """A free-form dict shaped like a secret marker must not bypass escaping. + + Previously the shape check accepted any value for `id`, letting a + constructor dict nested inside `id` reach the Reviver and get + instantiated on the way back in. + """ + poisoned_metadata = { + "lc": 1, + "type": "secret", + "id": [ + { + "lc": 1, + "type": "constructor", + "id": ["langchain_core", "documents", "base", "Document"], + "kwargs": {"page_content": "injected"}, + } + ], + } + doc = Document(page_content="hello", metadata=poisoned_metadata) + + serialized = dumpd(doc) + # The fake marker must be wrapped in `__lc_escaped__`, not passed + # through as if it were a real secret. + assert serialized["kwargs"]["metadata"] == {"__lc_escaped__": poisoned_metadata} + + # And on round-trip, the nested constructor must not be instantiated: + # the metadata comes back as plain data, even with the most permissive + # allowlist. + roundtripped = load(serialized, allowed_objects="all") + assert isinstance(roundtripped, Document) + assert roundtripped.metadata == poisoned_metadata + assert isinstance(roundtripped.metadata["id"][0], dict) + class TestInitValidator: """Tests for `init_validator` on `load()` and `loads()`.""" @@ -1158,3 +1201,217 @@ class TestBedrockValidators: } _bedrock_validator(class_path, kwargs) + + +class TestMessagesAllowlistTier: + """Tests for the 'messages' allowlist tier.""" + + def test_messages_tier_contains_expected_types(self) -> None: + expected = { + "AIMessage", + "AIMessageChunk", + "HumanMessage", + "HumanMessageChunk", + "SystemMessage", + "SystemMessageChunk", + "ToolMessage", + "ToolMessageChunk", + "RemoveMessage", + } + paths = _get_default_allowed_class_paths("messages") + actual = {t[-1] for t in paths} + assert expected.issubset(actual), f"Missing: {expected - actual}" + + def test_messages_tier_excludes_legacy_and_abstract_types(self) -> None: + legacy = { + "BaseMessage", + "BaseMessageChunk", + "ChatMessage", + "ChatMessageChunk", + "FunctionMessage", + "FunctionMessageChunk", + } + paths = _get_default_allowed_class_paths("messages") + actual = {t[-1] for t in paths} + overlap = legacy & actual + assert not overlap, f"Legacy/abstract message types in tier: {overlap}" + + def test_messages_tier_excludes_non_message_types(self) -> None: + non_messages = { + "Document", + "Generation", + "ChatGeneration", + "GenerationChunk", + "ChatGenerationChunk", + "PromptValue", + "StringPromptValue", + "ChatPromptValue", + "AgentAction", + "AgentActionMessageLog", + "AgentFinish", + } + paths = _get_default_allowed_class_paths("messages") + actual = {t[-1] for t in paths} + overlap = non_messages & actual + assert not overlap, f"Non-message types in messages tier: {overlap}" + + def test_messages_tier_excludes_dangerous_types(self) -> None: + dangerous = { + "ChatOpenAI", + "ChatAnthropic", + "OpenAI", + "PromptTemplate", + "ChatPromptTemplate", + "FewShotPromptWithTemplates", + "RunnableBinding", + "RunnableBranch", + "RunnableParallel", + "RunnableConfigurableFields", + "RunnableConfigurableAlternatives", + "DynamicRunnable", + "HubRunnable", + "OutputFixingParser", + } + paths = _get_default_allowed_class_paths("messages") + actual = {t[-1] for t in paths} + overlap = dangerous & actual + assert not overlap, f"Dangerous types in messages tier: {overlap}" + + def test_messages_tier_load_allows_message(self) -> None: + serialized = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": {"content": "hello"}, + } + loaded = load(serialized, allowed_objects="messages") + assert isinstance(loaded, AIMessage) + assert loaded.content == "hello" + + def test_messages_tier_load_blocks_prompt_template(self) -> None: + serialized = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "prompts", "prompt", "PromptTemplate"], + "kwargs": { + "input_variables": ["name"], + "template": "{name}", + "template_format": "f-string", + }, + } + with pytest.raises(ValueError, match="not allowed"): + load(serialized, allowed_objects="messages") + + def test_messages_tier_load_blocks_chat_model(self) -> None: + serialized = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "chat_models", "openai", "ChatOpenAI"], + "kwargs": {"model": "gpt-4"}, + } + with pytest.raises(ValueError, match="not allowed"): + load(serialized, allowed_objects="messages") + + +class TestAllowedObjectsDeprecation: + """Tests for the pending-default warning emitted when `allowed_objects` is unset.""" + + def test_unset_default_emits_pending_warning(self) -> None: + """load() with no allowed_objects emits pending deprecation warning.""" + serialized = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": {"content": "hello"}, + } + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + loaded = load(serialized) + dep_warnings = [ + x + for x in w + if issubclass( + x.category, + ( + LangChainDeprecationWarning, + LangChainPendingDeprecationWarning, + ), + ) + ] + assert len(dep_warnings) >= 1 + assert "allowed_objects" in str(dep_warnings[0].message) + assert isinstance(loaded, AIMessage) + + def test_explicit_core_no_warning(self) -> None: + """load() with explicit allowed_objects='core' does NOT warn.""" + serialized = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": {"content": "hello"}, + } + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + load(serialized, allowed_objects="core") + dep_warnings = [ + x + for x in w + if issubclass( + x.category, + ( + LangChainDeprecationWarning, + LangChainPendingDeprecationWarning, + ), + ) + ] + assert len(dep_warnings) == 0 + + def test_explicit_messages_no_deprecation_warning(self) -> None: + serialized = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": {"content": "hello"}, + } + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + load(serialized, allowed_objects="messages") + dep_warnings = [ + x for x in w if issubclass(x.category, LangChainDeprecationWarning) + ] + assert len(dep_warnings) == 0 + + def test_explicit_list_no_deprecation_warning(self) -> None: + serialized = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": {"content": "hello"}, + } + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + load(serialized, allowed_objects=[AIMessage]) + dep_warnings = [ + x for x in w if issubclass(x.category, LangChainDeprecationWarning) + ] + assert len(dep_warnings) == 0 + + +class TestInternalCallSitesUseMessages: + """Tests that internal call sites use 'messages' tier, not 'all'.""" + + def test_history_py_does_not_use_all(self) -> None: + source = inspect.getsource(RunnableWithMessageHistory) + assert 'allowed_objects="all"' not in source + assert ( + 'allowed_objects="messages"' in source + or "allowed_objects='messages'" in source + ) + + def test_log_stream_does_not_use_all(self) -> None: + source = inspect.getsource(log_stream) + assert 'allowed_objects="all"' not in source + assert ( + 'allowed_objects="messages"' in source + or "allowed_objects='messages'" in source + ) diff --git a/libs/langchain/langchain_classic/hub.py b/libs/langchain/langchain_classic/hub.py index 74490cbff74..602209d13aa 100644 --- a/libs/langchain/langchain_classic/hub.py +++ b/libs/langchain/langchain_classic/hub.py @@ -6,6 +6,7 @@ import json from collections.abc import Sequence from typing import Any, Literal +from langchain_core._api.deprecation import deprecated from langchain_core.load.dump import dumps from langchain_core.load.load import loads from langchain_core.prompts import BasePromptTemplate @@ -108,6 +109,13 @@ def push( ) +@deprecated( + since="1.4.0", + removal="2.0.0", + message=( + "langchain_classic.hub.pull is deprecated. Use the LangSmith SDK instead." + ), +) def pull( owner_repo_commit: str, *, @@ -117,11 +125,37 @@ def pull( ) -> Any: """Pull an object from the hub and returns it as a LangChain object. + !!! danger "Hub manifests are untrusted input" + + Treat every prompt pulled from the hub as untrusted, regardless of + the owner. Public prompts authored by other users are obviously + external content, but prompts from your own account — or your + organization's account — are also unsafe if that account, a + teammate's account, or the upstream prompt has been compromised. + A single malicious commit to a prompt your code pulls is enough to + execute attacker-controlled configuration on every machine that runs + `pull()`. + + `pull()` deserializes the manifest via `load()`, so the + `langchain_core.load.load` threat model applies — a manifest can + intentionally configure a model with a custom base URL, headers, + model name, or other constructor arguments. These are supported + features, but they also mean the prompt contents are executable + configuration rather than plain text: a compromised prompt can + redirect API traffic, inject headers, or trigger arbitrary code paths + in the classes it instantiates. + + Prefer the LangSmith SDK directly. If you must use `pull()`, pin the + commit hash, audit the manifest before deserializing, and never run + it against an account whose access controls you cannot vouch for. + Args: owner_repo_commit: The full name of the prompt to pull from in the format of `owner/prompt_name:commit_hash` or `owner/prompt_name` or just `prompt_name` if it's your own prompt. - include_model: Whether to include the model configuration in the pulled prompt. + include_model: Whether to include the model configuration in the pulled + prompt. When `True`, the model declared by the prompt is also + deserialized. api_url: The URL of the LangChain Hub API. Defaults to the hosted API service if you have an API key set, or a localhost instance if not. api_key: The API key to use to authenticate with the LangChain Hub API. @@ -151,4 +185,4 @@ def pull( # Then it's < 0.1.15 langchainhub resp: str = client.pull(owner_repo_commit) - return loads(resp) + return loads(resp, allowed_objects="core") diff --git a/libs/langchain/tests/unit_tests/test_hub.py b/libs/langchain/tests/unit_tests/test_hub.py new file mode 100644 index 00000000000..2a882ec7af2 --- /dev/null +++ b/libs/langchain/tests/unit_tests/test_hub.py @@ -0,0 +1,28 @@ +import warnings +from unittest.mock import MagicMock, patch + + +class TestHubPullDeprecation: + """Tests that `hub.pull` is deprecated in favor of the LangSmith SDK.""" + + def test_pull_emits_deprecation(self) -> None: + from langchain_core._api import LangChainDeprecationWarning + + from langchain_classic.hub import pull + + mock_client = MagicMock() + mock_client.pull_prompt = MagicMock(return_value=MagicMock()) + + with ( + patch("langchain_classic.hub._get_client", return_value=mock_client), + warnings.catch_warnings(record=True) as w, + ): + warnings.simplefilter("always") + pull("owner/repo") + dep_warnings = [ + x for x in w if issubclass(x.category, LangChainDeprecationWarning) + ] + assert len(dep_warnings) >= 1 + msg = str(dep_warnings[0].message) + assert "hub.pull" in msg + assert "LangSmith" in msg