diff --git a/docs/docs/integrations/tools/databricks.ipynb b/docs/docs/integrations/tools/databricks.ipynb index 823ab803e1f..bb44a716f58 100644 --- a/docs/docs/integrations/tools/databricks.ipynb +++ b/docs/docs/integrations/tools/databricks.ipynb @@ -74,6 +74,24 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(Optional) To increase the retry time for getting a function execution response, set environment variable UC_TOOL_CLIENT_EXECUTION_TIMEOUT. Default retry time value is 120s." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"UC_TOOL_CLIENT_EXECUTION_TIMEOUT\"] = \"200\"" + ] + }, { "cell_type": "code", "execution_count": 4, diff --git a/libs/community/langchain_community/tools/databricks/_execution.py b/libs/community/langchain_community/tools/databricks/_execution.py index 48401d55fd8..62e8414fe49 100644 --- a/libs/community/langchain_community/tools/databricks/_execution.py +++ b/libs/community/langchain_community/tools/databricks/_execution.py @@ -1,5 +1,8 @@ import inspect import json +import logging +import os +import time from dataclasses import dataclass from io import StringIO from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional @@ -7,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional if TYPE_CHECKING: from databricks.sdk import WorkspaceClient from databricks.sdk.service.catalog import FunctionInfo - from databricks.sdk.service.sql import StatementParameterListItem + from databricks.sdk.service.sql import StatementParameterListItem, StatementState EXECUTE_FUNCTION_ARG_NAME = "__execution_args__" DEFAULT_EXECUTE_FUNCTION_ARGS = { @@ -15,6 +18,9 @@ DEFAULT_EXECUTE_FUNCTION_ARGS = { "row_limit": 100, "byte_limit": 4096, } +UC_TOOL_CLIENT_EXECUTION_TIMEOUT = "UC_TOOL_CLIENT_EXECUTION_TIMEOUT" +DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT = "120" +_logger = logging.getLogger(__name__) def is_scalar(function: "FunctionInfo") -> bool: @@ -174,13 +180,42 @@ def execute_function( parameters=parametrized_statement.parameters, **execute_statement_args, # type: ignore ) - status = response.status - assert status is not None, f"Statement execution failed: {response}" - if status.state != StatementState.SUCCEEDED: - error = status.error + if response.status and job_pending(response.status.state) and response.statement_id: + statement_id = response.statement_id + wait_time = 0 + retry_cnt = 0 + client_execution_timeout = int( + os.environ.get( + UC_TOOL_CLIENT_EXECUTION_TIMEOUT, + DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT, + ) + ) + while wait_time < client_execution_timeout: + wait = min(2**retry_cnt, client_execution_timeout - wait_time) + _logger.debug( + f"Retrying {retry_cnt} time to get statement execution " + f"status after {wait} seconds." + ) + time.sleep(wait) + response = ws.statement_execution.get_statement(statement_id) # type: ignore + if response.status is None or not job_pending(response.status.state): + break + wait_time += wait + retry_cnt += 1 + if response.status and job_pending(response.status.state): + return FunctionExecutionResult( + error=f"Statement execution is still pending after {wait_time} " + "seconds. Please increase the wait_timeout argument for executing " + f"the function or increase {UC_TOOL_CLIENT_EXECUTION_TIMEOUT} " + "environment variable for increasing retrying time, default is " + f"{DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT} seconds." + ) + assert response.status is not None, f"Statement execution failed: {response}" + if response.status.state != StatementState.SUCCEEDED: + error = response.status.error assert ( error is not None - ), "Statement execution failed but no error message was provided." + ), f"Statement execution failed but no error message was provided: {response}" return FunctionExecutionResult(error=f"{error.error_code}: {error.message}") manifest = response.manifest assert manifest is not None @@ -211,3 +246,9 @@ def execute_function( return FunctionExecutionResult( format="CSV", value=csv_buffer.getvalue(), truncated=truncated ) + + +def job_pending(state: Optional["StatementState"]) -> bool: + from databricks.sdk.service.sql import StatementState + + return state in (StatementState.PENDING, StatementState.RUNNING)