mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-03 09:34:04 +00:00
1195 lines
45 KiB
Python
1195 lines
45 KiB
Python
import asyncio
|
|
import itertools
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import asdict
|
|
from typing import AsyncIterator, Awaitable, Callable, Iterator
|
|
|
|
from fastapi import APIRouter
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from dbgpt.component import SystemApp
|
|
from dbgpt.configs.model_config import LOGDIR
|
|
from dbgpt.core import ModelMetadata, ModelOutput
|
|
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
|
|
from dbgpt.model.cluster.base import *
|
|
from dbgpt.model.cluster.manager_base import (
|
|
WorkerManager,
|
|
WorkerManagerFactory,
|
|
WorkerRunData,
|
|
)
|
|
from dbgpt.model.cluster.registry import ModelRegistry
|
|
from dbgpt.model.cluster.worker_base import ModelWorker
|
|
from dbgpt.model.parameter import ModelWorkerParameters, WorkerType
|
|
from dbgpt.model.utils.llm_utils import list_supported_models
|
|
from dbgpt.util.fastapi import create_app, register_event_handler
|
|
from dbgpt.util.parameter_utils import (
|
|
EnvArgumentParser,
|
|
ParameterDescription,
|
|
_dict_to_command_args,
|
|
_get_dict_from_obj,
|
|
)
|
|
from dbgpt.util.system_utils import get_system_info
|
|
from dbgpt.util.tracer import SpanType, SpanTypeRunName, initialize_tracer, root_tracer
|
|
from dbgpt.util.utils import setup_http_service_logging, setup_logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
|
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
|
SendHeartbeatFunc = Callable[[WorkerRunData], Awaitable[None]]
|
|
ApplyFunction = Callable[[WorkerRunData], Awaitable[None]]
|
|
|
|
|
|
async def _async_heartbeat_sender(
|
|
worker_run_data: WorkerRunData,
|
|
heartbeat_interval,
|
|
send_heartbeat_func: SendHeartbeatFunc,
|
|
):
|
|
while not worker_run_data.stop_event.is_set():
|
|
try:
|
|
await send_heartbeat_func(worker_run_data)
|
|
except Exception as e:
|
|
logger.warn(f"Send heartbeat func error: {str(e)}")
|
|
finally:
|
|
await asyncio.sleep(heartbeat_interval)
|
|
|
|
|
|
class LocalWorkerManager(WorkerManager):
|
|
def __init__(
|
|
self,
|
|
register_func: RegisterFunc = None,
|
|
deregister_func: DeregisterFunc = None,
|
|
send_heartbeat_func: SendHeartbeatFunc = None,
|
|
model_registry: ModelRegistry = None,
|
|
host: str = None,
|
|
port: int = None,
|
|
) -> None:
|
|
self.workers: Dict[str, List[WorkerRunData]] = dict()
|
|
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
|
|
self.register_func = register_func
|
|
self.deregister_func = deregister_func
|
|
self.send_heartbeat_func = send_heartbeat_func
|
|
self.model_registry = model_registry
|
|
self.host = host
|
|
self.port = port
|
|
self.start_listeners = []
|
|
|
|
self.run_data = WorkerRunData(
|
|
host=self.host,
|
|
port=self.port,
|
|
worker_key=self._worker_key(
|
|
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
|
|
),
|
|
worker=None,
|
|
worker_params=None,
|
|
model_params=None,
|
|
stop_event=asyncio.Event(),
|
|
semaphore=None,
|
|
command_args=None,
|
|
)
|
|
|
|
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
|
return WorkerType.to_worker_key(model_name, worker_type)
|
|
|
|
async def run_blocking_func(self, func, *args):
|
|
if asyncio.iscoroutinefunction(func):
|
|
raise ValueError(f"The function {func} is not blocking function")
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(self.executor, func, *args)
|
|
|
|
async def start(self):
|
|
if len(self.workers) > 0:
|
|
out = await self._start_all_worker(apply_req=None)
|
|
if not out.success:
|
|
raise Exception(out.message)
|
|
if self.register_func:
|
|
await self.register_func(self.run_data)
|
|
if self.send_heartbeat_func:
|
|
asyncio.create_task(
|
|
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
|
|
)
|
|
for listener in self.start_listeners:
|
|
if asyncio.iscoroutinefunction(listener):
|
|
await listener(self)
|
|
else:
|
|
listener(self)
|
|
|
|
async def stop(self, ignore_exception: bool = False):
|
|
if not self.run_data.stop_event.is_set():
|
|
logger.info("Stop all workers")
|
|
self.run_data.stop_event.clear()
|
|
stop_tasks = []
|
|
stop_tasks.append(
|
|
self._stop_all_worker(apply_req=None, ignore_exception=ignore_exception)
|
|
)
|
|
if self.deregister_func:
|
|
# If ignore_exception is True, use exception handling to ignore any exceptions raised from self.deregister_func
|
|
if ignore_exception:
|
|
|
|
async def safe_deregister_func(run_data):
|
|
try:
|
|
await self.deregister_func(run_data)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Stop worker, ignored exception from deregister_func: {e}"
|
|
)
|
|
|
|
stop_tasks.append(safe_deregister_func(self.run_data))
|
|
else:
|
|
stop_tasks.append(self.deregister_func(self.run_data))
|
|
|
|
results = await asyncio.gather(*stop_tasks)
|
|
if not results[0].success and not ignore_exception:
|
|
raise Exception(results[0].message)
|
|
|
|
def after_start(self, listener: Callable[["WorkerManager"], None]):
|
|
self.start_listeners.append(listener)
|
|
|
|
def add_worker(
|
|
self,
|
|
worker: ModelWorker,
|
|
worker_params: ModelWorkerParameters,
|
|
command_args: List[str] = None,
|
|
) -> bool:
|
|
if not command_args:
|
|
command_args = sys.argv[1:]
|
|
worker.load_worker(**asdict(worker_params))
|
|
|
|
if not worker_params.worker_type:
|
|
worker_params.worker_type = worker.worker_type()
|
|
|
|
if isinstance(worker_params.worker_type, WorkerType):
|
|
worker_params.worker_type = worker_params.worker_type.value
|
|
|
|
worker_key = self._worker_key(
|
|
worker_params.worker_type, worker_params.model_name
|
|
)
|
|
|
|
# Load model params from persist storage
|
|
model_params = worker.parse_parameters(command_args=command_args)
|
|
|
|
worker_run_data = WorkerRunData(
|
|
host=self.host,
|
|
port=self.port,
|
|
worker_key=worker_key,
|
|
worker=worker,
|
|
worker_params=worker_params,
|
|
model_params=model_params,
|
|
stop_event=asyncio.Event(),
|
|
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
|
|
command_args=command_args,
|
|
)
|
|
instances = self.workers.get(worker_key)
|
|
if not instances:
|
|
instances = [worker_run_data]
|
|
self.workers[worker_key] = instances
|
|
logger.info(f"Init empty instances list for {worker_key}")
|
|
return True
|
|
else:
|
|
# TODO Update worker
|
|
logger.warning(f"Instance {worker_key} exist")
|
|
return False
|
|
|
|
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
|
|
worker_key = self._worker_key(
|
|
worker_params.worker_type, worker_params.model_name
|
|
)
|
|
instances = self.workers.get(worker_key)
|
|
if instances:
|
|
del self.workers[worker_key]
|
|
|
|
async def model_startup(self, startup_req: WorkerStartupRequest):
|
|
"""Start model"""
|
|
model_name = startup_req.model
|
|
worker_type = startup_req.worker_type
|
|
params = startup_req.params
|
|
logger.debug(
|
|
f"start model, model name {model_name}, worker type {worker_type}, params: {params}"
|
|
)
|
|
worker_params: ModelWorkerParameters = ModelWorkerParameters.from_dict(
|
|
params, ignore_extra_fields=True
|
|
)
|
|
if not worker_params.model_name:
|
|
worker_params.model_name = model_name
|
|
worker = _build_worker(worker_params)
|
|
command_args = _dict_to_command_args(params)
|
|
success = await self.run_blocking_func(
|
|
self.add_worker, worker, worker_params, command_args
|
|
)
|
|
if not success:
|
|
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
|
|
logger.warning(f"{msg}, worker_params: {worker_params}")
|
|
self._remove_worker(worker_params)
|
|
raise Exception(msg)
|
|
supported_types = WorkerType.values()
|
|
if worker_type not in supported_types:
|
|
self._remove_worker(worker_params)
|
|
raise ValueError(
|
|
f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}"
|
|
)
|
|
start_apply_req = WorkerApplyRequest(
|
|
model=worker_params.model_name,
|
|
apply_type=WorkerApplyType.START,
|
|
worker_type=worker_type,
|
|
)
|
|
out: WorkerApplyOutput = None
|
|
try:
|
|
out = await self.worker_apply(start_apply_req)
|
|
except Exception as e:
|
|
self._remove_worker(worker_params)
|
|
raise e
|
|
if not out.success:
|
|
self._remove_worker(worker_params)
|
|
raise Exception(out.message)
|
|
|
|
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
|
logger.info(f"Begin shutdown model, shutdown_req: {shutdown_req}")
|
|
apply_req = WorkerApplyRequest(
|
|
model=shutdown_req.model,
|
|
apply_type=WorkerApplyType.STOP,
|
|
worker_type=shutdown_req.worker_type,
|
|
)
|
|
out = await self._stop_all_worker(apply_req)
|
|
if not out.success:
|
|
raise Exception(out.message)
|
|
|
|
async def supported_models(self) -> List[WorkerSupportedModel]:
|
|
models = await self.run_blocking_func(list_supported_models)
|
|
return [WorkerSupportedModel(host=self.host, port=self.port, models=models)]
|
|
|
|
async def get_model_instances(
|
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
) -> List[WorkerRunData]:
|
|
return self.sync_get_model_instances(worker_type, model_name, healthy_only)
|
|
|
|
async def get_all_model_instances(
|
|
self, worker_type: str, healthy_only: bool = True
|
|
) -> List[WorkerRunData]:
|
|
instances = list(itertools.chain(*self.workers.values()))
|
|
result = []
|
|
for instance in instances:
|
|
name, wt = WorkerType.parse_worker_key(instance.worker_key)
|
|
if wt != worker_type or (healthy_only and instance.stopped):
|
|
continue
|
|
result.append(instance)
|
|
return result
|
|
|
|
def sync_get_model_instances(
|
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
) -> List[WorkerRunData]:
|
|
worker_key = self._worker_key(worker_type, model_name)
|
|
return self.workers.get(worker_key, [])
|
|
|
|
def _simple_select(
|
|
self, worker_type: str, model_name: str, worker_instances: List[WorkerRunData]
|
|
) -> WorkerRunData:
|
|
if not worker_instances:
|
|
raise Exception(
|
|
f"Cound not found worker instances for model name {model_name} and worker type {worker_type}"
|
|
)
|
|
worker_run_data = random.choice(worker_instances)
|
|
return worker_run_data
|
|
|
|
async def select_one_instance(
|
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
) -> WorkerRunData:
|
|
worker_instances = await self.get_model_instances(
|
|
worker_type, model_name, healthy_only
|
|
)
|
|
return self._simple_select(worker_type, model_name, worker_instances)
|
|
|
|
def sync_select_one_instance(
|
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
) -> WorkerRunData:
|
|
worker_instances = self.sync_get_model_instances(
|
|
worker_type, model_name, healthy_only
|
|
)
|
|
return self._simple_select(worker_type, model_name, worker_instances)
|
|
|
|
async def _get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
|
|
model = params.get("model")
|
|
if not model:
|
|
raise Exception("Model name count not be empty")
|
|
return await self.select_one_instance(worker_type, model, healthy_only=True)
|
|
|
|
def _sync_get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
|
|
model = params.get("model")
|
|
if not model:
|
|
raise Exception("Model name count not be empty")
|
|
return self.sync_select_one_instance(worker_type, model, healthy_only=True)
|
|
|
|
async def generate_stream(
|
|
self, params: Dict, async_wrapper=None, **kwargs
|
|
) -> AsyncIterator[ModelOutput]:
|
|
"""Generate stream result, chat scene"""
|
|
with root_tracer.start_span(
|
|
"WorkerManager.generate_stream", params.get("span_id")
|
|
) as span:
|
|
params["span_id"] = span.span_id
|
|
try:
|
|
worker_run_data = await self._get_model(params)
|
|
except Exception as e:
|
|
yield ModelOutput(
|
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
|
error_code=1,
|
|
)
|
|
return
|
|
async with worker_run_data.semaphore:
|
|
if worker_run_data.worker.support_async():
|
|
async for outout in worker_run_data.worker.async_generate_stream(
|
|
params
|
|
):
|
|
yield outout
|
|
else:
|
|
if not async_wrapper:
|
|
from starlette.concurrency import iterate_in_threadpool
|
|
|
|
async_wrapper = iterate_in_threadpool
|
|
async for output in async_wrapper(
|
|
worker_run_data.worker.generate_stream(params)
|
|
):
|
|
yield output
|
|
|
|
async def generate(self, params: Dict) -> ModelOutput:
|
|
"""Generate non stream result"""
|
|
with root_tracer.start_span(
|
|
"WorkerManager.generate", params.get("span_id")
|
|
) as span:
|
|
params["span_id"] = span.span_id
|
|
try:
|
|
worker_run_data = await self._get_model(params)
|
|
except Exception as e:
|
|
return ModelOutput(
|
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
|
error_code=1,
|
|
)
|
|
async with worker_run_data.semaphore:
|
|
if worker_run_data.worker.support_async():
|
|
return await worker_run_data.worker.async_generate(params)
|
|
else:
|
|
return await self.run_blocking_func(
|
|
worker_run_data.worker.generate, params
|
|
)
|
|
|
|
async def embeddings(self, params: Dict) -> List[List[float]]:
|
|
"""Embed input"""
|
|
with root_tracer.start_span(
|
|
"WorkerManager.embeddings", params.get("span_id")
|
|
) as span:
|
|
params["span_id"] = span.span_id
|
|
try:
|
|
worker_run_data = await self._get_model(params, worker_type="text2vec")
|
|
except Exception as e:
|
|
raise e
|
|
async with worker_run_data.semaphore:
|
|
if worker_run_data.worker.support_async():
|
|
return await worker_run_data.worker.async_embeddings(params)
|
|
else:
|
|
return await self.run_blocking_func(
|
|
worker_run_data.worker.embeddings, params
|
|
)
|
|
|
|
def sync_embeddings(self, params: Dict) -> List[List[float]]:
|
|
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
|
|
return worker_run_data.worker.embeddings(params)
|
|
|
|
async def count_token(self, params: Dict) -> int:
|
|
"""Count token of prompt"""
|
|
with root_tracer.start_span(
|
|
"WorkerManager.count_token", params.get("span_id")
|
|
) as span:
|
|
params["span_id"] = span.span_id
|
|
try:
|
|
worker_run_data = await self._get_model(params)
|
|
except Exception as e:
|
|
raise e
|
|
prompt = params.get("prompt")
|
|
async with worker_run_data.semaphore:
|
|
if worker_run_data.worker.support_async():
|
|
return await worker_run_data.worker.async_count_token(prompt)
|
|
else:
|
|
return await self.run_blocking_func(
|
|
worker_run_data.worker.count_token, prompt
|
|
)
|
|
|
|
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
|
"""Get model metadata"""
|
|
with root_tracer.start_span(
|
|
"WorkerManager.get_model_metadata", params.get("span_id")
|
|
) as span:
|
|
params["span_id"] = span.span_id
|
|
try:
|
|
worker_run_data = await self._get_model(params)
|
|
except Exception as e:
|
|
raise e
|
|
async with worker_run_data.semaphore:
|
|
if worker_run_data.worker.support_async():
|
|
return await worker_run_data.worker.async_get_model_metadata(params)
|
|
else:
|
|
return await self.run_blocking_func(
|
|
worker_run_data.worker.get_model_metadata, params
|
|
)
|
|
|
|
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
|
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
|
|
if apply_req.apply_type == WorkerApplyType.START:
|
|
apply_func = self._start_all_worker
|
|
elif apply_req.apply_type == WorkerApplyType.STOP:
|
|
apply_func = self._stop_all_worker
|
|
elif apply_req.apply_type == WorkerApplyType.RESTART:
|
|
apply_func = self._restart_all_worker
|
|
elif apply_req.apply_type == WorkerApplyType.UPDATE_PARAMS:
|
|
apply_func = self._update_all_worker_params
|
|
else:
|
|
raise ValueError(f"Unsupported apply type {apply_req.apply_type}")
|
|
return await apply_func(apply_req)
|
|
|
|
async def parameter_descriptions(
|
|
self, worker_type: str, model_name: str
|
|
) -> List[ParameterDescription]:
|
|
worker_instances = await self.get_model_instances(worker_type, model_name)
|
|
if not worker_instances:
|
|
raise Exception(
|
|
f"Not worker instances for model name {model_name} worker type {worker_type}"
|
|
)
|
|
worker_run_data = worker_instances[0]
|
|
return worker_run_data.worker.parameter_descriptions()
|
|
|
|
async def _apply_worker(
|
|
self, apply_req: WorkerApplyRequest, apply_func: ApplyFunction
|
|
) -> None:
|
|
"""Apply function to worker instances in parallel
|
|
|
|
Args:
|
|
apply_req (WorkerApplyRequest): Worker apply request
|
|
apply_func (ApplyFunction): Function to apply to worker instances, now function is async function
|
|
"""
|
|
logger.info(f"Apply req: {apply_req}, apply_func: {apply_func}")
|
|
if apply_req:
|
|
worker_type = apply_req.worker_type.value
|
|
model_name = apply_req.model
|
|
worker_instances = await self.get_model_instances(
|
|
worker_type, model_name, healthy_only=False
|
|
)
|
|
if not worker_instances:
|
|
raise Exception(
|
|
f"No worker instance found for the model {model_name} worker type {worker_type}"
|
|
)
|
|
else:
|
|
# Apply to all workers
|
|
worker_instances = list(itertools.chain(*self.workers.values()))
|
|
logger.info(f"Apply to all workers")
|
|
return await asyncio.gather(
|
|
*(apply_func(worker) for worker in worker_instances)
|
|
)
|
|
|
|
async def _start_all_worker(
|
|
self, apply_req: WorkerApplyRequest
|
|
) -> WorkerApplyOutput:
|
|
from httpx import TimeoutException, TransportError
|
|
|
|
# TODO avoid start twice
|
|
start_time = time.time()
|
|
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
|
|
|
async def _start_worker(worker_run_data: WorkerRunData):
|
|
_start_time = time.time()
|
|
info = worker_run_data._to_print_key()
|
|
out = WorkerApplyOutput("")
|
|
try:
|
|
await self.run_blocking_func(
|
|
worker_run_data.worker.start,
|
|
worker_run_data.model_params,
|
|
worker_run_data.command_args,
|
|
)
|
|
worker_run_data.stop_event.clear()
|
|
if worker_run_data.worker_params.register and self.register_func:
|
|
# Register worker to controller
|
|
await self.register_func(worker_run_data)
|
|
if (
|
|
worker_run_data.worker_params.send_heartbeat
|
|
and self.send_heartbeat_func
|
|
):
|
|
asyncio.create_task(
|
|
_async_heartbeat_sender(
|
|
worker_run_data,
|
|
worker_run_data.worker_params.heartbeat_interval,
|
|
self.send_heartbeat_func,
|
|
)
|
|
)
|
|
out.message = f"{info} start successfully"
|
|
except TimeoutException as e:
|
|
out.success = False
|
|
out.message = (
|
|
f"{info} start failed for network timeout, please make "
|
|
f"sure your port is available, if you are using global network "
|
|
f"proxy, please close it"
|
|
)
|
|
except TransportError as e:
|
|
out.success = False
|
|
out.message = (
|
|
f"{info} start failed for network error, please make "
|
|
f"sure your port is available, if you are using global network "
|
|
"proxy, please close it"
|
|
)
|
|
except Exception:
|
|
err_msg = traceback.format_exc()
|
|
out.success = False
|
|
out.message = f"{info} start failed, {err_msg}"
|
|
finally:
|
|
out.timecost = time.time() - _start_time
|
|
return out
|
|
|
|
outs = await self._apply_worker(apply_req, _start_worker)
|
|
out = WorkerApplyOutput.reduce(outs)
|
|
out.timecost = time.time() - start_time
|
|
return out
|
|
|
|
async def _stop_all_worker(
|
|
self, apply_req: WorkerApplyRequest, ignore_exception: bool = False
|
|
) -> WorkerApplyOutput:
|
|
start_time = time.time()
|
|
|
|
async def _stop_worker(worker_run_data: WorkerRunData):
|
|
_start_time = time.time()
|
|
info = worker_run_data._to_print_key()
|
|
out = WorkerApplyOutput("")
|
|
try:
|
|
await self.run_blocking_func(worker_run_data.worker.stop)
|
|
# Set stop event
|
|
worker_run_data.stop_event.set()
|
|
if worker_run_data._heartbeat_future:
|
|
# Wait thread finish
|
|
worker_run_data._heartbeat_future.result()
|
|
worker_run_data._heartbeat_future = None
|
|
if (
|
|
worker_run_data.worker_params.register
|
|
and self.register_func
|
|
and self.deregister_func
|
|
):
|
|
_deregister_func = self.deregister_func
|
|
if ignore_exception:
|
|
|
|
async def safe_deregister_func(run_data):
|
|
try:
|
|
await self.deregister_func(run_data)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Stop worker, ignored exception from deregister_func: {e}"
|
|
)
|
|
|
|
_deregister_func = safe_deregister_func
|
|
await _deregister_func(worker_run_data)
|
|
# Remove metadata
|
|
self._remove_worker(worker_run_data.worker_params)
|
|
out.message = f"{info} stop successfully"
|
|
except Exception as e:
|
|
out.success = False
|
|
out.message = f"{info} stop failed, {str(e)}"
|
|
finally:
|
|
out.timecost = time.time() - _start_time
|
|
return out
|
|
|
|
outs = await self._apply_worker(apply_req, _stop_worker)
|
|
out = WorkerApplyOutput.reduce(outs)
|
|
out.timecost = time.time() - start_time
|
|
return out
|
|
|
|
async def _restart_all_worker(
|
|
self, apply_req: WorkerApplyRequest
|
|
) -> WorkerApplyOutput:
|
|
out = await self._stop_all_worker(apply_req, ignore_exception=True)
|
|
if not out.success:
|
|
return out
|
|
return await self._start_all_worker(apply_req)
|
|
|
|
async def _update_all_worker_params(
|
|
self, apply_req: WorkerApplyRequest
|
|
) -> WorkerApplyOutput:
|
|
start_time = time.time()
|
|
need_restart = False
|
|
|
|
async def update_params(worker_run_data: WorkerRunData):
|
|
nonlocal need_restart
|
|
new_params = apply_req.params
|
|
if not new_params:
|
|
return
|
|
if worker_run_data.model_params.update_from(new_params):
|
|
need_restart = True
|
|
|
|
await self._apply_worker(apply_req, update_params)
|
|
message = f"Update worker params successfully"
|
|
timecost = time.time() - start_time
|
|
if need_restart:
|
|
logger.info("Model params update successfully, begin restart worker")
|
|
await self._restart_all_worker(apply_req)
|
|
timecost = time.time() - start_time
|
|
message = f"Update worker params and restart successfully"
|
|
return WorkerApplyOutput(message=message, timecost=timecost)
|
|
|
|
|
|
class WorkerManagerAdapter(WorkerManager):
|
|
def __init__(self, worker_manager: WorkerManager = None) -> None:
|
|
self.worker_manager = worker_manager
|
|
|
|
async def start(self):
|
|
return await self.worker_manager.start()
|
|
|
|
async def stop(self, ignore_exception: bool = False):
|
|
return await self.worker_manager.stop(ignore_exception=ignore_exception)
|
|
|
|
def after_start(self, listener: Callable[["WorkerManager"], None]):
|
|
if listener is not None:
|
|
self.worker_manager.after_start(listener)
|
|
|
|
async def supported_models(self) -> List[WorkerSupportedModel]:
|
|
return await self.worker_manager.supported_models()
|
|
|
|
async def model_startup(self, startup_req: WorkerStartupRequest):
|
|
return await self.worker_manager.model_startup(startup_req)
|
|
|
|
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
|
return await self.worker_manager.model_shutdown(shutdown_req)
|
|
|
|
async def get_model_instances(
|
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
) -> List[WorkerRunData]:
|
|
return await self.worker_manager.get_model_instances(
|
|
worker_type, model_name, healthy_only
|
|
)
|
|
|
|
async def get_all_model_instances(
|
|
self, worker_type: str, healthy_only: bool = True
|
|
) -> List[WorkerRunData]:
|
|
return await self.worker_manager.get_all_model_instances(
|
|
worker_type, healthy_only
|
|
)
|
|
|
|
def sync_get_model_instances(
|
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
) -> List[WorkerRunData]:
|
|
return self.worker_manager.sync_get_model_instances(
|
|
worker_type, model_name, healthy_only
|
|
)
|
|
|
|
async def select_one_instance(
|
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
) -> WorkerRunData:
|
|
return await self.worker_manager.select_one_instance(
|
|
worker_type, model_name, healthy_only
|
|
)
|
|
|
|
def sync_select_one_instance(
|
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
) -> WorkerRunData:
|
|
return self.worker_manager.sync_select_one_instance(
|
|
worker_type, model_name, healthy_only
|
|
)
|
|
|
|
async def generate_stream(
|
|
self, params: Dict, **kwargs
|
|
) -> AsyncIterator[ModelOutput]:
|
|
async for output in self.worker_manager.generate_stream(params, **kwargs):
|
|
yield output
|
|
|
|
async def generate(self, params: Dict) -> ModelOutput:
|
|
return await self.worker_manager.generate(params)
|
|
|
|
async def embeddings(self, params: Dict) -> List[List[float]]:
|
|
return await self.worker_manager.embeddings(params)
|
|
|
|
def sync_embeddings(self, params: Dict) -> List[List[float]]:
|
|
return self.worker_manager.sync_embeddings(params)
|
|
|
|
async def count_token(self, params: Dict) -> int:
|
|
return await self.worker_manager.count_token(params)
|
|
|
|
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
|
return await self.worker_manager.get_model_metadata(params)
|
|
|
|
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
|
return await self.worker_manager.worker_apply(apply_req)
|
|
|
|
async def parameter_descriptions(
|
|
self, worker_type: str, model_name: str
|
|
) -> List[ParameterDescription]:
|
|
return await self.worker_manager.parameter_descriptions(worker_type, model_name)
|
|
|
|
|
|
class _DefaultWorkerManagerFactory(WorkerManagerFactory):
|
|
def __init__(
|
|
self, system_app: SystemApp | None = None, worker_manager: WorkerManager = None
|
|
):
|
|
super().__init__(system_app)
|
|
self.worker_manager = worker_manager
|
|
|
|
def create(self) -> WorkerManager:
|
|
return self.worker_manager
|
|
|
|
|
|
worker_manager = WorkerManagerAdapter()
|
|
router = APIRouter()
|
|
|
|
|
|
async def generate_json_stream(params):
|
|
from starlette.concurrency import iterate_in_threadpool
|
|
|
|
async for output in worker_manager.generate_stream(
|
|
params, async_wrapper=iterate_in_threadpool
|
|
):
|
|
yield json.dumps(asdict(output), ensure_ascii=False).encode() + b"\0"
|
|
|
|
|
|
@router.post("/worker/generate_stream")
|
|
async def api_generate_stream(request: PromptRequest):
|
|
params = request.dict(exclude_none=True)
|
|
span_id = root_tracer.get_current_span_id()
|
|
if "span_id" not in params and span_id:
|
|
params["span_id"] = span_id
|
|
generator = generate_json_stream(params)
|
|
return StreamingResponse(generator)
|
|
|
|
|
|
@router.post("/worker/generate")
|
|
async def api_generate(request: PromptRequest):
|
|
params = request.dict(exclude_none=True)
|
|
span_id = root_tracer.get_current_span_id()
|
|
if "span_id" not in params and span_id:
|
|
params["span_id"] = span_id
|
|
return await worker_manager.generate(params)
|
|
|
|
|
|
@router.post("/worker/embeddings")
|
|
async def api_embeddings(request: EmbeddingsRequest):
|
|
params = request.dict(exclude_none=True)
|
|
span_id = root_tracer.get_current_span_id()
|
|
if "span_id" not in params and span_id:
|
|
params["span_id"] = span_id
|
|
return await worker_manager.embeddings(params)
|
|
|
|
|
|
@router.post("/worker/count_token")
|
|
async def api_count_token(request: CountTokenRequest):
|
|
params = request.dict(exclude_none=True)
|
|
span_id = root_tracer.get_current_span_id()
|
|
if "span_id" not in params and span_id:
|
|
params["span_id"] = span_id
|
|
return await worker_manager.count_token(params)
|
|
|
|
|
|
@router.post("/worker/model_metadata")
|
|
async def api_get_model_metadata(request: ModelMetadataRequest):
|
|
params = request.dict(exclude_none=True)
|
|
span_id = root_tracer.get_current_span_id()
|
|
if "span_id" not in params and span_id:
|
|
params["span_id"] = span_id
|
|
return await worker_manager.get_model_metadata(params)
|
|
|
|
|
|
@router.post("/worker/apply")
|
|
async def api_worker_apply(request: WorkerApplyRequest):
|
|
return await worker_manager.worker_apply(request)
|
|
|
|
|
|
@router.get("/worker/parameter/descriptions")
|
|
async def api_worker_parameter_descs(
|
|
model: str, worker_type: str = WorkerType.LLM.value
|
|
):
|
|
return await worker_manager.parameter_descriptions(worker_type, model)
|
|
|
|
|
|
@router.get("/worker/models/supports")
|
|
async def api_supported_models():
|
|
"""Get all supported models.
|
|
|
|
This method reads all models from the configuration file and tries to perform some basic checks on the model (like if the path exists).
|
|
|
|
If it's a RemoteWorkerManager, this method returns the list of models supported by the entire cluster.
|
|
"""
|
|
return await worker_manager.supported_models()
|
|
|
|
|
|
@router.post("/worker/models/startup")
|
|
async def api_model_startup(request: WorkerStartupRequest):
|
|
"""Start up a specific model."""
|
|
return await worker_manager.model_startup(request)
|
|
|
|
|
|
@router.post("/worker/models/shutdown")
|
|
async def api_model_shutdown(request: WorkerStartupRequest):
|
|
"""Shut down a specific model."""
|
|
return await worker_manager.model_shutdown(request)
|
|
|
|
|
|
def _setup_fastapi(
|
|
worker_params: ModelWorkerParameters,
|
|
app=None,
|
|
ignore_exception: bool = False,
|
|
system_app: Optional[SystemApp] = None,
|
|
):
|
|
if not app:
|
|
app = create_app()
|
|
setup_http_service_logging()
|
|
|
|
if system_app:
|
|
system_app._asgi_app = app
|
|
|
|
if worker_params.standalone:
|
|
from dbgpt.model.cluster.controller.controller import initialize_controller
|
|
from dbgpt.model.cluster.controller.controller import (
|
|
router as controller_router,
|
|
)
|
|
|
|
if not worker_params.controller_addr:
|
|
# if we have http_proxy or https_proxy in env, the server can not start
|
|
# so set it to empty here
|
|
os.environ["http_proxy"] = ""
|
|
os.environ["https_proxy"] = ""
|
|
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
|
logger.info(
|
|
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"
|
|
)
|
|
initialize_controller(app=app, system_app=system_app)
|
|
app.include_router(controller_router, prefix="/api")
|
|
|
|
async def startup_event():
|
|
async def start_worker_manager():
|
|
try:
|
|
await worker_manager.start()
|
|
except Exception as e:
|
|
import signal
|
|
|
|
logger.error(f"Error starting worker manager: {str(e)}")
|
|
os.kill(os.getpid(), signal.SIGINT)
|
|
|
|
# It cannot be blocked here because the startup of worker_manager depends on
|
|
# the fastapi app (registered to the controller)
|
|
asyncio.create_task(start_worker_manager())
|
|
|
|
async def shutdown_event():
|
|
await worker_manager.stop(ignore_exception=ignore_exception)
|
|
|
|
register_event_handler(app, "startup", startup_event)
|
|
register_event_handler(app, "shutdown", shutdown_event)
|
|
return app
|
|
|
|
|
|
def _parse_worker_params(
|
|
model_name: str = None, model_path: str = None, **kwargs
|
|
) -> ModelWorkerParameters:
|
|
worker_args = EnvArgumentParser()
|
|
env_prefix = None
|
|
if model_name:
|
|
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
|
worker_params: ModelWorkerParameters = worker_args.parse_args_into_dataclass(
|
|
ModelWorkerParameters,
|
|
env_prefixes=[env_prefix],
|
|
model_name=model_name,
|
|
model_path=model_path,
|
|
**kwargs,
|
|
)
|
|
env_prefix = EnvArgumentParser.get_env_prefix(worker_params.model_name)
|
|
# Read parameters agein with prefix of model name.
|
|
new_worker_params = worker_args.parse_args_into_dataclass(
|
|
ModelWorkerParameters,
|
|
env_prefixes=[env_prefix],
|
|
model_name=worker_params.model_name,
|
|
model_path=worker_params.model_path,
|
|
**kwargs,
|
|
)
|
|
worker_params.update_from(new_worker_params)
|
|
if worker_params.model_alias:
|
|
worker_params.model_name = worker_params.model_alias
|
|
|
|
# logger.info(f"Worker params: {worker_params}")
|
|
return worker_params
|
|
|
|
|
|
def _create_local_model_manager(
|
|
worker_params: ModelWorkerParameters,
|
|
) -> LocalWorkerManager:
|
|
from dbgpt.util.net_utils import _get_ip_address
|
|
|
|
host = (
|
|
worker_params.worker_register_host
|
|
if worker_params.worker_register_host
|
|
else _get_ip_address()
|
|
)
|
|
port = worker_params.port
|
|
if not worker_params.register or not worker_params.controller_addr:
|
|
logger.info(
|
|
f"Not register current to controller, register: {worker_params.register}, controller_addr: {worker_params.controller_addr}"
|
|
)
|
|
return LocalWorkerManager(host=host, port=port)
|
|
else:
|
|
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
|
|
|
|
client = ModelRegistryClient(worker_params.controller_addr)
|
|
|
|
async def register_func(worker_run_data: WorkerRunData):
|
|
instance = ModelInstance(
|
|
model_name=worker_run_data.worker_key, host=host, port=port
|
|
)
|
|
return await client.register_instance(instance)
|
|
|
|
async def deregister_func(worker_run_data: WorkerRunData):
|
|
instance = ModelInstance(
|
|
model_name=worker_run_data.worker_key, host=host, port=port
|
|
)
|
|
return await client.deregister_instance(instance)
|
|
|
|
async def send_heartbeat_func(worker_run_data: WorkerRunData):
|
|
instance = ModelInstance(
|
|
model_name=worker_run_data.worker_key, host=host, port=port
|
|
)
|
|
return await client.send_heartbeat(instance)
|
|
|
|
return LocalWorkerManager(
|
|
register_func=register_func,
|
|
deregister_func=deregister_func,
|
|
send_heartbeat_func=send_heartbeat_func,
|
|
host=host,
|
|
port=port,
|
|
)
|
|
|
|
|
|
def _build_worker(
|
|
worker_params: ModelWorkerParameters,
|
|
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
worker_class = worker_params.worker_class
|
|
if worker_class:
|
|
from dbgpt.util.module_utils import import_from_checked_string
|
|
|
|
worker_cls = import_from_checked_string(worker_class, ModelWorker)
|
|
logger.info(f"Import worker class from {worker_class} successfully")
|
|
else:
|
|
if (
|
|
worker_params.worker_type is None
|
|
or worker_params.worker_type == WorkerType.LLM
|
|
):
|
|
from dbgpt.model.cluster.worker.default_worker import DefaultModelWorker
|
|
|
|
worker_cls = DefaultModelWorker
|
|
elif worker_params.worker_type == WorkerType.TEXT2VEC:
|
|
from dbgpt.model.cluster.worker.embedding_worker import (
|
|
EmbeddingsModelWorker,
|
|
)
|
|
|
|
worker_cls = EmbeddingsModelWorker
|
|
else:
|
|
raise Exception("Unsupported worker type: {worker_params.worker_type}")
|
|
|
|
if ext_worker_kwargs:
|
|
return worker_cls(**ext_worker_kwargs)
|
|
else:
|
|
return worker_cls()
|
|
|
|
|
|
def _start_local_worker(
|
|
worker_manager: WorkerManagerAdapter,
|
|
worker_params: ModelWorkerParameters,
|
|
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
with root_tracer.start_span(
|
|
"WorkerManager._start_local_worker",
|
|
span_type=SpanType.RUN,
|
|
metadata={
|
|
"run_service": SpanTypeRunName.WORKER_MANAGER,
|
|
"params": _get_dict_from_obj(worker_params),
|
|
"sys_infos": _get_dict_from_obj(get_system_info()),
|
|
},
|
|
):
|
|
worker = _build_worker(worker_params, ext_worker_kwargs=ext_worker_kwargs)
|
|
if not worker_manager.worker_manager:
|
|
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
|
worker_manager.worker_manager.add_worker(worker, worker_params)
|
|
|
|
|
|
def _start_local_embedding_worker(
|
|
worker_manager: WorkerManagerAdapter,
|
|
embedding_model_name: str = None,
|
|
embedding_model_path: str = None,
|
|
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
if not embedding_model_name or not embedding_model_path:
|
|
return
|
|
embedding_worker_params = ModelWorkerParameters(
|
|
model_name=embedding_model_name,
|
|
model_path=embedding_model_path,
|
|
worker_type=WorkerType.TEXT2VEC,
|
|
worker_class="dbgpt.model.cluster.worker.embedding_worker.EmbeddingsModelWorker",
|
|
)
|
|
logger.info(
|
|
f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
|
|
)
|
|
_start_local_worker(
|
|
worker_manager, embedding_worker_params, ext_worker_kwargs=ext_worker_kwargs
|
|
)
|
|
|
|
|
|
def initialize_worker_manager_in_client(
|
|
app=None,
|
|
include_router: bool = True,
|
|
model_name: Optional[str] = None,
|
|
model_path: Optional[str] = None,
|
|
run_locally: bool = True,
|
|
controller_addr: Optional[str] = None,
|
|
local_port: int = 5670,
|
|
embedding_model_name: Optional[str] = None,
|
|
embedding_model_path: Optional[str] = None,
|
|
rerank_model_name: Optional[str] = None,
|
|
rerank_model_path: Optional[str] = None,
|
|
start_listener: Optional[Callable[["WorkerManager"], None]] = None,
|
|
system_app: Optional[SystemApp] = None,
|
|
):
|
|
"""Initialize WorkerManager in client.
|
|
If run_locally is True:
|
|
1. Start ModelController
|
|
2. Start LocalWorkerManager
|
|
3. Start worker in LocalWorkerManager
|
|
4. Register worker to ModelController
|
|
|
|
otherwise:
|
|
1. Build ModelRegistryClient with controller address
|
|
2. Start RemoteWorkerManager
|
|
|
|
"""
|
|
global worker_manager
|
|
|
|
if not app:
|
|
raise Exception("app can't be None")
|
|
|
|
if system_app:
|
|
logger.info(f"Register WorkerManager {_DefaultWorkerManagerFactory.name}")
|
|
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
|
|
|
|
worker_params: ModelWorkerParameters = _parse_worker_params(
|
|
model_name=model_name, model_path=model_path, controller_addr=controller_addr
|
|
)
|
|
|
|
controller_addr = None
|
|
if run_locally:
|
|
# TODO start ModelController
|
|
worker_params.standalone = True
|
|
worker_params.register = True
|
|
worker_params.port = local_port
|
|
logger.info(f"Worker params: {worker_params}")
|
|
_setup_fastapi(worker_params, app, ignore_exception=True, system_app=system_app)
|
|
_start_local_worker(worker_manager, worker_params)
|
|
worker_manager.after_start(start_listener)
|
|
_start_local_embedding_worker(
|
|
worker_manager, embedding_model_name, embedding_model_path
|
|
)
|
|
_start_local_embedding_worker(
|
|
worker_manager,
|
|
rerank_model_name,
|
|
rerank_model_path,
|
|
ext_worker_kwargs={"rerank_model": True},
|
|
)
|
|
else:
|
|
from dbgpt.model.cluster.controller.controller import (
|
|
ModelRegistryClient,
|
|
initialize_controller,
|
|
)
|
|
from dbgpt.model.cluster.worker.remote_manager import RemoteWorkerManager
|
|
|
|
if not worker_params.controller_addr:
|
|
raise ValueError("Controller can`t be None")
|
|
logger.info(f"Worker params: {worker_params}")
|
|
client = ModelRegistryClient(worker_params.controller_addr)
|
|
worker_manager.worker_manager = RemoteWorkerManager(client)
|
|
worker_manager.after_start(start_listener)
|
|
initialize_controller(
|
|
app=app,
|
|
remote_controller_addr=worker_params.controller_addr,
|
|
system_app=system_app,
|
|
)
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(worker_manager.start())
|
|
|
|
if include_router and app:
|
|
# mount WorkerManager router
|
|
app.include_router(router, prefix="/api")
|
|
|
|
|
|
def run_worker_manager(
|
|
app=None,
|
|
include_router: bool = True,
|
|
model_name: str = None,
|
|
model_path: str = None,
|
|
standalone: bool = False,
|
|
port: int = None,
|
|
embedding_model_name: str = None,
|
|
embedding_model_path: str = None,
|
|
start_listener: Callable[["WorkerManager"], None] = None,
|
|
**kwargs,
|
|
):
|
|
global worker_manager
|
|
|
|
worker_params: ModelWorkerParameters = _parse_worker_params(
|
|
model_name=model_name,
|
|
model_path=model_path,
|
|
standalone=standalone,
|
|
port=port,
|
|
**kwargs,
|
|
)
|
|
|
|
setup_logging(
|
|
"dbgpt",
|
|
logging_level=worker_params.log_level,
|
|
logger_filename=worker_params.log_file,
|
|
)
|
|
|
|
embedded_mod = True
|
|
logger.info(f"Worker params: {worker_params}")
|
|
system_app = SystemApp()
|
|
if not app:
|
|
# Run worker manager independently
|
|
embedded_mod = False
|
|
app = _setup_fastapi(worker_params, system_app=system_app)
|
|
system_app._asgi_app = app
|
|
|
|
initialize_tracer(
|
|
os.path.join(LOGDIR, worker_params.tracer_file),
|
|
system_app=system_app,
|
|
root_operation_name="DB-GPT-ModelWorker",
|
|
tracer_storage_cls=worker_params.tracer_storage_cls,
|
|
enable_open_telemetry=worker_params.tracer_to_open_telemetry,
|
|
otlp_endpoint=worker_params.otel_exporter_otlp_traces_endpoint,
|
|
otlp_insecure=worker_params.otel_exporter_otlp_traces_insecure,
|
|
otlp_timeout=worker_params.otel_exporter_otlp_traces_timeout,
|
|
)
|
|
|
|
_start_local_worker(worker_manager, worker_params)
|
|
_start_local_embedding_worker(
|
|
worker_manager, embedding_model_name, embedding_model_path
|
|
)
|
|
|
|
worker_manager.after_start(start_listener)
|
|
|
|
if include_router:
|
|
app.include_router(router, prefix="/api")
|
|
|
|
if not embedded_mod:
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
app, host=worker_params.host, port=worker_params.port, log_level="info"
|
|
)
|
|
else:
|
|
# Embedded mod, start worker manager
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(worker_manager.start())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_worker_manager()
|