From 850bf89e48a60df59064f510c597f33985c91ebc Mon Sep 17 00:00:00 2001 From: Serena Ruan <82044803+serena-ruan@users.noreply.github.com> Date: Thu, 29 Aug 2024 09:47:32 +0800 Subject: [PATCH] community[patch]: Support passing extra params for executing functions in UCFunctionToolkit (#25652) Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" Support passing extra params when executing UC functions: The params should be a dictionary with key EXECUTE_FUNCTION_ARG_NAME, the assumption is that the function itself doesn't use such variable name (starting and ending with double underscores), and if it does we raise Exception. If invalid params passing to the execute_statement, we raise Exception as well. - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Signed-off-by: Serena Ruan Co-authored-by: Bagatur --- libs/community/extended_testing_deps.txt | 1 + .../tools/databricks/_execution.py | 49 +++++++++- .../unit_tests/tools/databricks/__init__.py | 0 .../unit_tests/tools/databricks/test_tools.py | 91 +++++++++++++++++++ 4 files changed, 137 insertions(+), 4 deletions(-) create mode 100644 libs/community/tests/unit_tests/tools/databricks/__init__.py create mode 100644 libs/community/tests/unit_tests/tools/databricks/test_tools.py diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 8812425498c..0b2cea22981 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -91,3 +91,4 @@ xata>=1.0.0a7,<2 xmltodict>=0.13.0,<0.14 nanopq==0.2.1 mlflow[genai]>=2.14.0 +databricks-sdk>=0.30.0 diff --git a/libs/community/langchain_community/tools/databricks/_execution.py b/libs/community/langchain_community/tools/databricks/_execution.py index 6cc0c661562..48401d55fd8 100644 --- a/libs/community/langchain_community/tools/databricks/_execution.py +++ b/libs/community/langchain_community/tools/databricks/_execution.py @@ -1,3 +1,4 @@ +import inspect import json from dataclasses import dataclass from io import StringIO @@ -8,6 +9,13 @@ if TYPE_CHECKING: from databricks.sdk.service.catalog import FunctionInfo from databricks.sdk.service.sql import StatementParameterListItem +EXECUTE_FUNCTION_ARG_NAME = "__execution_args__" +DEFAULT_EXECUTE_FUNCTION_ARGS = { + "wait_timeout": "30s", + "row_limit": 100, + "byte_limit": 4096, +} + def is_scalar(function: "FunctionInfo") -> bool: from databricks.sdk.service.catalog import ColumnTypeName @@ -122,16 +130,49 @@ def execute_function( ) from e from databricks.sdk.service.sql import StatementState + if ( + function.input_params + and function.input_params.parameters + and any( + p.name == EXECUTE_FUNCTION_ARG_NAME + for p in function.input_params.parameters + ) + ): + raise ValueError( + "Parameter name conflicts with the reserved argument name for executing " + f"functions: {EXECUTE_FUNCTION_ARG_NAME}. " + f"Please rename the parameter {EXECUTE_FUNCTION_ARG_NAME}." + ) + + # avoid modifying the original dict + execute_statement_args = {**DEFAULT_EXECUTE_FUNCTION_ARGS} + allowed_execute_statement_args = inspect.signature( + ws.statement_execution.execute_statement + ).parameters + if not any( + p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD) + for p in allowed_execute_statement_args.values() + ): + invalid_params = set() + passed_execute_statement_args = parameters.pop(EXECUTE_FUNCTION_ARG_NAME, {}) + for k, v in passed_execute_statement_args.items(): + if k in allowed_execute_statement_args: + execute_statement_args[k] = v + else: + invalid_params.add(k) + if invalid_params: + raise ValueError( + f"Invalid parameters for executing functions: {invalid_params}. " + f"Allowed parameters are: {allowed_execute_statement_args.keys()}." + ) + # TODO: async so we can run functions in parallel parametrized_statement = get_execute_function_sql_stmt(function, parameters) - # TODO: configurable limits response = ws.statement_execution.execute_statement( statement=parametrized_statement.statement, warehouse_id=warehouse_id, parameters=parametrized_statement.parameters, - wait_timeout="30s", - row_limit=100, - byte_limit=4096, + **execute_statement_args, # type: ignore ) status = response.status assert status is not None, f"Statement execution failed: {response}" diff --git a/libs/community/tests/unit_tests/tools/databricks/__init__.py b/libs/community/tests/unit_tests/tools/databricks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/community/tests/unit_tests/tools/databricks/test_tools.py b/libs/community/tests/unit_tests/tools/databricks/test_tools.py new file mode 100644 index 00000000000..dc47bee12c2 --- /dev/null +++ b/libs/community/tests/unit_tests/tools/databricks/test_tools.py @@ -0,0 +1,91 @@ +from unittest import mock + +import pytest + +from langchain_community.tools.databricks._execution import ( + DEFAULT_EXECUTE_FUNCTION_ARGS, + EXECUTE_FUNCTION_ARG_NAME, + execute_function, +) + + +@pytest.mark.requires("databricks.sdk") +@pytest.mark.parametrize( + ("parameters", "execute_params"), + [ + ({"a": 1, "b": 2}, DEFAULT_EXECUTE_FUNCTION_ARGS), + ( + {"a": 1, EXECUTE_FUNCTION_ARG_NAME: {"wait_timeout": "10s"}}, + {**DEFAULT_EXECUTE_FUNCTION_ARGS, "wait_timeout": "10s"}, + ), + ( + {EXECUTE_FUNCTION_ARG_NAME: {"row_limit": "1000"}}, + {**DEFAULT_EXECUTE_FUNCTION_ARGS, "row_limit": "1000"}, + ), + ], +) +def test_execute_function(parameters: dict, execute_params: dict) -> None: + workspace_client = mock.Mock() + + def mock_execute_statement( # type: ignore + statement, + warehouse_id, + *, + byte_limit=None, + catalog=None, + disposition=None, + format=None, + on_wait_timeout=None, + parameters=None, + row_limit=None, + schema=None, + wait_timeout=None, + ): + for key, value in execute_params.items(): + assert locals()[key] == value + return mock.Mock() + + workspace_client.statement_execution.execute_statement = mock_execute_statement + function = mock.Mock() + function.data_type = "TABLE_TYPE" + function.input_params.parameters = [] + execute_function( + workspace_client, warehouse_id="id", function=function, parameters=parameters + ) + + +@pytest.mark.requires("databricks.sdk") +def test_execute_function_error() -> None: + workspace_client = mock.Mock() + + def mock_execute_statement( # type: ignore + statement, + warehouse_id, + *, + byte_limit=None, + catalog=None, + disposition=None, + format=None, + on_wait_timeout=None, + parameters=None, + row_limit=None, + schema=None, + wait_timeout=None, + ): + return mock.Mock() + + workspace_client.statement_execution.execute_statement = mock_execute_statement + function = mock.Mock() + function.data_type = "TABLE_TYPE" + function.input_params.parameters = [] + parameters = {EXECUTE_FUNCTION_ARG_NAME: {"invalid_param": "123"}} + with pytest.raises( + ValueError, + match=r"Invalid parameters for executing functions: {'invalid_param'}. ", + ): + execute_function( + workspace_client, + warehouse_id="id", + function=function, + parameters=parameters, + )