mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 08:32:32 +00:00
community[patch]: Databricks SerDe uses cloudpickle instead of pickle (#18607)
- **Description:** Databricks SerDe uses cloudpickle instead of pickle when serializing a user-defined function transform_input_fn since pickle does not support functions defined in `__main__`, and cloudpickle supports this. - **Dependencies:** cloudpickle>=2.0.0 Added a unit test.
This commit is contained in:
parent
f3e28289f6
commit
81985b31e6
@ -1,5 +1,4 @@
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
@ -225,7 +224,12 @@ def _is_hex_string(data: str) -> bool:
|
||||
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
|
||||
"""Loads a pickled function from a hexadecimal string."""
|
||||
try:
|
||||
return pickle.loads(bytes.fromhex(data))
|
||||
import cloudpickle
|
||||
except Exception as e:
|
||||
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")
|
||||
|
||||
try:
|
||||
return cloudpickle.loads(bytes.fromhex(data))
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
|
||||
@ -235,7 +239,12 @@ def _load_pickled_fn_from_hex_string(data: str) -> Callable:
|
||||
def _pickle_fn_to_hex_string(fn: Callable) -> str:
|
||||
"""Pickles a function and returns the hexadecimal string."""
|
||||
try:
|
||||
return pickle.dumps(fn).hex()
|
||||
import cloudpickle
|
||||
except Exception as e:
|
||||
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")
|
||||
|
||||
try:
|
||||
return cloudpickle.dumps(fn).hex()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to pickle the function: {e}")
|
||||
|
||||
|
8
libs/community/poetry.lock
generated
8
libs/community/poetry.lock
generated
@ -3650,7 +3650,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.28"
|
||||
version = "0.1.29"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -3687,7 +3687,7 @@ develop = true
|
||||
langchain-core = "^0.1.28"
|
||||
|
||||
[package.extras]
|
||||
extended-testing = []
|
||||
extended-testing = ["lxml (>=5.1.0,<6.0.0)"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
@ -9176,9 +9176,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
|
||||
[extras]
|
||||
cli = ["typer"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "d64381a1891a09e6215818c25ba7ca7b14a8708351695feab9ae53f4485f3b3e"
|
||||
content-hash = "d110eaaa4ecba8f6ed7faa2577b058c1f7c74171a6dbc53bc880f3c8598fc34b"
|
||||
|
@ -81,6 +81,7 @@ hologres-vector = {version = "^0.0.6", optional = true}
|
||||
praw = {version = "^7.7.1", optional = true}
|
||||
msal = {version = "^1.25.0", optional = true}
|
||||
databricks-vectorsearch = {version = "^0.21", optional = true}
|
||||
cloudpickle = {version = ">=2.0.0", optional = true}
|
||||
dgml-utils = {version = "^0.3.0", optional = true}
|
||||
datasets = {version = "^2.15.0", optional = true}
|
||||
tree-sitter = {version = "^0.20.2", optional = true}
|
||||
@ -249,6 +250,7 @@ extended_testing = [
|
||||
"hologres-vector",
|
||||
"praw",
|
||||
"databricks-vectorsearch",
|
||||
"cloudpickle",
|
||||
"dgml-utils",
|
||||
"cohere",
|
||||
"tree-sitter",
|
||||
@ -260,7 +262,8 @@ extended_testing = [
|
||||
"elasticsearch",
|
||||
"hdbcli",
|
||||
"oci",
|
||||
"rdflib"
|
||||
"rdflib",
|
||||
"cloudpickle",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@ -1,10 +1,13 @@
|
||||
"""test Databricks LLM"""
|
||||
import pickle
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from langchain_community.llms.databricks import Databricks
|
||||
from langchain_community.llms.databricks import (
|
||||
Databricks,
|
||||
_load_pickled_fn_from_hex_string,
|
||||
)
|
||||
|
||||
|
||||
class MockDatabricksServingEndpointClient:
|
||||
@ -29,7 +32,10 @@ def transform_input(**request: Any) -> Dict[str, Any]:
|
||||
return request
|
||||
|
||||
|
||||
@pytest.mark.requires("cloudpickle")
|
||||
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
||||
import cloudpickle
|
||||
|
||||
monkeypatch.setattr(
|
||||
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
|
||||
MockDatabricksServingEndpointClient,
|
||||
@ -42,5 +48,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
||||
transform_input_fn=transform_input,
|
||||
)
|
||||
params = llm._default_params
|
||||
pickled_string = pickle.dumps(transform_input).hex()
|
||||
pickled_string = cloudpickle.dumps(transform_input).hex()
|
||||
assert params["transform_input_fn"] == pickled_string
|
||||
|
||||
request = {"prompt": "What is the meaning of life?"}
|
||||
fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"])
|
||||
assert fn(**request) == transform_input(**request)
|
||||
|
Loading…
Reference in New Issue
Block a user