mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 10:20:01 +00:00
feat(model): Support database model registry (#1656)
This commit is contained in:
@@ -1,13 +1,156 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from inspect import signature
|
||||
from typing import List, Optional, Tuple, Type, TypeVar, Union, get_type_hints
|
||||
from typing import List, Literal, Optional, Tuple, Type, TypeVar, Union, get_type_hints
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIMixin(ABC):
|
||||
"""API mixin class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls: Union[str, List[str]],
|
||||
health_check_path: str,
|
||||
health_check_interval_secs: int = 5,
|
||||
health_check_timeout_secs: int = 30,
|
||||
check_health: bool = True,
|
||||
choice_type: Literal["latest_first", "random"] = "latest_first",
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
if isinstance(urls, str):
|
||||
# Split by ","
|
||||
urls = urls.split(",")
|
||||
urls = [url.strip() for url in urls]
|
||||
self._remote_urls = urls
|
||||
self._health_check_path = health_check_path
|
||||
self._health_urls = []
|
||||
self._health_check_interval_secs = health_check_interval_secs
|
||||
self._health_check_timeout_secs = health_check_timeout_secs
|
||||
self._heartbeat_map = {}
|
||||
self._choice_type = choice_type
|
||||
self._heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
|
||||
self._heartbeat_executor = executor or ThreadPoolExecutor(max_workers=3)
|
||||
self._heartbeat_stop_event = threading.Event()
|
||||
|
||||
if check_health:
|
||||
self._heartbeat_thread.daemon = True
|
||||
self._heartbeat_thread.start()
|
||||
|
||||
def _heartbeat_checker(self):
|
||||
logger.debug("Running health check")
|
||||
while not self._heartbeat_stop_event.is_set():
|
||||
try:
|
||||
healthy_urls = self._check_and_update_health()
|
||||
logger.debug(f"Healthy urls: {healthy_urls}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed, error: {e}")
|
||||
time.sleep(self._health_check_interval_secs)
|
||||
|
||||
def __del__(self):
|
||||
|
||||
self._heartbeat_stop_event.set()
|
||||
|
||||
def _check_health(self, url: str) -> Tuple[bool, str]:
|
||||
try:
|
||||
import requests
|
||||
|
||||
logger.debug(f"Checking health for {url}")
|
||||
req_url = url + self._health_check_path
|
||||
response = requests.get(req_url, timeout=10)
|
||||
return response.status_code == 200, url
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for {url}, error: {e}")
|
||||
return False, url
|
||||
|
||||
def _check_and_update_health(self) -> List[str]:
|
||||
"""Check health of all remote urls and update the health urls list."""
|
||||
check_tasks = []
|
||||
check_results = []
|
||||
for url in self._remote_urls:
|
||||
check_tasks.append(self._heartbeat_executor.submit(self._check_health, url))
|
||||
for task in check_tasks:
|
||||
check_results.append(task.result())
|
||||
now = datetime.now()
|
||||
for is_healthy, url in check_results:
|
||||
if is_healthy:
|
||||
self._heartbeat_map[url] = now
|
||||
healthy_urls = []
|
||||
for url, last_heartbeat in self._heartbeat_map.items():
|
||||
if now - last_heartbeat < timedelta(
|
||||
seconds=self._health_check_interval_secs
|
||||
):
|
||||
healthy_urls.append((url, last_heartbeat))
|
||||
# Sort by last heartbeat time, latest first
|
||||
healthy_urls.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
self._health_urls = [url for url, _ in healthy_urls]
|
||||
return self._health_urls
|
||||
|
||||
async def select_url(self, max_wait_health_timeout_secs: int = 2) -> str:
|
||||
"""Select a healthy url to send request.
|
||||
|
||||
If no healthy urls found, select randomly.
|
||||
"""
|
||||
import random
|
||||
|
||||
def _select(urls: List[str]):
|
||||
if self._choice_type == "latest_first":
|
||||
return urls[0]
|
||||
elif self._choice_type == "random":
|
||||
return random.choice(urls)
|
||||
else:
|
||||
raise ValueError(f"Invalid choice type: {self._choice_type}")
|
||||
|
||||
if self._health_urls:
|
||||
return _select(self._health_urls)
|
||||
elif max_wait_health_timeout_secs > 0:
|
||||
start_time = datetime.now()
|
||||
while datetime.now() - start_time < timedelta(
|
||||
seconds=max_wait_health_timeout_secs
|
||||
):
|
||||
if self._health_urls:
|
||||
return _select(self._health_urls)
|
||||
await asyncio.sleep(0.1)
|
||||
logger.warning("No healthy urls found, selecting randomly")
|
||||
return _select(self._remote_urls)
|
||||
|
||||
def sync_select_url(self, max_wait_health_timeout_secs: int = 2) -> str:
|
||||
"""Synchronous version of select_url."""
|
||||
import random
|
||||
import time
|
||||
|
||||
def _select(urls: List[str]):
|
||||
if self._choice_type == "latest_first":
|
||||
return urls[0]
|
||||
elif self._choice_type == "random":
|
||||
return random.choice(urls)
|
||||
else:
|
||||
raise ValueError(f"Invalid choice type: {self._choice_type}")
|
||||
|
||||
if self._health_urls:
|
||||
return _select(self._health_urls)
|
||||
elif max_wait_health_timeout_secs > 0:
|
||||
start_time = datetime.now()
|
||||
while datetime.now() - start_time < timedelta(
|
||||
seconds=max_wait_health_timeout_secs
|
||||
):
|
||||
if self._health_urls:
|
||||
return _select(self._health_urls)
|
||||
time.sleep(0.1)
|
||||
logger.warning("No healthy urls found, selecting randomly")
|
||||
return _select(self._remote_urls)
|
||||
|
||||
|
||||
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||
import typing_inspect
|
||||
|
||||
@@ -17,7 +160,7 @@ def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||
return None
|
||||
|
||||
|
||||
def _build_request(self, func, path, method, *args, **kwargs):
|
||||
def _build_request(self, base_url, func, path, method, *args, **kwargs):
|
||||
return_type = get_type_hints(func).get("return")
|
||||
if return_type is None:
|
||||
raise TypeError("Return type must be annotated in the decorated function.")
|
||||
@@ -27,7 +170,6 @@ def _build_request(self, func, path, method, *args, **kwargs):
|
||||
if not actual_dataclass:
|
||||
actual_dataclass = return_type
|
||||
sig = signature(func)
|
||||
base_url = self.base_url # Get base_url from class instance
|
||||
|
||||
bound = sig.bind(self, *args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
@@ -61,13 +203,22 @@ def _build_request(self, func, path, method, *args, **kwargs):
|
||||
return return_type, actual_dataclass, request_params
|
||||
|
||||
|
||||
def _api_remote(path, method="GET"):
|
||||
def _api_remote(path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2):
|
||||
def decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
import httpx
|
||||
|
||||
if not isinstance(self, APIMixin):
|
||||
raise TypeError(
|
||||
"The class must inherit from APIMixin to use the @_api_remote "
|
||||
"decorator."
|
||||
)
|
||||
# Found a healthy url to send request
|
||||
base_url = await self.select_url(
|
||||
max_wait_health_timeout_secs=max_wait_health_timeout_secs
|
||||
)
|
||||
return_type, actual_dataclass, request_params = _build_request(
|
||||
self, func, path, method, *args, **kwargs
|
||||
self, base_url, func, path, method, *args, **kwargs
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(**request_params)
|
||||
@@ -84,13 +235,24 @@ def _api_remote(path, method="GET"):
|
||||
return decorator
|
||||
|
||||
|
||||
def _sync_api_remote(path, method="GET"):
|
||||
def _sync_api_remote(
|
||||
path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2
|
||||
):
|
||||
def decorator(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
import requests
|
||||
|
||||
if not isinstance(self, APIMixin):
|
||||
raise TypeError(
|
||||
"The class must inherit from APIMixin to use the @_sync_api_remote "
|
||||
"decorator."
|
||||
)
|
||||
base_url = self.sync_select_url(
|
||||
max_wait_health_timeout_secs=max_wait_health_timeout_secs
|
||||
)
|
||||
|
||||
return_type, actual_dataclass, request_params = _build_request(
|
||||
self, func, path, method, *args, **kwargs
|
||||
self, base_url, func, path, method, *args, **kwargs
|
||||
)
|
||||
|
||||
response = requests.request(**request_params)
|
||||
|
Reference in New Issue
Block a user