From a7c1ce2b3ff0e015c8b996337810511b9db4cd35 Mon Sep 17 00:00:00 2001
From: Serena Ruan <82044803+serena-ruan@users.noreply.github.com>
Date: Wed, 9 Oct 2024 14:31:48 +0800
Subject: [PATCH] [community] Add timeout control and retry for UC tool
execution (#26645)
Add timeout at client side for UCFunctionToolkit and add retry logic.
Users could specify environment variable
`UC_TOOL_CLIENT_EXECUTION_TIMEOUT` to increase the timeout value for
retrying to get the execution response if the status is pending. Default
timeout value is 120s.
- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.
Tested in Databricks:
- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/
Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.
If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
---------
Signed-off-by: serena-ruan
Co-authored-by: Erick Friis
---
docs/docs/integrations/tools/databricks.ipynb | 18 +++++++
.../tools/databricks/_execution.py | 53 ++++++++++++++++---
2 files changed, 65 insertions(+), 6 deletions(-)
diff --git a/docs/docs/integrations/tools/databricks.ipynb b/docs/docs/integrations/tools/databricks.ipynb
index 823ab803e1f..bb44a716f58 100644
--- a/docs/docs/integrations/tools/databricks.ipynb
+++ b/docs/docs/integrations/tools/databricks.ipynb
@@ -74,6 +74,24 @@
")"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "(Optional) To increase the retry time for getting a function execution response, set environment variable UC_TOOL_CLIENT_EXECUTION_TIMEOUT. Default retry time value is 120s."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"UC_TOOL_CLIENT_EXECUTION_TIMEOUT\"] = \"200\""
+ ]
+ },
{
"cell_type": "code",
"execution_count": 4,
diff --git a/libs/community/langchain_community/tools/databricks/_execution.py b/libs/community/langchain_community/tools/databricks/_execution.py
index 48401d55fd8..62e8414fe49 100644
--- a/libs/community/langchain_community/tools/databricks/_execution.py
+++ b/libs/community/langchain_community/tools/databricks/_execution.py
@@ -1,5 +1,8 @@
import inspect
import json
+import logging
+import os
+import time
from dataclasses import dataclass
from io import StringIO
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
@@ -7,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
if TYPE_CHECKING:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.catalog import FunctionInfo
- from databricks.sdk.service.sql import StatementParameterListItem
+ from databricks.sdk.service.sql import StatementParameterListItem, StatementState
EXECUTE_FUNCTION_ARG_NAME = "__execution_args__"
DEFAULT_EXECUTE_FUNCTION_ARGS = {
@@ -15,6 +18,9 @@ DEFAULT_EXECUTE_FUNCTION_ARGS = {
"row_limit": 100,
"byte_limit": 4096,
}
+UC_TOOL_CLIENT_EXECUTION_TIMEOUT = "UC_TOOL_CLIENT_EXECUTION_TIMEOUT"
+DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT = "120"
+_logger = logging.getLogger(__name__)
def is_scalar(function: "FunctionInfo") -> bool:
@@ -174,13 +180,42 @@ def execute_function(
parameters=parametrized_statement.parameters,
**execute_statement_args, # type: ignore
)
- status = response.status
- assert status is not None, f"Statement execution failed: {response}"
- if status.state != StatementState.SUCCEEDED:
- error = status.error
+ if response.status and job_pending(response.status.state) and response.statement_id:
+ statement_id = response.statement_id
+ wait_time = 0
+ retry_cnt = 0
+ client_execution_timeout = int(
+ os.environ.get(
+ UC_TOOL_CLIENT_EXECUTION_TIMEOUT,
+ DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT,
+ )
+ )
+ while wait_time < client_execution_timeout:
+ wait = min(2**retry_cnt, client_execution_timeout - wait_time)
+ _logger.debug(
+ f"Retrying {retry_cnt} time to get statement execution "
+ f"status after {wait} seconds."
+ )
+ time.sleep(wait)
+ response = ws.statement_execution.get_statement(statement_id) # type: ignore
+ if response.status is None or not job_pending(response.status.state):
+ break
+ wait_time += wait
+ retry_cnt += 1
+ if response.status and job_pending(response.status.state):
+ return FunctionExecutionResult(
+ error=f"Statement execution is still pending after {wait_time} "
+ "seconds. Please increase the wait_timeout argument for executing "
+ f"the function or increase {UC_TOOL_CLIENT_EXECUTION_TIMEOUT} "
+ "environment variable for increasing retrying time, default is "
+ f"{DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT} seconds."
+ )
+ assert response.status is not None, f"Statement execution failed: {response}"
+ if response.status.state != StatementState.SUCCEEDED:
+ error = response.status.error
assert (
error is not None
- ), "Statement execution failed but no error message was provided."
+ ), f"Statement execution failed but no error message was provided: {response}"
return FunctionExecutionResult(error=f"{error.error_code}: {error.message}")
manifest = response.manifest
assert manifest is not None
@@ -211,3 +246,9 @@ def execute_function(
return FunctionExecutionResult(
format="CSV", value=csv_buffer.getvalue(), truncated=truncated
)
+
+
+def job_pending(state: Optional["StatementState"]) -> bool:
+ from databricks.sdk.service.sql import StatementState
+
+ return state in (StatementState.PENDING, StatementState.RUNNING)