mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
feat: Command-line tool design and multi-model integration
This commit is contained in:
parent
05712d39b9
commit
e4dd6060da
219
docs/getting_started/install/llm/cluster/model_cluster.md
Normal file
219
docs/getting_started/install/llm/cluster/model_cluster.md
Normal file
@ -0,0 +1,219 @@
|
||||
Cluster deployment
|
||||
==================================
|
||||
|
||||
## Model cluster deployment
|
||||
|
||||
|
||||
**Installing Command-Line Tool**
|
||||
|
||||
All operations below are performed using the `dbgpt` command. To use the `dbgpt` command, you need to install the DB-GPT project with `pip install -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` as a substitute for the `dbgpt` command.
|
||||
|
||||
### Launch Model Controller
|
||||
|
||||
```bash
|
||||
dbgpt start controller
|
||||
```
|
||||
|
||||
By default, the Model Controller starts on port 8000.
|
||||
|
||||
|
||||
### Launch Model Worker
|
||||
|
||||
If you are starting `chatglm2-6b`:
|
||||
|
||||
```bash
|
||||
dbgpt start worker --model_name chatglm2-6b \
|
||||
--model_path /app/models/chatglm2-6b \
|
||||
--port 8001 \
|
||||
--controller_addr http://127.0.0.1:8000
|
||||
```
|
||||
|
||||
If you are starting `vicuna-13b-v1.5`:
|
||||
|
||||
```bash
|
||||
dbgpt start worker --model_name vicuna-13b-v1.5 \
|
||||
--model_path /app/models/vicuna-13b-v1.5 \
|
||||
--port 8002 \
|
||||
--controller_addr http://127.0.0.1:8000
|
||||
```
|
||||
|
||||
Note: Be sure to use your own model name and model path.
|
||||
|
||||
|
||||
Check your model:
|
||||
|
||||
```bash
|
||||
dbgpt model list
|
||||
```
|
||||
|
||||
You will see the following output:
|
||||
```
|
||||
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat |
|
||||
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
| chatglm2-6b | llm | 172.17.0.6 | 8001 | True | True | None | 2023-08-31T04:48:45.252939 |
|
||||
| vicuna-13b-v1.5 | llm | 172.17.0.6 | 8002 | True | True | None | 2023-08-31T04:48:55.136676 |
|
||||
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
```
|
||||
|
||||
### Connect to the model service in the webserver (dbgpt_server)
|
||||
|
||||
**First, modify the `.env` file to change the model name and the Model Controller connection address.**
|
||||
|
||||
```bash
|
||||
LLM_MODEL=vicuna-13b-v1.5
|
||||
# The current default MODEL_SERVER address is the address of the Model Controller
|
||||
MODEL_SERVER=http://127.0.0.1:8000
|
||||
```
|
||||
|
||||
#### Start the webserver
|
||||
|
||||
```bash
|
||||
python pilot/server/dbgpt_server.py --light
|
||||
```
|
||||
|
||||
`--light` indicates not to start the embedded model service.
|
||||
|
||||
Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` to start:
|
||||
|
||||
```bash
|
||||
LLM_MODEL=chatglm2-6b python pilot/server/dbgpt_server.py --light
|
||||
```
|
||||
|
||||
|
||||
### More Command-Line Usages
|
||||
|
||||
You can view more command-line usages through the help command.
|
||||
|
||||
**View the `dbgpt` help**
|
||||
```bash
|
||||
dbgpt --help
|
||||
```
|
||||
|
||||
You will see the basic command parameters and usage:
|
||||
|
||||
```
|
||||
Usage: dbgpt [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--log-level TEXT Log level
|
||||
--version Show the version and exit.
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
model Clients that manage model serving
|
||||
start Start specific server.
|
||||
stop Start specific server.
|
||||
```
|
||||
|
||||
**View the `dbgpt start` help**
|
||||
|
||||
```bash
|
||||
dbgpt start --help
|
||||
```
|
||||
|
||||
Here you can see the related commands and usage for start:
|
||||
|
||||
```
|
||||
Usage: dbgpt start [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Start specific server.
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
apiserver Start apiserver(TODO)
|
||||
controller Start model controller
|
||||
webserver Start webserver(dbgpt_server.py)
|
||||
worker Start model worker
|
||||
```
|
||||
|
||||
**View the `dbgpt start worker`help**
|
||||
|
||||
```bash
|
||||
dbgpt start worker --help
|
||||
```
|
||||
|
||||
Here you can see the parameters to start Model Worker:
|
||||
|
||||
```
|
||||
Usage: dbgpt start worker [OPTIONS]
|
||||
|
||||
Start model worker
|
||||
|
||||
Options:
|
||||
--model_name TEXT Model name [required]
|
||||
--model_path TEXT Model path [required]
|
||||
--worker_type TEXT Worker type
|
||||
--worker_class TEXT Model worker class, pilot.model.worker.defau
|
||||
lt_worker.DefaultModelWorker
|
||||
--host TEXT Model worker deploy host [default: 0.0.0.0]
|
||||
--port INTEGER Model worker deploy port [default: 8000]
|
||||
--limit_model_concurrency INTEGER
|
||||
Model concurrency limit [default: 5]
|
||||
--standalone Standalone mode. If True, embedded Run
|
||||
ModelController
|
||||
--register Register current worker to model controller
|
||||
[default: True]
|
||||
--worker_register_host TEXT The ip address of current worker to register
|
||||
to ModelController. If None, the address is
|
||||
automatically determined
|
||||
--controller_addr TEXT The Model controller address to register
|
||||
--send_heartbeat Send heartbeat to model controller
|
||||
[default: True]
|
||||
--heartbeat_interval INTEGER The interval for sending heartbeats
|
||||
(seconds) [default: 20]
|
||||
--device TEXT Device to run model. If None, the device is
|
||||
automatically determined
|
||||
--model_type TEXT Model type, huggingface or llama.cpp
|
||||
[default: huggingface]
|
||||
--prompt_template TEXT Prompt template. If None, the prompt
|
||||
template is automatically determined from
|
||||
model path, supported template: zero_shot,vi
|
||||
cuna_v1.1,llama-2,alpaca,baichuan-chat
|
||||
--max_context_size INTEGER Maximum context size [default: 4096]
|
||||
--num_gpus INTEGER The number of gpus you expect to use, if it
|
||||
is empty, use all of them as much as
|
||||
possible
|
||||
--max_gpu_memory TEXT The maximum memory limit of each GPU, only
|
||||
valid in multi-GPU configuration
|
||||
--cpu_offloading CPU offloading
|
||||
--load_8bit 8-bit quantization
|
||||
--load_4bit 4-bit quantization
|
||||
--quant_type TEXT Quantization datatypes, `fp4` (four bit
|
||||
float) and `nf4` (normal four bit float),
|
||||
only valid when load_4bit=True [default:
|
||||
nf4]
|
||||
--use_double_quant Nested quantization, only valid when
|
||||
load_4bit=True [default: True]
|
||||
--compute_dtype TEXT Model compute type
|
||||
--trust_remote_code Trust remote code [default: True]
|
||||
--verbose Show verbose output.
|
||||
--help Show this message and exit.
|
||||
```
|
||||
|
||||
**View the `dbgpt model`help**
|
||||
|
||||
```bash
|
||||
dbgpt model --help
|
||||
```
|
||||
|
||||
The `dbgpt model ` command can connect to the Model Controller via the Model Controller address and then manage a remote model:
|
||||
|
||||
```
|
||||
Usage: dbgpt model [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Clients that manage model serving
|
||||
|
||||
Options:
|
||||
--address TEXT Address of the Model Controller to connect to. Just support
|
||||
light deploy model [default: http://127.0.0.1:8000]
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
list List model instances
|
||||
restart Restart model instances
|
||||
start Start model instances
|
||||
stop Stop model instances
|
||||
```
|
@ -19,6 +19,7 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
|
||||
|
||||
- llama_cpp
|
||||
- quantization
|
||||
- cluster deployment
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
@ -28,3 +29,4 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
|
||||
|
||||
./llama/llama_cpp.md
|
||||
./quantization/quantization.md
|
||||
./cluster/model_cluster.md
|
||||
|
@ -0,0 +1,172 @@
|
||||
# SOME DESCRIPTIVE TITLE.
|
||||
# Copyright (C) 2023, csunny
|
||||
# This file is distributed under the same license as the DB-GPT package.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
|
||||
#
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-08-31 16:43+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
"Language-Team: zh_CN <LL@li.org>\n"
|
||||
"Plural-Forms: nplurals=1; plural=0;\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:1
|
||||
#: b0614062af2a4c039aa8947e45c4382a
|
||||
msgid "Cluster deployment"
|
||||
msgstr "集群部署"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:4
|
||||
#: 568a2af68de5410ebd63f34b8a3665c5
|
||||
msgid "Model cluster deployment"
|
||||
msgstr "模型的集群部署"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:7
|
||||
#: 41b9c4bf32bd4c95a1a7b733543a322b
|
||||
msgid "**Installing Command-Line Tool**"
|
||||
msgstr "**命令行工具的安装**"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:9
|
||||
#: b340cf9309a742648fb05910dc6ba741
|
||||
msgid ""
|
||||
"All operations below are performed using the `dbgpt` command. To use the "
|
||||
"`dbgpt` command, you need to install the DB-GPT project with `pip install"
|
||||
" -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` "
|
||||
"as a substitute for the `dbgpt` command."
|
||||
msgstr "下面的操作都是使用 `dbgpt` 命令来进行,为了使用 `dbgpt` 命令,你需要使用 `pip install -e .` 来安装 DB-GPT 项目,当然,你也可以使用 `python pilot/scripts/cli_scripts.py` 来替代命令 `dbgpt`。"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:11
|
||||
#: 9299e223f8ec4060aaf40ddba019bfbf
|
||||
msgid "Launch Model Controller"
|
||||
msgstr "启动 Model Controller"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:17
|
||||
#: f858b280e1e54feba6639c7270b512c3
|
||||
msgid "By default, the Model Controller starts on port 8000."
|
||||
msgstr "默认情况下 Model Controller 启动端口是 8000。"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:20
|
||||
#: 51eecd5331ce441abd9658f873132935
|
||||
msgid "Launch Model Worker"
|
||||
msgstr "启动 Model Worker"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:22
|
||||
#: 09b1b7ffe3b740de94163f2ef24e8743
|
||||
msgid "If you are starting `chatglm2-6b`:"
|
||||
msgstr "如果你启动的是 `chatglm2-6b`"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:31
|
||||
#: 47f9ec2b957e4a2a858a8833ce2cd104
|
||||
msgid "If you are starting `vicuna-13b-v1.5`:"
|
||||
msgstr "如果你启动的是 `vicuna-13b-v1.5`"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:40
|
||||
#: 95a04817be5a473e8b4f7086ac8dfe2c
|
||||
msgid "Note: Be sure to use your own model name and model path."
|
||||
msgstr "注意:注意使用你自己的模型名称和模型路径。"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:43
|
||||
#: 644fff6c28ca49efa06e0f2c89f45e27
|
||||
msgid "Check your model:"
|
||||
msgstr "查看你的模型:"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:49
|
||||
#: 69df9c27c17e44b8a6cb010360c56338
|
||||
msgid "You will see the following output:"
|
||||
msgstr "你将会看到下面的输出:"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:59
|
||||
#: a1e1489068d04b51bb239bd32ef3546b
|
||||
msgid "Connect to the model service in the webserver (dbgpt_server)"
|
||||
msgstr "在 webserver(dbgpt_server) 中连接模型服务"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:61
|
||||
#: 0ab444ac06f34ce0994fe1f945abecde
|
||||
msgid ""
|
||||
"**First, modify the `.env` file to change the model name and the Model "
|
||||
"Controller connection address.**"
|
||||
msgstr "**先修改 `.env` 文件,修改模型名称和 Model Controller 连接地址。**"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:69
|
||||
#: 5125b43d70d744a491a70aa1001303be
|
||||
msgid "Start the webserver"
|
||||
msgstr ""
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:75
|
||||
#: f12f463c27c74337ba8cc9c7e4e934b4
|
||||
msgid "`--light` indicates not to start the embedded model service."
|
||||
msgstr "`--light` 表示不启动内嵌的模型服务。"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:77
|
||||
#: de2f7b040fd04c73a89ae453db84108f
|
||||
msgid ""
|
||||
"Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` "
|
||||
"to start:"
|
||||
msgstr "更简单的,你可以在命令行前添加 `LLM_MODEL=chatglm2-6b` 来启动:"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:84
|
||||
#: 83b93195a9fb45aeaeead6b05719f164
|
||||
msgid "More Command-Line Usages"
|
||||
msgstr "命令行的更多用法"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:86
|
||||
#: 8622edf06c1d443188e19a26ed306d58
|
||||
msgid "You can view more command-line usages through the help command."
|
||||
msgstr "你可以通过帮助命令来查看更多的命令行用法。"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:88
|
||||
#: 391c6fa585d045f5a583e26478adda7e
|
||||
msgid "**View the `dbgpt` help**"
|
||||
msgstr "**查看 `dbgpt` 的帮助**"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:93
|
||||
#: 4247e8ee8f084fa780fb523530cc64ab
|
||||
msgid "You will see the basic command parameters and usage:"
|
||||
msgstr "你将会看到基础的命令参数和用法:"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:109
|
||||
#: 87b43469e93e4b62a59488d1ef663b51
|
||||
msgid "**View the `dbgpt start` help**"
|
||||
msgstr "**查看 `dbgpt start` 的帮助**"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:115
|
||||
#: 213f7bfc459d4439b3a839292e2684e6
|
||||
msgid "Here you can see the related commands and usage for start:"
|
||||
msgstr "这里你能看到 start 相关的命令和用法。"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:132
|
||||
#: eb616a87e01d4cdc87dc61e462d83729
|
||||
msgid "**View the `dbgpt start worker`help**"
|
||||
msgstr "**查看 `dbgpt start worker`的帮助**"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:138
|
||||
#: 75918aa07d7c4ec39271358f2b474a63
|
||||
msgid "Here you can see the parameters to start Model Worker:"
|
||||
msgstr "这里你能看启动 Model Worker 的参数:"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:196
|
||||
#: 5ee6ea98af4d4693b8cbf059ad3001d8
|
||||
msgid "**View the `dbgpt model`help**"
|
||||
msgstr "**查看 `dbgpt model`的帮助**"
|
||||
|
||||
#: ../../getting_started/install/llm/cluster/model_cluster.md:202
|
||||
#: a96a1b58987a481fac2d527470226b90
|
||||
msgid ""
|
||||
"The `dbgpt model ` command can connect to the Model Controller via the "
|
||||
"Model Controller address and then manage a remote model:"
|
||||
msgstr "`dbgpt model ` 命令可以通过 Model Controller 地址连接到 Model Controller,然后对远程对某个模型进行管理。"
|
||||
|
||||
#~ msgid ""
|
||||
#~ "First, modify the `.env` file to "
|
||||
#~ "change the model name and the "
|
||||
#~ "Model Controller connection address."
|
||||
#~ msgstr ""
|
||||
|
@ -8,7 +8,7 @@ msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-08-17 23:29+0800\n"
|
||||
"POT-Creation-Date: 2023-08-31 16:38+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
@ -20,80 +20,85 @@ msgstr ""
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:2
|
||||
#: ../../getting_started/install/llm/llm.rst:23
|
||||
#: 9e4baaff732b49f1baded5fd1bde4bd5
|
||||
#: ../../getting_started/install/llm/llm.rst:24
|
||||
#: b348d4df8ca44dd78b42157a8ff6d33d
|
||||
msgid "LLM Usage"
|
||||
msgstr "LLM使用"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:3 d80e7360f17c4be889f220e2a846a81e
|
||||
#: ../../getting_started/install/llm/llm.rst:3 7f5960a7e5634254b330da27be87594b
|
||||
msgid ""
|
||||
"DB-GPT provides a management and deployment solution for multiple models."
|
||||
" This chapter mainly discusses how to deploy different models."
|
||||
msgstr "DB-GPT提供了多模型的管理和部署方案,本章主要讲解针对不同的模型该怎么部署"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:18
|
||||
#: 1cebd6c84b924eeb877ba31bd520328f
|
||||
#: b844ab204ec740ec9d7d191bb841f09e
|
||||
msgid ""
|
||||
"Multi LLMs Support, Supports multiple large language models, currently "
|
||||
"supporting"
|
||||
msgstr "目前DB-GPT已适配如下模型"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:9 1163d5d3a8834b24bb6a5716d4fcc680
|
||||
#: ../../getting_started/install/llm/llm.rst:9 c141437ddaf84c079360008343041b2f
|
||||
msgid "🔥 Vicuna-v1.5(7b,13b)"
|
||||
msgstr "🔥 Vicuna-v1.5(7b,13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:10
|
||||
#: 318dbc8e8de64bfb9cbc65343d050500
|
||||
#: d32b1e3f114c4eab8782b497097c1b37
|
||||
msgid "🔥 llama-2(7b,13b,70b)"
|
||||
msgstr "🔥 llama-2(7b,13b,70b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:11
|
||||
#: f2b1ef2bfbde46889131b6852d1211e8
|
||||
#: 0a417ee4d008421da07fff7add5d05eb
|
||||
msgid "WizardLM-v1.2(13b)"
|
||||
msgstr "WizardLM-v1.2(13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:12
|
||||
#: 2462160e5928434aaa6f58074b585c14
|
||||
#: 199e1a9fe3324dc8a1bcd9cd0b1ef047
|
||||
msgid "Vicuna (7b,13b)"
|
||||
msgstr "Vicuna (7b,13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:13
|
||||
#: 61cc1b9ed1914e379dee87034e9fbaea
|
||||
#: a9e4c5100534450db3a583fa5850e4be
|
||||
msgid "ChatGLM-6b (int4,int8)"
|
||||
msgstr "ChatGLM-6b (int4,int8)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:14
|
||||
#: 653f5fd3b6ab4c508a2ce449231b5a17
|
||||
#: 943324289eb94042b52fd824189cd93f
|
||||
msgid "ChatGLM2-6b (int4,int8)"
|
||||
msgstr "ChatGLM2-6b (int4,int8)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:15
|
||||
#: 6a1cf7a271f6496ba48d593db1f756f1
|
||||
#: f1226fdfac3b4e9d88642ffa69d75682
|
||||
msgid "guanaco(7b,13b,33b)"
|
||||
msgstr "guanaco(7b,13b,33b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:16
|
||||
#: f15083886362437ba96e1cb9ade738aa
|
||||
#: 3f2457f56eb341b6bc431c9beca8f4df
|
||||
msgid "Gorilla(7b,13b)"
|
||||
msgstr "Gorilla(7b,13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:17
|
||||
#: 5e1ded19d1b24ba485e5d2bf22ce2db7
|
||||
#: 86c8ce37be1c4a7ea3fc382100d77a9c
|
||||
msgid "Baichuan(7b,13b)"
|
||||
msgstr "Baichuan(7b,13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:18
|
||||
#: 30f6f71dd73b48d29a3647d58a9cdcaf
|
||||
#: 538111af95ad414cb2e631a89f9af379
|
||||
msgid "OpenAI"
|
||||
msgstr "OpenAI"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:20
|
||||
#: 9ad21acdaccb41c1bbfa30b5ce114732
|
||||
#: a203325b7ec248f7bff61ae89226a000
|
||||
msgid "llama_cpp"
|
||||
msgstr "llama_cpp"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:21
|
||||
#: b5064649f7c9443c9b9f15ec7fc02434
|
||||
#: 21a50634198047228bc51a03d2c31292
|
||||
msgid "quantization"
|
||||
msgstr "quantization"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:22
|
||||
#: dfaec4b04e6e45ff9c884b41534b1a79
|
||||
msgid "cluster deployment"
|
||||
msgstr ""
|
||||
|
||||
|
@ -166,6 +166,9 @@ class Config(metaclass=Singleton):
|
||||
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True"
|
||||
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
|
||||
self.IS_LOAD_8BIT = False
|
||||
# In order to be compatible with the new and old model parameter design
|
||||
os.environ["load_8bit"] = str(self.IS_LOAD_8BIT)
|
||||
os.environ["load_4bit"] = str(self.IS_LOAD_4BIT)
|
||||
|
||||
### EMBEDDING Configuration
|
||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
|
@ -7,25 +7,33 @@ from pilot.model.worker.manager import (
|
||||
WorkerApplyRequest,
|
||||
WorkerApplyType,
|
||||
)
|
||||
from pilot.model.parameter import (
|
||||
ModelControllerParameters,
|
||||
ModelWorkerParameters,
|
||||
ModelParameters,
|
||||
)
|
||||
from pilot.utils import get_or_create_event_loop
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
|
||||
@click.group("model")
|
||||
def model_cli_group():
|
||||
pass
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default="http://127.0.0.1:8000",
|
||||
required=False,
|
||||
show_default=True,
|
||||
help=(
|
||||
"Address of the Model Controller to connect to."
|
||||
"Address of the Model Controller to connect to. "
|
||||
"Just support light deploy model"
|
||||
),
|
||||
)
|
||||
def model_cli_group():
|
||||
"""Clients that manage model serving"""
|
||||
pass
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@click.option(
|
||||
"--model-name", type=str, default=None, required=False, help=("The name of model")
|
||||
)
|
||||
@ -79,16 +87,6 @@ def list(address: str, model_name: str, model_type: str):
|
||||
|
||||
|
||||
def add_model_options(func):
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default="http://127.0.0.1:8000",
|
||||
required=False,
|
||||
help=(
|
||||
"Address of the Model Controller to connect to."
|
||||
"Just support light deploy model"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--model-name",
|
||||
type=str,
|
||||
@ -149,3 +147,57 @@ def worker_apply(
|
||||
)
|
||||
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
|
||||
print(res)
|
||||
|
||||
|
||||
@click.command(name="controller")
|
||||
@EnvArgumentParser.create_click_option(ModelControllerParameters)
|
||||
def start_model_controller(**kwargs):
|
||||
"""Start model controller"""
|
||||
from pilot.model.controller.controller import run_model_controller
|
||||
|
||||
run_model_controller()
|
||||
|
||||
|
||||
@click.command(name="controller")
|
||||
def stop_model_controller(**kwargs):
|
||||
"""Start model controller"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@click.command(name="worker")
|
||||
@EnvArgumentParser.create_click_option(ModelWorkerParameters, ModelParameters)
|
||||
def start_model_worker(**kwargs):
|
||||
"""Start model worker"""
|
||||
from pilot.model.worker.manager import run_worker_manager
|
||||
|
||||
run_worker_manager()
|
||||
|
||||
|
||||
@click.command(name="worker")
|
||||
def stop_model_worker(**kwargs):
|
||||
"""Stop model worker"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@click.command(name="webserver")
|
||||
def start_webserver(**kwargs):
|
||||
"""Start webserver(dbgpt_server.py)"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@click.command(name="webserver")
|
||||
def stop_webserver(**kwargs):
|
||||
"""Stop webserver(dbgpt_server.py)"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@click.command(name="apiserver")
|
||||
def start_apiserver(**kwargs):
|
||||
"""Start apiserver(TODO)"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@click.command(name="controller")
|
||||
def stop_apiserver(**kwargs):
|
||||
"""Start apiserver(TODO)"""
|
||||
raise NotImplementedError
|
||||
|
@ -1,9 +1,11 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from pilot.model.base import ModelInstance
|
||||
from pilot.model.parameter import ModelControllerParameters
|
||||
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
|
||||
class ModelController:
|
||||
@ -63,3 +65,22 @@ async def api_get_all_instances(model_name: str = None, healthy_only: bool = Fal
|
||||
@router.post("/controller/heartbeat")
|
||||
async def api_model_heartbeat(request: ModelInstance):
|
||||
return await controller.send_heartbeat(request)
|
||||
|
||||
|
||||
def run_model_controller():
|
||||
import uvicorn
|
||||
|
||||
parser = EnvArgumentParser()
|
||||
env_prefix = "controller_"
|
||||
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
||||
ModelControllerParameters, env_prefix=env_prefix
|
||||
)
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api")
|
||||
uvicorn.run(
|
||||
app, host=controller_params.host, port=controller_params.port, log_level="info"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_model_controller()
|
||||
|
@ -8,14 +8,12 @@ from pilot.configs.model_config import DEVICE
|
||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
||||
from pilot.model.compression import compress_module
|
||||
from pilot.model.parameter import (
|
||||
EnvArgumentParser,
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
_genenv_ignoring_key_case,
|
||||
)
|
||||
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.utils import get_gpu_memory
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
|
||||
from pilot.logs import logger
|
||||
|
||||
|
||||
|
@ -1,139 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass, field, fields
|
||||
from dataclasses import dataclass, field, fields, MISSING
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pilot.model.conversation import conv_templates
|
||||
from pilot.utils.parameter_utils import BaseParameters
|
||||
|
||||
suported_prompt_templates = ",".join(conv_templates.keys())
|
||||
|
||||
|
||||
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
|
||||
"""Get the value from the environment variable, ignoring the case of the key"""
|
||||
if env_prefix:
|
||||
env_key = env_prefix + env_key
|
||||
return os.getenv(
|
||||
env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value))
|
||||
)
|
||||
|
||||
|
||||
class EnvArgumentParser:
|
||||
@staticmethod
|
||||
def get_env_prefix(env_key: str) -> str:
|
||||
if not env_key:
|
||||
return None
|
||||
env_key = env_key.replace("-", "_")
|
||||
return env_key + "_"
|
||||
|
||||
def parse_args_into_dataclass(
|
||||
self,
|
||||
dataclass_type: Type,
|
||||
env_prefix: str = None,
|
||||
command_args: List[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
||||
parser = argparse.ArgumentParser()
|
||||
for field in fields(dataclass_type):
|
||||
env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
|
||||
if not env_var_value:
|
||||
# Read without env prefix
|
||||
env_var_value = _genenv_ignoring_key_case(field.name)
|
||||
|
||||
if env_var_value:
|
||||
env_var_value = env_var_value.strip()
|
||||
if field.type is int or field.type == Optional[int]:
|
||||
env_var_value = int(env_var_value)
|
||||
elif field.type is float or field.type == Optional[float]:
|
||||
env_var_value = float(env_var_value)
|
||||
elif field.type is bool or field.type == Optional[bool]:
|
||||
env_var_value = env_var_value.lower() == "true"
|
||||
elif field.type is str or field.type == Optional[str]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type {field.type}")
|
||||
if not env_var_value:
|
||||
env_var_value = kwargs.get(field.name)
|
||||
if not env_var_value:
|
||||
env_var_value = field.default
|
||||
|
||||
# Add a command-line argument for this field
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
parser.add_argument(
|
||||
f"--{field.name}",
|
||||
type=self._get_argparse_type(field.type),
|
||||
help=help_text,
|
||||
choices=valid_values,
|
||||
default=env_var_value,
|
||||
)
|
||||
|
||||
# Parse the command-line arguments
|
||||
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
|
||||
print(f"cmd_args: {cmd_args}")
|
||||
for field in fields(dataclass_type):
|
||||
# cmd_line_value = getattr(cmd_args, field.name)
|
||||
if field.name in cmd_args:
|
||||
cmd_line_value = getattr(cmd_args, field.name)
|
||||
if cmd_line_value is not None:
|
||||
kwargs[field.name] = cmd_line_value
|
||||
|
||||
return dataclass_type(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_argparse_type(field_type: Type) -> Type:
|
||||
# Return the appropriate type for argparse to use based on the field type
|
||||
if field_type is int or field_type == Optional[int]:
|
||||
return int
|
||||
elif field_type is float or field_type == Optional[float]:
|
||||
return float
|
||||
elif field_type is bool or field_type == Optional[bool]:
|
||||
return bool
|
||||
elif field_type is str or field_type == Optional[str]:
|
||||
return str
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type {field_type}")
|
||||
|
||||
@staticmethod
|
||||
def _get_argparse_type_str(field_type: Type) -> str:
|
||||
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
|
||||
if argparse_type is int:
|
||||
return "int"
|
||||
elif argparse_type is float:
|
||||
return "float"
|
||||
elif argparse_type is bool:
|
||||
return "bool"
|
||||
else:
|
||||
return "str"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterDescription:
|
||||
param_name: str
|
||||
param_type: str
|
||||
description: str
|
||||
default_value: Optional[Any]
|
||||
valid_values: Optional[List[Any]]
|
||||
|
||||
|
||||
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
||||
descriptions = []
|
||||
for field in fields(dataclass_type):
|
||||
descriptions.append(
|
||||
ParameterDescription(
|
||||
param_name=field.name,
|
||||
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
||||
description=field.metadata.get("help", None),
|
||||
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
|
||||
valid_values=field.metadata.get("valid_values", None),
|
||||
)
|
||||
)
|
||||
return descriptions
|
||||
|
||||
|
||||
class WorkerType(str, Enum):
|
||||
LLM = "llm"
|
||||
TEXT2VEC = "text2vec"
|
||||
@ -144,58 +20,13 @@ class WorkerType(str, Enum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseParameters:
|
||||
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
|
||||
"""
|
||||
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
|
||||
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
|
||||
|
||||
Args:
|
||||
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
|
||||
|
||||
Returns:
|
||||
bool: True if at least one field was updated, otherwise False.
|
||||
"""
|
||||
updated = False # Flag to indicate whether any field was updated
|
||||
if isinstance(source, (BaseParameters, dict)):
|
||||
for field_info in fields(self):
|
||||
# Check if the field has a "fixed" tag in metadata
|
||||
tags = field_info.metadata.get("tags")
|
||||
tags = [] if not tags else tags.split(",")
|
||||
if tags and "fixed" in tags:
|
||||
continue # skip this field
|
||||
# Get the new value from source (either another BaseParameters object or a dict)
|
||||
new_value = (
|
||||
getattr(source, field_info.name)
|
||||
if isinstance(source, BaseParameters)
|
||||
else source.get(field_info.name, None)
|
||||
)
|
||||
|
||||
# If the new value is not None and different from the current value, update the field and set the flag
|
||||
if new_value is not None and new_value != getattr(
|
||||
self, field_info.name
|
||||
):
|
||||
setattr(self, field_info.name, new_value)
|
||||
updated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
|
||||
)
|
||||
|
||||
return updated
|
||||
|
||||
def __str__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
parameters = [
|
||||
f"\n\n=========================== {class_name} ===========================\n"
|
||||
]
|
||||
for field_info in fields(self):
|
||||
value = getattr(self, field_info.name)
|
||||
parameters.append(f"{field_info.name}: {value}")
|
||||
parameters.append(
|
||||
"\n======================================================================\n\n"
|
||||
)
|
||||
return "\n".join(parameters)
|
||||
class ModelControllerParameters(BaseParameters):
|
||||
host: Optional[str] = field(
|
||||
default="0.0.0.0", metadata={"help": "Model Controller deploy host"}
|
||||
)
|
||||
port: Optional[int] = field(
|
||||
default=8000, metadata={"help": "Model Controller deploy port"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -209,7 +40,7 @@ class ModelWorkerParameters(BaseParameters):
|
||||
worker_class: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Model worker deploy host, pilot.model.worker.default_worker.DefaultModelWorker"
|
||||
"help": "Model worker class, pilot.model.worker.default_worker.DefaultModelWorker"
|
||||
},
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
|
@ -2,10 +2,9 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, Iterator, List, Type
|
||||
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.parameter import (
|
||||
ModelParameters,
|
||||
from pilot.model.parameter import ModelParameters, WorkerType
|
||||
from pilot.utils.parameter_utils import (
|
||||
ParameterDescription,
|
||||
WorkerType,
|
||||
_get_parameter_descriptions,
|
||||
)
|
||||
|
||||
|
@ -7,10 +7,11 @@ from pilot.configs.model_config import DEVICE
|
||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.model.parameter import EnvArgumentParser, ModelParameters
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger("model_worker")
|
||||
|
||||
|
@ -5,11 +5,11 @@ from pilot.configs.model_config import DEVICE
|
||||
from pilot.model.loader import _get_model_real_path
|
||||
from pilot.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
EnvArgumentParser,
|
||||
WorkerType,
|
||||
)
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger("model_worker")
|
||||
|
||||
|
@ -23,15 +23,14 @@ from pilot.model.base import (
|
||||
)
|
||||
from pilot.model.controller.registry import ModelRegistry
|
||||
from pilot.model.parameter import (
|
||||
EnvArgumentParser,
|
||||
ModelParameters,
|
||||
ModelWorkerParameters,
|
||||
WorkerType,
|
||||
ParameterDescription,
|
||||
)
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
from pilot.utils import build_logger
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser, ParameterDescription
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
|
||||
@ -148,6 +147,8 @@ class LocalWorkerManager(WorkerManager):
|
||||
self.model_registry = model_registry
|
||||
|
||||
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
||||
if isinstance(worker_type, WorkerType):
|
||||
worker_type = worker_type.value
|
||||
return f"{model_name}@{worker_type}"
|
||||
|
||||
def add_worker(
|
||||
@ -166,6 +167,9 @@ class LocalWorkerManager(WorkerManager):
|
||||
if not worker_params.worker_type:
|
||||
worker_params.worker_type = worker.worker_type()
|
||||
|
||||
if isinstance(worker_params.worker_type, WorkerType):
|
||||
worker_params.worker_type = worker_params.worker_type.value
|
||||
|
||||
worker_key = self._worker_key(
|
||||
worker_params.worker_type, worker_params.model_name
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ import copy
|
||||
import logging
|
||||
|
||||
sys.path.append(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
|
||||
|
||||
@ -23,16 +23,53 @@ def cli(log_level: str):
|
||||
logging.basicConfig(level=log_level, encoding="utf-8")
|
||||
|
||||
|
||||
def add_command_alias(command, name: str, hidden: bool = False):
|
||||
def add_command_alias(command, name: str, hidden: bool = False, parent_group=None):
|
||||
if not parent_group:
|
||||
parent_group = cli
|
||||
new_command = copy.deepcopy(command)
|
||||
new_command.hidden = hidden
|
||||
cli.add_command(new_command, name=name)
|
||||
parent_group.add_command(new_command, name=name)
|
||||
|
||||
|
||||
@click.group()
|
||||
def start():
|
||||
"""Start specific server."""
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def stop():
|
||||
"""Start specific server."""
|
||||
pass
|
||||
|
||||
|
||||
cli.add_command(start)
|
||||
cli.add_command(stop)
|
||||
|
||||
try:
|
||||
from pilot.model.cli import model_cli_group
|
||||
from pilot.model.cli import (
|
||||
model_cli_group,
|
||||
start_model_controller,
|
||||
stop_model_controller,
|
||||
start_model_worker,
|
||||
stop_model_worker,
|
||||
start_webserver,
|
||||
stop_webserver,
|
||||
start_apiserver,
|
||||
stop_apiserver,
|
||||
)
|
||||
|
||||
add_command_alias(model_cli_group, name="model", parent_group=cli)
|
||||
add_command_alias(start_model_controller, name="controller", parent_group=start)
|
||||
add_command_alias(start_model_worker, name="worker", parent_group=start)
|
||||
add_command_alias(start_webserver, name="webserver", parent_group=start)
|
||||
add_command_alias(start_apiserver, name="apiserver", parent_group=start)
|
||||
|
||||
add_command_alias(stop_model_controller, name="controller", parent_group=stop)
|
||||
add_command_alias(stop_model_worker, name="worker", parent_group=stop)
|
||||
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
|
||||
add_command_alias(stop_apiserver, name="apiserver", parent_group=stop)
|
||||
|
||||
add_command_alias(model_cli_group, name="model")
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
|
||||
|
||||
|
314
pilot/utils/parameter_utils.py
Normal file
314
pilot/utils/parameter_utils.py
Normal file
@ -0,0 +1,314 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import Any, List, Optional, Type, Union, Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterDescription:
|
||||
param_name: str
|
||||
param_type: str
|
||||
description: str
|
||||
default_value: Optional[Any]
|
||||
valid_values: Optional[List[Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseParameters:
|
||||
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
|
||||
"""
|
||||
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
|
||||
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
|
||||
|
||||
Args:
|
||||
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
|
||||
|
||||
Returns:
|
||||
bool: True if at least one field was updated, otherwise False.
|
||||
"""
|
||||
updated = False # Flag to indicate whether any field was updated
|
||||
if isinstance(source, (BaseParameters, dict)):
|
||||
for field_info in fields(self):
|
||||
# Check if the field has a "fixed" tag in metadata
|
||||
tags = field_info.metadata.get("tags")
|
||||
tags = [] if not tags else tags.split(",")
|
||||
if tags and "fixed" in tags:
|
||||
continue # skip this field
|
||||
# Get the new value from source (either another BaseParameters object or a dict)
|
||||
new_value = (
|
||||
getattr(source, field_info.name)
|
||||
if isinstance(source, BaseParameters)
|
||||
else source.get(field_info.name, None)
|
||||
)
|
||||
|
||||
# If the new value is not None and different from the current value, update the field and set the flag
|
||||
if new_value is not None and new_value != getattr(
|
||||
self, field_info.name
|
||||
):
|
||||
setattr(self, field_info.name, new_value)
|
||||
updated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
|
||||
)
|
||||
|
||||
return updated
|
||||
|
||||
def __str__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
parameters = [
|
||||
f"\n\n=========================== {class_name} ===========================\n"
|
||||
]
|
||||
for field_info in fields(self):
|
||||
value = getattr(self, field_info.name)
|
||||
parameters.append(f"{field_info.name}: {value}")
|
||||
parameters.append(
|
||||
"\n======================================================================\n\n"
|
||||
)
|
||||
return "\n".join(parameters)
|
||||
|
||||
|
||||
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
|
||||
"""Get the value from the environment variable, ignoring the case of the key"""
|
||||
if env_prefix:
|
||||
env_key = env_prefix + env_key
|
||||
return os.getenv(
|
||||
env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value))
|
||||
)
|
||||
|
||||
|
||||
class EnvArgumentParser:
|
||||
@staticmethod
|
||||
def get_env_prefix(env_key: str) -> str:
|
||||
if not env_key:
|
||||
return None
|
||||
env_key = env_key.replace("-", "_")
|
||||
return env_key + "_"
|
||||
|
||||
def parse_args_into_dataclass(
|
||||
self,
|
||||
dataclass_type: Type,
|
||||
env_prefix: str = None,
|
||||
command_args: List[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
||||
parser = argparse.ArgumentParser()
|
||||
for field in fields(dataclass_type):
|
||||
env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
|
||||
if not env_var_value:
|
||||
# Read without env prefix
|
||||
env_var_value = _genenv_ignoring_key_case(field.name)
|
||||
|
||||
if env_var_value:
|
||||
env_var_value = env_var_value.strip()
|
||||
if field.type is int or field.type == Optional[int]:
|
||||
env_var_value = int(env_var_value)
|
||||
elif field.type is float or field.type == Optional[float]:
|
||||
env_var_value = float(env_var_value)
|
||||
elif field.type is bool or field.type == Optional[bool]:
|
||||
env_var_value = env_var_value.lower() == "true"
|
||||
elif field.type is str or field.type == Optional[str]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type {field.type}")
|
||||
if not env_var_value:
|
||||
env_var_value = kwargs.get(field.name)
|
||||
|
||||
# Add a command-line argument for this field
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
|
||||
argument_kwargs = {
|
||||
"type": EnvArgumentParser._get_argparse_type(field.type),
|
||||
"help": help_text,
|
||||
"choices": valid_values,
|
||||
"required": EnvArgumentParser._is_require_type(field.type),
|
||||
}
|
||||
if field.default != MISSING:
|
||||
argument_kwargs["default"] = field.default
|
||||
argument_kwargs["required"] = False
|
||||
if env_var_value:
|
||||
argument_kwargs["default"] = env_var_value
|
||||
argument_kwargs["required"] = False
|
||||
|
||||
parser.add_argument(f"--{field.name}", **argument_kwargs)
|
||||
|
||||
# Parse the command-line arguments
|
||||
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
|
||||
# cmd_args = parser.parse_args(args=command_args)
|
||||
# print(f"cmd_args: {cmd_args}")
|
||||
for field in fields(dataclass_type):
|
||||
# cmd_line_value = getattr(cmd_args, field.name)
|
||||
if field.name in cmd_args:
|
||||
cmd_line_value = getattr(cmd_args, field.name)
|
||||
if cmd_line_value is not None:
|
||||
kwargs[field.name] = cmd_line_value
|
||||
|
||||
return dataclass_type(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dataclass_type.__doc__)
|
||||
for field in fields(dataclass_type):
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
argument_kwargs = {
|
||||
"type": EnvArgumentParser._get_argparse_type(field.type),
|
||||
"help": help_text,
|
||||
"choices": valid_values,
|
||||
"required": EnvArgumentParser._is_require_type(field.type),
|
||||
}
|
||||
if field.default != MISSING:
|
||||
argument_kwargs["default"] = field.default
|
||||
argument_kwargs["required"] = False
|
||||
parser.add_argument(f"--{field.name}", **argument_kwargs)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def create_click_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
):
|
||||
import click
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
|
||||
# TODO dynamic configuration
|
||||
# pre_args = _SimpleArgParser('model_name', 'model_path')
|
||||
# pre_args.parse()
|
||||
# print(pre_args)
|
||||
|
||||
combined_fields = OrderedDict()
|
||||
if _dynamic_factory:
|
||||
_types = _dynamic_factory()
|
||||
if _types:
|
||||
dataclass_types = list(_types)
|
||||
for dataclass_type in dataclass_types:
|
||||
for field in fields(dataclass_type):
|
||||
if field.name not in combined_fields:
|
||||
combined_fields[field.name] = field
|
||||
|
||||
def decorator(func):
|
||||
for field_name, field in reversed(combined_fields.items()):
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
cli_params = {
|
||||
"default": None if field.default is MISSING else field.default,
|
||||
"help": help_text,
|
||||
"show_default": True,
|
||||
"required": field.default is MISSING,
|
||||
}
|
||||
if valid_values:
|
||||
cli_params["type"] = click.Choice(valid_values)
|
||||
real_type = EnvArgumentParser._get_argparse_type(field.type)
|
||||
if real_type is int:
|
||||
cli_params["type"] = click.INT
|
||||
elif real_type is float:
|
||||
cli_params["type"] = click.FLOAT
|
||||
elif real_type is str:
|
||||
cli_params["type"] = click.STRING
|
||||
elif real_type is bool:
|
||||
cli_params["is_flag"] = True
|
||||
|
||||
option_decorator = click.option(
|
||||
# f"--{field_name.replace('_', '-')}", **cli_params
|
||||
f"--{field_name}",
|
||||
**cli_params,
|
||||
)
|
||||
func = option_decorator(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _get_argparse_type(field_type: Type) -> Type:
|
||||
# Return the appropriate type for argparse to use based on the field type
|
||||
if field_type is int or field_type == Optional[int]:
|
||||
return int
|
||||
elif field_type is float or field_type == Optional[float]:
|
||||
return float
|
||||
elif field_type is bool or field_type == Optional[bool]:
|
||||
return bool
|
||||
elif field_type is str or field_type == Optional[str]:
|
||||
return str
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type {field_type}")
|
||||
|
||||
@staticmethod
|
||||
def _get_argparse_type_str(field_type: Type) -> str:
|
||||
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
|
||||
if argparse_type is int:
|
||||
return "int"
|
||||
elif argparse_type is float:
|
||||
return "float"
|
||||
elif argparse_type is bool:
|
||||
return "bool"
|
||||
else:
|
||||
return "str"
|
||||
|
||||
@staticmethod
|
||||
def _is_require_type(field_type: Type) -> str:
|
||||
return field_type not in [Optional[int], Optional[float], Optional[bool]]
|
||||
|
||||
|
||||
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
||||
descriptions = []
|
||||
for field in fields(dataclass_type):
|
||||
descriptions.append(
|
||||
ParameterDescription(
|
||||
param_name=field.name,
|
||||
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
||||
description=field.metadata.get("help", None),
|
||||
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
|
||||
valid_values=field.metadata.get("valid_values", None),
|
||||
)
|
||||
)
|
||||
return descriptions
|
||||
|
||||
|
||||
class _SimpleArgParser:
|
||||
def __init__(self, *args):
|
||||
self.params = {arg.replace("_", "-"): None for arg in args}
|
||||
|
||||
def parse(self, args=None):
|
||||
import sys
|
||||
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
else:
|
||||
args = list(args)
|
||||
prev_arg = None
|
||||
for arg in args:
|
||||
if arg.startswith("--"):
|
||||
if prev_arg:
|
||||
self.params[prev_arg] = None
|
||||
prev_arg = arg[2:]
|
||||
else:
|
||||
if prev_arg:
|
||||
self.params[prev_arg] = arg
|
||||
prev_arg = None
|
||||
|
||||
if prev_arg:
|
||||
self.params[prev_arg] = None
|
||||
|
||||
def _get_param(self, key):
|
||||
return self.params.get(key.replace("_", "-"), None)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return self._get_param(item)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._get_param(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._get_param(key) or default
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(
|
||||
[f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()]
|
||||
)
|
Loading…
Reference in New Issue
Block a user