diff --git a/docs/getting_started/install/llm/cluster/model_cluster.md b/docs/getting_started/install/llm/cluster/model_cluster.md new file mode 100644 index 000000000..5576dfc1e --- /dev/null +++ b/docs/getting_started/install/llm/cluster/model_cluster.md @@ -0,0 +1,219 @@ +Cluster deployment +================================== + +## Model cluster deployment + + +**Installing Command-Line Tool** + +All operations below are performed using the `dbgpt` command. To use the `dbgpt` command, you need to install the DB-GPT project with `pip install -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` as a substitute for the `dbgpt` command. + +### Launch Model Controller + +```bash +dbgpt start controller +``` + +By default, the Model Controller starts on port 8000. + + +### Launch Model Worker + +If you are starting `chatglm2-6b`: + +```bash +dbgpt start worker --model_name chatglm2-6b \ +--model_path /app/models/chatglm2-6b \ +--port 8001 \ +--controller_addr http://127.0.0.1:8000 +``` + +If you are starting `vicuna-13b-v1.5`: + +```bash +dbgpt start worker --model_name vicuna-13b-v1.5 \ +--model_path /app/models/vicuna-13b-v1.5 \ +--port 8002 \ +--controller_addr http://127.0.0.1:8000 +``` + +Note: Be sure to use your own model name and model path. + + +Check your model: + +```bash +dbgpt model list +``` + +You will see the following output: +``` ++-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+ +| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat | ++-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+ +| chatglm2-6b | llm | 172.17.0.6 | 8001 | True | True | None | 2023-08-31T04:48:45.252939 | +| vicuna-13b-v1.5 | llm | 172.17.0.6 | 8002 | True | True | None | 2023-08-31T04:48:55.136676 | ++-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+ +``` + +### Connect to the model service in the webserver (dbgpt_server) + +**First, modify the `.env` file to change the model name and the Model Controller connection address.** + +```bash +LLM_MODEL=vicuna-13b-v1.5 +# The current default MODEL_SERVER address is the address of the Model Controller +MODEL_SERVER=http://127.0.0.1:8000 +``` + +#### Start the webserver + +```bash +python pilot/server/dbgpt_server.py --light +``` + +`--light` indicates not to start the embedded model service. + +Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` to start: + +```bash +LLM_MODEL=chatglm2-6b python pilot/server/dbgpt_server.py --light +``` + + +### More Command-Line Usages + +You can view more command-line usages through the help command. + +**View the `dbgpt` help** +```bash +dbgpt --help +``` + +You will see the basic command parameters and usage: + +``` +Usage: dbgpt [OPTIONS] COMMAND [ARGS]... + +Options: + --log-level TEXT Log level + --version Show the version and exit. + --help Show this message and exit. + +Commands: + model Clients that manage model serving + start Start specific server. + stop Start specific server. +``` + +**View the `dbgpt start` help** + +```bash +dbgpt start --help +``` + +Here you can see the related commands and usage for start: + +``` +Usage: dbgpt start [OPTIONS] COMMAND [ARGS]... + + Start specific server. + +Options: + --help Show this message and exit. + +Commands: + apiserver Start apiserver(TODO) + controller Start model controller + webserver Start webserver(dbgpt_server.py) + worker Start model worker +``` + +**View the `dbgpt start worker`help** + +```bash +dbgpt start worker --help +``` + +Here you can see the parameters to start Model Worker: + +``` +Usage: dbgpt start worker [OPTIONS] + + Start model worker + +Options: + --model_name TEXT Model name [required] + --model_path TEXT Model path [required] + --worker_type TEXT Worker type + --worker_class TEXT Model worker class, pilot.model.worker.defau + lt_worker.DefaultModelWorker + --host TEXT Model worker deploy host [default: 0.0.0.0] + --port INTEGER Model worker deploy port [default: 8000] + --limit_model_concurrency INTEGER + Model concurrency limit [default: 5] + --standalone Standalone mode. If True, embedded Run + ModelController + --register Register current worker to model controller + [default: True] + --worker_register_host TEXT The ip address of current worker to register + to ModelController. If None, the address is + automatically determined + --controller_addr TEXT The Model controller address to register + --send_heartbeat Send heartbeat to model controller + [default: True] + --heartbeat_interval INTEGER The interval for sending heartbeats + (seconds) [default: 20] + --device TEXT Device to run model. If None, the device is + automatically determined + --model_type TEXT Model type, huggingface or llama.cpp + [default: huggingface] + --prompt_template TEXT Prompt template. If None, the prompt + template is automatically determined from + model path, supported template: zero_shot,vi + cuna_v1.1,llama-2,alpaca,baichuan-chat + --max_context_size INTEGER Maximum context size [default: 4096] + --num_gpus INTEGER The number of gpus you expect to use, if it + is empty, use all of them as much as + possible + --max_gpu_memory TEXT The maximum memory limit of each GPU, only + valid in multi-GPU configuration + --cpu_offloading CPU offloading + --load_8bit 8-bit quantization + --load_4bit 4-bit quantization + --quant_type TEXT Quantization datatypes, `fp4` (four bit + float) and `nf4` (normal four bit float), + only valid when load_4bit=True [default: + nf4] + --use_double_quant Nested quantization, only valid when + load_4bit=True [default: True] + --compute_dtype TEXT Model compute type + --trust_remote_code Trust remote code [default: True] + --verbose Show verbose output. + --help Show this message and exit. +``` + +**View the `dbgpt model`help** + +```bash +dbgpt model --help +``` + +The `dbgpt model ` command can connect to the Model Controller via the Model Controller address and then manage a remote model: + +``` +Usage: dbgpt model [OPTIONS] COMMAND [ARGS]... + + Clients that manage model serving + +Options: + --address TEXT Address of the Model Controller to connect to. Just support + light deploy model [default: http://127.0.0.1:8000] + --help Show this message and exit. + +Commands: + list List model instances + restart Restart model instances + start Start model instances + stop Stop model instances +``` \ No newline at end of file diff --git a/docs/getting_started/install/llm/llm.rst b/docs/getting_started/install/llm/llm.rst index 1eca893e2..accde0250 100644 --- a/docs/getting_started/install/llm/llm.rst +++ b/docs/getting_started/install/llm/llm.rst @@ -19,6 +19,7 @@ Multi LLMs Support, Supports multiple large language models, currently supportin - llama_cpp - quantization +- cluster deployment .. toctree:: :maxdepth: 2 @@ -28,3 +29,4 @@ Multi LLMs Support, Supports multiple large language models, currently supportin ./llama/llama_cpp.md ./quantization/quantization.md + ./cluster/model_cluster.md diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/cluster/model_cluster.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/cluster/model_cluster.po new file mode 100644 index 000000000..8f71a3391 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/cluster/model_cluster.po @@ -0,0 +1,172 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023, csunny +# This file is distributed under the same license as the DB-GPT package. +# FIRST AUTHOR , 2023. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: DB-GPT 👏👏 0.3.6\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-08-31 16:43+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language: zh_CN\n" +"Language-Team: zh_CN \n" +"Plural-Forms: nplurals=1; plural=0;\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.12.1\n" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:1 +#: b0614062af2a4c039aa8947e45c4382a +msgid "Cluster deployment" +msgstr "集群部署" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:4 +#: 568a2af68de5410ebd63f34b8a3665c5 +msgid "Model cluster deployment" +msgstr "模型的集群部署" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:7 +#: 41b9c4bf32bd4c95a1a7b733543a322b +msgid "**Installing Command-Line Tool**" +msgstr "**命令行工具的安装**" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:9 +#: b340cf9309a742648fb05910dc6ba741 +msgid "" +"All operations below are performed using the `dbgpt` command. To use the " +"`dbgpt` command, you need to install the DB-GPT project with `pip install" +" -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` " +"as a substitute for the `dbgpt` command." +msgstr "下面的操作都是使用 `dbgpt` 命令来进行,为了使用 `dbgpt` 命令,你需要使用 `pip install -e .` 来安装 DB-GPT 项目,当然,你也可以使用 `python pilot/scripts/cli_scripts.py` 来替代命令 `dbgpt`。" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:11 +#: 9299e223f8ec4060aaf40ddba019bfbf +msgid "Launch Model Controller" +msgstr "启动 Model Controller" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:17 +#: f858b280e1e54feba6639c7270b512c3 +msgid "By default, the Model Controller starts on port 8000." +msgstr "默认情况下 Model Controller 启动端口是 8000。" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:20 +#: 51eecd5331ce441abd9658f873132935 +msgid "Launch Model Worker" +msgstr "启动 Model Worker" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:22 +#: 09b1b7ffe3b740de94163f2ef24e8743 +msgid "If you are starting `chatglm2-6b`:" +msgstr "如果你启动的是 `chatglm2-6b`" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:31 +#: 47f9ec2b957e4a2a858a8833ce2cd104 +msgid "If you are starting `vicuna-13b-v1.5`:" +msgstr "如果你启动的是 `vicuna-13b-v1.5`" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:40 +#: 95a04817be5a473e8b4f7086ac8dfe2c +msgid "Note: Be sure to use your own model name and model path." +msgstr "注意:注意使用你自己的模型名称和模型路径。" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:43 +#: 644fff6c28ca49efa06e0f2c89f45e27 +msgid "Check your model:" +msgstr "查看你的模型:" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:49 +#: 69df9c27c17e44b8a6cb010360c56338 +msgid "You will see the following output:" +msgstr "你将会看到下面的输出:" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:59 +#: a1e1489068d04b51bb239bd32ef3546b +msgid "Connect to the model service in the webserver (dbgpt_server)" +msgstr "在 webserver(dbgpt_server) 中连接模型服务" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:61 +#: 0ab444ac06f34ce0994fe1f945abecde +msgid "" +"**First, modify the `.env` file to change the model name and the Model " +"Controller connection address.**" +msgstr "**先修改 `.env` 文件,修改模型名称和 Model Controller 连接地址。**" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:69 +#: 5125b43d70d744a491a70aa1001303be +msgid "Start the webserver" +msgstr "" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:75 +#: f12f463c27c74337ba8cc9c7e4e934b4 +msgid "`--light` indicates not to start the embedded model service." +msgstr "`--light` 表示不启动内嵌的模型服务。" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:77 +#: de2f7b040fd04c73a89ae453db84108f +msgid "" +"Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` " +"to start:" +msgstr "更简单的,你可以在命令行前添加 `LLM_MODEL=chatglm2-6b` 来启动:" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:84 +#: 83b93195a9fb45aeaeead6b05719f164 +msgid "More Command-Line Usages" +msgstr "命令行的更多用法" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:86 +#: 8622edf06c1d443188e19a26ed306d58 +msgid "You can view more command-line usages through the help command." +msgstr "你可以通过帮助命令来查看更多的命令行用法。" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:88 +#: 391c6fa585d045f5a583e26478adda7e +msgid "**View the `dbgpt` help**" +msgstr "**查看 `dbgpt` 的帮助**" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:93 +#: 4247e8ee8f084fa780fb523530cc64ab +msgid "You will see the basic command parameters and usage:" +msgstr "你将会看到基础的命令参数和用法:" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:109 +#: 87b43469e93e4b62a59488d1ef663b51 +msgid "**View the `dbgpt start` help**" +msgstr "**查看 `dbgpt start` 的帮助**" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:115 +#: 213f7bfc459d4439b3a839292e2684e6 +msgid "Here you can see the related commands and usage for start:" +msgstr "这里你能看到 start 相关的命令和用法。" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:132 +#: eb616a87e01d4cdc87dc61e462d83729 +msgid "**View the `dbgpt start worker`help**" +msgstr "**查看 `dbgpt start worker`的帮助**" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:138 +#: 75918aa07d7c4ec39271358f2b474a63 +msgid "Here you can see the parameters to start Model Worker:" +msgstr "这里你能看启动 Model Worker 的参数:" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:196 +#: 5ee6ea98af4d4693b8cbf059ad3001d8 +msgid "**View the `dbgpt model`help**" +msgstr "**查看 `dbgpt model`的帮助**" + +#: ../../getting_started/install/llm/cluster/model_cluster.md:202 +#: a96a1b58987a481fac2d527470226b90 +msgid "" +"The `dbgpt model ` command can connect to the Model Controller via the " +"Model Controller address and then manage a remote model:" +msgstr "`dbgpt model ` 命令可以通过 Model Controller 地址连接到 Model Controller,然后对远程对某个模型进行管理。" + +#~ msgid "" +#~ "First, modify the `.env` file to " +#~ "change the model name and the " +#~ "Model Controller connection address." +#~ msgstr "" + diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po index 01c1607ca..b5531e9f5 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: DB-GPT 👏👏 0.3.5\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-08-17 23:29+0800\n" +"POT-Creation-Date: 2023-08-31 16:38+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: zh_CN\n" @@ -20,80 +20,85 @@ msgstr "" "Generated-By: Babel 2.12.1\n" #: ../../getting_started/install/llm/llm.rst:2 -#: ../../getting_started/install/llm/llm.rst:23 -#: 9e4baaff732b49f1baded5fd1bde4bd5 +#: ../../getting_started/install/llm/llm.rst:24 +#: b348d4df8ca44dd78b42157a8ff6d33d msgid "LLM Usage" msgstr "LLM使用" -#: ../../getting_started/install/llm/llm.rst:3 d80e7360f17c4be889f220e2a846a81e +#: ../../getting_started/install/llm/llm.rst:3 7f5960a7e5634254b330da27be87594b msgid "" "DB-GPT provides a management and deployment solution for multiple models." " This chapter mainly discusses how to deploy different models." msgstr "DB-GPT提供了多模型的管理和部署方案,本章主要讲解针对不同的模型该怎么部署" #: ../../getting_started/install/llm/llm.rst:18 -#: 1cebd6c84b924eeb877ba31bd520328f +#: b844ab204ec740ec9d7d191bb841f09e msgid "" "Multi LLMs Support, Supports multiple large language models, currently " "supporting" msgstr "目前DB-GPT已适配如下模型" -#: ../../getting_started/install/llm/llm.rst:9 1163d5d3a8834b24bb6a5716d4fcc680 +#: ../../getting_started/install/llm/llm.rst:9 c141437ddaf84c079360008343041b2f msgid "🔥 Vicuna-v1.5(7b,13b)" msgstr "🔥 Vicuna-v1.5(7b,13b)" #: ../../getting_started/install/llm/llm.rst:10 -#: 318dbc8e8de64bfb9cbc65343d050500 +#: d32b1e3f114c4eab8782b497097c1b37 msgid "🔥 llama-2(7b,13b,70b)" msgstr "🔥 llama-2(7b,13b,70b)" #: ../../getting_started/install/llm/llm.rst:11 -#: f2b1ef2bfbde46889131b6852d1211e8 +#: 0a417ee4d008421da07fff7add5d05eb msgid "WizardLM-v1.2(13b)" msgstr "WizardLM-v1.2(13b)" #: ../../getting_started/install/llm/llm.rst:12 -#: 2462160e5928434aaa6f58074b585c14 +#: 199e1a9fe3324dc8a1bcd9cd0b1ef047 msgid "Vicuna (7b,13b)" msgstr "Vicuna (7b,13b)" #: ../../getting_started/install/llm/llm.rst:13 -#: 61cc1b9ed1914e379dee87034e9fbaea +#: a9e4c5100534450db3a583fa5850e4be msgid "ChatGLM-6b (int4,int8)" msgstr "ChatGLM-6b (int4,int8)" #: ../../getting_started/install/llm/llm.rst:14 -#: 653f5fd3b6ab4c508a2ce449231b5a17 +#: 943324289eb94042b52fd824189cd93f msgid "ChatGLM2-6b (int4,int8)" msgstr "ChatGLM2-6b (int4,int8)" #: ../../getting_started/install/llm/llm.rst:15 -#: 6a1cf7a271f6496ba48d593db1f756f1 +#: f1226fdfac3b4e9d88642ffa69d75682 msgid "guanaco(7b,13b,33b)" msgstr "guanaco(7b,13b,33b)" #: ../../getting_started/install/llm/llm.rst:16 -#: f15083886362437ba96e1cb9ade738aa +#: 3f2457f56eb341b6bc431c9beca8f4df msgid "Gorilla(7b,13b)" msgstr "Gorilla(7b,13b)" #: ../../getting_started/install/llm/llm.rst:17 -#: 5e1ded19d1b24ba485e5d2bf22ce2db7 +#: 86c8ce37be1c4a7ea3fc382100d77a9c msgid "Baichuan(7b,13b)" msgstr "Baichuan(7b,13b)" #: ../../getting_started/install/llm/llm.rst:18 -#: 30f6f71dd73b48d29a3647d58a9cdcaf +#: 538111af95ad414cb2e631a89f9af379 msgid "OpenAI" msgstr "OpenAI" #: ../../getting_started/install/llm/llm.rst:20 -#: 9ad21acdaccb41c1bbfa30b5ce114732 +#: a203325b7ec248f7bff61ae89226a000 msgid "llama_cpp" msgstr "llama_cpp" #: ../../getting_started/install/llm/llm.rst:21 -#: b5064649f7c9443c9b9f15ec7fc02434 +#: 21a50634198047228bc51a03d2c31292 msgid "quantization" msgstr "quantization" +#: ../../getting_started/install/llm/llm.rst:22 +#: dfaec4b04e6e45ff9c884b41534b1a79 +msgid "cluster deployment" +msgstr "" + diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 00d1d5080..d4c3b3f8c 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -166,6 +166,9 @@ class Config(metaclass=Singleton): self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True" if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT: self.IS_LOAD_8BIT = False + # In order to be compatible with the new and old model parameter design + os.environ["load_8bit"] = str(self.IS_LOAD_8BIT) + os.environ["load_4bit"] = str(self.IS_LOAD_4BIT) ### EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") diff --git a/pilot/model/cli.py b/pilot/model/cli.py index 6109d11ab..9208522d7 100644 --- a/pilot/model/cli.py +++ b/pilot/model/cli.py @@ -7,25 +7,33 @@ from pilot.model.worker.manager import ( WorkerApplyRequest, WorkerApplyType, ) +from pilot.model.parameter import ( + ModelControllerParameters, + ModelWorkerParameters, + ModelParameters, +) from pilot.utils import get_or_create_event_loop +from pilot.utils.parameter_utils import EnvArgumentParser @click.group("model") -def model_cli_group(): - pass - - -@model_cli_group.command() @click.option( "--address", type=str, default="http://127.0.0.1:8000", required=False, + show_default=True, help=( - "Address of the Model Controller to connect to." + "Address of the Model Controller to connect to. " "Just support light deploy model" ), ) +def model_cli_group(): + """Clients that manage model serving""" + pass + + +@model_cli_group.command() @click.option( "--model-name", type=str, default=None, required=False, help=("The name of model") ) @@ -79,16 +87,6 @@ def list(address: str, model_name: str, model_type: str): def add_model_options(func): - @click.option( - "--address", - type=str, - default="http://127.0.0.1:8000", - required=False, - help=( - "Address of the Model Controller to connect to." - "Just support light deploy model" - ), - ) @click.option( "--model-name", type=str, @@ -149,3 +147,57 @@ def worker_apply( ) res = loop.run_until_complete(worker_manager.worker_apply(apply_req)) print(res) + + +@click.command(name="controller") +@EnvArgumentParser.create_click_option(ModelControllerParameters) +def start_model_controller(**kwargs): + """Start model controller""" + from pilot.model.controller.controller import run_model_controller + + run_model_controller() + + +@click.command(name="controller") +def stop_model_controller(**kwargs): + """Start model controller""" + raise NotImplementedError + + +@click.command(name="worker") +@EnvArgumentParser.create_click_option(ModelWorkerParameters, ModelParameters) +def start_model_worker(**kwargs): + """Start model worker""" + from pilot.model.worker.manager import run_worker_manager + + run_worker_manager() + + +@click.command(name="worker") +def stop_model_worker(**kwargs): + """Stop model worker""" + raise NotImplementedError + + +@click.command(name="webserver") +def start_webserver(**kwargs): + """Start webserver(dbgpt_server.py)""" + raise NotImplementedError + + +@click.command(name="webserver") +def stop_webserver(**kwargs): + """Stop webserver(dbgpt_server.py)""" + raise NotImplementedError + + +@click.command(name="apiserver") +def start_apiserver(**kwargs): + """Start apiserver(TODO)""" + raise NotImplementedError + + +@click.command(name="controller") +def stop_apiserver(**kwargs): + """Start apiserver(TODO)""" + raise NotImplementedError diff --git a/pilot/model/controller/controller.py b/pilot/model/controller/controller.py index 84d4dfb29..51b61826e 100644 --- a/pilot/model/controller/controller.py +++ b/pilot/model/controller/controller.py @@ -1,9 +1,11 @@ import logging from typing import List -from fastapi import APIRouter +from fastapi import APIRouter, FastAPI from pilot.model.base import ModelInstance +from pilot.model.parameter import ModelControllerParameters from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry +from pilot.utils.parameter_utils import EnvArgumentParser class ModelController: @@ -63,3 +65,22 @@ async def api_get_all_instances(model_name: str = None, healthy_only: bool = Fal @router.post("/controller/heartbeat") async def api_model_heartbeat(request: ModelInstance): return await controller.send_heartbeat(request) + + +def run_model_controller(): + import uvicorn + + parser = EnvArgumentParser() + env_prefix = "controller_" + controller_params: ModelControllerParameters = parser.parse_args_into_dataclass( + ModelControllerParameters, env_prefix=env_prefix + ) + app = FastAPI() + app.include_router(router, prefix="/api") + uvicorn.run( + app, host=controller_params.host, port=controller_params.port, log_level="info" + ) + + +if __name__ == "__main__": + run_model_controller() diff --git a/pilot/model/loader.py b/pilot/model/loader.py index e4478450a..79b4c1049 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -8,14 +8,12 @@ from pilot.configs.model_config import DEVICE from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType from pilot.model.compression import compress_module from pilot.model.parameter import ( - EnvArgumentParser, ModelParameters, LlamaCppModelParameters, - _genenv_ignoring_key_case, ) from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations -from pilot.singleton import Singleton from pilot.utils import get_gpu_memory +from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case from pilot.logs import logger diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index baa646711..c47af4e21 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -1,139 +1,15 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import argparse -import os -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field, fields, MISSING from enum import Enum -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, Optional from pilot.model.conversation import conv_templates +from pilot.utils.parameter_utils import BaseParameters suported_prompt_templates = ",".join(conv_templates.keys()) -def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None): - """Get the value from the environment variable, ignoring the case of the key""" - if env_prefix: - env_key = env_prefix + env_key - return os.getenv( - env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value)) - ) - - -class EnvArgumentParser: - @staticmethod - def get_env_prefix(env_key: str) -> str: - if not env_key: - return None - env_key = env_key.replace("-", "_") - return env_key + "_" - - def parse_args_into_dataclass( - self, - dataclass_type: Type, - env_prefix: str = None, - command_args: List[str] = None, - **kwargs, - ) -> Any: - """Parse parameters from environment variables and command lines and populate them into data class""" - parser = argparse.ArgumentParser() - for field in fields(dataclass_type): - env_var_value = _genenv_ignoring_key_case(field.name, env_prefix) - if not env_var_value: - # Read without env prefix - env_var_value = _genenv_ignoring_key_case(field.name) - - if env_var_value: - env_var_value = env_var_value.strip() - if field.type is int or field.type == Optional[int]: - env_var_value = int(env_var_value) - elif field.type is float or field.type == Optional[float]: - env_var_value = float(env_var_value) - elif field.type is bool or field.type == Optional[bool]: - env_var_value = env_var_value.lower() == "true" - elif field.type is str or field.type == Optional[str]: - pass - else: - raise ValueError(f"Unsupported parameter type {field.type}") - if not env_var_value: - env_var_value = kwargs.get(field.name) - if not env_var_value: - env_var_value = field.default - - # Add a command-line argument for this field - help_text = field.metadata.get("help", "") - valid_values = field.metadata.get("valid_values", None) - parser.add_argument( - f"--{field.name}", - type=self._get_argparse_type(field.type), - help=help_text, - choices=valid_values, - default=env_var_value, - ) - - # Parse the command-line arguments - cmd_args, cmd_argv = parser.parse_known_args(args=command_args) - print(f"cmd_args: {cmd_args}") - for field in fields(dataclass_type): - # cmd_line_value = getattr(cmd_args, field.name) - if field.name in cmd_args: - cmd_line_value = getattr(cmd_args, field.name) - if cmd_line_value is not None: - kwargs[field.name] = cmd_line_value - - return dataclass_type(**kwargs) - - @staticmethod - def _get_argparse_type(field_type: Type) -> Type: - # Return the appropriate type for argparse to use based on the field type - if field_type is int or field_type == Optional[int]: - return int - elif field_type is float or field_type == Optional[float]: - return float - elif field_type is bool or field_type == Optional[bool]: - return bool - elif field_type is str or field_type == Optional[str]: - return str - else: - raise ValueError(f"Unsupported parameter type {field_type}") - - @staticmethod - def _get_argparse_type_str(field_type: Type) -> str: - argparse_type = EnvArgumentParser._get_argparse_type(field_type) - if argparse_type is int: - return "int" - elif argparse_type is float: - return "float" - elif argparse_type is bool: - return "bool" - else: - return "str" - - -@dataclass -class ParameterDescription: - param_name: str - param_type: str - description: str - default_value: Optional[Any] - valid_values: Optional[List[Any]] - - -def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]: - descriptions = [] - for field in fields(dataclass_type): - descriptions.append( - ParameterDescription( - param_name=field.name, - param_type=EnvArgumentParser._get_argparse_type_str(field.type), - description=field.metadata.get("help", None), - default_value=field.default, # TODO handle dataclasses._MISSING_TYPE - valid_values=field.metadata.get("valid_values", None), - ) - ) - return descriptions - - class WorkerType(str, Enum): LLM = "llm" TEXT2VEC = "text2vec" @@ -144,58 +20,13 @@ class WorkerType(str, Enum): @dataclass -class BaseParameters: - def update_from(self, source: Union["BaseParameters", dict]) -> bool: - """ - Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary. - Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata. - - Args: - source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary. - - Returns: - bool: True if at least one field was updated, otherwise False. - """ - updated = False # Flag to indicate whether any field was updated - if isinstance(source, (BaseParameters, dict)): - for field_info in fields(self): - # Check if the field has a "fixed" tag in metadata - tags = field_info.metadata.get("tags") - tags = [] if not tags else tags.split(",") - if tags and "fixed" in tags: - continue # skip this field - # Get the new value from source (either another BaseParameters object or a dict) - new_value = ( - getattr(source, field_info.name) - if isinstance(source, BaseParameters) - else source.get(field_info.name, None) - ) - - # If the new value is not None and different from the current value, update the field and set the flag - if new_value is not None and new_value != getattr( - self, field_info.name - ): - setattr(self, field_info.name, new_value) - updated = True - else: - raise ValueError( - "Source must be an instance of BaseParameters (or its derived class) or a dictionary." - ) - - return updated - - def __str__(self) -> str: - class_name = self.__class__.__name__ - parameters = [ - f"\n\n=========================== {class_name} ===========================\n" - ] - for field_info in fields(self): - value = getattr(self, field_info.name) - parameters.append(f"{field_info.name}: {value}") - parameters.append( - "\n======================================================================\n\n" - ) - return "\n".join(parameters) +class ModelControllerParameters(BaseParameters): + host: Optional[str] = field( + default="0.0.0.0", metadata={"help": "Model Controller deploy host"} + ) + port: Optional[int] = field( + default=8000, metadata={"help": "Model Controller deploy port"} + ) @dataclass @@ -209,7 +40,7 @@ class ModelWorkerParameters(BaseParameters): worker_class: Optional[str] = field( default=None, metadata={ - "help": "Model worker deploy host, pilot.model.worker.default_worker.DefaultModelWorker" + "help": "Model worker class, pilot.model.worker.default_worker.DefaultModelWorker" }, ) host: Optional[str] = field( diff --git a/pilot/model/worker/base.py b/pilot/model/worker/base.py index dfb0186fb..5809e2bc7 100644 --- a/pilot/model/worker/base.py +++ b/pilot/model/worker/base.py @@ -2,10 +2,9 @@ from abc import ABC, abstractmethod from typing import Dict, Iterator, List, Type from pilot.model.base import ModelOutput -from pilot.model.parameter import ( - ModelParameters, +from pilot.model.parameter import ModelParameters, WorkerType +from pilot.utils.parameter_utils import ( ParameterDescription, - WorkerType, _get_parameter_descriptions, ) diff --git a/pilot/model/worker/default_worker.py b/pilot/model/worker/default_worker.py index deea90191..034b3cffa 100644 --- a/pilot/model/worker/default_worker.py +++ b/pilot/model/worker/default_worker.py @@ -7,10 +7,11 @@ from pilot.configs.model_config import DEVICE from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper from pilot.model.base import ModelOutput from pilot.model.loader import ModelLoader, _get_model_real_path -from pilot.model.parameter import EnvArgumentParser, ModelParameters +from pilot.model.parameter import ModelParameters from pilot.model.worker.base import ModelWorker from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter from pilot.utils.model_utils import _clear_torch_cache +from pilot.utils.parameter_utils import EnvArgumentParser logger = logging.getLogger("model_worker") diff --git a/pilot/model/worker/embedding_worker.py b/pilot/model/worker/embedding_worker.py index 0f011dc6b..75bda74bc 100644 --- a/pilot/model/worker/embedding_worker.py +++ b/pilot/model/worker/embedding_worker.py @@ -5,11 +5,11 @@ from pilot.configs.model_config import DEVICE from pilot.model.loader import _get_model_real_path from pilot.model.parameter import ( EmbeddingModelParameters, - EnvArgumentParser, WorkerType, ) from pilot.model.worker.base import ModelWorker from pilot.utils.model_utils import _clear_torch_cache +from pilot.utils.parameter_utils import EnvArgumentParser logger = logging.getLogger("model_worker") diff --git a/pilot/model/worker/manager.py b/pilot/model/worker/manager.py index 3d18088e9..89063de8b 100644 --- a/pilot/model/worker/manager.py +++ b/pilot/model/worker/manager.py @@ -23,15 +23,14 @@ from pilot.model.base import ( ) from pilot.model.controller.registry import ModelRegistry from pilot.model.parameter import ( - EnvArgumentParser, ModelParameters, ModelWorkerParameters, WorkerType, - ParameterDescription, ) from pilot.model.worker.base import ModelWorker from pilot.scene.base_message import ModelMessage from pilot.utils import build_logger +from pilot.utils.parameter_utils import EnvArgumentParser, ParameterDescription from pydantic import BaseModel logger = build_logger("model_worker", LOGDIR + "/model_worker.log") @@ -148,6 +147,8 @@ class LocalWorkerManager(WorkerManager): self.model_registry = model_registry def _worker_key(self, worker_type: str, model_name: str) -> str: + if isinstance(worker_type, WorkerType): + worker_type = worker_type.value return f"{model_name}@{worker_type}" def add_worker( @@ -166,6 +167,9 @@ class LocalWorkerManager(WorkerManager): if not worker_params.worker_type: worker_params.worker_type = worker.worker_type() + if isinstance(worker_params.worker_type, WorkerType): + worker_params.worker_type = worker_params.worker_type.value + worker_key = self._worker_key( worker_params.worker_type, worker_params.model_name ) diff --git a/pilot/scripts/cli_scripts.py b/pilot/scripts/cli_scripts.py index 537b0ed25..ba3dd8b4b 100644 --- a/pilot/scripts/cli_scripts.py +++ b/pilot/scripts/cli_scripts.py @@ -5,7 +5,7 @@ import copy import logging sys.path.append( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -23,16 +23,53 @@ def cli(log_level: str): logging.basicConfig(level=log_level, encoding="utf-8") -def add_command_alias(command, name: str, hidden: bool = False): +def add_command_alias(command, name: str, hidden: bool = False, parent_group=None): + if not parent_group: + parent_group = cli new_command = copy.deepcopy(command) new_command.hidden = hidden - cli.add_command(new_command, name=name) + parent_group.add_command(new_command, name=name) +@click.group() +def start(): + """Start specific server.""" + pass + + +@click.group() +def stop(): + """Start specific server.""" + pass + + +cli.add_command(start) +cli.add_command(stop) + try: - from pilot.model.cli import model_cli_group + from pilot.model.cli import ( + model_cli_group, + start_model_controller, + stop_model_controller, + start_model_worker, + stop_model_worker, + start_webserver, + stop_webserver, + start_apiserver, + stop_apiserver, + ) + + add_command_alias(model_cli_group, name="model", parent_group=cli) + add_command_alias(start_model_controller, name="controller", parent_group=start) + add_command_alias(start_model_worker, name="worker", parent_group=start) + add_command_alias(start_webserver, name="webserver", parent_group=start) + add_command_alias(start_apiserver, name="apiserver", parent_group=start) + + add_command_alias(stop_model_controller, name="controller", parent_group=stop) + add_command_alias(stop_model_worker, name="worker", parent_group=stop) + add_command_alias(stop_webserver, name="webserver", parent_group=stop) + add_command_alias(stop_apiserver, name="apiserver", parent_group=stop) - add_command_alias(model_cli_group, name="model") except ImportError as e: logging.warning(f"Integrating dbgpt model command line tool failed: {e}") diff --git a/pilot/utils/parameter_utils.py b/pilot/utils/parameter_utils.py new file mode 100644 index 000000000..026e1bec1 --- /dev/null +++ b/pilot/utils/parameter_utils.py @@ -0,0 +1,314 @@ +import argparse +import os +from dataclasses import dataclass, fields, MISSING +from typing import Any, List, Optional, Type, Union, Callable + + +@dataclass +class ParameterDescription: + param_name: str + param_type: str + description: str + default_value: Optional[Any] + valid_values: Optional[List[Any]] + + +@dataclass +class BaseParameters: + def update_from(self, source: Union["BaseParameters", dict]) -> bool: + """ + Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary. + Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata. + + Args: + source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary. + + Returns: + bool: True if at least one field was updated, otherwise False. + """ + updated = False # Flag to indicate whether any field was updated + if isinstance(source, (BaseParameters, dict)): + for field_info in fields(self): + # Check if the field has a "fixed" tag in metadata + tags = field_info.metadata.get("tags") + tags = [] if not tags else tags.split(",") + if tags and "fixed" in tags: + continue # skip this field + # Get the new value from source (either another BaseParameters object or a dict) + new_value = ( + getattr(source, field_info.name) + if isinstance(source, BaseParameters) + else source.get(field_info.name, None) + ) + + # If the new value is not None and different from the current value, update the field and set the flag + if new_value is not None and new_value != getattr( + self, field_info.name + ): + setattr(self, field_info.name, new_value) + updated = True + else: + raise ValueError( + "Source must be an instance of BaseParameters (or its derived class) or a dictionary." + ) + + return updated + + def __str__(self) -> str: + class_name = self.__class__.__name__ + parameters = [ + f"\n\n=========================== {class_name} ===========================\n" + ] + for field_info in fields(self): + value = getattr(self, field_info.name) + parameters.append(f"{field_info.name}: {value}") + parameters.append( + "\n======================================================================\n\n" + ) + return "\n".join(parameters) + + +def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None): + """Get the value from the environment variable, ignoring the case of the key""" + if env_prefix: + env_key = env_prefix + env_key + return os.getenv( + env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value)) + ) + + +class EnvArgumentParser: + @staticmethod + def get_env_prefix(env_key: str) -> str: + if not env_key: + return None + env_key = env_key.replace("-", "_") + return env_key + "_" + + def parse_args_into_dataclass( + self, + dataclass_type: Type, + env_prefix: str = None, + command_args: List[str] = None, + **kwargs, + ) -> Any: + """Parse parameters from environment variables and command lines and populate them into data class""" + parser = argparse.ArgumentParser() + for field in fields(dataclass_type): + env_var_value = _genenv_ignoring_key_case(field.name, env_prefix) + if not env_var_value: + # Read without env prefix + env_var_value = _genenv_ignoring_key_case(field.name) + + if env_var_value: + env_var_value = env_var_value.strip() + if field.type is int or field.type == Optional[int]: + env_var_value = int(env_var_value) + elif field.type is float or field.type == Optional[float]: + env_var_value = float(env_var_value) + elif field.type is bool or field.type == Optional[bool]: + env_var_value = env_var_value.lower() == "true" + elif field.type is str or field.type == Optional[str]: + pass + else: + raise ValueError(f"Unsupported parameter type {field.type}") + if not env_var_value: + env_var_value = kwargs.get(field.name) + + # Add a command-line argument for this field + help_text = field.metadata.get("help", "") + valid_values = field.metadata.get("valid_values", None) + + argument_kwargs = { + "type": EnvArgumentParser._get_argparse_type(field.type), + "help": help_text, + "choices": valid_values, + "required": EnvArgumentParser._is_require_type(field.type), + } + if field.default != MISSING: + argument_kwargs["default"] = field.default + argument_kwargs["required"] = False + if env_var_value: + argument_kwargs["default"] = env_var_value + argument_kwargs["required"] = False + + parser.add_argument(f"--{field.name}", **argument_kwargs) + + # Parse the command-line arguments + cmd_args, cmd_argv = parser.parse_known_args(args=command_args) + # cmd_args = parser.parse_args(args=command_args) + # print(f"cmd_args: {cmd_args}") + for field in fields(dataclass_type): + # cmd_line_value = getattr(cmd_args, field.name) + if field.name in cmd_args: + cmd_line_value = getattr(cmd_args, field.name) + if cmd_line_value is not None: + kwargs[field.name] = cmd_line_value + + return dataclass_type(**kwargs) + + @staticmethod + def create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=dataclass_type.__doc__) + for field in fields(dataclass_type): + help_text = field.metadata.get("help", "") + valid_values = field.metadata.get("valid_values", None) + argument_kwargs = { + "type": EnvArgumentParser._get_argparse_type(field.type), + "help": help_text, + "choices": valid_values, + "required": EnvArgumentParser._is_require_type(field.type), + } + if field.default != MISSING: + argument_kwargs["default"] = field.default + argument_kwargs["required"] = False + parser.add_argument(f"--{field.name}", **argument_kwargs) + return parser + + @staticmethod + def create_click_option( + *dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None + ): + import click + import functools + from collections import OrderedDict + + # TODO dynamic configuration + # pre_args = _SimpleArgParser('model_name', 'model_path') + # pre_args.parse() + # print(pre_args) + + combined_fields = OrderedDict() + if _dynamic_factory: + _types = _dynamic_factory() + if _types: + dataclass_types = list(_types) + for dataclass_type in dataclass_types: + for field in fields(dataclass_type): + if field.name not in combined_fields: + combined_fields[field.name] = field + + def decorator(func): + for field_name, field in reversed(combined_fields.items()): + help_text = field.metadata.get("help", "") + valid_values = field.metadata.get("valid_values", None) + cli_params = { + "default": None if field.default is MISSING else field.default, + "help": help_text, + "show_default": True, + "required": field.default is MISSING, + } + if valid_values: + cli_params["type"] = click.Choice(valid_values) + real_type = EnvArgumentParser._get_argparse_type(field.type) + if real_type is int: + cli_params["type"] = click.INT + elif real_type is float: + cli_params["type"] = click.FLOAT + elif real_type is str: + cli_params["type"] = click.STRING + elif real_type is bool: + cli_params["is_flag"] = True + + option_decorator = click.option( + # f"--{field_name.replace('_', '-')}", **cli_params + f"--{field_name}", + **cli_params, + ) + func = option_decorator(func) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + return decorator + + @staticmethod + def _get_argparse_type(field_type: Type) -> Type: + # Return the appropriate type for argparse to use based on the field type + if field_type is int or field_type == Optional[int]: + return int + elif field_type is float or field_type == Optional[float]: + return float + elif field_type is bool or field_type == Optional[bool]: + return bool + elif field_type is str or field_type == Optional[str]: + return str + else: + raise ValueError(f"Unsupported parameter type {field_type}") + + @staticmethod + def _get_argparse_type_str(field_type: Type) -> str: + argparse_type = EnvArgumentParser._get_argparse_type(field_type) + if argparse_type is int: + return "int" + elif argparse_type is float: + return "float" + elif argparse_type is bool: + return "bool" + else: + return "str" + + @staticmethod + def _is_require_type(field_type: Type) -> str: + return field_type not in [Optional[int], Optional[float], Optional[bool]] + + +def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]: + descriptions = [] + for field in fields(dataclass_type): + descriptions.append( + ParameterDescription( + param_name=field.name, + param_type=EnvArgumentParser._get_argparse_type_str(field.type), + description=field.metadata.get("help", None), + default_value=field.default, # TODO handle dataclasses._MISSING_TYPE + valid_values=field.metadata.get("valid_values", None), + ) + ) + return descriptions + + +class _SimpleArgParser: + def __init__(self, *args): + self.params = {arg.replace("_", "-"): None for arg in args} + + def parse(self, args=None): + import sys + + if args is None: + args = sys.argv[1:] + else: + args = list(args) + prev_arg = None + for arg in args: + if arg.startswith("--"): + if prev_arg: + self.params[prev_arg] = None + prev_arg = arg[2:] + else: + if prev_arg: + self.params[prev_arg] = arg + prev_arg = None + + if prev_arg: + self.params[prev_arg] = None + + def _get_param(self, key): + return self.params.get(key.replace("_", "-"), None) + + def __getattr__(self, item): + return self._get_param(item) + + def __getitem__(self, key): + return self._get_param(key) + + def get(self, key, default=None): + return self._get_param(key) or default + + def __str__(self): + return "\n".join( + [f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()] + )