mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
community[major]: breaking change in some APIs to force users to opt-in for pickling (#18696)
This is a PR that adds a dangerous load parameter to force users to opt in to use pickle. This is a PR that's meant to raise user awareness that the pickling module is involved.
This commit is contained in:
parent
0e52961562
commit
4c25b49229
@ -250,7 +250,6 @@ def _pickle_fn_to_hex_string(fn: Callable) -> str:
|
||||
|
||||
|
||||
class Databricks(LLM):
|
||||
|
||||
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
|
||||
|
||||
It supports two endpoint types:
|
||||
@ -374,6 +373,15 @@ class Databricks(LLM):
|
||||
If not provided, the task is automatically inferred from the endpoint.
|
||||
"""
|
||||
|
||||
allow_dangerous_deserialization: bool = False
|
||||
"""Whether to allow dangerous deserialization of the data which
|
||||
involves loading data using pickle.
|
||||
|
||||
If the data has been modified by a malicious actor, it can deliver a
|
||||
malicious payload that results in execution of arbitrary code on the target
|
||||
machine.
|
||||
"""
|
||||
|
||||
_client: _DatabricksClientBase = PrivateAttr()
|
||||
|
||||
class Config:
|
||||
@ -435,6 +443,16 @@ class Databricks(LLM):
|
||||
return v
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
if not data.get("allow_dangerous_deserialization"):
|
||||
raise ValueError(
|
||||
"This code relies on the pickle module. "
|
||||
"You will need to set allow_dangerous_deserialization=True "
|
||||
"if you want to opt-in to allow deserialization of data using pickle."
|
||||
"Data can be compromised by a malicious actor if "
|
||||
"not handled properly to include "
|
||||
"a malicious payload that when deserialized with "
|
||||
"pickle can execute arbitrary code on your machine."
|
||||
)
|
||||
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
||||
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
|
||||
data["transform_input_fn"]
|
||||
|
@ -137,6 +137,11 @@ class SelfHostedPipeline(LLM):
|
||||
model_reqs: List[str] = ["./", "torch"]
|
||||
"""Requirements to install on hardware to inference the model."""
|
||||
|
||||
allow_dangerous_deserialization: bool = False
|
||||
"""Allow deserialization using pickle which can be dangerous if
|
||||
loading compromised data.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -149,6 +154,16 @@ class SelfHostedPipeline(LLM):
|
||||
and run on the server, i.e. in a module and not a REPL or closure.
|
||||
Then, initialize the remote inference function.
|
||||
"""
|
||||
if not kwargs.get("allow_dangerous_deserialization"):
|
||||
raise ValueError(
|
||||
"SelfHostedPipeline relies on the pickle module. "
|
||||
"You will need to set allow_dangerous_deserialization=True "
|
||||
"if you want to opt-in to allow deserialization of data using pickle."
|
||||
"Data can be compromised by a malicious actor if "
|
||||
"not handled properly to include "
|
||||
"a malicious payload that when deserialized with "
|
||||
"pickle can execute arbitrary code. "
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import runhouse as rh
|
||||
|
@ -429,6 +429,8 @@ class Annoy(VectorStore):
|
||||
cls,
|
||||
folder_path: str,
|
||||
embeddings: Embeddings,
|
||||
*,
|
||||
allow_dangerous_deserialization: bool = False,
|
||||
) -> Annoy:
|
||||
"""Load Annoy index, docstore, and index_to_docstore_id to disk.
|
||||
|
||||
@ -436,7 +438,25 @@ class Annoy(VectorStore):
|
||||
folder_path: folder path to load index, docstore,
|
||||
and index_to_docstore_id from.
|
||||
embeddings: Embeddings to use when generating queries.
|
||||
allow_dangerous_deserialization: whether to allow deserialization
|
||||
of the data which involves loading a pickle file.
|
||||
Pickle files can be modified by malicious actors to deliver a
|
||||
malicious payload that results in execution of
|
||||
arbitrary code on your machine.
|
||||
"""
|
||||
if not allow_dangerous_deserialization:
|
||||
raise ValueError(
|
||||
"The de-serialization relies loading a pickle file. "
|
||||
"Pickle files can be modified to deliver a malicious payload that "
|
||||
"results in execution of arbitrary code on your machine."
|
||||
"You will need to set `allow_dangerous_deserialization` to `True` to "
|
||||
"enable deserialization. If you do this, make sure that you "
|
||||
"trust the source of the data. For example, if you are loading a "
|
||||
"file that you created, and no that no one else has modified the file, "
|
||||
"then this is safe to do. Do not set this to `True` if you are loading "
|
||||
"a file from an untrusted source (e.g., some random site on the "
|
||||
"internet.)."
|
||||
)
|
||||
path = Path(folder_path)
|
||||
# load index separately since it is not picklable
|
||||
annoy = dependable_annoy_import()
|
||||
|
@ -1093,6 +1093,8 @@ class FAISS(VectorStore):
|
||||
folder_path: str,
|
||||
embeddings: Embeddings,
|
||||
index_name: str = "index",
|
||||
*,
|
||||
allow_dangerous_deserialization: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Load FAISS index, docstore, and index_to_docstore_id from disk.
|
||||
@ -1102,8 +1104,26 @@ class FAISS(VectorStore):
|
||||
and index_to_docstore_id from.
|
||||
embeddings: Embeddings to use when generating queries
|
||||
index_name: for saving with a specific index file name
|
||||
allow_dangerous_deserialization: whether to allow deserialization
|
||||
of the data which involves loading a pickle file.
|
||||
Pickle files can be modified by malicious actors to deliver a
|
||||
malicious payload that results in execution of
|
||||
arbitrary code on your machine.
|
||||
asynchronous: whether to use async version or not
|
||||
"""
|
||||
if not allow_dangerous_deserialization:
|
||||
raise ValueError(
|
||||
"The de-serialization relies loading a pickle file. "
|
||||
"Pickle files can be modified to deliver a malicious payload that "
|
||||
"results in execution of arbitrary code on your machine."
|
||||
"You will need to set `allow_dangerous_deserialization` to `True` to "
|
||||
"enable deserialization. If you do this, make sure that you "
|
||||
"trust the source of the data. For example, if you are loading a "
|
||||
"file that you created, and no that no one else has modified the file, "
|
||||
"then this is safe to do. Do not set this to `True` if you are loading "
|
||||
"a file from an untrusted source (e.g., some random site on the "
|
||||
"internet.)."
|
||||
)
|
||||
path = Path(folder_path)
|
||||
# load index separately since it is not picklable
|
||||
faiss = dependable_faiss_import()
|
||||
|
@ -460,6 +460,8 @@ class ScaNN(VectorStore):
|
||||
folder_path: str,
|
||||
embedding: Embeddings,
|
||||
index_name: str = "index",
|
||||
*,
|
||||
allow_dangerous_deserialization: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> ScaNN:
|
||||
"""Load ScaNN index, docstore, and index_to_docstore_id from disk.
|
||||
@ -469,7 +471,25 @@ class ScaNN(VectorStore):
|
||||
and index_to_docstore_id from.
|
||||
embeddings: Embeddings to use when generating queries
|
||||
index_name: for saving with a specific index file name
|
||||
allow_dangerous_deserialization: whether to allow deserialization
|
||||
of the data which involves loading a pickle file.
|
||||
Pickle files can be modified by malicious actors to deliver a
|
||||
malicious payload that results in execution of
|
||||
arbitrary code on your machine.
|
||||
"""
|
||||
if not allow_dangerous_deserialization:
|
||||
raise ValueError(
|
||||
"The de-serialization relies loading a pickle file. "
|
||||
"Pickle files can be modified to deliver a malicious payload that "
|
||||
"results in execution of arbitrary code on your machine."
|
||||
"You will need to set `allow_dangerous_deserialization` to `True` to "
|
||||
"enable deserialization. If you do this, make sure that you "
|
||||
"trust the source of the data. For example, if you are loading a "
|
||||
"file that you created, and no that no one else has modified the file, "
|
||||
"then this is safe to do. Do not set this to `True` if you are loading "
|
||||
"a file from an untrusted source (e.g., some random site on the "
|
||||
"internet.)."
|
||||
)
|
||||
path = Path(folder_path)
|
||||
scann_path = path / "{index_name}.scann".format(index_name=index_name)
|
||||
scann_path.mkdir(exist_ok=True, parents=True)
|
||||
|
@ -87,9 +87,28 @@ class TileDB(VectorStore):
|
||||
docs_array_uri: str = "",
|
||||
config: Optional[Mapping[str, Any]] = None,
|
||||
timestamp: Any = None,
|
||||
allow_dangerous_deserialization: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
"""Initialize with necessary components.
|
||||
|
||||
Args:
|
||||
allow_dangerous_deserialization: whether to allow deserialization
|
||||
of the data which involves loading data using pickle.
|
||||
data can be modified by malicious actors to deliver a
|
||||
malicious payload that results in execution of
|
||||
arbitrary code on your machine.
|
||||
"""
|
||||
if not allow_dangerous_deserialization:
|
||||
raise ValueError(
|
||||
"TileDB relies on pickle for serialization and deserialization. "
|
||||
"This can be dangerous if the data is intercepted and/or modified "
|
||||
"by malicious actors prior to being de-serialized. "
|
||||
"If you are sure that the data is safe from modification, you can "
|
||||
" set allow_dangerous_deserialization=True to proceed. "
|
||||
"Loading of compromised data using pickle can result in execution of "
|
||||
"arbitrary code on your machine."
|
||||
)
|
||||
self.embedding = embedding
|
||||
self.embedding_function = embedding.embed_query
|
||||
self.index_uri = index_uri
|
||||
|
@ -116,7 +116,9 @@ def test_annoy_local_save_load() -> None:
|
||||
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
docsearch.save_local(temp_dir.name)
|
||||
loaded_docsearch = Annoy.load_local(temp_dir.name, FakeEmbeddings())
|
||||
loaded_docsearch = Annoy.load_local(
|
||||
temp_dir.name, FakeEmbeddings(), allow_dangerous_deserialization=True
|
||||
)
|
||||
|
||||
assert docsearch.index_to_docstore_id == loaded_docsearch.index_to_docstore_id
|
||||
assert docsearch.docstore.__dict__ == loaded_docsearch.docstore.__dict__
|
||||
|
@ -252,7 +252,9 @@ def test_scann_local_save_load() -> None:
|
||||
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||
docsearch.save_local(temp_folder)
|
||||
new_docsearch = ScaNN.load_local(temp_folder, FakeEmbeddings())
|
||||
new_docsearch = ScaNN.load_local(
|
||||
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
|
||||
)
|
||||
assert new_docsearch.index is not None
|
||||
|
||||
|
||||
|
@ -44,8 +44,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
|
||||
|
||||
llm = Databricks(
|
||||
endpoint_name="databricks-mixtral-8x7b-instruct",
|
||||
endpoint_name="some_end_point_name", # Value should not matter for this test
|
||||
transform_input_fn=transform_input,
|
||||
allow_dangerous_deserialization=True,
|
||||
)
|
||||
params = llm._default_params
|
||||
pickled_string = cloudpickle.dumps(transform_input).hex()
|
||||
|
@ -608,7 +608,9 @@ def test_faiss_local_save_load() -> None:
|
||||
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||
docsearch.save_local(temp_folder)
|
||||
new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings())
|
||||
new_docsearch = FAISS.load_local(
|
||||
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
|
||||
)
|
||||
assert new_docsearch.index is not None
|
||||
|
||||
|
||||
@ -620,7 +622,9 @@ async def test_faiss_async_local_save_load() -> None:
|
||||
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||
docsearch.save_local(temp_folder)
|
||||
new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings())
|
||||
new_docsearch = FAISS.load_local(
|
||||
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
|
||||
)
|
||||
assert new_docsearch.index is not None
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user