mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-29 05:18:47 +00:00
204 lines
5.2 KiB
Python
204 lines
5.2 KiB
Python
import click
|
|
import functools
|
|
|
|
from pilot.model.controller.registry import ModelRegistryClient
|
|
from pilot.model.worker.manager import (
|
|
RemoteWorkerManager,
|
|
WorkerApplyRequest,
|
|
WorkerApplyType,
|
|
)
|
|
from pilot.model.parameter import (
|
|
ModelControllerParameters,
|
|
ModelWorkerParameters,
|
|
ModelParameters,
|
|
)
|
|
from pilot.utils import get_or_create_event_loop
|
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
|
|
|
|
|
@click.group("model")
|
|
@click.option(
|
|
"--address",
|
|
type=str,
|
|
default="http://127.0.0.1:8000",
|
|
required=False,
|
|
show_default=True,
|
|
help=(
|
|
"Address of the Model Controller to connect to. "
|
|
"Just support light deploy model"
|
|
),
|
|
)
|
|
def model_cli_group():
|
|
"""Clients that manage model serving"""
|
|
pass
|
|
|
|
|
|
@model_cli_group.command()
|
|
@click.option(
|
|
"--model-name", type=str, default=None, required=False, help=("The name of model")
|
|
)
|
|
@click.option(
|
|
"--model-type", type=str, default="llm", required=False, help=("The type of model")
|
|
)
|
|
def list(address: str, model_name: str, model_type: str):
|
|
"""List model instances"""
|
|
from prettytable import PrettyTable
|
|
|
|
loop = get_or_create_event_loop()
|
|
registry = ModelRegistryClient(address)
|
|
|
|
if not model_name:
|
|
instances = loop.run_until_complete(registry.get_all_model_instances())
|
|
else:
|
|
if not model_type:
|
|
model_type = "llm"
|
|
register_model_name = f"{model_name}@{model_type}"
|
|
instances = loop.run_until_complete(
|
|
registry.get_all_instances(register_model_name)
|
|
)
|
|
table = PrettyTable()
|
|
|
|
table.field_names = [
|
|
"Model Name",
|
|
"Model Type",
|
|
"Host",
|
|
"Port",
|
|
"Healthy",
|
|
"Enabled",
|
|
"Prompt Template",
|
|
"Last Heartbeat",
|
|
]
|
|
for instance in instances:
|
|
model_name, model_type = instance.model_name.split("@")
|
|
table.add_row(
|
|
[
|
|
model_name,
|
|
model_type,
|
|
instance.host,
|
|
instance.port,
|
|
instance.healthy,
|
|
instance.enabled,
|
|
instance.prompt_template,
|
|
instance.last_heartbeat,
|
|
]
|
|
)
|
|
|
|
print(table)
|
|
|
|
|
|
def add_model_options(func):
|
|
@click.option(
|
|
"--model-name",
|
|
type=str,
|
|
default=None,
|
|
required=True,
|
|
help=("The name of model"),
|
|
)
|
|
@click.option(
|
|
"--model-type",
|
|
type=str,
|
|
default="llm",
|
|
required=False,
|
|
help=("The type of model"),
|
|
)
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
@model_cli_group.command()
|
|
@add_model_options
|
|
def stop(address: str, model_name: str, model_type: str):
|
|
"""Stop model instances"""
|
|
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
|
|
|
|
|
|
@model_cli_group.command()
|
|
@add_model_options
|
|
def start(address: str, model_name: str, model_type: str):
|
|
"""Start model instances"""
|
|
worker_apply(address, model_name, model_type, WorkerApplyType.START)
|
|
|
|
|
|
@model_cli_group.command()
|
|
@add_model_options
|
|
def restart(address: str, model_name: str, model_type: str):
|
|
"""Restart model instances"""
|
|
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
|
|
|
|
|
|
# @model_cli_group.command()
|
|
# @add_model_options
|
|
# def modify(address: str, model_name: str, model_type: str):
|
|
# """Restart model instances"""
|
|
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
|
|
|
|
|
|
def worker_apply(
|
|
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
|
):
|
|
loop = get_or_create_event_loop()
|
|
registry = ModelRegistryClient(address)
|
|
worker_manager = RemoteWorkerManager(registry)
|
|
apply_req = WorkerApplyRequest(
|
|
model=model_name, worker_type=model_type, apply_type=apply_type
|
|
)
|
|
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
|
|
print(res)
|
|
|
|
|
|
@click.command(name="controller")
|
|
@EnvArgumentParser.create_click_option(ModelControllerParameters)
|
|
def start_model_controller(**kwargs):
|
|
"""Start model controller"""
|
|
from pilot.model.controller.controller import run_model_controller
|
|
|
|
run_model_controller()
|
|
|
|
|
|
@click.command(name="controller")
|
|
def stop_model_controller(**kwargs):
|
|
"""Start model controller"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@click.command(name="worker")
|
|
@EnvArgumentParser.create_click_option(ModelWorkerParameters, ModelParameters)
|
|
def start_model_worker(**kwargs):
|
|
"""Start model worker"""
|
|
from pilot.model.worker.manager import run_worker_manager
|
|
|
|
run_worker_manager()
|
|
|
|
|
|
@click.command(name="worker")
|
|
def stop_model_worker(**kwargs):
|
|
"""Stop model worker"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@click.command(name="webserver")
|
|
def start_webserver(**kwargs):
|
|
"""Start webserver(dbgpt_server.py)"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@click.command(name="webserver")
|
|
def stop_webserver(**kwargs):
|
|
"""Stop webserver(dbgpt_server.py)"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@click.command(name="apiserver")
|
|
def start_apiserver(**kwargs):
|
|
"""Start apiserver(TODO)"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@click.command(name="controller")
|
|
def stop_apiserver(**kwargs):
|
|
"""Start apiserver(TODO)"""
|
|
raise NotImplementedError
|