ci: Add script to check for pickle usage in community (#22863)

Add script to check for pickle usage in community.
This commit is contained in:
Eugene Yurtsev 2024-06-13 16:13:15 -04:00 committed by GitHub
parent 77209f315e
commit 8f7cc73817
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 65 additions and 8 deletions

View File

@ -44,6 +44,7 @@ lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
./scripts/check_pickle.sh .
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)

View File

@ -242,7 +242,7 @@ def _load_pickled_fn_from_hex_string(
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")
try:
return cloudpickle.loads(bytes.fromhex(data))
return cloudpickle.loads(bytes.fromhex(data)) # ignore[pickle]: explicit-opt-in
except Exception as e:
raise ValueError(
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"

View File

@ -36,7 +36,9 @@ def _send_pipeline_to_device(pipeline: Any, device: int) -> Any:
"""Send a pipeline to a device on the cluster."""
if isinstance(pipeline, str):
with open(pipeline, "rb") as f:
pipeline = pickle.load(f)
# This code path can only be triggered if the user
# passed allow_dangerous_deserialization=True
pipeline = pickle.load(f) # ignore[pickle]: explicit-opt-in
if importlib.util.find_spec("torch") is not None:
import torch

View File

@ -152,6 +152,8 @@ class TFIDFRetriever(BaseRetriever):
# Load docs and tfidf array as pickle.
with open(path / f"{file_name}.pkl", "rb") as f:
docs, tfidf_array = pickle.load(f)
# This code path can only be triggered if the user
# passed allow_dangerous_deserialization=True
docs, tfidf_array = pickle.load(f) # ignore[pickle]: explicit-opt-in
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array)

View File

@ -456,7 +456,14 @@ class Annoy(VectorStore):
annoy = guard_import("annoy")
# load docstore and index_to_docstore_id
with open(path / "index.pkl", "rb") as file:
docstore, index_to_docstore_id, config_object = pickle.load(file)
# Code path can only be reached if allow_dangerous_deserialization is True
(
docstore,
index_to_docstore_id,
config_object,
) = pickle.load( # ignore[pickle]: explicit-opt-in
file
)
f = int(config_object["ANNOY"]["f"])
metric = config_object["ANNOY"]["metric"]

View File

@ -1093,7 +1093,13 @@ class FAISS(VectorStore):
# load docstore and index_to_docstore_id
with open(path / f"{index_name}.pkl", "rb") as f:
docstore, index_to_docstore_id = pickle.load(f)
(
docstore,
index_to_docstore_id,
) = pickle.load( # ignore[pickle]: explicit-opt-in
f
)
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
def serialize_to_bytes(self) -> bytes:
@ -1123,7 +1129,13 @@ class FAISS(VectorStore):
"loading a file from an untrusted source (e.g., some random site on "
"the internet.)."
)
index, docstore, index_to_docstore_id = pickle.loads(serialized)
(
index,
docstore,
index_to_docstore_id,
) = pickle.loads( # ignore[pickle]: explicit-opt-in
serialized
)
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
def _select_relevance_score_fn(self) -> Callable[[float], float]:

View File

@ -493,7 +493,13 @@ class ScaNN(VectorStore):
# load docstore and index_to_docstore_id
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
docstore, index_to_docstore_id = pickle.load(f)
(
docstore,
index_to_docstore_id,
) = pickle.load( # ignore[pickle]: explicit-opt-in
f
)
return cls(embedding, index, docstore, index_to_docstore_id, **kwargs)
def _select_relevance_score_fn(self) -> Callable[[float], float]:

View File

@ -188,7 +188,7 @@ class TileDB(VectorStore):
pickled_metadata = doc.get("metadata")
result_doc = Document(page_content=str(doc["text"][0]))
if pickled_metadata is not None:
metadata = pickle.loads(
metadata = pickle.loads( # ignore[pickle]: explicit-opt-in
np.array(pickled_metadata.tolist()).astype(np.uint8).tobytes()
)
result_doc.metadata = metadata

View File

@ -0,0 +1,27 @@
#!/bin/bash
#
# This checks for usage of pickle in the package.
#
# Usage: ./scripts/check_pickle.sh /path/to/repository
#
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E 'pickle.load\(|pickle.loads\(' | grep -v '# ignore\[pickle\]: explicit-opt-in')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please avoid using pickle or cloudpickle."
echo "If you must, then add:"
echo "1. A security notice (scan the code for examples)"
echo "2. Code path should be opt-in."
exit 1
fi