community[patch]: Databricks SerDe uses cloudpickle instead of pickle (#18607)

- **Description:** Databricks SerDe uses cloudpickle instead of pickle
when serializing a user-defined function transform_input_fn since pickle
does not support functions defined in `__main__`, and cloudpickle
supports this.
- **Dependencies:** cloudpickle>=2.0.0

Added a unit test.
This commit is contained in:
Liang Zhang 2024-03-05 18:04:45 -08:00 committed by GitHub
parent f3e28289f6
commit 81985b31e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 11 deletions

View File

@ -1,5 +1,4 @@
import os
import pickle
import re
import warnings
from abc import ABC, abstractmethod
@ -225,7 +224,12 @@ def _is_hex_string(data: str) -> bool:
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
"""Loads a pickled function from a hexadecimal string."""
try:
return pickle.loads(bytes.fromhex(data))
import cloudpickle
except Exception as e:
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")
try:
return cloudpickle.loads(bytes.fromhex(data))
except Exception as e:
raise ValueError(
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
@ -235,7 +239,12 @@ def _load_pickled_fn_from_hex_string(data: str) -> Callable:
def _pickle_fn_to_hex_string(fn: Callable) -> str:
"""Pickles a function and returns the hexadecimal string."""
try:
return pickle.dumps(fn).hex()
import cloudpickle
except Exception as e:
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")
try:
return cloudpickle.dumps(fn).hex()
except Exception as e:
raise ValueError(f"Failed to pickle the function: {e}")

View File

@ -3650,7 +3650,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.28"
version = "0.1.29"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -3687,7 +3687,7 @@ develop = true
langchain-core = "^0.1.28"
[package.extras]
extended-testing = []
extended-testing = ["lxml (>=5.1.0,<6.0.0)"]
[package.source]
type = "directory"
@ -9176,9 +9176,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[extras]
cli = ["typer"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "d64381a1891a09e6215818c25ba7ca7b14a8708351695feab9ae53f4485f3b3e"
content-hash = "d110eaaa4ecba8f6ed7faa2577b058c1f7c74171a6dbc53bc880f3c8598fc34b"

View File

@ -81,6 +81,7 @@ hologres-vector = {version = "^0.0.6", optional = true}
praw = {version = "^7.7.1", optional = true}
msal = {version = "^1.25.0", optional = true}
databricks-vectorsearch = {version = "^0.21", optional = true}
cloudpickle = {version = ">=2.0.0", optional = true}
dgml-utils = {version = "^0.3.0", optional = true}
datasets = {version = "^2.15.0", optional = true}
tree-sitter = {version = "^0.20.2", optional = true}
@ -249,6 +250,7 @@ extended_testing = [
"hologres-vector",
"praw",
"databricks-vectorsearch",
"cloudpickle",
"dgml-utils",
"cohere",
"tree-sitter",
@ -260,7 +262,8 @@ extended_testing = [
"elasticsearch",
"hdbcli",
"oci",
"rdflib"
"rdflib",
"cloudpickle",
]
[tool.ruff]

View File

@ -1,10 +1,13 @@
"""test Databricks LLM"""
import pickle
from typing import Any, Dict
import pytest
from pytest import MonkeyPatch
from langchain_community.llms.databricks import Databricks
from langchain_community.llms.databricks import (
Databricks,
_load_pickled_fn_from_hex_string,
)
class MockDatabricksServingEndpointClient:
@ -29,7 +32,10 @@ def transform_input(**request: Any) -> Dict[str, Any]:
return request
@pytest.mark.requires("cloudpickle")
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
import cloudpickle
monkeypatch.setattr(
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
MockDatabricksServingEndpointClient,
@ -42,5 +48,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
transform_input_fn=transform_input,
)
params = llm._default_params
pickled_string = pickle.dumps(transform_input).hex()
pickled_string = cloudpickle.dumps(transform_input).hex()
assert params["transform_input_fn"] == pickled_string
request = {"prompt": "What is the meaning of life?"}
fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"])
assert fn(**request) == transform_input(**request)