mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
community: support Databricks Unity Catalog functions as LangChain tools (#22555)
This PR adds support for using Databricks Unity Catalog functions as LangChain tools, which runs inside a Databricks SQL warehouse. * An example notebook is provided.
This commit is contained in:
parent
c1ef731503
commit
f26ab93df8
168
docs/docs/integrations/tools/databricks.ipynb
Normal file
168
docs/docs/integrations/tools/databricks.ipynb
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Databricks Unity Catalog (UC)\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use UC functions as LangChain tools.\n",
|
||||||
|
"\n",
|
||||||
|
"See Databricks documentation ([AWS](https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-sql-function.html)|[Azure](https://learn.microsoft.com/en-us/azure/databricks/sql/language-manual/sql-ref-syntax-ddl-create-sql-function)|[GCP](https://docs.gcp.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-sql-function.html)) to learn how to create SQL or Python functions in UC. Do not skip function and parameter comments, which are critical for LLMs to call functions properly.\n",
|
||||||
|
"\n",
|
||||||
|
"In this example notebook, we create a simple Python function that executes arbitary code and use it as a LangChain tool:\n",
|
||||||
|
"\n",
|
||||||
|
"```sql\n",
|
||||||
|
"CREATE FUNCTION main.tools.python_exec (\n",
|
||||||
|
" code STRING COMMENT 'Python code to execute. Remember to print the final result to stdout.'\n",
|
||||||
|
")\n",
|
||||||
|
"RETURNS STRING\n",
|
||||||
|
"LANGUAGE PYTHON\n",
|
||||||
|
"COMMENT 'Executes Python code and returns its stdout.'\n",
|
||||||
|
"AS $$\n",
|
||||||
|
" import sys\n",
|
||||||
|
" from io import StringIO\n",
|
||||||
|
" stdout = StringIO()\n",
|
||||||
|
" sys.stdout = stdout\n",
|
||||||
|
" exec(code)\n",
|
||||||
|
" return stdout.getvalue()\n",
|
||||||
|
"$$\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"It runs in a secure and isolated environment within a Databricks SQL warehouse."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install --upgrade --quiet databricks-sdk langchain-community langchain-openai"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_openai import ChatOpenAI\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.tools.databricks import UCFunctionToolkit\n",
|
||||||
|
"\n",
|
||||||
|
"tools = (\n",
|
||||||
|
" UCFunctionToolkit(\n",
|
||||||
|
" # You can find the SQL warehouse ID in its UI after creation.\n",
|
||||||
|
" warehouse_id=\"xxxx123456789\"\n",
|
||||||
|
" )\n",
|
||||||
|
" .include(\n",
|
||||||
|
" # Include functions as tools using their qualified names.\n",
|
||||||
|
" # You can use \"{catalog_name}.{schema_name}.*\" to get all functions in a schema.\n",
|
||||||
|
" \"main.tools.python_exec\",\n",
|
||||||
|
" )\n",
|
||||||
|
" .get_tools()\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.agents import AgentExecutor, create_tool_calling_agent\n",
|
||||||
|
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||||
|
" [\n",
|
||||||
|
" (\n",
|
||||||
|
" \"system\",\n",
|
||||||
|
" \"You are a helpful assistant. Make sure to use tool for information.\",\n",
|
||||||
|
" ),\n",
|
||||||
|
" (\"placeholder\", \"{chat_history}\"),\n",
|
||||||
|
" (\"human\", \"{input}\"),\n",
|
||||||
|
" (\"placeholder\", \"{agent_scratchpad}\"),\n",
|
||||||
|
" ]\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"agent = create_tool_calling_agent(llm, tools, prompt)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"Invoking: `main__tools__python_exec` with `{'code': 'print(36939 * 8922.4)'}`\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[0m\u001b[36;1m\u001b[1;3m{\"format\": \"SCALAR\", \"value\": \"329584533.59999996\\n\", \"truncated\": false}\u001b[0m\u001b[32;1m\u001b[1;3mThe result of the multiplication 36939 * 8922.4 is 329,584,533.60.\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'input': '36939 * 8922.4',\n",
|
||||||
|
" 'output': 'The result of the multiplication 36939 * 8922.4 is 329,584,533.60.'}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n",
|
||||||
|
"agent_executor.invoke({\"input\": \"36939 * 8922.4\"})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "llm",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -0,0 +1,3 @@
|
|||||||
|
from langchain_community.tools.databricks.tool import UCFunctionToolkit
|
||||||
|
|
||||||
|
__all__ = ["UCFunctionToolkit"]
|
@ -0,0 +1,172 @@
|
|||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from io import StringIO
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from databricks.sdk import WorkspaceClient
|
||||||
|
from databricks.sdk.service.catalog import FunctionInfo
|
||||||
|
from databricks.sdk.service.sql import StatementParameterListItem
|
||||||
|
|
||||||
|
|
||||||
|
def is_scalar(function: "FunctionInfo") -> bool:
|
||||||
|
from databricks.sdk.service.catalog import ColumnTypeName
|
||||||
|
|
||||||
|
return function.data_type != ColumnTypeName.TABLE_TYPE
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParameterizedStatement:
|
||||||
|
statement: str
|
||||||
|
parameters: List["StatementParameterListItem"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionExecutionResult:
|
||||||
|
"""
|
||||||
|
Result of executing a function.
|
||||||
|
We always use a string to present the result value for AI model to consume.
|
||||||
|
"""
|
||||||
|
|
||||||
|
error: Optional[str] = None
|
||||||
|
format: Optional[Literal["SCALAR", "CSV"]] = None
|
||||||
|
value: Optional[str] = None
|
||||||
|
truncated: Optional[bool] = None
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
data = {k: v for (k, v) in self.__dict__.items() if v is not None}
|
||||||
|
return json.dumps(data)
|
||||||
|
|
||||||
|
|
||||||
|
def get_execute_function_sql_stmt(
|
||||||
|
function: "FunctionInfo", json_params: Dict[str, Any]
|
||||||
|
) -> ParameterizedStatement:
|
||||||
|
from databricks.sdk.service.catalog import ColumnTypeName
|
||||||
|
from databricks.sdk.service.sql import StatementParameterListItem
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
output_params = []
|
||||||
|
if is_scalar(function):
|
||||||
|
# TODO: IDENTIFIER(:function) did not work
|
||||||
|
parts.append(f"SELECT {function.full_name}(")
|
||||||
|
else:
|
||||||
|
parts.append(f"SELECT * FROM {function.full_name}(")
|
||||||
|
if function.input_params is None or function.input_params.parameters is None:
|
||||||
|
assert (
|
||||||
|
not json_params
|
||||||
|
), "Function has no parameters but parameters were provided."
|
||||||
|
else:
|
||||||
|
args = []
|
||||||
|
use_named_args = False
|
||||||
|
for p in function.input_params.parameters:
|
||||||
|
if p.name not in json_params:
|
||||||
|
if p.parameter_default is not None:
|
||||||
|
use_named_args = True
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameter {p.name} is required but not provided."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
arg_clause = ""
|
||||||
|
if use_named_args:
|
||||||
|
arg_clause += f"{p.name} => "
|
||||||
|
json_value = json_params[p.name]
|
||||||
|
if p.type_name in (
|
||||||
|
ColumnTypeName.ARRAY,
|
||||||
|
ColumnTypeName.MAP,
|
||||||
|
ColumnTypeName.STRUCT,
|
||||||
|
):
|
||||||
|
# Use from_json to restore values of complex types.
|
||||||
|
json_value_str = json.dumps(json_value)
|
||||||
|
# TODO: parametrize type
|
||||||
|
arg_clause += f"from_json(:{p.name}, '{p.type_text}')"
|
||||||
|
output_params.append(
|
||||||
|
StatementParameterListItem(name=p.name, value=json_value_str)
|
||||||
|
)
|
||||||
|
elif p.type_name == ColumnTypeName.BINARY:
|
||||||
|
# Use ubbase64 to restore binary values.
|
||||||
|
arg_clause += f"unbase64(:{p.name})"
|
||||||
|
output_params.append(
|
||||||
|
StatementParameterListItem(name=p.name, value=json_value)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
arg_clause += f":{p.name}"
|
||||||
|
output_params.append(
|
||||||
|
StatementParameterListItem(
|
||||||
|
name=p.name, value=json_value, type=p.type_text
|
||||||
|
)
|
||||||
|
)
|
||||||
|
args.append(arg_clause)
|
||||||
|
parts.append(",".join(args))
|
||||||
|
parts.append(")")
|
||||||
|
# TODO: check extra params in kwargs
|
||||||
|
statement = "".join(parts)
|
||||||
|
return ParameterizedStatement(statement=statement, parameters=output_params)
|
||||||
|
|
||||||
|
|
||||||
|
def execute_function(
|
||||||
|
ws: "WorkspaceClient",
|
||||||
|
warehouse_id: str,
|
||||||
|
function: "FunctionInfo",
|
||||||
|
parameters: Dict[str, Any],
|
||||||
|
) -> FunctionExecutionResult:
|
||||||
|
"""
|
||||||
|
Execute a function with the given arguments and return the result.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import pandas python package. "
|
||||||
|
"Please install it with `pip install pandas`."
|
||||||
|
) from e
|
||||||
|
from databricks.sdk.service.sql import StatementState
|
||||||
|
|
||||||
|
# TODO: async so we can run functions in parallel
|
||||||
|
parametrized_statement = get_execute_function_sql_stmt(function, parameters)
|
||||||
|
# TODO: configurable limits
|
||||||
|
response = ws.statement_execution.execute_statement(
|
||||||
|
statement=parametrized_statement.statement,
|
||||||
|
warehouse_id=warehouse_id,
|
||||||
|
parameters=parametrized_statement.parameters,
|
||||||
|
wait_timeout="30s",
|
||||||
|
row_limit=100,
|
||||||
|
byte_limit=4096,
|
||||||
|
)
|
||||||
|
status = response.status
|
||||||
|
assert status is not None, f"Statement execution failed: {response}"
|
||||||
|
if status.state != StatementState.SUCCEEDED:
|
||||||
|
error = status.error
|
||||||
|
assert (
|
||||||
|
error is not None
|
||||||
|
), "Statement execution failed but no error message was provided."
|
||||||
|
return FunctionExecutionResult(error=f"{error.error_code}: {error.message}")
|
||||||
|
manifest = response.manifest
|
||||||
|
assert manifest is not None
|
||||||
|
truncated = manifest.truncated
|
||||||
|
result = response.result
|
||||||
|
assert (
|
||||||
|
result is not None
|
||||||
|
), "Statement execution succeeded but no result was provided."
|
||||||
|
data_array = result.data_array
|
||||||
|
if is_scalar(function):
|
||||||
|
value = None
|
||||||
|
if data_array and len(data_array) > 0 and len(data_array[0]) > 0:
|
||||||
|
value = str(data_array[0][0]) # type: ignore
|
||||||
|
return FunctionExecutionResult(
|
||||||
|
format="SCALAR", value=value, truncated=truncated
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
schema = manifest.schema
|
||||||
|
assert (
|
||||||
|
schema is not None and schema.columns is not None
|
||||||
|
), "Statement execution succeeded but no schema was provided."
|
||||||
|
columns = [c.name for c in schema.columns]
|
||||||
|
if data_array is None:
|
||||||
|
data_array = []
|
||||||
|
pdf = pd.DataFrame.from_records(data_array, columns=columns)
|
||||||
|
csv_buffer = StringIO()
|
||||||
|
pdf.to_csv(csv_buffer, index=False)
|
||||||
|
return FunctionExecutionResult(
|
||||||
|
format="CSV", value=csv_buffer.getvalue(), truncated=truncated
|
||||||
|
)
|
201
libs/community/langchain_community/tools/databricks/tool.py
Normal file
201
libs/community/langchain_community/tools/databricks/tool.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
import json
|
||||||
|
from datetime import date, datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
from hashlib import md5
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||||
|
from langchain_core.tools import BaseTool, BaseToolkit, StructuredTool
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from databricks.sdk import WorkspaceClient
|
||||||
|
from databricks.sdk.service.catalog import FunctionInfo
|
||||||
|
|
||||||
|
from langchain_community.tools.databricks._execution import execute_function
|
||||||
|
|
||||||
|
|
||||||
|
def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
|
||||||
|
mapping = {
|
||||||
|
"long": int,
|
||||||
|
"binary": bytes,
|
||||||
|
"boolean": bool,
|
||||||
|
"date": date,
|
||||||
|
"double": float,
|
||||||
|
"float": float,
|
||||||
|
"integer": int,
|
||||||
|
"short": int,
|
||||||
|
"string": str,
|
||||||
|
"timestamp": datetime,
|
||||||
|
"timestamp_ntz": datetime,
|
||||||
|
"byte": int,
|
||||||
|
}
|
||||||
|
if isinstance(uc_type_json, str):
|
||||||
|
if uc_type_json in mapping:
|
||||||
|
return mapping[uc_type_json]
|
||||||
|
else:
|
||||||
|
if uc_type_json.startswith("decimal"):
|
||||||
|
return Decimal
|
||||||
|
elif uc_type_json == "void" or uc_type_json.startswith("interval"):
|
||||||
|
raise TypeError(f"Type {uc_type_json} is not supported.")
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unknown type {uc_type_json}. Try upgrading this package."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(uc_type_json, dict)
|
||||||
|
tpe = uc_type_json["type"]
|
||||||
|
if tpe == "array":
|
||||||
|
element_type = _uc_type_to_pydantic_type(uc_type_json["elementType"])
|
||||||
|
if uc_type_json["containsNull"]:
|
||||||
|
element_type = Optional[element_type] # type: ignore
|
||||||
|
return List[element_type] # type: ignore
|
||||||
|
elif tpe == "map":
|
||||||
|
key_type = uc_type_json["keyType"]
|
||||||
|
assert key_type == "string", TypeError(
|
||||||
|
f"Only support STRING key type for MAP but got {key_type}."
|
||||||
|
)
|
||||||
|
value_type = _uc_type_to_pydantic_type(uc_type_json["valueType"])
|
||||||
|
if uc_type_json["valueContainsNull"]:
|
||||||
|
value_type: Type = Optional[value_type] # type: ignore
|
||||||
|
return Dict[str, value_type] # type: ignore
|
||||||
|
elif tpe == "struct":
|
||||||
|
fields = {}
|
||||||
|
for field in uc_type_json["fields"]:
|
||||||
|
field_type = _uc_type_to_pydantic_type(field["type"])
|
||||||
|
if field.get("nullable"):
|
||||||
|
field_type = Optional[field_type] # type: ignore
|
||||||
|
comment = (
|
||||||
|
uc_type_json["metadata"].get("comment")
|
||||||
|
if "metadata" in uc_type_json
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
fields[field["name"]] = (field_type, Field(..., description=comment))
|
||||||
|
uc_type_json_str = json.dumps(uc_type_json, sort_keys=True)
|
||||||
|
type_hash = md5(uc_type_json_str.encode()).hexdigest()[:8]
|
||||||
|
return create_model(f"Struct_{type_hash}", **fields) # type: ignore
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unknown type {uc_type_json}. Try upgrading this package.")
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_args_schema(function: "FunctionInfo") -> Type[BaseModel]:
|
||||||
|
if function.input_params is None:
|
||||||
|
return BaseModel
|
||||||
|
params = function.input_params.parameters
|
||||||
|
assert params is not None
|
||||||
|
fields = {}
|
||||||
|
for p in params:
|
||||||
|
assert p.type_json is not None
|
||||||
|
type_json = json.loads(p.type_json)["type"]
|
||||||
|
pydantic_type = _uc_type_to_pydantic_type(type_json)
|
||||||
|
description = p.comment
|
||||||
|
default: Any = ...
|
||||||
|
if p.parameter_default:
|
||||||
|
pydantic_type = Optional[pydantic_type] # type: ignore
|
||||||
|
default = None
|
||||||
|
# TODO: Convert default value string to the correct type.
|
||||||
|
# We might need to use statement execution API
|
||||||
|
# to get the JSON representation of the value.
|
||||||
|
default_description = f"(Default: {p.parameter_default})"
|
||||||
|
if description:
|
||||||
|
description += f" {default_description}"
|
||||||
|
else:
|
||||||
|
description = default_description
|
||||||
|
fields[p.name] = (
|
||||||
|
pydantic_type,
|
||||||
|
Field(default=default, description=description),
|
||||||
|
)
|
||||||
|
return create_model(
|
||||||
|
f"{function.catalog_name}__{function.schema_name}__{function.name}__params",
|
||||||
|
**fields, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tool_name(function: "FunctionInfo") -> str:
|
||||||
|
tool_name = f"{function.catalog_name}__{function.schema_name}__{function.name}"[
|
||||||
|
-64:
|
||||||
|
]
|
||||||
|
return tool_name
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_workspace_client() -> "WorkspaceClient":
|
||||||
|
try:
|
||||||
|
from databricks.sdk import WorkspaceClient
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import databricks-sdk python package. "
|
||||||
|
"Please install it with `pip install databricks-sdk`."
|
||||||
|
) from e
|
||||||
|
return WorkspaceClient()
|
||||||
|
|
||||||
|
|
||||||
|
class UCFunctionToolkit(BaseToolkit):
|
||||||
|
warehouse_id: str = Field(
|
||||||
|
description="The ID of a Databricks SQL Warehouse to execute functions."
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace_client: "WorkspaceClient" = Field(
|
||||||
|
default_factory=_get_default_workspace_client,
|
||||||
|
description="Databricks workspace client.",
|
||||||
|
)
|
||||||
|
|
||||||
|
tools: Dict[str, BaseTool] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def include(self, *function_names: str, **kwargs: Any) -> Self:
|
||||||
|
"""
|
||||||
|
Includes UC functions to the toolkit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
functions: A list of UC function names in the format
|
||||||
|
"catalog_name.schema_name.function_name" or
|
||||||
|
"catalog_name.schema_name.*".
|
||||||
|
If the function name ends with ".*",
|
||||||
|
all functions in the schema will be added.
|
||||||
|
kwargs: Extra arguments to pass to StructuredTool, e.g., `return_direct`.
|
||||||
|
"""
|
||||||
|
for name in function_names:
|
||||||
|
if name.endswith(".*"):
|
||||||
|
catalog_name, schema_name = name[:-2].split(".")
|
||||||
|
# TODO: handle pagination, warn and truncate if too many
|
||||||
|
functions = self.workspace_client.functions.list(
|
||||||
|
catalog_name=catalog_name, schema_name=schema_name
|
||||||
|
)
|
||||||
|
for f in functions:
|
||||||
|
assert f.full_name is not None
|
||||||
|
self.include(f.full_name, **kwargs)
|
||||||
|
else:
|
||||||
|
if name not in self.tools:
|
||||||
|
self.tools[name] = self._make_tool(name, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _make_tool(self, function_name: str, **kwargs: Any) -> BaseTool:
|
||||||
|
function = self.workspace_client.functions.get(function_name)
|
||||||
|
name = _get_tool_name(function)
|
||||||
|
description = function.comment or ""
|
||||||
|
args_schema = _generate_args_schema(function)
|
||||||
|
|
||||||
|
def func(*args: Any, **kwargs: Any) -> str:
|
||||||
|
# TODO: We expect all named args and ignore args.
|
||||||
|
# Non-empty args show up when the function has no parameters.
|
||||||
|
args_json = json.loads(json.dumps(kwargs, default=str))
|
||||||
|
result = execute_function(
|
||||||
|
ws=self.workspace_client,
|
||||||
|
warehouse_id=self.warehouse_id,
|
||||||
|
function=function,
|
||||||
|
parameters=args_json,
|
||||||
|
)
|
||||||
|
return result.to_json()
|
||||||
|
|
||||||
|
return StructuredTool(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
args_schema=args_schema,
|
||||||
|
func=func,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_tools(self) -> List[BaseTool]:
|
||||||
|
return list(self.tools.values())
|
Loading…
Reference in New Issue
Block a user