mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
community[major]: breaking change in some APIs to force users to opt-in for pickling (#18696)
This is a PR that adds a dangerous load parameter to force users to opt in to use pickle. This is a PR that's meant to raise user awareness that the pickling module is involved.
This commit is contained in:
parent
0e52961562
commit
4c25b49229
@ -250,7 +250,6 @@ def _pickle_fn_to_hex_string(fn: Callable) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class Databricks(LLM):
|
class Databricks(LLM):
|
||||||
|
|
||||||
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
|
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
|
||||||
|
|
||||||
It supports two endpoint types:
|
It supports two endpoint types:
|
||||||
@ -374,6 +373,15 @@ class Databricks(LLM):
|
|||||||
If not provided, the task is automatically inferred from the endpoint.
|
If not provided, the task is automatically inferred from the endpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
allow_dangerous_deserialization: bool = False
|
||||||
|
"""Whether to allow dangerous deserialization of the data which
|
||||||
|
involves loading data using pickle.
|
||||||
|
|
||||||
|
If the data has been modified by a malicious actor, it can deliver a
|
||||||
|
malicious payload that results in execution of arbitrary code on the target
|
||||||
|
machine.
|
||||||
|
"""
|
||||||
|
|
||||||
_client: _DatabricksClientBase = PrivateAttr()
|
_client: _DatabricksClientBase = PrivateAttr()
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -435,6 +443,16 @@ class Databricks(LLM):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
def __init__(self, **data: Any):
|
def __init__(self, **data: Any):
|
||||||
|
if not data.get("allow_dangerous_deserialization"):
|
||||||
|
raise ValueError(
|
||||||
|
"This code relies on the pickle module. "
|
||||||
|
"You will need to set allow_dangerous_deserialization=True "
|
||||||
|
"if you want to opt-in to allow deserialization of data using pickle."
|
||||||
|
"Data can be compromised by a malicious actor if "
|
||||||
|
"not handled properly to include "
|
||||||
|
"a malicious payload that when deserialized with "
|
||||||
|
"pickle can execute arbitrary code on your machine."
|
||||||
|
)
|
||||||
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
||||||
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
|
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
|
||||||
data["transform_input_fn"]
|
data["transform_input_fn"]
|
||||||
|
@ -137,6 +137,11 @@ class SelfHostedPipeline(LLM):
|
|||||||
model_reqs: List[str] = ["./", "torch"]
|
model_reqs: List[str] = ["./", "torch"]
|
||||||
"""Requirements to install on hardware to inference the model."""
|
"""Requirements to install on hardware to inference the model."""
|
||||||
|
|
||||||
|
allow_dangerous_deserialization: bool = False
|
||||||
|
"""Allow deserialization using pickle which can be dangerous if
|
||||||
|
loading compromised data.
|
||||||
|
"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@ -149,6 +154,16 @@ class SelfHostedPipeline(LLM):
|
|||||||
and run on the server, i.e. in a module and not a REPL or closure.
|
and run on the server, i.e. in a module and not a REPL or closure.
|
||||||
Then, initialize the remote inference function.
|
Then, initialize the remote inference function.
|
||||||
"""
|
"""
|
||||||
|
if not kwargs.get("allow_dangerous_deserialization"):
|
||||||
|
raise ValueError(
|
||||||
|
"SelfHostedPipeline relies on the pickle module. "
|
||||||
|
"You will need to set allow_dangerous_deserialization=True "
|
||||||
|
"if you want to opt-in to allow deserialization of data using pickle."
|
||||||
|
"Data can be compromised by a malicious actor if "
|
||||||
|
"not handled properly to include "
|
||||||
|
"a malicious payload that when deserialized with "
|
||||||
|
"pickle can execute arbitrary code. "
|
||||||
|
)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
try:
|
try:
|
||||||
import runhouse as rh
|
import runhouse as rh
|
||||||
|
@ -429,6 +429,8 @@ class Annoy(VectorStore):
|
|||||||
cls,
|
cls,
|
||||||
folder_path: str,
|
folder_path: str,
|
||||||
embeddings: Embeddings,
|
embeddings: Embeddings,
|
||||||
|
*,
|
||||||
|
allow_dangerous_deserialization: bool = False,
|
||||||
) -> Annoy:
|
) -> Annoy:
|
||||||
"""Load Annoy index, docstore, and index_to_docstore_id to disk.
|
"""Load Annoy index, docstore, and index_to_docstore_id to disk.
|
||||||
|
|
||||||
@ -436,7 +438,25 @@ class Annoy(VectorStore):
|
|||||||
folder_path: folder path to load index, docstore,
|
folder_path: folder path to load index, docstore,
|
||||||
and index_to_docstore_id from.
|
and index_to_docstore_id from.
|
||||||
embeddings: Embeddings to use when generating queries.
|
embeddings: Embeddings to use when generating queries.
|
||||||
|
allow_dangerous_deserialization: whether to allow deserialization
|
||||||
|
of the data which involves loading a pickle file.
|
||||||
|
Pickle files can be modified by malicious actors to deliver a
|
||||||
|
malicious payload that results in execution of
|
||||||
|
arbitrary code on your machine.
|
||||||
"""
|
"""
|
||||||
|
if not allow_dangerous_deserialization:
|
||||||
|
raise ValueError(
|
||||||
|
"The de-serialization relies loading a pickle file. "
|
||||||
|
"Pickle files can be modified to deliver a malicious payload that "
|
||||||
|
"results in execution of arbitrary code on your machine."
|
||||||
|
"You will need to set `allow_dangerous_deserialization` to `True` to "
|
||||||
|
"enable deserialization. If you do this, make sure that you "
|
||||||
|
"trust the source of the data. For example, if you are loading a "
|
||||||
|
"file that you created, and no that no one else has modified the file, "
|
||||||
|
"then this is safe to do. Do not set this to `True` if you are loading "
|
||||||
|
"a file from an untrusted source (e.g., some random site on the "
|
||||||
|
"internet.)."
|
||||||
|
)
|
||||||
path = Path(folder_path)
|
path = Path(folder_path)
|
||||||
# load index separately since it is not picklable
|
# load index separately since it is not picklable
|
||||||
annoy = dependable_annoy_import()
|
annoy = dependable_annoy_import()
|
||||||
|
@ -1093,6 +1093,8 @@ class FAISS(VectorStore):
|
|||||||
folder_path: str,
|
folder_path: str,
|
||||||
embeddings: Embeddings,
|
embeddings: Embeddings,
|
||||||
index_name: str = "index",
|
index_name: str = "index",
|
||||||
|
*,
|
||||||
|
allow_dangerous_deserialization: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> FAISS:
|
) -> FAISS:
|
||||||
"""Load FAISS index, docstore, and index_to_docstore_id from disk.
|
"""Load FAISS index, docstore, and index_to_docstore_id from disk.
|
||||||
@ -1102,8 +1104,26 @@ class FAISS(VectorStore):
|
|||||||
and index_to_docstore_id from.
|
and index_to_docstore_id from.
|
||||||
embeddings: Embeddings to use when generating queries
|
embeddings: Embeddings to use when generating queries
|
||||||
index_name: for saving with a specific index file name
|
index_name: for saving with a specific index file name
|
||||||
|
allow_dangerous_deserialization: whether to allow deserialization
|
||||||
|
of the data which involves loading a pickle file.
|
||||||
|
Pickle files can be modified by malicious actors to deliver a
|
||||||
|
malicious payload that results in execution of
|
||||||
|
arbitrary code on your machine.
|
||||||
asynchronous: whether to use async version or not
|
asynchronous: whether to use async version or not
|
||||||
"""
|
"""
|
||||||
|
if not allow_dangerous_deserialization:
|
||||||
|
raise ValueError(
|
||||||
|
"The de-serialization relies loading a pickle file. "
|
||||||
|
"Pickle files can be modified to deliver a malicious payload that "
|
||||||
|
"results in execution of arbitrary code on your machine."
|
||||||
|
"You will need to set `allow_dangerous_deserialization` to `True` to "
|
||||||
|
"enable deserialization. If you do this, make sure that you "
|
||||||
|
"trust the source of the data. For example, if you are loading a "
|
||||||
|
"file that you created, and no that no one else has modified the file, "
|
||||||
|
"then this is safe to do. Do not set this to `True` if you are loading "
|
||||||
|
"a file from an untrusted source (e.g., some random site on the "
|
||||||
|
"internet.)."
|
||||||
|
)
|
||||||
path = Path(folder_path)
|
path = Path(folder_path)
|
||||||
# load index separately since it is not picklable
|
# load index separately since it is not picklable
|
||||||
faiss = dependable_faiss_import()
|
faiss = dependable_faiss_import()
|
||||||
|
@ -460,6 +460,8 @@ class ScaNN(VectorStore):
|
|||||||
folder_path: str,
|
folder_path: str,
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
index_name: str = "index",
|
index_name: str = "index",
|
||||||
|
*,
|
||||||
|
allow_dangerous_deserialization: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ScaNN:
|
) -> ScaNN:
|
||||||
"""Load ScaNN index, docstore, and index_to_docstore_id from disk.
|
"""Load ScaNN index, docstore, and index_to_docstore_id from disk.
|
||||||
@ -469,7 +471,25 @@ class ScaNN(VectorStore):
|
|||||||
and index_to_docstore_id from.
|
and index_to_docstore_id from.
|
||||||
embeddings: Embeddings to use when generating queries
|
embeddings: Embeddings to use when generating queries
|
||||||
index_name: for saving with a specific index file name
|
index_name: for saving with a specific index file name
|
||||||
|
allow_dangerous_deserialization: whether to allow deserialization
|
||||||
|
of the data which involves loading a pickle file.
|
||||||
|
Pickle files can be modified by malicious actors to deliver a
|
||||||
|
malicious payload that results in execution of
|
||||||
|
arbitrary code on your machine.
|
||||||
"""
|
"""
|
||||||
|
if not allow_dangerous_deserialization:
|
||||||
|
raise ValueError(
|
||||||
|
"The de-serialization relies loading a pickle file. "
|
||||||
|
"Pickle files can be modified to deliver a malicious payload that "
|
||||||
|
"results in execution of arbitrary code on your machine."
|
||||||
|
"You will need to set `allow_dangerous_deserialization` to `True` to "
|
||||||
|
"enable deserialization. If you do this, make sure that you "
|
||||||
|
"trust the source of the data. For example, if you are loading a "
|
||||||
|
"file that you created, and no that no one else has modified the file, "
|
||||||
|
"then this is safe to do. Do not set this to `True` if you are loading "
|
||||||
|
"a file from an untrusted source (e.g., some random site on the "
|
||||||
|
"internet.)."
|
||||||
|
)
|
||||||
path = Path(folder_path)
|
path = Path(folder_path)
|
||||||
scann_path = path / "{index_name}.scann".format(index_name=index_name)
|
scann_path = path / "{index_name}.scann".format(index_name=index_name)
|
||||||
scann_path.mkdir(exist_ok=True, parents=True)
|
scann_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
@ -87,9 +87,28 @@ class TileDB(VectorStore):
|
|||||||
docs_array_uri: str = "",
|
docs_array_uri: str = "",
|
||||||
config: Optional[Mapping[str, Any]] = None,
|
config: Optional[Mapping[str, Any]] = None,
|
||||||
timestamp: Any = None,
|
timestamp: Any = None,
|
||||||
|
allow_dangerous_deserialization: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize with necessary components."""
|
"""Initialize with necessary components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allow_dangerous_deserialization: whether to allow deserialization
|
||||||
|
of the data which involves loading data using pickle.
|
||||||
|
data can be modified by malicious actors to deliver a
|
||||||
|
malicious payload that results in execution of
|
||||||
|
arbitrary code on your machine.
|
||||||
|
"""
|
||||||
|
if not allow_dangerous_deserialization:
|
||||||
|
raise ValueError(
|
||||||
|
"TileDB relies on pickle for serialization and deserialization. "
|
||||||
|
"This can be dangerous if the data is intercepted and/or modified "
|
||||||
|
"by malicious actors prior to being de-serialized. "
|
||||||
|
"If you are sure that the data is safe from modification, you can "
|
||||||
|
" set allow_dangerous_deserialization=True to proceed. "
|
||||||
|
"Loading of compromised data using pickle can result in execution of "
|
||||||
|
"arbitrary code on your machine."
|
||||||
|
)
|
||||||
self.embedding = embedding
|
self.embedding = embedding
|
||||||
self.embedding_function = embedding.embed_query
|
self.embedding_function = embedding.embed_query
|
||||||
self.index_uri = index_uri
|
self.index_uri = index_uri
|
||||||
|
@ -116,7 +116,9 @@ def test_annoy_local_save_load() -> None:
|
|||||||
|
|
||||||
temp_dir = tempfile.TemporaryDirectory()
|
temp_dir = tempfile.TemporaryDirectory()
|
||||||
docsearch.save_local(temp_dir.name)
|
docsearch.save_local(temp_dir.name)
|
||||||
loaded_docsearch = Annoy.load_local(temp_dir.name, FakeEmbeddings())
|
loaded_docsearch = Annoy.load_local(
|
||||||
|
temp_dir.name, FakeEmbeddings(), allow_dangerous_deserialization=True
|
||||||
|
)
|
||||||
|
|
||||||
assert docsearch.index_to_docstore_id == loaded_docsearch.index_to_docstore_id
|
assert docsearch.index_to_docstore_id == loaded_docsearch.index_to_docstore_id
|
||||||
assert docsearch.docstore.__dict__ == loaded_docsearch.docstore.__dict__
|
assert docsearch.docstore.__dict__ == loaded_docsearch.docstore.__dict__
|
||||||
|
@ -252,7 +252,9 @@ def test_scann_local_save_load() -> None:
|
|||||||
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||||
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||||
docsearch.save_local(temp_folder)
|
docsearch.save_local(temp_folder)
|
||||||
new_docsearch = ScaNN.load_local(temp_folder, FakeEmbeddings())
|
new_docsearch = ScaNN.load_local(
|
||||||
|
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
|
||||||
|
)
|
||||||
assert new_docsearch.index is not None
|
assert new_docsearch.index is not None
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,8 +44,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
|||||||
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
|
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
|
||||||
|
|
||||||
llm = Databricks(
|
llm = Databricks(
|
||||||
endpoint_name="databricks-mixtral-8x7b-instruct",
|
endpoint_name="some_end_point_name", # Value should not matter for this test
|
||||||
transform_input_fn=transform_input,
|
transform_input_fn=transform_input,
|
||||||
|
allow_dangerous_deserialization=True,
|
||||||
)
|
)
|
||||||
params = llm._default_params
|
params = llm._default_params
|
||||||
pickled_string = cloudpickle.dumps(transform_input).hex()
|
pickled_string = cloudpickle.dumps(transform_input).hex()
|
||||||
|
@ -608,7 +608,9 @@ def test_faiss_local_save_load() -> None:
|
|||||||
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||||
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||||
docsearch.save_local(temp_folder)
|
docsearch.save_local(temp_folder)
|
||||||
new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings())
|
new_docsearch = FAISS.load_local(
|
||||||
|
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
|
||||||
|
)
|
||||||
assert new_docsearch.index is not None
|
assert new_docsearch.index is not None
|
||||||
|
|
||||||
|
|
||||||
@ -620,7 +622,9 @@ async def test_faiss_async_local_save_load() -> None:
|
|||||||
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||||
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||||
docsearch.save_local(temp_folder)
|
docsearch.save_local(temp_folder)
|
||||||
new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings())
|
new_docsearch = FAISS.load_local(
|
||||||
|
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
|
||||||
|
)
|
||||||
assert new_docsearch.index is not None
|
assert new_docsearch.index is not None
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user