{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Databricks\n", "\n", "The [Databricks](https://www.databricks.com/) Lakehouse Platform unifies data, analytics, and AI on one platform.\n", "\n", "This example notebook shows how to wrap Databricks endpoints as LLMs in LangChain.\n", "It supports two endpoint types:\n", "\n", "* Serving endpoint, recommended for production and development,\n", "* Cluster driver proxy app, recommended for interactive development." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Installation\n", "\n", "`mlflow >= 2.9 ` is required to run the code in this notebook. If it's not installed, please install it using this command:\n", "\n", "```\n", "pip install mlflow>=2.9\n", "```\n", "\n", "Also, we need `dbutils` for this example.\n", "\n", "```\n", "pip install dbutils\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Wrapping a serving endpoint: External model\n", "\n", "Prerequisite: Register an OpenAI API key as a secret:\n", "\n", " ```bash\n", " databricks secrets create-scope \n", " databricks secrets put-secret openai-api-key --string-value $OPENAI_API_KEY\n", " ```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following code creates a new serving endpoint with OpenAI's GPT-4 model for chat and generates a response using the endpoint." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "content='Hello! How can I assist you today?'\n" ] } ], "source": [ "from langchain_community.chat_models import ChatDatabricks\n", "from langchain_core.messages import HumanMessage\n", "from mlflow.deployments import get_deploy_client\n", "\n", "client = get_deploy_client(\"databricks\")\n", "\n", "secret = \"secrets//openai-api-key\" # replace `` with your scope\n", "name = \"my-chat\" # rename this if my-chat already exists\n", "client.create_endpoint(\n", " name=name,\n", " config={\n", " \"served_entities\": [\n", " {\n", " \"name\": \"my-chat\",\n", " \"external_model\": {\n", " \"name\": \"gpt-4\",\n", " \"provider\": \"openai\",\n", " \"task\": \"llm/v1/chat\",\n", " \"openai_config\": {\n", " \"openai_api_key\": \"{{\" + secret + \"}}\",\n", " },\n", " },\n", " }\n", " ],\n", " },\n", ")\n", "\n", "chat = ChatDatabricks(\n", " target_uri=\"databricks\",\n", " endpoint=name,\n", " temperature=0.1,\n", ")\n", "chat([HumanMessage(content=\"hello\")])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Wrapping a serving endpoint: Foundation model\n", "\n", "The following code uses the `databricks-bge-large-en` serving endpoint (no endpoint creation is required) to generate embeddings from input text." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.051055908203125, 0.007221221923828125, 0.003879547119140625]\n" ] } ], "source": [ "from langchain_community.embeddings import DatabricksEmbeddings\n", "\n", "embeddings = DatabricksEmbeddings(endpoint=\"databricks-bge-large-en\")\n", "embeddings.embed_query(\"hello\")[:3]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Wrapping a serving endpoint: Custom model\n", "\n", "Prerequisites:\n", "\n", "* An LLM was registered and deployed to [a Databricks serving endpoint](https://docs.databricks.com/machine-learning/model-serving/index.html).\n", "* You have [\"Can Query\" permission](https://docs.databricks.com/security/auth-authz/access-control/serving-endpoint-acl.html) to the endpoint.\n", "\n", "The expected MLflow model signature is:\n", "\n", " * inputs: `[{\"name\": \"prompt\", \"type\": \"string\"}, {\"name\": \"stop\", \"type\": \"list[string]\"}]`\n", " * outputs: `[{\"type\": \"string\"}]`\n", "\n", "If the model signature is incompatible or you want to insert extra configs, you can set `transform_input_fn` and `transform_output_fn` accordingly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'I am happy to hear that you are in good health and as always, you are appreciated.'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from langchain_community.llms import Databricks\n", "\n", "# If running a Databricks notebook attached to an interactive cluster in \"single user\"\n", "# or \"no isolation shared\" mode, you only need to specify the endpoint name to create\n", "# a `Databricks` instance to query a serving endpoint in the same workspace.\n", "llm = Databricks(endpoint_name=\"dolly\")\n", "\n", "llm(\"How are you?\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Good'" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "llm(\"How are you?\", stop=[\".\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'I am fine. Thank you!'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Otherwise, you can manually specify the Databricks workspace hostname and personal access token\n", "# or set `DATABRICKS_HOST` and `DATABRICKS_TOKEN` environment variables, respectively.\n", "# See https://docs.databricks.com/dev-tools/auth.html#databricks-personal-access-tokens\n", "# We strongly recommend not exposing the API token explicitly inside a notebook.\n", "# You can use Databricks secret manager to store your API token securely.\n", "# See https://docs.databricks.com/dev-tools/databricks-utils.html#secrets-utility-dbutilssecrets\n", "\n", "import os\n", "\n", "import dbutils\n", "\n", "os.environ[\"DATABRICKS_TOKEN\"] = dbutils.secrets.get(\"myworkspace\", \"api_token\")\n", "\n", "llm = Databricks(host=\"myworkspace.cloud.databricks.com\", endpoint_name=\"dolly\")\n", "\n", "llm(\"How are you?\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'I am fine.'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# If the serving endpoint accepts extra parameters like `temperature`,\n", "# you can set them in `model_kwargs`.\n", "llm = Databricks(endpoint_name=\"dolly\", model_kwargs={\"temperature\": 0.1})\n", "\n", "llm(\"How are you?\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'I’m Excellent. You?'" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use `transform_input_fn` and `transform_output_fn` if the serving endpoint\n", "# expects a different input schema and does not return a JSON string,\n", "# respectively, or you want to apply a prompt template on top.\n", "\n", "\n", "def transform_input(**request):\n", " full_prompt = f\"\"\"{request[\"prompt\"]}\n", " Be Concise.\n", " \"\"\"\n", " request[\"prompt\"] = full_prompt\n", " return request\n", "\n", "\n", "llm = Databricks(endpoint_name=\"dolly\", transform_input_fn=transform_input)\n", "\n", "llm(\"How are you?\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { }, "source": [ "## Wrapping a cluster driver proxy app\n", "\n", "Prerequisites:\n", "\n", "* An LLM loaded on a Databricks interactive cluster in \"single user\" or \"no isolation shared\" mode.\n", "* A local HTTP server running on the driver node to serve the model at `\"/\"` using HTTP POST with JSON input/output.\n", "* It uses a port number between `[3000, 8000]` and listens to the driver IP address or simply `0.0.0.0` instead of localhost only.\n", "* You have \"Can Attach To\" permission to the cluster.\n", "\n", "The expected server schema (using JSON schema) is:\n", "\n", "* inputs:\n", " ```json\n", " {\"type\": \"object\",\n", " \"properties\": {\n", " \"prompt\": {\"type\": \"string\"},\n", " \"stop\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}}},\n", " \"required\": [\"prompt\"]}\n", " ```\n", "* outputs: `{\"type\": \"string\"}`\n", "\n", "If the server schema is incompatible or you want to insert extra configs, you can use `transform_input_fn` and `transform_output_fn` accordingly.\n", "\n", "The following is a minimal example for running a driver proxy app to serve an LLM:\n", "\n", "```python\n", "from flask import Flask, request, jsonify\n", "import torch\n", "from transformers import pipeline, AutoTokenizer, StoppingCriteria\n", "\n", "model = \"databricks/dolly-v2-3b\"\n", "tokenizer = AutoTokenizer.from_pretrained(model, padding_side=\"left\")\n", "dolly = pipeline(model=model, tokenizer=tokenizer, trust_remote_code=True, device_map=\"auto\")\n", "device = dolly.device\n", "\n", "class CheckStop(StoppingCriteria):\n", " def __init__(self, stop=None):\n", " super().__init__()\n", " self.stop = stop or []\n", " self.matched = \"\"\n", " self.stop_ids = [tokenizer.encode(s, return_tensors='pt').to(device) for s in self.stop]\n", " def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):\n", " for i, s in enumerate(self.stop_ids):\n", " if torch.all((s == input_ids[0][-s.shape[1]:])).item():\n", " self.matched = self.stop[i]\n", " return True\n", " return False\n", "\n", "def llm(prompt, stop=None, **kwargs):\n", " check_stop = CheckStop(stop)\n", " result = dolly(prompt, stopping_criteria=[check_stop], **kwargs)\n", " return result[0][\"generated_text\"].rstrip(check_stop.matched)\n", "\n", "app = Flask(\"dolly\")\n", "\n", "@app.route('/', methods=['POST'])\n", "def serve_llm():\n", " resp = llm(**request.json)\n", " return jsonify(resp)\n", "\n", "app.run(host=\"0.0.0.0\", port=\"7777\")\n", "```\n", "\n", "Once the server is running, you can create a `Databricks` instance to wrap it as an LLM." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Hello, thank you for asking. It is wonderful to hear that you are well.'" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# If running a Databricks notebook attached to the same cluster that runs the app,\n", "# you only need to specify the driver port to create a `Databricks` instance.\n", "llm = Databricks(cluster_driver_port=\"7777\")\n", "\n", "llm(\"How are you?\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'I am well. You?'" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Otherwise, you can manually specify the cluster ID to use,\n", "# as well as Databricks workspace hostname and personal access token.\n", "\n", "llm = Databricks(cluster_id=\"0000-000000-xxxxxxxx\", cluster_driver_port=\"7777\")\n", "\n", "llm(\"How are you?\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'I am very well. It is a pleasure to meet you.'" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# If the app accepts extra parameters like `temperature`,\n", "# you can set them in `model_kwargs`.\n", "llm = Databricks(cluster_driver_port=\"7777\", model_kwargs={\"temperature\": 0.1})\n", "\n", "llm(\"How are you?\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'I AM DOING GREAT THANK YOU.'" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use `transform_input_fn` and `transform_output_fn` if the app\n", "# expects a different input schema and does not return a JSON string,\n", "# respectively, or you want to apply a prompt template on top.\n", "\n", "\n", "def transform_input(**request):\n", " full_prompt = f\"\"\"{request[\"prompt\"]}\n", " Be Concise.\n", " \"\"\"\n", " request[\"prompt\"] = full_prompt\n", " return request\n", "\n", "\n", "def transform_output(response):\n", " return response.upper()\n", "\n", "\n", "llm = Databricks(\n", " cluster_driver_port=\"7777\",\n", " transform_input_fn=transform_input,\n", " transform_output_fn=transform_output,\n", ")\n", "\n", "llm(\"How are you?\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }