mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 03:02:35 +00:00
community[patch]: Support passing extra params for executing functions in UCFunctionToolkit (#25652)
Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" Support passing extra params when executing UC functions: The params should be a dictionary with key EXECUTE_FUNCTION_ARG_NAME, the assumption is that the function itself doesn't use such variable name (starting and ending with double underscores), and if it does we raise Exception. If invalid params passing to the execute_statement, we raise Exception as well. - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Signed-off-by: Serena Ruan <serena.rxy@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
3555882a0d
commit
850bf89e48
@ -91,3 +91,4 @@ xata>=1.0.0a7,<2
|
||||
xmltodict>=0.13.0,<0.14
|
||||
nanopq==0.2.1
|
||||
mlflow[genai]>=2.14.0
|
||||
databricks-sdk>=0.30.0
|
||||
|
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from io import StringIO
|
||||
@ -8,6 +9,13 @@ if TYPE_CHECKING:
|
||||
from databricks.sdk.service.catalog import FunctionInfo
|
||||
from databricks.sdk.service.sql import StatementParameterListItem
|
||||
|
||||
EXECUTE_FUNCTION_ARG_NAME = "__execution_args__"
|
||||
DEFAULT_EXECUTE_FUNCTION_ARGS = {
|
||||
"wait_timeout": "30s",
|
||||
"row_limit": 100,
|
||||
"byte_limit": 4096,
|
||||
}
|
||||
|
||||
|
||||
def is_scalar(function: "FunctionInfo") -> bool:
|
||||
from databricks.sdk.service.catalog import ColumnTypeName
|
||||
@ -122,16 +130,49 @@ def execute_function(
|
||||
) from e
|
||||
from databricks.sdk.service.sql import StatementState
|
||||
|
||||
if (
|
||||
function.input_params
|
||||
and function.input_params.parameters
|
||||
and any(
|
||||
p.name == EXECUTE_FUNCTION_ARG_NAME
|
||||
for p in function.input_params.parameters
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Parameter name conflicts with the reserved argument name for executing "
|
||||
f"functions: {EXECUTE_FUNCTION_ARG_NAME}. "
|
||||
f"Please rename the parameter {EXECUTE_FUNCTION_ARG_NAME}."
|
||||
)
|
||||
|
||||
# avoid modifying the original dict
|
||||
execute_statement_args = {**DEFAULT_EXECUTE_FUNCTION_ARGS}
|
||||
allowed_execute_statement_args = inspect.signature(
|
||||
ws.statement_execution.execute_statement
|
||||
).parameters
|
||||
if not any(
|
||||
p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
|
||||
for p in allowed_execute_statement_args.values()
|
||||
):
|
||||
invalid_params = set()
|
||||
passed_execute_statement_args = parameters.pop(EXECUTE_FUNCTION_ARG_NAME, {})
|
||||
for k, v in passed_execute_statement_args.items():
|
||||
if k in allowed_execute_statement_args:
|
||||
execute_statement_args[k] = v
|
||||
else:
|
||||
invalid_params.add(k)
|
||||
if invalid_params:
|
||||
raise ValueError(
|
||||
f"Invalid parameters for executing functions: {invalid_params}. "
|
||||
f"Allowed parameters are: {allowed_execute_statement_args.keys()}."
|
||||
)
|
||||
|
||||
# TODO: async so we can run functions in parallel
|
||||
parametrized_statement = get_execute_function_sql_stmt(function, parameters)
|
||||
# TODO: configurable limits
|
||||
response = ws.statement_execution.execute_statement(
|
||||
statement=parametrized_statement.statement,
|
||||
warehouse_id=warehouse_id,
|
||||
parameters=parametrized_statement.parameters,
|
||||
wait_timeout="30s",
|
||||
row_limit=100,
|
||||
byte_limit=4096,
|
||||
**execute_statement_args, # type: ignore
|
||||
)
|
||||
status = response.status
|
||||
assert status is not None, f"Statement execution failed: {response}"
|
||||
|
@ -0,0 +1,91 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.tools.databricks._execution import (
|
||||
DEFAULT_EXECUTE_FUNCTION_ARGS,
|
||||
EXECUTE_FUNCTION_ARG_NAME,
|
||||
execute_function,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks.sdk")
|
||||
@pytest.mark.parametrize(
|
||||
("parameters", "execute_params"),
|
||||
[
|
||||
({"a": 1, "b": 2}, DEFAULT_EXECUTE_FUNCTION_ARGS),
|
||||
(
|
||||
{"a": 1, EXECUTE_FUNCTION_ARG_NAME: {"wait_timeout": "10s"}},
|
||||
{**DEFAULT_EXECUTE_FUNCTION_ARGS, "wait_timeout": "10s"},
|
||||
),
|
||||
(
|
||||
{EXECUTE_FUNCTION_ARG_NAME: {"row_limit": "1000"}},
|
||||
{**DEFAULT_EXECUTE_FUNCTION_ARGS, "row_limit": "1000"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_execute_function(parameters: dict, execute_params: dict) -> None:
|
||||
workspace_client = mock.Mock()
|
||||
|
||||
def mock_execute_statement( # type: ignore
|
||||
statement,
|
||||
warehouse_id,
|
||||
*,
|
||||
byte_limit=None,
|
||||
catalog=None,
|
||||
disposition=None,
|
||||
format=None,
|
||||
on_wait_timeout=None,
|
||||
parameters=None,
|
||||
row_limit=None,
|
||||
schema=None,
|
||||
wait_timeout=None,
|
||||
):
|
||||
for key, value in execute_params.items():
|
||||
assert locals()[key] == value
|
||||
return mock.Mock()
|
||||
|
||||
workspace_client.statement_execution.execute_statement = mock_execute_statement
|
||||
function = mock.Mock()
|
||||
function.data_type = "TABLE_TYPE"
|
||||
function.input_params.parameters = []
|
||||
execute_function(
|
||||
workspace_client, warehouse_id="id", function=function, parameters=parameters
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks.sdk")
|
||||
def test_execute_function_error() -> None:
|
||||
workspace_client = mock.Mock()
|
||||
|
||||
def mock_execute_statement( # type: ignore
|
||||
statement,
|
||||
warehouse_id,
|
||||
*,
|
||||
byte_limit=None,
|
||||
catalog=None,
|
||||
disposition=None,
|
||||
format=None,
|
||||
on_wait_timeout=None,
|
||||
parameters=None,
|
||||
row_limit=None,
|
||||
schema=None,
|
||||
wait_timeout=None,
|
||||
):
|
||||
return mock.Mock()
|
||||
|
||||
workspace_client.statement_execution.execute_statement = mock_execute_statement
|
||||
function = mock.Mock()
|
||||
function.data_type = "TABLE_TYPE"
|
||||
function.input_params.parameters = []
|
||||
parameters = {EXECUTE_FUNCTION_ARG_NAME: {"invalid_param": "123"}}
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Invalid parameters for executing functions: {'invalid_param'}. ",
|
||||
):
|
||||
execute_function(
|
||||
workspace_client,
|
||||
warehouse_id="id",
|
||||
function=function,
|
||||
parameters=parameters,
|
||||
)
|
Loading…
Reference in New Issue
Block a user