Files
DB-GPT/pilot/model/cli.py

204 lines
5.2 KiB
Python

import click
import functools
from pilot.model.controller.registry import ModelRegistryClient
from pilot.model.worker.manager import (
RemoteWorkerManager,
WorkerApplyRequest,
WorkerApplyType,
)
from pilot.model.parameter import (
ModelControllerParameters,
ModelWorkerParameters,
ModelParameters,
)
from pilot.utils import get_or_create_event_loop
from pilot.utils.parameter_utils import EnvArgumentParser
@click.group("model")
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
show_default=True,
help=(
"Address of the Model Controller to connect to. "
"Just support light deploy model"
),
)
def model_cli_group():
"""Clients that manage model serving"""
pass
@model_cli_group.command()
@click.option(
"--model-name", type=str, default=None, required=False, help=("The name of model")
)
@click.option(
"--model-type", type=str, default="llm", required=False, help=("The type of model")
)
def list(address: str, model_name: str, model_type: str):
"""List model instances"""
from prettytable import PrettyTable
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
if not model_name:
instances = loop.run_until_complete(registry.get_all_model_instances())
else:
if not model_type:
model_type = "llm"
register_model_name = f"{model_name}@{model_type}"
instances = loop.run_until_complete(
registry.get_all_instances(register_model_name)
)
table = PrettyTable()
table.field_names = [
"Model Name",
"Model Type",
"Host",
"Port",
"Healthy",
"Enabled",
"Prompt Template",
"Last Heartbeat",
]
for instance in instances:
model_name, model_type = instance.model_name.split("@")
table.add_row(
[
model_name,
model_type,
instance.host,
instance.port,
instance.healthy,
instance.enabled,
instance.prompt_template,
instance.last_heartbeat,
]
)
print(table)
def add_model_options(func):
@click.option(
"--model-name",
type=str,
default=None,
required=True,
help=("The name of model"),
)
@click.option(
"--model-type",
type=str,
default="llm",
required=False,
help=("The type of model"),
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@model_cli_group.command()
@add_model_options
def stop(address: str, model_name: str, model_type: str):
"""Stop model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
@model_cli_group.command()
@add_model_options
def start(address: str, model_name: str, model_type: str):
"""Start model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.START)
@model_cli_group.command()
@add_model_options
def restart(address: str, model_name: str, model_type: str):
"""Restart model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
# @model_cli_group.command()
# @add_model_options
# def modify(address: str, model_name: str, model_type: str):
# """Restart model instances"""
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
def worker_apply(
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
):
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
apply_req = WorkerApplyRequest(
model=model_name, worker_type=model_type, apply_type=apply_type
)
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
print(res)
@click.command(name="controller")
@EnvArgumentParser.create_click_option(ModelControllerParameters)
def start_model_controller(**kwargs):
"""Start model controller"""
from pilot.model.controller.controller import run_model_controller
run_model_controller()
@click.command(name="controller")
def stop_model_controller(**kwargs):
"""Start model controller"""
raise NotImplementedError
@click.command(name="worker")
@EnvArgumentParser.create_click_option(ModelWorkerParameters, ModelParameters)
def start_model_worker(**kwargs):
"""Start model worker"""
from pilot.model.worker.manager import run_worker_manager
run_worker_manager()
@click.command(name="worker")
def stop_model_worker(**kwargs):
"""Stop model worker"""
raise NotImplementedError
@click.command(name="webserver")
def start_webserver(**kwargs):
"""Start webserver(dbgpt_server.py)"""
raise NotImplementedError
@click.command(name="webserver")
def stop_webserver(**kwargs):
"""Stop webserver(dbgpt_server.py)"""
raise NotImplementedError
@click.command(name="apiserver")
def start_apiserver(**kwargs):
"""Start apiserver(TODO)"""
raise NotImplementedError
@click.command(name="controller")
def stop_apiserver(**kwargs):
"""Start apiserver(TODO)"""
raise NotImplementedError