mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 10:34:30 +00:00
feat(model): multi-model supports embedding model and simple component design implementation
This commit is contained in:
@@ -4,9 +4,10 @@ services:
|
||||
controller:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start controller
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- dbgptnet
|
||||
worker:
|
||||
llm-worker:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000
|
||||
environment:
|
||||
@@ -17,6 +18,27 @@ services:
|
||||
- /data:/data
|
||||
# Please modify it to your own model directory
|
||||
- /data/models:/app/models
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- dbgptnet
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
capabilities: [gpu]
|
||||
embedding-worker:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start worker --model_name text2vec --worker_type text2vec --model_path /app/models/text2vec-large-chinese --port 8002 --controller_addr http://controller:8000
|
||||
environment:
|
||||
- DBGPT_LOG_LEVEL=DEBUG
|
||||
depends_on:
|
||||
- controller
|
||||
volumes:
|
||||
- /data:/data
|
||||
# Please modify it to your own model directory
|
||||
- /data/models:/app/models
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- dbgptnet
|
||||
deploy:
|
||||
@@ -37,7 +59,8 @@ services:
|
||||
- MODEL_SERVER=http://controller:8000
|
||||
depends_on:
|
||||
- controller
|
||||
- worker
|
||||
- llm-worker
|
||||
- embedding-worker
|
||||
volumes:
|
||||
- /data:/data
|
||||
# Please modify it to your own model directory
|
||||
|
BIN
docs/_static/img/muti-model-cluster-overview.png
vendored
Normal file
BIN
docs/_static/img/muti-model-cluster-overview.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 361 KiB |
@@ -9,6 +9,7 @@ DB-GPT product is a Web application that you can chat database, chat knowledge,
|
||||
- docker
|
||||
- docker_compose
|
||||
- environment
|
||||
- cluster deployment
|
||||
- deploy_faq
|
||||
|
||||
.. toctree::
|
||||
@@ -20,6 +21,7 @@ DB-GPT product is a Web application that you can chat database, chat knowledge,
|
||||
./install/deploy/deploy.md
|
||||
./install/docker/docker.md
|
||||
./install/docker_compose/docker_compose.md
|
||||
./install/cluster/cluster.rst
|
||||
./install/llm/llm.rst
|
||||
./install/environment/environment.md
|
||||
./install/faq/deploy_faq.md
|
19
docs/getting_started/install/cluster/cluster.rst
Normal file
19
docs/getting_started/install/cluster/cluster.rst
Normal file
@@ -0,0 +1,19 @@
|
||||
Cluster deployment
|
||||
==================================
|
||||
|
||||
In order to deploy DB-GPT to multiple nodes, you can deploy a cluster. The cluster architecture diagram is as follows:
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<img src="../../../_static/img/muti-model-cluster-overview.png" />
|
||||
|
||||
|
||||
* On :ref:`Deploying on local machine <local-cluster-index>`. Local cluster deployment.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Cluster deployment
|
||||
:name: cluster_deploy
|
||||
:hidden:
|
||||
|
||||
./vms/index.md
|
3
docs/getting_started/install/cluster/kubernetes/index.md
Normal file
3
docs/getting_started/install/cluster/kubernetes/index.md
Normal file
@@ -0,0 +1,3 @@
|
||||
Kubernetes cluster deployment
|
||||
==================================
|
||||
(kubernetes-cluster-index)=
|
@@ -1,6 +1,6 @@
|
||||
Cluster deployment
|
||||
Local cluster deployment
|
||||
==================================
|
||||
|
||||
(local-cluster-index)=
|
||||
## Model cluster deployment
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ dbgpt start controller
|
||||
By default, the Model Controller starts on port 8000.
|
||||
|
||||
|
||||
### Launch Model Worker
|
||||
### Launch LLM Model Worker
|
||||
|
||||
If you are starting `chatglm2-6b`:
|
||||
|
||||
@@ -39,6 +39,18 @@ dbgpt start worker --model_name vicuna-13b-v1.5 \
|
||||
|
||||
Note: Be sure to use your own model name and model path.
|
||||
|
||||
### Launch Embedding Model Worker
|
||||
|
||||
```bash
|
||||
|
||||
dbgpt start worker --model_name text2vec \
|
||||
--model_path /app/models/text2vec-large-chinese \
|
||||
--worker_type text2vec \
|
||||
--port 8003 \
|
||||
--controller_addr http://127.0.0.1:8000
|
||||
```
|
||||
|
||||
Note: Be sure to use your own model name and model path.
|
||||
|
||||
Check your model:
|
||||
|
||||
@@ -51,8 +63,12 @@ You will see the following output:
|
||||
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat |
|
||||
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
| chatglm2-6b | llm | 172.17.0.6 | 8001 | True | True | None | 2023-08-31T04:48:45.252939 |
|
||||
| vicuna-13b-v1.5 | llm | 172.17.0.6 | 8002 | True | True | None | 2023-08-31T04:48:55.136676 |
|
||||
| chatglm2-6b | llm | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.287654 |
|
||||
| WorkerManager | service | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.286668 |
|
||||
| WorkerManager | service | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.845617 |
|
||||
| WorkerManager | service | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.598439 |
|
||||
| text2vec | text2vec | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.844796 |
|
||||
| vicuna-13b-v1.5 | llm | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.597775 |
|
||||
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
```
|
||||
|
||||
@@ -69,7 +85,7 @@ MODEL_SERVER=http://127.0.0.1:8000
|
||||
#### Start the webserver
|
||||
|
||||
```bash
|
||||
python pilot/server/dbgpt_server.py --light
|
||||
dbgpt start webserver --light
|
||||
```
|
||||
|
||||
`--light` indicates not to start the embedded model service.
|
||||
@@ -77,7 +93,7 @@ python pilot/server/dbgpt_server.py --light
|
||||
Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` to start:
|
||||
|
||||
```bash
|
||||
LLM_MODEL=chatglm2-6b python pilot/server/dbgpt_server.py --light
|
||||
LLM_MODEL=chatglm2-6b dbgpt start webserver --light
|
||||
```
|
||||
|
||||
|
||||
@@ -101,9 +117,11 @@ Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
model Clients that manage model serving
|
||||
start Start specific server.
|
||||
stop Start specific server.
|
||||
install Install dependencies, plugins, etc.
|
||||
knowledge Knowledge command line tool
|
||||
model Clients that manage model serving
|
||||
start Start specific server.
|
||||
stop Start specific server.
|
||||
```
|
||||
|
||||
**View the `dbgpt start` help**
|
||||
@@ -146,10 +164,11 @@ Options:
|
||||
--model_name TEXT Model name [required]
|
||||
--model_path TEXT Model path [required]
|
||||
--worker_type TEXT Worker type
|
||||
--worker_class TEXT Model worker class, pilot.model.worker.defau
|
||||
lt_worker.DefaultModelWorker
|
||||
--worker_class TEXT Model worker class,
|
||||
pilot.model.cluster.DefaultModelWorker
|
||||
--host TEXT Model worker deploy host [default: 0.0.0.0]
|
||||
--port INTEGER Model worker deploy port [default: 8000]
|
||||
--port INTEGER Model worker deploy port [default: 8001]
|
||||
--daemon Run Model Worker in background
|
||||
--limit_model_concurrency INTEGER
|
||||
Model concurrency limit [default: 5]
|
||||
--standalone Standalone mode. If True, embedded Run
|
||||
@@ -166,7 +185,7 @@ Options:
|
||||
(seconds) [default: 20]
|
||||
--device TEXT Device to run model. If None, the device is
|
||||
automatically determined
|
||||
--model_type TEXT Model type, huggingface or llama.cpp
|
||||
--model_type TEXT Model type, huggingface, llama.cpp and proxy
|
||||
[default: huggingface]
|
||||
--prompt_template TEXT Prompt template. If None, the prompt
|
||||
template is automatically determined from
|
||||
@@ -208,10 +227,13 @@ Usage: dbgpt model [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--address TEXT Address of the Model Controller to connect to. Just support
|
||||
light deploy model [default: http://127.0.0.1:8000]
|
||||
light deploy model, If the environment variable
|
||||
CONTROLLER_ADDRESS is configured, read from the environment
|
||||
variable
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
chat Interact with your bot from the command line
|
||||
list List model instances
|
||||
restart Restart model instances
|
||||
start Start model instances
|
@@ -6,6 +6,7 @@ DB-GPT provides a management and deployment solution for multiple models. This c
|
||||
|
||||
|
||||
Multi LLMs Support, Supports multiple large language models, currently supporting
|
||||
- 🔥 Baichuan2(7b,13b)
|
||||
- 🔥 Vicuna-v1.5(7b,13b)
|
||||
- 🔥 llama-2(7b,13b,70b)
|
||||
- WizardLM-v1.2(13b)
|
||||
@@ -19,7 +20,6 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
|
||||
|
||||
- llama_cpp
|
||||
- quantization
|
||||
- cluster deployment
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
@@ -29,4 +29,3 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
|
||||
|
||||
./llama/llama_cpp.md
|
||||
./quantization/quantization.md
|
||||
./cluster/model_cluster.md
|
||||
|
@@ -8,7 +8,7 @@ msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-08-16 18:31+0800\n"
|
||||
"POT-Creation-Date: 2023-09-13 09:06+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
@@ -19,34 +19,38 @@ msgstr ""
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../getting_started/install.rst:2 ../../getting_started/install.rst:14
|
||||
#: 2861085e63144eaca1bb825e5f05d089
|
||||
#: ../../getting_started/install.rst:2 ../../getting_started/install.rst:15
|
||||
#: e2c13385046b4da6b6838db6ba2ea59c
|
||||
msgid "Install"
|
||||
msgstr "Install"
|
||||
|
||||
#: ../../getting_started/install.rst:3 01a6603d91fa4520b0f839379d4eda23
|
||||
#: ../../getting_started/install.rst:3 3cb6cd251ed440dabe5d4f556435f405
|
||||
msgid ""
|
||||
"DB-GPT product is a Web application that you can chat database, chat "
|
||||
"knowledge, text2dashboard."
|
||||
msgstr "DB-GPT 可以生成sql,智能报表, 知识库问答的产品"
|
||||
|
||||
#: ../../getting_started/install.rst:8 beca85cddc9b4406aecf83d5dfcce1f7
|
||||
#: ../../getting_started/install.rst:8 6fe8104b70d24f5fbfe2ad9ebf3bc3ba
|
||||
msgid "deploy"
|
||||
msgstr "部署"
|
||||
|
||||
#: ../../getting_started/install.rst:9 601e9b9eb91f445fb07d2f1c807f0370
|
||||
#: ../../getting_started/install.rst:9 e67974b3672346809febf99a3b9a55d3
|
||||
msgid "docker"
|
||||
msgstr "docker"
|
||||
|
||||
#: ../../getting_started/install.rst:10 6d1e094ac9284458a32a3e7fa6241c81
|
||||
#: ../../getting_started/install.rst:10 64de16a047c74598966e19a656bf6c4f
|
||||
msgid "docker_compose"
|
||||
msgstr "docker_compose"
|
||||
|
||||
#: ../../getting_started/install.rst:11 ff1d1c60bbdc4e8ca82b7a9f303dd167
|
||||
#: ../../getting_started/install.rst:11 9f87d65e8675435b87cb9376a5bfd85c
|
||||
msgid "environment"
|
||||
msgstr "environment"
|
||||
|
||||
#: ../../getting_started/install.rst:12 33bfbe8defd74244bfc24e8fbfd640f6
|
||||
#: ../../getting_started/install.rst:12 e60fa13bb24544ed9d4f902337093ebc
|
||||
msgid "cluster deployment"
|
||||
msgstr "集群部署"
|
||||
|
||||
#: ../../getting_started/install.rst:13 7451712679c2412e858e7d3e2af6b174
|
||||
msgid "deploy_faq"
|
||||
msgstr "deploy_faq"
|
||||
|
||||
|
@@ -0,0 +1,42 @@
|
||||
# SOME DESCRIPTIVE TITLE.
|
||||
# Copyright (C) 2023, csunny
|
||||
# This file is distributed under the same license as the DB-GPT package.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
|
||||
#
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-09-13 10:11+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
"Language-Team: zh_CN <LL@li.org>\n"
|
||||
"Plural-Forms: nplurals=1; plural=0;\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../getting_started/install/cluster/cluster.rst:2
|
||||
#: ../../getting_started/install/cluster/cluster.rst:13
|
||||
#: 69804208b580447798d6946150da7bdf
|
||||
msgid "Cluster deployment"
|
||||
msgstr "集群部署"
|
||||
|
||||
#: ../../getting_started/install/cluster/cluster.rst:4
|
||||
#: fa3e4e0ae60a45eb836bcd256baa9d91
|
||||
msgid ""
|
||||
"In order to deploy DB-GPT to multiple nodes, you can deploy a cluster. "
|
||||
"The cluster architecture diagram is as follows:"
|
||||
msgstr "为了能将 DB-GPT 部署到多个节点上,你可以部署一个集群,集群的架构图如下:"
|
||||
|
||||
#: ../../getting_started/install/cluster/cluster.rst:11
|
||||
#: e739449099ca43cabe9883233ca7e572
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
"On :ref:`Deploying on local machine <local-cluster-index>`. Local cluster"
|
||||
" deployment."
|
||||
msgstr "关于 :ref:`在本地机器上部署 <local-cluster-index>`。本地集群部署。"
|
||||
|
@@ -0,0 +1,26 @@
|
||||
# SOME DESCRIPTIVE TITLE.
|
||||
# Copyright (C) 2023, csunny
|
||||
# This file is distributed under the same license as the DB-GPT package.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
|
||||
#
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-09-13 09:06+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
"Language-Team: zh_CN <LL@li.org>\n"
|
||||
"Plural-Forms: nplurals=1; plural=0;\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../getting_started/install/cluster/kubernetes/index.md:1
|
||||
#: 48e6f08f27c74f31a8b12758fe33dc24
|
||||
msgid "Kubernetes cluster deployment"
|
||||
msgstr "Kubernetes 集群部署"
|
||||
|
@@ -0,0 +1,176 @@
|
||||
# SOME DESCRIPTIVE TITLE.
|
||||
# Copyright (C) 2023, csunny
|
||||
# This file is distributed under the same license as the DB-GPT package.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
|
||||
#
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-09-13 09:06+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
"Language-Team: zh_CN <LL@li.org>\n"
|
||||
"Plural-Forms: nplurals=1; plural=0;\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:1
|
||||
#: 2d2e04ba49364eae9b8493bb274765a6
|
||||
msgid "Local cluster deployment"
|
||||
msgstr "本地集群部署"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:4
|
||||
#: e405d0e7ad8c4b2da4b4ca27c77f5fea
|
||||
msgid "Model cluster deployment"
|
||||
msgstr "模型集群部署"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:7
|
||||
#: bba397ddac754a2bab8edca163875b65
|
||||
msgid "**Installing Command-Line Tool**"
|
||||
msgstr "**安装命令行工具**"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:9
|
||||
#: bc45851124354522af8c9bb9748ff1fa
|
||||
msgid ""
|
||||
"All operations below are performed using the `dbgpt` command. To use the "
|
||||
"`dbgpt` command, you need to install the DB-GPT project with `pip install"
|
||||
" -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` "
|
||||
"as a substitute for the `dbgpt` command."
|
||||
msgstr ""
|
||||
"以下所有操作都使用 `dbgpt` 命令完成。要使用 `dbgpt` 命令,您需要安装DB-GPT项目,方法是使用`pip install -e .`。或者,您可以使用 `python pilot/scripts/cli_scripts.py` 作为 `dbgpt` 命令的替代。"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:11
|
||||
#: 9d11f7807fd140c8949b634700adc966
|
||||
msgid "Launch Model Controller"
|
||||
msgstr "启动 Model Controller"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:17
|
||||
#: 97716be92ba64ce9a215433bddf77add
|
||||
msgid "By default, the Model Controller starts on port 8000."
|
||||
msgstr "默认情况下,Model Controller 启动在 8000 端口。"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:20
|
||||
#: 3f65e6a1e59248a59c033891d1ab7ba8
|
||||
msgid "Launch LLM Model Worker"
|
||||
msgstr "启动 LLM Model Worker"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:22
|
||||
#: 60241d97573e4265b7fb150c378c4a08
|
||||
msgid "If you are starting `chatglm2-6b`:"
|
||||
msgstr "如果您启动的是 `chatglm2-6b`:"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:31
|
||||
#: 18bbeb1de110438fa96dd5c736b9a7b1
|
||||
msgid "If you are starting `vicuna-13b-v1.5`:"
|
||||
msgstr "如果您启动的是 `vicuna-13b-v1.5`:"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:40
|
||||
#: ../../getting_started/install/cluster/vms/index.md:53
|
||||
#: 24b1a27313c64224aaeab6cbfad1fe19 fc94a698a7904c6893eef7e7a6e52972
|
||||
msgid "Note: Be sure to use your own model name and model path."
|
||||
msgstr "注意:确保使用您自己的模型名称和模型路径。"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:42
|
||||
#: 19746195e85f4784bf66a9e67378c04b
|
||||
msgid "Launch Embedding Model Worker"
|
||||
msgstr "启动 Embedding Model Worker"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:55
|
||||
#: e93ce68091f64d0294b3f912a66cc18b
|
||||
msgid "Check your model:"
|
||||
msgstr "检查您的模型:"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:61
|
||||
#: fa0b8f3a18fe4bab88fbf002bf26d32e
|
||||
msgid "You will see the following output:"
|
||||
msgstr "您将看到以下输出:"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:75
|
||||
#: 695262fb4f224101902bc7865ac7871f
|
||||
msgid "Connect to the model service in the webserver (dbgpt_server)"
|
||||
msgstr "在 webserver (dbgpt_server) 中连接到模型服务 (dbgpt_server)"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:77
|
||||
#: 73bf4c2ae5c64d938e3b7e77c06fa21e
|
||||
msgid ""
|
||||
"**First, modify the `.env` file to change the model name and the Model "
|
||||
"Controller connection address.**"
|
||||
msgstr ""
|
||||
"**首先,修改 `.env` 文件以更改模型名称和模型控制器连接地址。**"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:85
|
||||
#: 8ab126fd72ed4368a79b821ba50e62c8
|
||||
msgid "Start the webserver"
|
||||
msgstr "启动 webserver"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:91
|
||||
#: 5a7e25c84ca2412bb64310bfad9e2403
|
||||
msgid "`--light` indicates not to start the embedded model service."
|
||||
msgstr "`--light` 表示不启动嵌入式模型服务。"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:93
|
||||
#: 8cd9ec4fa9cb4c0fa8ff05c05a85ea7f
|
||||
msgid ""
|
||||
"Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` "
|
||||
"to start:"
|
||||
msgstr ""
|
||||
"或者,您可以在命令前加上 `LLM_MODEL=chatglm2-6b` 来启动:"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:100
|
||||
#: 13ed16758a104860b5fc982d36638b17
|
||||
msgid "More Command-Line Usages"
|
||||
msgstr "更多命令行用法"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:102
|
||||
#: 175f614d547a4391bab9a77762f9174e
|
||||
msgid "You can view more command-line usages through the help command."
|
||||
msgstr "您可以通过帮助命令查看更多命令行用法。"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:104
|
||||
#: 6a4475d271c347fbbb35f2936a86823f
|
||||
msgid "**View the `dbgpt` help**"
|
||||
msgstr "**查看 `dbgpt` 帮助**"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:109
|
||||
#: 3eb11234cf504cc9ac369d8462daa14b
|
||||
msgid "You will see the basic command parameters and usage:"
|
||||
msgstr "您将看到基本的命令参数和用法:"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:127
|
||||
#: 6eb47aecceec414e8510fe022b6fddbd
|
||||
msgid "**View the `dbgpt start` help**"
|
||||
msgstr "**查看 `dbgpt start` 帮助**"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:133
|
||||
#: 1f4c0a4ce0704ca8ac33178bd13c69ad
|
||||
msgid "Here you can see the related commands and usage for start:"
|
||||
msgstr "在这里,您可以看到启动的相关命令和用法:"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:150
|
||||
#: 22e8e67bc55244e79764d091f334560b
|
||||
msgid "**View the `dbgpt start worker`help**"
|
||||
msgstr "**查看 `dbgpt start worker` 帮助**"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:156
|
||||
#: 5631b83fda714780855e99e90d4eb542
|
||||
msgid "Here you can see the parameters to start Model Worker:"
|
||||
msgstr "在这里,您可以看到启动 Model Worker 的参数:"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:215
|
||||
#: cf4a31fd3368481cba1b3ab382615f53
|
||||
msgid "**View the `dbgpt model`help**"
|
||||
msgstr "**查看 `dbgpt model` 帮助**"
|
||||
|
||||
#: ../../getting_started/install/cluster/vms/index.md:221
|
||||
#: 3740774ec4b240f2882b5b59da224d55
|
||||
msgid ""
|
||||
"The `dbgpt model ` command can connect to the Model Controller via the "
|
||||
"Model Controller address and then manage a remote model:"
|
||||
msgstr ""
|
||||
"`dbgpt model` 命令可以通过 Model Controller 地址连接到 Model Controller,然后管理远程模型:"
|
||||
|
@@ -8,7 +8,7 @@ msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-08-31 16:38+0800\n"
|
||||
"POT-Creation-Date: 2023-09-13 10:46+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
@@ -21,84 +21,88 @@ msgstr ""
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:2
|
||||
#: ../../getting_started/install/llm/llm.rst:24
|
||||
#: b348d4df8ca44dd78b42157a8ff6d33d
|
||||
#: e693a8d3769b4d9e99c4442ca77dc43c
|
||||
msgid "LLM Usage"
|
||||
msgstr "LLM使用"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:3 7f5960a7e5634254b330da27be87594b
|
||||
#: ../../getting_started/install/llm/llm.rst:3 0a73562d18ba455bab04277b715c3840
|
||||
msgid ""
|
||||
"DB-GPT provides a management and deployment solution for multiple models."
|
||||
" This chapter mainly discusses how to deploy different models."
|
||||
msgstr "DB-GPT提供了多模型的管理和部署方案,本章主要讲解针对不同的模型该怎么部署"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:18
|
||||
#: b844ab204ec740ec9d7d191bb841f09e
|
||||
#: ../../getting_started/install/llm/llm.rst:19
|
||||
#: d7e4de2a7e004888897204ec76b6030b
|
||||
msgid ""
|
||||
"Multi LLMs Support, Supports multiple large language models, currently "
|
||||
"supporting"
|
||||
msgstr "目前DB-GPT已适配如下模型"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:9 c141437ddaf84c079360008343041b2f
|
||||
#: ../../getting_started/install/llm/llm.rst:9 4616886b8b2244bd93355e871356d89e
|
||||
#, fuzzy
|
||||
msgid "🔥 Baichuan2(7b,13b)"
|
||||
msgstr "Baichuan(7b,13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:10
|
||||
#: ad0e4793d4e744c1bdf59f5a3d9c84be
|
||||
msgid "🔥 Vicuna-v1.5(7b,13b)"
|
||||
msgstr "🔥 Vicuna-v1.5(7b,13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:10
|
||||
#: d32b1e3f114c4eab8782b497097c1b37
|
||||
#: ../../getting_started/install/llm/llm.rst:11
|
||||
#: d291e58001ae487bbbf2a1f9f889f5fd
|
||||
msgid "🔥 llama-2(7b,13b,70b)"
|
||||
msgstr "🔥 llama-2(7b,13b,70b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:11
|
||||
#: 0a417ee4d008421da07fff7add5d05eb
|
||||
#: ../../getting_started/install/llm/llm.rst:12
|
||||
#: 1e49702ee40b4655945a2a13efaad536
|
||||
msgid "WizardLM-v1.2(13b)"
|
||||
msgstr "WizardLM-v1.2(13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:12
|
||||
#: 199e1a9fe3324dc8a1bcd9cd0b1ef047
|
||||
#: ../../getting_started/install/llm/llm.rst:13
|
||||
#: 4ef5913ddfe840d7a12289e6e1d4cb60
|
||||
msgid "Vicuna (7b,13b)"
|
||||
msgstr "Vicuna (7b,13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:13
|
||||
#: a9e4c5100534450db3a583fa5850e4be
|
||||
#: ../../getting_started/install/llm/llm.rst:14
|
||||
#: ea46c2211257459285fa48083cb59561
|
||||
msgid "ChatGLM-6b (int4,int8)"
|
||||
msgstr "ChatGLM-6b (int4,int8)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:14
|
||||
#: 943324289eb94042b52fd824189cd93f
|
||||
#: ../../getting_started/install/llm/llm.rst:15
|
||||
#: 90688302bae4452a84f14e8ecb7f1a21
|
||||
msgid "ChatGLM2-6b (int4,int8)"
|
||||
msgstr "ChatGLM2-6b (int4,int8)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:15
|
||||
#: f1226fdfac3b4e9d88642ffa69d75682
|
||||
#: ../../getting_started/install/llm/llm.rst:16
|
||||
#: ee1469545a314696a36e7296c7b71960
|
||||
msgid "guanaco(7b,13b,33b)"
|
||||
msgstr "guanaco(7b,13b,33b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:16
|
||||
#: 3f2457f56eb341b6bc431c9beca8f4df
|
||||
#: ../../getting_started/install/llm/llm.rst:17
|
||||
#: 25abad241f4d4eee970d5938bf71311f
|
||||
msgid "Gorilla(7b,13b)"
|
||||
msgstr "Gorilla(7b,13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:17
|
||||
#: 86c8ce37be1c4a7ea3fc382100d77a9c
|
||||
#: ../../getting_started/install/llm/llm.rst:18
|
||||
#: 8e3d0399431a4c6a9065a8ae0ad3c8ac
|
||||
msgid "Baichuan(7b,13b)"
|
||||
msgstr "Baichuan(7b,13b)"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:18
|
||||
#: 538111af95ad414cb2e631a89f9af379
|
||||
#: ../../getting_started/install/llm/llm.rst:19
|
||||
#: c285fa7c9c6c4e3e9840761a09955348
|
||||
msgid "OpenAI"
|
||||
msgstr "OpenAI"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:20
|
||||
#: a203325b7ec248f7bff61ae89226a000
|
||||
#: ../../getting_started/install/llm/llm.rst:21
|
||||
#: 4ac13a21f323455982750bd2e0243b72
|
||||
msgid "llama_cpp"
|
||||
msgstr "llama_cpp"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:21
|
||||
#: 21a50634198047228bc51a03d2c31292
|
||||
#: ../../getting_started/install/llm/llm.rst:22
|
||||
#: 7231edceef584724a6f569c6b363e083
|
||||
msgid "quantization"
|
||||
msgstr "quantization"
|
||||
|
||||
#: ../../getting_started/install/llm/llm.rst:22
|
||||
#: dfaec4b04e6e45ff9c884b41534b1a79
|
||||
msgid "cluster deployment"
|
||||
msgstr ""
|
||||
#~ msgid "cluster deployment"
|
||||
#~ msgstr ""
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
autodoc_pydantic==1.8.0
|
||||
autodoc_pydantic
|
||||
myst_parser
|
||||
nbsphinx==0.8.9
|
||||
sphinx==4.5.0
|
||||
|
143
pilot/componet.py
Normal file
143
pilot/componet.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, Dict, TypeVar, Optional, TYPE_CHECKING
|
||||
import asyncio
|
||||
|
||||
# Checking for type hints during runtime
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
class LifeCycle:
|
||||
"""This class defines hooks for lifecycle events of a component."""
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the component starts."""
|
||||
pass
|
||||
|
||||
async def async_before_start(self):
|
||||
"""Asynchronous version of before_start."""
|
||||
pass
|
||||
|
||||
def after_start(self):
|
||||
"""Called after the component has started."""
|
||||
pass
|
||||
|
||||
async def async_after_start(self):
|
||||
"""Asynchronous version of after_start."""
|
||||
pass
|
||||
|
||||
def before_stop(self):
|
||||
"""Called before the component stops."""
|
||||
pass
|
||||
|
||||
async def async_before_stop(self):
|
||||
"""Asynchronous version of before_stop."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseComponet(LifeCycle, ABC):
|
||||
"""Abstract Base Component class. All custom components should extend this."""
|
||||
|
||||
name = "base_dbgpt_componet"
|
||||
|
||||
def __init__(self, system_app: Optional[SystemApp] = None):
|
||||
if system_app is not None:
|
||||
self.init_app(system_app)
|
||||
|
||||
@abstractmethod
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the component with the main application.
|
||||
|
||||
This method needs to be implemented by every component to define how it integrates
|
||||
with the main system app.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseComponet)
|
||||
|
||||
|
||||
class SystemApp(LifeCycle):
|
||||
"""Main System Application class that manages the lifecycle and registration of components."""
|
||||
|
||||
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
|
||||
self.componets: Dict[
|
||||
str, BaseComponet
|
||||
] = {} # Dictionary to store registered components.
|
||||
self._asgi_app = asgi_app
|
||||
|
||||
@property
|
||||
def app(self) -> Optional["FastAPI"]:
|
||||
"""Returns the internal ASGI app."""
|
||||
return self._asgi_app
|
||||
|
||||
def register(self, componet: Type[BaseComponet], *args, **kwargs):
|
||||
"""Register a new component by its type."""
|
||||
instance = componet(self, *args, **kwargs)
|
||||
self.register_instance(instance)
|
||||
|
||||
def register_instance(self, instance: T):
|
||||
"""Register an already initialized component."""
|
||||
self.componets[instance.name] = instance
|
||||
instance.init_app(self)
|
||||
|
||||
def get_componet(self, name: str, componet_type: Type[T]) -> T:
|
||||
"""Retrieve a registered component by its name and type."""
|
||||
component = self.componets.get(name)
|
||||
if not component:
|
||||
raise ValueError(f"No component found with name {name}")
|
||||
if not isinstance(component, componet_type):
|
||||
raise TypeError(f"Component {name} is not of type {componet_type}")
|
||||
return component
|
||||
|
||||
def before_start(self):
|
||||
"""Invoke the before_start hooks for all registered components."""
|
||||
for _, v in self.componets.items():
|
||||
v.before_start()
|
||||
|
||||
async def async_before_start(self):
|
||||
"""Asynchronously invoke the before_start hooks for all registered components."""
|
||||
tasks = [v.async_before_start() for _, v in self.componets.items()]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def after_start(self):
|
||||
"""Invoke the after_start hooks for all registered components."""
|
||||
for _, v in self.componets.items():
|
||||
v.after_start()
|
||||
|
||||
async def async_after_start(self):
|
||||
"""Asynchronously invoke the after_start hooks for all registered components."""
|
||||
tasks = [v.async_after_start() for _, v in self.componets.items()]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def before_stop(self):
|
||||
"""Invoke the before_stop hooks for all registered components."""
|
||||
for _, v in self.componets.items():
|
||||
try:
|
||||
v.before_stop()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
async def async_before_stop(self):
|
||||
"""Asynchronously invoke the before_stop hooks for all registered components."""
|
||||
tasks = [v.async_before_stop() for _, v in self.componets.items()]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _build(self):
|
||||
"""Integrate lifecycle events with the internal ASGI app if available."""
|
||||
if not self.app:
|
||||
return
|
||||
|
||||
@self.app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""ASGI app startup event handler."""
|
||||
asyncio.create_task(self.async_after_start())
|
||||
self.after_start()
|
||||
|
||||
@self.app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""ASGI app shutdown event handler."""
|
||||
await self.async_before_stop()
|
||||
self.before_stop()
|
@@ -189,6 +189,10 @@ class Config(metaclass=Singleton):
|
||||
### Log level
|
||||
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||
|
||||
from pilot.componet import SystemApp
|
||||
|
||||
self.SYSTEM_APP: SystemApp = None
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value"""
|
||||
self.debug_mode = value
|
||||
|
@@ -4,6 +4,7 @@ import asyncio
|
||||
from pilot.configs.config import Config
|
||||
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.componet import SystemApp
|
||||
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
||||
from pilot.connections.base import BaseConnect
|
||||
|
||||
@@ -46,9 +47,9 @@ class ConnectManager:
|
||||
raise ValueError("Unsupport Db Type!" + db_type)
|
||||
return result
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, system_app: SystemApp):
|
||||
self.storage = DuckdbConnectConfig()
|
||||
self.db_summary_client = DBSummaryClient()
|
||||
self.db_summary_client = DBSummaryClient(system_app)
|
||||
self.__load_config_db()
|
||||
|
||||
def __load_config_db(self):
|
||||
|
@@ -2,6 +2,6 @@ from pilot.configs.config import Config
|
||||
from pilot.connections.manages.connection_manager import ConnectManager
|
||||
|
||||
if __name__ == "__main__":
|
||||
mange = ConnectManager()
|
||||
mange = ConnectManager(system_app=None)
|
||||
types = mange.get_all_completed_types()
|
||||
print(str(types))
|
||||
|
@@ -1,9 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
from pilot.embedding_engine.embedding_factory import (
|
||||
EmbeddingFactory,
|
||||
DefaultEmbeddingFactory,
|
||||
)
|
||||
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding, KnowledgeType
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -24,13 +27,16 @@ class EmbeddingEngine:
|
||||
knowledge_source: Optional[str] = None,
|
||||
source_reader: Optional = None,
|
||||
text_splitter: Optional[TextSplitter] = None,
|
||||
embedding_factory: EmbeddingFactory = None,
|
||||
):
|
||||
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source"""
|
||||
self.knowledge_source = knowledge_source
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.knowledge_type = knowledge_type
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
if not embedding_factory:
|
||||
embedding_factory = DefaultEmbeddingFactory()
|
||||
self.embeddings = embedding_factory.create(model_name=self.model_name)
|
||||
self.vector_store_config["embeddings"] = self.embeddings
|
||||
self.source_reader = source_reader
|
||||
self.text_splitter = text_splitter
|
||||
|
39
pilot/embedding_engine/embedding_factory.py
Normal file
39
pilot/embedding_engine/embedding_factory.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
|
||||
from pilot.componet import BaseComponet
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class EmbeddingFactory(BaseComponet, ABC):
|
||||
name = "embedding_factory"
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "Embeddings":
|
||||
"""Create embedding"""
|
||||
|
||||
|
||||
class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(self, system_app=None, model_name: str = None, **kwargs: Any) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = model_name
|
||||
self.kwargs = kwargs
|
||||
|
||||
def init_app(self, system_app):
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "Embeddings":
|
||||
if not model_name:
|
||||
model_name = self._default_model_name
|
||||
if embedding_cls:
|
||||
return embedding_cls(model_name=model_name, **self.kwargs)
|
||||
else:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
return HuggingFaceEmbeddings(model_name=model_name, **self.kwargs)
|
@@ -96,13 +96,18 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
|
||||
|
||||
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||
from pilot.model.parameter import EmbeddingModelParameters, WorkerType
|
||||
|
||||
pre_args = _SimpleArgParser("model_name", "model_path")
|
||||
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
|
||||
pre_args.parse()
|
||||
model_name = pre_args.get("model_name")
|
||||
model_path = pre_args.get("model_path")
|
||||
worker_type = pre_args.get("worker_type")
|
||||
if model_name is None:
|
||||
return None
|
||||
if worker_type == WorkerType.TEXT2VEC:
|
||||
return [EmbeddingModelParameters]
|
||||
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||
param_class = llm_adapter.model_param_class()
|
||||
return [param_class]
|
||||
|
@@ -167,7 +167,6 @@ def stop(model_name: str, model_type: str, host: str, port: int):
|
||||
|
||||
|
||||
def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
|
||||
from pilot.model.adapter import _dynamic_model_parser
|
||||
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||
from pilot.model.cluster import RemoteWorkerManager
|
||||
from pilot.model.parameter import WorkerType
|
||||
|
@@ -5,6 +5,9 @@ from pilot.model.cluster.base import (
|
||||
WorkerParameterRequest,
|
||||
WorkerStartupRequest,
|
||||
)
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
|
||||
|
||||
from pilot.model.cluster.worker.manager import (
|
||||
initialize_worker_manager_in_client,
|
||||
run_worker_manager,
|
||||
@@ -23,11 +26,15 @@ __all__ = [
|
||||
"EmbeddingsRequest",
|
||||
"PromptRequest",
|
||||
"WorkerApplyRequest",
|
||||
"WorkerParameterRequest"
|
||||
"WorkerStartupRequest"
|
||||
"worker_manager"
|
||||
"WorkerParameterRequest",
|
||||
"WorkerStartupRequest",
|
||||
"ModelWorker",
|
||||
"DefaultModelWorker",
|
||||
"worker_manager",
|
||||
"run_worker_manager",
|
||||
"initialize_worker_manager_in_client",
|
||||
"ModelRegistry",
|
||||
"ModelRegistryClient" "RemoteWorkerManager" "run_model_controller",
|
||||
"ModelRegistryClient",
|
||||
"RemoteWorkerManager",
|
||||
"run_model_controller",
|
||||
]
|
||||
|
@@ -8,7 +8,10 @@ from pilot.model.base import ModelInstance
|
||||
from pilot.model.parameter import ModelControllerParameters
|
||||
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
from pilot.utils.api_utils import _api_remote as api_remote
|
||||
from pilot.utils.api_utils import (
|
||||
_api_remote as api_remote,
|
||||
_sync_api_remote as sync_api_remote,
|
||||
)
|
||||
|
||||
|
||||
class BaseModelController(ABC):
|
||||
@@ -89,6 +92,12 @@ class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
return await self.get_all_instances()
|
||||
|
||||
@sync_api_remote(path="/api/controller/models")
|
||||
def sync_get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
|
||||
class ModelControllerAdapter(BaseModelController):
|
||||
def __init__(self, backend: BaseModelController = None) -> None:
|
||||
|
0
pilot/model/cluster/embedding/__init__.py
Normal file
0
pilot/model/cluster/embedding/__init__.py
Normal file
28
pilot/model/cluster/embedding/remote_embedding.py
Normal file
28
pilot/model/cluster/embedding/remote_embedding.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import List
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
from pilot.model.cluster.manager_base import WorkerManager
|
||||
|
||||
|
||||
class RemoteEmbeddings(Embeddings):
|
||||
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
|
||||
self.model_name = model_name
|
||||
self.worker_manager = worker_manager
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
params = {"model": self.model_name, "input": texts}
|
||||
return self.worker_manager.sync_embeddings(params)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
params = {"model": self.model_name, "input": texts}
|
||||
return await self.worker_manager.embeddings(params)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await self.aembed_documents([text])[0]
|
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Dict, Iterator
|
||||
from typing import List, Optional, Dict, Iterator, Callable
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from concurrent.futures import Future
|
||||
@@ -35,15 +35,31 @@ class WorkerManager(ABC):
|
||||
async def stop(self):
|
||||
"""Stop worker manager"""
|
||||
|
||||
@abstractmethod
|
||||
def after_start(self, listener: Callable[["WorkerManager"], None]):
|
||||
"""Add a listener after WorkerManager startup"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
"""Asynchronous get model instances by worker type and model name"""
|
||||
|
||||
@abstractmethod
|
||||
def sync_get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
"""Get model instances by worker type and model name"""
|
||||
|
||||
@abstractmethod
|
||||
async def select_one_instance(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
"""Asynchronous select one instance"""
|
||||
|
||||
@abstractmethod
|
||||
def sync_select_one_instance(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
"""Select one instance"""
|
||||
|
||||
@@ -69,7 +85,15 @@ class WorkerManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Embed input"""
|
||||
"""Asynchronous embed input"""
|
||||
|
||||
@abstractmethod
|
||||
def sync_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Embed input
|
||||
|
||||
This function may be passed to a third-party system call for synchronous calls.
|
||||
We must provide a synchronous version.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
|
@@ -58,6 +58,12 @@ class ModelRegistry(ABC):
|
||||
- List[ModelInstance]: A list of instances for the given model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def sync_get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
"""
|
||||
@@ -163,6 +169,11 @@ class EmbeddedModelRegistry(ModelRegistry):
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
return self.sync_get_all_instances(model_name, healthy_only)
|
||||
|
||||
def sync_get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
instances = self.registry[model_name]
|
||||
if healthy_only:
|
||||
@@ -179,7 +190,7 @@ class EmbeddedModelRegistry(ModelRegistry):
|
||||
)
|
||||
if not exist_ins:
|
||||
# register new install from heartbeat
|
||||
self.register_instance(instance)
|
||||
await self.register_instance(instance)
|
||||
return True
|
||||
|
||||
ins = exist_ins[0]
|
||||
|
@@ -24,7 +24,7 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
"Could not import langchain.embeddings.HuggingFaceEmbeddings python package. "
|
||||
"Please install it with `pip install langchain`."
|
||||
) from exc
|
||||
self.embeddings: Embeddings = None
|
||||
self._embeddings_impl: Embeddings = None
|
||||
self._model_params = None
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
@@ -75,16 +75,16 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
|
||||
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
|
||||
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
|
||||
self.embeddings = HuggingFaceEmbeddings(**kwargs)
|
||||
self._embeddings_impl = HuggingFaceEmbeddings(**kwargs)
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.embeddings:
|
||||
if not self._embeddings_impl:
|
||||
return
|
||||
del self.embeddings
|
||||
self.embeddings = None
|
||||
del self._embeddings_impl
|
||||
self._embeddings_impl = None
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict):
|
||||
@@ -96,5 +96,7 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
raise NotImplementedError("Not supported generate for embeddings model")
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
model = params.get("model")
|
||||
logger.info(f"Receive embeddings request, model: {model}")
|
||||
input: List[str] = params["input"]
|
||||
return self.embeddings.embed_documents(input)
|
||||
return self._embeddings_impl.embed_documents(input)
|
||||
|
@@ -72,6 +72,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
self.model_registry = model_registry
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.start_listeners = []
|
||||
|
||||
self.run_data = WorkerRunData(
|
||||
host=self.host,
|
||||
@@ -105,6 +106,8 @@ class LocalWorkerManager(WorkerManager):
|
||||
asyncio.create_task(
|
||||
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
|
||||
)
|
||||
for listener in self.start_listeners:
|
||||
listener(self)
|
||||
|
||||
async def stop(self):
|
||||
if not self.run_data.stop_event.is_set():
|
||||
@@ -116,6 +119,9 @@ class LocalWorkerManager(WorkerManager):
|
||||
stop_tasks.append(self.deregister_func(self.run_data))
|
||||
await asyncio.gather(*stop_tasks)
|
||||
|
||||
def after_start(self, listener: Callable[["WorkerManager"], None]):
|
||||
self.start_listeners.append(listener)
|
||||
|
||||
def add_worker(
|
||||
self,
|
||||
worker: ModelWorker,
|
||||
@@ -137,14 +143,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
worker_key = self._worker_key(
|
||||
worker_params.worker_type, worker_params.model_name
|
||||
)
|
||||
host = worker_params.host
|
||||
port = worker_params.port
|
||||
|
||||
instances = self.workers.get(worker_key)
|
||||
if not instances:
|
||||
instances = []
|
||||
self.workers[worker_key] = instances
|
||||
logger.info(f"Init empty instances list for {worker_key}")
|
||||
# Load model params from persist storage
|
||||
model_params = worker.parse_parameters(command_args=command_args)
|
||||
|
||||
@@ -159,14 +158,15 @@ class LocalWorkerManager(WorkerManager):
|
||||
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
|
||||
command_args=command_args,
|
||||
)
|
||||
exist_instances = [
|
||||
ins for ins in instances if ins.host == host and ins.port == port
|
||||
]
|
||||
if not exist_instances:
|
||||
instances.append(worker_run_data)
|
||||
instances = self.workers.get(worker_key)
|
||||
if not instances:
|
||||
instances = [worker_run_data]
|
||||
self.workers[worker_key] = instances
|
||||
logger.info(f"Init empty instances list for {worker_key}")
|
||||
return True
|
||||
else:
|
||||
# TODO Update worker
|
||||
logger.warn(f"Instance {worker_key} exist")
|
||||
return False
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
@@ -222,16 +222,18 @@ class LocalWorkerManager(WorkerManager):
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
return self.sync_get_model_instances(worker_type, model_name, healthy_only)
|
||||
|
||||
def sync_get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
return self.workers.get(worker_key)
|
||||
|
||||
async def select_one_instance(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
def _simple_select(
|
||||
self, worker_type: str, model_name: str, worker_instances: List[WorkerRunData]
|
||||
) -> WorkerRunData:
|
||||
worker_instances = await self.get_model_instances(
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
if not worker_instances:
|
||||
raise Exception(
|
||||
f"Cound not found worker instances for model name {model_name} and worker type {worker_type}"
|
||||
@@ -239,12 +241,34 @@ class LocalWorkerManager(WorkerManager):
|
||||
worker_run_data = random.choice(worker_instances)
|
||||
return worker_run_data
|
||||
|
||||
async def select_one_instance(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
worker_instances = await self.get_model_instances(
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
return self._simple_select(worker_type, model_name, worker_instances)
|
||||
|
||||
def sync_select_one_instance(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
worker_instances = self.sync_get_model_instances(
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
return self._simple_select(worker_type, model_name, worker_instances)
|
||||
|
||||
async def _get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
|
||||
model = params.get("model")
|
||||
if not model:
|
||||
raise Exception("Model name count not be empty")
|
||||
return await self.select_one_instance(worker_type, model, healthy_only=True)
|
||||
|
||||
def _sync_get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
|
||||
model = params.get("model")
|
||||
if not model:
|
||||
raise Exception("Model name count not be empty")
|
||||
return self.sync_select_one_instance(worker_type, model, healthy_only=True)
|
||||
|
||||
async def generate_stream(
|
||||
self, params: Dict, async_wrapper=None, **kwargs
|
||||
) -> Iterator[ModelOutput]:
|
||||
@@ -304,6 +328,10 @@ class LocalWorkerManager(WorkerManager):
|
||||
worker_run_data.worker.embeddings, params
|
||||
)
|
||||
|
||||
def sync_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
|
||||
return worker_run_data.worker.embeddings(params)
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
|
||||
if apply_req.apply_type == WorkerApplyType.START:
|
||||
@@ -458,6 +486,10 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
async def stop(self):
|
||||
return await self.worker_manager.stop()
|
||||
|
||||
def after_start(self, listener: Callable[["WorkerManager"], None]):
|
||||
if listener is not None:
|
||||
self.worker_manager.after_start(listener)
|
||||
|
||||
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||
return await self.worker_manager.supported_models()
|
||||
|
||||
@@ -474,6 +506,13 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
def sync_get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
return self.worker_manager.sync_get_model_instances(
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
async def select_one_instance(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
@@ -481,6 +520,13 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
def sync_select_one_instance(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
return self.worker_manager.sync_select_one_instance(
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
|
||||
async for output in self.worker_manager.generate_stream(params, **kwargs):
|
||||
yield output
|
||||
@@ -491,6 +537,9 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
return await self.worker_manager.embeddings(params)
|
||||
|
||||
def sync_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
return self.worker_manager.sync_embeddings(params)
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
return await self.worker_manager.worker_apply(apply_req)
|
||||
|
||||
@@ -586,11 +635,11 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
asyncio.create_task(worker_manager.worker_manager.start())
|
||||
asyncio.create_task(worker_manager.start())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def startup_event():
|
||||
await worker_manager.worker_manager.stop()
|
||||
await worker_manager.stop()
|
||||
|
||||
return app
|
||||
|
||||
@@ -666,29 +715,60 @@ def _create_local_model_manager(
|
||||
|
||||
|
||||
def _build_worker(worker_params: ModelWorkerParameters):
|
||||
if worker_params.worker_class:
|
||||
worker_class = worker_params.worker_class
|
||||
if worker_class:
|
||||
from pilot.utils.module_utils import import_from_checked_string
|
||||
|
||||
worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker)
|
||||
logger.info(
|
||||
f"Import worker class from {worker_params.worker_class} successfully"
|
||||
)
|
||||
worker: ModelWorker = worker_cls()
|
||||
worker_cls = import_from_checked_string(worker_class, ModelWorker)
|
||||
logger.info(f"Import worker class from {worker_class} successfully")
|
||||
else:
|
||||
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
|
||||
if (
|
||||
worker_params.worker_type is None
|
||||
or worker_params.worker_type == WorkerType.LLM
|
||||
):
|
||||
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
|
||||
|
||||
worker = DefaultModelWorker()
|
||||
return worker
|
||||
worker_cls = DefaultModelWorker
|
||||
elif worker_params.worker_type == WorkerType.TEXT2VEC:
|
||||
from pilot.model.cluster.worker.embedding_worker import (
|
||||
EmbeddingsModelWorker,
|
||||
)
|
||||
|
||||
worker_cls = EmbeddingsModelWorker
|
||||
else:
|
||||
raise Exception("Unsupported worker type: {worker_params.worker_type}")
|
||||
|
||||
return worker_cls()
|
||||
|
||||
|
||||
def _start_local_worker(
|
||||
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
|
||||
):
|
||||
worker = _build_worker(worker_params)
|
||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||
if not worker_manager.worker_manager:
|
||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||
worker_manager.worker_manager.add_worker(worker, worker_params)
|
||||
|
||||
|
||||
def _start_local_embedding_worker(
|
||||
worker_manager: WorkerManagerAdapter,
|
||||
embedding_model_name: str = None,
|
||||
embedding_model_path: str = None,
|
||||
):
|
||||
if not embedding_model_name or not embedding_model_path:
|
||||
return
|
||||
embedding_worker_params = ModelWorkerParameters(
|
||||
model_name=embedding_model_name,
|
||||
model_path=embedding_model_path,
|
||||
worker_type=WorkerType.TEXT2VEC,
|
||||
worker_class="pilot.model.cluster.worker.embedding_worker.EmbeddingsModelWorker",
|
||||
)
|
||||
logger.info(
|
||||
f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
|
||||
)
|
||||
_start_local_worker(worker_manager, embedding_worker_params)
|
||||
|
||||
|
||||
def initialize_worker_manager_in_client(
|
||||
app=None,
|
||||
include_router: bool = True,
|
||||
@@ -697,6 +777,9 @@ def initialize_worker_manager_in_client(
|
||||
run_locally: bool = True,
|
||||
controller_addr: str = None,
|
||||
local_port: int = 5000,
|
||||
embedding_model_name: str = None,
|
||||
embedding_model_path: str = None,
|
||||
start_listener: Callable[["WorkerManager"], None] = None,
|
||||
):
|
||||
"""Initialize WorkerManager in client.
|
||||
If run_locally is True:
|
||||
@@ -728,6 +811,10 @@ def initialize_worker_manager_in_client(
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
_setup_fastapi(worker_params, app)
|
||||
_start_local_worker(worker_manager, worker_params)
|
||||
worker_manager.after_start(start_listener)
|
||||
_start_local_embedding_worker(
|
||||
worker_manager, embedding_model_name, embedding_model_path
|
||||
)
|
||||
else:
|
||||
from pilot.model.cluster.controller.controller import (
|
||||
ModelRegistryClient,
|
||||
@@ -741,9 +828,12 @@ def initialize_worker_manager_in_client(
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
client = ModelRegistryClient(worker_params.controller_addr)
|
||||
worker_manager.worker_manager = RemoteWorkerManager(client)
|
||||
worker_manager.after_start(start_listener)
|
||||
initialize_controller(
|
||||
app=app, remote_controller_addr=worker_params.controller_addr
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(worker_manager.start())
|
||||
|
||||
if include_router and app:
|
||||
# mount WorkerManager router
|
||||
@@ -757,6 +847,8 @@ def run_worker_manager(
|
||||
model_path: str = None,
|
||||
standalone: bool = False,
|
||||
port: int = None,
|
||||
embedding_model_name: str = None,
|
||||
embedding_model_path: str = None,
|
||||
):
|
||||
global worker_manager
|
||||
|
||||
@@ -765,15 +857,22 @@ def run_worker_manager(
|
||||
)
|
||||
|
||||
embedded_mod = True
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
if not app:
|
||||
# Run worker manager independently
|
||||
embedded_mod = False
|
||||
app = _setup_fastapi(worker_params)
|
||||
_start_local_worker(worker_manager, worker_params)
|
||||
_start_local_embedding_worker(
|
||||
worker_manager, embedding_model_name, embedding_model_path
|
||||
)
|
||||
else:
|
||||
_start_local_worker(worker_manager, worker_params)
|
||||
_start_local_embedding_worker(
|
||||
worker_manager, embedding_model_name, embedding_model_path
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(worker_manager.worker_manager.start())
|
||||
loop.run_until_complete(worker_manager.start())
|
||||
|
||||
if include_router:
|
||||
app.include_router(router, prefix="/api")
|
||||
|
@@ -1,14 +1,12 @@
|
||||
from typing import Callable, Any
|
||||
import httpx
|
||||
import asyncio
|
||||
from typing import Any, Callable
|
||||
|
||||
import httpx
|
||||
from pilot.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
|
||||
from pilot.model.cluster.base import *
|
||||
from pilot.model.cluster.registry import ModelRegistry
|
||||
from pilot.model.cluster.worker.manager import LocalWorkerManager, WorkerRunData, logger
|
||||
from pilot.model.cluster.base import *
|
||||
from pilot.model.base import (
|
||||
ModelInstance,
|
||||
WorkerApplyOutput,
|
||||
WorkerSupportedModel,
|
||||
)
|
||||
from pilot.model.cluster.worker.remote_worker import RemoteModelWorker
|
||||
|
||||
|
||||
class RemoteWorkerManager(LocalWorkerManager):
|
||||
@@ -16,7 +14,8 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
super().__init__(model_registry=model_registry)
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
for listener in self.start_listeners:
|
||||
listener(self)
|
||||
|
||||
async def stop(self):
|
||||
pass
|
||||
@@ -125,15 +124,9 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
success_handler=lambda x: True,
|
||||
)
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
def _build_worker_instances(
|
||||
self, model_name: str, instances: List[ModelInstance]
|
||||
) -> List[WorkerRunData]:
|
||||
from pilot.model.cluster.worker.remote_worker import RemoteModelWorker
|
||||
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
|
||||
worker_key, healthy_only
|
||||
)
|
||||
worker_instances = []
|
||||
for ins in instances:
|
||||
worker = RemoteModelWorker()
|
||||
@@ -151,6 +144,24 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
worker_instances.append(wr)
|
||||
return worker_instances
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
|
||||
worker_key, healthy_only
|
||||
)
|
||||
return self._build_worker_instances(model_name, instances)
|
||||
|
||||
def sync_get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
instances: List[ModelInstance] = self.model_registry.sync_get_all_instances(
|
||||
worker_key, healthy_only
|
||||
)
|
||||
return self._build_worker_instances(model_name, instances)
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
||||
return await self._fetch_from_worker(
|
||||
|
@@ -87,7 +87,15 @@ class RemoteModelWorker(ModelWorker):
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Get embeddings for input"""
|
||||
raise NotImplementedError
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
self.worker_addr + "/embeddings",
|
||||
headers=self.headers,
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Asynchronous get embeddings for input"""
|
||||
|
@@ -46,9 +46,7 @@ class ModelWorkerParameters(BaseModelParameters):
|
||||
)
|
||||
worker_class: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Model worker class, pilot.model.worker.default_worker.DefaultModelWorker"
|
||||
},
|
||||
metadata={"help": "Model worker class, pilot.model.cluster.DefaultModelWorker"},
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
default="0.0.0.0", metadata={"help": "Model worker deploy host"}
|
||||
|
@@ -111,7 +111,7 @@ async def db_connect_delete(db_name: str = None):
|
||||
|
||||
async def async_db_summary_embedding(db_name, db_type):
|
||||
# 在这里执行需要异步运行的代码
|
||||
db_summary_client = DBSummaryClient()
|
||||
db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
db_summary_client.db_summary_embedding(db_name, db_type)
|
||||
|
||||
|
||||
|
@@ -61,7 +61,7 @@ class ChatDashboard(BaseChat):
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
|
||||
client = DBSummaryClient()
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
try:
|
||||
table_infos = client.get_similar_tables(
|
||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
|
@@ -35,7 +35,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient()
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
try:
|
||||
table_infos = client.get_db_summary(
|
||||
dbname=self.db_name,
|
||||
|
@@ -41,7 +41,7 @@ class ChatWithDbQA(BaseChat):
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
if self.db_name:
|
||||
client = DBSummaryClient()
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
try:
|
||||
table_infos = client.get_db_summary(
|
||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
|
@@ -23,6 +23,7 @@ class ChatKnowledge(BaseChat):
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
""" """
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
self.knowledge_space = select_param
|
||||
super().__init__(
|
||||
@@ -47,9 +48,13 @@ class ChatKnowledge(BaseChat):
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
embedding_factory = CFG.SYSTEM_APP.get_componet(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
self.knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
|
@@ -2,10 +2,11 @@ import signal
|
||||
import os
|
||||
import threading
|
||||
import sys
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.componet import SystemApp
|
||||
from pilot.utils.parameter_utils import BaseParameters
|
||||
|
||||
|
||||
@@ -18,30 +19,28 @@ def signal_handler(sig, frame):
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def async_db_summery():
|
||||
def async_db_summery(system_app: SystemApp):
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
|
||||
client = DBSummaryClient()
|
||||
client = DBSummaryClient(system_app=system_app)
|
||||
thread = threading.Thread(target=client.init_db_summary)
|
||||
thread.start()
|
||||
|
||||
|
||||
def server_init(args):
|
||||
def server_init(args, system_app: SystemApp):
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.connections.manages.connection_manager import ConnectManager
|
||||
|
||||
from pilot.common.plugins import scan_plugins
|
||||
|
||||
# logger.info(f"args: {args}")
|
||||
|
||||
# init config
|
||||
cfg = Config()
|
||||
# init connect manage
|
||||
conn_manage = ConnectManager()
|
||||
cfg.LOCAL_DB_MANAGE = conn_manage
|
||||
cfg.SYSTEM_APP = system_app
|
||||
|
||||
# load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
async_db_summery()
|
||||
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# Loader plugins and commands
|
||||
@@ -70,6 +69,22 @@ def server_init(args):
|
||||
cfg.command_disply = command_disply_registry
|
||||
|
||||
|
||||
def _create_model_start_listener(system_app: SystemApp):
|
||||
from pilot.connections.manages.connection_manager import ConnectManager
|
||||
from pilot.model.cluster import worker_manager
|
||||
|
||||
cfg = Config()
|
||||
|
||||
def startup_event(wh):
|
||||
# init connect manage
|
||||
print("begin run _add_app_startup_event")
|
||||
conn_manage = ConnectManager(system_app)
|
||||
cfg.LOCAL_DB_MANAGE = conn_manage
|
||||
async_db_summery(system_app)
|
||||
|
||||
return startup_event
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebWerverParameters(BaseParameters):
|
||||
host: Optional[str] = field(
|
||||
|
38
pilot/server/componet_configs.py
Normal file
38
pilot/server/componet_configs.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
|
||||
from pilot.componet import SystemApp
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
def initialize_componets(system_app: SystemApp, embedding_model_name: str):
|
||||
from pilot.model.cluster import worker_manager
|
||||
|
||||
system_app.register(
|
||||
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
||||
)
|
||||
|
||||
|
||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(
|
||||
self, system_app, worker_manager, model_name: str = None, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._worker_manager = worker_manager
|
||||
self._default_model_name = model_name
|
||||
self.kwargs = kwargs
|
||||
|
||||
def init_app(self, system_app):
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "Embeddings":
|
||||
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
|
||||
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
# Ignore model_name args
|
||||
return RemoteEmbeddings(self._default_model_name, self._worker_manager)
|
@@ -7,9 +7,15 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
||||
sys.path.append(ROOT_PATH)
|
||||
import signal
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
|
||||
from pilot.componet import SystemApp
|
||||
|
||||
from pilot.server.base import server_init, WebWerverParameters
|
||||
from pilot.server.base import (
|
||||
server_init,
|
||||
WebWerverParameters,
|
||||
_create_model_start_listener,
|
||||
)
|
||||
from pilot.server.componet_configs import initialize_componets
|
||||
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi import FastAPI, applications
|
||||
@@ -48,6 +54,8 @@ def swagger_monkey_patch(*args, **kwargs):
|
||||
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
# 添加跨域中间件
|
||||
@@ -98,7 +106,12 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
|
||||
|
||||
setup_logging(logging_level=param.log_level)
|
||||
server_init(param)
|
||||
# Before start
|
||||
system_app.before_start()
|
||||
|
||||
server_init(param, system_app)
|
||||
model_start_listener = _create_model_start_listener(system_app)
|
||||
initialize_componets(system_app, CFG.EMBEDDING_MODEL)
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
if not param.light:
|
||||
@@ -108,6 +121,9 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
model_name=CFG.LLM_MODEL,
|
||||
model_path=model_path,
|
||||
local_port=param.port,
|
||||
embedding_model_name=CFG.EMBEDDING_MODEL,
|
||||
embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
start_listener=model_start_listener,
|
||||
)
|
||||
|
||||
CFG.NEW_SERVER_MODE = True
|
||||
@@ -120,6 +136,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
run_locally=False,
|
||||
controller_addr=CFG.MODEL_SERVER,
|
||||
local_port=param.port,
|
||||
start_listener=model_start_listener,
|
||||
)
|
||||
CFG.SERVER_LIGHT_MODE = True
|
||||
|
||||
|
@@ -4,8 +4,6 @@ import tempfile
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile, Form
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
@@ -14,6 +12,7 @@ from pilot.configs.model_config import (
|
||||
|
||||
from pilot.openapi.api_view_model import Result
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
from pilot.server.knowledge.service import KnowledgeService
|
||||
from pilot.server.knowledge.request.request import (
|
||||
@@ -32,10 +31,6 @@ CFG = Config()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
|
||||
knowledge_space_service = KnowledgeService()
|
||||
|
||||
|
||||
@@ -186,8 +181,13 @@ def document_list(space_name: str, query_request: ChunkQueryRequest):
|
||||
@router.post("/knowledge/{vector_name}/query")
|
||||
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
||||
print(f"Received params: {space_name}, {query_request}")
|
||||
embedding_factory = CFG.SYSTEM_APP.get_componet(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
client = EmbeddingEngine(
|
||||
model_name=embeddings, vector_store_config={"vector_store_name": space_name}
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={"vector_store_name": space_name},
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
docs = client.similar_search(query_request.query, query_request.top_k)
|
||||
res = [
|
||||
|
@@ -154,6 +154,7 @@ class KnowledgeService:
|
||||
|
||||
def sync_knowledge_document(self, space_name, doc_ids):
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from langchain.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
SpacyTextSplitter,
|
||||
@@ -204,6 +205,9 @@ class KnowledgeService:
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
embedding_factory = CFG.SYSTEM_APP.get_componet(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
client = EmbeddingEngine(
|
||||
knowledge_source=doc.content,
|
||||
knowledge_type=doc.doc_type.upper(),
|
||||
@@ -214,6 +218,7 @@ class KnowledgeService:
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
},
|
||||
text_splitter=text_splitter,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
chunk_docs = client.read()
|
||||
# update document status
|
||||
|
@@ -8,7 +8,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
|
||||
from pilot.model.cluster import run_worker_manager
|
||||
|
||||
CFG = Config()
|
||||
@@ -21,4 +21,6 @@ if __name__ == "__main__":
|
||||
model_path=model_path,
|
||||
standalone=True,
|
||||
port=CFG.MODEL_PORT,
|
||||
embedding_model_name=CFG.EMBEDDING_MODEL,
|
||||
embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
)
|
||||
|
@@ -2,6 +2,7 @@ import json
|
||||
import uuid
|
||||
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.componet import SystemApp
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
@@ -26,16 +27,19 @@ class DBSummaryClient:
|
||||
, get_similar_tables method(get user query related tables info)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def db_summary_embedding(self, dbname, db_type):
|
||||
"""put db profile and table profile summary into vector store"""
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
db_summary_client = RdbmsSummary(dbname, db_type)
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
embedding_factory = self.system_app.get_componet(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embeddings = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vector_store_config = {
|
||||
@@ -83,15 +87,20 @@ class DBSummaryClient:
|
||||
|
||||
def get_db_summary(self, dbname, query, topk):
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_profile",
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
embedding_factory = CFG.SYSTEM_APP.get_componet(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
||||
ans = [d.page_content for d in table_docs]
|
||||
@@ -100,6 +109,7 @@ class DBSummaryClient:
|
||||
def get_similar_tables(self, dbname, query, topk):
|
||||
"""get user query related tables info"""
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_summary",
|
||||
@@ -107,9 +117,13 @@ class DBSummaryClient:
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
embedding_factory = CFG.SYSTEM_APP.get_componet(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
||||
@@ -136,6 +150,7 @@ class DBSummaryClient:
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
table_summery = knowledge_embedding_client.similar_search(query, 1)
|
||||
related_table_summaries.append(table_summery[0].page_content)
|
||||
|
@@ -15,60 +15,60 @@ def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||
return None
|
||||
|
||||
|
||||
def _build_request(self, func, path, method, *args, **kwargs):
|
||||
return_type = get_type_hints(func).get("return")
|
||||
if return_type is None:
|
||||
raise TypeError("Return type must be annotated in the decorated function.")
|
||||
|
||||
actual_dataclass = _extract_dataclass_from_generic(return_type)
|
||||
logging.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
|
||||
if not actual_dataclass:
|
||||
actual_dataclass = return_type
|
||||
sig = signature(func)
|
||||
base_url = self.base_url # Get base_url from class instance
|
||||
|
||||
bound = sig.bind(self, *args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
|
||||
formatted_url = base_url + path.format(**bound.arguments)
|
||||
|
||||
# Extract args names from signature, except "self"
|
||||
arg_names = list(sig.parameters.keys())[1:]
|
||||
|
||||
# Combine args and kwargs into a single dictionary
|
||||
combined_args = dict(zip(arg_names, args))
|
||||
combined_args.update(kwargs)
|
||||
|
||||
request_data = {}
|
||||
for key, value in combined_args.items():
|
||||
if is_dataclass(value):
|
||||
# Here, instead of adding it as a nested dictionary,
|
||||
# we set request_data directly to its dictionary representation.
|
||||
request_data = asdict(value)
|
||||
else:
|
||||
request_data[key] = value
|
||||
|
||||
request_params = {"method": method, "url": formatted_url}
|
||||
|
||||
if method in ["POST", "PUT", "PATCH"]:
|
||||
request_params["json"] = request_data
|
||||
else: # For GET, DELETE, etc.
|
||||
request_params["params"] = request_data
|
||||
|
||||
logging.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
|
||||
return return_type, actual_dataclass, request_params
|
||||
|
||||
|
||||
def _api_remote(path, method="GET"):
|
||||
def decorator(func):
|
||||
return_type = get_type_hints(func).get("return")
|
||||
if return_type is None:
|
||||
raise TypeError("Return type must be annotated in the decorated function.")
|
||||
|
||||
actual_dataclass = _extract_dataclass_from_generic(return_type)
|
||||
logging.debug(
|
||||
f"return_type: {return_type}, actual_dataclass: {actual_dataclass}"
|
||||
)
|
||||
if not actual_dataclass:
|
||||
actual_dataclass = return_type
|
||||
sig = signature(func)
|
||||
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
import httpx
|
||||
|
||||
base_url = self.base_url # Get base_url from class instance
|
||||
|
||||
bound = sig.bind(self, *args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
|
||||
formatted_url = base_url + path.format(**bound.arguments)
|
||||
|
||||
# Extract args names from signature, except "self"
|
||||
arg_names = list(sig.parameters.keys())[1:]
|
||||
|
||||
# Combine args and kwargs into a single dictionary
|
||||
combined_args = dict(zip(arg_names, args))
|
||||
combined_args.update(kwargs)
|
||||
|
||||
request_data = {}
|
||||
for key, value in combined_args.items():
|
||||
if is_dataclass(value):
|
||||
# Here, instead of adding it as a nested dictionary,
|
||||
# we set request_data directly to its dictionary representation.
|
||||
request_data = asdict(value)
|
||||
else:
|
||||
request_data[key] = value
|
||||
|
||||
request_params = {"method": method, "url": formatted_url}
|
||||
|
||||
if method in ["POST", "PUT", "PATCH"]:
|
||||
request_params["json"] = request_data
|
||||
else: # For GET, DELETE, etc.
|
||||
request_params["params"] = request_data
|
||||
|
||||
logging.info(
|
||||
f"request_params: {request_params}, args: {args}, kwargs: {kwargs}"
|
||||
return_type, actual_dataclass, request_params = _build_request(
|
||||
self, func, path, method, *args, **kwargs
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(**request_params)
|
||||
|
||||
if response.status_code == 200:
|
||||
return _parse_response(
|
||||
response.json(), return_type, actual_dataclass
|
||||
@@ -82,6 +82,28 @@ def _api_remote(path, method="GET"):
|
||||
return decorator
|
||||
|
||||
|
||||
def _sync_api_remote(path, method="GET"):
|
||||
def decorator(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
import requests
|
||||
|
||||
return_type, actual_dataclass, request_params = _build_request(
|
||||
self, func, path, method, *args, **kwargs
|
||||
)
|
||||
|
||||
response = requests.request(**request_params)
|
||||
|
||||
if response.status_code == 200:
|
||||
return _parse_response(response.json(), return_type, actual_dataclass)
|
||||
else:
|
||||
error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _parse_response(json_response, return_type, actual_dataclass):
|
||||
# print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}')
|
||||
if is_dataclass(actual_dataclass):
|
||||
|
@@ -30,8 +30,9 @@ def knownledge_tovec_st(filename):
|
||||
https://github.com/UKPLab/sentence-transformers
|
||||
"""
|
||||
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
embeddings = DefaultEmbeddingFactory().create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
|
||||
)
|
||||
|
||||
@@ -58,8 +59,9 @@ def load_knownledge_from_doc():
|
||||
)
|
||||
|
||||
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
embeddings = DefaultEmbeddingFactory().create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
|
||||
)
|
||||
|
||||
|
@@ -44,7 +44,11 @@ class KnownLedge2Vector:
|
||||
def __init__(self, model_name=None) -> None:
|
||||
if not model_name:
|
||||
# use default embedding model
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
self.embeddings = DefaultEmbeddingFactory().create(
|
||||
model_name=self.model_name
|
||||
)
|
||||
|
||||
def init_vector_store(self):
|
||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||
|
Reference in New Issue
Block a user