feat(model): Support database model registry (#1656)

This commit is contained in:
Fangyin Cheng
2024-06-24 19:07:10 +08:00
committed by GitHub
parent c57ee0289b
commit 47d205f676
35 changed files with 2014 additions and 792 deletions

View File

@@ -1,13 +1,156 @@
import asyncio
import logging
import threading
import time
from abc import ABC
from concurrent.futures import Executor, ThreadPoolExecutor
from dataclasses import asdict, is_dataclass
from datetime import datetime, timedelta
from inspect import signature
from typing import List, Optional, Tuple, Type, TypeVar, Union, get_type_hints
from typing import List, Literal, Optional, Tuple, Type, TypeVar, Union, get_type_hints
T = TypeVar("T")
logger = logging.getLogger(__name__)
class APIMixin(ABC):
"""API mixin class."""
def __init__(
self,
urls: Union[str, List[str]],
health_check_path: str,
health_check_interval_secs: int = 5,
health_check_timeout_secs: int = 30,
check_health: bool = True,
choice_type: Literal["latest_first", "random"] = "latest_first",
executor: Optional[Executor] = None,
):
if isinstance(urls, str):
# Split by ","
urls = urls.split(",")
urls = [url.strip() for url in urls]
self._remote_urls = urls
self._health_check_path = health_check_path
self._health_urls = []
self._health_check_interval_secs = health_check_interval_secs
self._health_check_timeout_secs = health_check_timeout_secs
self._heartbeat_map = {}
self._choice_type = choice_type
self._heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
self._heartbeat_executor = executor or ThreadPoolExecutor(max_workers=3)
self._heartbeat_stop_event = threading.Event()
if check_health:
self._heartbeat_thread.daemon = True
self._heartbeat_thread.start()
def _heartbeat_checker(self):
logger.debug("Running health check")
while not self._heartbeat_stop_event.is_set():
try:
healthy_urls = self._check_and_update_health()
logger.debug(f"Healthy urls: {healthy_urls}")
except Exception as e:
logger.warning(f"Health check failed, error: {e}")
time.sleep(self._health_check_interval_secs)
def __del__(self):
self._heartbeat_stop_event.set()
def _check_health(self, url: str) -> Tuple[bool, str]:
try:
import requests
logger.debug(f"Checking health for {url}")
req_url = url + self._health_check_path
response = requests.get(req_url, timeout=10)
return response.status_code == 200, url
except Exception as e:
logger.warning(f"Health check failed for {url}, error: {e}")
return False, url
def _check_and_update_health(self) -> List[str]:
"""Check health of all remote urls and update the health urls list."""
check_tasks = []
check_results = []
for url in self._remote_urls:
check_tasks.append(self._heartbeat_executor.submit(self._check_health, url))
for task in check_tasks:
check_results.append(task.result())
now = datetime.now()
for is_healthy, url in check_results:
if is_healthy:
self._heartbeat_map[url] = now
healthy_urls = []
for url, last_heartbeat in self._heartbeat_map.items():
if now - last_heartbeat < timedelta(
seconds=self._health_check_interval_secs
):
healthy_urls.append((url, last_heartbeat))
# Sort by last heartbeat time, latest first
healthy_urls.sort(key=lambda x: x[1], reverse=True)
self._health_urls = [url for url, _ in healthy_urls]
return self._health_urls
async def select_url(self, max_wait_health_timeout_secs: int = 2) -> str:
"""Select a healthy url to send request.
If no healthy urls found, select randomly.
"""
import random
def _select(urls: List[str]):
if self._choice_type == "latest_first":
return urls[0]
elif self._choice_type == "random":
return random.choice(urls)
else:
raise ValueError(f"Invalid choice type: {self._choice_type}")
if self._health_urls:
return _select(self._health_urls)
elif max_wait_health_timeout_secs > 0:
start_time = datetime.now()
while datetime.now() - start_time < timedelta(
seconds=max_wait_health_timeout_secs
):
if self._health_urls:
return _select(self._health_urls)
await asyncio.sleep(0.1)
logger.warning("No healthy urls found, selecting randomly")
return _select(self._remote_urls)
def sync_select_url(self, max_wait_health_timeout_secs: int = 2) -> str:
"""Synchronous version of select_url."""
import random
import time
def _select(urls: List[str]):
if self._choice_type == "latest_first":
return urls[0]
elif self._choice_type == "random":
return random.choice(urls)
else:
raise ValueError(f"Invalid choice type: {self._choice_type}")
if self._health_urls:
return _select(self._health_urls)
elif max_wait_health_timeout_secs > 0:
start_time = datetime.now()
while datetime.now() - start_time < timedelta(
seconds=max_wait_health_timeout_secs
):
if self._health_urls:
return _select(self._health_urls)
time.sleep(0.1)
logger.warning("No healthy urls found, selecting randomly")
return _select(self._remote_urls)
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
import typing_inspect
@@ -17,7 +160,7 @@ def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
return None
def _build_request(self, func, path, method, *args, **kwargs):
def _build_request(self, base_url, func, path, method, *args, **kwargs):
return_type = get_type_hints(func).get("return")
if return_type is None:
raise TypeError("Return type must be annotated in the decorated function.")
@@ -27,7 +170,6 @@ def _build_request(self, func, path, method, *args, **kwargs):
if not actual_dataclass:
actual_dataclass = return_type
sig = signature(func)
base_url = self.base_url # Get base_url from class instance
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()
@@ -61,13 +203,22 @@ def _build_request(self, func, path, method, *args, **kwargs):
return return_type, actual_dataclass, request_params
def _api_remote(path, method="GET"):
def _api_remote(path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2):
def decorator(func):
async def wrapper(self, *args, **kwargs):
import httpx
if not isinstance(self, APIMixin):
raise TypeError(
"The class must inherit from APIMixin to use the @_api_remote "
"decorator."
)
# Found a healthy url to send request
base_url = await self.select_url(
max_wait_health_timeout_secs=max_wait_health_timeout_secs
)
return_type, actual_dataclass, request_params = _build_request(
self, func, path, method, *args, **kwargs
self, base_url, func, path, method, *args, **kwargs
)
async with httpx.AsyncClient() as client:
response = await client.request(**request_params)
@@ -84,13 +235,24 @@ def _api_remote(path, method="GET"):
return decorator
def _sync_api_remote(path, method="GET"):
def _sync_api_remote(
path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2
):
def decorator(func):
def wrapper(self, *args, **kwargs):
import requests
if not isinstance(self, APIMixin):
raise TypeError(
"The class must inherit from APIMixin to use the @_sync_api_remote "
"decorator."
)
base_url = self.sync_select_url(
max_wait_health_timeout_secs=max_wait_health_timeout_secs
)
return_type, actual_dataclass, request_params = _build_request(
self, func, path, method, *args, **kwargs
self, base_url, func, path, method, *args, **kwargs
)
response = requests.request(**request_params)

View File

@@ -0,0 +1,105 @@
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from ..api_utils import APIMixin
# Mock requests.get
@pytest.fixture
def mock_requests_get():
with patch("requests.get") as mock_get:
yield mock_get
@pytest.fixture
def apimixin():
urls = "http://example.com,http://example2.com"
health_check_path = "/health"
apimixin = APIMixin(urls, health_check_path)
yield apimixin
# Ensure the executor is properly shut down after tests
apimixin._heartbeat_executor.shutdown(wait=False)
def test_apimixin_initialization(apimixin):
"""Test APIMixin initialization with various parameters."""
assert apimixin._remote_urls == ["http://example.com", "http://example2.com"]
assert apimixin._health_check_path == "/health"
assert apimixin._health_check_interval_secs == 5
assert apimixin._health_check_timeout_secs == 30
assert apimixin._choice_type == "latest_first"
assert isinstance(apimixin._heartbeat_executor, ThreadPoolExecutor)
def test_health_check(apimixin, mock_requests_get):
"""Test the _check_health method."""
url = "http://example.com"
# Mocking a successful response
mock_response = MagicMock()
mock_response.status_code = 200
mock_requests_get.return_value = mock_response
is_healthy, checked_url = apimixin._check_health(url)
assert is_healthy
assert checked_url == url
# Mocking a failed response
mock_requests_get.side_effect = Exception("Connection error")
is_healthy, checked_url = apimixin._check_health(url)
assert not is_healthy
assert checked_url == url
def test_check_and_update_health(apimixin, mock_requests_get):
"""Test the _check_and_update_health method."""
apimixin._heartbeat_map = {
"http://example.com": datetime.now() - timedelta(seconds=3),
"http://example2.com": datetime.now() - timedelta(seconds=10),
}
# Mocking responses
def side_effect(url, timeout):
mock_response = MagicMock()
if url == "http://example.com/health":
mock_response.status_code = 200
elif url == "http://example2.com/health":
mock_response.status_code = 500
return mock_response
mock_requests_get.side_effect = side_effect
health_urls = apimixin._check_and_update_health()
assert "http://example.com" in health_urls
assert "http://example2.com" not in health_urls
@pytest.mark.asyncio
async def test_select_url(apimixin, mock_requests_get):
"""Test the async select_url method."""
apimixin._health_urls = ["http://example.com"]
selected_url = await apimixin.select_url()
assert selected_url == "http://example.com"
# Test with no healthy URLs
apimixin._health_urls = []
selected_url = await apimixin.select_url(max_wait_health_timeout_secs=1)
assert selected_url in ["http://example.com", "http://example2.com"]
def test_sync_select_url(apimixin, mock_requests_get):
"""Test the synchronous sync_select_url method."""
apimixin._health_urls = ["http://example.com"]
selected_url = apimixin.sync_select_url()
assert selected_url == "http://example.com"
# Test with no healthy URLs
apimixin._health_urls = []
selected_url = apimixin.sync_select_url(max_wait_health_timeout_secs=1)
assert selected_url in ["http://example.com", "http://example2.com"]

View File

@@ -172,7 +172,7 @@ def setup_http_service_logging(exclude_paths: Optional[List[str]] = None):
"""
if not exclude_paths:
# Not show heartbeat log
exclude_paths = ["/api/controller/heartbeat"]
exclude_paths = ["/api/controller/heartbeat", "/api/health"]
uvicorn_logger = logging.getLogger("uvicorn.access")
if uvicorn_logger:
for path in exclude_paths: