mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 18:24:10 +00:00
community[patch]: Databricks SerDe uses cloudpickle instead of pickle (#18607)
- **Description:** Databricks SerDe uses cloudpickle instead of pickle when serializing a user-defined function transform_input_fn since pickle does not support functions defined in `__main__`, and cloudpickle supports this. - **Dependencies:** cloudpickle>=2.0.0 Added a unit test.
This commit is contained in:
parent
f3e28289f6
commit
81985b31e6
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -225,7 +224,12 @@ def _is_hex_string(data: str) -> bool:
|
|||||||
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
|
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
|
||||||
"""Loads a pickled function from a hexadecimal string."""
|
"""Loads a pickled function from a hexadecimal string."""
|
||||||
try:
|
try:
|
||||||
return pickle.loads(bytes.fromhex(data))
|
import cloudpickle
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return cloudpickle.loads(bytes.fromhex(data))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
|
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
|
||||||
@ -235,7 +239,12 @@ def _load_pickled_fn_from_hex_string(data: str) -> Callable:
|
|||||||
def _pickle_fn_to_hex_string(fn: Callable) -> str:
|
def _pickle_fn_to_hex_string(fn: Callable) -> str:
|
||||||
"""Pickles a function and returns the hexadecimal string."""
|
"""Pickles a function and returns the hexadecimal string."""
|
||||||
try:
|
try:
|
||||||
return pickle.dumps(fn).hex()
|
import cloudpickle
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return cloudpickle.dumps(fn).hex()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to pickle the function: {e}")
|
raise ValueError(f"Failed to pickle the function: {e}")
|
||||||
|
|
||||||
|
8
libs/community/poetry.lock
generated
8
libs/community/poetry.lock
generated
@ -3650,7 +3650,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.1.28"
|
version = "0.1.29"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
@ -3687,7 +3687,7 @@ develop = true
|
|||||||
langchain-core = "^0.1.28"
|
langchain-core = "^0.1.28"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
extended-testing = []
|
extended-testing = ["lxml (>=5.1.0,<6.0.0)"]
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "directory"
|
type = "directory"
|
||||||
@ -9176,9 +9176,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
cli = ["typer"]
|
cli = ["typer"]
|
||||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"]
|
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "d64381a1891a09e6215818c25ba7ca7b14a8708351695feab9ae53f4485f3b3e"
|
content-hash = "d110eaaa4ecba8f6ed7faa2577b058c1f7c74171a6dbc53bc880f3c8598fc34b"
|
||||||
|
@ -81,6 +81,7 @@ hologres-vector = {version = "^0.0.6", optional = true}
|
|||||||
praw = {version = "^7.7.1", optional = true}
|
praw = {version = "^7.7.1", optional = true}
|
||||||
msal = {version = "^1.25.0", optional = true}
|
msal = {version = "^1.25.0", optional = true}
|
||||||
databricks-vectorsearch = {version = "^0.21", optional = true}
|
databricks-vectorsearch = {version = "^0.21", optional = true}
|
||||||
|
cloudpickle = {version = ">=2.0.0", optional = true}
|
||||||
dgml-utils = {version = "^0.3.0", optional = true}
|
dgml-utils = {version = "^0.3.0", optional = true}
|
||||||
datasets = {version = "^2.15.0", optional = true}
|
datasets = {version = "^2.15.0", optional = true}
|
||||||
tree-sitter = {version = "^0.20.2", optional = true}
|
tree-sitter = {version = "^0.20.2", optional = true}
|
||||||
@ -249,6 +250,7 @@ extended_testing = [
|
|||||||
"hologres-vector",
|
"hologres-vector",
|
||||||
"praw",
|
"praw",
|
||||||
"databricks-vectorsearch",
|
"databricks-vectorsearch",
|
||||||
|
"cloudpickle",
|
||||||
"dgml-utils",
|
"dgml-utils",
|
||||||
"cohere",
|
"cohere",
|
||||||
"tree-sitter",
|
"tree-sitter",
|
||||||
@ -260,7 +262,8 @@ extended_testing = [
|
|||||||
"elasticsearch",
|
"elasticsearch",
|
||||||
"hdbcli",
|
"hdbcli",
|
||||||
"oci",
|
"oci",
|
||||||
"rdflib"
|
"rdflib",
|
||||||
|
"cloudpickle",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
"""test Databricks LLM"""
|
"""test Databricks LLM"""
|
||||||
import pickle
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
from pytest import MonkeyPatch
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
from langchain_community.llms.databricks import Databricks
|
from langchain_community.llms.databricks import (
|
||||||
|
Databricks,
|
||||||
|
_load_pickled_fn_from_hex_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MockDatabricksServingEndpointClient:
|
class MockDatabricksServingEndpointClient:
|
||||||
@ -29,7 +32,10 @@ def transform_input(**request: Any) -> Dict[str, Any]:
|
|||||||
return request
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("cloudpickle")
|
||||||
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
import cloudpickle
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
|
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
|
||||||
MockDatabricksServingEndpointClient,
|
MockDatabricksServingEndpointClient,
|
||||||
@ -42,5 +48,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
|||||||
transform_input_fn=transform_input,
|
transform_input_fn=transform_input,
|
||||||
)
|
)
|
||||||
params = llm._default_params
|
params = llm._default_params
|
||||||
pickled_string = pickle.dumps(transform_input).hex()
|
pickled_string = cloudpickle.dumps(transform_input).hex()
|
||||||
assert params["transform_input_fn"] == pickled_string
|
assert params["transform_input_fn"] == pickled_string
|
||||||
|
|
||||||
|
request = {"prompt": "What is the meaning of life?"}
|
||||||
|
fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"])
|
||||||
|
assert fn(**request) == transform_input(**request)
|
||||||
|
Loading…
Reference in New Issue
Block a user