DB-GPT/dbgpt/model/cluster/worker/manager.py
2024-11-26 19:47:28 +08:00

1195 lines
45 KiB
Python

import asyncio
import itertools
import json
import logging
import os
import random
import sys
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import AsyncIterator, Awaitable, Callable, Iterator
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import LOGDIR
from dbgpt.core import ModelMetadata, ModelOutput
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
from dbgpt.model.cluster.base import *
from dbgpt.model.cluster.manager_base import (
WorkerManager,
WorkerManagerFactory,
WorkerRunData,
)
from dbgpt.model.cluster.registry import ModelRegistry
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.parameter import ModelWorkerParameters, WorkerType
from dbgpt.model.utils.llm_utils import list_supported_models
from dbgpt.util.fastapi import create_app, register_event_handler
from dbgpt.util.parameter_utils import (
EnvArgumentParser,
ParameterDescription,
_dict_to_command_args,
_get_dict_from_obj,
)
from dbgpt.util.system_utils import get_system_info
from dbgpt.util.tracer import SpanType, SpanTypeRunName, initialize_tracer, root_tracer
from dbgpt.util.utils import setup_http_service_logging, setup_logging
logger = logging.getLogger(__name__)
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]
SendHeartbeatFunc = Callable[[WorkerRunData], Awaitable[None]]
ApplyFunction = Callable[[WorkerRunData], Awaitable[None]]
async def _async_heartbeat_sender(
worker_run_data: WorkerRunData,
heartbeat_interval,
send_heartbeat_func: SendHeartbeatFunc,
):
while not worker_run_data.stop_event.is_set():
try:
await send_heartbeat_func(worker_run_data)
except Exception as e:
logger.warn(f"Send heartbeat func error: {str(e)}")
finally:
await asyncio.sleep(heartbeat_interval)
class LocalWorkerManager(WorkerManager):
def __init__(
self,
register_func: RegisterFunc = None,
deregister_func: DeregisterFunc = None,
send_heartbeat_func: SendHeartbeatFunc = None,
model_registry: ModelRegistry = None,
host: str = None,
port: int = None,
) -> None:
self.workers: Dict[str, List[WorkerRunData]] = dict()
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
self.register_func = register_func
self.deregister_func = deregister_func
self.send_heartbeat_func = send_heartbeat_func
self.model_registry = model_registry
self.host = host
self.port = port
self.start_listeners = []
self.run_data = WorkerRunData(
host=self.host,
port=self.port,
worker_key=self._worker_key(
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
),
worker=None,
worker_params=None,
model_params=None,
stop_event=asyncio.Event(),
semaphore=None,
command_args=None,
)
def _worker_key(self, worker_type: str, model_name: str) -> str:
return WorkerType.to_worker_key(model_name, worker_type)
async def run_blocking_func(self, func, *args):
if asyncio.iscoroutinefunction(func):
raise ValueError(f"The function {func} is not blocking function")
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, func, *args)
async def start(self):
if len(self.workers) > 0:
out = await self._start_all_worker(apply_req=None)
if not out.success:
raise Exception(out.message)
if self.register_func:
await self.register_func(self.run_data)
if self.send_heartbeat_func:
asyncio.create_task(
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
)
for listener in self.start_listeners:
if asyncio.iscoroutinefunction(listener):
await listener(self)
else:
listener(self)
async def stop(self, ignore_exception: bool = False):
if not self.run_data.stop_event.is_set():
logger.info("Stop all workers")
self.run_data.stop_event.clear()
stop_tasks = []
stop_tasks.append(
self._stop_all_worker(apply_req=None, ignore_exception=ignore_exception)
)
if self.deregister_func:
# If ignore_exception is True, use exception handling to ignore any exceptions raised from self.deregister_func
if ignore_exception:
async def safe_deregister_func(run_data):
try:
await self.deregister_func(run_data)
except Exception as e:
logger.warning(
f"Stop worker, ignored exception from deregister_func: {e}"
)
stop_tasks.append(safe_deregister_func(self.run_data))
else:
stop_tasks.append(self.deregister_func(self.run_data))
results = await asyncio.gather(*stop_tasks)
if not results[0].success and not ignore_exception:
raise Exception(results[0].message)
def after_start(self, listener: Callable[["WorkerManager"], None]):
self.start_listeners.append(listener)
def add_worker(
self,
worker: ModelWorker,
worker_params: ModelWorkerParameters,
command_args: List[str] = None,
) -> bool:
if not command_args:
command_args = sys.argv[1:]
worker.load_worker(**asdict(worker_params))
if not worker_params.worker_type:
worker_params.worker_type = worker.worker_type()
if isinstance(worker_params.worker_type, WorkerType):
worker_params.worker_type = worker_params.worker_type.value
worker_key = self._worker_key(
worker_params.worker_type, worker_params.model_name
)
# Load model params from persist storage
model_params = worker.parse_parameters(command_args=command_args)
worker_run_data = WorkerRunData(
host=self.host,
port=self.port,
worker_key=worker_key,
worker=worker,
worker_params=worker_params,
model_params=model_params,
stop_event=asyncio.Event(),
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
command_args=command_args,
)
instances = self.workers.get(worker_key)
if not instances:
instances = [worker_run_data]
self.workers[worker_key] = instances
logger.info(f"Init empty instances list for {worker_key}")
return True
else:
# TODO Update worker
logger.warning(f"Instance {worker_key} exist")
return False
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
worker_key = self._worker_key(
worker_params.worker_type, worker_params.model_name
)
instances = self.workers.get(worker_key)
if instances:
del self.workers[worker_key]
async def model_startup(self, startup_req: WorkerStartupRequest):
"""Start model"""
model_name = startup_req.model
worker_type = startup_req.worker_type
params = startup_req.params
logger.debug(
f"start model, model name {model_name}, worker type {worker_type}, params: {params}"
)
worker_params: ModelWorkerParameters = ModelWorkerParameters.from_dict(
params, ignore_extra_fields=True
)
if not worker_params.model_name:
worker_params.model_name = model_name
worker = _build_worker(worker_params)
command_args = _dict_to_command_args(params)
success = await self.run_blocking_func(
self.add_worker, worker, worker_params, command_args
)
if not success:
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
logger.warning(f"{msg}, worker_params: {worker_params}")
self._remove_worker(worker_params)
raise Exception(msg)
supported_types = WorkerType.values()
if worker_type not in supported_types:
self._remove_worker(worker_params)
raise ValueError(
f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}"
)
start_apply_req = WorkerApplyRequest(
model=worker_params.model_name,
apply_type=WorkerApplyType.START,
worker_type=worker_type,
)
out: WorkerApplyOutput = None
try:
out = await self.worker_apply(start_apply_req)
except Exception as e:
self._remove_worker(worker_params)
raise e
if not out.success:
self._remove_worker(worker_params)
raise Exception(out.message)
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
logger.info(f"Begin shutdown model, shutdown_req: {shutdown_req}")
apply_req = WorkerApplyRequest(
model=shutdown_req.model,
apply_type=WorkerApplyType.STOP,
worker_type=shutdown_req.worker_type,
)
out = await self._stop_all_worker(apply_req)
if not out.success:
raise Exception(out.message)
async def supported_models(self) -> List[WorkerSupportedModel]:
models = await self.run_blocking_func(list_supported_models)
return [WorkerSupportedModel(host=self.host, port=self.port, models=models)]
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
return self.sync_get_model_instances(worker_type, model_name, healthy_only)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
instances = list(itertools.chain(*self.workers.values()))
result = []
for instance in instances:
name, wt = WorkerType.parse_worker_key(instance.worker_key)
if wt != worker_type or (healthy_only and instance.stopped):
continue
result.append(instance)
return result
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
worker_key = self._worker_key(worker_type, model_name)
return self.workers.get(worker_key, [])
def _simple_select(
self, worker_type: str, model_name: str, worker_instances: List[WorkerRunData]
) -> WorkerRunData:
if not worker_instances:
raise Exception(
f"Cound not found worker instances for model name {model_name} and worker type {worker_type}"
)
worker_run_data = random.choice(worker_instances)
return worker_run_data
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
worker_instances = await self.get_model_instances(
worker_type, model_name, healthy_only
)
return self._simple_select(worker_type, model_name, worker_instances)
def sync_select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
worker_instances = self.sync_get_model_instances(
worker_type, model_name, healthy_only
)
return self._simple_select(worker_type, model_name, worker_instances)
async def _get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
model = params.get("model")
if not model:
raise Exception("Model name count not be empty")
return await self.select_one_instance(worker_type, model, healthy_only=True)
def _sync_get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
model = params.get("model")
if not model:
raise Exception("Model name count not be empty")
return self.sync_select_one_instance(worker_type, model, healthy_only=True)
async def generate_stream(
self, params: Dict, async_wrapper=None, **kwargs
) -> AsyncIterator[ModelOutput]:
"""Generate stream result, chat scene"""
with root_tracer.start_span(
"WorkerManager.generate_stream", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
return
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
async for outout in worker_run_data.worker.async_generate_stream(
params
):
yield outout
else:
if not async_wrapper:
from starlette.concurrency import iterate_in_threadpool
async_wrapper = iterate_in_threadpool
async for output in async_wrapper(
worker_run_data.worker.generate_stream(params)
):
yield output
async def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
with root_tracer.start_span(
"WorkerManager.generate", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_generate(params)
else:
return await self.run_blocking_func(
worker_run_data.worker.generate, params
)
async def embeddings(self, params: Dict) -> List[List[float]]:
"""Embed input"""
with root_tracer.start_span(
"WorkerManager.embeddings", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params, worker_type="text2vec")
except Exception as e:
raise e
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_embeddings(params)
else:
return await self.run_blocking_func(
worker_run_data.worker.embeddings, params
)
def sync_embeddings(self, params: Dict) -> List[List[float]]:
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
return worker_run_data.worker.embeddings(params)
async def count_token(self, params: Dict) -> int:
"""Count token of prompt"""
with root_tracer.start_span(
"WorkerManager.count_token", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
raise e
prompt = params.get("prompt")
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_count_token(prompt)
else:
return await self.run_blocking_func(
worker_run_data.worker.count_token, prompt
)
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata"""
with root_tracer.start_span(
"WorkerManager.get_model_metadata", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
raise e
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_get_model_metadata(params)
else:
return await self.run_blocking_func(
worker_run_data.worker.get_model_metadata, params
)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
if apply_req.apply_type == WorkerApplyType.START:
apply_func = self._start_all_worker
elif apply_req.apply_type == WorkerApplyType.STOP:
apply_func = self._stop_all_worker
elif apply_req.apply_type == WorkerApplyType.RESTART:
apply_func = self._restart_all_worker
elif apply_req.apply_type == WorkerApplyType.UPDATE_PARAMS:
apply_func = self._update_all_worker_params
else:
raise ValueError(f"Unsupported apply type {apply_req.apply_type}")
return await apply_func(apply_req)
async def parameter_descriptions(
self, worker_type: str, model_name: str
) -> List[ParameterDescription]:
worker_instances = await self.get_model_instances(worker_type, model_name)
if not worker_instances:
raise Exception(
f"Not worker instances for model name {model_name} worker type {worker_type}"
)
worker_run_data = worker_instances[0]
return worker_run_data.worker.parameter_descriptions()
async def _apply_worker(
self, apply_req: WorkerApplyRequest, apply_func: ApplyFunction
) -> None:
"""Apply function to worker instances in parallel
Args:
apply_req (WorkerApplyRequest): Worker apply request
apply_func (ApplyFunction): Function to apply to worker instances, now function is async function
"""
logger.info(f"Apply req: {apply_req}, apply_func: {apply_func}")
if apply_req:
worker_type = apply_req.worker_type.value
model_name = apply_req.model
worker_instances = await self.get_model_instances(
worker_type, model_name, healthy_only=False
)
if not worker_instances:
raise Exception(
f"No worker instance found for the model {model_name} worker type {worker_type}"
)
else:
# Apply to all workers
worker_instances = list(itertools.chain(*self.workers.values()))
logger.info(f"Apply to all workers")
return await asyncio.gather(
*(apply_func(worker) for worker in worker_instances)
)
async def _start_all_worker(
self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput:
from httpx import TimeoutException, TransportError
# TODO avoid start twice
start_time = time.time()
logger.info(f"Begin start all worker, apply_req: {apply_req}")
async def _start_worker(worker_run_data: WorkerRunData):
_start_time = time.time()
info = worker_run_data._to_print_key()
out = WorkerApplyOutput("")
try:
await self.run_blocking_func(
worker_run_data.worker.start,
worker_run_data.model_params,
worker_run_data.command_args,
)
worker_run_data.stop_event.clear()
if worker_run_data.worker_params.register and self.register_func:
# Register worker to controller
await self.register_func(worker_run_data)
if (
worker_run_data.worker_params.send_heartbeat
and self.send_heartbeat_func
):
asyncio.create_task(
_async_heartbeat_sender(
worker_run_data,
worker_run_data.worker_params.heartbeat_interval,
self.send_heartbeat_func,
)
)
out.message = f"{info} start successfully"
except TimeoutException as e:
out.success = False
out.message = (
f"{info} start failed for network timeout, please make "
f"sure your port is available, if you are using global network "
f"proxy, please close it"
)
except TransportError as e:
out.success = False
out.message = (
f"{info} start failed for network error, please make "
f"sure your port is available, if you are using global network "
"proxy, please close it"
)
except Exception:
err_msg = traceback.format_exc()
out.success = False
out.message = f"{info} start failed, {err_msg}"
finally:
out.timecost = time.time() - _start_time
return out
outs = await self._apply_worker(apply_req, _start_worker)
out = WorkerApplyOutput.reduce(outs)
out.timecost = time.time() - start_time
return out
async def _stop_all_worker(
self, apply_req: WorkerApplyRequest, ignore_exception: bool = False
) -> WorkerApplyOutput:
start_time = time.time()
async def _stop_worker(worker_run_data: WorkerRunData):
_start_time = time.time()
info = worker_run_data._to_print_key()
out = WorkerApplyOutput("")
try:
await self.run_blocking_func(worker_run_data.worker.stop)
# Set stop event
worker_run_data.stop_event.set()
if worker_run_data._heartbeat_future:
# Wait thread finish
worker_run_data._heartbeat_future.result()
worker_run_data._heartbeat_future = None
if (
worker_run_data.worker_params.register
and self.register_func
and self.deregister_func
):
_deregister_func = self.deregister_func
if ignore_exception:
async def safe_deregister_func(run_data):
try:
await self.deregister_func(run_data)
except Exception as e:
logger.warning(
f"Stop worker, ignored exception from deregister_func: {e}"
)
_deregister_func = safe_deregister_func
await _deregister_func(worker_run_data)
# Remove metadata
self._remove_worker(worker_run_data.worker_params)
out.message = f"{info} stop successfully"
except Exception as e:
out.success = False
out.message = f"{info} stop failed, {str(e)}"
finally:
out.timecost = time.time() - _start_time
return out
outs = await self._apply_worker(apply_req, _stop_worker)
out = WorkerApplyOutput.reduce(outs)
out.timecost = time.time() - start_time
return out
async def _restart_all_worker(
self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput:
out = await self._stop_all_worker(apply_req, ignore_exception=True)
if not out.success:
return out
return await self._start_all_worker(apply_req)
async def _update_all_worker_params(
self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput:
start_time = time.time()
need_restart = False
async def update_params(worker_run_data: WorkerRunData):
nonlocal need_restart
new_params = apply_req.params
if not new_params:
return
if worker_run_data.model_params.update_from(new_params):
need_restart = True
await self._apply_worker(apply_req, update_params)
message = f"Update worker params successfully"
timecost = time.time() - start_time
if need_restart:
logger.info("Model params update successfully, begin restart worker")
await self._restart_all_worker(apply_req)
timecost = time.time() - start_time
message = f"Update worker params and restart successfully"
return WorkerApplyOutput(message=message, timecost=timecost)
class WorkerManagerAdapter(WorkerManager):
def __init__(self, worker_manager: WorkerManager = None) -> None:
self.worker_manager = worker_manager
async def start(self):
return await self.worker_manager.start()
async def stop(self, ignore_exception: bool = False):
return await self.worker_manager.stop(ignore_exception=ignore_exception)
def after_start(self, listener: Callable[["WorkerManager"], None]):
if listener is not None:
self.worker_manager.after_start(listener)
async def supported_models(self) -> List[WorkerSupportedModel]:
return await self.worker_manager.supported_models()
async def model_startup(self, startup_req: WorkerStartupRequest):
return await self.worker_manager.model_startup(startup_req)
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
return await self.worker_manager.model_shutdown(shutdown_req)
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
return await self.worker_manager.get_model_instances(
worker_type, model_name, healthy_only
)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
return await self.worker_manager.get_all_model_instances(
worker_type, healthy_only
)
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
return self.worker_manager.sync_get_model_instances(
worker_type, model_name, healthy_only
)
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
return await self.worker_manager.select_one_instance(
worker_type, model_name, healthy_only
)
def sync_select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
return self.worker_manager.sync_select_one_instance(
worker_type, model_name, healthy_only
)
async def generate_stream(
self, params: Dict, **kwargs
) -> AsyncIterator[ModelOutput]:
async for output in self.worker_manager.generate_stream(params, **kwargs):
yield output
async def generate(self, params: Dict) -> ModelOutput:
return await self.worker_manager.generate(params)
async def embeddings(self, params: Dict) -> List[List[float]]:
return await self.worker_manager.embeddings(params)
def sync_embeddings(self, params: Dict) -> List[List[float]]:
return self.worker_manager.sync_embeddings(params)
async def count_token(self, params: Dict) -> int:
return await self.worker_manager.count_token(params)
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
return await self.worker_manager.get_model_metadata(params)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
return await self.worker_manager.worker_apply(apply_req)
async def parameter_descriptions(
self, worker_type: str, model_name: str
) -> List[ParameterDescription]:
return await self.worker_manager.parameter_descriptions(worker_type, model_name)
class _DefaultWorkerManagerFactory(WorkerManagerFactory):
def __init__(
self, system_app: SystemApp | None = None, worker_manager: WorkerManager = None
):
super().__init__(system_app)
self.worker_manager = worker_manager
def create(self) -> WorkerManager:
return self.worker_manager
worker_manager = WorkerManagerAdapter()
router = APIRouter()
async def generate_json_stream(params):
from starlette.concurrency import iterate_in_threadpool
async for output in worker_manager.generate_stream(
params, async_wrapper=iterate_in_threadpool
):
yield json.dumps(asdict(output), ensure_ascii=False).encode() + b"\0"
@router.post("/worker/generate_stream")
async def api_generate_stream(request: PromptRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
generator = generate_json_stream(params)
return StreamingResponse(generator)
@router.post("/worker/generate")
async def api_generate(request: PromptRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.generate(params)
@router.post("/worker/embeddings")
async def api_embeddings(request: EmbeddingsRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.embeddings(params)
@router.post("/worker/count_token")
async def api_count_token(request: CountTokenRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.count_token(params)
@router.post("/worker/model_metadata")
async def api_get_model_metadata(request: ModelMetadataRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.get_model_metadata(params)
@router.post("/worker/apply")
async def api_worker_apply(request: WorkerApplyRequest):
return await worker_manager.worker_apply(request)
@router.get("/worker/parameter/descriptions")
async def api_worker_parameter_descs(
model: str, worker_type: str = WorkerType.LLM.value
):
return await worker_manager.parameter_descriptions(worker_type, model)
@router.get("/worker/models/supports")
async def api_supported_models():
"""Get all supported models.
This method reads all models from the configuration file and tries to perform some basic checks on the model (like if the path exists).
If it's a RemoteWorkerManager, this method returns the list of models supported by the entire cluster.
"""
return await worker_manager.supported_models()
@router.post("/worker/models/startup")
async def api_model_startup(request: WorkerStartupRequest):
"""Start up a specific model."""
return await worker_manager.model_startup(request)
@router.post("/worker/models/shutdown")
async def api_model_shutdown(request: WorkerStartupRequest):
"""Shut down a specific model."""
return await worker_manager.model_shutdown(request)
def _setup_fastapi(
worker_params: ModelWorkerParameters,
app=None,
ignore_exception: bool = False,
system_app: Optional[SystemApp] = None,
):
if not app:
app = create_app()
setup_http_service_logging()
if system_app:
system_app._asgi_app = app
if worker_params.standalone:
from dbgpt.model.cluster.controller.controller import initialize_controller
from dbgpt.model.cluster.controller.controller import (
router as controller_router,
)
if not worker_params.controller_addr:
# if we have http_proxy or https_proxy in env, the server can not start
# so set it to empty here
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = ""
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
logger.info(
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"
)
initialize_controller(app=app, system_app=system_app)
app.include_router(controller_router, prefix="/api")
async def startup_event():
async def start_worker_manager():
try:
await worker_manager.start()
except Exception as e:
import signal
logger.error(f"Error starting worker manager: {str(e)}")
os.kill(os.getpid(), signal.SIGINT)
# It cannot be blocked here because the startup of worker_manager depends on
# the fastapi app (registered to the controller)
asyncio.create_task(start_worker_manager())
async def shutdown_event():
await worker_manager.stop(ignore_exception=ignore_exception)
register_event_handler(app, "startup", startup_event)
register_event_handler(app, "shutdown", shutdown_event)
return app
def _parse_worker_params(
model_name: str = None, model_path: str = None, **kwargs
) -> ModelWorkerParameters:
worker_args = EnvArgumentParser()
env_prefix = None
if model_name:
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
worker_params: ModelWorkerParameters = worker_args.parse_args_into_dataclass(
ModelWorkerParameters,
env_prefixes=[env_prefix],
model_name=model_name,
model_path=model_path,
**kwargs,
)
env_prefix = EnvArgumentParser.get_env_prefix(worker_params.model_name)
# Read parameters agein with prefix of model name.
new_worker_params = worker_args.parse_args_into_dataclass(
ModelWorkerParameters,
env_prefixes=[env_prefix],
model_name=worker_params.model_name,
model_path=worker_params.model_path,
**kwargs,
)
worker_params.update_from(new_worker_params)
if worker_params.model_alias:
worker_params.model_name = worker_params.model_alias
# logger.info(f"Worker params: {worker_params}")
return worker_params
def _create_local_model_manager(
worker_params: ModelWorkerParameters,
) -> LocalWorkerManager:
from dbgpt.util.net_utils import _get_ip_address
host = (
worker_params.worker_register_host
if worker_params.worker_register_host
else _get_ip_address()
)
port = worker_params.port
if not worker_params.register or not worker_params.controller_addr:
logger.info(
f"Not register current to controller, register: {worker_params.register}, controller_addr: {worker_params.controller_addr}"
)
return LocalWorkerManager(host=host, port=port)
else:
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
client = ModelRegistryClient(worker_params.controller_addr)
async def register_func(worker_run_data: WorkerRunData):
instance = ModelInstance(
model_name=worker_run_data.worker_key, host=host, port=port
)
return await client.register_instance(instance)
async def deregister_func(worker_run_data: WorkerRunData):
instance = ModelInstance(
model_name=worker_run_data.worker_key, host=host, port=port
)
return await client.deregister_instance(instance)
async def send_heartbeat_func(worker_run_data: WorkerRunData):
instance = ModelInstance(
model_name=worker_run_data.worker_key, host=host, port=port
)
return await client.send_heartbeat(instance)
return LocalWorkerManager(
register_func=register_func,
deregister_func=deregister_func,
send_heartbeat_func=send_heartbeat_func,
host=host,
port=port,
)
def _build_worker(
worker_params: ModelWorkerParameters,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
worker_class = worker_params.worker_class
if worker_class:
from dbgpt.util.module_utils import import_from_checked_string
worker_cls = import_from_checked_string(worker_class, ModelWorker)
logger.info(f"Import worker class from {worker_class} successfully")
else:
if (
worker_params.worker_type is None
or worker_params.worker_type == WorkerType.LLM
):
from dbgpt.model.cluster.worker.default_worker import DefaultModelWorker
worker_cls = DefaultModelWorker
elif worker_params.worker_type == WorkerType.TEXT2VEC:
from dbgpt.model.cluster.worker.embedding_worker import (
EmbeddingsModelWorker,
)
worker_cls = EmbeddingsModelWorker
else:
raise Exception("Unsupported worker type: {worker_params.worker_type}")
if ext_worker_kwargs:
return worker_cls(**ext_worker_kwargs)
else:
return worker_cls()
def _start_local_worker(
worker_manager: WorkerManagerAdapter,
worker_params: ModelWorkerParameters,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
with root_tracer.start_span(
"WorkerManager._start_local_worker",
span_type=SpanType.RUN,
metadata={
"run_service": SpanTypeRunName.WORKER_MANAGER,
"params": _get_dict_from_obj(worker_params),
"sys_infos": _get_dict_from_obj(get_system_info()),
},
):
worker = _build_worker(worker_params, ext_worker_kwargs=ext_worker_kwargs)
if not worker_manager.worker_manager:
worker_manager.worker_manager = _create_local_model_manager(worker_params)
worker_manager.worker_manager.add_worker(worker, worker_params)
def _start_local_embedding_worker(
worker_manager: WorkerManagerAdapter,
embedding_model_name: str = None,
embedding_model_path: str = None,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
if not embedding_model_name or not embedding_model_path:
return
embedding_worker_params = ModelWorkerParameters(
model_name=embedding_model_name,
model_path=embedding_model_path,
worker_type=WorkerType.TEXT2VEC,
worker_class="dbgpt.model.cluster.worker.embedding_worker.EmbeddingsModelWorker",
)
logger.info(
f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
)
_start_local_worker(
worker_manager, embedding_worker_params, ext_worker_kwargs=ext_worker_kwargs
)
def initialize_worker_manager_in_client(
app=None,
include_router: bool = True,
model_name: Optional[str] = None,
model_path: Optional[str] = None,
run_locally: bool = True,
controller_addr: Optional[str] = None,
local_port: int = 5670,
embedding_model_name: Optional[str] = None,
embedding_model_path: Optional[str] = None,
rerank_model_name: Optional[str] = None,
rerank_model_path: Optional[str] = None,
start_listener: Optional[Callable[["WorkerManager"], None]] = None,
system_app: Optional[SystemApp] = None,
):
"""Initialize WorkerManager in client.
If run_locally is True:
1. Start ModelController
2. Start LocalWorkerManager
3. Start worker in LocalWorkerManager
4. Register worker to ModelController
otherwise:
1. Build ModelRegistryClient with controller address
2. Start RemoteWorkerManager
"""
global worker_manager
if not app:
raise Exception("app can't be None")
if system_app:
logger.info(f"Register WorkerManager {_DefaultWorkerManagerFactory.name}")
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
worker_params: ModelWorkerParameters = _parse_worker_params(
model_name=model_name, model_path=model_path, controller_addr=controller_addr
)
controller_addr = None
if run_locally:
# TODO start ModelController
worker_params.standalone = True
worker_params.register = True
worker_params.port = local_port
logger.info(f"Worker params: {worker_params}")
_setup_fastapi(worker_params, app, ignore_exception=True, system_app=system_app)
_start_local_worker(worker_manager, worker_params)
worker_manager.after_start(start_listener)
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
_start_local_embedding_worker(
worker_manager,
rerank_model_name,
rerank_model_path,
ext_worker_kwargs={"rerank_model": True},
)
else:
from dbgpt.model.cluster.controller.controller import (
ModelRegistryClient,
initialize_controller,
)
from dbgpt.model.cluster.worker.remote_manager import RemoteWorkerManager
if not worker_params.controller_addr:
raise ValueError("Controller can`t be None")
logger.info(f"Worker params: {worker_params}")
client = ModelRegistryClient(worker_params.controller_addr)
worker_manager.worker_manager = RemoteWorkerManager(client)
worker_manager.after_start(start_listener)
initialize_controller(
app=app,
remote_controller_addr=worker_params.controller_addr,
system_app=system_app,
)
loop = asyncio.get_event_loop()
loop.run_until_complete(worker_manager.start())
if include_router and app:
# mount WorkerManager router
app.include_router(router, prefix="/api")
def run_worker_manager(
app=None,
include_router: bool = True,
model_name: str = None,
model_path: str = None,
standalone: bool = False,
port: int = None,
embedding_model_name: str = None,
embedding_model_path: str = None,
start_listener: Callable[["WorkerManager"], None] = None,
**kwargs,
):
global worker_manager
worker_params: ModelWorkerParameters = _parse_worker_params(
model_name=model_name,
model_path=model_path,
standalone=standalone,
port=port,
**kwargs,
)
setup_logging(
"dbgpt",
logging_level=worker_params.log_level,
logger_filename=worker_params.log_file,
)
embedded_mod = True
logger.info(f"Worker params: {worker_params}")
system_app = SystemApp()
if not app:
# Run worker manager independently
embedded_mod = False
app = _setup_fastapi(worker_params, system_app=system_app)
system_app._asgi_app = app
initialize_tracer(
os.path.join(LOGDIR, worker_params.tracer_file),
system_app=system_app,
root_operation_name="DB-GPT-ModelWorker",
tracer_storage_cls=worker_params.tracer_storage_cls,
enable_open_telemetry=worker_params.tracer_to_open_telemetry,
otlp_endpoint=worker_params.otel_exporter_otlp_traces_endpoint,
otlp_insecure=worker_params.otel_exporter_otlp_traces_insecure,
otlp_timeout=worker_params.otel_exporter_otlp_traces_timeout,
)
_start_local_worker(worker_manager, worker_params)
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
worker_manager.after_start(start_listener)
if include_router:
app.include_router(router, prefix="/api")
if not embedded_mod:
import uvicorn
uvicorn.run(
app, host=worker_params.host, port=worker_params.port, log_level="info"
)
else:
# Embedded mod, start worker manager
loop = asyncio.get_event_loop()
loop.run_until_complete(worker_manager.start())
if __name__ == "__main__":
run_worker_manager()