mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
471
dbgpt/model/cli.py
Normal file
471
dbgpt/model/cli.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import click
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Type, Optional
|
||||
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
from dbgpt.model.base import WorkerApplyType
|
||||
from dbgpt.model.parameter import (
|
||||
ModelControllerParameters,
|
||||
ModelAPIServerParameters,
|
||||
ModelWorkerParameters,
|
||||
ModelParameters,
|
||||
BaseParameters,
|
||||
)
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
from dbgpt.util.parameter_utils import (
|
||||
EnvArgumentParser,
|
||||
_build_parameter_class,
|
||||
build_lazy_click_command,
|
||||
)
|
||||
from dbgpt.util.command_utils import (
|
||||
_run_current_with_daemon,
|
||||
_stop_service,
|
||||
_detect_controller_address,
|
||||
)
|
||||
|
||||
|
||||
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
||||
|
||||
logger = logging.getLogger("dbgpt_cli")
|
||||
|
||||
|
||||
def _get_worker_manager(address: str):
|
||||
from dbgpt.model.cluster import RemoteWorkerManager, ModelRegistryClient
|
||||
|
||||
registry = ModelRegistryClient(address)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
return worker_manager
|
||||
|
||||
|
||||
@click.group("model")
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
show_default=True,
|
||||
help=(
|
||||
"Address of the Model Controller to connect to. "
|
||||
"Just support light deploy model, If the environment variable CONTROLLER_ADDRESS is configured, read from the environment variable"
|
||||
),
|
||||
)
|
||||
def model_cli_group(address: str):
|
||||
"""Clients that manage model serving"""
|
||||
global MODEL_CONTROLLER_ADDRESS
|
||||
if not address:
|
||||
MODEL_CONTROLLER_ADDRESS = _detect_controller_address()
|
||||
else:
|
||||
MODEL_CONTROLLER_ADDRESS = address
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@click.option(
|
||||
"--model_name", type=str, default=None, required=False, help=("The name of model")
|
||||
)
|
||||
@click.option(
|
||||
"--model_type", type=str, default="llm", required=False, help=("The type of model")
|
||||
)
|
||||
def list(model_name: str, model_type: str):
|
||||
"""List model instances"""
|
||||
from prettytable import PrettyTable
|
||||
from dbgpt.model.cluster import ModelRegistryClient
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
|
||||
|
||||
if not model_name:
|
||||
instances = loop.run_until_complete(registry.get_all_model_instances())
|
||||
else:
|
||||
if not model_type:
|
||||
model_type = "llm"
|
||||
register_model_name = f"{model_name}@{model_type}"
|
||||
instances = loop.run_until_complete(
|
||||
registry.get_all_instances(register_model_name)
|
||||
)
|
||||
table = PrettyTable()
|
||||
|
||||
table.field_names = [
|
||||
"Model Name",
|
||||
"Model Type",
|
||||
"Host",
|
||||
"Port",
|
||||
"Healthy",
|
||||
"Enabled",
|
||||
"Prompt Template",
|
||||
"Last Heartbeat",
|
||||
]
|
||||
for instance in instances:
|
||||
model_name, model_type = instance.model_name.split("@")
|
||||
table.add_row(
|
||||
[
|
||||
model_name,
|
||||
model_type,
|
||||
instance.host,
|
||||
instance.port,
|
||||
instance.healthy,
|
||||
instance.enabled,
|
||||
instance.prompt_template if instance.prompt_template else "",
|
||||
instance.last_heartbeat,
|
||||
]
|
||||
)
|
||||
|
||||
print(table)
|
||||
|
||||
|
||||
def add_model_options(func):
|
||||
@click.option(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help=("The name of model"),
|
||||
)
|
||||
@click.option(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default="llm",
|
||||
required=False,
|
||||
help=("The type of model"),
|
||||
)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
@click.option(
|
||||
"--host",
|
||||
type=str,
|
||||
required=True,
|
||||
help=("The remote host to stop model"),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
required=True,
|
||||
help=("The remote port to stop model"),
|
||||
)
|
||||
def stop(model_name: str, model_type: str, host: str, port: int):
|
||||
"""Stop model instances"""
|
||||
from dbgpt.model.cluster import WorkerStartupRequest, RemoteWorkerManager
|
||||
|
||||
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
|
||||
req = WorkerStartupRequest(
|
||||
host=host,
|
||||
port=port,
|
||||
worker_type=model_type,
|
||||
model=model_name,
|
||||
params={},
|
||||
)
|
||||
loop = get_or_create_event_loop()
|
||||
res = loop.run_until_complete(worker_manager.model_shutdown(req))
|
||||
print(res)
|
||||
|
||||
|
||||
def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
|
||||
from dbgpt.util.parameter_utils import _SimpleArgParser
|
||||
from dbgpt.model.cluster import RemoteWorkerManager
|
||||
from dbgpt.model.parameter import WorkerType
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
pre_args = _SimpleArgParser("model_name", "address", "host", "port")
|
||||
pre_args.parse()
|
||||
model_name = pre_args.get("model_name")
|
||||
address = pre_args.get("address")
|
||||
host = pre_args.get("host")
|
||||
port = pre_args.get("port")
|
||||
if port:
|
||||
port = int(port)
|
||||
|
||||
if not address:
|
||||
address = _detect_controller_address()
|
||||
|
||||
worker_manager: RemoteWorkerManager = _get_worker_manager(address)
|
||||
loop = get_or_create_event_loop()
|
||||
models = loop.run_until_complete(worker_manager.supported_models())
|
||||
|
||||
fields_dict = {}
|
||||
fields_dict["model_name"] = (
|
||||
str,
|
||||
field(default=None, metadata={"help": "The model name to deploy"}),
|
||||
)
|
||||
fields_dict["host"] = (
|
||||
str,
|
||||
field(default=None, metadata={"help": "The remote host to deploy model"}),
|
||||
)
|
||||
fields_dict["port"] = (
|
||||
int,
|
||||
field(default=None, metadata={"help": "The remote port to deploy model"}),
|
||||
)
|
||||
result_class = dataclass(
|
||||
type("RemoteModelWorkerParameters", (object,), fields_dict)
|
||||
)
|
||||
|
||||
if not models:
|
||||
return [result_class]
|
||||
|
||||
valid_models = []
|
||||
valid_model_cls = []
|
||||
for model in models:
|
||||
if host and host != model.host:
|
||||
continue
|
||||
if port and port != model.port:
|
||||
continue
|
||||
valid_models += [m.model for m in model.models]
|
||||
valid_model_cls += [
|
||||
(m, _build_parameter_class(m.params)) for m in model.models if m.params
|
||||
]
|
||||
real_model, real_params_cls = valid_model_cls[0]
|
||||
real_path = None
|
||||
real_worker_type = "llm"
|
||||
if model_name:
|
||||
params_cls_list = [m for m in valid_model_cls if m[0].model == model_name]
|
||||
if not params_cls_list:
|
||||
raise ValueError(f"Not supported model with model name: {model_name}")
|
||||
real_model, real_params_cls = params_cls_list[0]
|
||||
real_path = real_model.path
|
||||
real_worker_type = real_model.worker_type
|
||||
|
||||
@dataclass
|
||||
class RemoteModelWorkerParameters(BaseParameters):
|
||||
model_name: str = field(
|
||||
metadata={"valid_values": valid_models, "help": "The model name to deploy"}
|
||||
)
|
||||
model_path: Optional[str] = field(
|
||||
default=real_path, metadata={"help": "The model path to deploy"}
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
default=models[0].host,
|
||||
metadata={
|
||||
"valid_values": [model.host for model in models],
|
||||
"help": "The remote host to deploy model",
|
||||
},
|
||||
)
|
||||
|
||||
port: Optional[int] = field(
|
||||
default=models[0].port,
|
||||
metadata={
|
||||
"valid_values": [model.port for model in models],
|
||||
"help": "The remote port to deploy model",
|
||||
},
|
||||
)
|
||||
worker_type: Optional[str] = field(
|
||||
default=real_worker_type,
|
||||
metadata={
|
||||
"valid_values": WorkerType.values(),
|
||||
"help": "Worker type",
|
||||
},
|
||||
)
|
||||
|
||||
return [RemoteModelWorkerParameters, real_params_cls]
|
||||
|
||||
|
||||
@model_cli_group.command(
|
||||
cls=build_lazy_click_command(_dynamic_factory=_remote_model_dynamic_factory)
|
||||
)
|
||||
def start(**kwargs):
|
||||
"""Start model instances"""
|
||||
from dbgpt.model.cluster import WorkerStartupRequest, RemoteWorkerManager
|
||||
|
||||
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
|
||||
req = WorkerStartupRequest(
|
||||
host=kwargs["host"],
|
||||
port=kwargs["port"],
|
||||
worker_type=kwargs["worker_type"],
|
||||
model=kwargs["model_name"],
|
||||
params={},
|
||||
)
|
||||
del kwargs["host"]
|
||||
del kwargs["port"]
|
||||
del kwargs["worker_type"]
|
||||
req.params = kwargs
|
||||
loop = get_or_create_event_loop()
|
||||
res = loop.run_until_complete(worker_manager.model_startup(req))
|
||||
print(res)
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def restart(model_name: str, model_type: str):
|
||||
"""Restart model instances"""
|
||||
worker_apply(
|
||||
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.RESTART
|
||||
)
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@click.option(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help=("The name of model"),
|
||||
)
|
||||
@click.option(
|
||||
"--system",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("System prompt"),
|
||||
)
|
||||
def chat(model_name: str, system: str):
|
||||
"""Interact with your bot from the command line"""
|
||||
_cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system)
|
||||
|
||||
|
||||
def worker_apply(
|
||||
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
||||
):
|
||||
from dbgpt.model.cluster import WorkerApplyRequest
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
worker_manager = _get_worker_manager(address)
|
||||
apply_req = WorkerApplyRequest(
|
||||
model=model_name, worker_type=model_type, apply_type=apply_type
|
||||
)
|
||||
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
|
||||
print(res)
|
||||
|
||||
|
||||
def _cli_chat(address: str, model_name: str, system_prompt: str = None):
|
||||
loop = get_or_create_event_loop()
|
||||
worker_manager = worker_manager = _get_worker_manager(address)
|
||||
loop.run_until_complete(_chat_stream(worker_manager, model_name, system_prompt))
|
||||
|
||||
|
||||
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
|
||||
from dbgpt.model.cluster import PromptRequest
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
|
||||
hist = []
|
||||
previous_response = ""
|
||||
if system_prompt:
|
||||
hist.append(
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
|
||||
)
|
||||
while True:
|
||||
previous_response = ""
|
||||
user_input = input("\n\nYou: ")
|
||||
if user_input.lower().strip() == "exit":
|
||||
break
|
||||
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
|
||||
request = PromptRequest(messages=hist, model=model_name, prompt="", echo=False)
|
||||
request = request.dict(exclude_none=True)
|
||||
print("Bot: ", end="")
|
||||
async for response in worker_manager.generate_stream(request):
|
||||
incremental_output = response.text[len(previous_response) :]
|
||||
print(incremental_output, end="", flush=True)
|
||||
previous_response = response.text
|
||||
hist.append(
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content=previous_response)
|
||||
)
|
||||
|
||||
|
||||
def add_stop_server_options(func):
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("The port to stop"),
|
||||
)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@click.command(name="controller")
|
||||
@EnvArgumentParser.create_click_option(ModelControllerParameters)
|
||||
def start_model_controller(**kwargs):
|
||||
"""Start model controller"""
|
||||
|
||||
if kwargs["daemon"]:
|
||||
log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log")
|
||||
_run_current_with_daemon("ModelController", log_file)
|
||||
else:
|
||||
from dbgpt.model.cluster import run_model_controller
|
||||
|
||||
run_model_controller()
|
||||
|
||||
|
||||
@click.command(name="controller")
|
||||
@add_stop_server_options
|
||||
def stop_model_controller(port: int):
|
||||
"""Start model controller"""
|
||||
# Command fragments to check against running processes
|
||||
_stop_service("controller", "ModelController", port=port)
|
||||
|
||||
|
||||
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
|
||||
from dbgpt.model.model_adapter import _dynamic_model_parser
|
||||
|
||||
param_class = _dynamic_model_parser()
|
||||
fix_class = [ModelWorkerParameters]
|
||||
if not param_class:
|
||||
param_class = [ModelParameters]
|
||||
fix_class += param_class
|
||||
return fix_class
|
||||
|
||||
|
||||
@click.command(
|
||||
name="worker", cls=build_lazy_click_command(_dynamic_factory=_model_dynamic_factory)
|
||||
)
|
||||
def start_model_worker(**kwargs):
|
||||
"""Start model worker"""
|
||||
if kwargs["daemon"]:
|
||||
port = kwargs["port"]
|
||||
model_type = kwargs.get("worker_type") or "llm"
|
||||
log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log")
|
||||
_run_current_with_daemon("ModelWorker", log_file)
|
||||
else:
|
||||
from dbgpt.model.cluster import run_worker_manager
|
||||
|
||||
run_worker_manager()
|
||||
|
||||
|
||||
@click.command(name="worker")
|
||||
@add_stop_server_options
|
||||
def stop_model_worker(port: int):
|
||||
"""Stop model worker"""
|
||||
name = "ModelWorker"
|
||||
if port:
|
||||
name = f"{name}-{port}"
|
||||
_stop_service("worker", name, port=port)
|
||||
|
||||
|
||||
@click.command(name="apiserver")
|
||||
@EnvArgumentParser.create_click_option(ModelAPIServerParameters)
|
||||
def start_apiserver(**kwargs):
|
||||
"""Start apiserver"""
|
||||
|
||||
if kwargs["daemon"]:
|
||||
log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log")
|
||||
_run_current_with_daemon("ModelAPIServer", log_file)
|
||||
else:
|
||||
from dbgpt.model.cluster import run_apiserver
|
||||
|
||||
run_apiserver()
|
||||
|
||||
|
||||
@click.command(name="apiserver")
|
||||
@add_stop_server_options
|
||||
def stop_apiserver(port: int):
|
||||
"""Stop apiserver"""
|
||||
name = "ModelAPIServer"
|
||||
if port:
|
||||
name = f"{name}-{port}"
|
||||
_stop_service("apiserver", name, port=port)
|
||||
|
||||
|
||||
def _stop_all_model_server(**kwargs):
|
||||
"""Stop all server"""
|
||||
_stop_service("worker", "ModelWorker")
|
||||
_stop_service("controller", "ModelController")
|
Reference in New Issue
Block a user