mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-13 14:06:43 +00:00
feat: Command-line tool design and multi-model integration
This commit is contained in:
parent
05712d39b9
commit
e4dd6060da
219
docs/getting_started/install/llm/cluster/model_cluster.md
Normal file
219
docs/getting_started/install/llm/cluster/model_cluster.md
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
Cluster deployment
|
||||||
|
==================================
|
||||||
|
|
||||||
|
## Model cluster deployment
|
||||||
|
|
||||||
|
|
||||||
|
**Installing Command-Line Tool**
|
||||||
|
|
||||||
|
All operations below are performed using the `dbgpt` command. To use the `dbgpt` command, you need to install the DB-GPT project with `pip install -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` as a substitute for the `dbgpt` command.
|
||||||
|
|
||||||
|
### Launch Model Controller
|
||||||
|
|
||||||
|
```bash
|
||||||
|
dbgpt start controller
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, the Model Controller starts on port 8000.
|
||||||
|
|
||||||
|
|
||||||
|
### Launch Model Worker
|
||||||
|
|
||||||
|
If you are starting `chatglm2-6b`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
dbgpt start worker --model_name chatglm2-6b \
|
||||||
|
--model_path /app/models/chatglm2-6b \
|
||||||
|
--port 8001 \
|
||||||
|
--controller_addr http://127.0.0.1:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are starting `vicuna-13b-v1.5`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
dbgpt start worker --model_name vicuna-13b-v1.5 \
|
||||||
|
--model_path /app/models/vicuna-13b-v1.5 \
|
||||||
|
--port 8002 \
|
||||||
|
--controller_addr http://127.0.0.1:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: Be sure to use your own model name and model path.
|
||||||
|
|
||||||
|
|
||||||
|
Check your model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
dbgpt model list
|
||||||
|
```
|
||||||
|
|
||||||
|
You will see the following output:
|
||||||
|
```
|
||||||
|
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||||
|
| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat |
|
||||||
|
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||||
|
| chatglm2-6b | llm | 172.17.0.6 | 8001 | True | True | None | 2023-08-31T04:48:45.252939 |
|
||||||
|
| vicuna-13b-v1.5 | llm | 172.17.0.6 | 8002 | True | True | None | 2023-08-31T04:48:55.136676 |
|
||||||
|
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||||
|
```
|
||||||
|
|
||||||
|
### Connect to the model service in the webserver (dbgpt_server)
|
||||||
|
|
||||||
|
**First, modify the `.env` file to change the model name and the Model Controller connection address.**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLM_MODEL=vicuna-13b-v1.5
|
||||||
|
# The current default MODEL_SERVER address is the address of the Model Controller
|
||||||
|
MODEL_SERVER=http://127.0.0.1:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Start the webserver
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python pilot/server/dbgpt_server.py --light
|
||||||
|
```
|
||||||
|
|
||||||
|
`--light` indicates not to start the embedded model service.
|
||||||
|
|
||||||
|
Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` to start:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLM_MODEL=chatglm2-6b python pilot/server/dbgpt_server.py --light
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### More Command-Line Usages
|
||||||
|
|
||||||
|
You can view more command-line usages through the help command.
|
||||||
|
|
||||||
|
**View the `dbgpt` help**
|
||||||
|
```bash
|
||||||
|
dbgpt --help
|
||||||
|
```
|
||||||
|
|
||||||
|
You will see the basic command parameters and usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
Usage: dbgpt [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--log-level TEXT Log level
|
||||||
|
--version Show the version and exit.
|
||||||
|
--help Show this message and exit.
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
model Clients that manage model serving
|
||||||
|
start Start specific server.
|
||||||
|
stop Start specific server.
|
||||||
|
```
|
||||||
|
|
||||||
|
**View the `dbgpt start` help**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
dbgpt start --help
|
||||||
|
```
|
||||||
|
|
||||||
|
Here you can see the related commands and usage for start:
|
||||||
|
|
||||||
|
```
|
||||||
|
Usage: dbgpt start [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Start specific server.
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--help Show this message and exit.
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
apiserver Start apiserver(TODO)
|
||||||
|
controller Start model controller
|
||||||
|
webserver Start webserver(dbgpt_server.py)
|
||||||
|
worker Start model worker
|
||||||
|
```
|
||||||
|
|
||||||
|
**View the `dbgpt start worker`help**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
dbgpt start worker --help
|
||||||
|
```
|
||||||
|
|
||||||
|
Here you can see the parameters to start Model Worker:
|
||||||
|
|
||||||
|
```
|
||||||
|
Usage: dbgpt start worker [OPTIONS]
|
||||||
|
|
||||||
|
Start model worker
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--model_name TEXT Model name [required]
|
||||||
|
--model_path TEXT Model path [required]
|
||||||
|
--worker_type TEXT Worker type
|
||||||
|
--worker_class TEXT Model worker class, pilot.model.worker.defau
|
||||||
|
lt_worker.DefaultModelWorker
|
||||||
|
--host TEXT Model worker deploy host [default: 0.0.0.0]
|
||||||
|
--port INTEGER Model worker deploy port [default: 8000]
|
||||||
|
--limit_model_concurrency INTEGER
|
||||||
|
Model concurrency limit [default: 5]
|
||||||
|
--standalone Standalone mode. If True, embedded Run
|
||||||
|
ModelController
|
||||||
|
--register Register current worker to model controller
|
||||||
|
[default: True]
|
||||||
|
--worker_register_host TEXT The ip address of current worker to register
|
||||||
|
to ModelController. If None, the address is
|
||||||
|
automatically determined
|
||||||
|
--controller_addr TEXT The Model controller address to register
|
||||||
|
--send_heartbeat Send heartbeat to model controller
|
||||||
|
[default: True]
|
||||||
|
--heartbeat_interval INTEGER The interval for sending heartbeats
|
||||||
|
(seconds) [default: 20]
|
||||||
|
--device TEXT Device to run model. If None, the device is
|
||||||
|
automatically determined
|
||||||
|
--model_type TEXT Model type, huggingface or llama.cpp
|
||||||
|
[default: huggingface]
|
||||||
|
--prompt_template TEXT Prompt template. If None, the prompt
|
||||||
|
template is automatically determined from
|
||||||
|
model path, supported template: zero_shot,vi
|
||||||
|
cuna_v1.1,llama-2,alpaca,baichuan-chat
|
||||||
|
--max_context_size INTEGER Maximum context size [default: 4096]
|
||||||
|
--num_gpus INTEGER The number of gpus you expect to use, if it
|
||||||
|
is empty, use all of them as much as
|
||||||
|
possible
|
||||||
|
--max_gpu_memory TEXT The maximum memory limit of each GPU, only
|
||||||
|
valid in multi-GPU configuration
|
||||||
|
--cpu_offloading CPU offloading
|
||||||
|
--load_8bit 8-bit quantization
|
||||||
|
--load_4bit 4-bit quantization
|
||||||
|
--quant_type TEXT Quantization datatypes, `fp4` (four bit
|
||||||
|
float) and `nf4` (normal four bit float),
|
||||||
|
only valid when load_4bit=True [default:
|
||||||
|
nf4]
|
||||||
|
--use_double_quant Nested quantization, only valid when
|
||||||
|
load_4bit=True [default: True]
|
||||||
|
--compute_dtype TEXT Model compute type
|
||||||
|
--trust_remote_code Trust remote code [default: True]
|
||||||
|
--verbose Show verbose output.
|
||||||
|
--help Show this message and exit.
|
||||||
|
```
|
||||||
|
|
||||||
|
**View the `dbgpt model`help**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
dbgpt model --help
|
||||||
|
```
|
||||||
|
|
||||||
|
The `dbgpt model ` command can connect to the Model Controller via the Model Controller address and then manage a remote model:
|
||||||
|
|
||||||
|
```
|
||||||
|
Usage: dbgpt model [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Clients that manage model serving
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--address TEXT Address of the Model Controller to connect to. Just support
|
||||||
|
light deploy model [default: http://127.0.0.1:8000]
|
||||||
|
--help Show this message and exit.
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
list List model instances
|
||||||
|
restart Restart model instances
|
||||||
|
start Start model instances
|
||||||
|
stop Stop model instances
|
||||||
|
```
|
@ -19,6 +19,7 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
|
|||||||
|
|
||||||
- llama_cpp
|
- llama_cpp
|
||||||
- quantization
|
- quantization
|
||||||
|
- cluster deployment
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
@ -28,3 +29,4 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
|
|||||||
|
|
||||||
./llama/llama_cpp.md
|
./llama/llama_cpp.md
|
||||||
./quantization/quantization.md
|
./quantization/quantization.md
|
||||||
|
./cluster/model_cluster.md
|
||||||
|
@ -0,0 +1,172 @@
|
|||||||
|
# SOME DESCRIPTIVE TITLE.
|
||||||
|
# Copyright (C) 2023, csunny
|
||||||
|
# This file is distributed under the same license as the DB-GPT package.
|
||||||
|
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
|
||||||
|
#
|
||||||
|
#, fuzzy
|
||||||
|
msgid ""
|
||||||
|
msgstr ""
|
||||||
|
"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
|
||||||
|
"Report-Msgid-Bugs-To: \n"
|
||||||
|
"POT-Creation-Date: 2023-08-31 16:43+0800\n"
|
||||||
|
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||||
|
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||||
|
"Language: zh_CN\n"
|
||||||
|
"Language-Team: zh_CN <LL@li.org>\n"
|
||||||
|
"Plural-Forms: nplurals=1; plural=0;\n"
|
||||||
|
"MIME-Version: 1.0\n"
|
||||||
|
"Content-Type: text/plain; charset=utf-8\n"
|
||||||
|
"Content-Transfer-Encoding: 8bit\n"
|
||||||
|
"Generated-By: Babel 2.12.1\n"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:1
|
||||||
|
#: b0614062af2a4c039aa8947e45c4382a
|
||||||
|
msgid "Cluster deployment"
|
||||||
|
msgstr "集群部署"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:4
|
||||||
|
#: 568a2af68de5410ebd63f34b8a3665c5
|
||||||
|
msgid "Model cluster deployment"
|
||||||
|
msgstr "模型的集群部署"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:7
|
||||||
|
#: 41b9c4bf32bd4c95a1a7b733543a322b
|
||||||
|
msgid "**Installing Command-Line Tool**"
|
||||||
|
msgstr "**命令行工具的安装**"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:9
|
||||||
|
#: b340cf9309a742648fb05910dc6ba741
|
||||||
|
msgid ""
|
||||||
|
"All operations below are performed using the `dbgpt` command. To use the "
|
||||||
|
"`dbgpt` command, you need to install the DB-GPT project with `pip install"
|
||||||
|
" -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` "
|
||||||
|
"as a substitute for the `dbgpt` command."
|
||||||
|
msgstr "下面的操作都是使用 `dbgpt` 命令来进行,为了使用 `dbgpt` 命令,你需要使用 `pip install -e .` 来安装 DB-GPT 项目,当然,你也可以使用 `python pilot/scripts/cli_scripts.py` 来替代命令 `dbgpt`。"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:11
|
||||||
|
#: 9299e223f8ec4060aaf40ddba019bfbf
|
||||||
|
msgid "Launch Model Controller"
|
||||||
|
msgstr "启动 Model Controller"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:17
|
||||||
|
#: f858b280e1e54feba6639c7270b512c3
|
||||||
|
msgid "By default, the Model Controller starts on port 8000."
|
||||||
|
msgstr "默认情况下 Model Controller 启动端口是 8000。"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:20
|
||||||
|
#: 51eecd5331ce441abd9658f873132935
|
||||||
|
msgid "Launch Model Worker"
|
||||||
|
msgstr "启动 Model Worker"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:22
|
||||||
|
#: 09b1b7ffe3b740de94163f2ef24e8743
|
||||||
|
msgid "If you are starting `chatglm2-6b`:"
|
||||||
|
msgstr "如果你启动的是 `chatglm2-6b`"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:31
|
||||||
|
#: 47f9ec2b957e4a2a858a8833ce2cd104
|
||||||
|
msgid "If you are starting `vicuna-13b-v1.5`:"
|
||||||
|
msgstr "如果你启动的是 `vicuna-13b-v1.5`"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:40
|
||||||
|
#: 95a04817be5a473e8b4f7086ac8dfe2c
|
||||||
|
msgid "Note: Be sure to use your own model name and model path."
|
||||||
|
msgstr "注意:注意使用你自己的模型名称和模型路径。"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:43
|
||||||
|
#: 644fff6c28ca49efa06e0f2c89f45e27
|
||||||
|
msgid "Check your model:"
|
||||||
|
msgstr "查看你的模型:"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:49
|
||||||
|
#: 69df9c27c17e44b8a6cb010360c56338
|
||||||
|
msgid "You will see the following output:"
|
||||||
|
msgstr "你将会看到下面的输出:"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:59
|
||||||
|
#: a1e1489068d04b51bb239bd32ef3546b
|
||||||
|
msgid "Connect to the model service in the webserver (dbgpt_server)"
|
||||||
|
msgstr "在 webserver(dbgpt_server) 中连接模型服务"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:61
|
||||||
|
#: 0ab444ac06f34ce0994fe1f945abecde
|
||||||
|
msgid ""
|
||||||
|
"**First, modify the `.env` file to change the model name and the Model "
|
||||||
|
"Controller connection address.**"
|
||||||
|
msgstr "**先修改 `.env` 文件,修改模型名称和 Model Controller 连接地址。**"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:69
|
||||||
|
#: 5125b43d70d744a491a70aa1001303be
|
||||||
|
msgid "Start the webserver"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:75
|
||||||
|
#: f12f463c27c74337ba8cc9c7e4e934b4
|
||||||
|
msgid "`--light` indicates not to start the embedded model service."
|
||||||
|
msgstr "`--light` 表示不启动内嵌的模型服务。"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:77
|
||||||
|
#: de2f7b040fd04c73a89ae453db84108f
|
||||||
|
msgid ""
|
||||||
|
"Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` "
|
||||||
|
"to start:"
|
||||||
|
msgstr "更简单的,你可以在命令行前添加 `LLM_MODEL=chatglm2-6b` 来启动:"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:84
|
||||||
|
#: 83b93195a9fb45aeaeead6b05719f164
|
||||||
|
msgid "More Command-Line Usages"
|
||||||
|
msgstr "命令行的更多用法"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:86
|
||||||
|
#: 8622edf06c1d443188e19a26ed306d58
|
||||||
|
msgid "You can view more command-line usages through the help command."
|
||||||
|
msgstr "你可以通过帮助命令来查看更多的命令行用法。"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:88
|
||||||
|
#: 391c6fa585d045f5a583e26478adda7e
|
||||||
|
msgid "**View the `dbgpt` help**"
|
||||||
|
msgstr "**查看 `dbgpt` 的帮助**"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:93
|
||||||
|
#: 4247e8ee8f084fa780fb523530cc64ab
|
||||||
|
msgid "You will see the basic command parameters and usage:"
|
||||||
|
msgstr "你将会看到基础的命令参数和用法:"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:109
|
||||||
|
#: 87b43469e93e4b62a59488d1ef663b51
|
||||||
|
msgid "**View the `dbgpt start` help**"
|
||||||
|
msgstr "**查看 `dbgpt start` 的帮助**"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:115
|
||||||
|
#: 213f7bfc459d4439b3a839292e2684e6
|
||||||
|
msgid "Here you can see the related commands and usage for start:"
|
||||||
|
msgstr "这里你能看到 start 相关的命令和用法。"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:132
|
||||||
|
#: eb616a87e01d4cdc87dc61e462d83729
|
||||||
|
msgid "**View the `dbgpt start worker`help**"
|
||||||
|
msgstr "**查看 `dbgpt start worker`的帮助**"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:138
|
||||||
|
#: 75918aa07d7c4ec39271358f2b474a63
|
||||||
|
msgid "Here you can see the parameters to start Model Worker:"
|
||||||
|
msgstr "这里你能看启动 Model Worker 的参数:"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:196
|
||||||
|
#: 5ee6ea98af4d4693b8cbf059ad3001d8
|
||||||
|
msgid "**View the `dbgpt model`help**"
|
||||||
|
msgstr "**查看 `dbgpt model`的帮助**"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/cluster/model_cluster.md:202
|
||||||
|
#: a96a1b58987a481fac2d527470226b90
|
||||||
|
msgid ""
|
||||||
|
"The `dbgpt model ` command can connect to the Model Controller via the "
|
||||||
|
"Model Controller address and then manage a remote model:"
|
||||||
|
msgstr "`dbgpt model ` 命令可以通过 Model Controller 地址连接到 Model Controller,然后对远程对某个模型进行管理。"
|
||||||
|
|
||||||
|
#~ msgid ""
|
||||||
|
#~ "First, modify the `.env` file to "
|
||||||
|
#~ "change the model name and the "
|
||||||
|
#~ "Model Controller connection address."
|
||||||
|
#~ msgstr ""
|
||||||
|
|
@ -8,7 +8,7 @@ msgid ""
|
|||||||
msgstr ""
|
msgstr ""
|
||||||
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
|
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
|
||||||
"Report-Msgid-Bugs-To: \n"
|
"Report-Msgid-Bugs-To: \n"
|
||||||
"POT-Creation-Date: 2023-08-17 23:29+0800\n"
|
"POT-Creation-Date: 2023-08-31 16:38+0800\n"
|
||||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||||
"Language: zh_CN\n"
|
"Language: zh_CN\n"
|
||||||
@ -20,80 +20,85 @@ msgstr ""
|
|||||||
"Generated-By: Babel 2.12.1\n"
|
"Generated-By: Babel 2.12.1\n"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:2
|
#: ../../getting_started/install/llm/llm.rst:2
|
||||||
#: ../../getting_started/install/llm/llm.rst:23
|
#: ../../getting_started/install/llm/llm.rst:24
|
||||||
#: 9e4baaff732b49f1baded5fd1bde4bd5
|
#: b348d4df8ca44dd78b42157a8ff6d33d
|
||||||
msgid "LLM Usage"
|
msgid "LLM Usage"
|
||||||
msgstr "LLM使用"
|
msgstr "LLM使用"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:3 d80e7360f17c4be889f220e2a846a81e
|
#: ../../getting_started/install/llm/llm.rst:3 7f5960a7e5634254b330da27be87594b
|
||||||
msgid ""
|
msgid ""
|
||||||
"DB-GPT provides a management and deployment solution for multiple models."
|
"DB-GPT provides a management and deployment solution for multiple models."
|
||||||
" This chapter mainly discusses how to deploy different models."
|
" This chapter mainly discusses how to deploy different models."
|
||||||
msgstr "DB-GPT提供了多模型的管理和部署方案,本章主要讲解针对不同的模型该怎么部署"
|
msgstr "DB-GPT提供了多模型的管理和部署方案,本章主要讲解针对不同的模型该怎么部署"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:18
|
#: ../../getting_started/install/llm/llm.rst:18
|
||||||
#: 1cebd6c84b924eeb877ba31bd520328f
|
#: b844ab204ec740ec9d7d191bb841f09e
|
||||||
msgid ""
|
msgid ""
|
||||||
"Multi LLMs Support, Supports multiple large language models, currently "
|
"Multi LLMs Support, Supports multiple large language models, currently "
|
||||||
"supporting"
|
"supporting"
|
||||||
msgstr "目前DB-GPT已适配如下模型"
|
msgstr "目前DB-GPT已适配如下模型"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:9 1163d5d3a8834b24bb6a5716d4fcc680
|
#: ../../getting_started/install/llm/llm.rst:9 c141437ddaf84c079360008343041b2f
|
||||||
msgid "🔥 Vicuna-v1.5(7b,13b)"
|
msgid "🔥 Vicuna-v1.5(7b,13b)"
|
||||||
msgstr "🔥 Vicuna-v1.5(7b,13b)"
|
msgstr "🔥 Vicuna-v1.5(7b,13b)"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:10
|
#: ../../getting_started/install/llm/llm.rst:10
|
||||||
#: 318dbc8e8de64bfb9cbc65343d050500
|
#: d32b1e3f114c4eab8782b497097c1b37
|
||||||
msgid "🔥 llama-2(7b,13b,70b)"
|
msgid "🔥 llama-2(7b,13b,70b)"
|
||||||
msgstr "🔥 llama-2(7b,13b,70b)"
|
msgstr "🔥 llama-2(7b,13b,70b)"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:11
|
#: ../../getting_started/install/llm/llm.rst:11
|
||||||
#: f2b1ef2bfbde46889131b6852d1211e8
|
#: 0a417ee4d008421da07fff7add5d05eb
|
||||||
msgid "WizardLM-v1.2(13b)"
|
msgid "WizardLM-v1.2(13b)"
|
||||||
msgstr "WizardLM-v1.2(13b)"
|
msgstr "WizardLM-v1.2(13b)"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:12
|
#: ../../getting_started/install/llm/llm.rst:12
|
||||||
#: 2462160e5928434aaa6f58074b585c14
|
#: 199e1a9fe3324dc8a1bcd9cd0b1ef047
|
||||||
msgid "Vicuna (7b,13b)"
|
msgid "Vicuna (7b,13b)"
|
||||||
msgstr "Vicuna (7b,13b)"
|
msgstr "Vicuna (7b,13b)"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:13
|
#: ../../getting_started/install/llm/llm.rst:13
|
||||||
#: 61cc1b9ed1914e379dee87034e9fbaea
|
#: a9e4c5100534450db3a583fa5850e4be
|
||||||
msgid "ChatGLM-6b (int4,int8)"
|
msgid "ChatGLM-6b (int4,int8)"
|
||||||
msgstr "ChatGLM-6b (int4,int8)"
|
msgstr "ChatGLM-6b (int4,int8)"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:14
|
#: ../../getting_started/install/llm/llm.rst:14
|
||||||
#: 653f5fd3b6ab4c508a2ce449231b5a17
|
#: 943324289eb94042b52fd824189cd93f
|
||||||
msgid "ChatGLM2-6b (int4,int8)"
|
msgid "ChatGLM2-6b (int4,int8)"
|
||||||
msgstr "ChatGLM2-6b (int4,int8)"
|
msgstr "ChatGLM2-6b (int4,int8)"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:15
|
#: ../../getting_started/install/llm/llm.rst:15
|
||||||
#: 6a1cf7a271f6496ba48d593db1f756f1
|
#: f1226fdfac3b4e9d88642ffa69d75682
|
||||||
msgid "guanaco(7b,13b,33b)"
|
msgid "guanaco(7b,13b,33b)"
|
||||||
msgstr "guanaco(7b,13b,33b)"
|
msgstr "guanaco(7b,13b,33b)"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:16
|
#: ../../getting_started/install/llm/llm.rst:16
|
||||||
#: f15083886362437ba96e1cb9ade738aa
|
#: 3f2457f56eb341b6bc431c9beca8f4df
|
||||||
msgid "Gorilla(7b,13b)"
|
msgid "Gorilla(7b,13b)"
|
||||||
msgstr "Gorilla(7b,13b)"
|
msgstr "Gorilla(7b,13b)"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:17
|
#: ../../getting_started/install/llm/llm.rst:17
|
||||||
#: 5e1ded19d1b24ba485e5d2bf22ce2db7
|
#: 86c8ce37be1c4a7ea3fc382100d77a9c
|
||||||
msgid "Baichuan(7b,13b)"
|
msgid "Baichuan(7b,13b)"
|
||||||
msgstr "Baichuan(7b,13b)"
|
msgstr "Baichuan(7b,13b)"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:18
|
#: ../../getting_started/install/llm/llm.rst:18
|
||||||
#: 30f6f71dd73b48d29a3647d58a9cdcaf
|
#: 538111af95ad414cb2e631a89f9af379
|
||||||
msgid "OpenAI"
|
msgid "OpenAI"
|
||||||
msgstr "OpenAI"
|
msgstr "OpenAI"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:20
|
#: ../../getting_started/install/llm/llm.rst:20
|
||||||
#: 9ad21acdaccb41c1bbfa30b5ce114732
|
#: a203325b7ec248f7bff61ae89226a000
|
||||||
msgid "llama_cpp"
|
msgid "llama_cpp"
|
||||||
msgstr "llama_cpp"
|
msgstr "llama_cpp"
|
||||||
|
|
||||||
#: ../../getting_started/install/llm/llm.rst:21
|
#: ../../getting_started/install/llm/llm.rst:21
|
||||||
#: b5064649f7c9443c9b9f15ec7fc02434
|
#: 21a50634198047228bc51a03d2c31292
|
||||||
msgid "quantization"
|
msgid "quantization"
|
||||||
msgstr "quantization"
|
msgstr "quantization"
|
||||||
|
|
||||||
|
#: ../../getting_started/install/llm/llm.rst:22
|
||||||
|
#: dfaec4b04e6e45ff9c884b41534b1a79
|
||||||
|
msgid "cluster deployment"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
@ -166,6 +166,9 @@ class Config(metaclass=Singleton):
|
|||||||
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True"
|
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True"
|
||||||
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
|
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
|
||||||
self.IS_LOAD_8BIT = False
|
self.IS_LOAD_8BIT = False
|
||||||
|
# In order to be compatible with the new and old model parameter design
|
||||||
|
os.environ["load_8bit"] = str(self.IS_LOAD_8BIT)
|
||||||
|
os.environ["load_4bit"] = str(self.IS_LOAD_4BIT)
|
||||||
|
|
||||||
### EMBEDDING Configuration
|
### EMBEDDING Configuration
|
||||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||||
|
@ -7,25 +7,33 @@ from pilot.model.worker.manager import (
|
|||||||
WorkerApplyRequest,
|
WorkerApplyRequest,
|
||||||
WorkerApplyType,
|
WorkerApplyType,
|
||||||
)
|
)
|
||||||
|
from pilot.model.parameter import (
|
||||||
|
ModelControllerParameters,
|
||||||
|
ModelWorkerParameters,
|
||||||
|
ModelParameters,
|
||||||
|
)
|
||||||
from pilot.utils import get_or_create_event_loop
|
from pilot.utils import get_or_create_event_loop
|
||||||
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
|
|
||||||
@click.group("model")
|
@click.group("model")
|
||||||
def model_cli_group():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@model_cli_group.command()
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"--address",
|
"--address",
|
||||||
type=str,
|
type=str,
|
||||||
default="http://127.0.0.1:8000",
|
default="http://127.0.0.1:8000",
|
||||||
required=False,
|
required=False,
|
||||||
|
show_default=True,
|
||||||
help=(
|
help=(
|
||||||
"Address of the Model Controller to connect to."
|
"Address of the Model Controller to connect to. "
|
||||||
"Just support light deploy model"
|
"Just support light deploy model"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
def model_cli_group():
|
||||||
|
"""Clients that manage model serving"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@model_cli_group.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
"--model-name", type=str, default=None, required=False, help=("The name of model")
|
"--model-name", type=str, default=None, required=False, help=("The name of model")
|
||||||
)
|
)
|
||||||
@ -79,16 +87,6 @@ def list(address: str, model_name: str, model_type: str):
|
|||||||
|
|
||||||
|
|
||||||
def add_model_options(func):
|
def add_model_options(func):
|
||||||
@click.option(
|
|
||||||
"--address",
|
|
||||||
type=str,
|
|
||||||
default="http://127.0.0.1:8000",
|
|
||||||
required=False,
|
|
||||||
help=(
|
|
||||||
"Address of the Model Controller to connect to."
|
|
||||||
"Just support light deploy model"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
@ -149,3 +147,57 @@ def worker_apply(
|
|||||||
)
|
)
|
||||||
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
|
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
|
||||||
print(res)
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(name="controller")
|
||||||
|
@EnvArgumentParser.create_click_option(ModelControllerParameters)
|
||||||
|
def start_model_controller(**kwargs):
|
||||||
|
"""Start model controller"""
|
||||||
|
from pilot.model.controller.controller import run_model_controller
|
||||||
|
|
||||||
|
run_model_controller()
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(name="controller")
|
||||||
|
def stop_model_controller(**kwargs):
|
||||||
|
"""Start model controller"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(name="worker")
|
||||||
|
@EnvArgumentParser.create_click_option(ModelWorkerParameters, ModelParameters)
|
||||||
|
def start_model_worker(**kwargs):
|
||||||
|
"""Start model worker"""
|
||||||
|
from pilot.model.worker.manager import run_worker_manager
|
||||||
|
|
||||||
|
run_worker_manager()
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(name="worker")
|
||||||
|
def stop_model_worker(**kwargs):
|
||||||
|
"""Stop model worker"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(name="webserver")
|
||||||
|
def start_webserver(**kwargs):
|
||||||
|
"""Start webserver(dbgpt_server.py)"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(name="webserver")
|
||||||
|
def stop_webserver(**kwargs):
|
||||||
|
"""Stop webserver(dbgpt_server.py)"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(name="apiserver")
|
||||||
|
def start_apiserver(**kwargs):
|
||||||
|
"""Start apiserver(TODO)"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(name="controller")
|
||||||
|
def stop_apiserver(**kwargs):
|
||||||
|
"""Start apiserver(TODO)"""
|
||||||
|
raise NotImplementedError
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, FastAPI
|
||||||
from pilot.model.base import ModelInstance
|
from pilot.model.base import ModelInstance
|
||||||
|
from pilot.model.parameter import ModelControllerParameters
|
||||||
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
|
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
|
||||||
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
|
|
||||||
class ModelController:
|
class ModelController:
|
||||||
@ -63,3 +65,22 @@ async def api_get_all_instances(model_name: str = None, healthy_only: bool = Fal
|
|||||||
@router.post("/controller/heartbeat")
|
@router.post("/controller/heartbeat")
|
||||||
async def api_model_heartbeat(request: ModelInstance):
|
async def api_model_heartbeat(request: ModelInstance):
|
||||||
return await controller.send_heartbeat(request)
|
return await controller.send_heartbeat(request)
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_controller():
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
parser = EnvArgumentParser()
|
||||||
|
env_prefix = "controller_"
|
||||||
|
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
||||||
|
ModelControllerParameters, env_prefix=env_prefix
|
||||||
|
)
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router, prefix="/api")
|
||||||
|
uvicorn.run(
|
||||||
|
app, host=controller_params.host, port=controller_params.port, log_level="info"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_model_controller()
|
||||||
|
@ -8,14 +8,12 @@ from pilot.configs.model_config import DEVICE
|
|||||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
||||||
from pilot.model.compression import compress_module
|
from pilot.model.compression import compress_module
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
EnvArgumentParser,
|
|
||||||
ModelParameters,
|
ModelParameters,
|
||||||
LlamaCppModelParameters,
|
LlamaCppModelParameters,
|
||||||
_genenv_ignoring_key_case,
|
|
||||||
)
|
)
|
||||||
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
|
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
|
||||||
from pilot.singleton import Singleton
|
|
||||||
from pilot.utils import get_gpu_memory
|
from pilot.utils import get_gpu_memory
|
||||||
|
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,139 +1,15 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import argparse
|
from dataclasses import dataclass, field, fields, MISSING
|
||||||
import os
|
|
||||||
from dataclasses import dataclass, field, fields
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pilot.model.conversation import conv_templates
|
from pilot.model.conversation import conv_templates
|
||||||
|
from pilot.utils.parameter_utils import BaseParameters
|
||||||
|
|
||||||
suported_prompt_templates = ",".join(conv_templates.keys())
|
suported_prompt_templates = ",".join(conv_templates.keys())
|
||||||
|
|
||||||
|
|
||||||
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
|
|
||||||
"""Get the value from the environment variable, ignoring the case of the key"""
|
|
||||||
if env_prefix:
|
|
||||||
env_key = env_prefix + env_key
|
|
||||||
return os.getenv(
|
|
||||||
env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EnvArgumentParser:
|
|
||||||
@staticmethod
|
|
||||||
def get_env_prefix(env_key: str) -> str:
|
|
||||||
if not env_key:
|
|
||||||
return None
|
|
||||||
env_key = env_key.replace("-", "_")
|
|
||||||
return env_key + "_"
|
|
||||||
|
|
||||||
def parse_args_into_dataclass(
|
|
||||||
self,
|
|
||||||
dataclass_type: Type,
|
|
||||||
env_prefix: str = None,
|
|
||||||
command_args: List[str] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Any:
|
|
||||||
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
for field in fields(dataclass_type):
|
|
||||||
env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
|
|
||||||
if not env_var_value:
|
|
||||||
# Read without env prefix
|
|
||||||
env_var_value = _genenv_ignoring_key_case(field.name)
|
|
||||||
|
|
||||||
if env_var_value:
|
|
||||||
env_var_value = env_var_value.strip()
|
|
||||||
if field.type is int or field.type == Optional[int]:
|
|
||||||
env_var_value = int(env_var_value)
|
|
||||||
elif field.type is float or field.type == Optional[float]:
|
|
||||||
env_var_value = float(env_var_value)
|
|
||||||
elif field.type is bool or field.type == Optional[bool]:
|
|
||||||
env_var_value = env_var_value.lower() == "true"
|
|
||||||
elif field.type is str or field.type == Optional[str]:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported parameter type {field.type}")
|
|
||||||
if not env_var_value:
|
|
||||||
env_var_value = kwargs.get(field.name)
|
|
||||||
if not env_var_value:
|
|
||||||
env_var_value = field.default
|
|
||||||
|
|
||||||
# Add a command-line argument for this field
|
|
||||||
help_text = field.metadata.get("help", "")
|
|
||||||
valid_values = field.metadata.get("valid_values", None)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{field.name}",
|
|
||||||
type=self._get_argparse_type(field.type),
|
|
||||||
help=help_text,
|
|
||||||
choices=valid_values,
|
|
||||||
default=env_var_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse the command-line arguments
|
|
||||||
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
|
|
||||||
print(f"cmd_args: {cmd_args}")
|
|
||||||
for field in fields(dataclass_type):
|
|
||||||
# cmd_line_value = getattr(cmd_args, field.name)
|
|
||||||
if field.name in cmd_args:
|
|
||||||
cmd_line_value = getattr(cmd_args, field.name)
|
|
||||||
if cmd_line_value is not None:
|
|
||||||
kwargs[field.name] = cmd_line_value
|
|
||||||
|
|
||||||
return dataclass_type(**kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_argparse_type(field_type: Type) -> Type:
|
|
||||||
# Return the appropriate type for argparse to use based on the field type
|
|
||||||
if field_type is int or field_type == Optional[int]:
|
|
||||||
return int
|
|
||||||
elif field_type is float or field_type == Optional[float]:
|
|
||||||
return float
|
|
||||||
elif field_type is bool or field_type == Optional[bool]:
|
|
||||||
return bool
|
|
||||||
elif field_type is str or field_type == Optional[str]:
|
|
||||||
return str
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported parameter type {field_type}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_argparse_type_str(field_type: Type) -> str:
|
|
||||||
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
|
|
||||||
if argparse_type is int:
|
|
||||||
return "int"
|
|
||||||
elif argparse_type is float:
|
|
||||||
return "float"
|
|
||||||
elif argparse_type is bool:
|
|
||||||
return "bool"
|
|
||||||
else:
|
|
||||||
return "str"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ParameterDescription:
|
|
||||||
param_name: str
|
|
||||||
param_type: str
|
|
||||||
description: str
|
|
||||||
default_value: Optional[Any]
|
|
||||||
valid_values: Optional[List[Any]]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
|
||||||
descriptions = []
|
|
||||||
for field in fields(dataclass_type):
|
|
||||||
descriptions.append(
|
|
||||||
ParameterDescription(
|
|
||||||
param_name=field.name,
|
|
||||||
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
|
||||||
description=field.metadata.get("help", None),
|
|
||||||
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
|
|
||||||
valid_values=field.metadata.get("valid_values", None),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return descriptions
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerType(str, Enum):
|
class WorkerType(str, Enum):
|
||||||
LLM = "llm"
|
LLM = "llm"
|
||||||
TEXT2VEC = "text2vec"
|
TEXT2VEC = "text2vec"
|
||||||
@ -144,59 +20,14 @@ class WorkerType(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseParameters:
|
class ModelControllerParameters(BaseParameters):
|
||||||
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
|
host: Optional[str] = field(
|
||||||
"""
|
default="0.0.0.0", metadata={"help": "Model Controller deploy host"}
|
||||||
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
|
|
||||||
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if at least one field was updated, otherwise False.
|
|
||||||
"""
|
|
||||||
updated = False # Flag to indicate whether any field was updated
|
|
||||||
if isinstance(source, (BaseParameters, dict)):
|
|
||||||
for field_info in fields(self):
|
|
||||||
# Check if the field has a "fixed" tag in metadata
|
|
||||||
tags = field_info.metadata.get("tags")
|
|
||||||
tags = [] if not tags else tags.split(",")
|
|
||||||
if tags and "fixed" in tags:
|
|
||||||
continue # skip this field
|
|
||||||
# Get the new value from source (either another BaseParameters object or a dict)
|
|
||||||
new_value = (
|
|
||||||
getattr(source, field_info.name)
|
|
||||||
if isinstance(source, BaseParameters)
|
|
||||||
else source.get(field_info.name, None)
|
|
||||||
)
|
)
|
||||||
|
port: Optional[int] = field(
|
||||||
# If the new value is not None and different from the current value, update the field and set the flag
|
default=8000, metadata={"help": "Model Controller deploy port"}
|
||||||
if new_value is not None and new_value != getattr(
|
|
||||||
self, field_info.name
|
|
||||||
):
|
|
||||||
setattr(self, field_info.name, new_value)
|
|
||||||
updated = True
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return updated
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
class_name = self.__class__.__name__
|
|
||||||
parameters = [
|
|
||||||
f"\n\n=========================== {class_name} ===========================\n"
|
|
||||||
]
|
|
||||||
for field_info in fields(self):
|
|
||||||
value = getattr(self, field_info.name)
|
|
||||||
parameters.append(f"{field_info.name}: {value}")
|
|
||||||
parameters.append(
|
|
||||||
"\n======================================================================\n\n"
|
|
||||||
)
|
|
||||||
return "\n".join(parameters)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelWorkerParameters(BaseParameters):
|
class ModelWorkerParameters(BaseParameters):
|
||||||
@ -209,7 +40,7 @@ class ModelWorkerParameters(BaseParameters):
|
|||||||
worker_class: Optional[str] = field(
|
worker_class: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Model worker deploy host, pilot.model.worker.default_worker.DefaultModelWorker"
|
"help": "Model worker class, pilot.model.worker.default_worker.DefaultModelWorker"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
host: Optional[str] = field(
|
host: Optional[str] = field(
|
||||||
|
@ -2,10 +2,9 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Dict, Iterator, List, Type
|
from typing import Dict, Iterator, List, Type
|
||||||
|
|
||||||
from pilot.model.base import ModelOutput
|
from pilot.model.base import ModelOutput
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import ModelParameters, WorkerType
|
||||||
ModelParameters,
|
from pilot.utils.parameter_utils import (
|
||||||
ParameterDescription,
|
ParameterDescription,
|
||||||
WorkerType,
|
|
||||||
_get_parameter_descriptions,
|
_get_parameter_descriptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7,10 +7,11 @@ from pilot.configs.model_config import DEVICE
|
|||||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
||||||
from pilot.model.base import ModelOutput
|
from pilot.model.base import ModelOutput
|
||||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||||
from pilot.model.parameter import EnvArgumentParser, ModelParameters
|
from pilot.model.parameter import ModelParameters
|
||||||
from pilot.model.worker.base import ModelWorker
|
from pilot.model.worker.base import ModelWorker
|
||||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_torch_cache
|
||||||
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
logger = logging.getLogger("model_worker")
|
logger = logging.getLogger("model_worker")
|
||||||
|
|
||||||
|
@ -5,11 +5,11 @@ from pilot.configs.model_config import DEVICE
|
|||||||
from pilot.model.loader import _get_model_real_path
|
from pilot.model.loader import _get_model_real_path
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
EmbeddingModelParameters,
|
EmbeddingModelParameters,
|
||||||
EnvArgumentParser,
|
|
||||||
WorkerType,
|
WorkerType,
|
||||||
)
|
)
|
||||||
from pilot.model.worker.base import ModelWorker
|
from pilot.model.worker.base import ModelWorker
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_torch_cache
|
||||||
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
logger = logging.getLogger("model_worker")
|
logger = logging.getLogger("model_worker")
|
||||||
|
|
||||||
|
@ -23,15 +23,14 @@ from pilot.model.base import (
|
|||||||
)
|
)
|
||||||
from pilot.model.controller.registry import ModelRegistry
|
from pilot.model.controller.registry import ModelRegistry
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
EnvArgumentParser,
|
|
||||||
ModelParameters,
|
ModelParameters,
|
||||||
ModelWorkerParameters,
|
ModelWorkerParameters,
|
||||||
WorkerType,
|
WorkerType,
|
||||||
ParameterDescription,
|
|
||||||
)
|
)
|
||||||
from pilot.model.worker.base import ModelWorker
|
from pilot.model.worker.base import ModelWorker
|
||||||
from pilot.scene.base_message import ModelMessage
|
from pilot.scene.base_message import ModelMessage
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
|
from pilot.utils.parameter_utils import EnvArgumentParser, ParameterDescription
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
|
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
|
||||||
@ -148,6 +147,8 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
self.model_registry = model_registry
|
self.model_registry = model_registry
|
||||||
|
|
||||||
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
||||||
|
if isinstance(worker_type, WorkerType):
|
||||||
|
worker_type = worker_type.value
|
||||||
return f"{model_name}@{worker_type}"
|
return f"{model_name}@{worker_type}"
|
||||||
|
|
||||||
def add_worker(
|
def add_worker(
|
||||||
@ -166,6 +167,9 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
if not worker_params.worker_type:
|
if not worker_params.worker_type:
|
||||||
worker_params.worker_type = worker.worker_type()
|
worker_params.worker_type = worker.worker_type()
|
||||||
|
|
||||||
|
if isinstance(worker_params.worker_type, WorkerType):
|
||||||
|
worker_params.worker_type = worker_params.worker_type.value
|
||||||
|
|
||||||
worker_key = self._worker_key(
|
worker_key = self._worker_key(
|
||||||
worker_params.worker_type, worker_params.model_name
|
worker_params.worker_type, worker_params.model_name
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,7 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
sys.path.append(
|
sys.path.append(
|
||||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -23,16 +23,53 @@ def cli(log_level: str):
|
|||||||
logging.basicConfig(level=log_level, encoding="utf-8")
|
logging.basicConfig(level=log_level, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
def add_command_alias(command, name: str, hidden: bool = False):
|
def add_command_alias(command, name: str, hidden: bool = False, parent_group=None):
|
||||||
|
if not parent_group:
|
||||||
|
parent_group = cli
|
||||||
new_command = copy.deepcopy(command)
|
new_command = copy.deepcopy(command)
|
||||||
new_command.hidden = hidden
|
new_command.hidden = hidden
|
||||||
cli.add_command(new_command, name=name)
|
parent_group.add_command(new_command, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def start():
|
||||||
|
"""Start specific server."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def stop():
|
||||||
|
"""Start specific server."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
cli.add_command(start)
|
||||||
|
cli.add_command(stop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pilot.model.cli import model_cli_group
|
from pilot.model.cli import (
|
||||||
|
model_cli_group,
|
||||||
|
start_model_controller,
|
||||||
|
stop_model_controller,
|
||||||
|
start_model_worker,
|
||||||
|
stop_model_worker,
|
||||||
|
start_webserver,
|
||||||
|
stop_webserver,
|
||||||
|
start_apiserver,
|
||||||
|
stop_apiserver,
|
||||||
|
)
|
||||||
|
|
||||||
|
add_command_alias(model_cli_group, name="model", parent_group=cli)
|
||||||
|
add_command_alias(start_model_controller, name="controller", parent_group=start)
|
||||||
|
add_command_alias(start_model_worker, name="worker", parent_group=start)
|
||||||
|
add_command_alias(start_webserver, name="webserver", parent_group=start)
|
||||||
|
add_command_alias(start_apiserver, name="apiserver", parent_group=start)
|
||||||
|
|
||||||
|
add_command_alias(stop_model_controller, name="controller", parent_group=stop)
|
||||||
|
add_command_alias(stop_model_worker, name="worker", parent_group=stop)
|
||||||
|
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
|
||||||
|
add_command_alias(stop_apiserver, name="apiserver", parent_group=stop)
|
||||||
|
|
||||||
add_command_alias(model_cli_group, name="model")
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
|
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
|
||||||
|
|
||||||
|
314
pilot/utils/parameter_utils.py
Normal file
314
pilot/utils/parameter_utils.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, fields, MISSING
|
||||||
|
from typing import Any, List, Optional, Type, Union, Callable
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParameterDescription:
|
||||||
|
param_name: str
|
||||||
|
param_type: str
|
||||||
|
description: str
|
||||||
|
default_value: Optional[Any]
|
||||||
|
valid_values: Optional[List[Any]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseParameters:
|
||||||
|
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
|
||||||
|
"""
|
||||||
|
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
|
||||||
|
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if at least one field was updated, otherwise False.
|
||||||
|
"""
|
||||||
|
updated = False # Flag to indicate whether any field was updated
|
||||||
|
if isinstance(source, (BaseParameters, dict)):
|
||||||
|
for field_info in fields(self):
|
||||||
|
# Check if the field has a "fixed" tag in metadata
|
||||||
|
tags = field_info.metadata.get("tags")
|
||||||
|
tags = [] if not tags else tags.split(",")
|
||||||
|
if tags and "fixed" in tags:
|
||||||
|
continue # skip this field
|
||||||
|
# Get the new value from source (either another BaseParameters object or a dict)
|
||||||
|
new_value = (
|
||||||
|
getattr(source, field_info.name)
|
||||||
|
if isinstance(source, BaseParameters)
|
||||||
|
else source.get(field_info.name, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the new value is not None and different from the current value, update the field and set the flag
|
||||||
|
if new_value is not None and new_value != getattr(
|
||||||
|
self, field_info.name
|
||||||
|
):
|
||||||
|
setattr(self, field_info.name, new_value)
|
||||||
|
updated = True
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
class_name = self.__class__.__name__
|
||||||
|
parameters = [
|
||||||
|
f"\n\n=========================== {class_name} ===========================\n"
|
||||||
|
]
|
||||||
|
for field_info in fields(self):
|
||||||
|
value = getattr(self, field_info.name)
|
||||||
|
parameters.append(f"{field_info.name}: {value}")
|
||||||
|
parameters.append(
|
||||||
|
"\n======================================================================\n\n"
|
||||||
|
)
|
||||||
|
return "\n".join(parameters)
|
||||||
|
|
||||||
|
|
||||||
|
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
|
||||||
|
"""Get the value from the environment variable, ignoring the case of the key"""
|
||||||
|
if env_prefix:
|
||||||
|
env_key = env_prefix + env_key
|
||||||
|
return os.getenv(
|
||||||
|
env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EnvArgumentParser:
|
||||||
|
@staticmethod
|
||||||
|
def get_env_prefix(env_key: str) -> str:
|
||||||
|
if not env_key:
|
||||||
|
return None
|
||||||
|
env_key = env_key.replace("-", "_")
|
||||||
|
return env_key + "_"
|
||||||
|
|
||||||
|
def parse_args_into_dataclass(
|
||||||
|
self,
|
||||||
|
dataclass_type: Type,
|
||||||
|
env_prefix: str = None,
|
||||||
|
command_args: List[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
for field in fields(dataclass_type):
|
||||||
|
env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
|
||||||
|
if not env_var_value:
|
||||||
|
# Read without env prefix
|
||||||
|
env_var_value = _genenv_ignoring_key_case(field.name)
|
||||||
|
|
||||||
|
if env_var_value:
|
||||||
|
env_var_value = env_var_value.strip()
|
||||||
|
if field.type is int or field.type == Optional[int]:
|
||||||
|
env_var_value = int(env_var_value)
|
||||||
|
elif field.type is float or field.type == Optional[float]:
|
||||||
|
env_var_value = float(env_var_value)
|
||||||
|
elif field.type is bool or field.type == Optional[bool]:
|
||||||
|
env_var_value = env_var_value.lower() == "true"
|
||||||
|
elif field.type is str or field.type == Optional[str]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported parameter type {field.type}")
|
||||||
|
if not env_var_value:
|
||||||
|
env_var_value = kwargs.get(field.name)
|
||||||
|
|
||||||
|
# Add a command-line argument for this field
|
||||||
|
help_text = field.metadata.get("help", "")
|
||||||
|
valid_values = field.metadata.get("valid_values", None)
|
||||||
|
|
||||||
|
argument_kwargs = {
|
||||||
|
"type": EnvArgumentParser._get_argparse_type(field.type),
|
||||||
|
"help": help_text,
|
||||||
|
"choices": valid_values,
|
||||||
|
"required": EnvArgumentParser._is_require_type(field.type),
|
||||||
|
}
|
||||||
|
if field.default != MISSING:
|
||||||
|
argument_kwargs["default"] = field.default
|
||||||
|
argument_kwargs["required"] = False
|
||||||
|
if env_var_value:
|
||||||
|
argument_kwargs["default"] = env_var_value
|
||||||
|
argument_kwargs["required"] = False
|
||||||
|
|
||||||
|
parser.add_argument(f"--{field.name}", **argument_kwargs)
|
||||||
|
|
||||||
|
# Parse the command-line arguments
|
||||||
|
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
|
||||||
|
# cmd_args = parser.parse_args(args=command_args)
|
||||||
|
# print(f"cmd_args: {cmd_args}")
|
||||||
|
for field in fields(dataclass_type):
|
||||||
|
# cmd_line_value = getattr(cmd_args, field.name)
|
||||||
|
if field.name in cmd_args:
|
||||||
|
cmd_line_value = getattr(cmd_args, field.name)
|
||||||
|
if cmd_line_value is not None:
|
||||||
|
kwargs[field.name] = cmd_line_value
|
||||||
|
|
||||||
|
return dataclass_type(**kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(description=dataclass_type.__doc__)
|
||||||
|
for field in fields(dataclass_type):
|
||||||
|
help_text = field.metadata.get("help", "")
|
||||||
|
valid_values = field.metadata.get("valid_values", None)
|
||||||
|
argument_kwargs = {
|
||||||
|
"type": EnvArgumentParser._get_argparse_type(field.type),
|
||||||
|
"help": help_text,
|
||||||
|
"choices": valid_values,
|
||||||
|
"required": EnvArgumentParser._is_require_type(field.type),
|
||||||
|
}
|
||||||
|
if field.default != MISSING:
|
||||||
|
argument_kwargs["default"] = field.default
|
||||||
|
argument_kwargs["required"] = False
|
||||||
|
parser.add_argument(f"--{field.name}", **argument_kwargs)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_click_option(
|
||||||
|
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||||
|
):
|
||||||
|
import click
|
||||||
|
import functools
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
# TODO dynamic configuration
|
||||||
|
# pre_args = _SimpleArgParser('model_name', 'model_path')
|
||||||
|
# pre_args.parse()
|
||||||
|
# print(pre_args)
|
||||||
|
|
||||||
|
combined_fields = OrderedDict()
|
||||||
|
if _dynamic_factory:
|
||||||
|
_types = _dynamic_factory()
|
||||||
|
if _types:
|
||||||
|
dataclass_types = list(_types)
|
||||||
|
for dataclass_type in dataclass_types:
|
||||||
|
for field in fields(dataclass_type):
|
||||||
|
if field.name not in combined_fields:
|
||||||
|
combined_fields[field.name] = field
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
for field_name, field in reversed(combined_fields.items()):
|
||||||
|
help_text = field.metadata.get("help", "")
|
||||||
|
valid_values = field.metadata.get("valid_values", None)
|
||||||
|
cli_params = {
|
||||||
|
"default": None if field.default is MISSING else field.default,
|
||||||
|
"help": help_text,
|
||||||
|
"show_default": True,
|
||||||
|
"required": field.default is MISSING,
|
||||||
|
}
|
||||||
|
if valid_values:
|
||||||
|
cli_params["type"] = click.Choice(valid_values)
|
||||||
|
real_type = EnvArgumentParser._get_argparse_type(field.type)
|
||||||
|
if real_type is int:
|
||||||
|
cli_params["type"] = click.INT
|
||||||
|
elif real_type is float:
|
||||||
|
cli_params["type"] = click.FLOAT
|
||||||
|
elif real_type is str:
|
||||||
|
cli_params["type"] = click.STRING
|
||||||
|
elif real_type is bool:
|
||||||
|
cli_params["is_flag"] = True
|
||||||
|
|
||||||
|
option_decorator = click.option(
|
||||||
|
# f"--{field_name.replace('_', '-')}", **cli_params
|
||||||
|
f"--{field_name}",
|
||||||
|
**cli_params,
|
||||||
|
)
|
||||||
|
func = option_decorator(func)
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_argparse_type(field_type: Type) -> Type:
|
||||||
|
# Return the appropriate type for argparse to use based on the field type
|
||||||
|
if field_type is int or field_type == Optional[int]:
|
||||||
|
return int
|
||||||
|
elif field_type is float or field_type == Optional[float]:
|
||||||
|
return float
|
||||||
|
elif field_type is bool or field_type == Optional[bool]:
|
||||||
|
return bool
|
||||||
|
elif field_type is str or field_type == Optional[str]:
|
||||||
|
return str
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported parameter type {field_type}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_argparse_type_str(field_type: Type) -> str:
|
||||||
|
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
|
||||||
|
if argparse_type is int:
|
||||||
|
return "int"
|
||||||
|
elif argparse_type is float:
|
||||||
|
return "float"
|
||||||
|
elif argparse_type is bool:
|
||||||
|
return "bool"
|
||||||
|
else:
|
||||||
|
return "str"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_require_type(field_type: Type) -> str:
|
||||||
|
return field_type not in [Optional[int], Optional[float], Optional[bool]]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
||||||
|
descriptions = []
|
||||||
|
for field in fields(dataclass_type):
|
||||||
|
descriptions.append(
|
||||||
|
ParameterDescription(
|
||||||
|
param_name=field.name,
|
||||||
|
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
||||||
|
description=field.metadata.get("help", None),
|
||||||
|
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
|
||||||
|
valid_values=field.metadata.get("valid_values", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return descriptions
|
||||||
|
|
||||||
|
|
||||||
|
class _SimpleArgParser:
|
||||||
|
def __init__(self, *args):
|
||||||
|
self.params = {arg.replace("_", "-"): None for arg in args}
|
||||||
|
|
||||||
|
def parse(self, args=None):
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if args is None:
|
||||||
|
args = sys.argv[1:]
|
||||||
|
else:
|
||||||
|
args = list(args)
|
||||||
|
prev_arg = None
|
||||||
|
for arg in args:
|
||||||
|
if arg.startswith("--"):
|
||||||
|
if prev_arg:
|
||||||
|
self.params[prev_arg] = None
|
||||||
|
prev_arg = arg[2:]
|
||||||
|
else:
|
||||||
|
if prev_arg:
|
||||||
|
self.params[prev_arg] = arg
|
||||||
|
prev_arg = None
|
||||||
|
|
||||||
|
if prev_arg:
|
||||||
|
self.params[prev_arg] = None
|
||||||
|
|
||||||
|
def _get_param(self, key):
|
||||||
|
return self.params.get(key.replace("_", "-"), None)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
return self._get_param(item)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._get_param(key)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
return self._get_param(key) or default
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "\n".join(
|
||||||
|
[f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()]
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user