feat: Command-line tool design and multi-model integration

This commit is contained in:
FangYin Cheng 2023-08-31 17:21:38 +08:00
parent 05712d39b9
commit e4dd6060da
15 changed files with 887 additions and 229 deletions

View File

@ -0,0 +1,219 @@
Cluster deployment
==================================
## Model cluster deployment
**Installing Command-Line Tool**
All operations below are performed using the `dbgpt` command. To use the `dbgpt` command, you need to install the DB-GPT project with `pip install -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` as a substitute for the `dbgpt` command.
### Launch Model Controller
```bash
dbgpt start controller
```
By default, the Model Controller starts on port 8000.
### Launch Model Worker
If you are starting `chatglm2-6b`:
```bash
dbgpt start worker --model_name chatglm2-6b \
--model_path /app/models/chatglm2-6b \
--port 8001 \
--controller_addr http://127.0.0.1:8000
```
If you are starting `vicuna-13b-v1.5`:
```bash
dbgpt start worker --model_name vicuna-13b-v1.5 \
--model_path /app/models/vicuna-13b-v1.5 \
--port 8002 \
--controller_addr http://127.0.0.1:8000
```
Note: Be sure to use your own model name and model path.
Check your model:
```bash
dbgpt model list
```
You will see the following output:
```
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat |
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
| chatglm2-6b | llm | 172.17.0.6 | 8001 | True | True | None | 2023-08-31T04:48:45.252939 |
| vicuna-13b-v1.5 | llm | 172.17.0.6 | 8002 | True | True | None | 2023-08-31T04:48:55.136676 |
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
```
### Connect to the model service in the webserver (dbgpt_server)
**First, modify the `.env` file to change the model name and the Model Controller connection address.**
```bash
LLM_MODEL=vicuna-13b-v1.5
# The current default MODEL_SERVER address is the address of the Model Controller
MODEL_SERVER=http://127.0.0.1:8000
```
#### Start the webserver
```bash
python pilot/server/dbgpt_server.py --light
```
`--light` indicates not to start the embedded model service.
Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` to start:
```bash
LLM_MODEL=chatglm2-6b python pilot/server/dbgpt_server.py --light
```
### More Command-Line Usages
You can view more command-line usages through the help command.
**View the `dbgpt` help**
```bash
dbgpt --help
```
You will see the basic command parameters and usage:
```
Usage: dbgpt [OPTIONS] COMMAND [ARGS]...
Options:
--log-level TEXT Log level
--version Show the version and exit.
--help Show this message and exit.
Commands:
model Clients that manage model serving
start Start specific server.
stop Start specific server.
```
**View the `dbgpt start` help**
```bash
dbgpt start --help
```
Here you can see the related commands and usage for start:
```
Usage: dbgpt start [OPTIONS] COMMAND [ARGS]...
Start specific server.
Options:
--help Show this message and exit.
Commands:
apiserver Start apiserver(TODO)
controller Start model controller
webserver Start webserver(dbgpt_server.py)
worker Start model worker
```
**View the `dbgpt start worker`help**
```bash
dbgpt start worker --help
```
Here you can see the parameters to start Model Worker:
```
Usage: dbgpt start worker [OPTIONS]
Start model worker
Options:
--model_name TEXT Model name [required]
--model_path TEXT Model path [required]
--worker_type TEXT Worker type
--worker_class TEXT Model worker class, pilot.model.worker.defau
lt_worker.DefaultModelWorker
--host TEXT Model worker deploy host [default: 0.0.0.0]
--port INTEGER Model worker deploy port [default: 8000]
--limit_model_concurrency INTEGER
Model concurrency limit [default: 5]
--standalone Standalone mode. If True, embedded Run
ModelController
--register Register current worker to model controller
[default: True]
--worker_register_host TEXT The ip address of current worker to register
to ModelController. If None, the address is
automatically determined
--controller_addr TEXT The Model controller address to register
--send_heartbeat Send heartbeat to model controller
[default: True]
--heartbeat_interval INTEGER The interval for sending heartbeats
(seconds) [default: 20]
--device TEXT Device to run model. If None, the device is
automatically determined
--model_type TEXT Model type, huggingface or llama.cpp
[default: huggingface]
--prompt_template TEXT Prompt template. If None, the prompt
template is automatically determined from
model path, supported template: zero_shot,vi
cuna_v1.1,llama-2,alpaca,baichuan-chat
--max_context_size INTEGER Maximum context size [default: 4096]
--num_gpus INTEGER The number of gpus you expect to use, if it
is empty, use all of them as much as
possible
--max_gpu_memory TEXT The maximum memory limit of each GPU, only
valid in multi-GPU configuration
--cpu_offloading CPU offloading
--load_8bit 8-bit quantization
--load_4bit 4-bit quantization
--quant_type TEXT Quantization datatypes, `fp4` (four bit
float) and `nf4` (normal four bit float),
only valid when load_4bit=True [default:
nf4]
--use_double_quant Nested quantization, only valid when
load_4bit=True [default: True]
--compute_dtype TEXT Model compute type
--trust_remote_code Trust remote code [default: True]
--verbose Show verbose output.
--help Show this message and exit.
```
**View the `dbgpt model`help**
```bash
dbgpt model --help
```
The `dbgpt model ` command can connect to the Model Controller via the Model Controller address and then manage a remote model:
```
Usage: dbgpt model [OPTIONS] COMMAND [ARGS]...
Clients that manage model serving
Options:
--address TEXT Address of the Model Controller to connect to. Just support
light deploy model [default: http://127.0.0.1:8000]
--help Show this message and exit.
Commands:
list List model instances
restart Restart model instances
start Start model instances
stop Stop model instances
```

View File

@ -19,6 +19,7 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
- llama_cpp
- quantization
- cluster deployment
.. toctree::
:maxdepth: 2
@ -28,3 +29,4 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
./llama/llama_cpp.md
./quantization/quantization.md
./cluster/model_cluster.md

View File

@ -0,0 +1,172 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) 2023, csunny
# This file is distributed under the same license as the DB-GPT package.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-08-31 16:43+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
"Language-Team: zh_CN <LL@li.org>\n"
"Plural-Forms: nplurals=1; plural=0;\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/llm/cluster/model_cluster.md:1
#: b0614062af2a4c039aa8947e45c4382a
msgid "Cluster deployment"
msgstr "集群部署"
#: ../../getting_started/install/llm/cluster/model_cluster.md:4
#: 568a2af68de5410ebd63f34b8a3665c5
msgid "Model cluster deployment"
msgstr "模型的集群部署"
#: ../../getting_started/install/llm/cluster/model_cluster.md:7
#: 41b9c4bf32bd4c95a1a7b733543a322b
msgid "**Installing Command-Line Tool**"
msgstr "**命令行工具的安装**"
#: ../../getting_started/install/llm/cluster/model_cluster.md:9
#: b340cf9309a742648fb05910dc6ba741
msgid ""
"All operations below are performed using the `dbgpt` command. To use the "
"`dbgpt` command, you need to install the DB-GPT project with `pip install"
" -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` "
"as a substitute for the `dbgpt` command."
msgstr "下面的操作都是使用 `dbgpt` 命令来进行,为了使用 `dbgpt` 命令,你需要使用 `pip install -e .` 来安装 DB-GPT 项目,当然,你也可以使用 `python pilot/scripts/cli_scripts.py` 来替代命令 `dbgpt`。"
#: ../../getting_started/install/llm/cluster/model_cluster.md:11
#: 9299e223f8ec4060aaf40ddba019bfbf
msgid "Launch Model Controller"
msgstr "启动 Model Controller"
#: ../../getting_started/install/llm/cluster/model_cluster.md:17
#: f858b280e1e54feba6639c7270b512c3
msgid "By default, the Model Controller starts on port 8000."
msgstr "默认情况下 Model Controller 启动端口是 8000。"
#: ../../getting_started/install/llm/cluster/model_cluster.md:20
#: 51eecd5331ce441abd9658f873132935
msgid "Launch Model Worker"
msgstr "启动 Model Worker"
#: ../../getting_started/install/llm/cluster/model_cluster.md:22
#: 09b1b7ffe3b740de94163f2ef24e8743
msgid "If you are starting `chatglm2-6b`:"
msgstr "如果你启动的是 `chatglm2-6b`"
#: ../../getting_started/install/llm/cluster/model_cluster.md:31
#: 47f9ec2b957e4a2a858a8833ce2cd104
msgid "If you are starting `vicuna-13b-v1.5`:"
msgstr "如果你启动的是 `vicuna-13b-v1.5`"
#: ../../getting_started/install/llm/cluster/model_cluster.md:40
#: 95a04817be5a473e8b4f7086ac8dfe2c
msgid "Note: Be sure to use your own model name and model path."
msgstr "注意:注意使用你自己的模型名称和模型路径。"
#: ../../getting_started/install/llm/cluster/model_cluster.md:43
#: 644fff6c28ca49efa06e0f2c89f45e27
msgid "Check your model:"
msgstr "查看你的模型:"
#: ../../getting_started/install/llm/cluster/model_cluster.md:49
#: 69df9c27c17e44b8a6cb010360c56338
msgid "You will see the following output:"
msgstr "你将会看到下面的输出:"
#: ../../getting_started/install/llm/cluster/model_cluster.md:59
#: a1e1489068d04b51bb239bd32ef3546b
msgid "Connect to the model service in the webserver (dbgpt_server)"
msgstr "在 webserver(dbgpt_server) 中连接模型服务"
#: ../../getting_started/install/llm/cluster/model_cluster.md:61
#: 0ab444ac06f34ce0994fe1f945abecde
msgid ""
"**First, modify the `.env` file to change the model name and the Model "
"Controller connection address.**"
msgstr "**先修改 `.env` 文件,修改模型名称和 Model Controller 连接地址。**"
#: ../../getting_started/install/llm/cluster/model_cluster.md:69
#: 5125b43d70d744a491a70aa1001303be
msgid "Start the webserver"
msgstr ""
#: ../../getting_started/install/llm/cluster/model_cluster.md:75
#: f12f463c27c74337ba8cc9c7e4e934b4
msgid "`--light` indicates not to start the embedded model service."
msgstr "`--light` 表示不启动内嵌的模型服务。"
#: ../../getting_started/install/llm/cluster/model_cluster.md:77
#: de2f7b040fd04c73a89ae453db84108f
msgid ""
"Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` "
"to start:"
msgstr "更简单的,你可以在命令行前添加 `LLM_MODEL=chatglm2-6b` 来启动:"
#: ../../getting_started/install/llm/cluster/model_cluster.md:84
#: 83b93195a9fb45aeaeead6b05719f164
msgid "More Command-Line Usages"
msgstr "命令行的更多用法"
#: ../../getting_started/install/llm/cluster/model_cluster.md:86
#: 8622edf06c1d443188e19a26ed306d58
msgid "You can view more command-line usages through the help command."
msgstr "你可以通过帮助命令来查看更多的命令行用法。"
#: ../../getting_started/install/llm/cluster/model_cluster.md:88
#: 391c6fa585d045f5a583e26478adda7e
msgid "**View the `dbgpt` help**"
msgstr "**查看 `dbgpt` 的帮助**"
#: ../../getting_started/install/llm/cluster/model_cluster.md:93
#: 4247e8ee8f084fa780fb523530cc64ab
msgid "You will see the basic command parameters and usage:"
msgstr "你将会看到基础的命令参数和用法:"
#: ../../getting_started/install/llm/cluster/model_cluster.md:109
#: 87b43469e93e4b62a59488d1ef663b51
msgid "**View the `dbgpt start` help**"
msgstr "**查看 `dbgpt start` 的帮助**"
#: ../../getting_started/install/llm/cluster/model_cluster.md:115
#: 213f7bfc459d4439b3a839292e2684e6
msgid "Here you can see the related commands and usage for start:"
msgstr "这里你能看到 start 相关的命令和用法。"
#: ../../getting_started/install/llm/cluster/model_cluster.md:132
#: eb616a87e01d4cdc87dc61e462d83729
msgid "**View the `dbgpt start worker`help**"
msgstr "**查看 `dbgpt start worker`的帮助**"
#: ../../getting_started/install/llm/cluster/model_cluster.md:138
#: 75918aa07d7c4ec39271358f2b474a63
msgid "Here you can see the parameters to start Model Worker:"
msgstr "这里你能看启动 Model Worker 的参数:"
#: ../../getting_started/install/llm/cluster/model_cluster.md:196
#: 5ee6ea98af4d4693b8cbf059ad3001d8
msgid "**View the `dbgpt model`help**"
msgstr "**查看 `dbgpt model`的帮助**"
#: ../../getting_started/install/llm/cluster/model_cluster.md:202
#: a96a1b58987a481fac2d527470226b90
msgid ""
"The `dbgpt model ` command can connect to the Model Controller via the "
"Model Controller address and then manage a remote model:"
msgstr "`dbgpt model ` 命令可以通过 Model Controller 地址连接到 Model Controller然后对远程对某个模型进行管理。"
#~ msgid ""
#~ "First, modify the `.env` file to "
#~ "change the model name and the "
#~ "Model Controller connection address."
#~ msgstr ""

View File

@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-08-17 23:29+0800\n"
"POT-Creation-Date: 2023-08-31 16:38+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
@ -20,80 +20,85 @@ msgstr ""
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/llm/llm.rst:2
#: ../../getting_started/install/llm/llm.rst:23
#: 9e4baaff732b49f1baded5fd1bde4bd5
#: ../../getting_started/install/llm/llm.rst:24
#: b348d4df8ca44dd78b42157a8ff6d33d
msgid "LLM Usage"
msgstr "LLM使用"
#: ../../getting_started/install/llm/llm.rst:3 d80e7360f17c4be889f220e2a846a81e
#: ../../getting_started/install/llm/llm.rst:3 7f5960a7e5634254b330da27be87594b
msgid ""
"DB-GPT provides a management and deployment solution for multiple models."
" This chapter mainly discusses how to deploy different models."
msgstr "DB-GPT提供了多模型的管理和部署方案本章主要讲解针对不同的模型该怎么部署"
#: ../../getting_started/install/llm/llm.rst:18
#: 1cebd6c84b924eeb877ba31bd520328f
#: b844ab204ec740ec9d7d191bb841f09e
msgid ""
"Multi LLMs Support, Supports multiple large language models, currently "
"supporting"
msgstr "目前DB-GPT已适配如下模型"
#: ../../getting_started/install/llm/llm.rst:9 1163d5d3a8834b24bb6a5716d4fcc680
#: ../../getting_started/install/llm/llm.rst:9 c141437ddaf84c079360008343041b2f
msgid "🔥 Vicuna-v1.5(7b,13b)"
msgstr "🔥 Vicuna-v1.5(7b,13b)"
#: ../../getting_started/install/llm/llm.rst:10
#: 318dbc8e8de64bfb9cbc65343d050500
#: d32b1e3f114c4eab8782b497097c1b37
msgid "🔥 llama-2(7b,13b,70b)"
msgstr "🔥 llama-2(7b,13b,70b)"
#: ../../getting_started/install/llm/llm.rst:11
#: f2b1ef2bfbde46889131b6852d1211e8
#: 0a417ee4d008421da07fff7add5d05eb
msgid "WizardLM-v1.2(13b)"
msgstr "WizardLM-v1.2(13b)"
#: ../../getting_started/install/llm/llm.rst:12
#: 2462160e5928434aaa6f58074b585c14
#: 199e1a9fe3324dc8a1bcd9cd0b1ef047
msgid "Vicuna (7b,13b)"
msgstr "Vicuna (7b,13b)"
#: ../../getting_started/install/llm/llm.rst:13
#: 61cc1b9ed1914e379dee87034e9fbaea
#: a9e4c5100534450db3a583fa5850e4be
msgid "ChatGLM-6b (int4,int8)"
msgstr "ChatGLM-6b (int4,int8)"
#: ../../getting_started/install/llm/llm.rst:14
#: 653f5fd3b6ab4c508a2ce449231b5a17
#: 943324289eb94042b52fd824189cd93f
msgid "ChatGLM2-6b (int4,int8)"
msgstr "ChatGLM2-6b (int4,int8)"
#: ../../getting_started/install/llm/llm.rst:15
#: 6a1cf7a271f6496ba48d593db1f756f1
#: f1226fdfac3b4e9d88642ffa69d75682
msgid "guanaco(7b,13b,33b)"
msgstr "guanaco(7b,13b,33b)"
#: ../../getting_started/install/llm/llm.rst:16
#: f15083886362437ba96e1cb9ade738aa
#: 3f2457f56eb341b6bc431c9beca8f4df
msgid "Gorilla(7b,13b)"
msgstr "Gorilla(7b,13b)"
#: ../../getting_started/install/llm/llm.rst:17
#: 5e1ded19d1b24ba485e5d2bf22ce2db7
#: 86c8ce37be1c4a7ea3fc382100d77a9c
msgid "Baichuan(7b,13b)"
msgstr "Baichuan(7b,13b)"
#: ../../getting_started/install/llm/llm.rst:18
#: 30f6f71dd73b48d29a3647d58a9cdcaf
#: 538111af95ad414cb2e631a89f9af379
msgid "OpenAI"
msgstr "OpenAI"
#: ../../getting_started/install/llm/llm.rst:20
#: 9ad21acdaccb41c1bbfa30b5ce114732
#: a203325b7ec248f7bff61ae89226a000
msgid "llama_cpp"
msgstr "llama_cpp"
#: ../../getting_started/install/llm/llm.rst:21
#: b5064649f7c9443c9b9f15ec7fc02434
#: 21a50634198047228bc51a03d2c31292
msgid "quantization"
msgstr "quantization"
#: ../../getting_started/install/llm/llm.rst:22
#: dfaec4b04e6e45ff9c884b41534b1a79
msgid "cluster deployment"
msgstr ""

View File

@ -166,6 +166,9 @@ class Config(metaclass=Singleton):
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True"
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
self.IS_LOAD_8BIT = False
# In order to be compatible with the new and old model parameter design
os.environ["load_8bit"] = str(self.IS_LOAD_8BIT)
os.environ["load_4bit"] = str(self.IS_LOAD_4BIT)
### EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")

View File

@ -7,25 +7,33 @@ from pilot.model.worker.manager import (
WorkerApplyRequest,
WorkerApplyType,
)
from pilot.model.parameter import (
ModelControllerParameters,
ModelWorkerParameters,
ModelParameters,
)
from pilot.utils import get_or_create_event_loop
from pilot.utils.parameter_utils import EnvArgumentParser
@click.group("model")
def model_cli_group():
pass
@model_cli_group.command()
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
show_default=True,
help=(
"Address of the Model Controller to connect to."
"Address of the Model Controller to connect to. "
"Just support light deploy model"
),
)
def model_cli_group():
"""Clients that manage model serving"""
pass
@model_cli_group.command()
@click.option(
"--model-name", type=str, default=None, required=False, help=("The name of model")
)
@ -79,16 +87,6 @@ def list(address: str, model_name: str, model_type: str):
def add_model_options(func):
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
help=(
"Address of the Model Controller to connect to."
"Just support light deploy model"
),
)
@click.option(
"--model-name",
type=str,
@ -149,3 +147,57 @@ def worker_apply(
)
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
print(res)
@click.command(name="controller")
@EnvArgumentParser.create_click_option(ModelControllerParameters)
def start_model_controller(**kwargs):
"""Start model controller"""
from pilot.model.controller.controller import run_model_controller
run_model_controller()
@click.command(name="controller")
def stop_model_controller(**kwargs):
"""Start model controller"""
raise NotImplementedError
@click.command(name="worker")
@EnvArgumentParser.create_click_option(ModelWorkerParameters, ModelParameters)
def start_model_worker(**kwargs):
"""Start model worker"""
from pilot.model.worker.manager import run_worker_manager
run_worker_manager()
@click.command(name="worker")
def stop_model_worker(**kwargs):
"""Stop model worker"""
raise NotImplementedError
@click.command(name="webserver")
def start_webserver(**kwargs):
"""Start webserver(dbgpt_server.py)"""
raise NotImplementedError
@click.command(name="webserver")
def stop_webserver(**kwargs):
"""Stop webserver(dbgpt_server.py)"""
raise NotImplementedError
@click.command(name="apiserver")
def start_apiserver(**kwargs):
"""Start apiserver(TODO)"""
raise NotImplementedError
@click.command(name="controller")
def stop_apiserver(**kwargs):
"""Start apiserver(TODO)"""
raise NotImplementedError

View File

@ -1,9 +1,11 @@
import logging
from typing import List
from fastapi import APIRouter
from fastapi import APIRouter, FastAPI
from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
from pilot.utils.parameter_utils import EnvArgumentParser
class ModelController:
@ -63,3 +65,22 @@ async def api_get_all_instances(model_name: str = None, healthy_only: bool = Fal
@router.post("/controller/heartbeat")
async def api_model_heartbeat(request: ModelInstance):
return await controller.send_heartbeat(request)
def run_model_controller():
import uvicorn
parser = EnvArgumentParser()
env_prefix = "controller_"
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
ModelControllerParameters, env_prefix=env_prefix
)
app = FastAPI()
app.include_router(router, prefix="/api")
uvicorn.run(
app, host=controller_params.host, port=controller_params.port, log_level="info"
)
if __name__ == "__main__":
run_model_controller()

View File

@ -8,14 +8,12 @@ from pilot.configs.model_config import DEVICE
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
from pilot.model.compression import compress_module
from pilot.model.parameter import (
EnvArgumentParser,
ModelParameters,
LlamaCppModelParameters,
_genenv_ignoring_key_case,
)
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
from pilot.singleton import Singleton
from pilot.utils import get_gpu_memory
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
from pilot.logs import logger

View File

@ -1,139 +1,15 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field, fields, MISSING
from enum import Enum
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, Optional
from pilot.model.conversation import conv_templates
from pilot.utils.parameter_utils import BaseParameters
suported_prompt_templates = ",".join(conv_templates.keys())
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
"""Get the value from the environment variable, ignoring the case of the key"""
if env_prefix:
env_key = env_prefix + env_key
return os.getenv(
env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value))
)
class EnvArgumentParser:
@staticmethod
def get_env_prefix(env_key: str) -> str:
if not env_key:
return None
env_key = env_key.replace("-", "_")
return env_key + "_"
def parse_args_into_dataclass(
self,
dataclass_type: Type,
env_prefix: str = None,
command_args: List[str] = None,
**kwargs,
) -> Any:
"""Parse parameters from environment variables and command lines and populate them into data class"""
parser = argparse.ArgumentParser()
for field in fields(dataclass_type):
env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
if not env_var_value:
# Read without env prefix
env_var_value = _genenv_ignoring_key_case(field.name)
if env_var_value:
env_var_value = env_var_value.strip()
if field.type is int or field.type == Optional[int]:
env_var_value = int(env_var_value)
elif field.type is float or field.type == Optional[float]:
env_var_value = float(env_var_value)
elif field.type is bool or field.type == Optional[bool]:
env_var_value = env_var_value.lower() == "true"
elif field.type is str or field.type == Optional[str]:
pass
else:
raise ValueError(f"Unsupported parameter type {field.type}")
if not env_var_value:
env_var_value = kwargs.get(field.name)
if not env_var_value:
env_var_value = field.default
# Add a command-line argument for this field
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
parser.add_argument(
f"--{field.name}",
type=self._get_argparse_type(field.type),
help=help_text,
choices=valid_values,
default=env_var_value,
)
# Parse the command-line arguments
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
print(f"cmd_args: {cmd_args}")
for field in fields(dataclass_type):
# cmd_line_value = getattr(cmd_args, field.name)
if field.name in cmd_args:
cmd_line_value = getattr(cmd_args, field.name)
if cmd_line_value is not None:
kwargs[field.name] = cmd_line_value
return dataclass_type(**kwargs)
@staticmethod
def _get_argparse_type(field_type: Type) -> Type:
# Return the appropriate type for argparse to use based on the field type
if field_type is int or field_type == Optional[int]:
return int
elif field_type is float or field_type == Optional[float]:
return float
elif field_type is bool or field_type == Optional[bool]:
return bool
elif field_type is str or field_type == Optional[str]:
return str
else:
raise ValueError(f"Unsupported parameter type {field_type}")
@staticmethod
def _get_argparse_type_str(field_type: Type) -> str:
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
if argparse_type is int:
return "int"
elif argparse_type is float:
return "float"
elif argparse_type is bool:
return "bool"
else:
return "str"
@dataclass
class ParameterDescription:
param_name: str
param_type: str
description: str
default_value: Optional[Any]
valid_values: Optional[List[Any]]
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
descriptions = []
for field in fields(dataclass_type):
descriptions.append(
ParameterDescription(
param_name=field.name,
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
description=field.metadata.get("help", None),
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
valid_values=field.metadata.get("valid_values", None),
)
)
return descriptions
class WorkerType(str, Enum):
LLM = "llm"
TEXT2VEC = "text2vec"
@ -144,58 +20,13 @@ class WorkerType(str, Enum):
@dataclass
class BaseParameters:
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
"""
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
Args:
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
Returns:
bool: True if at least one field was updated, otherwise False.
"""
updated = False # Flag to indicate whether any field was updated
if isinstance(source, (BaseParameters, dict)):
for field_info in fields(self):
# Check if the field has a "fixed" tag in metadata
tags = field_info.metadata.get("tags")
tags = [] if not tags else tags.split(",")
if tags and "fixed" in tags:
continue # skip this field
# Get the new value from source (either another BaseParameters object or a dict)
new_value = (
getattr(source, field_info.name)
if isinstance(source, BaseParameters)
else source.get(field_info.name, None)
)
# If the new value is not None and different from the current value, update the field and set the flag
if new_value is not None and new_value != getattr(
self, field_info.name
):
setattr(self, field_info.name, new_value)
updated = True
else:
raise ValueError(
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
)
return updated
def __str__(self) -> str:
class_name = self.__class__.__name__
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(self):
value = getattr(self, field_info.name)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)
class ModelControllerParameters(BaseParameters):
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model Controller deploy host"}
)
port: Optional[int] = field(
default=8000, metadata={"help": "Model Controller deploy port"}
)
@dataclass
@ -209,7 +40,7 @@ class ModelWorkerParameters(BaseParameters):
worker_class: Optional[str] = field(
default=None,
metadata={
"help": "Model worker deploy host, pilot.model.worker.default_worker.DefaultModelWorker"
"help": "Model worker class, pilot.model.worker.default_worker.DefaultModelWorker"
},
)
host: Optional[str] = field(

View File

@ -2,10 +2,9 @@ from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Type
from pilot.model.base import ModelOutput
from pilot.model.parameter import (
ModelParameters,
from pilot.model.parameter import ModelParameters, WorkerType
from pilot.utils.parameter_utils import (
ParameterDescription,
WorkerType,
_get_parameter_descriptions,
)

View File

@ -7,10 +7,11 @@ from pilot.configs.model_config import DEVICE
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import EnvArgumentParser, ModelParameters
from pilot.model.parameter import ModelParameters
from pilot.model.worker.base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger("model_worker")

View File

@ -5,11 +5,11 @@ from pilot.configs.model_config import DEVICE
from pilot.model.loader import _get_model_real_path
from pilot.model.parameter import (
EmbeddingModelParameters,
EnvArgumentParser,
WorkerType,
)
from pilot.model.worker.base import ModelWorker
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger("model_worker")

View File

@ -23,15 +23,14 @@ from pilot.model.base import (
)
from pilot.model.controller.registry import ModelRegistry
from pilot.model.parameter import (
EnvArgumentParser,
ModelParameters,
ModelWorkerParameters,
WorkerType,
ParameterDescription,
)
from pilot.model.worker.base import ModelWorker
from pilot.scene.base_message import ModelMessage
from pilot.utils import build_logger
from pilot.utils.parameter_utils import EnvArgumentParser, ParameterDescription
from pydantic import BaseModel
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
@ -148,6 +147,8 @@ class LocalWorkerManager(WorkerManager):
self.model_registry = model_registry
def _worker_key(self, worker_type: str, model_name: str) -> str:
if isinstance(worker_type, WorkerType):
worker_type = worker_type.value
return f"{model_name}@{worker_type}"
def add_worker(
@ -166,6 +167,9 @@ class LocalWorkerManager(WorkerManager):
if not worker_params.worker_type:
worker_params.worker_type = worker.worker_type()
if isinstance(worker_params.worker_type, WorkerType):
worker_params.worker_type = worker_params.worker_type.value
worker_key = self._worker_key(
worker_params.worker_type, worker_params.model_name
)

View File

@ -5,7 +5,7 @@ import copy
import logging
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
@ -23,16 +23,53 @@ def cli(log_level: str):
logging.basicConfig(level=log_level, encoding="utf-8")
def add_command_alias(command, name: str, hidden: bool = False):
def add_command_alias(command, name: str, hidden: bool = False, parent_group=None):
if not parent_group:
parent_group = cli
new_command = copy.deepcopy(command)
new_command.hidden = hidden
cli.add_command(new_command, name=name)
parent_group.add_command(new_command, name=name)
@click.group()
def start():
"""Start specific server."""
pass
@click.group()
def stop():
"""Start specific server."""
pass
cli.add_command(start)
cli.add_command(stop)
try:
from pilot.model.cli import model_cli_group
from pilot.model.cli import (
model_cli_group,
start_model_controller,
stop_model_controller,
start_model_worker,
stop_model_worker,
start_webserver,
stop_webserver,
start_apiserver,
stop_apiserver,
)
add_command_alias(model_cli_group, name="model", parent_group=cli)
add_command_alias(start_model_controller, name="controller", parent_group=start)
add_command_alias(start_model_worker, name="worker", parent_group=start)
add_command_alias(start_webserver, name="webserver", parent_group=start)
add_command_alias(start_apiserver, name="apiserver", parent_group=start)
add_command_alias(stop_model_controller, name="controller", parent_group=stop)
add_command_alias(stop_model_worker, name="worker", parent_group=stop)
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
add_command_alias(stop_apiserver, name="apiserver", parent_group=stop)
add_command_alias(model_cli_group, name="model")
except ImportError as e:
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")

View File

@ -0,0 +1,314 @@
import argparse
import os
from dataclasses import dataclass, fields, MISSING
from typing import Any, List, Optional, Type, Union, Callable
@dataclass
class ParameterDescription:
param_name: str
param_type: str
description: str
default_value: Optional[Any]
valid_values: Optional[List[Any]]
@dataclass
class BaseParameters:
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
"""
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
Args:
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
Returns:
bool: True if at least one field was updated, otherwise False.
"""
updated = False # Flag to indicate whether any field was updated
if isinstance(source, (BaseParameters, dict)):
for field_info in fields(self):
# Check if the field has a "fixed" tag in metadata
tags = field_info.metadata.get("tags")
tags = [] if not tags else tags.split(",")
if tags and "fixed" in tags:
continue # skip this field
# Get the new value from source (either another BaseParameters object or a dict)
new_value = (
getattr(source, field_info.name)
if isinstance(source, BaseParameters)
else source.get(field_info.name, None)
)
# If the new value is not None and different from the current value, update the field and set the flag
if new_value is not None and new_value != getattr(
self, field_info.name
):
setattr(self, field_info.name, new_value)
updated = True
else:
raise ValueError(
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
)
return updated
def __str__(self) -> str:
class_name = self.__class__.__name__
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(self):
value = getattr(self, field_info.name)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
"""Get the value from the environment variable, ignoring the case of the key"""
if env_prefix:
env_key = env_prefix + env_key
return os.getenv(
env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value))
)
class EnvArgumentParser:
@staticmethod
def get_env_prefix(env_key: str) -> str:
if not env_key:
return None
env_key = env_key.replace("-", "_")
return env_key + "_"
def parse_args_into_dataclass(
self,
dataclass_type: Type,
env_prefix: str = None,
command_args: List[str] = None,
**kwargs,
) -> Any:
"""Parse parameters from environment variables and command lines and populate them into data class"""
parser = argparse.ArgumentParser()
for field in fields(dataclass_type):
env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
if not env_var_value:
# Read without env prefix
env_var_value = _genenv_ignoring_key_case(field.name)
if env_var_value:
env_var_value = env_var_value.strip()
if field.type is int or field.type == Optional[int]:
env_var_value = int(env_var_value)
elif field.type is float or field.type == Optional[float]:
env_var_value = float(env_var_value)
elif field.type is bool or field.type == Optional[bool]:
env_var_value = env_var_value.lower() == "true"
elif field.type is str or field.type == Optional[str]:
pass
else:
raise ValueError(f"Unsupported parameter type {field.type}")
if not env_var_value:
env_var_value = kwargs.get(field.name)
# Add a command-line argument for this field
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
argument_kwargs = {
"type": EnvArgumentParser._get_argparse_type(field.type),
"help": help_text,
"choices": valid_values,
"required": EnvArgumentParser._is_require_type(field.type),
}
if field.default != MISSING:
argument_kwargs["default"] = field.default
argument_kwargs["required"] = False
if env_var_value:
argument_kwargs["default"] = env_var_value
argument_kwargs["required"] = False
parser.add_argument(f"--{field.name}", **argument_kwargs)
# Parse the command-line arguments
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
# cmd_args = parser.parse_args(args=command_args)
# print(f"cmd_args: {cmd_args}")
for field in fields(dataclass_type):
# cmd_line_value = getattr(cmd_args, field.name)
if field.name in cmd_args:
cmd_line_value = getattr(cmd_args, field.name)
if cmd_line_value is not None:
kwargs[field.name] = cmd_line_value
return dataclass_type(**kwargs)
@staticmethod
def create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dataclass_type.__doc__)
for field in fields(dataclass_type):
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
argument_kwargs = {
"type": EnvArgumentParser._get_argparse_type(field.type),
"help": help_text,
"choices": valid_values,
"required": EnvArgumentParser._is_require_type(field.type),
}
if field.default != MISSING:
argument_kwargs["default"] = field.default
argument_kwargs["required"] = False
parser.add_argument(f"--{field.name}", **argument_kwargs)
return parser
@staticmethod
def create_click_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
):
import click
import functools
from collections import OrderedDict
# TODO dynamic configuration
# pre_args = _SimpleArgParser('model_name', 'model_path')
# pre_args.parse()
# print(pre_args)
combined_fields = OrderedDict()
if _dynamic_factory:
_types = _dynamic_factory()
if _types:
dataclass_types = list(_types)
for dataclass_type in dataclass_types:
for field in fields(dataclass_type):
if field.name not in combined_fields:
combined_fields[field.name] = field
def decorator(func):
for field_name, field in reversed(combined_fields.items()):
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
cli_params = {
"default": None if field.default is MISSING else field.default,
"help": help_text,
"show_default": True,
"required": field.default is MISSING,
}
if valid_values:
cli_params["type"] = click.Choice(valid_values)
real_type = EnvArgumentParser._get_argparse_type(field.type)
if real_type is int:
cli_params["type"] = click.INT
elif real_type is float:
cli_params["type"] = click.FLOAT
elif real_type is str:
cli_params["type"] = click.STRING
elif real_type is bool:
cli_params["is_flag"] = True
option_decorator = click.option(
# f"--{field_name.replace('_', '-')}", **cli_params
f"--{field_name}",
**cli_params,
)
func = option_decorator(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorator
@staticmethod
def _get_argparse_type(field_type: Type) -> Type:
# Return the appropriate type for argparse to use based on the field type
if field_type is int or field_type == Optional[int]:
return int
elif field_type is float or field_type == Optional[float]:
return float
elif field_type is bool or field_type == Optional[bool]:
return bool
elif field_type is str or field_type == Optional[str]:
return str
else:
raise ValueError(f"Unsupported parameter type {field_type}")
@staticmethod
def _get_argparse_type_str(field_type: Type) -> str:
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
if argparse_type is int:
return "int"
elif argparse_type is float:
return "float"
elif argparse_type is bool:
return "bool"
else:
return "str"
@staticmethod
def _is_require_type(field_type: Type) -> str:
return field_type not in [Optional[int], Optional[float], Optional[bool]]
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
descriptions = []
for field in fields(dataclass_type):
descriptions.append(
ParameterDescription(
param_name=field.name,
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
description=field.metadata.get("help", None),
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
valid_values=field.metadata.get("valid_values", None),
)
)
return descriptions
class _SimpleArgParser:
def __init__(self, *args):
self.params = {arg.replace("_", "-"): None for arg in args}
def parse(self, args=None):
import sys
if args is None:
args = sys.argv[1:]
else:
args = list(args)
prev_arg = None
for arg in args:
if arg.startswith("--"):
if prev_arg:
self.params[prev_arg] = None
prev_arg = arg[2:]
else:
if prev_arg:
self.params[prev_arg] = arg
prev_arg = None
if prev_arg:
self.params[prev_arg] = None
def _get_param(self, key):
return self.params.get(key.replace("_", "-"), None)
def __getattr__(self, item):
return self._get_param(item)
def __getitem__(self, key):
return self._get_param(key)
def get(self, key, default=None):
return self._get_param(key) or default
def __str__(self):
return "\n".join(
[f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()]
)