diff --git a/libs/community/langchain_community/retrievers/tfidf.py b/libs/community/langchain_community/retrievers/tfidf.py index 9a033265b34..f6a3b57307a 100644 --- a/libs/community/langchain_community/retrievers/tfidf.py +++ b/libs/community/langchain_community/retrievers/tfidf.py @@ -106,8 +106,26 @@ class TFIDFRetriever(BaseRetriever): def load_local( cls, folder_path: str, + *, + allow_dangerous_deserialization: bool = False, file_name: str = "tfidf_vectorizer", ) -> TFIDFRetriever: + """Load the retriever from local storage. + + Args: + folder_path: Folder path to load from. + allow_dangerous_deserialization: Whether to allow dangerous deserialization. + Defaults to False. + The deserialization relies on .joblib and .pkl files, which can be + modified to deliver a malicious payload that results in execution of + arbitrary code on your machine. You will need to set this to `True` to + use deserialization. If you do this, make sure you trust the source of + the file. + file_name: File name to load from. Defaults to "tfidf_vectorizer". + + Returns: + TFIDFRetriever: Loaded retriever. + """ try: import joblib except ImportError: @@ -115,6 +133,18 @@ class TFIDFRetriever(BaseRetriever): "Could not import joblib, please install with `pip install joblib`." ) + if not allow_dangerous_deserialization: + raise ValueError( + "The de-serialization of this retriever is based on .joblib and " + ".pkl files." + "Such files can be modified to deliver a malicious payload that " + "results in execution of arbitrary code on your machine." + "You will need to set `allow_dangerous_deserialization` to `True` to " + "load this retriever. If you do this, make sure you trust the source " + "of the file, and you are responsible for validating the the file " + "came from a trusted source." + ) + path = Path(folder_path) # Load vectorizer with joblib load. diff --git a/libs/community/tests/unit_tests/retrievers/test_tfidf.py b/libs/community/tests/unit_tests/retrievers/test_tfidf.py index cbdb14d8cf1..2417b0dd78e 100644 --- a/libs/community/tests/unit_tests/retrievers/test_tfidf.py +++ b/libs/community/tests/unit_tests/retrievers/test_tfidf.py @@ -56,6 +56,11 @@ def test_save_local_load_local() -> None: loaded_tfidf_retriever = TFIDFRetriever.load_local( folder_path=temp_folder, file_name=file_name, + # Not a realistic security risk in this case. + # OK to allow for testing purposes. + # If the file has been compromised during this test, there's + # a much bigger problem. + allow_dangerous_deserialization=True, ) assert len(loaded_tfidf_retriever.docs) == 3 assert loaded_tfidf_retriever.tfidf_array.toarray().shape == (3, 5)