community: support Databricks Unity Catalog functions as LangChain tools (#22555)

This PR adds support for using Databricks Unity Catalog functions as
LangChain tools, which runs inside a Databricks SQL warehouse.

* An example notebook is provided.
This commit is contained in:
Xiangrui Meng
2024-06-06 09:38:50 -07:00
committed by GitHub
parent c1ef731503
commit f26ab93df8
4 changed files with 544 additions and 0 deletions

View File

@@ -0,0 +1,168 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Databricks Unity Catalog (UC)\n",
"\n",
"This notebook shows how to use UC functions as LangChain tools.\n",
"\n",
"See Databricks documentation ([AWS](https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-sql-function.html)|[Azure](https://learn.microsoft.com/en-us/azure/databricks/sql/language-manual/sql-ref-syntax-ddl-create-sql-function)|[GCP](https://docs.gcp.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-sql-function.html)) to learn how to create SQL or Python functions in UC. Do not skip function and parameter comments, which are critical for LLMs to call functions properly.\n",
"\n",
"In this example notebook, we create a simple Python function that executes arbitary code and use it as a LangChain tool:\n",
"\n",
"```sql\n",
"CREATE FUNCTION main.tools.python_exec (\n",
" code STRING COMMENT 'Python code to execute. Remember to print the final result to stdout.'\n",
")\n",
"RETURNS STRING\n",
"LANGUAGE PYTHON\n",
"COMMENT 'Executes Python code and returns its stdout.'\n",
"AS $$\n",
" import sys\n",
" from io import StringIO\n",
" stdout = StringIO()\n",
" sys.stdout = stdout\n",
" exec(code)\n",
" return stdout.getvalue()\n",
"$$\n",
"```\n",
"\n",
"It runs in a secure and isolated environment within a Databricks SQL warehouse."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet databricks-sdk langchain-community langchain-openai"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.tools.databricks import UCFunctionToolkit\n",
"\n",
"tools = (\n",
" UCFunctionToolkit(\n",
" # You can find the SQL warehouse ID in its UI after creation.\n",
" warehouse_id=\"xxxx123456789\"\n",
" )\n",
" .include(\n",
" # Include functions as tools using their qualified names.\n",
" # You can use \"{catalog_name}.{schema_name}.*\" to get all functions in a schema.\n",
" \"main.tools.python_exec\",\n",
" )\n",
" .get_tools()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import AgentExecutor, create_tool_calling_agent\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant. Make sure to use tool for information.\",\n",
" ),\n",
" (\"placeholder\", \"{chat_history}\"),\n",
" (\"human\", \"{input}\"),\n",
" (\"placeholder\", \"{agent_scratchpad}\"),\n",
" ]\n",
")\n",
"\n",
"agent = create_tool_calling_agent(llm, tools, prompt)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `main__tools__python_exec` with `{'code': 'print(36939 * 8922.4)'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m{\"format\": \"SCALAR\", \"value\": \"329584533.59999996\\n\", \"truncated\": false}\u001b[0m\u001b[32;1m\u001b[1;3mThe result of the multiplication 36939 * 8922.4 is 329,584,533.60.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': '36939 * 8922.4',\n",
" 'output': 'The result of the multiplication 36939 * 8922.4 is 329,584,533.60.'}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n",
"agent_executor.invoke({\"input\": \"36939 * 8922.4\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "llm",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}