diff --git a/CODE_OF_CONDUCT b/CODE_OF_CONDUCT
new file mode 100644
index 000000000..b7efcc0b3
--- /dev/null
+++ b/CODE_OF_CONDUCT
@@ -0,0 +1,126 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, caste, color, religion, or sexual
+identity and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+ and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the overall
+ community
+
+Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or advances of
+ any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email address,
+ without their explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+[INSERT CONTACT METHOD].
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+*Community Impact*: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+*Consequence*: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+*Community Impact*: A violation through a single incident or series of
+actions.
+
+*Consequence*: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or permanent
+ban.
+
+### 3. Temporary Ban
+
+*Community Impact*: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+*Consequence*: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+*Community Impact*: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior, harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+*Consequence*: A permanent ban from any sort of public interaction within the
+community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.1, available at
+[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
+
+Community Impact Guidelines were inspired by
+[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
+
+For answers to common questions about this code of conduct, see the FAQ at
+[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
+[https://www.contributor-covenant.org/translations][translations].
diff --git a/README.md b/README.md
index 7d7af852a..3f1e8109f 100644
--- a/README.md
+++ b/README.md
@@ -43,8 +43,8 @@ DB-GPT is an experimental open-source project that uses localized GPT large mode
## Contents
-- [install](#install)
-- [demo](#demo)
+- [Install](#install)
+- [Demo](#demo)
- [introduction](#introduction)
- [features](#features)
- [contribution](#contribution)
@@ -177,7 +177,7 @@ Currently, we have released multiple key features, which are listed below to dem
| [StarRocks](https://github.com/StarRocks/starrocks) | No | TODO |
## Introduction
-Is the architecture of the entire DB-GPT shown in the following figure:
+The architecture of the entire DB-GPT is shown.
@@ -213,7 +213,7 @@ The core capabilities mainly consist of the following parts:
## Contribution
-- Please run `black .` before submitting the code. Contributing guidelines, [how to contribution](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
+- Please run `black .` before submitting the code. Contributing guidelines, [how to contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
## RoadMap
@@ -330,7 +330,7 @@ As of October 10, 2023, by fine-tuning an open-source model of 13 billion parame
The MIT License (MIT)
## Contact Information
-We are working on building a community, if you have any ideas about building the community, feel free to contact us.
+We are working on building a community, if you have any ideas for building the community, feel free to contact us.
[](https://discord.gg/nASQyBjvY)
diff --git a/assets/wechat.jpg b/assets/wechat.jpg
index c7de6223b..ec465785c 100644
Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ
diff --git a/docker/compose_examples/cluster-docker-compose.yml b/docker/compose_examples/cluster-docker-compose.yml
index b41033458..0ad6be9ae 100644
--- a/docker/compose_examples/cluster-docker-compose.yml
+++ b/docker/compose_examples/cluster-docker-compose.yml
@@ -7,6 +7,16 @@ services:
restart: unless-stopped
networks:
- dbgptnet
+ api-server:
+ image: eosphorosai/dbgpt:latest
+ command: dbgpt start apiserver --controller_addr http://controller:8000
+ restart: unless-stopped
+ depends_on:
+ - controller
+ networks:
+ - dbgptnet
+ ports:
+ - 8100:8100/tcp
llm-worker:
image: eosphorosai/dbgpt:latest
command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000
diff --git a/docs/conf.py b/docs/conf.py
index 8601627ca..437aeab9b 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -32,7 +32,7 @@ extensions = [
"sphinx_panels",
"sphinx_tabs.tabs",
"IPython.sphinxext.ipython_console_highlighting",
- 'sphinx.ext.autosectionlabel'
+ "sphinx.ext.autosectionlabel",
]
source_suffix = [".ipynb", ".html", ".md", ".rst"]
diff --git a/docs/getting_started/install/cluster/cluster.rst b/docs/getting_started/install/cluster/cluster.rst
index 93660d0a4..17895e7bc 100644
--- a/docs/getting_started/install/cluster/cluster.rst
+++ b/docs/getting_started/install/cluster/cluster.rst
@@ -77,3 +77,4 @@ By analyzing this information, we can identify performance bottlenecks in model
./vms/standalone.md
./vms/index.md
+ ./openai.md
diff --git a/docs/getting_started/install/cluster/openai.md b/docs/getting_started/install/cluster/openai.md
new file mode 100644
index 000000000..8f23ba0fa
--- /dev/null
+++ b/docs/getting_started/install/cluster/openai.md
@@ -0,0 +1,51 @@
+OpenAI-Compatible RESTful APIs
+==================================
+(openai-apis-index)=
+
+### Install Prepare
+
+You must [deploy DB-GPT cluster](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html) first.
+
+### Launch Model API Server
+
+```bash
+dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY
+```
+By default, the Model API Server starts on port 8100.
+
+### Validate with cURL
+
+#### List models
+
+```bash
+curl http://127.0.0.1:8100/api/v1/models \
+-H "Authorization: Bearer EMPTY" \
+-H "Content-Type: application/json"
+```
+
+#### Chat completions
+
+```bash
+curl http://127.0.0.1:8100/api/v1/chat/completions \
+-H "Authorization: Bearer EMPTY" \
+-H "Content-Type: application/json" \
+-d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}'
+```
+
+### Validate with OpenAI Official SDK
+
+#### Chat completions
+
+```python
+import openai
+openai.api_key = "EMPTY"
+openai.api_base = "http://127.0.0.1:8100/api/v1"
+model = "vicuna-13b-v1.5"
+
+completion = openai.ChatCompletion.create(
+ model=model,
+ messages=[{"role": "user", "content": "hello"}]
+)
+# print the completion
+print(completion.choices[0].message.content)
+```
\ No newline at end of file
diff --git a/docs/getting_started/install/llm/proxyllm/proxyllm.md b/docs/getting_started/install/llm/proxyllm/proxyllm.md
index a04252d2f..fae549dd3 100644
--- a/docs/getting_started/install/llm/proxyllm/proxyllm.md
+++ b/docs/getting_started/install/llm/proxyllm/proxyllm.md
@@ -24,9 +24,12 @@ PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
#Azure
LLM_MODEL=chatgpt_proxyllm
-OPENAI_API_TYPE=azure
-PROXY_API_KEY={your-openai-sk}
-PROXY_SERVER_URL=https://xx.openai.azure.com/v1/chat/completions
+PROXY_API_KEY={your-azure-sk}
+PROXY_API_BASE=https://{your domain}.openai.azure.com/
+PROXY_API_TYPE=azure
+PROXY_SERVER_URL=xxxx
+PROXY_API_VERSION=2023-05-15
+PROXYLLM_BACKEND=gpt-35-turbo
#Aliyun tongyi
LLM_MODEL=tongyi_proxyllm
diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/openai.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/openai.po
new file mode 100644
index 000000000..0ef41aa6d
--- /dev/null
+++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/openai.po
@@ -0,0 +1,71 @@
+# SOME DESCRIPTIVE TITLE.
+# Copyright (C) 2023, csunny
+# This file is distributed under the same license as the DB-GPT package.
+# FIRST AUTHOR , 2023.
+#
+#, fuzzy
+msgid ""
+msgstr ""
+"Project-Id-Version: DB-GPT 👏👏 0.4.0\n"
+"Report-Msgid-Bugs-To: \n"
+"POT-Creation-Date: 2023-11-02 21:09+0800\n"
+"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
+"Last-Translator: FULL NAME \n"
+"Language: zh_CN\n"
+"Language-Team: zh_CN \n"
+"Plural-Forms: nplurals=1; plural=0;\n"
+"MIME-Version: 1.0\n"
+"Content-Type: text/plain; charset=utf-8\n"
+"Content-Transfer-Encoding: 8bit\n"
+"Generated-By: Babel 2.12.1\n"
+
+#: ../../getting_started/install/cluster/openai.md:1
+#: 01f4e2bf853341198633b367efec1522
+msgid "OpenAI-Compatible RESTful APIs"
+msgstr "OpenAI RESTful 兼容接口"
+
+#: ../../getting_started/install/cluster/openai.md:5
+#: d8717e42335e4027bf4e76b3d28768ee
+msgid "Install Prepare"
+msgstr "安装准备"
+
+#: ../../getting_started/install/cluster/openai.md:7
+#: 9a48d8ee116942468de4c6faf9a64758
+msgid ""
+"You must [deploy DB-GPT cluster](https://db-"
+"gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html)"
+" first."
+msgstr "你必须先部署 [DB-GPT 集群]"
+"(https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh-cn/latest/getting_started/install/cluster/vms/index.html)。"
+
+#: ../../getting_started/install/cluster/openai.md:9
+#: 7673a7121f004f7ca6b1a94a7e238fa3
+msgid "Launch Model API Server"
+msgstr "启动模型 API Server"
+
+#: ../../getting_started/install/cluster/openai.md:14
+#: 84a925c2cbcd4e4895a1d2d2fe8f720f
+msgid "By default, the Model API Server starts on port 8100."
+msgstr "默认情况下,模型 API Server 使用 8100 端口启动。"
+
+#: ../../getting_started/install/cluster/openai.md:16
+#: e53ed41977cd4721becd51eba05c6609
+msgid "Validate with cURL"
+msgstr "通过 cURL 验证"
+
+#: ../../getting_started/install/cluster/openai.md:18
+#: 7c883b410b5c4e53a256bf17c1ded80d
+msgid "List models"
+msgstr "列出模型"
+
+#: ../../getting_started/install/cluster/openai.md:26
+#: ../../getting_started/install/cluster/openai.md:37
+#: 7cf0ed13f0754f149ec085cd6cf7a45a 990d5d5ed5d64ab49550e68495b9e7a0
+msgid "Chat completions"
+msgstr ""
+
+#: ../../getting_started/install/cluster/openai.md:35
+#: 81583edd22df44e091d18a0832278131
+msgid "Validate with OpenAI Official SDK"
+msgstr "通过 OpenAI 官方 SDK 验证"
+
diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po
index 1f448ee4e..addf53bc6 100644
--- a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po
+++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2023-11-02 10:10+0800\n"
+"POT-Creation-Date: 2023-11-02 21:04+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language: zh_CN\n"
@@ -20,292 +20,292 @@ msgstr ""
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/environment/environment.md:1
-#: 28d2f84fc8884e78afad8118cd59c654
+#: a17719d2f4374285a7beb4d1db470146
#, fuzzy
msgid "Environment Parameter"
msgstr "环境变量说明"
#: ../../getting_started/install/environment/environment.md:4
-#: c83fbb5e1aa643cdb09fffe7f3d1a3c5
+#: 9a62e6fff7914eeaa2d195ddef4fcb61
msgid "LLM MODEL Config"
msgstr "模型配置"
#: ../../getting_started/install/environment/environment.md:5
-#: eb675965ae57407e8d8bf90fed8e9e2a
+#: 90e3991538324ecfac8cac7ef2103ac2
msgid "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
msgstr "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
#: ../../getting_started/install/environment/environment.md:6
-#: 5d28d35126d849ea9b0d963fd1ba8699
+#: 1f45af01100c4586acbc05469e3006bc
msgid "LLM_MODEL=vicuna-13b"
msgstr "LLM_MODEL=vicuna-13b"
#: ../../getting_started/install/environment/environment.md:8
-#: 01955b2d0fbe4d94939ebf2cbb380bdd
+#: bed14b704f154c2db525f7fafd3aa5a4
msgid "MODEL_SERVER_ADDRESS"
msgstr "MODEL_SERVER_ADDRESS"
#: ../../getting_started/install/environment/environment.md:9
-#: 4eaaa9ab59854c0386b28b3111c82784
+#: ea42946cfe4f4ad996bf82c1996e7344
msgid "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY"
msgstr "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY"
#: ../../getting_started/install/environment/environment.md:12
-#: 5c2dd05e16834443b7451c2541b59757
+#: 021c261231f342fdba34098b1baa06fd
msgid "LIMIT_MODEL_CONCURRENCY=5"
msgstr "LIMIT_MODEL_CONCURRENCY=5"
#: ../../getting_started/install/environment/environment.md:14
-#: 7707836c2fb04e7da13d2d59b5f9566f
+#: afaf0ba7fd09463d8ff74b514ed7264c
msgid "MAX_POSITION_EMBEDDINGS"
msgstr "MAX_POSITION_EMBEDDINGS"
#: ../../getting_started/install/environment/environment.md:16
-#: ee24a7d3d8384e61b715ef3bd362b965
+#: e4517a942bca4361a64a00408f993f5b
msgid "MAX_POSITION_EMBEDDINGS=4096"
msgstr "MAX_POSITION_EMBEDDINGS=4096"
#: ../../getting_started/install/environment/environment.md:18
-#: 90b51aa4e46b4d1298c672e0052c2f68
+#: 78d2ef04ed4548b9b7b0fb8ae35c9d5c
msgid "QUANTIZE_QLORA"
msgstr "QUANTIZE_QLORA"
#: ../../getting_started/install/environment/environment.md:20
-#: 7de7a8eb431e4973ae00f68ca0686281
+#: bfa65db03c6d46bba293331f03ab15ac
msgid "QUANTIZE_QLORA=True"
msgstr "QUANTIZE_QLORA=True"
#: ../../getting_started/install/environment/environment.md:22
-#: e331ca016a474f4aa4e9182165a2693a
+#: 1947d45a7f184821910b4834ad5f1897
msgid "QUANTIZE_8bit"
msgstr "QUANTIZE_8bit"
#: ../../getting_started/install/environment/environment.md:24
-#: 519ccce5a0884778be2719c437a17bd4
+#: 4a2ee2919d0e4bdaa13c9d92eefd2aac
msgid "QUANTIZE_8bit=True"
msgstr "QUANTIZE_8bit=True"
#: ../../getting_started/install/environment/environment.md:27
-#: 1c0586d070f046de8d0f9f94a6b508b4
+#: 348dc1e411b54ab09414f40a20e934e4
msgid "LLM PROXY Settings"
msgstr "LLM PROXY Settings"
#: ../../getting_started/install/environment/environment.md:28
-#: c208c3f4b13f4b39962de814e5be6ab9
+#: a692e78425a040f5828ab54ff9a33f77
msgid "OPENAI Key"
msgstr "OPENAI Key"
#: ../../getting_started/install/environment/environment.md:30
-#: 9228bbee2faa4467b1d24f1125faaac8
+#: 940d00e25a424acf92951a314a64e5ea
msgid "PROXY_API_KEY={your-openai-sk}"
msgstr "PROXY_API_KEY={your-openai-sk}"
#: ../../getting_started/install/environment/environment.md:31
-#: 759ae581883348019c1ba79e8954728a
+#: 4bd27547ae6041679e91f2a363cd1deb
msgid "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions"
msgstr "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions"
#: ../../getting_started/install/environment/environment.md:33
-#: 83f3952917d34aab80bd34119f7d1e20
+#: cfa3071afb0b47baad6bd729d4a02cb9
msgid "from https://bard.google.com/ f12-> application-> __Secure-1PSID"
msgstr "from https://bard.google.com/ f12-> application-> __Secure-1PSID"
#: ../../getting_started/install/environment/environment.md:35
-#: 1d70707ca82749bb90b2bed1aee44d62
+#: a17efa03b10f47f68afac9e865982a75
msgid "BARD_PROXY_API_KEY={your-bard-token}"
msgstr "BARD_PROXY_API_KEY={your-bard-token}"
#: ../../getting_started/install/environment/environment.md:38
-#: 38a2091fa223493ea23cb9bbb33cf58e
+#: 6bcfe90574da4d82a459e8e11bf73cba
msgid "DATABASE SETTINGS"
msgstr "DATABASE SETTINGS"
#: ../../getting_started/install/environment/environment.md:39
-#: 5134180d7a5945b48b072a1eb92b27ba
+#: 2b1e62d9bf5d4af5a22f68c8248eaafb
msgid "SQLite database (Current default database)"
msgstr "SQLite database (Current default database)"
#: ../../getting_started/install/environment/environment.md:40
-#: 6875e2300e094668a45fa4f2551e0d30
+#: 8a909ac3b3c943da8dbc4e8dd596c80c
msgid "LOCAL_DB_PATH=data/default_sqlite.db"
msgstr "LOCAL_DB_PATH=data/default_sqlite.db"
#: ../../getting_started/install/environment/environment.md:41
-#: 034e8f06f24f44af9d8184563f99b4b3
+#: 90ae6507932f4815b6e180051738bb93
msgid "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite"
msgstr "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite"
#: ../../getting_started/install/environment/environment.md:43
-#: f688149a97f740269f80b79775236ce9
+#: d2ce34e0dcf44ccf9e8007d548ba7b0a
msgid "MYSQL database"
msgstr "MYSQL database"
#: ../../getting_started/install/environment/environment.md:44
-#: 6db0b305137d45a3aa036e4f2262f460
+#: c07159d63c334f6cbb95fcc30bfb7ea5
msgid "LOCAL_DB_TYPE=mysql"
msgstr "LOCAL_DB_TYPE=mysql"
#: ../../getting_started/install/environment/environment.md:45
-#: b6d662ce8d5f44f0b54a7f6e7c66f5a5
+#: e16700b2ea8d411e91d010c1cde7aecc
msgid "LOCAL_DB_USER=root"
msgstr "LOCAL_DB_USER=root"
#: ../../getting_started/install/environment/environment.md:46
-#: cd7493d61ac9415283640dc6c018d2f4
+#: bfc2dce1bf374121b6861e677b4e1ffa
msgid "LOCAL_DB_PASSWORD=aa12345678"
msgstr "LOCAL_DB_PASSWORD=aa12345678"
#: ../../getting_started/install/environment/environment.md:47
-#: 4ea2a622b23f4342a4c2ab7f8d9c4e8d
+#: bc384739f5b04e21a34d0d2b78e7906c
msgid "LOCAL_DB_HOST=127.0.0.1"
msgstr "LOCAL_DB_HOST=127.0.0.1"
#: ../../getting_started/install/environment/environment.md:48
-#: 936db95a0ab246098028f4dbb596cd17
+#: e5253d452e0d42b7ac308fe6fbfb5017
msgid "LOCAL_DB_PORT=3306"
msgstr "LOCAL_DB_PORT=3306"
#: ../../getting_started/install/environment/environment.md:51
-#: d9255f25989840ea9c9e7b34f3947c87
+#: 9ca8f6fe06ed4cbab390f94be252e165
msgid "EMBEDDING SETTINGS"
msgstr "EMBEDDING SETTINGS"
#: ../../getting_started/install/environment/environment.md:52
-#: b09291d32aca43928a981e873476a985
+#: 76c7c260293c4b49bae057143fd48377
msgid "EMBEDDING MODEL Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
msgstr "EMBEDDING模型, 参考see /pilot/configs/model_config.LLM_MODEL_CONFIG"
#: ../../getting_started/install/environment/environment.md:53
-#: 63de573b03a54413b997f18a1ccee279
+#: f1d63a0128ce493cae37d34f1976bcca
msgid "EMBEDDING_MODEL=text2vec"
msgstr "EMBEDDING_MODEL=text2vec"
#: ../../getting_started/install/environment/environment.md:55
-#: 0ef8defbab544bd0b9475a036f278489
+#: b8fbb99109d04781b2dd5bc5d6efa5bd
msgid "Embedding Chunk size, default 500"
msgstr "Embedding 切片大小, 默认500"
#: ../../getting_started/install/environment/environment.md:57
-#: 33dbc7941d054baa8c6ecfc0bf1ce271
+#: bf8256576ea34f6a9c5f261ab9aab676
msgid "KNOWLEDGE_CHUNK_SIZE=500"
msgstr "KNOWLEDGE_CHUNK_SIZE=500"
#: ../../getting_started/install/environment/environment.md:59
-#: e6ee9f2620ab45ecbc8e9c0642f5ca42
+#: 9b156c6b599b4c02a58ce023b4ff25f2
msgid "Embedding Chunk Overlap, default 100"
msgstr "Embedding chunk Overlap, 文本块之间的最大重叠量。保留一些重叠可以保持文本块之间的连续性(例如使用滑动窗口),默认100"
#: ../../getting_started/install/environment/environment.md:60
-#: fcddf64340a04df4ab95176fc2fc67a6
+#: dcafd903c36041ac85ac99a14dbee512
msgid "KNOWLEDGE_CHUNK_OVERLAP=100"
msgstr "KNOWLEDGE_CHUNK_OVERLAP=100"
#: ../../getting_started/install/environment/environment.md:62
-#: 61272200194b4461a921581feb1273da
+#: 6c3244b7e5e24b0188c7af4bb52e9134
#, fuzzy
msgid "embedding recall top k,5"
msgstr "embedding 召回topk, 默认5"
#: ../../getting_started/install/environment/environment.md:64
-#: b433091f055542b1b89ff2d525ac99e4
+#: f4a2f30551cf4fe1a7ff3c7c74ec77be
msgid "KNOWLEDGE_SEARCH_TOP_SIZE=5"
msgstr "KNOWLEDGE_SEARCH_TOP_SIZE=5"
#: ../../getting_started/install/environment/environment.md:66
-#: 1db0de41aebd4caa8cc2eaecb4cacd6a
+#: 593f2512362f467e92fdaa60dd5903a0
#, fuzzy
msgid "embedding recall max token ,2000"
msgstr "embedding向量召回最大token, 默认2000"
#: ../../getting_started/install/environment/environment.md:68
-#: 81b9d862e58941a4b09680a7520cdabe
+#: 83d6d28914be4d6282d457272e508ddc
msgid "KNOWLEDGE_SEARCH_MAX_TOKEN=5"
msgstr "KNOWLEDGE_SEARCH_MAX_TOKEN=5"
#: ../../getting_started/install/environment/environment.md:71
#: ../../getting_started/install/environment/environment.md:87
-#: cac73575d54544778bdee09b18532fd9 f78a509949a64f03aa330f31901e2e7a
+#: 6bc1b9d995e74294a1c78e783c550db7 d33c77ded834438e9f4a2df06e7e041a
msgid "Vector Store SETTINGS"
msgstr "Vector Store SETTINGS"
#: ../../getting_started/install/environment/environment.md:72
#: ../../getting_started/install/environment/environment.md:88
-#: 5ebba1cb047b4b09849000244237dfbb 7e9285e91bcb4b2d9413909c0d0a06a7
+#: 9cafa06e2d584f70afd848184e0fa52a f01057251b8b4ffea806192dfe1048ed
msgid "Chroma"
msgstr "Chroma"
#: ../../getting_started/install/environment/environment.md:73
#: ../../getting_started/install/environment/environment.md:89
-#: 05625cfcc23c4745ae1fa0d94ce5450c 3a8615f1507f4fc49d1adda5100a4edf
+#: e6c16fab37484769b819aeecbc13e6db faad299722e5400e95ec6ac3c1e018b8
msgid "VECTOR_STORE_TYPE=Chroma"
msgstr "VECTOR_STORE_TYPE=Chroma"
#: ../../getting_started/install/environment/environment.md:74
#: ../../getting_started/install/environment/environment.md:90
-#: 5b559376aea44f159262e6d4b75c7ec1 e954782861404b10b4e893e01cf74452
+#: 4eca3a51716d406f8ffd49c06550e871 581ee9dd38064b119660c44bdd00cbaa
msgid "MILVUS"
msgstr "MILVUS"
#: ../../getting_started/install/environment/environment.md:75
#: ../../getting_started/install/environment/environment.md:91
-#: 55ee8199c97a4929aeefd32370f2b92d 8f40c02543ea4a2ca9632dd9e8a08c2e
+#: 814c93048bed46589358a854d6c99683 b72b1269a2224f5f961214e41c019f21
msgid "VECTOR_STORE_TYPE=Milvus"
msgstr "VECTOR_STORE_TYPE=Milvus"
#: ../../getting_started/install/environment/environment.md:76
#: ../../getting_started/install/environment/environment.md:92
-#: 528a01d25720491c8e086bf43a62ad92 ba1386d551d7494a85681a2803081a6f
+#: 73ae665f1db9402883662734588fd02c c4da20319c994e83ba5a7706db967178
msgid "MILVUS_URL=127.0.0.1"
msgstr "MILVUS_URL=127.0.0.1"
#: ../../getting_started/install/environment/environment.md:77
#: ../../getting_started/install/environment/environment.md:93
-#: b031950dafcd4d4783c120dc933c4178 c2e9c8cdd41741e3aba01e59a6ef245d
+#: e30c5288516d42aa858a485db50490c1 f843b2e58bcb4e4594e3c28499c341d0
msgid "MILVUS_PORT=19530"
msgstr "MILVUS_PORT=19530"
#: ../../getting_started/install/environment/environment.md:78
#: ../../getting_started/install/environment/environment.md:94
-#: 27b0a64af6434cb2840373e2b38c9bd5 d0e4d79af7954b129ffff7303a1ec3ce
+#: 158669efcc7d4bcaac1c8dd01b499029 24e88ffd32f242f281c56c0ec3ad2639
msgid "MILVUS_USERNAME"
msgstr "MILVUS_USERNAME"
#: ../../getting_started/install/environment/environment.md:79
#: ../../getting_started/install/environment/environment.md:95
-#: 27aa1a5b61e64dd6bfe29124e274809e 5c58892498ce4f46a59f54b2887822d4
+#: 111a985297184c8aa5a0dd8e14a58445 6602093a6bb24d6792548e2392105c82
msgid "MILVUS_PASSWORD"
msgstr "MILVUS_PASSWORD"
#: ../../getting_started/install/environment/environment.md:80
#: ../../getting_started/install/environment/environment.md:96
-#: 009e57d4acc5434da2146f0545911c85 bac8888dcbff47fbb0ea8ae685445aac
+#: 47bdfcd78fbe4ccdb5f49b717a6d01a6 b96c0545b2044926a8a8190caf94ad25
msgid "MILVUS_SECURE="
msgstr "MILVUS_SECURE="
#: ../../getting_started/install/environment/environment.md:82
#: ../../getting_started/install/environment/environment.md:98
-#: a6eeb16ab5274045bee88ecc3d93e09e eb341774403d47658b9b7e94c4c16d5c
+#: 755c32b5d6c54607907a138b5474c0ec ff4f2a7ddaa14f089dda7a14e1062c36
msgid "WEAVIATE"
msgstr "WEAVIATE"
#: ../../getting_started/install/environment/environment.md:83
-#: fbd97522d8da4824b41b99298fd41069
+#: 23b2ce83385d40a589a004709f9864be
msgid "VECTOR_STORE_TYPE=Weaviate"
msgstr "VECTOR_STORE_TYPE=Weaviate"
#: ../../getting_started/install/environment/environment.md:84
#: ../../getting_started/install/environment/environment.md:99
-#: 341785b4abfe42b5af1c2e04497261f4 a81cc2aabc8240f3ac1f674d9350bff4
+#: 9acef304d89a448a9e734346705ba872 cf5151b6c1594ccd8beb1c3f77769acb
msgid "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network"
msgstr "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network"
#: ../../getting_started/install/environment/environment.md:102
-#: 5bb9e5daa36241d499089c1b1910f729
+#: c3003516b2364051bf34f8c3086e348a
msgid "Multi-GPU Setting"
msgstr "Multi-GPU Setting"
#: ../../getting_started/install/environment/environment.md:103
-#: 30df45b7f1f7423c9f18c6360f0b7600
+#: ade8fc381c5e438aa29d159c10041713
msgid ""
"See https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-"
"visibility-cuda_visible_devices/ If CUDA_VISIBLE_DEVICES is not "
@@ -315,49 +315,49 @@ msgstr ""
"cuda_visible_devices/ 如果 CUDA_VISIBLE_DEVICES没有设置, 会使用所有可用的gpu"
#: ../../getting_started/install/environment/environment.md:106
-#: 8631ea968dfb4d90a7ae6bdb2acdfdce
+#: e137bd19be5e410ba6709027dbf2923a
msgid "CUDA_VISIBLE_DEVICES=0"
msgstr "CUDA_VISIBLE_DEVICES=0"
#: ../../getting_started/install/environment/environment.md:108
-#: 0010422280dd4fe79326ebceb2a66f0e
+#: 7669947acbdc4b1d92bcc029a8353a5d
msgid ""
"Optionally, you can also specify the gpu ID to use before the starting "
"command"
msgstr "你也可以通过启动命令设置gpu ID"
#: ../../getting_started/install/environment/environment.md:110
-#: 00106f7341304fbd9425721ea8e6a261
+#: 751743d1753b4051beea46371278d793
msgid "CUDA_VISIBLE_DEVICES=3,4,5,6"
msgstr "CUDA_VISIBLE_DEVICES=3,4,5,6"
#: ../../getting_started/install/environment/environment.md:112
-#: 720aa8b3478744d78e4b10dfeccb50b4
+#: 3acc3de0af0d4df2bb575e161e377f85
msgid "You can configure the maximum memory used by each GPU."
msgstr "可以设置GPU的最大内存"
#: ../../getting_started/install/environment/environment.md:114
-#: f9639ac96a244296832c75bbcbdae2af
+#: 67f1d9b172b84294a44ecace5436e6e0
msgid "MAX_GPU_MEMORY=16Gib"
msgstr "MAX_GPU_MEMORY=16Gib"
#: ../../getting_started/install/environment/environment.md:117
-#: fc4d955fdb3e4256af5c8f29b042dcd6
+#: 3c69dfe48bcf46b89b76cac1e7849a66
msgid "Other Setting"
msgstr "Other Setting"
#: ../../getting_started/install/environment/environment.md:118
-#: 66b14a834e884339be2d48392e884933
+#: d5015b70f4fe4d20a63de9d87f86957a
msgid "Language Settings(influence prompt language)"
msgstr "Language Settings(涉及prompt语言以及知识切片方式)"
#: ../../getting_started/install/environment/environment.md:119
-#: 5c9f05174eb84edd9e1316cc0721a840
+#: 5543c28bb8e34c9fb3bb6b063c2b1750
msgid "LANGUAGE=en"
msgstr "LANGUAGE=en"
#: ../../getting_started/install/environment/environment.md:120
-#: 7f6d62117d024c51bba9255fa4fcf151
+#: cb4ed5b892ee41068c1ca76cb29aa400
msgid "LANGUAGE=zh"
msgstr "LANGUAGE=zh"
diff --git a/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po b/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po
index 032e81609..bb2dae7af 100644
--- a/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po
+++ b/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 0.3.0\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2023-11-02 10:10+0800\n"
+"POT-Creation-Date: 2023-11-02 21:04+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language: zh_CN\n"
@@ -19,11 +19,11 @@ msgstr ""
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
-#: ../../modules/knowledge.md:1 436b94d3a8374ed18feb5c14893a84e6
+#: ../../modules/knowledge.md:1 b94b3b15cb2441ed9d78abd222a717b7
msgid "Knowledge"
msgstr "知识"
-#: ../../modules/knowledge.md:3 918a3747cbed42d18b8c9c4547e67b14
+#: ../../modules/knowledge.md:3 c6d6e308a6ce42948d29e928136ef561
#, fuzzy
msgid ""
"As the knowledge base is currently the most significant user demand "
@@ -34,15 +34,15 @@ msgstr ""
"由于知识库是当前用户需求最显著的场景,我们原生支持知识库的构建和处理。同时,我们还在本项目中提供了多种知识库管理策略,如:pdf,md , "
"txt, word, ppt"
-#: ../../modules/knowledge.md:4 d4d4b5d57918485aafa457bb9fdcf626
+#: ../../modules/knowledge.md:4 268abc408d40410ba90cf5f121dc5270
msgid "Default built-in knowledge base"
msgstr ""
-#: ../../modules/knowledge.md:5 d4d4b5d57918485aafa457bb9fdcf626
+#: ../../modules/knowledge.md:5 558c3364c38b458a8ebf81030efc2a48
msgid "Custom addition of knowledge bases"
msgstr ""
-#: ../../modules/knowledge.md:6 984361ce835c4c3492e29e1fb897348a
+#: ../../modules/knowledge.md:6 9cb3ce62da1440579c095848c7aef88c
msgid ""
"Various usage scenarios such as constructing knowledge bases through "
"plugin capabilities and web crawling. Users only need to organize the "
@@ -50,53 +50,53 @@ msgid ""
"the knowledge base required for the large model."
msgstr ""
-#: ../../modules/knowledge.md:9 746e4fbd3212460198be51b90caee2c8
+#: ../../modules/knowledge.md:9 b8ca6bc4dd9845baa56e36eea7fac2a2
#, fuzzy
msgid "Create your own knowledge repository"
msgstr "创建你自己的知识库"
-#: ../../modules/knowledge.md:11 1c46b33b0532417c824efbaa3e687c3f
+#: ../../modules/knowledge.md:11 17d7178a67924f43aa5b6293707ef041
msgid ""
"1.Place personal knowledge files or folders in the pilot/datasets "
"directory."
msgstr ""
-#: ../../modules/knowledge.md:13 3b16f387b5354947a89d6df77bd65bdb
+#: ../../modules/knowledge.md:13 31c31f14bf444981939689f9a9fb038a
#, fuzzy
msgid ""
"We currently support many document formats: txt, pdf, md, html, doc, ppt,"
" and url."
msgstr "当前支持txt, pdf, md, html, doc, ppt, url文档格式"
-#: ../../modules/knowledge.md:15 09ec337d7da4418db854e58afb6c0980
+#: ../../modules/knowledge.md:15 9ad2f2e05f8842a9b9d8469a3704df23
msgid "before execution:"
msgstr "开始前"
-#: ../../modules/knowledge.md:22 c09b3decb018485f8e56830ddc156194
+#: ../../modules/knowledge.md:22 6fd2775914b641c4b8e486417b558ea6
msgid ""
"2.Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma "
"(now only support Chroma and Milvus, if you set Milvus, please set "
"MILVUS_URL and MILVUS_PORT)"
msgstr ""
-#: ../../modules/knowledge.md:25 74460ec7709441d5945ce9f745a26d20
+#: ../../modules/knowledge.md:25 131c5f58898a4682940910980edb2043
msgid "2.Run the knowledge repository initialization command"
msgstr ""
-#: ../../modules/knowledge.md:31 4498ec4e46ff4e24b45dd855e829bd32
+#: ../../modules/knowledge.md:31 2cf550f17881497bb881b19efcc18c23
msgid ""
"Optionally, you can run `dbgpt knowledge load --help` command to see more"
" usage."
msgstr ""
-#: ../../modules/knowledge.md:33 5048ac3289e540f2a2b5fd0e5ed043f5
+#: ../../modules/knowledge.md:33 c8a2ea571b944bdfbcad48fa8b54fcc9
msgid ""
"3.Add the knowledge repository in the interface by entering the name of "
"your knowledge repository (if not specified, enter \"default\") so you "
"can use it for Q&A based on your knowledge base."
msgstr ""
-#: ../../modules/knowledge.md:35 deeccff20f7f453dad0881b63dae2a18
+#: ../../modules/knowledge.md:35 b701170ad75e49dea7d7734c15681e0f
msgid ""
"Note that the default vector model used is text2vec-large-chinese (which "
"is a large model, so if your personal computer configuration is not "
diff --git a/pilot/base_modules/agent/db/__init__.py b/pilot/base_modules/agent/db/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/component.py b/pilot/component.py
index 3826a5e48..16013ee17 100644
--- a/pilot/component.py
+++ b/pilot/component.py
@@ -46,6 +46,8 @@ class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
+ MODEL_REGISTRY = "dbgpt_model_registry"
+ MODEL_API_SERVER = "dbgpt_model_api_server"
AGENT_HUB = "dbgpt_agent_hub"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
TRACER = "dbgpt_tracer"
@@ -69,7 +71,6 @@ class BaseComponent(LifeCycle, ABC):
This method needs to be implemented by every component to define how it integrates
with the main system app.
"""
- pass
T = TypeVar("T", bound=BaseComponent)
@@ -91,13 +92,28 @@ class SystemApp(LifeCycle):
"""Returns the internal ASGI app."""
return self._asgi_app
- def register(self, component: Type[BaseComponent], *args, **kwargs):
- """Register a new component by its type."""
+ def register(self, component: Type[BaseComponent], *args, **kwargs) -> T:
+ """Register a new component by its type.
+
+ Args:
+ component (Type[BaseComponent]): The component class to register
+
+ Returns:
+ T: The instance of registered component
+ """
instance = component(self, *args, **kwargs)
self.register_instance(instance)
+ return instance
- def register_instance(self, instance: T):
- """Register an already initialized component."""
+ def register_instance(self, instance: T) -> T:
+ """Register an already initialized component.
+
+ Args:
+ instance (T): The component instance to register
+
+ Returns:
+ T: The instance of registered component
+ """
name = instance.name
if isinstance(name, ComponentType):
name = name.value
@@ -108,18 +124,34 @@ class SystemApp(LifeCycle):
logger.info(f"Register component with name {name} and instance: {instance}")
self.components[name] = instance
instance.init_app(self)
+ return instance
def get_component(
self,
name: Union[str, ComponentType],
component_type: Type[T],
default_component=_EMPTY_DEFAULT_COMPONENT,
+ or_register_component: Type[BaseComponent] = None,
+ *args,
+ **kwargs,
) -> T:
- """Retrieve a registered component by its name and type."""
+ """Retrieve a registered component by its name and type.
+
+ Args:
+ name (Union[str, ComponentType]): Component name
+ component_type (Type[T]): The type of current retrieve component
+ default_component : The default component instance if not retrieve by name
+ or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
+
+ Returns:
+ T: The instance retrieved by component name
+ """
if isinstance(name, ComponentType):
name = name.value
component = self.components.get(name)
if not component:
+ if or_register_component:
+ return self.register(or_register_component, *args, **kwargs)
if default_component != _EMPTY_DEFAULT_COMPONENT:
return default_component
raise ValueError(f"No component found with name {name}")
diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py
index e1575ea03..0e1fb3d40 100644
--- a/pilot/configs/model_config.py
+++ b/pilot/configs/model_config.py
@@ -78,6 +78,10 @@ LLM_MODEL_CONFIG = {
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
+ "codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"),
+ "codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"),
+ "codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"),
+ "codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
# For test now
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
}
diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py
index 69b159a13..5ce5b2173 100644
--- a/pilot/model/adapter.py
+++ b/pilot/model/adapter.py
@@ -320,6 +320,19 @@ class Llama2Adapter(BaseLLMAdaper):
return model, tokenizer
+class CodeLlamaAdapter(BaseLLMAdaper):
+ """The model adapter for codellama"""
+
+ def match(self, model_path: str):
+ return "codellama" in model_path.lower()
+
+ def loader(self, model_path: str, from_pretrained_kwargs: dict):
+ model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ return model, tokenizer
+
+
class BaichuanAdapter(BaseLLMAdaper):
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""
@@ -420,6 +433,7 @@ register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
register_llm_model_adapters(Llama2Adapter)
+register_llm_model_adapters(CodeLlamaAdapter)
register_llm_model_adapters(BaichuanAdapter)
register_llm_model_adapters(WizardLMAdapter)
register_llm_model_adapters(LlamaCppAdapater)
diff --git a/pilot/model/base.py b/pilot/model/base.py
index e89b243c9..48480b94b 100644
--- a/pilot/model/base.py
+++ b/pilot/model/base.py
@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
from enum import Enum
-from typing import TypedDict, Optional, Dict, List
+from typing import TypedDict, Optional, Dict, List, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from pilot.utils.parameter_utils import ParameterDescription
@@ -52,6 +52,8 @@ class ModelOutput:
text: str
error_code: int
model_context: Dict = None
+ finish_reason: str = None
+ usage: Dict[str, Any] = None
def to_dict(self) -> Dict:
return asdict(self)
diff --git a/pilot/model/cli.py b/pilot/model/cli.py
index 1030adfc2..79b47db82 100644
--- a/pilot/model/cli.py
+++ b/pilot/model/cli.py
@@ -8,6 +8,7 @@ from pilot.configs.model_config import LOGDIR
from pilot.model.base import WorkerApplyType
from pilot.model.parameter import (
ModelControllerParameters,
+ ModelAPIServerParameters,
ModelWorkerParameters,
ModelParameters,
BaseParameters,
@@ -441,15 +442,27 @@ def stop_model_worker(port: int):
@click.command(name="apiserver")
+@EnvArgumentParser.create_click_option(ModelAPIServerParameters)
def start_apiserver(**kwargs):
- """Start apiserver(TODO)"""
- raise NotImplementedError
+ """Start apiserver"""
+
+ if kwargs["daemon"]:
+ log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log")
+ _run_current_with_daemon("ModelAPIServer", log_file)
+ else:
+ from pilot.model.cluster import run_apiserver
+
+ run_apiserver()
@click.command(name="apiserver")
-def stop_apiserver(**kwargs):
- """Start apiserver(TODO)"""
- raise NotImplementedError
+@add_stop_server_options
+def stop_apiserver(port: int):
+ """Stop apiserver"""
+ name = "ModelAPIServer"
+ if port:
+ name = f"{name}-{port}"
+ _stop_service("apiserver", name, port=port)
def _stop_all_model_server(**kwargs):
diff --git a/pilot/model/cluster/__init__.py b/pilot/model/cluster/__init__.py
index 9937ffa0b..a777a8d4b 100644
--- a/pilot/model/cluster/__init__.py
+++ b/pilot/model/cluster/__init__.py
@@ -21,6 +21,7 @@ from pilot.model.cluster.controller.controller import (
run_model_controller,
BaseModelController,
)
+from pilot.model.cluster.apiserver.api import run_apiserver
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
@@ -40,4 +41,5 @@ __all__ = [
"ModelRegistryClient",
"RemoteWorkerManager",
"run_model_controller",
+ "run_apiserver",
]
diff --git a/pilot/model/cluster/apiserver/__init__.py b/pilot/model/cluster/apiserver/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/model/cluster/apiserver/api.py b/pilot/model/cluster/apiserver/api.py
new file mode 100644
index 000000000..148a51eed
--- /dev/null
+++ b/pilot/model/cluster/apiserver/api.py
@@ -0,0 +1,443 @@
+"""A server that provides OpenAI-compatible RESTful APIs. It supports:
+- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
+
+Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
+"""
+from typing import Optional, List, Dict, Any, Generator
+
+import logging
+import asyncio
+import shortuuid
+import json
+from fastapi import APIRouter, FastAPI
+from fastapi import Depends, HTTPException
+from fastapi.exceptions import RequestValidationError
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse
+from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
+
+from pydantic import BaseSettings
+
+from fastchat.protocol.openai_api_protocol import (
+ ChatCompletionResponse,
+ ChatCompletionResponseStreamChoice,
+ ChatCompletionStreamResponse,
+ ChatMessage,
+ ChatCompletionResponseChoice,
+ DeltaMessage,
+ EmbeddingsRequest,
+ EmbeddingsResponse,
+ ErrorResponse,
+ ModelCard,
+ ModelList,
+ ModelPermission,
+ UsageInfo,
+)
+from fastchat.protocol.api_protocol import (
+ APIChatCompletionRequest,
+ APITokenCheckRequest,
+ APITokenCheckResponse,
+ APITokenCheckResponseItem,
+)
+from fastchat.serve.openai_api_server import create_error_response, check_requests
+from fastchat.constants import ErrorCode
+
+from pilot.component import BaseComponent, ComponentType, SystemApp
+from pilot.utils.parameter_utils import EnvArgumentParser
+from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
+from pilot.model.base import ModelInstance, ModelOutput
+from pilot.model.parameter import ModelAPIServerParameters, WorkerType
+from pilot.model.cluster import ModelRegistry, ModelRegistryClient
+from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
+from pilot.utils.utils import setup_logging
+
+logger = logging.getLogger(__name__)
+
+
+class APIServerException(Exception):
+ def __init__(self, code: int, message: str):
+ self.code = code
+ self.message = message
+
+
+class APISettings(BaseSettings):
+ api_keys: Optional[List[str]] = None
+
+
+api_settings = APISettings()
+get_bearer_token = HTTPBearer(auto_error=False)
+
+
+async def check_api_key(
+ auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
+) -> str:
+ if api_settings.api_keys:
+ if auth is None or (token := auth.credentials) not in api_settings.api_keys:
+ raise HTTPException(
+ status_code=401,
+ detail={
+ "error": {
+ "message": "",
+ "type": "invalid_request_error",
+ "param": None,
+ "code": "invalid_api_key",
+ }
+ },
+ )
+ return token
+ else:
+ # api_keys not set; allow all
+ return None
+
+
+class APIServer(BaseComponent):
+ name = ComponentType.MODEL_API_SERVER
+
+ def init_app(self, system_app: SystemApp):
+ self.system_app = system_app
+
+ def get_worker_manager(self) -> WorkerManager:
+ """Get the worker manager component instance
+
+ Raises:
+ APIServerException: If can't get worker manager component instance
+ """
+ worker_manager = self.system_app.get_component(
+ ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
+ ).create()
+ if not worker_manager:
+ raise APIServerException(
+ ErrorCode.INTERNAL_ERROR,
+ f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app",
+ )
+ return worker_manager
+
+ def get_model_registry(self) -> ModelRegistry:
+ """Get the model registry component instance
+
+ Raises:
+ APIServerException: If can't get model registry component instance
+ """
+
+ controller = self.system_app.get_component(
+ ComponentType.MODEL_REGISTRY, ModelRegistry
+ )
+ if not controller:
+ raise APIServerException(
+ ErrorCode.INTERNAL_ERROR,
+ f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app",
+ )
+ return controller
+
+ async def get_model_instances_or_raise(
+ self, model_name: str
+ ) -> List[ModelInstance]:
+ """Get healthy model instances with request model name
+
+ Args:
+ model_name (str): Model name
+
+ Raises:
+ APIServerException: If can't get healthy model instances with request model name
+ """
+ registry = self.get_model_registry()
+ registry_model_name = f"{model_name}@llm"
+ model_instances = await registry.get_all_instances(
+ registry_model_name, healthy_only=True
+ )
+ if not model_instances:
+ all_instances = await registry.get_all_model_instances(healthy_only=True)
+ models = [
+ ins.model_name.split("@llm")[0]
+ for ins in all_instances
+ if ins.model_name.endswith("@llm")
+ ]
+ if models:
+ models = "&&".join(models)
+ message = f"Only {models} allowed now, your model {model_name}"
+ else:
+ message = f"No models allowed now, your model {model_name}"
+ raise APIServerException(ErrorCode.INVALID_MODEL, message)
+ return model_instances
+
+ async def get_available_models(self) -> ModelList:
+ """Return available models
+
+ Just include LLM and embedding models.
+
+ Returns:
+ List[ModelList]: The list of models.
+ """
+ registry = self.get_model_registry()
+ model_instances = await registry.get_all_model_instances(healthy_only=True)
+ model_name_set = set()
+ for inst in model_instances:
+ name, worker_type = WorkerType.parse_worker_key(inst.model_name)
+ if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC:
+ model_name_set.add(name)
+ models = list(model_name_set)
+ models.sort()
+ # TODO: return real model permission details
+ model_cards = []
+ for m in models:
+ model_cards.append(
+ ModelCard(
+ id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()]
+ )
+ )
+ return ModelList(data=model_cards)
+
+ async def chat_completion_stream_generator(
+ self, model_name: str, params: Dict[str, Any], n: int
+ ) -> Generator[str, Any, None]:
+ """Chat stream completion generator
+
+ Args:
+ model_name (str): Model name
+ params (Dict[str, Any]): The parameters pass to model worker
+ n (int): How many completions to generate for each prompt.
+ """
+ worker_manager = self.get_worker_manager()
+ id = f"chatcmpl-{shortuuid.random()}"
+ finish_stream_events = []
+ for i in range(n):
+ # First chunk with role
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(role="assistant"),
+ finish_reason=None,
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=id, choices=[choice_data], model=model_name
+ )
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+
+ previous_text = ""
+ async for model_output in worker_manager.generate_stream(params):
+ model_output: ModelOutput = model_output
+ if model_output.error_code != 0:
+ yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+ decoded_unicode = model_output.text.replace("\ufffd", "")
+ delta_text = decoded_unicode[len(previous_text) :]
+ previous_text = (
+ decoded_unicode
+ if len(decoded_unicode) > len(previous_text)
+ else previous_text
+ )
+
+ if len(delta_text) == 0:
+ delta_text = None
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(content=delta_text),
+ finish_reason=model_output.finish_reason,
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=id, choices=[choice_data], model=model_name
+ )
+ if delta_text is None:
+ if model_output.finish_reason is not None:
+ finish_stream_events.append(chunk)
+ continue
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+ # There is not "content" field in the last delta message, so exclude_none to exclude field "content".
+ for finish_chunk in finish_stream_events:
+ yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+
+ async def chat_completion_generate(
+ self, model_name: str, params: Dict[str, Any], n: int
+ ) -> ChatCompletionResponse:
+ """Generate completion
+ Args:
+ model_name (str): Model name
+ params (Dict[str, Any]): The parameters pass to model worker
+ n (int): How many completions to generate for each prompt.
+ """
+ worker_manager: WorkerManager = self.get_worker_manager()
+ choices = []
+ chat_completions = []
+ for i in range(n):
+ model_output = asyncio.create_task(worker_manager.generate(params))
+ chat_completions.append(model_output)
+ try:
+ all_tasks = await asyncio.gather(*chat_completions)
+ except Exception as e:
+ return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
+ usage = UsageInfo()
+ for i, model_output in enumerate(all_tasks):
+ model_output: ModelOutput = model_output
+ if model_output.error_code != 0:
+ return create_error_response(model_output.error_code, model_output.text)
+ choices.append(
+ ChatCompletionResponseChoice(
+ index=i,
+ message=ChatMessage(role="assistant", content=model_output.text),
+ finish_reason=model_output.finish_reason or "stop",
+ )
+ )
+ if model_output.usage:
+ task_usage = UsageInfo.parse_obj(model_output.usage)
+ for usage_key, usage_value in task_usage.dict().items():
+ setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
+
+ return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
+
+
+def get_api_server() -> APIServer:
+ api_server = global_system_app.get_component(
+ ComponentType.MODEL_API_SERVER, APIServer, default_component=None
+ )
+ if not api_server:
+ global_system_app.register(APIServer)
+ return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer)
+
+
+router = APIRouter()
+
+
+@router.get("/v1/models", dependencies=[Depends(check_api_key)])
+async def get_available_models(api_server: APIServer = Depends(get_api_server)):
+ return await api_server.get_available_models()
+
+
+@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
+async def create_chat_completion(
+ request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server)
+):
+ await api_server.get_model_instances_or_raise(request.model)
+ error_check_ret = check_requests(request)
+ if error_check_ret is not None:
+ return error_check_ret
+ params = {
+ "model": request.model,
+ "messages": ModelMessage.to_dict_list(
+ ModelMessage.from_openai_messages(request.messages)
+ ),
+ "echo": False,
+ }
+ if request.temperature:
+ params["temperature"] = request.temperature
+ if request.top_p:
+ params["top_p"] = request.top_p
+ if request.max_tokens:
+ params["max_new_tokens"] = request.max_tokens
+ if request.stop:
+ params["stop"] = request.stop
+ if request.user:
+ params["user"] = request.user
+
+ # TODO check token length
+ if request.stream:
+ generator = api_server.chat_completion_stream_generator(
+ request.model, params, request.n
+ )
+ return StreamingResponse(generator, media_type="text/event-stream")
+ return await api_server.chat_completion_generate(request.model, params, request.n)
+
+
+def _initialize_all(controller_addr: str, system_app: SystemApp):
+ from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient
+ from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
+
+ if not system_app.get_component(
+ ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
+ ):
+ # Register model registry if not exist
+ registry = ModelRegistryClient(controller_addr)
+ registry.name = ComponentType.MODEL_REGISTRY.value
+ system_app.register_instance(registry)
+
+ registry = system_app.get_component(
+ ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
+ )
+ worker_manager = RemoteWorkerManager(registry)
+
+ # Register worker manager component if not exist
+ system_app.get_component(
+ ComponentType.WORKER_MANAGER_FACTORY,
+ WorkerManagerFactory,
+ or_register_component=_DefaultWorkerManagerFactory,
+ worker_manager=worker_manager,
+ )
+ # Register api server component if not exist
+ system_app.get_component(
+ ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer
+ )
+
+
+def initialize_apiserver(
+ controller_addr: str,
+ app=None,
+ system_app: SystemApp = None,
+ host: str = None,
+ port: int = None,
+ api_keys: List[str] = None,
+):
+ global global_system_app
+ global api_settings
+ embedded_mod = True
+ if not app:
+ embedded_mod = False
+ app = FastAPI()
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
+ allow_headers=["*"],
+ )
+
+ if not system_app:
+ system_app = SystemApp(app)
+ global_system_app = system_app
+
+ if api_keys:
+ api_settings.api_keys = api_keys
+
+ app.include_router(router, prefix="/api", tags=["APIServer"])
+
+ @app.exception_handler(APIServerException)
+ async def validation_apiserver_exception_handler(request, exc: APIServerException):
+ return create_error_response(exc.code, exc.message)
+
+ @app.exception_handler(RequestValidationError)
+ async def validation_exception_handler(request, exc):
+ return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
+
+ _initialize_all(controller_addr, system_app)
+
+ if not embedded_mod:
+ import uvicorn
+
+ uvicorn.run(app, host=host, port=port, log_level="info")
+
+
+def run_apiserver():
+ parser = EnvArgumentParser()
+ env_prefix = "apiserver_"
+ apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass(
+ ModelAPIServerParameters,
+ env_prefixes=[env_prefix],
+ )
+ setup_logging(
+ "pilot",
+ logging_level=apiserver_params.log_level,
+ logger_filename=apiserver_params.log_file,
+ )
+ api_keys = None
+ if apiserver_params.api_keys:
+ api_keys = apiserver_params.api_keys.strip().split(",")
+
+ initialize_apiserver(
+ apiserver_params.controller_addr,
+ host=apiserver_params.host,
+ port=apiserver_params.port,
+ api_keys=api_keys,
+ )
+
+
+if __name__ == "__main__":
+ run_apiserver()
diff --git a/pilot/model/cluster/apiserver/tests/__init__.py b/pilot/model/cluster/apiserver/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/model/cluster/apiserver/tests/test_api.py b/pilot/model/cluster/apiserver/tests/test_api.py
new file mode 100644
index 000000000..281a8aff6
--- /dev/null
+++ b/pilot/model/cluster/apiserver/tests/test_api.py
@@ -0,0 +1,248 @@
+import pytest
+import pytest_asyncio
+from aioresponses import aioresponses
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from httpx import AsyncClient, HTTPError
+
+from pilot.component import SystemApp
+from pilot.utils.openai_utils import chat_completion_stream, chat_completion
+
+from pilot.model.cluster.apiserver.api import (
+ api_settings,
+ initialize_apiserver,
+ ModelList,
+ UsageInfo,
+ ChatCompletionResponse,
+ ChatCompletionResponseStreamChoice,
+ ChatCompletionStreamResponse,
+ ChatMessage,
+ ChatCompletionResponseChoice,
+ DeltaMessage,
+)
+from pilot.model.cluster.tests.conftest import _new_cluster
+
+from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
+
+app = FastAPI()
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
+ allow_headers=["*"],
+)
+
+
+@pytest_asyncio.fixture
+async def system_app():
+ return SystemApp(app)
+
+
+@pytest_asyncio.fixture
+async def client(request, system_app: SystemApp):
+ param = getattr(request, "param", {})
+ api_keys = param.get("api_keys", [])
+ client_api_key = param.get("client_api_key")
+ if "num_workers" not in param:
+ param["num_workers"] = 2
+ if "api_keys" in param:
+ del param["api_keys"]
+ headers = {}
+ if client_api_key:
+ headers["Authorization"] = "Bearer " + client_api_key
+ print(f"param: {param}")
+ if api_settings:
+ # Clear global api keys
+ api_settings.api_keys = []
+ async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
+ async with _new_cluster(**param) as cluster:
+ worker_manager, model_registry = cluster
+ system_app.register(_DefaultWorkerManagerFactory, worker_manager)
+ system_app.register_instance(model_registry)
+ # print(f"Instances {model_registry.registry}")
+ initialize_apiserver(None, app, system_app, api_keys=api_keys)
+ yield client
+
+
+@pytest.mark.asyncio
+async def test_get_all_models(client: AsyncClient):
+ res = await client.get("/api/v1/models")
+ res.status_code == 200
+ model_lists = ModelList.parse_obj(res.json())
+ print(f"model list json: {res.json()}")
+ assert model_lists.object == "list"
+ assert len(model_lists.data) == 2
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "client, expected_messages",
+ [
+ ({"stream_messags": ["Hello", " world."]}, "Hello world."),
+ ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
+ ],
+ indirect=["client"],
+)
+async def test_chat_completions(client: AsyncClient, expected_messages):
+ chat_data = {
+ "model": "test-model-name-0",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "stream": True,
+ }
+ full_text = ""
+ async for text in chat_completion_stream(
+ "/api/v1/chat/completions", chat_data, client
+ ):
+ full_text += text
+ assert full_text == expected_messages
+
+ assert (
+ await chat_completion("/api/v1/chat/completions", chat_data, client)
+ == expected_messages
+ )
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "client, expected_messages, client_api_key",
+ [
+ (
+ {"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
+ "Hello world.",
+ "abc",
+ ),
+ ({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
+ ],
+ indirect=["client"],
+)
+async def test_chat_completions_with_openai_lib_async_no_stream(
+ client: AsyncClient, expected_messages: str, client_api_key: str
+):
+ import openai
+
+ openai.api_key = client_api_key
+ openai.api_base = "http://test/api/v1"
+
+ model_name = "test-model-name-0"
+
+ with aioresponses() as mocked:
+ mock_message = {"text": expected_messages}
+ one_res = ChatCompletionResponseChoice(
+ index=0,
+ message=ChatMessage(role="assistant", content=expected_messages),
+ finish_reason="stop",
+ )
+ data = ChatCompletionResponse(
+ model=model_name, choices=[one_res], usage=UsageInfo()
+ )
+ mock_message = f"{data.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+ # Mock http request
+ mocked.post(
+ "http://test/api/v1/chat/completions", status=200, body=mock_message
+ )
+ completion = await openai.ChatCompletion.acreate(
+ model=model_name,
+ messages=[{"role": "user", "content": "Hello! What is your name?"}],
+ )
+ assert completion.choices[0].message.content == expected_messages
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "client, expected_messages, client_api_key",
+ [
+ (
+ {"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
+ "Hello world.",
+ "abc",
+ ),
+ ({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
+ ],
+ indirect=["client"],
+)
+async def test_chat_completions_with_openai_lib_async_stream(
+ client: AsyncClient, expected_messages: str, client_api_key: str
+):
+ import openai
+
+ openai.api_key = client_api_key
+ openai.api_base = "http://test/api/v1"
+
+ model_name = "test-model-name-0"
+
+ with aioresponses() as mocked:
+ mock_message = {"text": expected_messages}
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=0,
+ delta=DeltaMessage(content=expected_messages),
+ finish_reason="stop",
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=0, choices=[choice_data], model=model_name
+ )
+ mock_message = f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+ mocked.post(
+ "http://test/api/v1/chat/completions",
+ status=200,
+ body=mock_message,
+ content_type="text/event-stream",
+ )
+
+ stream_stream_resp = ""
+ async for stream_resp in await openai.ChatCompletion.acreate(
+ model=model_name,
+ messages=[{"role": "user", "content": "Hello! What is your name?"}],
+ stream=True,
+ ):
+ stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "")
+ assert stream_stream_resp == expected_messages
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "client, expected_messages, api_key_is_error",
+ [
+ (
+ {
+ "stream_messags": ["Hello", " world."],
+ "api_keys": ["abc", "xx"],
+ "client_api_key": "abc",
+ },
+ "Hello world.",
+ False,
+ ),
+ ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。", False),
+ (
+ {"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc", "xx"]},
+ "你好,我是张三。",
+ True,
+ ),
+ (
+ {
+ "stream_messags": ["你好,我是", "张三。"],
+ "api_keys": ["abc", "xx"],
+ "client_api_key": "error_api_key",
+ },
+ "你好,我是张三。",
+ True,
+ ),
+ ],
+ indirect=["client"],
+)
+async def test_chat_completions_with_api_keys(
+ client: AsyncClient, expected_messages: str, api_key_is_error: bool
+):
+ chat_data = {
+ "model": "test-model-name-0",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "stream": True,
+ }
+ if api_key_is_error:
+ with pytest.raises(HTTPError):
+ await chat_completion("/api/v1/chat/completions", chat_data, client)
+ else:
+ assert (
+ await chat_completion("/api/v1/chat/completions", chat_data, client)
+ == expected_messages
+ )
diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py
index 173c8c019..0006d91a0 100644
--- a/pilot/model/cluster/controller/controller.py
+++ b/pilot/model/cluster/controller/controller.py
@@ -66,7 +66,9 @@ class LocalModelController(BaseModelController):
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
)
if not model_name:
- return await self.registry.get_all_model_instances()
+ return await self.registry.get_all_model_instances(
+ healthy_only=healthy_only
+ )
else:
return await self.registry.get_all_instances(model_name, healthy_only)
@@ -98,8 +100,10 @@ class _RemoteModelController(BaseModelController):
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
- async def get_all_model_instances(self) -> List[ModelInstance]:
- return await self.get_all_instances()
+ async def get_all_model_instances(
+ self, healthy_only: bool = False
+ ) -> List[ModelInstance]:
+ return await self.get_all_instances(healthy_only=healthy_only)
@sync_api_remote(path="/api/controller/models")
def sync_get_all_instances(
diff --git a/pilot/model/cluster/registry.py b/pilot/model/cluster/registry.py
index 398882eb9..eb5f1e415 100644
--- a/pilot/model/cluster/registry.py
+++ b/pilot/model/cluster/registry.py
@@ -1,22 +1,37 @@
import random
import threading
import time
+import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta
-from typing import Dict, List, Tuple
+from typing import Dict, List, Optional, Tuple
import itertools
+from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import ModelInstance
-class ModelRegistry(ABC):
+logger = logging.getLogger(__name__)
+
+
+class ModelRegistry(BaseComponent, ABC):
"""
Abstract base class for a model registry. It provides an interface
for registering, deregistering, fetching instances, and sending heartbeats
for instances.
"""
+ name = ComponentType.MODEL_REGISTRY
+
+ def __init__(self, system_app: SystemApp | None = None):
+ self.system_app = system_app
+ super().__init__(system_app)
+
+ def init_app(self, system_app: SystemApp):
+ """Initialize the component with the main application."""
+ self.system_app = system_app
+
@abstractmethod
async def register_instance(self, instance: ModelInstance) -> bool:
"""
@@ -65,9 +80,11 @@ class ModelRegistry(ABC):
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
@abstractmethod
- async def get_all_model_instances(self) -> List[ModelInstance]:
+ async def get_all_model_instances(
+ self, healthy_only: bool = False
+ ) -> List[ModelInstance]:
"""
- Fetch all instances of all models
+ Fetch all instances of all models, Optionally, fetch only the healthy instances.
Returns:
- List[ModelInstance]: A list of instances for the all models.
@@ -105,8 +122,12 @@ class ModelRegistry(ABC):
class EmbeddedModelRegistry(ModelRegistry):
def __init__(
- self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120
+ self,
+ system_app: SystemApp | None = None,
+ heartbeat_interval_secs: int = 60,
+ heartbeat_timeout_secs: int = 120,
):
+ super().__init__(system_app)
self.registry: Dict[str, List[ModelInstance]] = defaultdict(list)
self.heartbeat_interval_secs = heartbeat_interval_secs
self.heartbeat_timeout_secs = heartbeat_timeout_secs
@@ -180,9 +201,14 @@ class EmbeddedModelRegistry(ModelRegistry):
instances = [ins for ins in instances if ins.healthy == True]
return instances
- async def get_all_model_instances(self) -> List[ModelInstance]:
- print(self.registry)
- return list(itertools.chain(*self.registry.values()))
+ async def get_all_model_instances(
+ self, healthy_only: bool = False
+ ) -> List[ModelInstance]:
+ logger.debug("Current registry metadata:\n{self.registry}")
+ instances = list(itertools.chain(*self.registry.values()))
+ if healthy_only:
+ instances = [ins for ins in instances if ins.healthy == True]
+ return instances
async def send_heartbeat(self, instance: ModelInstance) -> bool:
_, exist_ins = self._get_instances(
diff --git a/pilot/model/cluster/tests/__init__.py b/pilot/model/cluster/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/model/cluster/worker/tests/base_tests.py b/pilot/model/cluster/tests/conftest.py
similarity index 71%
rename from pilot/model/cluster/worker/tests/base_tests.py
rename to pilot/model/cluster/tests/conftest.py
index 21821d9f9..f614387ac 100644
--- a/pilot/model/cluster/worker/tests/base_tests.py
+++ b/pilot/model/cluster/tests/conftest.py
@@ -6,6 +6,7 @@ from pilot.model.parameter import ModelParameters, ModelWorkerParameters, Worker
from pilot.model.base import ModelOutput
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.worker.manager import (
+ WorkerManager,
LocalWorkerManager,
RegisterFunc,
DeregisterFunc,
@@ -13,6 +14,23 @@ from pilot.model.cluster.worker.manager import (
ApplyFunction,
)
+from pilot.model.base import ModelInstance
+from pilot.model.cluster.registry import ModelRegistry, EmbeddedModelRegistry
+
+
+@pytest.fixture
+def model_registry(request):
+ return EmbeddedModelRegistry()
+
+
+@pytest.fixture
+def model_instance():
+ return ModelInstance(
+ model_name="test_model",
+ host="192.168.1.1",
+ port=5000,
+ )
+
class MockModelWorker(ModelWorker):
def __init__(
@@ -51,8 +69,10 @@ class MockModelWorker(ModelWorker):
raise Exception("Stop worker error for mock")
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
+ full_text = ""
for msg in self.stream_messags:
- yield ModelOutput(text=msg, error_code=0)
+ full_text += msg
+ yield ModelOutput(text=full_text, error_code=0)
def generate(self, params: Dict) -> ModelOutput:
output = None
@@ -67,6 +87,8 @@ class MockModelWorker(ModelWorker):
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
+ClusterType = Tuple[WorkerManager, ModelRegistry]
+
def _new_worker_params(
model_name: str = _TEST_MODEL_NAME,
@@ -85,7 +107,9 @@ def _create_workers(
worker_type: str = WorkerType.LLM.value,
stream_messags: List[str] = None,
embeddings: List[List[float]] = None,
-) -> List[Tuple[ModelWorker, ModelWorkerParameters]]:
+ host: str = "127.0.0.1",
+ start_port=8001,
+) -> List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]:
workers = []
for i in range(num_workers):
model_name = f"test-model-name-{i}"
@@ -98,10 +122,16 @@ def _create_workers(
stream_messags=stream_messags,
embeddings=embeddings,
)
+ model_instance = ModelInstance(
+ model_name=WorkerType.to_worker_key(model_name, worker_type),
+ host=host,
+ port=start_port + i,
+ healthy=True,
+ )
worker_params = _new_worker_params(
model_name, model_path, worker_type=worker_type
)
- workers.append((worker, worker_params))
+ workers.append((worker, worker_params, model_instance))
return workers
@@ -127,12 +157,12 @@ async def _start_worker_manager(**kwargs):
model_registry=model_registry,
)
- for worker, worker_params in _create_workers(
+ for worker, worker_params, model_instance in _create_workers(
num_workers, error_worker, stop_error, stream_messags, embeddings
):
worker_manager.add_worker(worker, worker_params)
if workers:
- for worker, worker_params in workers:
+ for worker, worker_params, model_instance in workers:
worker_manager.add_worker(worker, worker_params)
if start:
@@ -143,6 +173,15 @@ async def _start_worker_manager(**kwargs):
await worker_manager.stop()
+async def _create_model_registry(
+ workers: List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]
+) -> ModelRegistry:
+ registry = EmbeddedModelRegistry()
+ for _, _, inst in workers:
+ assert await registry.register_instance(inst) == True
+ return registry
+
+
@pytest_asyncio.fixture
async def manager_2_workers(request):
param = getattr(request, "param", {})
@@ -166,3 +205,27 @@ async def manager_2_embedding_workers(request):
)
async with _start_worker_manager(workers=workers, **param) as worker_manager:
yield (worker_manager, workers)
+
+
+@asynccontextmanager
+async def _new_cluster(**kwargs) -> ClusterType:
+ num_workers = kwargs.get("num_workers", 0)
+ workers = _create_workers(
+ num_workers, stream_messags=kwargs.get("stream_messags", [])
+ )
+ if "num_workers" in kwargs:
+ del kwargs["num_workers"]
+ registry = await _create_model_registry(
+ workers,
+ )
+ async with _start_worker_manager(workers=workers, **kwargs) as worker_manager:
+ yield (worker_manager, registry)
+
+
+@pytest_asyncio.fixture
+async def cluster_2_workers(request):
+ param = getattr(request, "param", {})
+ workers = _create_workers(2)
+ registry = await _create_model_registry(workers)
+ async with _start_worker_manager(workers=workers, **param) as worker_manager:
+ yield (worker_manager, registry)
diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py
index 5caa2ee7e..44a476f20 100644
--- a/pilot/model/cluster/worker/default_worker.py
+++ b/pilot/model/cluster/worker/default_worker.py
@@ -256,15 +256,22 @@ class DefaultModelWorker(ModelWorker):
return params, model_context, generate_stream_func, model_span
def _handle_output(self, output, previous_response, model_context):
+ finish_reason = None
+ usage = None
if isinstance(output, dict):
finish_reason = output.get("finish_reason")
+ usage = output.get("usage")
output = output["text"]
if finish_reason is not None:
logger.info(f"finish_reason: {finish_reason}")
incremental_output = output[len(previous_response) :]
print(incremental_output, end="", flush=True)
model_output = ModelOutput(
- text=output, error_code=0, model_context=model_context
+ text=output,
+ error_code=0,
+ model_context=model_context,
+ finish_reason=finish_reason,
+ usage=usage,
)
return model_output, incremental_output, output
diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py
index a76fa6685..2dcfb086e 100644
--- a/pilot/model/cluster/worker/manager.py
+++ b/pilot/model/cluster/worker/manager.py
@@ -99,9 +99,7 @@ class LocalWorkerManager(WorkerManager):
)
def _worker_key(self, worker_type: str, model_name: str) -> str:
- if isinstance(worker_type, WorkerType):
- worker_type = worker_type.value
- return f"{model_name}@{worker_type}"
+ return WorkerType.to_worker_key(model_name, worker_type)
async def run_blocking_func(self, func, *args):
if asyncio.iscoroutinefunction(func):
diff --git a/pilot/model/cluster/worker/tests/test_manager.py b/pilot/model/cluster/worker/tests/test_manager.py
index 919e64f99..681fb49a3 100644
--- a/pilot/model/cluster/worker/tests/test_manager.py
+++ b/pilot/model/cluster/worker/tests/test_manager.py
@@ -3,7 +3,7 @@ import pytest
from typing import List, Iterator, Dict, Tuple
from dataclasses import asdict
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
-from pilot.model.base import ModelOutput, WorkerApplyType
+from pilot.model.base import ModelOutput, WorkerApplyType, ModelInstance
from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.manager_base import WorkerRunData
@@ -14,7 +14,7 @@ from pilot.model.cluster.worker.manager import (
SendHeartbeatFunc,
ApplyFunction,
)
-from pilot.model.cluster.worker.tests.base_tests import (
+from pilot.model.cluster.tests.conftest import (
MockModelWorker,
manager_2_workers,
manager_with_2_workers,
@@ -216,7 +216,7 @@ async def test__remove_worker():
workers = _create_workers(3)
async with _start_worker_manager(workers=workers, stop=False) as manager:
assert len(manager.workers) == 3
- for _, worker_params in workers:
+ for _, worker_params, _ in workers:
manager._remove_worker(worker_params)
not_exist_parmas = _new_worker_params(
model_name="this is a not exist worker params"
@@ -229,7 +229,7 @@ async def test__remove_worker():
async def test_model_startup(mock_build_worker):
async with _start_worker_manager() as manager:
workers = _create_workers(1)
- worker, worker_params = workers[0]
+ worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
@@ -245,7 +245,7 @@ async def test_model_startup(mock_build_worker):
async with _start_worker_manager() as manager:
workers = _create_workers(1, error_worker=True)
- worker, worker_params = workers[0]
+ worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
host="127.0.0.1",
@@ -263,7 +263,7 @@ async def test_model_startup(mock_build_worker):
async def test_model_shutdown(mock_build_worker):
async with _start_worker_manager(start=False, stop=False) as manager:
workers = _create_workers(1)
- worker, worker_params = workers[0]
+ worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
@@ -298,7 +298,7 @@ async def test_get_model_instances(is_async):
workers = _create_workers(3)
async with _start_worker_manager(workers=workers, stop=False) as manager:
assert len(manager.workers) == 3
- for _, worker_params in workers:
+ for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
if is_async:
@@ -326,7 +326,7 @@ async def test__simple_select(
]
):
manager, workers = manager_with_2_workers
- for _, worker_params in workers:
+ for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
instances = await manager.get_model_instances(worker_type, model_name)
@@ -351,7 +351,7 @@ async def test_select_one_instance(
],
):
manager, workers = manager_with_2_workers
- for _, worker_params in workers:
+ for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
if is_async:
@@ -376,7 +376,7 @@ async def test__get_model(
],
):
manager, workers = manager_with_2_workers
- for _, worker_params in workers:
+ for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
@@ -403,13 +403,13 @@ async def test_generate_stream(
expected_messages: str,
):
manager, workers = manager_with_2_workers
- for _, worker_params in workers:
+ for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
text = ""
async for out in manager.generate_stream(params):
- text += out.text
+ text = out.text
assert text == expected_messages
@@ -417,8 +417,8 @@ async def test_generate_stream(
@pytest.mark.parametrize(
"manager_with_2_workers, expected_messages",
[
- ({"stream_messags": ["Hello", " world."]}, " world."),
- ({"stream_messags": ["你好,我是", "张三。"]}, "张三。"),
+ ({"stream_messags": ["Hello", " world."]}, "Hello world."),
+ ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
],
indirect=["manager_with_2_workers"],
)
@@ -429,7 +429,7 @@ async def test_generate(
expected_messages: str,
):
manager, workers = manager_with_2_workers
- for _, worker_params in workers:
+ for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
@@ -454,7 +454,7 @@ async def test_embeddings(
is_async: bool,
):
manager, workers = manager_2_embedding_workers
- for _, worker_params in workers:
+ for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name, "input": ["hello", "world"]}
@@ -472,7 +472,7 @@ async def test_parameter_descriptions(
]
):
manager, workers = manager_with_2_workers
- for _, worker_params in workers:
+ for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = await manager.parameter_descriptions(worker_type, model_name)
diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py
index b3674e946..5d4309d9f 100644
--- a/pilot/model/conversation.py
+++ b/pilot/model/conversation.py
@@ -339,6 +339,27 @@ register_conv_template(
)
)
+
+# codellama template
+# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
+# reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md
+register_conv_template(
+ Conversation(
+ name="codellama",
+ system="[INST] <>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request."
+ "If you don't know the answer to the request, please don't share false information.\n<>\n\n",
+ roles=("[INST]", "[/INST]"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA2,
+ sep=" ",
+ sep2=" ",
+ stop_token_ids=[2],
+ system_formatter=lambda msg: f"[INST] <>\n{msg}\n<>\n\n",
+ )
+)
+
+
# Alpaca default template
register_conv_template(
Conversation(
diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py
index 1580e8863..e2deeaa02 100644
--- a/pilot/model/model_adapter.py
+++ b/pilot/model/model_adapter.py
@@ -45,6 +45,10 @@ _OLD_MODELS = [
"llama-cpp",
"proxyllm",
"gptj-6b",
+ "codellama-13b-sql-sft",
+ "codellama-7b",
+ "codellama-7b-sql-sft",
+ "codellama-13b",
]
@@ -148,8 +152,12 @@ class LLMModelAdaper:
conv.append_message(conv.roles[1], content)
else:
raise ValueError(f"Unknown role: {role}")
+
if system_messages:
- conv.set_system_message("".join(system_messages))
+ if isinstance(conv, Conversation):
+ conv.set_system_message("".join(system_messages))
+ else:
+ conv.update_system_message("".join(system_messages))
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
@@ -459,7 +467,8 @@ register_conv_template(
sep="\n",
sep2="",
stop_str=["", "[UNK]"],
- )
+ ),
+ override=True,
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
register_conv_template(
@@ -474,7 +483,8 @@ register_conv_template(
sep="###",
sep2="",
stop_str=["", "[UNK]"],
- )
+ ),
+ override=True,
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
register_conv_template(
@@ -487,5 +497,6 @@ register_conv_template(
sep="",
sep2="",
stop_str=["", "<|endoftext|>"],
- )
+ ),
+ override=True,
)
diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py
index ea81ec091..e21de1c42 100644
--- a/pilot/model/parameter.py
+++ b/pilot/model/parameter.py
@@ -1,9 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
+
import os
from dataclasses import dataclass, field
from enum import Enum
-from typing import Dict, Optional
+from typing import Dict, Optional, Union, Tuple
from pilot.model.conversation import conv_templates
from pilot.utils.parameter_utils import BaseParameters
@@ -19,6 +20,35 @@ class WorkerType(str, Enum):
def values():
return [item.value for item in WorkerType]
+ @staticmethod
+ def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str:
+ """Generate worker key from worker name and worker type
+
+ Args:
+ worker_name (str): Worker name(eg., chatglm2-6b)
+ worker_type (Union[str, "WorkerType"]): Worker type(eg., 'llm', or [`WorkerType.LLM`])
+
+ Returns:
+ str: Generated worker key
+ """
+ if "@" in worker_name:
+ raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}")
+ if isinstance(worker_type, WorkerType):
+ worker_type = worker_type.value
+ return f"{worker_name}@{worker_type}"
+
+ @staticmethod
+ def parse_worker_key(worker_key: str) -> Tuple[str, str]:
+ """Parse worker name and worker type from worker key
+
+ Args:
+ worker_key (str): Worker key generated by [`WorkerType.to_worker_key`]
+
+ Returns:
+ Tuple[str, str]: Worker name and worker type
+ """
+ return tuple(worker_key.split("@"))
+
@dataclass
class ModelControllerParameters(BaseParameters):
@@ -60,6 +90,56 @@ class ModelControllerParameters(BaseParameters):
)
+@dataclass
+class ModelAPIServerParameters(BaseParameters):
+ host: Optional[str] = field(
+ default="0.0.0.0", metadata={"help": "Model API server deploy host"}
+ )
+ port: Optional[int] = field(
+ default=8100, metadata={"help": "Model API server deploy port"}
+ )
+ daemon: Optional[bool] = field(
+ default=False, metadata={"help": "Run Model API server in background"}
+ )
+ controller_addr: Optional[str] = field(
+ default="http://127.0.0.1:8000",
+ metadata={"help": "The Model controller address to connect"},
+ )
+
+ api_keys: Optional[str] = field(
+ default=None,
+ metadata={"help": "Optional list of comma separated API keys"},
+ )
+
+ log_level: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "Logging level",
+ "valid_values": [
+ "FATAL",
+ "ERROR",
+ "WARNING",
+ "WARNING",
+ "INFO",
+ "DEBUG",
+ "NOTSET",
+ ],
+ },
+ )
+ log_file: Optional[str] = field(
+ default="dbgpt_model_apiserver.log",
+ metadata={
+ "help": "The filename to store log",
+ },
+ )
+ tracer_file: Optional[str] = field(
+ default="dbgpt_model_apiserver_tracer.jsonl",
+ metadata={
+ "help": "The filename to store tracer span records",
+ },
+ )
+
+
@dataclass
class BaseModelParameters(BaseParameters):
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py
index eeb42a285..12a72e909 100644
--- a/pilot/scene/base_message.py
+++ b/pilot/scene/base_message.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Tuple, Optional
+from typing import Any, Dict, List, Tuple, Optional, Union
from pydantic import BaseModel, Field, root_validator
@@ -70,14 +70,6 @@ class SystemMessage(BaseMessage):
return "system"
-class ModelMessage(BaseModel):
- """Type of message that interaction between dbgpt-server and llm-server"""
-
- """Similar to openai's message format"""
- role: str
- content: str
-
-
class ModelMessageRoleType:
""" "Type of ModelMessage role"""
@@ -87,6 +79,45 @@ class ModelMessageRoleType:
VIEW = "view"
+class ModelMessage(BaseModel):
+ """Type of message that interaction between dbgpt-server and llm-server"""
+
+ """Similar to openai's message format"""
+ role: str
+ content: str
+
+ @staticmethod
+ def from_openai_messages(
+ messages: Union[str, List[Dict[str, str]]]
+ ) -> List["ModelMessage"]:
+ """Openai message format to current ModelMessage format"""
+ if isinstance(messages, str):
+ return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)]
+ result = []
+ for message in messages:
+ msg_role = message["role"]
+ content = message["content"]
+ if msg_role == "system":
+ result.append(
+ ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content)
+ )
+ elif msg_role == "user":
+ result.append(
+ ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
+ )
+ elif msg_role == "assistant":
+ result.append(
+ ModelMessage(role=ModelMessageRoleType.AI, content=content)
+ )
+ else:
+ raise ValueError(f"Unknown role: {msg_role}")
+ return result
+
+ @staticmethod
+ def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
+ return list(map(lambda m: m.dict(), messages))
+
+
class Generation(BaseModel):
"""Output of a single generation."""
diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py
index e6c3d9056..4d40a8eef 100644
--- a/pilot/scene/chat_knowledge/v1/chat.py
+++ b/pilot/scene/chat_knowledge/v1/chat.py
@@ -1,3 +1,4 @@
+import json
import os
from typing import Dict, List
@@ -110,12 +111,14 @@ class ChatKnowledge(BaseChat):
list(map(lambda doc: doc.metadata, docs)), "source"
)
- self.current_message.knowledge_source = self.sources
- if not docs:
- raise ValueError(
- "you have no knowledge space, please add your knowledge space"
- )
- context = [d.page_content for d in docs]
+ if not docs or len(docs) == 0:
+ print("no relevant docs to retrieve")
+ context = "no relevant docs to retrieve"
+ # raise ValueError(
+ # "you have no knowledge space, please add your knowledge space"
+ # )
+ else:
+ context = [d.page_content for d in docs]
context = context[: self.max_token]
relations = list(
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
@@ -128,17 +131,26 @@ class ChatKnowledge(BaseChat):
return input_values
def parse_source_view(self, sources: List):
- html_title = f"##### **References:**"
- lines = ""
+ """
+ build knowledge reference view message to web
+ {
+ "title":"References",
+ "reference":{
+ "name":"aa.pdf",
+ "pages":["1","2","3"]
+ },
+ }
+ """
+ references = {"title": "References", "reference": {}}
for item in sources:
source = item["source"] if "source" in item else ""
- pages = ",".join(item["pages"]) if "pages" in item else ""
- lines += f"{source}"
+ references["reference"]["name"] = source
+ pages = item["pages"] if "pages" in item else []
if len(pages) > 0:
- lines += f", **pages**:{pages}\n\n"
- else:
- lines += "\n\n"
- html = f"""{html_title}\n{lines}"""
+ references["reference"]["pages"] = pages
+ html = (
+ f"""{json.dumps(references, ensure_ascii=False)}"""
+ )
return html
def merge_by_key(self, data, key):
diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py
index cb486021b..64b72739b 100644
--- a/pilot/server/chat_adapter.py
+++ b/pilot/server/chat_adapter.py
@@ -215,6 +215,16 @@ class Llama2ChatAdapter(BaseChatAdpter):
return get_conv_template("llama-2")
+class CodeLlamaChatAdapter(BaseChatAdpter):
+ """The model ChatAdapter for codellama ."""
+
+ def match(self, model_path: str):
+ return "codellama" in model_path.lower()
+
+ def get_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("codellama")
+
+
class BaichuanChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "baichuan" in model_path.lower()
@@ -268,6 +278,7 @@ register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
register_llm_model_chat_adapter(Llama2ChatAdapter)
+register_llm_model_chat_adapter(CodeLlamaChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter)
register_llm_model_chat_adapter(LlamaCppChatAdapter)
diff --git a/pilot/utils/openai_utils.py b/pilot/utils/openai_utils.py
new file mode 100644
index 000000000..6577d3abf
--- /dev/null
+++ b/pilot/utils/openai_utils.py
@@ -0,0 +1,99 @@
+from typing import Dict, Any, Awaitable, Callable, Optional, Iterator
+import httpx
+import asyncio
+import logging
+import json
+
+logger = logging.getLogger(__name__)
+MessageCaller = Callable[[str], Awaitable[None]]
+
+
+async def _do_chat_completion(
+ url: str,
+ chat_data: Dict[str, Any],
+ client: httpx.AsyncClient,
+ headers: Dict[str, Any] = {},
+ timeout: int = 60,
+ caller: Optional[MessageCaller] = None,
+) -> Iterator[str]:
+ async with client.stream(
+ "POST",
+ url,
+ headers=headers,
+ json=chat_data,
+ timeout=timeout,
+ ) as res:
+ if res.status_code != 200:
+ error_message = await res.aread()
+ if error_message:
+ error_message = error_message.decode("utf-8")
+ logger.error(
+ f"Request failed with status {res.status_code}. Error: {error_message}"
+ )
+ raise httpx.RequestError(
+ f"Request failed with status {res.status_code}",
+ request=res.request,
+ )
+ async for line in res.aiter_lines():
+ if line:
+ if not line.startswith("data: "):
+ if caller:
+ await caller(line)
+ yield line
+ else:
+ decoded_line = line.split("data: ", 1)[1]
+ if decoded_line.lower().strip() != "[DONE]".lower():
+ obj = json.loads(decoded_line)
+ if obj["choices"][0]["delta"].get("content") is not None:
+ text = obj["choices"][0]["delta"].get("content")
+ if caller:
+ await caller(text)
+ yield text
+ await asyncio.sleep(0.02)
+
+
+async def chat_completion_stream(
+ url: str,
+ chat_data: Dict[str, Any],
+ client: Optional[httpx.AsyncClient] = None,
+ headers: Dict[str, Any] = {},
+ timeout: int = 60,
+ caller: Optional[MessageCaller] = None,
+) -> Iterator[str]:
+ if client:
+ async for text in _do_chat_completion(
+ url,
+ chat_data,
+ client=client,
+ headers=headers,
+ timeout=timeout,
+ caller=caller,
+ ):
+ yield text
+ else:
+ async with httpx.AsyncClient() as client:
+ async for text in _do_chat_completion(
+ url,
+ chat_data,
+ client=client,
+ headers=headers,
+ timeout=timeout,
+ caller=caller,
+ ):
+ yield text
+
+
+async def chat_completion(
+ url: str,
+ chat_data: Dict[str, Any],
+ client: Optional[httpx.AsyncClient] = None,
+ headers: Dict[str, Any] = {},
+ timeout: int = 60,
+ caller: Optional[MessageCaller] = None,
+) -> str:
+ full_text = ""
+ async for text in chat_completion_stream(
+ url, chat_data, client, headers=headers, timeout=timeout, caller=caller
+ ):
+ full_text += text
+ return full_text
diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt
index 072b527f1..d1a98ed49 100644
--- a/requirements/dev-requirements.txt
+++ b/requirements/dev-requirements.txt
@@ -8,6 +8,7 @@ pytest-integration
pytest-mock
pytest-recording
pytesseract==0.3.10
+aioresponses
# python code format
black
# for git hooks