DB-GPT/dbgpt/model/cluster/worker/remote_manager.py

207 lines
7.6 KiB
Python

import asyncio
from typing import Any, Callable
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
from dbgpt.model.cluster.base import *
from dbgpt.model.cluster.registry import ModelRegistry
from dbgpt.model.cluster.worker.manager import LocalWorkerManager, WorkerRunData, logger
from dbgpt.model.cluster.worker.remote_worker import RemoteModelWorker
class RemoteWorkerManager(LocalWorkerManager):
def __init__(self, model_registry: ModelRegistry = None) -> None:
super().__init__(model_registry=model_registry)
async def start(self):
for listener in self.start_listeners:
if asyncio.iscoroutinefunction(listener):
await listener(self)
else:
listener(self)
async def stop(self, ignore_exception: bool = False):
pass
async def _fetch_from_worker(
self,
worker_run_data: WorkerRunData,
endpoint: str,
method: str = "GET",
json: dict = None,
params: dict = None,
additional_headers: dict = None,
success_handler: Callable = None,
error_handler: Callable = None,
) -> Any:
# Lazy import to avoid high time cost
import httpx
url = worker_run_data.worker.worker_addr + endpoint
headers = {**worker_run_data.worker.headers, **(additional_headers or {})}
timeout = worker_run_data.worker.timeout
async with httpx.AsyncClient() as client:
request = client.build_request(
method,
url,
json=json, # using json for data to ensure it sends as application/json
params=params,
headers=headers,
timeout=timeout,
)
response = await client.send(request)
if response.status_code != 200:
if error_handler:
return error_handler(response)
else:
error_msg = f"Request to {url} failed, error: {response.text}"
raise Exception(error_msg)
if success_handler:
return success_handler(response)
return response.json()
async def _apply_to_worker_manager_instances(self):
pass
async def supported_models(self) -> List[WorkerSupportedModel]:
worker_instances = await self.get_model_instances(
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
)
async def get_supported_models(worker_run_data) -> List[WorkerSupportedModel]:
def handler(response):
return list(WorkerSupportedModel.from_dict(m) for m in response.json())
return await self._fetch_from_worker(
worker_run_data, "/models/supports", success_handler=handler
)
models = []
results = await asyncio.gather(
*(get_supported_models(worker) for worker in worker_instances)
)
for res in results:
models += res
return models
async def _get_worker_service_instance(
self, host: str = None, port: int = None
) -> List[WorkerRunData]:
worker_instances = await self.get_model_instances(
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
)
error_msg = f"Cound not found worker instances"
if host and port:
worker_instances = [
ins for ins in worker_instances if ins.host == host and ins.port == port
]
error_msg = f"Cound not found worker instances for host {host} port {port}"
if not worker_instances:
raise Exception(error_msg)
return worker_instances
async def model_startup(self, startup_req: WorkerStartupRequest):
worker_instances = await self._get_worker_service_instance(
startup_req.host, startup_req.port
)
worker_run_data = worker_instances[0]
logger.info(f"Start model remote, startup_req: {startup_req}")
return await self._fetch_from_worker(
worker_run_data,
"/models/startup",
method="POST",
json=startup_req.dict(),
success_handler=lambda x: None,
)
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
worker_instances = await self._get_worker_service_instance(
shutdown_req.host, shutdown_req.port
)
worker_run_data = worker_instances[0]
logger.info(f"Shutdown model remote, shutdown_req: {shutdown_req}")
return await self._fetch_from_worker(
worker_run_data,
"/models/shutdown",
method="POST",
json=shutdown_req.dict(),
success_handler=lambda x: None,
)
def _build_worker_instances(
self, model_name: str, instances: List[ModelInstance]
) -> List[WorkerRunData]:
worker_instances = []
for instance in instances:
worker_instances.append(
self._build_single_worker_instance(model_name, instance)
)
return worker_instances
def _build_single_worker_instance(self, model_name: str, instance: ModelInstance):
worker = RemoteModelWorker()
worker.load_worker(
model_name, model_name, host=instance.host, port=instance.port
)
wr = WorkerRunData(
host=instance.host,
port=instance.port,
worker_key=instance.model_name,
worker=worker,
worker_params=None,
model_params=None,
stop_event=asyncio.Event(),
semaphore=asyncio.Semaphore(100), # Not limit in client
)
return wr
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
worker_key = self._worker_key(worker_type, model_name)
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
worker_key, healthy_only
)
return self._build_worker_instances(model_name, instances)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
instances: List[
ModelInstance
] = await self.model_registry.get_all_model_instances(healthy_only=healthy_only)
result = []
for instance in instances:
name, wt = WorkerType.parse_worker_key(instance.model_name)
if wt != worker_type:
continue
result.append(self._build_single_worker_instance(name, instance))
return result
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
worker_key = self._worker_key(worker_type, model_name)
instances: List[ModelInstance] = self.model_registry.sync_get_all_instances(
worker_key, healthy_only
)
return self._build_worker_instances(model_name, instances)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
async def _remote_apply_func(worker_run_data: WorkerRunData):
return await self._fetch_from_worker(
worker_run_data,
"/apply",
method="POST",
json=apply_req.dict(),
success_handler=lambda res: WorkerApplyOutput(**res.json()),
error_handler=lambda res: WorkerApplyOutput(
message=res.text, success=False
),
)
results = await self._apply_worker(apply_req, _remote_apply_func)
if results:
return results[0]