mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 11:01:09 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
@@ -952,7 +952,10 @@ def _create_local_model_manager(
|
||||
)
|
||||
|
||||
|
||||
def _build_worker(worker_params: ModelWorkerParameters):
|
||||
def _build_worker(
|
||||
worker_params: ModelWorkerParameters,
|
||||
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
worker_class = worker_params.worker_class
|
||||
if worker_class:
|
||||
from dbgpt.util.module_utils import import_from_checked_string
|
||||
@@ -976,11 +979,16 @@ def _build_worker(worker_params: ModelWorkerParameters):
|
||||
else:
|
||||
raise Exception("Unsupported worker type: {worker_params.worker_type}")
|
||||
|
||||
return worker_cls()
|
||||
if ext_worker_kwargs:
|
||||
return worker_cls(**ext_worker_kwargs)
|
||||
else:
|
||||
return worker_cls()
|
||||
|
||||
|
||||
def _start_local_worker(
|
||||
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
|
||||
worker_manager: WorkerManagerAdapter,
|
||||
worker_params: ModelWorkerParameters,
|
||||
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager._start_local_worker",
|
||||
@@ -991,7 +999,7 @@ def _start_local_worker(
|
||||
"sys_infos": _get_dict_from_obj(get_system_info()),
|
||||
},
|
||||
):
|
||||
worker = _build_worker(worker_params)
|
||||
worker = _build_worker(worker_params, ext_worker_kwargs=ext_worker_kwargs)
|
||||
if not worker_manager.worker_manager:
|
||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||
worker_manager.worker_manager.add_worker(worker, worker_params)
|
||||
@@ -1001,6 +1009,7 @@ def _start_local_embedding_worker(
|
||||
worker_manager: WorkerManagerAdapter,
|
||||
embedding_model_name: str = None,
|
||||
embedding_model_path: str = None,
|
||||
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if not embedding_model_name or not embedding_model_path:
|
||||
return
|
||||
@@ -1013,21 +1022,25 @@ def _start_local_embedding_worker(
|
||||
logger.info(
|
||||
f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
|
||||
)
|
||||
_start_local_worker(worker_manager, embedding_worker_params)
|
||||
_start_local_worker(
|
||||
worker_manager, embedding_worker_params, ext_worker_kwargs=ext_worker_kwargs
|
||||
)
|
||||
|
||||
|
||||
def initialize_worker_manager_in_client(
|
||||
app=None,
|
||||
include_router: bool = True,
|
||||
model_name: str = None,
|
||||
model_path: str = None,
|
||||
model_name: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
run_locally: bool = True,
|
||||
controller_addr: str = None,
|
||||
controller_addr: Optional[str] = None,
|
||||
local_port: int = 5670,
|
||||
embedding_model_name: str = None,
|
||||
embedding_model_path: str = None,
|
||||
start_listener: Callable[["WorkerManager"], None] = None,
|
||||
system_app: SystemApp = None,
|
||||
embedding_model_name: Optional[str] = None,
|
||||
embedding_model_path: Optional[str] = None,
|
||||
rerank_model_name: Optional[str] = None,
|
||||
rerank_model_path: Optional[str] = None,
|
||||
start_listener: Optional[Callable[["WorkerManager"], None]] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
):
|
||||
"""Initialize WorkerManager in client.
|
||||
If run_locally is True:
|
||||
@@ -1063,6 +1076,12 @@ def initialize_worker_manager_in_client(
|
||||
_start_local_embedding_worker(
|
||||
worker_manager, embedding_model_name, embedding_model_path
|
||||
)
|
||||
_start_local_embedding_worker(
|
||||
worker_manager,
|
||||
rerank_model_name,
|
||||
rerank_model_path,
|
||||
ext_worker_kwargs={"rerank_model": True},
|
||||
)
|
||||
else:
|
||||
from dbgpt.model.cluster.controller.controller import (
|
||||
ModelRegistryClient,
|
||||
@@ -1072,7 +1091,6 @@ def initialize_worker_manager_in_client(
|
||||
|
||||
if not worker_params.controller_addr:
|
||||
raise ValueError("Controller can`t be None")
|
||||
controller_addr = worker_params.controller_addr
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
client = ModelRegistryClient(worker_params.controller_addr)
|
||||
worker_manager.worker_manager = RemoteWorkerManager(client)
|
||||
|
Reference in New Issue
Block a user