community[patch]: Databricks SerDe uses cloudpickle instead of pickle (#18607)

- **Description:** Databricks SerDe uses cloudpickle instead of pickle
when serializing a user-defined function transform_input_fn since pickle
does not support functions defined in `__main__`, and cloudpickle
supports this.
- **Dependencies:** cloudpickle>=2.0.0

Added a unit test.
This commit is contained in:
Liang Zhang
2024-03-05 18:04:45 -08:00
committed by GitHub
parent f3e28289f6
commit 81985b31e6
4 changed files with 33 additions and 11 deletions

View File

@@ -1,10 +1,13 @@
"""test Databricks LLM"""
import pickle
from typing import Any, Dict
import pytest
from pytest import MonkeyPatch
from langchain_community.llms.databricks import Databricks
from langchain_community.llms.databricks import (
Databricks,
_load_pickled_fn_from_hex_string,
)
class MockDatabricksServingEndpointClient:
@@ -29,7 +32,10 @@ def transform_input(**request: Any) -> Dict[str, Any]:
return request
@pytest.mark.requires("cloudpickle")
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
import cloudpickle
monkeypatch.setattr(
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
MockDatabricksServingEndpointClient,
@@ -42,5 +48,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
transform_input_fn=transform_input,
)
params = llm._default_params
pickled_string = pickle.dumps(transform_input).hex()
pickled_string = cloudpickle.dumps(transform_input).hex()
assert params["transform_input_fn"] == pickled_string
request = {"prompt": "What is the meaning of life?"}
fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"])
assert fn(**request) == transform_input(**request)