DB-GPT/dbgpt/util/api_utils.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

289 lines
11 KiB
Python

import asyncio
import logging
import threading
import time
from abc import ABC
from concurrent.futures import Executor, ThreadPoolExecutor
from dataclasses import asdict, is_dataclass
from datetime import datetime, timedelta
from inspect import signature
from typing import List, Literal, Optional, Tuple, Type, TypeVar, Union, get_type_hints
T = TypeVar("T")
logger = logging.getLogger(__name__)
class APIMixin(ABC):
"""API mixin class."""
def __init__(
self,
urls: Union[str, List[str]],
health_check_path: str,
health_check_interval_secs: int = 5,
health_check_timeout_secs: int = 30,
check_health: bool = True,
choice_type: Literal["latest_first", "random"] = "latest_first",
executor: Optional[Executor] = None,
):
if isinstance(urls, str):
# Split by ","
urls = urls.split(",")
urls = [url.strip() for url in urls]
self._remote_urls = urls
self._health_check_path = health_check_path
self._health_urls = []
self._health_check_interval_secs = health_check_interval_secs
self._health_check_timeout_secs = health_check_timeout_secs
self._heartbeat_map = {}
self._choice_type = choice_type
self._heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
self._heartbeat_executor = executor or ThreadPoolExecutor(max_workers=3)
self._heartbeat_stop_event = threading.Event()
if check_health:
self._heartbeat_thread.daemon = True
self._heartbeat_thread.start()
def _heartbeat_checker(self):
logger.debug("Running health check")
while not self._heartbeat_stop_event.is_set():
try:
healthy_urls = self._check_and_update_health()
logger.debug(f"Healthy urls: {healthy_urls}")
except Exception as e:
logger.warning(f"Health check failed, error: {e}")
time.sleep(self._health_check_interval_secs)
def __del__(self):
self._heartbeat_stop_event.set()
def _check_health(self, url: str) -> Tuple[bool, str]:
try:
import requests
logger.debug(f"Checking health for {url}")
req_url = url + self._health_check_path
response = requests.get(req_url, timeout=10)
return response.status_code == 200, url
except Exception as e:
logger.warning(f"Health check failed for {url}, error: {e}")
return False, url
def _check_and_update_health(self) -> List[str]:
"""Check health of all remote urls and update the health urls list."""
check_tasks = []
check_results = []
for url in self._remote_urls:
check_tasks.append(self._heartbeat_executor.submit(self._check_health, url))
for task in check_tasks:
check_results.append(task.result())
now = datetime.now()
for is_healthy, url in check_results:
if is_healthy:
self._heartbeat_map[url] = now
healthy_urls = []
for url, last_heartbeat in self._heartbeat_map.items():
if now - last_heartbeat < timedelta(
seconds=self._health_check_interval_secs
):
healthy_urls.append((url, last_heartbeat))
# Sort by last heartbeat time, latest first
healthy_urls.sort(key=lambda x: x[1], reverse=True)
self._health_urls = [url for url, _ in healthy_urls]
return self._health_urls
async def select_url(self, max_wait_health_timeout_secs: int = 2) -> str:
"""Select a healthy url to send request.
If no healthy urls found, select randomly.
"""
import random
def _select(urls: List[str]):
if self._choice_type == "latest_first":
return urls[0]
elif self._choice_type == "random":
return random.choice(urls)
else:
raise ValueError(f"Invalid choice type: {self._choice_type}")
if self._health_urls:
return _select(self._health_urls)
elif max_wait_health_timeout_secs > 0:
start_time = datetime.now()
while datetime.now() - start_time < timedelta(
seconds=max_wait_health_timeout_secs
):
if self._health_urls:
return _select(self._health_urls)
await asyncio.sleep(0.1)
logger.warning("No healthy urls found, selecting randomly")
return _select(self._remote_urls)
def sync_select_url(self, max_wait_health_timeout_secs: int = 2) -> str:
"""Synchronous version of select_url."""
import random
import time
def _select(urls: List[str]):
if self._choice_type == "latest_first":
return urls[0]
elif self._choice_type == "random":
return random.choice(urls)
else:
raise ValueError(f"Invalid choice type: {self._choice_type}")
if self._health_urls:
return _select(self._health_urls)
elif max_wait_health_timeout_secs > 0:
start_time = datetime.now()
while datetime.now() - start_time < timedelta(
seconds=max_wait_health_timeout_secs
):
if self._health_urls:
return _select(self._health_urls)
time.sleep(0.1)
logger.warning("No healthy urls found, selecting randomly")
return _select(self._remote_urls)
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
import typing_inspect
"""Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
return typing_inspect.get_args(type_hint)[0]
return None
def _build_request(self, base_url, func, path, method, *args, **kwargs):
return_type = get_type_hints(func).get("return")
if return_type is None:
raise TypeError("Return type must be annotated in the decorated function.")
actual_dataclass = _extract_dataclass_from_generic(return_type)
logger.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
if not actual_dataclass:
actual_dataclass = return_type
sig = signature(func)
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()
formatted_url = base_url + path.format(**bound.arguments)
# Extract args names from signature, except "self"
arg_names = list(sig.parameters.keys())[1:]
# Combine args and kwargs into a single dictionary
combined_args = dict(zip(arg_names, args))
combined_args.update(kwargs)
request_data = {}
for key, value in combined_args.items():
if is_dataclass(value):
# Here, instead of adding it as a nested dictionary,
# we set request_data directly to its dictionary representation.
request_data = asdict(value)
else:
request_data[key] = value
request_params = {"method": method, "url": formatted_url}
if method in ["POST", "PUT", "PATCH"]:
request_params["json"] = request_data
else: # For GET, DELETE, etc.
request_params["params"] = request_data
logger.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
return return_type, actual_dataclass, request_params
def _api_remote(path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2):
def decorator(func):
async def wrapper(self, *args, **kwargs):
import httpx
if not isinstance(self, APIMixin):
raise TypeError(
"The class must inherit from APIMixin to use the @_api_remote "
"decorator."
)
# Found a healthy url to send request
base_url = await self.select_url(
max_wait_health_timeout_secs=max_wait_health_timeout_secs
)
return_type, actual_dataclass, request_params = _build_request(
self, base_url, func, path, method, *args, **kwargs
)
async with httpx.AsyncClient() as client:
response = await client.request(**request_params)
if response.status_code == 200:
return _parse_response(
response.json(), return_type, actual_dataclass
)
else:
error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
raise Exception(error_msg)
return wrapper
return decorator
def _sync_api_remote(
path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2
):
def decorator(func):
def wrapper(self, *args, **kwargs):
import requests
if not isinstance(self, APIMixin):
raise TypeError(
"The class must inherit from APIMixin to use the @_sync_api_remote "
"decorator."
)
base_url = self.sync_select_url(
max_wait_health_timeout_secs=max_wait_health_timeout_secs
)
return_type, actual_dataclass, request_params = _build_request(
self, base_url, func, path, method, *args, **kwargs
)
response = requests.request(**request_params)
if response.status_code == 200:
return _parse_response(response.json(), return_type, actual_dataclass)
else:
error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
raise Exception(error_msg)
return wrapper
return decorator
def _parse_response(json_response, return_type, actual_dataclass):
# print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}')
if is_dataclass(actual_dataclass):
if return_type.__origin__ is list: # for List[dataclass]
if isinstance(json_response, list):
return [actual_dataclass(**item) for item in json_response]
else:
raise TypeError(
f"Expected list in response but got {type(json_response)}"
)
else:
if isinstance(json_response, dict):
return actual_dataclass(**json_response)
else:
raise TypeError(
f"Expected dictionary in response but got {type(json_response)}"
)
else:
return json_response