From 3a068b26f3b475c790990cd8ff1a0fd839364260 Mon Sep 17 00:00:00 2001 From: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:27:26 -0700 Subject: [PATCH] community[patch]: Databricks - fix scope of dangerous deserialization error in Databricks LLM connector (#20368) fix scope of dangerous deserialization error in Databricks LLM connector --------- Signed-off-by: dbczumar --- .../langchain_community/llms/databricks.py | 35 ++++++++----- .../tests/unit_tests/llms/test_databricks.py | 50 +++++++++++++++---- 2 files changed, 63 insertions(+), 22 deletions(-) diff --git a/libs/community/langchain_community/llms/databricks.py b/libs/community/langchain_community/llms/databricks.py index 0debf87353a..06da23183e1 100644 --- a/libs/community/langchain_community/llms/databricks.py +++ b/libs/community/langchain_community/llms/databricks.py @@ -221,8 +221,21 @@ def _is_hex_string(data: str) -> bool: return bool(re.match(pattern, data)) -def _load_pickled_fn_from_hex_string(data: str) -> Callable: +def _load_pickled_fn_from_hex_string( + data: str, allow_dangerous_deserialization: Optional[bool] +) -> Callable: """Loads a pickled function from a hexadecimal string.""" + if not allow_dangerous_deserialization: + raise ValueError( + "This code relies on the pickle module. " + "You will need to set allow_dangerous_deserialization=True " + "if you want to opt-in to allow deserialization of data using pickle." + "Data can be compromised by a malicious actor if " + "not handled properly to include " + "a malicious payload that when deserialized with " + "pickle can execute arbitrary code on your machine." + ) + try: import cloudpickle except Exception as e: @@ -443,25 +456,21 @@ class Databricks(LLM): return v def __init__(self, **data: Any): - if not data.get("allow_dangerous_deserialization"): - raise ValueError( - "This code relies on the pickle module. " - "You will need to set allow_dangerous_deserialization=True " - "if you want to opt-in to allow deserialization of data using pickle." - "Data can be compromised by a malicious actor if " - "not handled properly to include " - "a malicious payload that when deserialized with " - "pickle can execute arbitrary code on your machine." - ) if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]): data["transform_input_fn"] = _load_pickled_fn_from_hex_string( - data["transform_input_fn"] + data=data["transform_input_fn"], + allow_dangerous_deserialization=data.get( + "allow_dangerous_deserialization" + ), ) if "transform_output_fn" in data and _is_hex_string( data["transform_output_fn"] ): data["transform_output_fn"] = _load_pickled_fn_from_hex_string( - data["transform_output_fn"] + data=data["transform_output_fn"], + allow_dangerous_deserialization=data.get( + "allow_dangerous_deserialization" + ), ) super().__init__(**data) diff --git a/libs/community/tests/unit_tests/llms/test_databricks.py b/libs/community/tests/unit_tests/llms/test_databricks.py index 640f274a762..0eb24d39461 100644 --- a/libs/community/tests/unit_tests/llms/test_databricks.py +++ b/libs/community/tests/unit_tests/llms/test_databricks.py @@ -56,7 +56,10 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None: assert params["transform_input_fn"] == pickled_string request = {"prompt": "What is the meaning of life?"} - fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"]) + fn = _load_pickled_fn_from_hex_string( + data=params["transform_input_fn"], + allow_dangerous_deserialization=True, + ) assert fn(**request) == transform_input(**request) @@ -69,15 +72,44 @@ def test_saving_loading_llm(monkeypatch: MonkeyPatch, tmp_path: Path) -> None: monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") llm = Databricks( - endpoint_name="chat", temperature=0.1, allow_dangerous_deserialization=True + endpoint_name="chat", + temperature=0.1, ) llm.save(file_path=tmp_path / "databricks.yaml") - # Loading without allowing_dangerous_deserialization=True should raise an error. - with pytest.raises(ValueError, match="This code relies on the pickle module."): - load_llm(tmp_path / "databricks.yaml") - - loaded_llm = load_llm( - tmp_path / "databricks.yaml", allow_dangerous_deserialization=True - ) + loaded_llm = load_llm(tmp_path / "databricks.yaml") assert_llm_equality(llm, loaded_llm) + + +@pytest.mark.requires("cloudpickle") +def test_saving_loading_llm_dangerous_serde_check( + monkeypatch: MonkeyPatch, tmp_path: Path +) -> None: + monkeypatch.setattr( + "langchain_community.llms.databricks._DatabricksServingEndpointClient", + MockDatabricksServingEndpointClient, + ) + monkeypatch.setenv("DATABRICKS_HOST", "my-default-host") + monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") + + llm1 = Databricks( + endpoint_name="chat", + temperature=0.1, + transform_input_fn=lambda x, y, **kwargs: {}, + ) + llm1.save(file_path=tmp_path / "databricks1.yaml") + + with pytest.raises(ValueError, match="This code relies on the pickle module."): + load_llm(tmp_path / "databricks1.yaml") + + load_llm(tmp_path / "databricks1.yaml", allow_dangerous_deserialization=True) + + llm2 = Databricks( + endpoint_name="chat", temperature=0.1, transform_output_fn=lambda x: "test" + ) + llm2.save(file_path=tmp_path / "databricks2.yaml") + + with pytest.raises(ValueError, match="This code relies on the pickle module."): + load_llm(tmp_path / "databricks2.yaml") + + load_llm(tmp_path / "databricks2.yaml", allow_dangerous_deserialization=True)