feat(model): multi-model supports embedding model and simple component design implementation

This commit is contained in:
FangYin Cheng
2023-09-13 12:14:03 +08:00
parent 68d30dd4bb
commit 581cf361bf
47 changed files with 1050 additions and 211 deletions

View File

@@ -4,9 +4,10 @@ services:
controller:
image: eosphorosai/dbgpt:latest
command: dbgpt start controller
restart: unless-stopped
networks:
- dbgptnet
worker:
llm-worker:
image: eosphorosai/dbgpt:latest
command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000
environment:
@@ -17,6 +18,27 @@ services:
- /data:/data
# Please modify it to your own model directory
- /data/models:/app/models
restart: unless-stopped
networks:
- dbgptnet
deploy:
resources:
reservations:
devices:
- driver: nvidia
capabilities: [gpu]
embedding-worker:
image: eosphorosai/dbgpt:latest
command: dbgpt start worker --model_name text2vec --worker_type text2vec --model_path /app/models/text2vec-large-chinese --port 8002 --controller_addr http://controller:8000
environment:
- DBGPT_LOG_LEVEL=DEBUG
depends_on:
- controller
volumes:
- /data:/data
# Please modify it to your own model directory
- /data/models:/app/models
restart: unless-stopped
networks:
- dbgptnet
deploy:
@@ -37,7 +59,8 @@ services:
- MODEL_SERVER=http://controller:8000
depends_on:
- controller
- worker
- llm-worker
- embedding-worker
volumes:
- /data:/data
# Please modify it to your own model directory

Binary file not shown.

After

Width:  |  Height:  |  Size: 361 KiB

View File

@@ -9,6 +9,7 @@ DB-GPT product is a Web application that you can chat database, chat knowledge,
- docker
- docker_compose
- environment
- cluster deployment
- deploy_faq
.. toctree::
@@ -20,6 +21,7 @@ DB-GPT product is a Web application that you can chat database, chat knowledge,
./install/deploy/deploy.md
./install/docker/docker.md
./install/docker_compose/docker_compose.md
./install/cluster/cluster.rst
./install/llm/llm.rst
./install/environment/environment.md
./install/faq/deploy_faq.md

View File

@@ -0,0 +1,19 @@
Cluster deployment
==================================
In order to deploy DB-GPT to multiple nodes, you can deploy a cluster. The cluster architecture diagram is as follows:
.. raw:: html
<img src="../../../_static/img/muti-model-cluster-overview.png" />
* On :ref:`Deploying on local machine <local-cluster-index>`. Local cluster deployment.
.. toctree::
:maxdepth: 2
:caption: Cluster deployment
:name: cluster_deploy
:hidden:
./vms/index.md

View File

@@ -0,0 +1,3 @@
Kubernetes cluster deployment
==================================
(kubernetes-cluster-index)=

View File

@@ -1,6 +1,6 @@
Cluster deployment
Local cluster deployment
==================================
(local-cluster-index)=
## Model cluster deployment
@@ -17,7 +17,7 @@ dbgpt start controller
By default, the Model Controller starts on port 8000.
### Launch Model Worker
### Launch LLM Model Worker
If you are starting `chatglm2-6b`:
@@ -39,6 +39,18 @@ dbgpt start worker --model_name vicuna-13b-v1.5 \
Note: Be sure to use your own model name and model path.
### Launch Embedding Model Worker
```bash
dbgpt start worker --model_name text2vec \
--model_path /app/models/text2vec-large-chinese \
--worker_type text2vec \
--port 8003 \
--controller_addr http://127.0.0.1:8000
```
Note: Be sure to use your own model name and model path.
Check your model:
@@ -51,8 +63,12 @@ You will see the following output:
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat |
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
| chatglm2-6b | llm | 172.17.0.6 | 8001 | True | True | None | 2023-08-31T04:48:45.252939 |
| vicuna-13b-v1.5 | llm | 172.17.0.6 | 8002 | True | True | None | 2023-08-31T04:48:55.136676 |
| chatglm2-6b | llm | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.287654 |
| WorkerManager | service | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.286668 |
| WorkerManager | service | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.845617 |
| WorkerManager | service | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.598439 |
| text2vec | text2vec | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.844796 |
| vicuna-13b-v1.5 | llm | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.597775 |
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
```
@@ -69,7 +85,7 @@ MODEL_SERVER=http://127.0.0.1:8000
#### Start the webserver
```bash
python pilot/server/dbgpt_server.py --light
dbgpt start webserver --light
```
`--light` indicates not to start the embedded model service.
@@ -77,7 +93,7 @@ python pilot/server/dbgpt_server.py --light
Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` to start:
```bash
LLM_MODEL=chatglm2-6b python pilot/server/dbgpt_server.py --light
LLM_MODEL=chatglm2-6b dbgpt start webserver --light
```
@@ -101,9 +117,11 @@ Options:
--help Show this message and exit.
Commands:
model Clients that manage model serving
start Start specific server.
stop Start specific server.
install Install dependencies, plugins, etc.
knowledge Knowledge command line tool
model Clients that manage model serving
start Start specific server.
stop Start specific server.
```
**View the `dbgpt start` help**
@@ -146,10 +164,11 @@ Options:
--model_name TEXT Model name [required]
--model_path TEXT Model path [required]
--worker_type TEXT Worker type
--worker_class TEXT Model worker class, pilot.model.worker.defau
lt_worker.DefaultModelWorker
--worker_class TEXT Model worker class,
pilot.model.cluster.DefaultModelWorker
--host TEXT Model worker deploy host [default: 0.0.0.0]
--port INTEGER Model worker deploy port [default: 8000]
--port INTEGER Model worker deploy port [default: 8001]
--daemon Run Model Worker in background
--limit_model_concurrency INTEGER
Model concurrency limit [default: 5]
--standalone Standalone mode. If True, embedded Run
@@ -166,7 +185,7 @@ Options:
(seconds) [default: 20]
--device TEXT Device to run model. If None, the device is
automatically determined
--model_type TEXT Model type, huggingface or llama.cpp
--model_type TEXT Model type, huggingface, llama.cpp and proxy
[default: huggingface]
--prompt_template TEXT Prompt template. If None, the prompt
template is automatically determined from
@@ -208,10 +227,13 @@ Usage: dbgpt model [OPTIONS] COMMAND [ARGS]...
Options:
--address TEXT Address of the Model Controller to connect to. Just support
light deploy model [default: http://127.0.0.1:8000]
light deploy model, If the environment variable
CONTROLLER_ADDRESS is configured, read from the environment
variable
--help Show this message and exit.
Commands:
chat Interact with your bot from the command line
list List model instances
restart Restart model instances
start Start model instances

View File

@@ -6,6 +6,7 @@ DB-GPT provides a management and deployment solution for multiple models. This c
Multi LLMs Support, Supports multiple large language models, currently supporting
- 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b)
- WizardLM-v1.2(13b)
@@ -19,7 +20,6 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
- llama_cpp
- quantization
- cluster deployment
.. toctree::
:maxdepth: 2
@@ -29,4 +29,3 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
./llama/llama_cpp.md
./quantization/quantization.md
./cluster/model_cluster.md

View File

@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-08-16 18:31+0800\n"
"POT-Creation-Date: 2023-09-13 09:06+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
@@ -19,34 +19,38 @@ msgstr ""
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install.rst:2 ../../getting_started/install.rst:14
#: 2861085e63144eaca1bb825e5f05d089
#: ../../getting_started/install.rst:2 ../../getting_started/install.rst:15
#: e2c13385046b4da6b6838db6ba2ea59c
msgid "Install"
msgstr "Install"
#: ../../getting_started/install.rst:3 01a6603d91fa4520b0f839379d4eda23
#: ../../getting_started/install.rst:3 3cb6cd251ed440dabe5d4f556435f405
msgid ""
"DB-GPT product is a Web application that you can chat database, chat "
"knowledge, text2dashboard."
msgstr "DB-GPT 可以生成sql智能报表, 知识库问答的产品"
#: ../../getting_started/install.rst:8 beca85cddc9b4406aecf83d5dfcce1f7
#: ../../getting_started/install.rst:8 6fe8104b70d24f5fbfe2ad9ebf3bc3ba
msgid "deploy"
msgstr "部署"
#: ../../getting_started/install.rst:9 601e9b9eb91f445fb07d2f1c807f0370
#: ../../getting_started/install.rst:9 e67974b3672346809febf99a3b9a55d3
msgid "docker"
msgstr "docker"
#: ../../getting_started/install.rst:10 6d1e094ac9284458a32a3e7fa6241c81
#: ../../getting_started/install.rst:10 64de16a047c74598966e19a656bf6c4f
msgid "docker_compose"
msgstr "docker_compose"
#: ../../getting_started/install.rst:11 ff1d1c60bbdc4e8ca82b7a9f303dd167
#: ../../getting_started/install.rst:11 9f87d65e8675435b87cb9376a5bfd85c
msgid "environment"
msgstr "environment"
#: ../../getting_started/install.rst:12 33bfbe8defd74244bfc24e8fbfd640f6
#: ../../getting_started/install.rst:12 e60fa13bb24544ed9d4f902337093ebc
msgid "cluster deployment"
msgstr "集群部署"
#: ../../getting_started/install.rst:13 7451712679c2412e858e7d3e2af6b174
msgid "deploy_faq"
msgstr "deploy_faq"

View File

@@ -0,0 +1,42 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) 2023, csunny
# This file is distributed under the same license as the DB-GPT package.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-09-13 10:11+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
"Language-Team: zh_CN <LL@li.org>\n"
"Plural-Forms: nplurals=1; plural=0;\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/cluster/cluster.rst:2
#: ../../getting_started/install/cluster/cluster.rst:13
#: 69804208b580447798d6946150da7bdf
msgid "Cluster deployment"
msgstr "集群部署"
#: ../../getting_started/install/cluster/cluster.rst:4
#: fa3e4e0ae60a45eb836bcd256baa9d91
msgid ""
"In order to deploy DB-GPT to multiple nodes, you can deploy a cluster. "
"The cluster architecture diagram is as follows:"
msgstr "为了能将 DB-GPT 部署到多个节点上,你可以部署一个集群,集群的架构图如下:"
#: ../../getting_started/install/cluster/cluster.rst:11
#: e739449099ca43cabe9883233ca7e572
#, fuzzy
msgid ""
"On :ref:`Deploying on local machine <local-cluster-index>`. Local cluster"
" deployment."
msgstr "关于 :ref:`在本地机器上部署 <local-cluster-index>`。本地集群部署。"

View File

@@ -0,0 +1,26 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) 2023, csunny
# This file is distributed under the same license as the DB-GPT package.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-09-13 09:06+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
"Language-Team: zh_CN <LL@li.org>\n"
"Plural-Forms: nplurals=1; plural=0;\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/cluster/kubernetes/index.md:1
#: 48e6f08f27c74f31a8b12758fe33dc24
msgid "Kubernetes cluster deployment"
msgstr "Kubernetes 集群部署"

View File

@@ -0,0 +1,176 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) 2023, csunny
# This file is distributed under the same license as the DB-GPT package.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-09-13 09:06+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
"Language-Team: zh_CN <LL@li.org>\n"
"Plural-Forms: nplurals=1; plural=0;\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/cluster/vms/index.md:1
#: 2d2e04ba49364eae9b8493bb274765a6
msgid "Local cluster deployment"
msgstr "本地集群部署"
#: ../../getting_started/install/cluster/vms/index.md:4
#: e405d0e7ad8c4b2da4b4ca27c77f5fea
msgid "Model cluster deployment"
msgstr "模型集群部署"
#: ../../getting_started/install/cluster/vms/index.md:7
#: bba397ddac754a2bab8edca163875b65
msgid "**Installing Command-Line Tool**"
msgstr "**安装命令行工具**"
#: ../../getting_started/install/cluster/vms/index.md:9
#: bc45851124354522af8c9bb9748ff1fa
msgid ""
"All operations below are performed using the `dbgpt` command. To use the "
"`dbgpt` command, you need to install the DB-GPT project with `pip install"
" -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` "
"as a substitute for the `dbgpt` command."
msgstr ""
"以下所有操作都使用 `dbgpt` 命令完成。要使用 `dbgpt` 命令您需要安装DB-GPT项目方法是使用`pip install -e .`。或者,您可以使用 `python pilot/scripts/cli_scripts.py` 作为 `dbgpt` 命令的替代。"
#: ../../getting_started/install/cluster/vms/index.md:11
#: 9d11f7807fd140c8949b634700adc966
msgid "Launch Model Controller"
msgstr "启动 Model Controller"
#: ../../getting_started/install/cluster/vms/index.md:17
#: 97716be92ba64ce9a215433bddf77add
msgid "By default, the Model Controller starts on port 8000."
msgstr "默认情况下Model Controller 启动在 8000 端口。"
#: ../../getting_started/install/cluster/vms/index.md:20
#: 3f65e6a1e59248a59c033891d1ab7ba8
msgid "Launch LLM Model Worker"
msgstr "启动 LLM Model Worker"
#: ../../getting_started/install/cluster/vms/index.md:22
#: 60241d97573e4265b7fb150c378c4a08
msgid "If you are starting `chatglm2-6b`:"
msgstr "如果您启动的是 `chatglm2-6b`"
#: ../../getting_started/install/cluster/vms/index.md:31
#: 18bbeb1de110438fa96dd5c736b9a7b1
msgid "If you are starting `vicuna-13b-v1.5`:"
msgstr "如果您启动的是 `vicuna-13b-v1.5`"
#: ../../getting_started/install/cluster/vms/index.md:40
#: ../../getting_started/install/cluster/vms/index.md:53
#: 24b1a27313c64224aaeab6cbfad1fe19 fc94a698a7904c6893eef7e7a6e52972
msgid "Note: Be sure to use your own model name and model path."
msgstr "注意:确保使用您自己的模型名称和模型路径。"
#: ../../getting_started/install/cluster/vms/index.md:42
#: 19746195e85f4784bf66a9e67378c04b
msgid "Launch Embedding Model Worker"
msgstr "启动 Embedding Model Worker"
#: ../../getting_started/install/cluster/vms/index.md:55
#: e93ce68091f64d0294b3f912a66cc18b
msgid "Check your model:"
msgstr "检查您的模型:"
#: ../../getting_started/install/cluster/vms/index.md:61
#: fa0b8f3a18fe4bab88fbf002bf26d32e
msgid "You will see the following output:"
msgstr "您将看到以下输出:"
#: ../../getting_started/install/cluster/vms/index.md:75
#: 695262fb4f224101902bc7865ac7871f
msgid "Connect to the model service in the webserver (dbgpt_server)"
msgstr "在 webserver (dbgpt_server) 中连接到模型服务 (dbgpt_server)"
#: ../../getting_started/install/cluster/vms/index.md:77
#: 73bf4c2ae5c64d938e3b7e77c06fa21e
msgid ""
"**First, modify the `.env` file to change the model name and the Model "
"Controller connection address.**"
msgstr ""
"**首先,修改 `.env` 文件以更改模型名称和模型控制器连接地址。**"
#: ../../getting_started/install/cluster/vms/index.md:85
#: 8ab126fd72ed4368a79b821ba50e62c8
msgid "Start the webserver"
msgstr "启动 webserver"
#: ../../getting_started/install/cluster/vms/index.md:91
#: 5a7e25c84ca2412bb64310bfad9e2403
msgid "`--light` indicates not to start the embedded model service."
msgstr "`--light` 表示不启动嵌入式模型服务。"
#: ../../getting_started/install/cluster/vms/index.md:93
#: 8cd9ec4fa9cb4c0fa8ff05c05a85ea7f
msgid ""
"Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` "
"to start:"
msgstr ""
"或者,您可以在命令前加上 `LLM_MODEL=chatglm2-6b` 来启动:"
#: ../../getting_started/install/cluster/vms/index.md:100
#: 13ed16758a104860b5fc982d36638b17
msgid "More Command-Line Usages"
msgstr "更多命令行用法"
#: ../../getting_started/install/cluster/vms/index.md:102
#: 175f614d547a4391bab9a77762f9174e
msgid "You can view more command-line usages through the help command."
msgstr "您可以通过帮助命令查看更多命令行用法。"
#: ../../getting_started/install/cluster/vms/index.md:104
#: 6a4475d271c347fbbb35f2936a86823f
msgid "**View the `dbgpt` help**"
msgstr "**查看 `dbgpt` 帮助**"
#: ../../getting_started/install/cluster/vms/index.md:109
#: 3eb11234cf504cc9ac369d8462daa14b
msgid "You will see the basic command parameters and usage:"
msgstr "您将看到基本的命令参数和用法:"
#: ../../getting_started/install/cluster/vms/index.md:127
#: 6eb47aecceec414e8510fe022b6fddbd
msgid "**View the `dbgpt start` help**"
msgstr "**查看 `dbgpt start` 帮助**"
#: ../../getting_started/install/cluster/vms/index.md:133
#: 1f4c0a4ce0704ca8ac33178bd13c69ad
msgid "Here you can see the related commands and usage for start:"
msgstr "在这里,您可以看到启动的相关命令和用法:"
#: ../../getting_started/install/cluster/vms/index.md:150
#: 22e8e67bc55244e79764d091f334560b
msgid "**View the `dbgpt start worker`help**"
msgstr "**查看 `dbgpt start worker` 帮助**"
#: ../../getting_started/install/cluster/vms/index.md:156
#: 5631b83fda714780855e99e90d4eb542
msgid "Here you can see the parameters to start Model Worker:"
msgstr "在这里,您可以看到启动 Model Worker 的参数:"
#: ../../getting_started/install/cluster/vms/index.md:215
#: cf4a31fd3368481cba1b3ab382615f53
msgid "**View the `dbgpt model`help**"
msgstr "**查看 `dbgpt model` 帮助**"
#: ../../getting_started/install/cluster/vms/index.md:221
#: 3740774ec4b240f2882b5b59da224d55
msgid ""
"The `dbgpt model ` command can connect to the Model Controller via the "
"Model Controller address and then manage a remote model:"
msgstr ""
"`dbgpt model` 命令可以通过 Model Controller 地址连接到 Model Controller然后管理远程模型"

View File

@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-08-31 16:38+0800\n"
"POT-Creation-Date: 2023-09-13 10:46+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
@@ -21,84 +21,88 @@ msgstr ""
#: ../../getting_started/install/llm/llm.rst:2
#: ../../getting_started/install/llm/llm.rst:24
#: b348d4df8ca44dd78b42157a8ff6d33d
#: e693a8d3769b4d9e99c4442ca77dc43c
msgid "LLM Usage"
msgstr "LLM使用"
#: ../../getting_started/install/llm/llm.rst:3 7f5960a7e5634254b330da27be87594b
#: ../../getting_started/install/llm/llm.rst:3 0a73562d18ba455bab04277b715c3840
msgid ""
"DB-GPT provides a management and deployment solution for multiple models."
" This chapter mainly discusses how to deploy different models."
msgstr "DB-GPT提供了多模型的管理和部署方案本章主要讲解针对不同的模型该怎么部署"
#: ../../getting_started/install/llm/llm.rst:18
#: b844ab204ec740ec9d7d191bb841f09e
#: ../../getting_started/install/llm/llm.rst:19
#: d7e4de2a7e004888897204ec76b6030b
msgid ""
"Multi LLMs Support, Supports multiple large language models, currently "
"supporting"
msgstr "目前DB-GPT已适配如下模型"
#: ../../getting_started/install/llm/llm.rst:9 c141437ddaf84c079360008343041b2f
#: ../../getting_started/install/llm/llm.rst:9 4616886b8b2244bd93355e871356d89e
#, fuzzy
msgid "🔥 Baichuan2(7b,13b)"
msgstr "Baichuan(7b,13b)"
#: ../../getting_started/install/llm/llm.rst:10
#: ad0e4793d4e744c1bdf59f5a3d9c84be
msgid "🔥 Vicuna-v1.5(7b,13b)"
msgstr "🔥 Vicuna-v1.5(7b,13b)"
#: ../../getting_started/install/llm/llm.rst:10
#: d32b1e3f114c4eab8782b497097c1b37
#: ../../getting_started/install/llm/llm.rst:11
#: d291e58001ae487bbbf2a1f9f889f5fd
msgid "🔥 llama-2(7b,13b,70b)"
msgstr "🔥 llama-2(7b,13b,70b)"
#: ../../getting_started/install/llm/llm.rst:11
#: 0a417ee4d008421da07fff7add5d05eb
#: ../../getting_started/install/llm/llm.rst:12
#: 1e49702ee40b4655945a2a13efaad536
msgid "WizardLM-v1.2(13b)"
msgstr "WizardLM-v1.2(13b)"
#: ../../getting_started/install/llm/llm.rst:12
#: 199e1a9fe3324dc8a1bcd9cd0b1ef047
#: ../../getting_started/install/llm/llm.rst:13
#: 4ef5913ddfe840d7a12289e6e1d4cb60
msgid "Vicuna (7b,13b)"
msgstr "Vicuna (7b,13b)"
#: ../../getting_started/install/llm/llm.rst:13
#: a9e4c5100534450db3a583fa5850e4be
#: ../../getting_started/install/llm/llm.rst:14
#: ea46c2211257459285fa48083cb59561
msgid "ChatGLM-6b (int4,int8)"
msgstr "ChatGLM-6b (int4,int8)"
#: ../../getting_started/install/llm/llm.rst:14
#: 943324289eb94042b52fd824189cd93f
#: ../../getting_started/install/llm/llm.rst:15
#: 90688302bae4452a84f14e8ecb7f1a21
msgid "ChatGLM2-6b (int4,int8)"
msgstr "ChatGLM2-6b (int4,int8)"
#: ../../getting_started/install/llm/llm.rst:15
#: f1226fdfac3b4e9d88642ffa69d75682
#: ../../getting_started/install/llm/llm.rst:16
#: ee1469545a314696a36e7296c7b71960
msgid "guanaco(7b,13b,33b)"
msgstr "guanaco(7b,13b,33b)"
#: ../../getting_started/install/llm/llm.rst:16
#: 3f2457f56eb341b6bc431c9beca8f4df
#: ../../getting_started/install/llm/llm.rst:17
#: 25abad241f4d4eee970d5938bf71311f
msgid "Gorilla(7b,13b)"
msgstr "Gorilla(7b,13b)"
#: ../../getting_started/install/llm/llm.rst:17
#: 86c8ce37be1c4a7ea3fc382100d77a9c
#: ../../getting_started/install/llm/llm.rst:18
#: 8e3d0399431a4c6a9065a8ae0ad3c8ac
msgid "Baichuan(7b,13b)"
msgstr "Baichuan(7b,13b)"
#: ../../getting_started/install/llm/llm.rst:18
#: 538111af95ad414cb2e631a89f9af379
#: ../../getting_started/install/llm/llm.rst:19
#: c285fa7c9c6c4e3e9840761a09955348
msgid "OpenAI"
msgstr "OpenAI"
#: ../../getting_started/install/llm/llm.rst:20
#: a203325b7ec248f7bff61ae89226a000
#: ../../getting_started/install/llm/llm.rst:21
#: 4ac13a21f323455982750bd2e0243b72
msgid "llama_cpp"
msgstr "llama_cpp"
#: ../../getting_started/install/llm/llm.rst:21
#: 21a50634198047228bc51a03d2c31292
#: ../../getting_started/install/llm/llm.rst:22
#: 7231edceef584724a6f569c6b363e083
msgid "quantization"
msgstr "quantization"
#: ../../getting_started/install/llm/llm.rst:22
#: dfaec4b04e6e45ff9c884b41534b1a79
msgid "cluster deployment"
msgstr ""
#~ msgid "cluster deployment"
#~ msgstr ""

View File

@@ -1,4 +1,4 @@
autodoc_pydantic==1.8.0
autodoc_pydantic
myst_parser
nbsphinx==0.8.9
sphinx==4.5.0

143
pilot/componet.py Normal file
View File

@@ -0,0 +1,143 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, Dict, TypeVar, Optional, TYPE_CHECKING
import asyncio
# Checking for type hints during runtime
if TYPE_CHECKING:
from fastapi import FastAPI
class LifeCycle:
"""This class defines hooks for lifecycle events of a component."""
def before_start(self):
"""Called before the component starts."""
pass
async def async_before_start(self):
"""Asynchronous version of before_start."""
pass
def after_start(self):
"""Called after the component has started."""
pass
async def async_after_start(self):
"""Asynchronous version of after_start."""
pass
def before_stop(self):
"""Called before the component stops."""
pass
async def async_before_stop(self):
"""Asynchronous version of before_stop."""
pass
class BaseComponet(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this."""
name = "base_dbgpt_componet"
def __init__(self, system_app: Optional[SystemApp] = None):
if system_app is not None:
self.init_app(system_app)
@abstractmethod
def init_app(self, system_app: SystemApp):
"""Initialize the component with the main application.
This method needs to be implemented by every component to define how it integrates
with the main system app.
"""
pass
T = TypeVar("T", bound=BaseComponet)
class SystemApp(LifeCycle):
"""Main System Application class that manages the lifecycle and registration of components."""
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
self.componets: Dict[
str, BaseComponet
] = {} # Dictionary to store registered components.
self._asgi_app = asgi_app
@property
def app(self) -> Optional["FastAPI"]:
"""Returns the internal ASGI app."""
return self._asgi_app
def register(self, componet: Type[BaseComponet], *args, **kwargs):
"""Register a new component by its type."""
instance = componet(self, *args, **kwargs)
self.register_instance(instance)
def register_instance(self, instance: T):
"""Register an already initialized component."""
self.componets[instance.name] = instance
instance.init_app(self)
def get_componet(self, name: str, componet_type: Type[T]) -> T:
"""Retrieve a registered component by its name and type."""
component = self.componets.get(name)
if not component:
raise ValueError(f"No component found with name {name}")
if not isinstance(component, componet_type):
raise TypeError(f"Component {name} is not of type {componet_type}")
return component
def before_start(self):
"""Invoke the before_start hooks for all registered components."""
for _, v in self.componets.items():
v.before_start()
async def async_before_start(self):
"""Asynchronously invoke the before_start hooks for all registered components."""
tasks = [v.async_before_start() for _, v in self.componets.items()]
await asyncio.gather(*tasks)
def after_start(self):
"""Invoke the after_start hooks for all registered components."""
for _, v in self.componets.items():
v.after_start()
async def async_after_start(self):
"""Asynchronously invoke the after_start hooks for all registered components."""
tasks = [v.async_after_start() for _, v in self.componets.items()]
await asyncio.gather(*tasks)
def before_stop(self):
"""Invoke the before_stop hooks for all registered components."""
for _, v in self.componets.items():
try:
v.before_stop()
except Exception as e:
pass
async def async_before_stop(self):
"""Asynchronously invoke the before_stop hooks for all registered components."""
tasks = [v.async_before_stop() for _, v in self.componets.items()]
await asyncio.gather(*tasks)
def _build(self):
"""Integrate lifecycle events with the internal ASGI app if available."""
if not self.app:
return
@self.app.on_event("startup")
async def startup_event():
"""ASGI app startup event handler."""
asyncio.create_task(self.async_after_start())
self.after_start()
@self.app.on_event("shutdown")
async def shutdown_event():
"""ASGI app shutdown event handler."""
await self.async_before_stop()
self.before_stop()

View File

@@ -189,6 +189,10 @@ class Config(metaclass=Singleton):
### Log level
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
from pilot.componet import SystemApp
self.SYSTEM_APP: SystemApp = None
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value

View File

@@ -4,6 +4,7 @@ import asyncio
from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
from pilot.componet import SystemApp
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect
@@ -46,9 +47,9 @@ class ConnectManager:
raise ValueError("Unsupport Db Type" + db_type)
return result
def __init__(self):
def __init__(self, system_app: SystemApp):
self.storage = DuckdbConnectConfig()
self.db_summary_client = DBSummaryClient()
self.db_summary_client = DBSummaryClient(system_app)
self.__load_config_db()
def __load_config_db(self):

View File

@@ -2,6 +2,6 @@ from pilot.configs.config import Config
from pilot.connections.manages.connection_manager import ConnectManager
if __name__ == "__main__":
mange = ConnectManager()
mange = ConnectManager(system_app=None)
types = mange.get_all_completed_types()
print(str(types))

View File

@@ -1,9 +1,12 @@
from typing import Optional
from chromadb.errors import NotEnoughElementsException
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import TextSplitter
from pilot.embedding_engine.embedding_factory import (
EmbeddingFactory,
DefaultEmbeddingFactory,
)
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding, KnowledgeType
from pilot.vector_store.connector import VectorStoreConnector
@@ -24,13 +27,16 @@ class EmbeddingEngine:
knowledge_source: Optional[str] = None,
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
embedding_factory: EmbeddingFactory = None,
):
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source"""
self.knowledge_source = knowledge_source
self.model_name = model_name
self.vector_store_config = vector_store_config
self.knowledge_type = knowledge_type
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
if not embedding_factory:
embedding_factory = DefaultEmbeddingFactory()
self.embeddings = embedding_factory.create(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings
self.source_reader = source_reader
self.text_splitter = text_splitter

View File

@@ -0,0 +1,39 @@
from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import BaseComponet
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
class EmbeddingFactory(BaseComponet, ABC):
name = "embedding_factory"
@abstractmethod
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
"""Create embedding"""
class DefaultEmbeddingFactory(EmbeddingFactory):
def __init__(self, system_app=None, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
self._default_model_name = model_name
self.kwargs = kwargs
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
if not model_name:
model_name = self._default_model_name
if embedding_cls:
return embedding_cls(model_name=model_name, **self.kwargs)
else:
from langchain.embeddings import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(model_name=model_name, **self.kwargs)

View File

@@ -96,13 +96,18 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import EmbeddingModelParameters, WorkerType
pre_args = _SimpleArgParser("model_name", "model_path")
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
pre_args.parse()
model_name = pre_args.get("model_name")
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
if model_name is None:
return None
if worker_type == WorkerType.TEXT2VEC:
return [EmbeddingModelParameters]
llm_adapter = get_llm_model_adapter(model_name, model_path)
param_class = llm_adapter.model_param_class()
return [param_class]

View File

@@ -167,7 +167,6 @@ def stop(model_name: str, model_type: str, host: str, port: int):
def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
from pilot.model.adapter import _dynamic_model_parser
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.cluster import RemoteWorkerManager
from pilot.model.parameter import WorkerType

View File

@@ -5,6 +5,9 @@ from pilot.model.cluster.base import (
WorkerParameterRequest,
WorkerStartupRequest,
)
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
from pilot.model.cluster.worker.manager import (
initialize_worker_manager_in_client,
run_worker_manager,
@@ -23,11 +26,15 @@ __all__ = [
"EmbeddingsRequest",
"PromptRequest",
"WorkerApplyRequest",
"WorkerParameterRequest"
"WorkerStartupRequest"
"worker_manager"
"WorkerParameterRequest",
"WorkerStartupRequest",
"ModelWorker",
"DefaultModelWorker",
"worker_manager",
"run_worker_manager",
"initialize_worker_manager_in_client",
"ModelRegistry",
"ModelRegistryClient" "RemoteWorkerManager" "run_model_controller",
"ModelRegistryClient",
"RemoteWorkerManager",
"run_model_controller",
]

View File

@@ -8,7 +8,10 @@ from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
from pilot.utils.parameter_utils import EnvArgumentParser
from pilot.utils.api_utils import _api_remote as api_remote
from pilot.utils.api_utils import (
_api_remote as api_remote,
_sync_api_remote as sync_api_remote,
)
class BaseModelController(ABC):
@@ -89,6 +92,12 @@ class ModelRegistryClient(_RemoteModelController, ModelRegistry):
async def get_all_model_instances(self) -> List[ModelInstance]:
return await self.get_all_instances()
@sync_api_remote(path="/api/controller/models")
def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
pass
class ModelControllerAdapter(BaseModelController):
def __init__(self, backend: BaseModelController = None) -> None:

View File

@@ -0,0 +1,28 @@
from typing import List
from langchain.embeddings.base import Embeddings
from pilot.model.cluster.manager_base import WorkerManager
class RemoteEmbeddings(Embeddings):
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
self.model_name = model_name
self.worker_manager = worker_manager
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
params = {"model": self.model_name, "input": texts}
return self.worker_manager.sync_embeddings(params)
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
params = {"model": self.model_name, "input": texts}
return await self.worker_manager.embeddings(params)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await self.aembed_documents([text])[0]

View File

@@ -1,6 +1,6 @@
import asyncio
from dataclasses import dataclass
from typing import List, Optional, Dict, Iterator
from typing import List, Optional, Dict, Iterator, Callable
from abc import ABC, abstractmethod
from datetime import datetime
from concurrent.futures import Future
@@ -35,15 +35,31 @@ class WorkerManager(ABC):
async def stop(self):
"""Stop worker manager"""
@abstractmethod
def after_start(self, listener: Callable[["WorkerManager"], None]):
"""Add a listener after WorkerManager startup"""
@abstractmethod
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Asynchronous get model instances by worker type and model name"""
@abstractmethod
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Get model instances by worker type and model name"""
@abstractmethod
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
"""Asynchronous select one instance"""
@abstractmethod
def sync_select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
"""Select one instance"""
@@ -69,7 +85,15 @@ class WorkerManager(ABC):
@abstractmethod
async def embeddings(self, params: Dict) -> List[List[float]]:
"""Embed input"""
"""Asynchronous embed input"""
@abstractmethod
def sync_embeddings(self, params: Dict) -> List[List[float]]:
"""Embed input
This function may be passed to a third-party system call for synchronous calls.
We must provide a synchronous version.
"""
@abstractmethod
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:

View File

@@ -58,6 +58,12 @@ class ModelRegistry(ABC):
- List[ModelInstance]: A list of instances for the given model.
"""
@abstractmethod
def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
@abstractmethod
async def get_all_model_instances(self) -> List[ModelInstance]:
"""
@@ -163,6 +169,11 @@ class EmbeddedModelRegistry(ModelRegistry):
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
return self.sync_get_all_instances(model_name, healthy_only)
def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
instances = self.registry[model_name]
if healthy_only:
@@ -179,7 +190,7 @@ class EmbeddedModelRegistry(ModelRegistry):
)
if not exist_ins:
# register new install from heartbeat
self.register_instance(instance)
await self.register_instance(instance)
return True
ins = exist_ins[0]

View File

@@ -24,7 +24,7 @@ class EmbeddingsModelWorker(ModelWorker):
"Could not import langchain.embeddings.HuggingFaceEmbeddings python package. "
"Please install it with `pip install langchain`."
) from exc
self.embeddings: Embeddings = None
self._embeddings_impl: Embeddings = None
self._model_params = None
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
@@ -75,16 +75,16 @@ class EmbeddingsModelWorker(ModelWorker):
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
self.embeddings = HuggingFaceEmbeddings(**kwargs)
self._embeddings_impl = HuggingFaceEmbeddings(**kwargs)
def __del__(self):
self.stop()
def stop(self) -> None:
if not self.embeddings:
if not self._embeddings_impl:
return
del self.embeddings
self.embeddings = None
del self._embeddings_impl
self._embeddings_impl = None
_clear_torch_cache(self._model_params.device)
def generate_stream(self, params: Dict):
@@ -96,5 +96,7 @@ class EmbeddingsModelWorker(ModelWorker):
raise NotImplementedError("Not supported generate for embeddings model")
def embeddings(self, params: Dict) -> List[List[float]]:
model = params.get("model")
logger.info(f"Receive embeddings request, model: {model}")
input: List[str] = params["input"]
return self.embeddings.embed_documents(input)
return self._embeddings_impl.embed_documents(input)

View File

@@ -72,6 +72,7 @@ class LocalWorkerManager(WorkerManager):
self.model_registry = model_registry
self.host = host
self.port = port
self.start_listeners = []
self.run_data = WorkerRunData(
host=self.host,
@@ -105,6 +106,8 @@ class LocalWorkerManager(WorkerManager):
asyncio.create_task(
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
)
for listener in self.start_listeners:
listener(self)
async def stop(self):
if not self.run_data.stop_event.is_set():
@@ -116,6 +119,9 @@ class LocalWorkerManager(WorkerManager):
stop_tasks.append(self.deregister_func(self.run_data))
await asyncio.gather(*stop_tasks)
def after_start(self, listener: Callable[["WorkerManager"], None]):
self.start_listeners.append(listener)
def add_worker(
self,
worker: ModelWorker,
@@ -137,14 +143,7 @@ class LocalWorkerManager(WorkerManager):
worker_key = self._worker_key(
worker_params.worker_type, worker_params.model_name
)
host = worker_params.host
port = worker_params.port
instances = self.workers.get(worker_key)
if not instances:
instances = []
self.workers[worker_key] = instances
logger.info(f"Init empty instances list for {worker_key}")
# Load model params from persist storage
model_params = worker.parse_parameters(command_args=command_args)
@@ -159,14 +158,15 @@ class LocalWorkerManager(WorkerManager):
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
command_args=command_args,
)
exist_instances = [
ins for ins in instances if ins.host == host and ins.port == port
]
if not exist_instances:
instances.append(worker_run_data)
instances = self.workers.get(worker_key)
if not instances:
instances = [worker_run_data]
self.workers[worker_key] = instances
logger.info(f"Init empty instances list for {worker_key}")
return True
else:
# TODO Update worker
logger.warn(f"Instance {worker_key} exist")
return False
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
@@ -222,16 +222,18 @@ class LocalWorkerManager(WorkerManager):
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
return self.sync_get_model_instances(worker_type, model_name, healthy_only)
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
worker_key = self._worker_key(worker_type, model_name)
return self.workers.get(worker_key)
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
def _simple_select(
self, worker_type: str, model_name: str, worker_instances: List[WorkerRunData]
) -> WorkerRunData:
worker_instances = await self.get_model_instances(
worker_type, model_name, healthy_only
)
if not worker_instances:
raise Exception(
f"Cound not found worker instances for model name {model_name} and worker type {worker_type}"
@@ -239,12 +241,34 @@ class LocalWorkerManager(WorkerManager):
worker_run_data = random.choice(worker_instances)
return worker_run_data
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
worker_instances = await self.get_model_instances(
worker_type, model_name, healthy_only
)
return self._simple_select(worker_type, model_name, worker_instances)
def sync_select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
worker_instances = self.sync_get_model_instances(
worker_type, model_name, healthy_only
)
return self._simple_select(worker_type, model_name, worker_instances)
async def _get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
model = params.get("model")
if not model:
raise Exception("Model name count not be empty")
return await self.select_one_instance(worker_type, model, healthy_only=True)
def _sync_get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
model = params.get("model")
if not model:
raise Exception("Model name count not be empty")
return self.sync_select_one_instance(worker_type, model, healthy_only=True)
async def generate_stream(
self, params: Dict, async_wrapper=None, **kwargs
) -> Iterator[ModelOutput]:
@@ -304,6 +328,10 @@ class LocalWorkerManager(WorkerManager):
worker_run_data.worker.embeddings, params
)
def sync_embeddings(self, params: Dict) -> List[List[float]]:
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
return worker_run_data.worker.embeddings(params)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
if apply_req.apply_type == WorkerApplyType.START:
@@ -458,6 +486,10 @@ class WorkerManagerAdapter(WorkerManager):
async def stop(self):
return await self.worker_manager.stop()
def after_start(self, listener: Callable[["WorkerManager"], None]):
if listener is not None:
self.worker_manager.after_start(listener)
async def supported_models(self) -> List[WorkerSupportedModel]:
return await self.worker_manager.supported_models()
@@ -474,6 +506,13 @@ class WorkerManagerAdapter(WorkerManager):
worker_type, model_name, healthy_only
)
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
return self.worker_manager.sync_get_model_instances(
worker_type, model_name, healthy_only
)
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
@@ -481,6 +520,13 @@ class WorkerManagerAdapter(WorkerManager):
worker_type, model_name, healthy_only
)
def sync_select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
return self.worker_manager.sync_select_one_instance(
worker_type, model_name, healthy_only
)
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
async for output in self.worker_manager.generate_stream(params, **kwargs):
yield output
@@ -491,6 +537,9 @@ class WorkerManagerAdapter(WorkerManager):
async def embeddings(self, params: Dict) -> List[List[float]]:
return await self.worker_manager.embeddings(params)
def sync_embeddings(self, params: Dict) -> List[List[float]]:
return self.worker_manager.sync_embeddings(params)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
return await self.worker_manager.worker_apply(apply_req)
@@ -586,11 +635,11 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
@app.on_event("startup")
async def startup_event():
asyncio.create_task(worker_manager.worker_manager.start())
asyncio.create_task(worker_manager.start())
@app.on_event("shutdown")
async def startup_event():
await worker_manager.worker_manager.stop()
await worker_manager.stop()
return app
@@ -666,29 +715,60 @@ def _create_local_model_manager(
def _build_worker(worker_params: ModelWorkerParameters):
if worker_params.worker_class:
worker_class = worker_params.worker_class
if worker_class:
from pilot.utils.module_utils import import_from_checked_string
worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker)
logger.info(
f"Import worker class from {worker_params.worker_class} successfully"
)
worker: ModelWorker = worker_cls()
worker_cls = import_from_checked_string(worker_class, ModelWorker)
logger.info(f"Import worker class from {worker_class} successfully")
else:
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
if (
worker_params.worker_type is None
or worker_params.worker_type == WorkerType.LLM
):
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
worker = DefaultModelWorker()
return worker
worker_cls = DefaultModelWorker
elif worker_params.worker_type == WorkerType.TEXT2VEC:
from pilot.model.cluster.worker.embedding_worker import (
EmbeddingsModelWorker,
)
worker_cls = EmbeddingsModelWorker
else:
raise Exception("Unsupported worker type: {worker_params.worker_type}")
return worker_cls()
def _start_local_worker(
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
):
worker = _build_worker(worker_params)
worker_manager.worker_manager = _create_local_model_manager(worker_params)
if not worker_manager.worker_manager:
worker_manager.worker_manager = _create_local_model_manager(worker_params)
worker_manager.worker_manager.add_worker(worker, worker_params)
def _start_local_embedding_worker(
worker_manager: WorkerManagerAdapter,
embedding_model_name: str = None,
embedding_model_path: str = None,
):
if not embedding_model_name or not embedding_model_path:
return
embedding_worker_params = ModelWorkerParameters(
model_name=embedding_model_name,
model_path=embedding_model_path,
worker_type=WorkerType.TEXT2VEC,
worker_class="pilot.model.cluster.worker.embedding_worker.EmbeddingsModelWorker",
)
logger.info(
f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
)
_start_local_worker(worker_manager, embedding_worker_params)
def initialize_worker_manager_in_client(
app=None,
include_router: bool = True,
@@ -697,6 +777,9 @@ def initialize_worker_manager_in_client(
run_locally: bool = True,
controller_addr: str = None,
local_port: int = 5000,
embedding_model_name: str = None,
embedding_model_path: str = None,
start_listener: Callable[["WorkerManager"], None] = None,
):
"""Initialize WorkerManager in client.
If run_locally is True:
@@ -728,6 +811,10 @@ def initialize_worker_manager_in_client(
logger.info(f"Worker params: {worker_params}")
_setup_fastapi(worker_params, app)
_start_local_worker(worker_manager, worker_params)
worker_manager.after_start(start_listener)
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
else:
from pilot.model.cluster.controller.controller import (
ModelRegistryClient,
@@ -741,9 +828,12 @@ def initialize_worker_manager_in_client(
logger.info(f"Worker params: {worker_params}")
client = ModelRegistryClient(worker_params.controller_addr)
worker_manager.worker_manager = RemoteWorkerManager(client)
worker_manager.after_start(start_listener)
initialize_controller(
app=app, remote_controller_addr=worker_params.controller_addr
)
loop = asyncio.get_event_loop()
loop.run_until_complete(worker_manager.start())
if include_router and app:
# mount WorkerManager router
@@ -757,6 +847,8 @@ def run_worker_manager(
model_path: str = None,
standalone: bool = False,
port: int = None,
embedding_model_name: str = None,
embedding_model_path: str = None,
):
global worker_manager
@@ -765,15 +857,22 @@ def run_worker_manager(
)
embedded_mod = True
logger.info(f"Worker params: {worker_params}")
if not app:
# Run worker manager independently
embedded_mod = False
app = _setup_fastapi(worker_params)
_start_local_worker(worker_manager, worker_params)
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
else:
_start_local_worker(worker_manager, worker_params)
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
loop = asyncio.get_event_loop()
loop.run_until_complete(worker_manager.worker_manager.start())
loop.run_until_complete(worker_manager.start())
if include_router:
app.include_router(router, prefix="/api")

View File

@@ -1,14 +1,12 @@
from typing import Callable, Any
import httpx
import asyncio
from typing import Any, Callable
import httpx
from pilot.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
from pilot.model.cluster.base import *
from pilot.model.cluster.registry import ModelRegistry
from pilot.model.cluster.worker.manager import LocalWorkerManager, WorkerRunData, logger
from pilot.model.cluster.base import *
from pilot.model.base import (
ModelInstance,
WorkerApplyOutput,
WorkerSupportedModel,
)
from pilot.model.cluster.worker.remote_worker import RemoteModelWorker
class RemoteWorkerManager(LocalWorkerManager):
@@ -16,7 +14,8 @@ class RemoteWorkerManager(LocalWorkerManager):
super().__init__(model_registry=model_registry)
async def start(self):
pass
for listener in self.start_listeners:
listener(self)
async def stop(self):
pass
@@ -125,15 +124,9 @@ class RemoteWorkerManager(LocalWorkerManager):
success_handler=lambda x: True,
)
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
def _build_worker_instances(
self, model_name: str, instances: List[ModelInstance]
) -> List[WorkerRunData]:
from pilot.model.cluster.worker.remote_worker import RemoteModelWorker
worker_key = self._worker_key(worker_type, model_name)
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
worker_key, healthy_only
)
worker_instances = []
for ins in instances:
worker = RemoteModelWorker()
@@ -151,6 +144,24 @@ class RemoteWorkerManager(LocalWorkerManager):
worker_instances.append(wr)
return worker_instances
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
worker_key = self._worker_key(worker_type, model_name)
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
worker_key, healthy_only
)
return self._build_worker_instances(model_name, instances)
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
worker_key = self._worker_key(worker_type, model_name)
instances: List[ModelInstance] = self.model_registry.sync_get_all_instances(
worker_key, healthy_only
)
return self._build_worker_instances(model_name, instances)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
async def _remote_apply_func(worker_run_data: WorkerRunData):
return await self._fetch_from_worker(

View File

@@ -87,7 +87,15 @@ class RemoteModelWorker(ModelWorker):
def embeddings(self, params: Dict) -> List[List[float]]:
"""Get embeddings for input"""
raise NotImplementedError
import requests
response = requests.post(
self.worker_addr + "/embeddings",
headers=self.headers,
json=params,
timeout=self.timeout,
)
return response.json()
async def async_embeddings(self, params: Dict) -> List[List[float]]:
"""Asynchronous get embeddings for input"""

View File

@@ -46,9 +46,7 @@ class ModelWorkerParameters(BaseModelParameters):
)
worker_class: Optional[str] = field(
default=None,
metadata={
"help": "Model worker class, pilot.model.worker.default_worker.DefaultModelWorker"
},
metadata={"help": "Model worker class, pilot.model.cluster.DefaultModelWorker"},
)
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model worker deploy host"}

View File

@@ -111,7 +111,7 @@ async def db_connect_delete(db_name: str = None):
async def async_db_summary_embedding(db_name, db_type):
# 在这里执行需要异步运行的代码
db_summary_client = DBSummaryClient()
db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
db_summary_client.db_summary_embedding(db_name, db_type)

View File

@@ -61,7 +61,7 @@ class ChatDashboard(BaseChat):
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient()
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
table_infos = client.get_similar_tables(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k

View File

@@ -35,7 +35,7 @@ class ChatWithDbAutoExecute(BaseChat):
from pilot.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient()
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
table_infos = client.get_db_summary(
dbname=self.db_name,

View File

@@ -41,7 +41,7 @@ class ChatWithDbQA(BaseChat):
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
if self.db_name:
client = DBSummaryClient()
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k

View File

@@ -23,6 +23,7 @@ class ChatKnowledge(BaseChat):
def __init__(self, chat_session_id, user_input, select_param: str = None):
""" """
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
self.knowledge_space = select_param
super().__init__(
@@ -47,9 +48,13 @@ class ChatKnowledge(BaseChat):
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_componet(
"embedding_factory", EmbeddingFactory
)
self.knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
def generate_input_values(self):

View File

@@ -2,10 +2,11 @@ import signal
import os
import threading
import sys
from typing import Optional
from typing import Optional, Any
from dataclasses import dataclass, field
from pilot.configs.config import Config
from pilot.componet import SystemApp
from pilot.utils.parameter_utils import BaseParameters
@@ -18,30 +19,28 @@ def signal_handler(sig, frame):
os._exit(0)
def async_db_summery():
def async_db_summery(system_app: SystemApp):
from pilot.summary.db_summary_client import DBSummaryClient
client = DBSummaryClient()
client = DBSummaryClient(system_app=system_app)
thread = threading.Thread(target=client.init_db_summary)
thread.start()
def server_init(args):
def server_init(args, system_app: SystemApp):
from pilot.commands.command_mange import CommandRegistry
from pilot.connections.manages.connection_manager import ConnectManager
from pilot.common.plugins import scan_plugins
# logger.info(f"args: {args}")
# init config
cfg = Config()
# init connect manage
conn_manage = ConnectManager()
cfg.LOCAL_DB_MANAGE = conn_manage
cfg.SYSTEM_APP = system_app
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
async_db_summery()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# Loader plugins and commands
@@ -70,6 +69,22 @@ def server_init(args):
cfg.command_disply = command_disply_registry
def _create_model_start_listener(system_app: SystemApp):
from pilot.connections.manages.connection_manager import ConnectManager
from pilot.model.cluster import worker_manager
cfg = Config()
def startup_event(wh):
# init connect manage
print("begin run _add_app_startup_event")
conn_manage = ConnectManager(system_app)
cfg.LOCAL_DB_MANAGE = conn_manage
async_db_summery(system_app)
return startup_event
@dataclass
class WebWerverParameters(BaseParameters):
host: Optional[str] = field(

View File

@@ -0,0 +1,38 @@
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import SystemApp
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
def initialize_componets(system_app: SystemApp, embedding_model_name: str):
from pilot.model.cluster import worker_manager
system_app.register(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(
self, system_app, worker_manager, model_name: str = None, **kwargs: Any
) -> None:
super().__init__(system_app=system_app)
self._worker_manager = worker_manager
self._default_model_name = model_name
self.kwargs = kwargs
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
if embedding_cls:
raise NotImplementedError
# Ignore model_name args
return RemoteEmbeddings(self._default_model_name, self._worker_manager)

View File

@@ -7,9 +7,15 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
from pilot.componet import SystemApp
from pilot.server.base import server_init, WebWerverParameters
from pilot.server.base import (
server_init,
WebWerverParameters,
_create_model_start_listener,
)
from pilot.server.componet_configs import initialize_componets
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications
@@ -48,6 +54,8 @@ def swagger_monkey_patch(*args, **kwargs):
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
system_app = SystemApp(app)
origins = ["*"]
# 添加跨域中间件
@@ -98,7 +106,12 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
setup_logging(logging_level=param.log_level)
server_init(param)
# Before start
system_app.before_start()
server_init(param, system_app)
model_start_listener = _create_model_start_listener(system_app)
initialize_componets(system_app, CFG.EMBEDDING_MODEL)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
if not param.light:
@@ -108,6 +121,9 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
model_name=CFG.LLM_MODEL,
model_path=model_path,
local_port=param.port,
embedding_model_name=CFG.EMBEDDING_MODEL,
embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
start_listener=model_start_listener,
)
CFG.NEW_SERVER_MODE = True
@@ -120,6 +136,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
run_locally=False,
controller_addr=CFG.MODEL_SERVER,
local_port=param.port,
start_listener=model_start_listener,
)
CFG.SERVER_LIGHT_MODE = True

View File

@@ -4,8 +4,6 @@ import tempfile
from fastapi import APIRouter, File, UploadFile, Form
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
@@ -14,6 +12,7 @@ from pilot.configs.model_config import (
from pilot.openapi.api_view_model import Result
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.knowledge.service import KnowledgeService
from pilot.server.knowledge.request.request import (
@@ -32,10 +31,6 @@ CFG = Config()
router = APIRouter()
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
knowledge_space_service = KnowledgeService()
@@ -186,8 +181,13 @@ def document_list(space_name: str, query_request: ChunkQueryRequest):
@router.post("/knowledge/{vector_name}/query")
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
print(f"Received params: {space_name}, {query_request}")
embedding_factory = CFG.SYSTEM_APP.get_componet(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(
model_name=embeddings, vector_store_config={"vector_store_name": space_name}
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={"vector_store_name": space_name},
embedding_factory=embedding_factory,
)
docs = client.similar_search(query_request.query, query_request.top_k)
res = [

View File

@@ -154,6 +154,7 @@ class KnowledgeService:
def sync_knowledge_document(self, space_name, doc_ids):
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
@@ -204,6 +205,9 @@ class KnowledgeService:
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
embedding_factory = CFG.SYSTEM_APP.get_componet(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
@@ -214,6 +218,7 @@ class KnowledgeService:
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
text_splitter=text_splitter,
embedding_factory=embedding_factory,
)
chunk_docs = client.read()
# update document status

View File

@@ -8,7 +8,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
from pilot.model.cluster import run_worker_manager
CFG = Config()
@@ -21,4 +21,6 @@ if __name__ == "__main__":
model_path=model_path,
standalone=True,
port=CFG.MODEL_PORT,
embedding_model_name=CFG.EMBEDDING_MODEL,
embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
)

View File

@@ -2,6 +2,7 @@ import json
import uuid
from pilot.common.schema import DBType
from pilot.componet import SystemApp
from pilot.configs.config import Config
from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
@@ -26,16 +27,19 @@ class DBSummaryClient:
, get_similar_tables method(get user query related tables info)
"""
def __init__(self):
pass
def __init__(self, system_app: SystemApp):
self.system_app = system_app
def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store"""
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
db_summary_client = RdbmsSummary(dbname, db_type)
embeddings = HuggingFaceEmbeddings(
embedding_factory = self.system_app.get_componet(
"embedding_factory", EmbeddingFactory
)
embeddings = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_store_config = {
@@ -83,15 +87,20 @@ class DBSummaryClient:
def get_db_summary(self, dbname, query, topk):
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
vector_store_config = {
"vector_store_name": dbname + "_profile",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_componet(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
table_docs = knowledge_embedding_client.similar_search(query, topk)
ans = [d.page_content for d in table_docs]
@@ -100,6 +109,7 @@ class DBSummaryClient:
def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
vector_store_config = {
"vector_store_name": dbname + "_summary",
@@ -107,9 +117,13 @@ class DBSummaryClient:
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_componet(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
if CFG.SUMMARY_CONFIG == "FAST":
table_docs = knowledge_embedding_client.similar_search(query, topk)
@@ -136,6 +150,7 @@ class DBSummaryClient:
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
table_summery = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summery[0].page_content)

View File

@@ -15,60 +15,60 @@ def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
return None
def _build_request(self, func, path, method, *args, **kwargs):
return_type = get_type_hints(func).get("return")
if return_type is None:
raise TypeError("Return type must be annotated in the decorated function.")
actual_dataclass = _extract_dataclass_from_generic(return_type)
logging.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
if not actual_dataclass:
actual_dataclass = return_type
sig = signature(func)
base_url = self.base_url # Get base_url from class instance
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()
formatted_url = base_url + path.format(**bound.arguments)
# Extract args names from signature, except "self"
arg_names = list(sig.parameters.keys())[1:]
# Combine args and kwargs into a single dictionary
combined_args = dict(zip(arg_names, args))
combined_args.update(kwargs)
request_data = {}
for key, value in combined_args.items():
if is_dataclass(value):
# Here, instead of adding it as a nested dictionary,
# we set request_data directly to its dictionary representation.
request_data = asdict(value)
else:
request_data[key] = value
request_params = {"method": method, "url": formatted_url}
if method in ["POST", "PUT", "PATCH"]:
request_params["json"] = request_data
else: # For GET, DELETE, etc.
request_params["params"] = request_data
logging.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
return return_type, actual_dataclass, request_params
def _api_remote(path, method="GET"):
def decorator(func):
return_type = get_type_hints(func).get("return")
if return_type is None:
raise TypeError("Return type must be annotated in the decorated function.")
actual_dataclass = _extract_dataclass_from_generic(return_type)
logging.debug(
f"return_type: {return_type}, actual_dataclass: {actual_dataclass}"
)
if not actual_dataclass:
actual_dataclass = return_type
sig = signature(func)
async def wrapper(self, *args, **kwargs):
import httpx
base_url = self.base_url # Get base_url from class instance
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()
formatted_url = base_url + path.format(**bound.arguments)
# Extract args names from signature, except "self"
arg_names = list(sig.parameters.keys())[1:]
# Combine args and kwargs into a single dictionary
combined_args = dict(zip(arg_names, args))
combined_args.update(kwargs)
request_data = {}
for key, value in combined_args.items():
if is_dataclass(value):
# Here, instead of adding it as a nested dictionary,
# we set request_data directly to its dictionary representation.
request_data = asdict(value)
else:
request_data[key] = value
request_params = {"method": method, "url": formatted_url}
if method in ["POST", "PUT", "PATCH"]:
request_params["json"] = request_data
else: # For GET, DELETE, etc.
request_params["params"] = request_data
logging.info(
f"request_params: {request_params}, args: {args}, kwargs: {kwargs}"
return_type, actual_dataclass, request_params = _build_request(
self, func, path, method, *args, **kwargs
)
async with httpx.AsyncClient() as client:
response = await client.request(**request_params)
if response.status_code == 200:
return _parse_response(
response.json(), return_type, actual_dataclass
@@ -82,6 +82,28 @@ def _api_remote(path, method="GET"):
return decorator
def _sync_api_remote(path, method="GET"):
def decorator(func):
def wrapper(self, *args, **kwargs):
import requests
return_type, actual_dataclass, request_params = _build_request(
self, func, path, method, *args, **kwargs
)
response = requests.request(**request_params)
if response.status_code == 200:
return _parse_response(response.json(), return_type, actual_dataclass)
else:
error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
raise Exception(error_msg)
return wrapper
return decorator
def _parse_response(json_response, return_type, actual_dataclass):
# print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}')
if is_dataclass(actual_dataclass):

View File

@@ -30,8 +30,9 @@ def knownledge_tovec_st(filename):
https://github.com/UKPLab/sentence-transformers
"""
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory
embeddings = HuggingFaceEmbeddings(
embeddings = DefaultEmbeddingFactory().create(
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
)
@@ -58,8 +59,9 @@ def load_knownledge_from_doc():
)
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory
embeddings = HuggingFaceEmbeddings(
embeddings = DefaultEmbeddingFactory().create(
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
)

View File

@@ -44,7 +44,11 @@ class KnownLedge2Vector:
def __init__(self, model_name=None) -> None:
if not model_name:
# use default embedding model
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory
self.embeddings = DefaultEmbeddingFactory().create(
model_name=self.model_name
)
def init_vector_store(self):
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")