feat(model): Add support for multiple versions of the same model (#1246)

This commit is contained in:
2024-03-08 18:40:10 +08:00 committed by GitHub
parent 7446817340
commit 04b6402720
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 2 deletions

View File

@ -217,7 +217,6 @@ class LocalWorkerManager(WorkerManager):
) )
if not worker_params.model_name: if not worker_params.model_name:
worker_params.model_name = model_name worker_params.model_name = model_name
assert model_name == worker_params.model_name
worker = _build_worker(worker_params) worker = _build_worker(worker_params)
command_args = _dict_to_command_args(params) command_args = _dict_to_command_args(params)
success = await self.run_blocking_func( success = await self.run_blocking_func(
@ -235,7 +234,9 @@ class LocalWorkerManager(WorkerManager):
f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}" f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}"
) )
start_apply_req = WorkerApplyRequest( start_apply_req = WorkerApplyRequest(
model=model_name, apply_type=WorkerApplyType.START, worker_type=worker_type model=worker_params.model_name,
apply_type=WorkerApplyType.START,
worker_type=worker_type,
) )
out: WorkerApplyOutput = None out: WorkerApplyOutput = None
try: try:
@ -895,6 +896,8 @@ def _parse_worker_params(
**kwargs, **kwargs,
) )
worker_params.update_from(new_worker_params) worker_params.update_from(new_worker_params)
if worker_params.model_alias:
worker_params.model_name = worker_params.model_alias
# logger.info(f"Worker params: {worker_params}") # logger.info(f"Worker params: {worker_params}")
return worker_params return worker_params

View File

@ -164,6 +164,10 @@ class ModelWorkerParameters(BaseModelParameters):
default=None, default=None,
metadata={"valid_values": WorkerType.values(), "help": "Worker type"}, metadata={"valid_values": WorkerType.values(), "help": "Worker type"},
) )
model_alias: Optional[str] = field(
default=None,
metadata={"help": "model alias"},
)
worker_class: Optional[str] = field( worker_class: Optional[str] = field(
default=None, default=None,
metadata={"help": "Model worker class, dbgpt.model.cluster.DefaultModelWorker"}, metadata={"help": "Model worker class, dbgpt.model.cluster.DefaultModelWorker"},