mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
community[patch]: Support SerDe transform functions in Databricks LLM (#16752)
**Description:** Databricks LLM does not support SerDe the transform_input_fn and transform_output_fn. After saving and loading, the LLM will be broken. This PR serialize these functions into a hex string using pickle, and saving the hex string in the yaml file. Using pickle to serialize a function can be flaky, but this is a simple workaround that unblocks many use cases. If more sophisticated SerDe is needed, we can improve it later. Test: Added a simple unit test. I did manual test on Databricks and it works well. The saved yaml looks like: ``` llm: _type: databricks cluster_driver_port: null cluster_id: null databricks_uri: databricks endpoint_name: databricks-mixtral-8x7b-instruct extra_params: {} host: e2-dogfood.staging.cloud.databricks.com max_tokens: null model_kwargs: null n: 1 stop: null task: null temperature: 0.0 transform_input_fn: 80049520000000000000008c085f5f6d61696e5f5f948c0f7472616e73666f726d5f696e7075749493942e transform_output_fn: null ``` @baskaryan ```python from langchain_community.embeddings import DatabricksEmbeddings from langchain_community.llms import Databricks from langchain.chains import RetrievalQA from langchain.document_loaders import TextLoader from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import FAISS import mlflow embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en") def transform_input(**request): request["messages"] = [ { "role": "user", "content": request["prompt"] } ] del request["prompt"] return request llm = Databricks(endpoint_name="databricks-mixtral-8x7b-instruct", transform_input_fn=transform_input) persist_dir = "faiss_databricks_embedding" # Create the vector db, persist the db to a local fs folder loader = TextLoader("state_of_the_union.txt") documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) docs = text_splitter.split_documents(documents) db = FAISS.from_documents(docs, embeddings) db.save_local(persist_dir) def load_retriever(persist_directory): embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en") vectorstore = FAISS.load_local(persist_directory, embeddings) return vectorstore.as_retriever() retriever = load_retriever(persist_dir) retrievalQA = RetrievalQA.from_llm(llm=llm, retriever=retriever) with mlflow.start_run() as run: logged_model = mlflow.langchain.log_model( retrievalQA, artifact_path="retrieval_qa", loader_fn=load_retriever, persist_dir=persist_dir, ) # Load the retrievalQA chain loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri) print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}])) ```
This commit is contained in:
parent
ce22e10c4b
commit
7306600e2f
@ -1,4 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||||
@ -212,6 +214,32 @@ def get_default_api_token() -> str:
|
|||||||
return api_token
|
return api_token
|
||||||
|
|
||||||
|
|
||||||
|
def _is_hex_string(data: str) -> bool:
|
||||||
|
"""Checks if a data is a valid hexadecimal string using a regular expression."""
|
||||||
|
if not isinstance(data, str):
|
||||||
|
return False
|
||||||
|
pattern = r"^[0-9a-fA-F]+$"
|
||||||
|
return bool(re.match(pattern, data))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
|
||||||
|
"""Loads a pickled function from a hexadecimal string."""
|
||||||
|
try:
|
||||||
|
return pickle.loads(bytes.fromhex(data))
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _pickle_fn_to_hex_string(fn: Callable) -> str:
|
||||||
|
"""Pickles a function and returns the hexadecimal string."""
|
||||||
|
try:
|
||||||
|
return pickle.dumps(fn).hex()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to pickle the function: {e}")
|
||||||
|
|
||||||
|
|
||||||
class Databricks(LLM):
|
class Databricks(LLM):
|
||||||
|
|
||||||
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
|
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
|
||||||
@ -398,6 +426,17 @@ class Databricks(LLM):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
def __init__(self, **data: Any):
|
def __init__(self, **data: Any):
|
||||||
|
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
||||||
|
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
|
||||||
|
data["transform_input_fn"]
|
||||||
|
)
|
||||||
|
if "transform_output_fn" in data and _is_hex_string(
|
||||||
|
data["transform_output_fn"]
|
||||||
|
):
|
||||||
|
data["transform_output_fn"] = _load_pickled_fn_from_hex_string(
|
||||||
|
data["transform_output_fn"]
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
if self.model_kwargs is not None and self.extra_params is not None:
|
if self.model_kwargs is not None and self.extra_params is not None:
|
||||||
raise ValueError("Cannot set both extra_params and extra_params.")
|
raise ValueError("Cannot set both extra_params and extra_params.")
|
||||||
@ -443,9 +482,12 @@ class Databricks(LLM):
|
|||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
"extra_params": self.extra_params,
|
"extra_params": self.extra_params,
|
||||||
"task": self.task,
|
"task": self.task,
|
||||||
# TODO: Support saving transform_input_fn and transform_output_fn
|
"transform_input_fn": None
|
||||||
# "transform_input_fn": self.transform_input_fn,
|
if self.transform_input_fn is None
|
||||||
# "transform_output_fn": self.transform_output_fn,
|
else _pickle_fn_to_hex_string(self.transform_input_fn),
|
||||||
|
"transform_output_fn": None
|
||||||
|
if self.transform_output_fn is None
|
||||||
|
else _pickle_fn_to_hex_string(self.transform_output_fn),
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
46
libs/community/tests/unit_tests/llms/test_databricks.py
Normal file
46
libs/community/tests/unit_tests/llms/test_databricks.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""test Databricks LLM"""
|
||||||
|
import pickle
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
|
from langchain_community.llms.databricks import Databricks
|
||||||
|
|
||||||
|
|
||||||
|
class MockDatabricksServingEndpointClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
api_token: str,
|
||||||
|
endpoint_name: str,
|
||||||
|
databricks_uri: str,
|
||||||
|
task: str,
|
||||||
|
):
|
||||||
|
self.host = host
|
||||||
|
self.api_token = api_token
|
||||||
|
self.endpoint_name = endpoint_name
|
||||||
|
self.databricks_uri = databricks_uri
|
||||||
|
self.task = task
|
||||||
|
|
||||||
|
|
||||||
|
def transform_input(**request: Any) -> Dict[str, Any]:
|
||||||
|
request["messages"] = [{"role": "user", "content": request["prompt"]}]
|
||||||
|
del request["prompt"]
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
|
||||||
|
MockDatabricksServingEndpointClient,
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
|
||||||
|
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
|
||||||
|
|
||||||
|
llm = Databricks(
|
||||||
|
endpoint_name="databricks-mixtral-8x7b-instruct",
|
||||||
|
transform_input_fn=transform_input,
|
||||||
|
)
|
||||||
|
params = llm._default_params
|
||||||
|
pickled_string = pickle.dumps(transform_input).hex()
|
||||||
|
assert params["transform_input_fn"] == pickled_string
|
Loading…
Reference in New Issue
Block a user