mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-22 17:39:02 +00:00 
			
		
		
		
	Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
		
			
				
	
	
		
			289 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			289 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import asyncio
 | |
| import logging
 | |
| import threading
 | |
| import time
 | |
| from abc import ABC
 | |
| from concurrent.futures import Executor, ThreadPoolExecutor
 | |
| from dataclasses import asdict, is_dataclass
 | |
| from datetime import datetime, timedelta
 | |
| from inspect import signature
 | |
| from typing import List, Literal, Optional, Tuple, Type, TypeVar, Union, get_type_hints
 | |
| 
 | |
| T = TypeVar("T")
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class APIMixin(ABC):
 | |
|     """API mixin class."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         urls: Union[str, List[str]],
 | |
|         health_check_path: str,
 | |
|         health_check_interval_secs: int = 5,
 | |
|         health_check_timeout_secs: int = 30,
 | |
|         check_health: bool = True,
 | |
|         choice_type: Literal["latest_first", "random"] = "latest_first",
 | |
|         executor: Optional[Executor] = None,
 | |
|     ):
 | |
|         if isinstance(urls, str):
 | |
|             # Split by ","
 | |
|             urls = urls.split(",")
 | |
|         urls = [url.strip() for url in urls]
 | |
|         self._remote_urls = urls
 | |
|         self._health_check_path = health_check_path
 | |
|         self._health_urls = []
 | |
|         self._health_check_interval_secs = health_check_interval_secs
 | |
|         self._health_check_timeout_secs = health_check_timeout_secs
 | |
|         self._heartbeat_map = {}
 | |
|         self._choice_type = choice_type
 | |
|         self._heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
 | |
|         self._heartbeat_executor = executor or ThreadPoolExecutor(max_workers=3)
 | |
|         self._heartbeat_stop_event = threading.Event()
 | |
| 
 | |
|         if check_health:
 | |
|             self._heartbeat_thread.daemon = True
 | |
|             self._heartbeat_thread.start()
 | |
| 
 | |
|     def _heartbeat_checker(self):
 | |
|         logger.debug("Running health check")
 | |
|         while not self._heartbeat_stop_event.is_set():
 | |
|             try:
 | |
|                 healthy_urls = self._check_and_update_health()
 | |
|                 logger.debug(f"Healthy urls: {healthy_urls}")
 | |
|             except Exception as e:
 | |
|                 logger.warning(f"Health check failed, error: {e}")
 | |
|             time.sleep(self._health_check_interval_secs)
 | |
| 
 | |
|     def __del__(self):
 | |
|         self._heartbeat_stop_event.set()
 | |
| 
 | |
|     def _check_health(self, url: str) -> Tuple[bool, str]:
 | |
|         try:
 | |
|             import requests
 | |
| 
 | |
|             logger.debug(f"Checking health for {url}")
 | |
|             req_url = url + self._health_check_path
 | |
|             response = requests.get(req_url, timeout=10)
 | |
|             return response.status_code == 200, url
 | |
|         except Exception as e:
 | |
|             logger.warning(f"Health check failed for {url}, error: {e}")
 | |
|             return False, url
 | |
| 
 | |
|     def _check_and_update_health(self) -> List[str]:
 | |
|         """Check health of all remote urls and update the health urls list."""
 | |
|         check_tasks = []
 | |
|         check_results = []
 | |
|         for url in self._remote_urls:
 | |
|             check_tasks.append(self._heartbeat_executor.submit(self._check_health, url))
 | |
|         for task in check_tasks:
 | |
|             check_results.append(task.result())
 | |
|         now = datetime.now()
 | |
|         for is_healthy, url in check_results:
 | |
|             if is_healthy:
 | |
|                 self._heartbeat_map[url] = now
 | |
|         healthy_urls = []
 | |
|         for url, last_heartbeat in self._heartbeat_map.items():
 | |
|             if now - last_heartbeat < timedelta(
 | |
|                 seconds=self._health_check_interval_secs
 | |
|             ):
 | |
|                 healthy_urls.append((url, last_heartbeat))
 | |
|         # Sort by last heartbeat time, latest first
 | |
|         healthy_urls.sort(key=lambda x: x[1], reverse=True)
 | |
| 
 | |
|         self._health_urls = [url for url, _ in healthy_urls]
 | |
|         return self._health_urls
 | |
| 
 | |
|     async def select_url(self, max_wait_health_timeout_secs: int = 2) -> str:
 | |
|         """Select a healthy url to send request.
 | |
| 
 | |
|         If no healthy urls found, select randomly.
 | |
|         """
 | |
|         import random
 | |
| 
 | |
|         def _select(urls: List[str]):
 | |
|             if self._choice_type == "latest_first":
 | |
|                 return urls[0]
 | |
|             elif self._choice_type == "random":
 | |
|                 return random.choice(urls)
 | |
|             else:
 | |
|                 raise ValueError(f"Invalid choice type: {self._choice_type}")
 | |
| 
 | |
|         if self._health_urls:
 | |
|             return _select(self._health_urls)
 | |
|         elif max_wait_health_timeout_secs > 0:
 | |
|             start_time = datetime.now()
 | |
|             while datetime.now() - start_time < timedelta(
 | |
|                 seconds=max_wait_health_timeout_secs
 | |
|             ):
 | |
|                 if self._health_urls:
 | |
|                     return _select(self._health_urls)
 | |
|                 await asyncio.sleep(0.1)
 | |
|         logger.warning("No healthy urls found, selecting randomly")
 | |
|         return _select(self._remote_urls)
 | |
| 
 | |
|     def sync_select_url(self, max_wait_health_timeout_secs: int = 2) -> str:
 | |
|         """Synchronous version of select_url."""
 | |
|         import random
 | |
|         import time
 | |
| 
 | |
|         def _select(urls: List[str]):
 | |
|             if self._choice_type == "latest_first":
 | |
|                 return urls[0]
 | |
|             elif self._choice_type == "random":
 | |
|                 return random.choice(urls)
 | |
|             else:
 | |
|                 raise ValueError(f"Invalid choice type: {self._choice_type}")
 | |
| 
 | |
|         if self._health_urls:
 | |
|             return _select(self._health_urls)
 | |
|         elif max_wait_health_timeout_secs > 0:
 | |
|             start_time = datetime.now()
 | |
|             while datetime.now() - start_time < timedelta(
 | |
|                 seconds=max_wait_health_timeout_secs
 | |
|             ):
 | |
|                 if self._health_urls:
 | |
|                     return _select(self._health_urls)
 | |
|                 time.sleep(0.1)
 | |
|         logger.warning("No healthy urls found, selecting randomly")
 | |
|         return _select(self._remote_urls)
 | |
| 
 | |
| 
 | |
| def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
 | |
|     import typing_inspect
 | |
| 
 | |
|     """Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
 | |
|     if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
 | |
|         return typing_inspect.get_args(type_hint)[0]
 | |
|     return None
 | |
| 
 | |
| 
 | |
| def _build_request(self, base_url, func, path, method, *args, **kwargs):
 | |
|     return_type = get_type_hints(func).get("return")
 | |
|     if return_type is None:
 | |
|         raise TypeError("Return type must be annotated in the decorated function.")
 | |
| 
 | |
|     actual_dataclass = _extract_dataclass_from_generic(return_type)
 | |
|     logger.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
 | |
|     if not actual_dataclass:
 | |
|         actual_dataclass = return_type
 | |
|     sig = signature(func)
 | |
| 
 | |
|     bound = sig.bind(self, *args, **kwargs)
 | |
|     bound.apply_defaults()
 | |
| 
 | |
|     formatted_url = base_url + path.format(**bound.arguments)
 | |
| 
 | |
|     # Extract args names from signature, except "self"
 | |
|     arg_names = list(sig.parameters.keys())[1:]
 | |
| 
 | |
|     # Combine args and kwargs into a single dictionary
 | |
|     combined_args = dict(zip(arg_names, args))
 | |
|     combined_args.update(kwargs)
 | |
| 
 | |
|     request_data = {}
 | |
|     for key, value in combined_args.items():
 | |
|         if is_dataclass(value):
 | |
|             # Here, instead of adding it as a nested dictionary,
 | |
|             # we set request_data directly to its dictionary representation.
 | |
|             request_data = asdict(value)
 | |
|         else:
 | |
|             request_data[key] = value
 | |
| 
 | |
|     request_params = {"method": method, "url": formatted_url}
 | |
| 
 | |
|     if method in ["POST", "PUT", "PATCH"]:
 | |
|         request_params["json"] = request_data
 | |
|     else:  # For GET, DELETE, etc.
 | |
|         request_params["params"] = request_data
 | |
| 
 | |
|     logger.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
 | |
|     return return_type, actual_dataclass, request_params
 | |
| 
 | |
| 
 | |
| def _api_remote(path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2):
 | |
|     def decorator(func):
 | |
|         async def wrapper(self, *args, **kwargs):
 | |
|             import httpx
 | |
| 
 | |
|             if not isinstance(self, APIMixin):
 | |
|                 raise TypeError(
 | |
|                     "The class must inherit from APIMixin to use the @_api_remote "
 | |
|                     "decorator."
 | |
|                 )
 | |
|             # Found a healthy url to send request
 | |
|             base_url = await self.select_url(
 | |
|                 max_wait_health_timeout_secs=max_wait_health_timeout_secs
 | |
|             )
 | |
|             return_type, actual_dataclass, request_params = _build_request(
 | |
|                 self, base_url, func, path, method, *args, **kwargs
 | |
|             )
 | |
|             async with httpx.AsyncClient() as client:
 | |
|                 response = await client.request(**request_params)
 | |
|                 if response.status_code == 200:
 | |
|                     return _parse_response(
 | |
|                         response.json(), return_type, actual_dataclass
 | |
|                     )
 | |
|                 else:
 | |
|                     error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
 | |
|                     raise Exception(error_msg)
 | |
| 
 | |
|         return wrapper
 | |
| 
 | |
|     return decorator
 | |
| 
 | |
| 
 | |
| def _sync_api_remote(
 | |
|     path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2
 | |
| ):
 | |
|     def decorator(func):
 | |
|         def wrapper(self, *args, **kwargs):
 | |
|             import requests
 | |
| 
 | |
|             if not isinstance(self, APIMixin):
 | |
|                 raise TypeError(
 | |
|                     "The class must inherit from APIMixin to use the @_sync_api_remote "
 | |
|                     "decorator."
 | |
|                 )
 | |
|             base_url = self.sync_select_url(
 | |
|                 max_wait_health_timeout_secs=max_wait_health_timeout_secs
 | |
|             )
 | |
| 
 | |
|             return_type, actual_dataclass, request_params = _build_request(
 | |
|                 self, base_url, func, path, method, *args, **kwargs
 | |
|             )
 | |
| 
 | |
|             response = requests.request(**request_params)
 | |
| 
 | |
|             if response.status_code == 200:
 | |
|                 return _parse_response(response.json(), return_type, actual_dataclass)
 | |
|             else:
 | |
|                 error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
 | |
|                 raise Exception(error_msg)
 | |
| 
 | |
|         return wrapper
 | |
| 
 | |
|     return decorator
 | |
| 
 | |
| 
 | |
| def _parse_response(json_response, return_type, actual_dataclass):
 | |
|     # print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}')
 | |
|     if is_dataclass(actual_dataclass):
 | |
|         if return_type.__origin__ is list:  # for List[dataclass]
 | |
|             if isinstance(json_response, list):
 | |
|                 return [actual_dataclass(**item) for item in json_response]
 | |
|             else:
 | |
|                 raise TypeError(
 | |
|                     f"Expected list in response but got {type(json_response)}"
 | |
|                 )
 | |
|         else:
 | |
|             if isinstance(json_response, dict):
 | |
|                 return actual_dataclass(**json_response)
 | |
|             else:
 | |
|                 raise TypeError(
 | |
|                     f"Expected dictionary in response but got {type(json_response)}"
 | |
|                 )
 | |
|     else:
 | |
|         return json_response
 |