mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
community[patch]: Databricks - fix scope of dangerous deserialization error in Databricks LLM connector (#20368)
fix scope of dangerous deserialization error in Databricks LLM connector --------- Signed-off-by: dbczumar <corey.zumar@databricks.com>
This commit is contained in:
parent
f1248f8d9a
commit
3a068b26f3
@ -221,8 +221,21 @@ def _is_hex_string(data: str) -> bool:
|
|||||||
return bool(re.match(pattern, data))
|
return bool(re.match(pattern, data))
|
||||||
|
|
||||||
|
|
||||||
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
|
def _load_pickled_fn_from_hex_string(
|
||||||
|
data: str, allow_dangerous_deserialization: Optional[bool]
|
||||||
|
) -> Callable:
|
||||||
"""Loads a pickled function from a hexadecimal string."""
|
"""Loads a pickled function from a hexadecimal string."""
|
||||||
|
if not allow_dangerous_deserialization:
|
||||||
|
raise ValueError(
|
||||||
|
"This code relies on the pickle module. "
|
||||||
|
"You will need to set allow_dangerous_deserialization=True "
|
||||||
|
"if you want to opt-in to allow deserialization of data using pickle."
|
||||||
|
"Data can be compromised by a malicious actor if "
|
||||||
|
"not handled properly to include "
|
||||||
|
"a malicious payload that when deserialized with "
|
||||||
|
"pickle can execute arbitrary code on your machine."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -443,25 +456,21 @@ class Databricks(LLM):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
def __init__(self, **data: Any):
|
def __init__(self, **data: Any):
|
||||||
if not data.get("allow_dangerous_deserialization"):
|
|
||||||
raise ValueError(
|
|
||||||
"This code relies on the pickle module. "
|
|
||||||
"You will need to set allow_dangerous_deserialization=True "
|
|
||||||
"if you want to opt-in to allow deserialization of data using pickle."
|
|
||||||
"Data can be compromised by a malicious actor if "
|
|
||||||
"not handled properly to include "
|
|
||||||
"a malicious payload that when deserialized with "
|
|
||||||
"pickle can execute arbitrary code on your machine."
|
|
||||||
)
|
|
||||||
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
||||||
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
|
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
|
||||||
data["transform_input_fn"]
|
data=data["transform_input_fn"],
|
||||||
|
allow_dangerous_deserialization=data.get(
|
||||||
|
"allow_dangerous_deserialization"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if "transform_output_fn" in data and _is_hex_string(
|
if "transform_output_fn" in data and _is_hex_string(
|
||||||
data["transform_output_fn"]
|
data["transform_output_fn"]
|
||||||
):
|
):
|
||||||
data["transform_output_fn"] = _load_pickled_fn_from_hex_string(
|
data["transform_output_fn"] = _load_pickled_fn_from_hex_string(
|
||||||
data["transform_output_fn"]
|
data=data["transform_output_fn"],
|
||||||
|
allow_dangerous_deserialization=data.get(
|
||||||
|
"allow_dangerous_deserialization"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
@ -56,7 +56,10 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
|||||||
assert params["transform_input_fn"] == pickled_string
|
assert params["transform_input_fn"] == pickled_string
|
||||||
|
|
||||||
request = {"prompt": "What is the meaning of life?"}
|
request = {"prompt": "What is the meaning of life?"}
|
||||||
fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"])
|
fn = _load_pickled_fn_from_hex_string(
|
||||||
|
data=params["transform_input_fn"],
|
||||||
|
allow_dangerous_deserialization=True,
|
||||||
|
)
|
||||||
assert fn(**request) == transform_input(**request)
|
assert fn(**request) == transform_input(**request)
|
||||||
|
|
||||||
|
|
||||||
@ -69,15 +72,44 @@ def test_saving_loading_llm(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
|
|||||||
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
|
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
|
||||||
|
|
||||||
llm = Databricks(
|
llm = Databricks(
|
||||||
endpoint_name="chat", temperature=0.1, allow_dangerous_deserialization=True
|
endpoint_name="chat",
|
||||||
|
temperature=0.1,
|
||||||
)
|
)
|
||||||
llm.save(file_path=tmp_path / "databricks.yaml")
|
llm.save(file_path=tmp_path / "databricks.yaml")
|
||||||
|
|
||||||
# Loading without allowing_dangerous_deserialization=True should raise an error.
|
loaded_llm = load_llm(tmp_path / "databricks.yaml")
|
||||||
with pytest.raises(ValueError, match="This code relies on the pickle module."):
|
|
||||||
load_llm(tmp_path / "databricks.yaml")
|
|
||||||
|
|
||||||
loaded_llm = load_llm(
|
|
||||||
tmp_path / "databricks.yaml", allow_dangerous_deserialization=True
|
|
||||||
)
|
|
||||||
assert_llm_equality(llm, loaded_llm)
|
assert_llm_equality(llm, loaded_llm)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("cloudpickle")
|
||||||
|
def test_saving_loading_llm_dangerous_serde_check(
|
||||||
|
monkeypatch: MonkeyPatch, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
|
||||||
|
MockDatabricksServingEndpointClient,
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
|
||||||
|
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
|
||||||
|
|
||||||
|
llm1 = Databricks(
|
||||||
|
endpoint_name="chat",
|
||||||
|
temperature=0.1,
|
||||||
|
transform_input_fn=lambda x, y, **kwargs: {},
|
||||||
|
)
|
||||||
|
llm1.save(file_path=tmp_path / "databricks1.yaml")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="This code relies on the pickle module."):
|
||||||
|
load_llm(tmp_path / "databricks1.yaml")
|
||||||
|
|
||||||
|
load_llm(tmp_path / "databricks1.yaml", allow_dangerous_deserialization=True)
|
||||||
|
|
||||||
|
llm2 = Databricks(
|
||||||
|
endpoint_name="chat", temperature=0.1, transform_output_fn=lambda x: "test"
|
||||||
|
)
|
||||||
|
llm2.save(file_path=tmp_path / "databricks2.yaml")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="This code relies on the pickle module."):
|
||||||
|
load_llm(tmp_path / "databricks2.yaml")
|
||||||
|
|
||||||
|
load_llm(tmp_path / "databricks2.yaml", allow_dangerous_deserialization=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user