[community] Add timeout control and retry for UC tool execution (#26645)

Add timeout at client side for UCFunctionToolkit and add retry logic.
Users could specify environment variable
`UC_TOOL_CLIENT_EXECUTION_TIMEOUT` to increase the timeout value for
retrying to get the execution response if the status is pending. Default
timeout value is 120s.


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

Tested in Databricks:
<img width="1200" alt="image"
src="https://github.com/user-attachments/assets/54ab5dfc-5e57-4941-b7d9-bfe3f8ad3f62">



- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Signed-off-by: serena-ruan <serena.rxy@gmail.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Serena Ruan 2024-10-09 14:31:48 +08:00 committed by GitHub
parent 481bd25d29
commit a7c1ce2b3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 6 deletions

View File

@ -74,6 +74,24 @@
")" ")"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(Optional) To increase the retry time for getting a function execution response, set environment variable UC_TOOL_CLIENT_EXECUTION_TIMEOUT. Default retry time value is 120s."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"UC_TOOL_CLIENT_EXECUTION_TIMEOUT\"] = \"200\""
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,

View File

@ -1,5 +1,8 @@
import inspect import inspect
import json import json
import logging
import os
import time
from dataclasses import dataclass from dataclasses import dataclass
from io import StringIO from io import StringIO
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
@ -7,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from databricks.sdk import WorkspaceClient from databricks.sdk import WorkspaceClient
from databricks.sdk.service.catalog import FunctionInfo from databricks.sdk.service.catalog import FunctionInfo
from databricks.sdk.service.sql import StatementParameterListItem from databricks.sdk.service.sql import StatementParameterListItem, StatementState
EXECUTE_FUNCTION_ARG_NAME = "__execution_args__" EXECUTE_FUNCTION_ARG_NAME = "__execution_args__"
DEFAULT_EXECUTE_FUNCTION_ARGS = { DEFAULT_EXECUTE_FUNCTION_ARGS = {
@ -15,6 +18,9 @@ DEFAULT_EXECUTE_FUNCTION_ARGS = {
"row_limit": 100, "row_limit": 100,
"byte_limit": 4096, "byte_limit": 4096,
} }
UC_TOOL_CLIENT_EXECUTION_TIMEOUT = "UC_TOOL_CLIENT_EXECUTION_TIMEOUT"
DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT = "120"
_logger = logging.getLogger(__name__)
def is_scalar(function: "FunctionInfo") -> bool: def is_scalar(function: "FunctionInfo") -> bool:
@ -174,13 +180,42 @@ def execute_function(
parameters=parametrized_statement.parameters, parameters=parametrized_statement.parameters,
**execute_statement_args, # type: ignore **execute_statement_args, # type: ignore
) )
status = response.status if response.status and job_pending(response.status.state) and response.statement_id:
assert status is not None, f"Statement execution failed: {response}" statement_id = response.statement_id
if status.state != StatementState.SUCCEEDED: wait_time = 0
error = status.error retry_cnt = 0
client_execution_timeout = int(
os.environ.get(
UC_TOOL_CLIENT_EXECUTION_TIMEOUT,
DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT,
)
)
while wait_time < client_execution_timeout:
wait = min(2**retry_cnt, client_execution_timeout - wait_time)
_logger.debug(
f"Retrying {retry_cnt} time to get statement execution "
f"status after {wait} seconds."
)
time.sleep(wait)
response = ws.statement_execution.get_statement(statement_id) # type: ignore
if response.status is None or not job_pending(response.status.state):
break
wait_time += wait
retry_cnt += 1
if response.status and job_pending(response.status.state):
return FunctionExecutionResult(
error=f"Statement execution is still pending after {wait_time} "
"seconds. Please increase the wait_timeout argument for executing "
f"the function or increase {UC_TOOL_CLIENT_EXECUTION_TIMEOUT} "
"environment variable for increasing retrying time, default is "
f"{DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT} seconds."
)
assert response.status is not None, f"Statement execution failed: {response}"
if response.status.state != StatementState.SUCCEEDED:
error = response.status.error
assert ( assert (
error is not None error is not None
), "Statement execution failed but no error message was provided." ), f"Statement execution failed but no error message was provided: {response}"
return FunctionExecutionResult(error=f"{error.error_code}: {error.message}") return FunctionExecutionResult(error=f"{error.error_code}: {error.message}")
manifest = response.manifest manifest = response.manifest
assert manifest is not None assert manifest is not None
@ -211,3 +246,9 @@ def execute_function(
return FunctionExecutionResult( return FunctionExecutionResult(
format="CSV", value=csv_buffer.getvalue(), truncated=truncated format="CSV", value=csv_buffer.getvalue(), truncated=truncated
) )
def job_pending(state: Optional["StatementState"]) -> bool:
from databricks.sdk.service.sql import StatementState
return state in (StatementState.PENDING, StatementState.RUNNING)