community[patch]: Support passing extra params for executing functions in UCFunctionToolkit (#25652)

Thank you for contributing to LangChain!

- [x] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


Support passing extra params when executing UC functions:
The params should be a dictionary with key EXECUTE_FUNCTION_ARG_NAME,
the assumption is that the function itself doesn't use such variable
name (starting and ending with double underscores), and if it does we
raise Exception.
If invalid params passing to the execute_statement, we raise Exception
as well.


- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Serena Ruan 2024-08-29 09:47:32 +08:00 committed by GitHub
parent 3555882a0d
commit 850bf89e48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 137 additions and 4 deletions

View File

@ -91,3 +91,4 @@ xata>=1.0.0a7,<2
xmltodict>=0.13.0,<0.14
nanopq==0.2.1
mlflow[genai]>=2.14.0
databricks-sdk>=0.30.0

View File

@ -1,3 +1,4 @@
import inspect
import json
from dataclasses import dataclass
from io import StringIO
@ -8,6 +9,13 @@ if TYPE_CHECKING:
from databricks.sdk.service.catalog import FunctionInfo
from databricks.sdk.service.sql import StatementParameterListItem
EXECUTE_FUNCTION_ARG_NAME = "__execution_args__"
DEFAULT_EXECUTE_FUNCTION_ARGS = {
"wait_timeout": "30s",
"row_limit": 100,
"byte_limit": 4096,
}
def is_scalar(function: "FunctionInfo") -> bool:
from databricks.sdk.service.catalog import ColumnTypeName
@ -122,16 +130,49 @@ def execute_function(
) from e
from databricks.sdk.service.sql import StatementState
if (
function.input_params
and function.input_params.parameters
and any(
p.name == EXECUTE_FUNCTION_ARG_NAME
for p in function.input_params.parameters
)
):
raise ValueError(
"Parameter name conflicts with the reserved argument name for executing "
f"functions: {EXECUTE_FUNCTION_ARG_NAME}. "
f"Please rename the parameter {EXECUTE_FUNCTION_ARG_NAME}."
)
# avoid modifying the original dict
execute_statement_args = {**DEFAULT_EXECUTE_FUNCTION_ARGS}
allowed_execute_statement_args = inspect.signature(
ws.statement_execution.execute_statement
).parameters
if not any(
p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
for p in allowed_execute_statement_args.values()
):
invalid_params = set()
passed_execute_statement_args = parameters.pop(EXECUTE_FUNCTION_ARG_NAME, {})
for k, v in passed_execute_statement_args.items():
if k in allowed_execute_statement_args:
execute_statement_args[k] = v
else:
invalid_params.add(k)
if invalid_params:
raise ValueError(
f"Invalid parameters for executing functions: {invalid_params}. "
f"Allowed parameters are: {allowed_execute_statement_args.keys()}."
)
# TODO: async so we can run functions in parallel
parametrized_statement = get_execute_function_sql_stmt(function, parameters)
# TODO: configurable limits
response = ws.statement_execution.execute_statement(
statement=parametrized_statement.statement,
warehouse_id=warehouse_id,
parameters=parametrized_statement.parameters,
wait_timeout="30s",
row_limit=100,
byte_limit=4096,
**execute_statement_args, # type: ignore
)
status = response.status
assert status is not None, f"Statement execution failed: {response}"

View File

@ -0,0 +1,91 @@
from unittest import mock
import pytest
from langchain_community.tools.databricks._execution import (
DEFAULT_EXECUTE_FUNCTION_ARGS,
EXECUTE_FUNCTION_ARG_NAME,
execute_function,
)
@pytest.mark.requires("databricks.sdk")
@pytest.mark.parametrize(
("parameters", "execute_params"),
[
({"a": 1, "b": 2}, DEFAULT_EXECUTE_FUNCTION_ARGS),
(
{"a": 1, EXECUTE_FUNCTION_ARG_NAME: {"wait_timeout": "10s"}},
{**DEFAULT_EXECUTE_FUNCTION_ARGS, "wait_timeout": "10s"},
),
(
{EXECUTE_FUNCTION_ARG_NAME: {"row_limit": "1000"}},
{**DEFAULT_EXECUTE_FUNCTION_ARGS, "row_limit": "1000"},
),
],
)
def test_execute_function(parameters: dict, execute_params: dict) -> None:
workspace_client = mock.Mock()
def mock_execute_statement( # type: ignore
statement,
warehouse_id,
*,
byte_limit=None,
catalog=None,
disposition=None,
format=None,
on_wait_timeout=None,
parameters=None,
row_limit=None,
schema=None,
wait_timeout=None,
):
for key, value in execute_params.items():
assert locals()[key] == value
return mock.Mock()
workspace_client.statement_execution.execute_statement = mock_execute_statement
function = mock.Mock()
function.data_type = "TABLE_TYPE"
function.input_params.parameters = []
execute_function(
workspace_client, warehouse_id="id", function=function, parameters=parameters
)
@pytest.mark.requires("databricks.sdk")
def test_execute_function_error() -> None:
workspace_client = mock.Mock()
def mock_execute_statement( # type: ignore
statement,
warehouse_id,
*,
byte_limit=None,
catalog=None,
disposition=None,
format=None,
on_wait_timeout=None,
parameters=None,
row_limit=None,
schema=None,
wait_timeout=None,
):
return mock.Mock()
workspace_client.statement_execution.execute_statement = mock_execute_statement
function = mock.Mock()
function.data_type = "TABLE_TYPE"
function.input_params.parameters = []
parameters = {EXECUTE_FUNCTION_ARG_NAME: {"invalid_param": "123"}}
with pytest.raises(
ValueError,
match=r"Invalid parameters for executing functions: {'invalid_param'}. ",
):
execute_function(
workspace_client,
warehouse_id="id",
function=function,
parameters=parameters,
)