mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 04:55:14 +00:00
community[patch]: Databricks SerDe uses cloudpickle instead of pickle (#18607)
- **Description:** Databricks SerDe uses cloudpickle instead of pickle when serializing a user-defined function transform_input_fn since pickle does not support functions defined in `__main__`, and cloudpickle supports this. - **Dependencies:** cloudpickle>=2.0.0 Added a unit test.
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
"""test Databricks LLM"""
|
||||
import pickle
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from langchain_community.llms.databricks import Databricks
|
||||
from langchain_community.llms.databricks import (
|
||||
Databricks,
|
||||
_load_pickled_fn_from_hex_string,
|
||||
)
|
||||
|
||||
|
||||
class MockDatabricksServingEndpointClient:
|
||||
@@ -29,7 +32,10 @@ def transform_input(**request: Any) -> Dict[str, Any]:
|
||||
return request
|
||||
|
||||
|
||||
@pytest.mark.requires("cloudpickle")
|
||||
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
||||
import cloudpickle
|
||||
|
||||
monkeypatch.setattr(
|
||||
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
|
||||
MockDatabricksServingEndpointClient,
|
||||
@@ -42,5 +48,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
||||
transform_input_fn=transform_input,
|
||||
)
|
||||
params = llm._default_params
|
||||
pickled_string = pickle.dumps(transform_input).hex()
|
||||
pickled_string = cloudpickle.dumps(transform_input).hex()
|
||||
assert params["transform_input_fn"] == pickled_string
|
||||
|
||||
request = {"prompt": "What is the meaning of life?"}
|
||||
fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"])
|
||||
assert fn(**request) == transform_input(**request)
|
||||
|
Reference in New Issue
Block a user