mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-04 10:00:17 +00:00
207 lines
7.6 KiB
Python
207 lines
7.6 KiB
Python
import asyncio
|
|
from typing import Any, Callable
|
|
|
|
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
|
|
from dbgpt.model.cluster.base import *
|
|
from dbgpt.model.cluster.registry import ModelRegistry
|
|
from dbgpt.model.cluster.worker.manager import LocalWorkerManager, WorkerRunData, logger
|
|
from dbgpt.model.cluster.worker.remote_worker import RemoteModelWorker
|
|
|
|
|
|
class RemoteWorkerManager(LocalWorkerManager):
|
|
def __init__(self, model_registry: ModelRegistry = None) -> None:
|
|
super().__init__(model_registry=model_registry)
|
|
|
|
async def start(self):
|
|
for listener in self.start_listeners:
|
|
if asyncio.iscoroutinefunction(listener):
|
|
await listener(self)
|
|
else:
|
|
listener(self)
|
|
|
|
async def stop(self, ignore_exception: bool = False):
|
|
pass
|
|
|
|
async def _fetch_from_worker(
|
|
self,
|
|
worker_run_data: WorkerRunData,
|
|
endpoint: str,
|
|
method: str = "GET",
|
|
json: dict = None,
|
|
params: dict = None,
|
|
additional_headers: dict = None,
|
|
success_handler: Callable = None,
|
|
error_handler: Callable = None,
|
|
) -> Any:
|
|
# Lazy import to avoid high time cost
|
|
import httpx
|
|
|
|
url = worker_run_data.worker.worker_addr + endpoint
|
|
headers = {**worker_run_data.worker.headers, **(additional_headers or {})}
|
|
timeout = worker_run_data.worker.timeout
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
request = client.build_request(
|
|
method,
|
|
url,
|
|
json=json, # using json for data to ensure it sends as application/json
|
|
params=params,
|
|
headers=headers,
|
|
timeout=timeout,
|
|
)
|
|
|
|
response = await client.send(request)
|
|
if response.status_code != 200:
|
|
if error_handler:
|
|
return error_handler(response)
|
|
else:
|
|
error_msg = f"Request to {url} failed, error: {response.text}"
|
|
raise Exception(error_msg)
|
|
if success_handler:
|
|
return success_handler(response)
|
|
return response.json()
|
|
|
|
async def _apply_to_worker_manager_instances(self):
|
|
pass
|
|
|
|
async def supported_models(self) -> List[WorkerSupportedModel]:
|
|
worker_instances = await self.get_model_instances(
|
|
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
|
|
)
|
|
|
|
async def get_supported_models(worker_run_data) -> List[WorkerSupportedModel]:
|
|
def handler(response):
|
|
return list(WorkerSupportedModel.from_dict(m) for m in response.json())
|
|
|
|
return await self._fetch_from_worker(
|
|
worker_run_data, "/models/supports", success_handler=handler
|
|
)
|
|
|
|
models = []
|
|
results = await asyncio.gather(
|
|
*(get_supported_models(worker) for worker in worker_instances)
|
|
)
|
|
for res in results:
|
|
models += res
|
|
return models
|
|
|
|
async def _get_worker_service_instance(
|
|
self, host: str = None, port: int = None
|
|
) -> List[WorkerRunData]:
|
|
worker_instances = await self.get_model_instances(
|
|
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
|
|
)
|
|
error_msg = f"Cound not found worker instances"
|
|
if host and port:
|
|
worker_instances = [
|
|
ins for ins in worker_instances if ins.host == host and ins.port == port
|
|
]
|
|
error_msg = f"Cound not found worker instances for host {host} port {port}"
|
|
if not worker_instances:
|
|
raise Exception(error_msg)
|
|
return worker_instances
|
|
|
|
async def model_startup(self, startup_req: WorkerStartupRequest):
|
|
worker_instances = await self._get_worker_service_instance(
|
|
startup_req.host, startup_req.port
|
|
)
|
|
worker_run_data = worker_instances[0]
|
|
logger.info(f"Start model remote, startup_req: {startup_req}")
|
|
return await self._fetch_from_worker(
|
|
worker_run_data,
|
|
"/models/startup",
|
|
method="POST",
|
|
json=startup_req.dict(),
|
|
success_handler=lambda x: None,
|
|
)
|
|
|
|
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
|
worker_instances = await self._get_worker_service_instance(
|
|
shutdown_req.host, shutdown_req.port
|
|
)
|
|
worker_run_data = worker_instances[0]
|
|
logger.info(f"Shutdown model remote, shutdown_req: {shutdown_req}")
|
|
return await self._fetch_from_worker(
|
|
worker_run_data,
|
|
"/models/shutdown",
|
|
method="POST",
|
|
json=shutdown_req.dict(),
|
|
success_handler=lambda x: None,
|
|
)
|
|
|
|
def _build_worker_instances(
|
|
self, model_name: str, instances: List[ModelInstance]
|
|
) -> List[WorkerRunData]:
|
|
worker_instances = []
|
|
for instance in instances:
|
|
worker_instances.append(
|
|
self._build_single_worker_instance(model_name, instance)
|
|
)
|
|
return worker_instances
|
|
|
|
def _build_single_worker_instance(self, model_name: str, instance: ModelInstance):
|
|
worker = RemoteModelWorker()
|
|
worker.load_worker(
|
|
model_name, model_name, host=instance.host, port=instance.port
|
|
)
|
|
wr = WorkerRunData(
|
|
host=instance.host,
|
|
port=instance.port,
|
|
worker_key=instance.model_name,
|
|
worker=worker,
|
|
worker_params=None,
|
|
model_params=None,
|
|
stop_event=asyncio.Event(),
|
|
semaphore=asyncio.Semaphore(100), # Not limit in client
|
|
)
|
|
return wr
|
|
|
|
async def get_model_instances(
|
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
) -> List[WorkerRunData]:
|
|
worker_key = self._worker_key(worker_type, model_name)
|
|
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
|
|
worker_key, healthy_only
|
|
)
|
|
return self._build_worker_instances(model_name, instances)
|
|
|
|
async def get_all_model_instances(
|
|
self, worker_type: str, healthy_only: bool = True
|
|
) -> List[WorkerRunData]:
|
|
instances: List[
|
|
ModelInstance
|
|
] = await self.model_registry.get_all_model_instances(healthy_only=healthy_only)
|
|
result = []
|
|
for instance in instances:
|
|
name, wt = WorkerType.parse_worker_key(instance.model_name)
|
|
if wt != worker_type:
|
|
continue
|
|
result.append(self._build_single_worker_instance(name, instance))
|
|
return result
|
|
|
|
def sync_get_model_instances(
|
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
) -> List[WorkerRunData]:
|
|
worker_key = self._worker_key(worker_type, model_name)
|
|
instances: List[ModelInstance] = self.model_registry.sync_get_all_instances(
|
|
worker_key, healthy_only
|
|
)
|
|
return self._build_worker_instances(model_name, instances)
|
|
|
|
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
|
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
|
return await self._fetch_from_worker(
|
|
worker_run_data,
|
|
"/apply",
|
|
method="POST",
|
|
json=apply_req.dict(),
|
|
success_handler=lambda res: WorkerApplyOutput(**res.json()),
|
|
error_handler=lambda res: WorkerApplyOutput(
|
|
message=res.text, success=False
|
|
),
|
|
)
|
|
|
|
results = await self._apply_worker(apply_req, _remote_apply_func)
|
|
if results:
|
|
return results[0]
|