mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 05:47:47 +00:00
475 lines
14 KiB
Python
475 lines
14 KiB
Python
import functools
|
|
import logging
|
|
import os
|
|
from typing import Callable, List, Optional, Type
|
|
|
|
import click
|
|
|
|
from dbgpt.configs.model_config import LOGDIR
|
|
from dbgpt.model.base import WorkerApplyType
|
|
from dbgpt.model.parameter import (
|
|
BaseParameters,
|
|
ModelAPIServerParameters,
|
|
ModelControllerParameters,
|
|
ModelParameters,
|
|
ModelWorkerParameters,
|
|
)
|
|
from dbgpt.util import get_or_create_event_loop
|
|
from dbgpt.util.command_utils import (
|
|
_detect_controller_address,
|
|
_run_current_with_daemon,
|
|
_stop_service,
|
|
)
|
|
from dbgpt.util.parameter_utils import (
|
|
EnvArgumentParser,
|
|
_build_parameter_class,
|
|
build_lazy_click_command,
|
|
)
|
|
|
|
# Your can set environment variable CONTROLLER_ADDRESS to set the default address
|
|
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
|
|
|
logger = logging.getLogger("dbgpt_cli")
|
|
|
|
|
|
def _get_worker_manager(address: str):
|
|
from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager
|
|
|
|
registry = ModelRegistryClient(address)
|
|
worker_manager = RemoteWorkerManager(registry)
|
|
return worker_manager
|
|
|
|
|
|
@click.group("model")
|
|
@click.option(
|
|
"--address",
|
|
type=str,
|
|
default=None,
|
|
required=False,
|
|
show_default=True,
|
|
help=(
|
|
"Address of the Model Controller to connect to. "
|
|
"Just support light deploy model, If the environment variable CONTROLLER_ADDRESS is configured, read from the environment variable"
|
|
),
|
|
)
|
|
def model_cli_group(address: str):
|
|
"""Clients that manage model serving"""
|
|
global MODEL_CONTROLLER_ADDRESS
|
|
if not address:
|
|
MODEL_CONTROLLER_ADDRESS = _detect_controller_address()
|
|
else:
|
|
MODEL_CONTROLLER_ADDRESS = address
|
|
|
|
|
|
@model_cli_group.command()
|
|
@click.option(
|
|
"--model_name", type=str, default=None, required=False, help=("The name of model")
|
|
)
|
|
@click.option(
|
|
"--model_type", type=str, default="llm", required=False, help=("The type of model")
|
|
)
|
|
def list(model_name: str, model_type: str):
|
|
"""List model instances"""
|
|
from prettytable import PrettyTable
|
|
|
|
from dbgpt.model.cluster import ModelRegistryClient
|
|
|
|
loop = get_or_create_event_loop()
|
|
registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
|
|
|
|
if not model_name:
|
|
instances = loop.run_until_complete(registry.get_all_model_instances())
|
|
else:
|
|
if not model_type:
|
|
model_type = "llm"
|
|
register_model_name = f"{model_name}@{model_type}"
|
|
instances = loop.run_until_complete(
|
|
registry.get_all_instances(register_model_name)
|
|
)
|
|
table = PrettyTable()
|
|
|
|
table.field_names = [
|
|
"Model Name",
|
|
"Model Type",
|
|
"Host",
|
|
"Port",
|
|
"Healthy",
|
|
"Enabled",
|
|
"Prompt Template",
|
|
"Last Heartbeat",
|
|
]
|
|
for instance in instances:
|
|
model_name, model_type = instance.model_name.split("@")
|
|
table.add_row(
|
|
[
|
|
model_name,
|
|
model_type,
|
|
instance.host,
|
|
instance.port,
|
|
instance.healthy,
|
|
instance.enabled,
|
|
instance.prompt_template if instance.prompt_template else "",
|
|
instance.last_heartbeat,
|
|
]
|
|
)
|
|
|
|
print(table)
|
|
|
|
|
|
def add_model_options(func):
|
|
@click.option(
|
|
"--model_name",
|
|
type=str,
|
|
default=None,
|
|
required=True,
|
|
help=("The name of model"),
|
|
)
|
|
@click.option(
|
|
"--model_type",
|
|
type=str,
|
|
default="llm",
|
|
required=False,
|
|
help=("The type of model"),
|
|
)
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
@model_cli_group.command()
|
|
@add_model_options
|
|
@click.option(
|
|
"--host",
|
|
type=str,
|
|
required=True,
|
|
help=("The remote host to stop model"),
|
|
)
|
|
@click.option(
|
|
"--port",
|
|
type=int,
|
|
required=True,
|
|
help=("The remote port to stop model"),
|
|
)
|
|
def stop(model_name: str, model_type: str, host: str, port: int):
|
|
"""Stop model instances"""
|
|
from dbgpt.model.cluster import RemoteWorkerManager, WorkerStartupRequest
|
|
|
|
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
|
|
req = WorkerStartupRequest(
|
|
host=host,
|
|
port=port,
|
|
worker_type=model_type,
|
|
model=model_name,
|
|
params={},
|
|
)
|
|
loop = get_or_create_event_loop()
|
|
res = loop.run_until_complete(worker_manager.model_shutdown(req))
|
|
print(res)
|
|
|
|
|
|
def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
|
|
from dataclasses import dataclass, field
|
|
|
|
from dbgpt.model.cluster import RemoteWorkerManager
|
|
from dbgpt.model.parameter import WorkerType
|
|
from dbgpt.util.parameter_utils import _SimpleArgParser
|
|
|
|
pre_args = _SimpleArgParser("model_name", "address", "host", "port")
|
|
pre_args.parse()
|
|
model_name = pre_args.get("model_name")
|
|
address = pre_args.get("address")
|
|
host = pre_args.get("host")
|
|
port = pre_args.get("port")
|
|
if port:
|
|
port = int(port)
|
|
|
|
if not address:
|
|
address = _detect_controller_address()
|
|
|
|
worker_manager: RemoteWorkerManager = _get_worker_manager(address)
|
|
loop = get_or_create_event_loop()
|
|
models = loop.run_until_complete(worker_manager.supported_models())
|
|
|
|
fields_dict = {}
|
|
fields_dict["model_name"] = (
|
|
str,
|
|
field(default=None, metadata={"help": "The model name to deploy"}),
|
|
)
|
|
fields_dict["host"] = (
|
|
str,
|
|
field(default=None, metadata={"help": "The remote host to deploy model"}),
|
|
)
|
|
fields_dict["port"] = (
|
|
int,
|
|
field(default=None, metadata={"help": "The remote port to deploy model"}),
|
|
)
|
|
result_class = dataclass(
|
|
type("RemoteModelWorkerParameters", (object,), fields_dict)
|
|
)
|
|
|
|
if not models:
|
|
return [result_class]
|
|
|
|
valid_models = []
|
|
valid_model_cls = []
|
|
for model in models:
|
|
if host and host != model.host:
|
|
continue
|
|
if port and port != model.port:
|
|
continue
|
|
valid_models += [m.model for m in model.models]
|
|
valid_model_cls += [
|
|
(m, _build_parameter_class(m.params)) for m in model.models if m.params
|
|
]
|
|
real_model, real_params_cls = valid_model_cls[0]
|
|
real_path = None
|
|
real_worker_type = "llm"
|
|
if model_name:
|
|
params_cls_list = [m for m in valid_model_cls if m[0].model == model_name]
|
|
if not params_cls_list:
|
|
raise ValueError(f"Not supported model with model name: {model_name}")
|
|
real_model, real_params_cls = params_cls_list[0]
|
|
real_path = real_model.path
|
|
real_worker_type = real_model.worker_type
|
|
|
|
@dataclass
|
|
class RemoteModelWorkerParameters(BaseParameters):
|
|
model_name: str = field(
|
|
metadata={"valid_values": valid_models, "help": "The model name to deploy"}
|
|
)
|
|
model_path: Optional[str] = field(
|
|
default=real_path, metadata={"help": "The model path to deploy"}
|
|
)
|
|
host: Optional[str] = field(
|
|
default=models[0].host,
|
|
metadata={
|
|
"valid_values": [model.host for model in models],
|
|
"help": "The remote host to deploy model",
|
|
},
|
|
)
|
|
|
|
port: Optional[int] = field(
|
|
default=models[0].port,
|
|
metadata={
|
|
"valid_values": [model.port for model in models],
|
|
"help": "The remote port to deploy model",
|
|
},
|
|
)
|
|
worker_type: Optional[str] = field(
|
|
default=real_worker_type,
|
|
metadata={
|
|
"valid_values": WorkerType.values(),
|
|
"help": "Worker type",
|
|
},
|
|
)
|
|
|
|
return [RemoteModelWorkerParameters, real_params_cls]
|
|
|
|
|
|
@model_cli_group.command(
|
|
cls=build_lazy_click_command(_dynamic_factory=_remote_model_dynamic_factory)
|
|
)
|
|
def start(**kwargs):
|
|
"""Start model instances"""
|
|
from dbgpt.model.cluster import RemoteWorkerManager, WorkerStartupRequest
|
|
|
|
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
|
|
req = WorkerStartupRequest(
|
|
host=kwargs["host"],
|
|
port=kwargs["port"],
|
|
worker_type=kwargs["worker_type"],
|
|
model=kwargs["model_name"],
|
|
params={},
|
|
)
|
|
del kwargs["host"]
|
|
del kwargs["port"]
|
|
del kwargs["worker_type"]
|
|
req.params = kwargs
|
|
loop = get_or_create_event_loop()
|
|
res = loop.run_until_complete(worker_manager.model_startup(req))
|
|
print(res)
|
|
|
|
|
|
@model_cli_group.command()
|
|
@add_model_options
|
|
def restart(model_name: str, model_type: str):
|
|
"""Restart model instances"""
|
|
worker_apply(
|
|
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.RESTART
|
|
)
|
|
|
|
|
|
@model_cli_group.command()
|
|
@click.option(
|
|
"--model_name",
|
|
type=str,
|
|
default=None,
|
|
required=True,
|
|
help=("The name of model"),
|
|
)
|
|
@click.option(
|
|
"--system",
|
|
type=str,
|
|
default=None,
|
|
required=False,
|
|
help=("System prompt"),
|
|
)
|
|
def chat(model_name: str, system: str):
|
|
"""Interact with your bot from the command line"""
|
|
_cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system)
|
|
|
|
|
|
def worker_apply(
|
|
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
|
):
|
|
from dbgpt.model.cluster import WorkerApplyRequest
|
|
|
|
loop = get_or_create_event_loop()
|
|
worker_manager = _get_worker_manager(address)
|
|
apply_req = WorkerApplyRequest(
|
|
model=model_name, worker_type=model_type, apply_type=apply_type
|
|
)
|
|
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
|
|
print(res)
|
|
|
|
|
|
def _cli_chat(address: str, model_name: str, system_prompt: str = None):
|
|
loop = get_or_create_event_loop()
|
|
worker_manager = worker_manager = _get_worker_manager(address)
|
|
loop.run_until_complete(_chat_stream(worker_manager, model_name, system_prompt))
|
|
|
|
|
|
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
|
|
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
|
from dbgpt.model.cluster import PromptRequest
|
|
|
|
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
|
|
hist = []
|
|
previous_response = ""
|
|
if system_prompt:
|
|
hist.append(
|
|
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
|
|
)
|
|
while True:
|
|
previous_response = ""
|
|
user_input = input("\n\nYou: ")
|
|
if user_input.lower().strip() == "exit":
|
|
break
|
|
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
|
|
request = PromptRequest(messages=hist, model=model_name, prompt="", echo=False)
|
|
request = request.dict(exclude_none=True)
|
|
print("Bot: ", end="")
|
|
async for response in worker_manager.generate_stream(request):
|
|
incremental_output = response.text[len(previous_response) :]
|
|
print(incremental_output, end="", flush=True)
|
|
previous_response = response.text
|
|
hist.append(
|
|
ModelMessage(role=ModelMessageRoleType.AI, content=previous_response)
|
|
)
|
|
|
|
|
|
def add_stop_server_options(func):
|
|
@click.option(
|
|
"--port",
|
|
type=int,
|
|
default=None,
|
|
required=False,
|
|
help=("The port to stop"),
|
|
)
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
@click.command(name="controller")
|
|
@EnvArgumentParser.create_click_option(ModelControllerParameters)
|
|
def start_model_controller(**kwargs):
|
|
"""Start model controller"""
|
|
|
|
if kwargs["daemon"]:
|
|
log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log")
|
|
_run_current_with_daemon("ModelController", log_file)
|
|
else:
|
|
from dbgpt.model.cluster import run_model_controller
|
|
|
|
run_model_controller()
|
|
|
|
|
|
@click.command(name="controller")
|
|
@add_stop_server_options
|
|
def stop_model_controller(port: int):
|
|
"""Start model controller"""
|
|
# Command fragments to check against running processes
|
|
_stop_service("controller", "ModelController", port=port)
|
|
|
|
|
|
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
|
|
from dbgpt.model.adapter.model_adapter import _dynamic_model_parser
|
|
|
|
param_class = _dynamic_model_parser()
|
|
fix_class = [ModelWorkerParameters]
|
|
if not param_class:
|
|
param_class = [ModelParameters]
|
|
fix_class += param_class
|
|
return fix_class
|
|
|
|
|
|
@click.command(
|
|
name="worker", cls=build_lazy_click_command(_dynamic_factory=_model_dynamic_factory)
|
|
)
|
|
def start_model_worker(**kwargs):
|
|
"""Start model worker"""
|
|
if kwargs["daemon"]:
|
|
port = kwargs["port"]
|
|
model_type = kwargs.get("worker_type") or "llm"
|
|
log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log")
|
|
_run_current_with_daemon("ModelWorker", log_file)
|
|
else:
|
|
from dbgpt.model.cluster import run_worker_manager
|
|
|
|
run_worker_manager()
|
|
|
|
|
|
@click.command(name="worker")
|
|
@add_stop_server_options
|
|
def stop_model_worker(port: int):
|
|
"""Stop model worker"""
|
|
name = "ModelWorker"
|
|
if port:
|
|
name = f"{name}-{port}"
|
|
_stop_service("worker", name, port=port)
|
|
|
|
|
|
@click.command(name="apiserver")
|
|
@EnvArgumentParser.create_click_option(ModelAPIServerParameters)
|
|
def start_apiserver(**kwargs):
|
|
"""Start apiserver"""
|
|
|
|
if kwargs["daemon"]:
|
|
log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log")
|
|
_run_current_with_daemon("ModelAPIServer", log_file)
|
|
else:
|
|
from dbgpt.model.cluster import run_apiserver
|
|
|
|
run_apiserver()
|
|
|
|
|
|
@click.command(name="apiserver")
|
|
@add_stop_server_options
|
|
def stop_apiserver(port: int):
|
|
"""Stop apiserver"""
|
|
name = "ModelAPIServer"
|
|
if port:
|
|
name = f"{name}-{port}"
|
|
_stop_service("apiserver", name, port=port)
|
|
|
|
|
|
def _stop_all_model_server(**kwargs):
|
|
"""Stop all server"""
|
|
_stop_service("worker", "ModelWorker")
|
|
_stop_service("controller", "ModelController")
|