DB-GPT/dbgpt/model/cli.py

475 lines
14 KiB
Python

import functools
import logging
import os
from typing import Callable, List, Optional, Type
import click
from dbgpt.configs.model_config import LOGDIR
from dbgpt.model.base import WorkerApplyType
from dbgpt.model.parameter import (
BaseParameters,
ModelAPIServerParameters,
ModelControllerParameters,
ModelParameters,
ModelWorkerParameters,
)
from dbgpt.util import get_or_create_event_loop
from dbgpt.util.command_utils import (
_detect_controller_address,
_run_current_with_daemon,
_stop_service,
)
from dbgpt.util.parameter_utils import (
EnvArgumentParser,
_build_parameter_class,
build_lazy_click_command,
)
# Your can set environment variable CONTROLLER_ADDRESS to set the default address
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
logger = logging.getLogger("dbgpt_cli")
def _get_worker_manager(address: str):
from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
return worker_manager
@click.group("model")
@click.option(
"--address",
type=str,
default=None,
required=False,
show_default=True,
help=(
"Address of the Model Controller to connect to. "
"Just support light deploy model, If the environment variable CONTROLLER_ADDRESS is configured, read from the environment variable"
),
)
def model_cli_group(address: str):
"""Clients that manage model serving"""
global MODEL_CONTROLLER_ADDRESS
if not address:
MODEL_CONTROLLER_ADDRESS = _detect_controller_address()
else:
MODEL_CONTROLLER_ADDRESS = address
@model_cli_group.command()
@click.option(
"--model_name", type=str, default=None, required=False, help=("The name of model")
)
@click.option(
"--model_type", type=str, default="llm", required=False, help=("The type of model")
)
def list(model_name: str, model_type: str):
"""List model instances"""
from prettytable import PrettyTable
from dbgpt.model.cluster import ModelRegistryClient
loop = get_or_create_event_loop()
registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
if not model_name:
instances = loop.run_until_complete(registry.get_all_model_instances())
else:
if not model_type:
model_type = "llm"
register_model_name = f"{model_name}@{model_type}"
instances = loop.run_until_complete(
registry.get_all_instances(register_model_name)
)
table = PrettyTable()
table.field_names = [
"Model Name",
"Model Type",
"Host",
"Port",
"Healthy",
"Enabled",
"Prompt Template",
"Last Heartbeat",
]
for instance in instances:
model_name, model_type = instance.model_name.split("@")
table.add_row(
[
model_name,
model_type,
instance.host,
instance.port,
instance.healthy,
instance.enabled,
instance.prompt_template if instance.prompt_template else "",
instance.last_heartbeat,
]
)
print(table)
def add_model_options(func):
@click.option(
"--model_name",
type=str,
default=None,
required=True,
help=("The name of model"),
)
@click.option(
"--model_type",
type=str,
default="llm",
required=False,
help=("The type of model"),
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@model_cli_group.command()
@add_model_options
@click.option(
"--host",
type=str,
required=True,
help=("The remote host to stop model"),
)
@click.option(
"--port",
type=int,
required=True,
help=("The remote port to stop model"),
)
def stop(model_name: str, model_type: str, host: str, port: int):
"""Stop model instances"""
from dbgpt.model.cluster import RemoteWorkerManager, WorkerStartupRequest
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
req = WorkerStartupRequest(
host=host,
port=port,
worker_type=model_type,
model=model_name,
params={},
)
loop = get_or_create_event_loop()
res = loop.run_until_complete(worker_manager.model_shutdown(req))
print(res)
def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
from dataclasses import dataclass, field
from dbgpt.model.cluster import RemoteWorkerManager
from dbgpt.model.parameter import WorkerType
from dbgpt.util.parameter_utils import _SimpleArgParser
pre_args = _SimpleArgParser("model_name", "address", "host", "port")
pre_args.parse()
model_name = pre_args.get("model_name")
address = pre_args.get("address")
host = pre_args.get("host")
port = pre_args.get("port")
if port:
port = int(port)
if not address:
address = _detect_controller_address()
worker_manager: RemoteWorkerManager = _get_worker_manager(address)
loop = get_or_create_event_loop()
models = loop.run_until_complete(worker_manager.supported_models())
fields_dict = {}
fields_dict["model_name"] = (
str,
field(default=None, metadata={"help": "The model name to deploy"}),
)
fields_dict["host"] = (
str,
field(default=None, metadata={"help": "The remote host to deploy model"}),
)
fields_dict["port"] = (
int,
field(default=None, metadata={"help": "The remote port to deploy model"}),
)
result_class = dataclass(
type("RemoteModelWorkerParameters", (object,), fields_dict)
)
if not models:
return [result_class]
valid_models = []
valid_model_cls = []
for model in models:
if host and host != model.host:
continue
if port and port != model.port:
continue
valid_models += [m.model for m in model.models]
valid_model_cls += [
(m, _build_parameter_class(m.params)) for m in model.models if m.params
]
real_model, real_params_cls = valid_model_cls[0]
real_path = None
real_worker_type = "llm"
if model_name:
params_cls_list = [m for m in valid_model_cls if m[0].model == model_name]
if not params_cls_list:
raise ValueError(f"Not supported model with model name: {model_name}")
real_model, real_params_cls = params_cls_list[0]
real_path = real_model.path
real_worker_type = real_model.worker_type
@dataclass
class RemoteModelWorkerParameters(BaseParameters):
model_name: str = field(
metadata={"valid_values": valid_models, "help": "The model name to deploy"}
)
model_path: Optional[str] = field(
default=real_path, metadata={"help": "The model path to deploy"}
)
host: Optional[str] = field(
default=models[0].host,
metadata={
"valid_values": [model.host for model in models],
"help": "The remote host to deploy model",
},
)
port: Optional[int] = field(
default=models[0].port,
metadata={
"valid_values": [model.port for model in models],
"help": "The remote port to deploy model",
},
)
worker_type: Optional[str] = field(
default=real_worker_type,
metadata={
"valid_values": WorkerType.values(),
"help": "Worker type",
},
)
return [RemoteModelWorkerParameters, real_params_cls]
@model_cli_group.command(
cls=build_lazy_click_command(_dynamic_factory=_remote_model_dynamic_factory)
)
def start(**kwargs):
"""Start model instances"""
from dbgpt.model.cluster import RemoteWorkerManager, WorkerStartupRequest
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
req = WorkerStartupRequest(
host=kwargs["host"],
port=kwargs["port"],
worker_type=kwargs["worker_type"],
model=kwargs["model_name"],
params={},
)
del kwargs["host"]
del kwargs["port"]
del kwargs["worker_type"]
req.params = kwargs
loop = get_or_create_event_loop()
res = loop.run_until_complete(worker_manager.model_startup(req))
print(res)
@model_cli_group.command()
@add_model_options
def restart(model_name: str, model_type: str):
"""Restart model instances"""
worker_apply(
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.RESTART
)
@model_cli_group.command()
@click.option(
"--model_name",
type=str,
default=None,
required=True,
help=("The name of model"),
)
@click.option(
"--system",
type=str,
default=None,
required=False,
help=("System prompt"),
)
def chat(model_name: str, system: str):
"""Interact with your bot from the command line"""
_cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system)
def worker_apply(
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
):
from dbgpt.model.cluster import WorkerApplyRequest
loop = get_or_create_event_loop()
worker_manager = _get_worker_manager(address)
apply_req = WorkerApplyRequest(
model=model_name, worker_type=model_type, apply_type=apply_type
)
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
print(res)
def _cli_chat(address: str, model_name: str, system_prompt: str = None):
loop = get_or_create_event_loop()
worker_manager = worker_manager = _get_worker_manager(address)
loop.run_until_complete(_chat_stream(worker_manager, model_name, system_prompt))
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.model.cluster import PromptRequest
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
hist = []
previous_response = ""
if system_prompt:
hist.append(
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
)
while True:
previous_response = ""
user_input = input("\n\nYou: ")
if user_input.lower().strip() == "exit":
break
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
request = PromptRequest(messages=hist, model=model_name, prompt="", echo=False)
request = request.dict(exclude_none=True)
print("Bot: ", end="")
async for response in worker_manager.generate_stream(request):
incremental_output = response.text[len(previous_response) :]
print(incremental_output, end="", flush=True)
previous_response = response.text
hist.append(
ModelMessage(role=ModelMessageRoleType.AI, content=previous_response)
)
def add_stop_server_options(func):
@click.option(
"--port",
type=int,
default=None,
required=False,
help=("The port to stop"),
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@click.command(name="controller")
@EnvArgumentParser.create_click_option(ModelControllerParameters)
def start_model_controller(**kwargs):
"""Start model controller"""
if kwargs["daemon"]:
log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log")
_run_current_with_daemon("ModelController", log_file)
else:
from dbgpt.model.cluster import run_model_controller
run_model_controller()
@click.command(name="controller")
@add_stop_server_options
def stop_model_controller(port: int):
"""Start model controller"""
# Command fragments to check against running processes
_stop_service("controller", "ModelController", port=port)
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
from dbgpt.model.adapter.model_adapter import _dynamic_model_parser
param_class = _dynamic_model_parser()
fix_class = [ModelWorkerParameters]
if not param_class:
param_class = [ModelParameters]
fix_class += param_class
return fix_class
@click.command(
name="worker", cls=build_lazy_click_command(_dynamic_factory=_model_dynamic_factory)
)
def start_model_worker(**kwargs):
"""Start model worker"""
if kwargs["daemon"]:
port = kwargs["port"]
model_type = kwargs.get("worker_type") or "llm"
log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log")
_run_current_with_daemon("ModelWorker", log_file)
else:
from dbgpt.model.cluster import run_worker_manager
run_worker_manager()
@click.command(name="worker")
@add_stop_server_options
def stop_model_worker(port: int):
"""Stop model worker"""
name = "ModelWorker"
if port:
name = f"{name}-{port}"
_stop_service("worker", name, port=port)
@click.command(name="apiserver")
@EnvArgumentParser.create_click_option(ModelAPIServerParameters)
def start_apiserver(**kwargs):
"""Start apiserver"""
if kwargs["daemon"]:
log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log")
_run_current_with_daemon("ModelAPIServer", log_file)
else:
from dbgpt.model.cluster import run_apiserver
run_apiserver()
@click.command(name="apiserver")
@add_stop_server_options
def stop_apiserver(port: int):
"""Stop apiserver"""
name = "ModelAPIServer"
if port:
name = f"{name}-{port}"
_stop_service("apiserver", name, port=port)
def _stop_all_model_server(**kwargs):
"""Stop all server"""
_stop_service("worker", "ModelWorker")
_stop_service("controller", "ModelController")