mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 22:51:24 +00:00
feat(model): Support database model registry (#1656)
This commit is contained in:
@@ -1,13 +1,156 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from inspect import signature
|
||||
from typing import List, Optional, Tuple, Type, TypeVar, Union, get_type_hints
|
||||
from typing import List, Literal, Optional, Tuple, Type, TypeVar, Union, get_type_hints
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIMixin(ABC):
|
||||
"""API mixin class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls: Union[str, List[str]],
|
||||
health_check_path: str,
|
||||
health_check_interval_secs: int = 5,
|
||||
health_check_timeout_secs: int = 30,
|
||||
check_health: bool = True,
|
||||
choice_type: Literal["latest_first", "random"] = "latest_first",
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
if isinstance(urls, str):
|
||||
# Split by ","
|
||||
urls = urls.split(",")
|
||||
urls = [url.strip() for url in urls]
|
||||
self._remote_urls = urls
|
||||
self._health_check_path = health_check_path
|
||||
self._health_urls = []
|
||||
self._health_check_interval_secs = health_check_interval_secs
|
||||
self._health_check_timeout_secs = health_check_timeout_secs
|
||||
self._heartbeat_map = {}
|
||||
self._choice_type = choice_type
|
||||
self._heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
|
||||
self._heartbeat_executor = executor or ThreadPoolExecutor(max_workers=3)
|
||||
self._heartbeat_stop_event = threading.Event()
|
||||
|
||||
if check_health:
|
||||
self._heartbeat_thread.daemon = True
|
||||
self._heartbeat_thread.start()
|
||||
|
||||
def _heartbeat_checker(self):
|
||||
logger.debug("Running health check")
|
||||
while not self._heartbeat_stop_event.is_set():
|
||||
try:
|
||||
healthy_urls = self._check_and_update_health()
|
||||
logger.debug(f"Healthy urls: {healthy_urls}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed, error: {e}")
|
||||
time.sleep(self._health_check_interval_secs)
|
||||
|
||||
def __del__(self):
|
||||
|
||||
self._heartbeat_stop_event.set()
|
||||
|
||||
def _check_health(self, url: str) -> Tuple[bool, str]:
|
||||
try:
|
||||
import requests
|
||||
|
||||
logger.debug(f"Checking health for {url}")
|
||||
req_url = url + self._health_check_path
|
||||
response = requests.get(req_url, timeout=10)
|
||||
return response.status_code == 200, url
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for {url}, error: {e}")
|
||||
return False, url
|
||||
|
||||
def _check_and_update_health(self) -> List[str]:
|
||||
"""Check health of all remote urls and update the health urls list."""
|
||||
check_tasks = []
|
||||
check_results = []
|
||||
for url in self._remote_urls:
|
||||
check_tasks.append(self._heartbeat_executor.submit(self._check_health, url))
|
||||
for task in check_tasks:
|
||||
check_results.append(task.result())
|
||||
now = datetime.now()
|
||||
for is_healthy, url in check_results:
|
||||
if is_healthy:
|
||||
self._heartbeat_map[url] = now
|
||||
healthy_urls = []
|
||||
for url, last_heartbeat in self._heartbeat_map.items():
|
||||
if now - last_heartbeat < timedelta(
|
||||
seconds=self._health_check_interval_secs
|
||||
):
|
||||
healthy_urls.append((url, last_heartbeat))
|
||||
# Sort by last heartbeat time, latest first
|
||||
healthy_urls.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
self._health_urls = [url for url, _ in healthy_urls]
|
||||
return self._health_urls
|
||||
|
||||
async def select_url(self, max_wait_health_timeout_secs: int = 2) -> str:
|
||||
"""Select a healthy url to send request.
|
||||
|
||||
If no healthy urls found, select randomly.
|
||||
"""
|
||||
import random
|
||||
|
||||
def _select(urls: List[str]):
|
||||
if self._choice_type == "latest_first":
|
||||
return urls[0]
|
||||
elif self._choice_type == "random":
|
||||
return random.choice(urls)
|
||||
else:
|
||||
raise ValueError(f"Invalid choice type: {self._choice_type}")
|
||||
|
||||
if self._health_urls:
|
||||
return _select(self._health_urls)
|
||||
elif max_wait_health_timeout_secs > 0:
|
||||
start_time = datetime.now()
|
||||
while datetime.now() - start_time < timedelta(
|
||||
seconds=max_wait_health_timeout_secs
|
||||
):
|
||||
if self._health_urls:
|
||||
return _select(self._health_urls)
|
||||
await asyncio.sleep(0.1)
|
||||
logger.warning("No healthy urls found, selecting randomly")
|
||||
return _select(self._remote_urls)
|
||||
|
||||
def sync_select_url(self, max_wait_health_timeout_secs: int = 2) -> str:
|
||||
"""Synchronous version of select_url."""
|
||||
import random
|
||||
import time
|
||||
|
||||
def _select(urls: List[str]):
|
||||
if self._choice_type == "latest_first":
|
||||
return urls[0]
|
||||
elif self._choice_type == "random":
|
||||
return random.choice(urls)
|
||||
else:
|
||||
raise ValueError(f"Invalid choice type: {self._choice_type}")
|
||||
|
||||
if self._health_urls:
|
||||
return _select(self._health_urls)
|
||||
elif max_wait_health_timeout_secs > 0:
|
||||
start_time = datetime.now()
|
||||
while datetime.now() - start_time < timedelta(
|
||||
seconds=max_wait_health_timeout_secs
|
||||
):
|
||||
if self._health_urls:
|
||||
return _select(self._health_urls)
|
||||
time.sleep(0.1)
|
||||
logger.warning("No healthy urls found, selecting randomly")
|
||||
return _select(self._remote_urls)
|
||||
|
||||
|
||||
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||
import typing_inspect
|
||||
|
||||
@@ -17,7 +160,7 @@ def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||
return None
|
||||
|
||||
|
||||
def _build_request(self, func, path, method, *args, **kwargs):
|
||||
def _build_request(self, base_url, func, path, method, *args, **kwargs):
|
||||
return_type = get_type_hints(func).get("return")
|
||||
if return_type is None:
|
||||
raise TypeError("Return type must be annotated in the decorated function.")
|
||||
@@ -27,7 +170,6 @@ def _build_request(self, func, path, method, *args, **kwargs):
|
||||
if not actual_dataclass:
|
||||
actual_dataclass = return_type
|
||||
sig = signature(func)
|
||||
base_url = self.base_url # Get base_url from class instance
|
||||
|
||||
bound = sig.bind(self, *args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
@@ -61,13 +203,22 @@ def _build_request(self, func, path, method, *args, **kwargs):
|
||||
return return_type, actual_dataclass, request_params
|
||||
|
||||
|
||||
def _api_remote(path, method="GET"):
|
||||
def _api_remote(path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2):
|
||||
def decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
import httpx
|
||||
|
||||
if not isinstance(self, APIMixin):
|
||||
raise TypeError(
|
||||
"The class must inherit from APIMixin to use the @_api_remote "
|
||||
"decorator."
|
||||
)
|
||||
# Found a healthy url to send request
|
||||
base_url = await self.select_url(
|
||||
max_wait_health_timeout_secs=max_wait_health_timeout_secs
|
||||
)
|
||||
return_type, actual_dataclass, request_params = _build_request(
|
||||
self, func, path, method, *args, **kwargs
|
||||
self, base_url, func, path, method, *args, **kwargs
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(**request_params)
|
||||
@@ -84,13 +235,24 @@ def _api_remote(path, method="GET"):
|
||||
return decorator
|
||||
|
||||
|
||||
def _sync_api_remote(path, method="GET"):
|
||||
def _sync_api_remote(
|
||||
path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2
|
||||
):
|
||||
def decorator(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
import requests
|
||||
|
||||
if not isinstance(self, APIMixin):
|
||||
raise TypeError(
|
||||
"The class must inherit from APIMixin to use the @_sync_api_remote "
|
||||
"decorator."
|
||||
)
|
||||
base_url = self.sync_select_url(
|
||||
max_wait_health_timeout_secs=max_wait_health_timeout_secs
|
||||
)
|
||||
|
||||
return_type, actual_dataclass, request_params = _build_request(
|
||||
self, func, path, method, *args, **kwargs
|
||||
self, base_url, func, path, method, *args, **kwargs
|
||||
)
|
||||
|
||||
response = requests.request(**request_params)
|
||||
|
105
dbgpt/util/tests/test_api_utils.py
Normal file
105
dbgpt/util/tests/test_api_utils.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ..api_utils import APIMixin
|
||||
|
||||
|
||||
# Mock requests.get
|
||||
@pytest.fixture
|
||||
def mock_requests_get():
|
||||
with patch("requests.get") as mock_get:
|
||||
yield mock_get
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def apimixin():
|
||||
urls = "http://example.com,http://example2.com"
|
||||
health_check_path = "/health"
|
||||
apimixin = APIMixin(urls, health_check_path)
|
||||
yield apimixin
|
||||
# Ensure the executor is properly shut down after tests
|
||||
apimixin._heartbeat_executor.shutdown(wait=False)
|
||||
|
||||
|
||||
def test_apimixin_initialization(apimixin):
|
||||
"""Test APIMixin initialization with various parameters."""
|
||||
assert apimixin._remote_urls == ["http://example.com", "http://example2.com"]
|
||||
assert apimixin._health_check_path == "/health"
|
||||
assert apimixin._health_check_interval_secs == 5
|
||||
assert apimixin._health_check_timeout_secs == 30
|
||||
assert apimixin._choice_type == "latest_first"
|
||||
assert isinstance(apimixin._heartbeat_executor, ThreadPoolExecutor)
|
||||
|
||||
|
||||
def test_health_check(apimixin, mock_requests_get):
|
||||
"""Test the _check_health method."""
|
||||
url = "http://example.com"
|
||||
|
||||
# Mocking a successful response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
is_healthy, checked_url = apimixin._check_health(url)
|
||||
assert is_healthy
|
||||
assert checked_url == url
|
||||
|
||||
# Mocking a failed response
|
||||
mock_requests_get.side_effect = Exception("Connection error")
|
||||
is_healthy, checked_url = apimixin._check_health(url)
|
||||
assert not is_healthy
|
||||
assert checked_url == url
|
||||
|
||||
|
||||
def test_check_and_update_health(apimixin, mock_requests_get):
|
||||
"""Test the _check_and_update_health method."""
|
||||
apimixin._heartbeat_map = {
|
||||
"http://example.com": datetime.now() - timedelta(seconds=3),
|
||||
"http://example2.com": datetime.now() - timedelta(seconds=10),
|
||||
}
|
||||
|
||||
# Mocking responses
|
||||
def side_effect(url, timeout):
|
||||
mock_response = MagicMock()
|
||||
if url == "http://example.com/health":
|
||||
mock_response.status_code = 200
|
||||
elif url == "http://example2.com/health":
|
||||
mock_response.status_code = 500
|
||||
return mock_response
|
||||
|
||||
mock_requests_get.side_effect = side_effect
|
||||
|
||||
health_urls = apimixin._check_and_update_health()
|
||||
assert "http://example.com" in health_urls
|
||||
assert "http://example2.com" not in health_urls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_url(apimixin, mock_requests_get):
|
||||
"""Test the async select_url method."""
|
||||
apimixin._health_urls = ["http://example.com"]
|
||||
|
||||
selected_url = await apimixin.select_url()
|
||||
assert selected_url == "http://example.com"
|
||||
|
||||
# Test with no healthy URLs
|
||||
apimixin._health_urls = []
|
||||
selected_url = await apimixin.select_url(max_wait_health_timeout_secs=1)
|
||||
assert selected_url in ["http://example.com", "http://example2.com"]
|
||||
|
||||
|
||||
def test_sync_select_url(apimixin, mock_requests_get):
|
||||
"""Test the synchronous sync_select_url method."""
|
||||
apimixin._health_urls = ["http://example.com"]
|
||||
|
||||
selected_url = apimixin.sync_select_url()
|
||||
assert selected_url == "http://example.com"
|
||||
|
||||
# Test with no healthy URLs
|
||||
apimixin._health_urls = []
|
||||
selected_url = apimixin.sync_select_url(max_wait_health_timeout_secs=1)
|
||||
assert selected_url in ["http://example.com", "http://example2.com"]
|
@@ -172,7 +172,7 @@ def setup_http_service_logging(exclude_paths: Optional[List[str]] = None):
|
||||
"""
|
||||
if not exclude_paths:
|
||||
# Not show heartbeat log
|
||||
exclude_paths = ["/api/controller/heartbeat"]
|
||||
exclude_paths = ["/api/controller/heartbeat", "/api/health"]
|
||||
uvicorn_logger = logging.getLogger("uvicorn.access")
|
||||
if uvicorn_logger:
|
||||
for path in exclude_paths:
|
||||
|
Reference in New Issue
Block a user