feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -952,7 +952,10 @@ def _create_local_model_manager(
)
def _build_worker(worker_params: ModelWorkerParameters):
def _build_worker(
worker_params: ModelWorkerParameters,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
worker_class = worker_params.worker_class
if worker_class:
from dbgpt.util.module_utils import import_from_checked_string
@@ -976,11 +979,16 @@ def _build_worker(worker_params: ModelWorkerParameters):
else:
raise Exception("Unsupported worker type: {worker_params.worker_type}")
return worker_cls()
if ext_worker_kwargs:
return worker_cls(**ext_worker_kwargs)
else:
return worker_cls()
def _start_local_worker(
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
worker_manager: WorkerManagerAdapter,
worker_params: ModelWorkerParameters,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
with root_tracer.start_span(
"WorkerManager._start_local_worker",
@@ -991,7 +999,7 @@ def _start_local_worker(
"sys_infos": _get_dict_from_obj(get_system_info()),
},
):
worker = _build_worker(worker_params)
worker = _build_worker(worker_params, ext_worker_kwargs=ext_worker_kwargs)
if not worker_manager.worker_manager:
worker_manager.worker_manager = _create_local_model_manager(worker_params)
worker_manager.worker_manager.add_worker(worker, worker_params)
@@ -1001,6 +1009,7 @@ def _start_local_embedding_worker(
worker_manager: WorkerManagerAdapter,
embedding_model_name: str = None,
embedding_model_path: str = None,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
if not embedding_model_name or not embedding_model_path:
return
@@ -1013,21 +1022,25 @@ def _start_local_embedding_worker(
logger.info(
f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
)
_start_local_worker(worker_manager, embedding_worker_params)
_start_local_worker(
worker_manager, embedding_worker_params, ext_worker_kwargs=ext_worker_kwargs
)
def initialize_worker_manager_in_client(
app=None,
include_router: bool = True,
model_name: str = None,
model_path: str = None,
model_name: Optional[str] = None,
model_path: Optional[str] = None,
run_locally: bool = True,
controller_addr: str = None,
controller_addr: Optional[str] = None,
local_port: int = 5670,
embedding_model_name: str = None,
embedding_model_path: str = None,
start_listener: Callable[["WorkerManager"], None] = None,
system_app: SystemApp = None,
embedding_model_name: Optional[str] = None,
embedding_model_path: Optional[str] = None,
rerank_model_name: Optional[str] = None,
rerank_model_path: Optional[str] = None,
start_listener: Optional[Callable[["WorkerManager"], None]] = None,
system_app: Optional[SystemApp] = None,
):
"""Initialize WorkerManager in client.
If run_locally is True:
@@ -1063,6 +1076,12 @@ def initialize_worker_manager_in_client(
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
_start_local_embedding_worker(
worker_manager,
rerank_model_name,
rerank_model_path,
ext_worker_kwargs={"rerank_model": True},
)
else:
from dbgpt.model.cluster.controller.controller import (
ModelRegistryClient,
@@ -1072,7 +1091,6 @@ def initialize_worker_manager_in_client(
if not worker_params.controller_addr:
raise ValueError("Controller can`t be None")
controller_addr = worker_params.controller_addr
logger.info(f"Worker params: {worker_params}")
client = ModelRegistryClient(worker_params.controller_addr)
worker_manager.worker_manager = RemoteWorkerManager(client)