diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 8812425498c..0b2cea22981 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -91,3 +91,4 @@ xata>=1.0.0a7,<2 xmltodict>=0.13.0,<0.14 nanopq==0.2.1 mlflow[genai]>=2.14.0 +databricks-sdk>=0.30.0 diff --git a/libs/community/langchain_community/tools/databricks/_execution.py b/libs/community/langchain_community/tools/databricks/_execution.py index 6cc0c661562..48401d55fd8 100644 --- a/libs/community/langchain_community/tools/databricks/_execution.py +++ b/libs/community/langchain_community/tools/databricks/_execution.py @@ -1,3 +1,4 @@ +import inspect import json from dataclasses import dataclass from io import StringIO @@ -8,6 +9,13 @@ if TYPE_CHECKING: from databricks.sdk.service.catalog import FunctionInfo from databricks.sdk.service.sql import StatementParameterListItem +EXECUTE_FUNCTION_ARG_NAME = "__execution_args__" +DEFAULT_EXECUTE_FUNCTION_ARGS = { + "wait_timeout": "30s", + "row_limit": 100, + "byte_limit": 4096, +} + def is_scalar(function: "FunctionInfo") -> bool: from databricks.sdk.service.catalog import ColumnTypeName @@ -122,16 +130,49 @@ def execute_function( ) from e from databricks.sdk.service.sql import StatementState + if ( + function.input_params + and function.input_params.parameters + and any( + p.name == EXECUTE_FUNCTION_ARG_NAME + for p in function.input_params.parameters + ) + ): + raise ValueError( + "Parameter name conflicts with the reserved argument name for executing " + f"functions: {EXECUTE_FUNCTION_ARG_NAME}. " + f"Please rename the parameter {EXECUTE_FUNCTION_ARG_NAME}." + ) + + # avoid modifying the original dict + execute_statement_args = {**DEFAULT_EXECUTE_FUNCTION_ARGS} + allowed_execute_statement_args = inspect.signature( + ws.statement_execution.execute_statement + ).parameters + if not any( + p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD) + for p in allowed_execute_statement_args.values() + ): + invalid_params = set() + passed_execute_statement_args = parameters.pop(EXECUTE_FUNCTION_ARG_NAME, {}) + for k, v in passed_execute_statement_args.items(): + if k in allowed_execute_statement_args: + execute_statement_args[k] = v + else: + invalid_params.add(k) + if invalid_params: + raise ValueError( + f"Invalid parameters for executing functions: {invalid_params}. " + f"Allowed parameters are: {allowed_execute_statement_args.keys()}." + ) + # TODO: async so we can run functions in parallel parametrized_statement = get_execute_function_sql_stmt(function, parameters) - # TODO: configurable limits response = ws.statement_execution.execute_statement( statement=parametrized_statement.statement, warehouse_id=warehouse_id, parameters=parametrized_statement.parameters, - wait_timeout="30s", - row_limit=100, - byte_limit=4096, + **execute_statement_args, # type: ignore ) status = response.status assert status is not None, f"Statement execution failed: {response}" diff --git a/libs/community/tests/unit_tests/tools/databricks/__init__.py b/libs/community/tests/unit_tests/tools/databricks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/community/tests/unit_tests/tools/databricks/test_tools.py b/libs/community/tests/unit_tests/tools/databricks/test_tools.py new file mode 100644 index 00000000000..dc47bee12c2 --- /dev/null +++ b/libs/community/tests/unit_tests/tools/databricks/test_tools.py @@ -0,0 +1,91 @@ +from unittest import mock + +import pytest + +from langchain_community.tools.databricks._execution import ( + DEFAULT_EXECUTE_FUNCTION_ARGS, + EXECUTE_FUNCTION_ARG_NAME, + execute_function, +) + + +@pytest.mark.requires("databricks.sdk") +@pytest.mark.parametrize( + ("parameters", "execute_params"), + [ + ({"a": 1, "b": 2}, DEFAULT_EXECUTE_FUNCTION_ARGS), + ( + {"a": 1, EXECUTE_FUNCTION_ARG_NAME: {"wait_timeout": "10s"}}, + {**DEFAULT_EXECUTE_FUNCTION_ARGS, "wait_timeout": "10s"}, + ), + ( + {EXECUTE_FUNCTION_ARG_NAME: {"row_limit": "1000"}}, + {**DEFAULT_EXECUTE_FUNCTION_ARGS, "row_limit": "1000"}, + ), + ], +) +def test_execute_function(parameters: dict, execute_params: dict) -> None: + workspace_client = mock.Mock() + + def mock_execute_statement( # type: ignore + statement, + warehouse_id, + *, + byte_limit=None, + catalog=None, + disposition=None, + format=None, + on_wait_timeout=None, + parameters=None, + row_limit=None, + schema=None, + wait_timeout=None, + ): + for key, value in execute_params.items(): + assert locals()[key] == value + return mock.Mock() + + workspace_client.statement_execution.execute_statement = mock_execute_statement + function = mock.Mock() + function.data_type = "TABLE_TYPE" + function.input_params.parameters = [] + execute_function( + workspace_client, warehouse_id="id", function=function, parameters=parameters + ) + + +@pytest.mark.requires("databricks.sdk") +def test_execute_function_error() -> None: + workspace_client = mock.Mock() + + def mock_execute_statement( # type: ignore + statement, + warehouse_id, + *, + byte_limit=None, + catalog=None, + disposition=None, + format=None, + on_wait_timeout=None, + parameters=None, + row_limit=None, + schema=None, + wait_timeout=None, + ): + return mock.Mock() + + workspace_client.statement_execution.execute_statement = mock_execute_statement + function = mock.Mock() + function.data_type = "TABLE_TYPE" + function.input_params.parameters = [] + parameters = {EXECUTE_FUNCTION_ARG_NAME: {"invalid_param": "123"}} + with pytest.raises( + ValueError, + match=r"Invalid parameters for executing functions: {'invalid_param'}. ", + ): + execute_function( + workspace_client, + warehouse_id="id", + function=function, + parameters=parameters, + )