feat:merge main branch

This commit is contained in:
aries_ckt 2023-11-03 10:27:48 +08:00
commit 6fe7bfd63d
37 changed files with 1550 additions and 166 deletions

126
CODE_OF_CONDUCT Normal file
View File

@ -0,0 +1,126 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or advances of
any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
[INSERT CONTACT METHOD].
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
*Community Impact*: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
*Consequence*: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
*Community Impact*: A violation through a single incident or series of
actions.
*Consequence*: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
*Community Impact*: A serious violation of community standards, including
sustained inappropriate behavior.
*Consequence*: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
*Community Impact*: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
*Consequence*: A permanent ban from any sort of public interaction within the
community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].

View File

@ -43,8 +43,8 @@ DB-GPT is an experimental open-source project that uses localized GPT large mode
## Contents
- [install](#install)
- [demo](#demo)
- [Install](#install)
- [Demo](#demo)
- [introduction](#introduction)
- [features](#features)
- [contribution](#contribution)
@ -177,7 +177,7 @@ Currently, we have released multiple key features, which are listed below to dem
| [StarRocks](https://github.com/StarRocks/starrocks) | No | TODO |
## Introduction
Is the architecture of the entire DB-GPT shown in the following figure:
The architecture of the entire DB-GPT is shown.
<p align="center">
<img src="./assets/DB-GPT.png" width="800" />
@ -213,7 +213,7 @@ The core capabilities mainly consist of the following parts:
## Contribution
- Please run `black .` before submitting the code. Contributing guidelines, [how to contribution](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
- Please run `black .` before submitting the code. Contributing guidelines, [how to contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
## RoadMap
@ -330,7 +330,7 @@ As of October 10, 2023, by fine-tuning an open-source model of 13 billion parame
The MIT License (MIT)
## Contact Information
We are working on building a community, if you have any ideas about building the community, feel free to contact us.
We are working on building a community, if you have any ideas for building the community, feel free to contact us.
[![](https://dcbadge.vercel.app/api/server/nASQyBjvY?compact=true&style=flat)](https://discord.gg/nASQyBjvY)
<p align="center">

Binary file not shown.

Before

Width:  |  Height:  |  Size: 229 KiB

After

Width:  |  Height:  |  Size: 202 KiB

View File

@ -7,6 +7,16 @@ services:
restart: unless-stopped
networks:
- dbgptnet
api-server:
image: eosphorosai/dbgpt:latest
command: dbgpt start apiserver --controller_addr http://controller:8000
restart: unless-stopped
depends_on:
- controller
networks:
- dbgptnet
ports:
- 8100:8100/tcp
llm-worker:
image: eosphorosai/dbgpt:latest
command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000

View File

@ -32,7 +32,7 @@ extensions = [
"sphinx_panels",
"sphinx_tabs.tabs",
"IPython.sphinxext.ipython_console_highlighting",
'sphinx.ext.autosectionlabel'
"sphinx.ext.autosectionlabel",
]
source_suffix = [".ipynb", ".html", ".md", ".rst"]

View File

@ -77,3 +77,4 @@ By analyzing this information, we can identify performance bottlenecks in model
./vms/standalone.md
./vms/index.md
./openai.md

View File

@ -0,0 +1,51 @@
OpenAI-Compatible RESTful APIs
==================================
(openai-apis-index)=
### Install Prepare
You must [deploy DB-GPT cluster](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html) first.
### Launch Model API Server
```bash
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY
```
By default, the Model API Server starts on port 8100.
### Validate with cURL
#### List models
```bash
curl http://127.0.0.1:8100/api/v1/models \
-H "Authorization: Bearer EMPTY" \
-H "Content-Type: application/json"
```
#### Chat completions
```bash
curl http://127.0.0.1:8100/api/v1/chat/completions \
-H "Authorization: Bearer EMPTY" \
-H "Content-Type: application/json" \
-d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}'
```
### Validate with OpenAI Official SDK
#### Chat completions
```python
import openai
openai.api_key = "EMPTY"
openai.api_base = "http://127.0.0.1:8100/api/v1"
model = "vicuna-13b-v1.5"
completion = openai.ChatCompletion.create(
model=model,
messages=[{"role": "user", "content": "hello"}]
)
# print the completion
print(completion.choices[0].message.content)
```

View File

@ -24,9 +24,12 @@ PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
#Azure
LLM_MODEL=chatgpt_proxyllm
OPENAI_API_TYPE=azure
PROXY_API_KEY={your-openai-sk}
PROXY_SERVER_URL=https://xx.openai.azure.com/v1/chat/completions
PROXY_API_KEY={your-azure-sk}
PROXY_API_BASE=https://{your domain}.openai.azure.com/
PROXY_API_TYPE=azure
PROXY_SERVER_URL=xxxx
PROXY_API_VERSION=2023-05-15
PROXYLLM_BACKEND=gpt-35-turbo
#Aliyun tongyi
LLM_MODEL=tongyi_proxyllm

View File

@ -0,0 +1,71 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) 2023, csunny
# This file is distributed under the same license as the DB-GPT package.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.4.0\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-11-02 21:09+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
"Language-Team: zh_CN <LL@li.org>\n"
"Plural-Forms: nplurals=1; plural=0;\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/cluster/openai.md:1
#: 01f4e2bf853341198633b367efec1522
msgid "OpenAI-Compatible RESTful APIs"
msgstr "OpenAI RESTful 兼容接口"
#: ../../getting_started/install/cluster/openai.md:5
#: d8717e42335e4027bf4e76b3d28768ee
msgid "Install Prepare"
msgstr "安装准备"
#: ../../getting_started/install/cluster/openai.md:7
#: 9a48d8ee116942468de4c6faf9a64758
msgid ""
"You must [deploy DB-GPT cluster](https://db-"
"gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html)"
" first."
msgstr "你必须先部署 [DB-GPT 集群]"
"(https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh-cn/latest/getting_started/install/cluster/vms/index.html)。"
#: ../../getting_started/install/cluster/openai.md:9
#: 7673a7121f004f7ca6b1a94a7e238fa3
msgid "Launch Model API Server"
msgstr "启动模型 API Server"
#: ../../getting_started/install/cluster/openai.md:14
#: 84a925c2cbcd4e4895a1d2d2fe8f720f
msgid "By default, the Model API Server starts on port 8100."
msgstr "默认情况下,模型 API Server 使用 8100 端口启动。"
#: ../../getting_started/install/cluster/openai.md:16
#: e53ed41977cd4721becd51eba05c6609
msgid "Validate with cURL"
msgstr "通过 cURL 验证"
#: ../../getting_started/install/cluster/openai.md:18
#: 7c883b410b5c4e53a256bf17c1ded80d
msgid "List models"
msgstr "列出模型"
#: ../../getting_started/install/cluster/openai.md:26
#: ../../getting_started/install/cluster/openai.md:37
#: 7cf0ed13f0754f149ec085cd6cf7a45a 990d5d5ed5d64ab49550e68495b9e7a0
msgid "Chat completions"
msgstr ""
#: ../../getting_started/install/cluster/openai.md:35
#: 81583edd22df44e091d18a0832278131
msgid "Validate with OpenAI Official SDK"
msgstr "通过 OpenAI 官方 SDK 验证"

View File

@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-11-02 10:10+0800\n"
"POT-Creation-Date: 2023-11-02 21:04+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
@ -20,292 +20,292 @@ msgstr ""
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/environment/environment.md:1
#: 28d2f84fc8884e78afad8118cd59c654
#: a17719d2f4374285a7beb4d1db470146
#, fuzzy
msgid "Environment Parameter"
msgstr "环境变量说明"
#: ../../getting_started/install/environment/environment.md:4
#: c83fbb5e1aa643cdb09fffe7f3d1a3c5
#: 9a62e6fff7914eeaa2d195ddef4fcb61
msgid "LLM MODEL Config"
msgstr "模型配置"
#: ../../getting_started/install/environment/environment.md:5
#: eb675965ae57407e8d8bf90fed8e9e2a
#: 90e3991538324ecfac8cac7ef2103ac2
msgid "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
msgstr "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
#: ../../getting_started/install/environment/environment.md:6
#: 5d28d35126d849ea9b0d963fd1ba8699
#: 1f45af01100c4586acbc05469e3006bc
msgid "LLM_MODEL=vicuna-13b"
msgstr "LLM_MODEL=vicuna-13b"
#: ../../getting_started/install/environment/environment.md:8
#: 01955b2d0fbe4d94939ebf2cbb380bdd
#: bed14b704f154c2db525f7fafd3aa5a4
msgid "MODEL_SERVER_ADDRESS"
msgstr "MODEL_SERVER_ADDRESS"
#: ../../getting_started/install/environment/environment.md:9
#: 4eaaa9ab59854c0386b28b3111c82784
#: ea42946cfe4f4ad996bf82c1996e7344
msgid "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY"
msgstr "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY"
#: ../../getting_started/install/environment/environment.md:12
#: 5c2dd05e16834443b7451c2541b59757
#: 021c261231f342fdba34098b1baa06fd
msgid "LIMIT_MODEL_CONCURRENCY=5"
msgstr "LIMIT_MODEL_CONCURRENCY=5"
#: ../../getting_started/install/environment/environment.md:14
#: 7707836c2fb04e7da13d2d59b5f9566f
#: afaf0ba7fd09463d8ff74b514ed7264c
msgid "MAX_POSITION_EMBEDDINGS"
msgstr "MAX_POSITION_EMBEDDINGS"
#: ../../getting_started/install/environment/environment.md:16
#: ee24a7d3d8384e61b715ef3bd362b965
#: e4517a942bca4361a64a00408f993f5b
msgid "MAX_POSITION_EMBEDDINGS=4096"
msgstr "MAX_POSITION_EMBEDDINGS=4096"
#: ../../getting_started/install/environment/environment.md:18
#: 90b51aa4e46b4d1298c672e0052c2f68
#: 78d2ef04ed4548b9b7b0fb8ae35c9d5c
msgid "QUANTIZE_QLORA"
msgstr "QUANTIZE_QLORA"
#: ../../getting_started/install/environment/environment.md:20
#: 7de7a8eb431e4973ae00f68ca0686281
#: bfa65db03c6d46bba293331f03ab15ac
msgid "QUANTIZE_QLORA=True"
msgstr "QUANTIZE_QLORA=True"
#: ../../getting_started/install/environment/environment.md:22
#: e331ca016a474f4aa4e9182165a2693a
#: 1947d45a7f184821910b4834ad5f1897
msgid "QUANTIZE_8bit"
msgstr "QUANTIZE_8bit"
#: ../../getting_started/install/environment/environment.md:24
#: 519ccce5a0884778be2719c437a17bd4
#: 4a2ee2919d0e4bdaa13c9d92eefd2aac
msgid "QUANTIZE_8bit=True"
msgstr "QUANTIZE_8bit=True"
#: ../../getting_started/install/environment/environment.md:27
#: 1c0586d070f046de8d0f9f94a6b508b4
#: 348dc1e411b54ab09414f40a20e934e4
msgid "LLM PROXY Settings"
msgstr "LLM PROXY Settings"
#: ../../getting_started/install/environment/environment.md:28
#: c208c3f4b13f4b39962de814e5be6ab9
#: a692e78425a040f5828ab54ff9a33f77
msgid "OPENAI Key"
msgstr "OPENAI Key"
#: ../../getting_started/install/environment/environment.md:30
#: 9228bbee2faa4467b1d24f1125faaac8
#: 940d00e25a424acf92951a314a64e5ea
msgid "PROXY_API_KEY={your-openai-sk}"
msgstr "PROXY_API_KEY={your-openai-sk}"
#: ../../getting_started/install/environment/environment.md:31
#: 759ae581883348019c1ba79e8954728a
#: 4bd27547ae6041679e91f2a363cd1deb
msgid "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions"
msgstr "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions"
#: ../../getting_started/install/environment/environment.md:33
#: 83f3952917d34aab80bd34119f7d1e20
#: cfa3071afb0b47baad6bd729d4a02cb9
msgid "from https://bard.google.com/ f12-> application-> __Secure-1PSID"
msgstr "from https://bard.google.com/ f12-> application-> __Secure-1PSID"
#: ../../getting_started/install/environment/environment.md:35
#: 1d70707ca82749bb90b2bed1aee44d62
#: a17efa03b10f47f68afac9e865982a75
msgid "BARD_PROXY_API_KEY={your-bard-token}"
msgstr "BARD_PROXY_API_KEY={your-bard-token}"
#: ../../getting_started/install/environment/environment.md:38
#: 38a2091fa223493ea23cb9bbb33cf58e
#: 6bcfe90574da4d82a459e8e11bf73cba
msgid "DATABASE SETTINGS"
msgstr "DATABASE SETTINGS"
#: ../../getting_started/install/environment/environment.md:39
#: 5134180d7a5945b48b072a1eb92b27ba
#: 2b1e62d9bf5d4af5a22f68c8248eaafb
msgid "SQLite database (Current default database)"
msgstr "SQLite database (Current default database)"
#: ../../getting_started/install/environment/environment.md:40
#: 6875e2300e094668a45fa4f2551e0d30
#: 8a909ac3b3c943da8dbc4e8dd596c80c
msgid "LOCAL_DB_PATH=data/default_sqlite.db"
msgstr "LOCAL_DB_PATH=data/default_sqlite.db"
#: ../../getting_started/install/environment/environment.md:41
#: 034e8f06f24f44af9d8184563f99b4b3
#: 90ae6507932f4815b6e180051738bb93
msgid "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite"
msgstr "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite"
#: ../../getting_started/install/environment/environment.md:43
#: f688149a97f740269f80b79775236ce9
#: d2ce34e0dcf44ccf9e8007d548ba7b0a
msgid "MYSQL database"
msgstr "MYSQL database"
#: ../../getting_started/install/environment/environment.md:44
#: 6db0b305137d45a3aa036e4f2262f460
#: c07159d63c334f6cbb95fcc30bfb7ea5
msgid "LOCAL_DB_TYPE=mysql"
msgstr "LOCAL_DB_TYPE=mysql"
#: ../../getting_started/install/environment/environment.md:45
#: b6d662ce8d5f44f0b54a7f6e7c66f5a5
#: e16700b2ea8d411e91d010c1cde7aecc
msgid "LOCAL_DB_USER=root"
msgstr "LOCAL_DB_USER=root"
#: ../../getting_started/install/environment/environment.md:46
#: cd7493d61ac9415283640dc6c018d2f4
#: bfc2dce1bf374121b6861e677b4e1ffa
msgid "LOCAL_DB_PASSWORD=aa12345678"
msgstr "LOCAL_DB_PASSWORD=aa12345678"
#: ../../getting_started/install/environment/environment.md:47
#: 4ea2a622b23f4342a4c2ab7f8d9c4e8d
#: bc384739f5b04e21a34d0d2b78e7906c
msgid "LOCAL_DB_HOST=127.0.0.1"
msgstr "LOCAL_DB_HOST=127.0.0.1"
#: ../../getting_started/install/environment/environment.md:48
#: 936db95a0ab246098028f4dbb596cd17
#: e5253d452e0d42b7ac308fe6fbfb5017
msgid "LOCAL_DB_PORT=3306"
msgstr "LOCAL_DB_PORT=3306"
#: ../../getting_started/install/environment/environment.md:51
#: d9255f25989840ea9c9e7b34f3947c87
#: 9ca8f6fe06ed4cbab390f94be252e165
msgid "EMBEDDING SETTINGS"
msgstr "EMBEDDING SETTINGS"
#: ../../getting_started/install/environment/environment.md:52
#: b09291d32aca43928a981e873476a985
#: 76c7c260293c4b49bae057143fd48377
msgid "EMBEDDING MODEL Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
msgstr "EMBEDDING模型, 参考see /pilot/configs/model_config.LLM_MODEL_CONFIG"
#: ../../getting_started/install/environment/environment.md:53
#: 63de573b03a54413b997f18a1ccee279
#: f1d63a0128ce493cae37d34f1976bcca
msgid "EMBEDDING_MODEL=text2vec"
msgstr "EMBEDDING_MODEL=text2vec"
#: ../../getting_started/install/environment/environment.md:55
#: 0ef8defbab544bd0b9475a036f278489
#: b8fbb99109d04781b2dd5bc5d6efa5bd
msgid "Embedding Chunk size, default 500"
msgstr "Embedding 切片大小, 默认500"
#: ../../getting_started/install/environment/environment.md:57
#: 33dbc7941d054baa8c6ecfc0bf1ce271
#: bf8256576ea34f6a9c5f261ab9aab676
msgid "KNOWLEDGE_CHUNK_SIZE=500"
msgstr "KNOWLEDGE_CHUNK_SIZE=500"
#: ../../getting_started/install/environment/environment.md:59
#: e6ee9f2620ab45ecbc8e9c0642f5ca42
#: 9b156c6b599b4c02a58ce023b4ff25f2
msgid "Embedding Chunk Overlap, default 100"
msgstr "Embedding chunk Overlap, 文本块之间的最大重叠量。保留一些重叠可以保持文本块之间的连续性(例如使用滑动窗口),默认100"
#: ../../getting_started/install/environment/environment.md:60
#: fcddf64340a04df4ab95176fc2fc67a6
#: dcafd903c36041ac85ac99a14dbee512
msgid "KNOWLEDGE_CHUNK_OVERLAP=100"
msgstr "KNOWLEDGE_CHUNK_OVERLAP=100"
#: ../../getting_started/install/environment/environment.md:62
#: 61272200194b4461a921581feb1273da
#: 6c3244b7e5e24b0188c7af4bb52e9134
#, fuzzy
msgid "embedding recall top k,5"
msgstr "embedding 召回topk, 默认5"
#: ../../getting_started/install/environment/environment.md:64
#: b433091f055542b1b89ff2d525ac99e4
#: f4a2f30551cf4fe1a7ff3c7c74ec77be
msgid "KNOWLEDGE_SEARCH_TOP_SIZE=5"
msgstr "KNOWLEDGE_SEARCH_TOP_SIZE=5"
#: ../../getting_started/install/environment/environment.md:66
#: 1db0de41aebd4caa8cc2eaecb4cacd6a
#: 593f2512362f467e92fdaa60dd5903a0
#, fuzzy
msgid "embedding recall max token ,2000"
msgstr "embedding向量召回最大token, 默认2000"
#: ../../getting_started/install/environment/environment.md:68
#: 81b9d862e58941a4b09680a7520cdabe
#: 83d6d28914be4d6282d457272e508ddc
msgid "KNOWLEDGE_SEARCH_MAX_TOKEN=5"
msgstr "KNOWLEDGE_SEARCH_MAX_TOKEN=5"
#: ../../getting_started/install/environment/environment.md:71
#: ../../getting_started/install/environment/environment.md:87
#: cac73575d54544778bdee09b18532fd9 f78a509949a64f03aa330f31901e2e7a
#: 6bc1b9d995e74294a1c78e783c550db7 d33c77ded834438e9f4a2df06e7e041a
msgid "Vector Store SETTINGS"
msgstr "Vector Store SETTINGS"
#: ../../getting_started/install/environment/environment.md:72
#: ../../getting_started/install/environment/environment.md:88
#: 5ebba1cb047b4b09849000244237dfbb 7e9285e91bcb4b2d9413909c0d0a06a7
#: 9cafa06e2d584f70afd848184e0fa52a f01057251b8b4ffea806192dfe1048ed
msgid "Chroma"
msgstr "Chroma"
#: ../../getting_started/install/environment/environment.md:73
#: ../../getting_started/install/environment/environment.md:89
#: 05625cfcc23c4745ae1fa0d94ce5450c 3a8615f1507f4fc49d1adda5100a4edf
#: e6c16fab37484769b819aeecbc13e6db faad299722e5400e95ec6ac3c1e018b8
msgid "VECTOR_STORE_TYPE=Chroma"
msgstr "VECTOR_STORE_TYPE=Chroma"
#: ../../getting_started/install/environment/environment.md:74
#: ../../getting_started/install/environment/environment.md:90
#: 5b559376aea44f159262e6d4b75c7ec1 e954782861404b10b4e893e01cf74452
#: 4eca3a51716d406f8ffd49c06550e871 581ee9dd38064b119660c44bdd00cbaa
msgid "MILVUS"
msgstr "MILVUS"
#: ../../getting_started/install/environment/environment.md:75
#: ../../getting_started/install/environment/environment.md:91
#: 55ee8199c97a4929aeefd32370f2b92d 8f40c02543ea4a2ca9632dd9e8a08c2e
#: 814c93048bed46589358a854d6c99683 b72b1269a2224f5f961214e41c019f21
msgid "VECTOR_STORE_TYPE=Milvus"
msgstr "VECTOR_STORE_TYPE=Milvus"
#: ../../getting_started/install/environment/environment.md:76
#: ../../getting_started/install/environment/environment.md:92
#: 528a01d25720491c8e086bf43a62ad92 ba1386d551d7494a85681a2803081a6f
#: 73ae665f1db9402883662734588fd02c c4da20319c994e83ba5a7706db967178
msgid "MILVUS_URL=127.0.0.1"
msgstr "MILVUS_URL=127.0.0.1"
#: ../../getting_started/install/environment/environment.md:77
#: ../../getting_started/install/environment/environment.md:93
#: b031950dafcd4d4783c120dc933c4178 c2e9c8cdd41741e3aba01e59a6ef245d
#: e30c5288516d42aa858a485db50490c1 f843b2e58bcb4e4594e3c28499c341d0
msgid "MILVUS_PORT=19530"
msgstr "MILVUS_PORT=19530"
#: ../../getting_started/install/environment/environment.md:78
#: ../../getting_started/install/environment/environment.md:94
#: 27b0a64af6434cb2840373e2b38c9bd5 d0e4d79af7954b129ffff7303a1ec3ce
#: 158669efcc7d4bcaac1c8dd01b499029 24e88ffd32f242f281c56c0ec3ad2639
msgid "MILVUS_USERNAME"
msgstr "MILVUS_USERNAME"
#: ../../getting_started/install/environment/environment.md:79
#: ../../getting_started/install/environment/environment.md:95
#: 27aa1a5b61e64dd6bfe29124e274809e 5c58892498ce4f46a59f54b2887822d4
#: 111a985297184c8aa5a0dd8e14a58445 6602093a6bb24d6792548e2392105c82
msgid "MILVUS_PASSWORD"
msgstr "MILVUS_PASSWORD"
#: ../../getting_started/install/environment/environment.md:80
#: ../../getting_started/install/environment/environment.md:96
#: 009e57d4acc5434da2146f0545911c85 bac8888dcbff47fbb0ea8ae685445aac
#: 47bdfcd78fbe4ccdb5f49b717a6d01a6 b96c0545b2044926a8a8190caf94ad25
msgid "MILVUS_SECURE="
msgstr "MILVUS_SECURE="
#: ../../getting_started/install/environment/environment.md:82
#: ../../getting_started/install/environment/environment.md:98
#: a6eeb16ab5274045bee88ecc3d93e09e eb341774403d47658b9b7e94c4c16d5c
#: 755c32b5d6c54607907a138b5474c0ec ff4f2a7ddaa14f089dda7a14e1062c36
msgid "WEAVIATE"
msgstr "WEAVIATE"
#: ../../getting_started/install/environment/environment.md:83
#: fbd97522d8da4824b41b99298fd41069
#: 23b2ce83385d40a589a004709f9864be
msgid "VECTOR_STORE_TYPE=Weaviate"
msgstr "VECTOR_STORE_TYPE=Weaviate"
#: ../../getting_started/install/environment/environment.md:84
#: ../../getting_started/install/environment/environment.md:99
#: 341785b4abfe42b5af1c2e04497261f4 a81cc2aabc8240f3ac1f674d9350bff4
#: 9acef304d89a448a9e734346705ba872 cf5151b6c1594ccd8beb1c3f77769acb
msgid "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network"
msgstr "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network"
#: ../../getting_started/install/environment/environment.md:102
#: 5bb9e5daa36241d499089c1b1910f729
#: c3003516b2364051bf34f8c3086e348a
msgid "Multi-GPU Setting"
msgstr "Multi-GPU Setting"
#: ../../getting_started/install/environment/environment.md:103
#: 30df45b7f1f7423c9f18c6360f0b7600
#: ade8fc381c5e438aa29d159c10041713
msgid ""
"See https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-"
"visibility-cuda_visible_devices/ If CUDA_VISIBLE_DEVICES is not "
@ -315,49 +315,49 @@ msgstr ""
"cuda_visible_devices/ 如果 CUDA_VISIBLE_DEVICES没有设置, 会使用所有可用的gpu"
#: ../../getting_started/install/environment/environment.md:106
#: 8631ea968dfb4d90a7ae6bdb2acdfdce
#: e137bd19be5e410ba6709027dbf2923a
msgid "CUDA_VISIBLE_DEVICES=0"
msgstr "CUDA_VISIBLE_DEVICES=0"
#: ../../getting_started/install/environment/environment.md:108
#: 0010422280dd4fe79326ebceb2a66f0e
#: 7669947acbdc4b1d92bcc029a8353a5d
msgid ""
"Optionally, you can also specify the gpu ID to use before the starting "
"command"
msgstr "你也可以通过启动命令设置gpu ID"
#: ../../getting_started/install/environment/environment.md:110
#: 00106f7341304fbd9425721ea8e6a261
#: 751743d1753b4051beea46371278d793
msgid "CUDA_VISIBLE_DEVICES=3,4,5,6"
msgstr "CUDA_VISIBLE_DEVICES=3,4,5,6"
#: ../../getting_started/install/environment/environment.md:112
#: 720aa8b3478744d78e4b10dfeccb50b4
#: 3acc3de0af0d4df2bb575e161e377f85
msgid "You can configure the maximum memory used by each GPU."
msgstr "可以设置GPU的最大内存"
#: ../../getting_started/install/environment/environment.md:114
#: f9639ac96a244296832c75bbcbdae2af
#: 67f1d9b172b84294a44ecace5436e6e0
msgid "MAX_GPU_MEMORY=16Gib"
msgstr "MAX_GPU_MEMORY=16Gib"
#: ../../getting_started/install/environment/environment.md:117
#: fc4d955fdb3e4256af5c8f29b042dcd6
#: 3c69dfe48bcf46b89b76cac1e7849a66
msgid "Other Setting"
msgstr "Other Setting"
#: ../../getting_started/install/environment/environment.md:118
#: 66b14a834e884339be2d48392e884933
#: d5015b70f4fe4d20a63de9d87f86957a
msgid "Language Settings(influence prompt language)"
msgstr "Language Settings(涉及prompt语言以及知识切片方式)"
#: ../../getting_started/install/environment/environment.md:119
#: 5c9f05174eb84edd9e1316cc0721a840
#: 5543c28bb8e34c9fb3bb6b063c2b1750
msgid "LANGUAGE=en"
msgstr "LANGUAGE=en"
#: ../../getting_started/install/environment/environment.md:120
#: 7f6d62117d024c51bba9255fa4fcf151
#: cb4ed5b892ee41068c1ca76cb29aa400
msgid "LANGUAGE=zh"
msgstr "LANGUAGE=zh"

View File

@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 0.3.0\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-11-02 10:10+0800\n"
"POT-Creation-Date: 2023-11-02 21:04+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
@ -19,11 +19,11 @@ msgstr ""
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../modules/knowledge.md:1 436b94d3a8374ed18feb5c14893a84e6
#: ../../modules/knowledge.md:1 b94b3b15cb2441ed9d78abd222a717b7
msgid "Knowledge"
msgstr "知识"
#: ../../modules/knowledge.md:3 918a3747cbed42d18b8c9c4547e67b14
#: ../../modules/knowledge.md:3 c6d6e308a6ce42948d29e928136ef561
#, fuzzy
msgid ""
"As the knowledge base is currently the most significant user demand "
@ -34,15 +34,15 @@ msgstr ""
"由于知识库是当前用户需求最显著的场景,我们原生支持知识库的构建和处理。同时,我们还在本项目中提供了多种知识库管理策略,如:pdf,md , "
"txt, word, ppt"
#: ../../modules/knowledge.md:4 d4d4b5d57918485aafa457bb9fdcf626
#: ../../modules/knowledge.md:4 268abc408d40410ba90cf5f121dc5270
msgid "Default built-in knowledge base"
msgstr ""
#: ../../modules/knowledge.md:5 d4d4b5d57918485aafa457bb9fdcf626
#: ../../modules/knowledge.md:5 558c3364c38b458a8ebf81030efc2a48
msgid "Custom addition of knowledge bases"
msgstr ""
#: ../../modules/knowledge.md:6 984361ce835c4c3492e29e1fb897348a
#: ../../modules/knowledge.md:6 9cb3ce62da1440579c095848c7aef88c
msgid ""
"Various usage scenarios such as constructing knowledge bases through "
"plugin capabilities and web crawling. Users only need to organize the "
@ -50,53 +50,53 @@ msgid ""
"the knowledge base required for the large model."
msgstr ""
#: ../../modules/knowledge.md:9 746e4fbd3212460198be51b90caee2c8
#: ../../modules/knowledge.md:9 b8ca6bc4dd9845baa56e36eea7fac2a2
#, fuzzy
msgid "Create your own knowledge repository"
msgstr "创建你自己的知识库"
#: ../../modules/knowledge.md:11 1c46b33b0532417c824efbaa3e687c3f
#: ../../modules/knowledge.md:11 17d7178a67924f43aa5b6293707ef041
msgid ""
"1.Place personal knowledge files or folders in the pilot/datasets "
"directory."
msgstr ""
#: ../../modules/knowledge.md:13 3b16f387b5354947a89d6df77bd65bdb
#: ../../modules/knowledge.md:13 31c31f14bf444981939689f9a9fb038a
#, fuzzy
msgid ""
"We currently support many document formats: txt, pdf, md, html, doc, ppt,"
" and url."
msgstr "当前支持txt, pdf, md, html, doc, ppt, url文档格式"
#: ../../modules/knowledge.md:15 09ec337d7da4418db854e58afb6c0980
#: ../../modules/knowledge.md:15 9ad2f2e05f8842a9b9d8469a3704df23
msgid "before execution:"
msgstr "开始前"
#: ../../modules/knowledge.md:22 c09b3decb018485f8e56830ddc156194
#: ../../modules/knowledge.md:22 6fd2775914b641c4b8e486417b558ea6
msgid ""
"2.Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma "
"(now only support Chroma and Milvus, if you set Milvus, please set "
"MILVUS_URL and MILVUS_PORT)"
msgstr ""
#: ../../modules/knowledge.md:25 74460ec7709441d5945ce9f745a26d20
#: ../../modules/knowledge.md:25 131c5f58898a4682940910980edb2043
msgid "2.Run the knowledge repository initialization command"
msgstr ""
#: ../../modules/knowledge.md:31 4498ec4e46ff4e24b45dd855e829bd32
#: ../../modules/knowledge.md:31 2cf550f17881497bb881b19efcc18c23
msgid ""
"Optionally, you can run `dbgpt knowledge load --help` command to see more"
" usage."
msgstr ""
#: ../../modules/knowledge.md:33 5048ac3289e540f2a2b5fd0e5ed043f5
#: ../../modules/knowledge.md:33 c8a2ea571b944bdfbcad48fa8b54fcc9
msgid ""
"3.Add the knowledge repository in the interface by entering the name of "
"your knowledge repository (if not specified, enter \"default\") so you "
"can use it for Q&A based on your knowledge base."
msgstr ""
#: ../../modules/knowledge.md:35 deeccff20f7f453dad0881b63dae2a18
#: ../../modules/knowledge.md:35 b701170ad75e49dea7d7734c15681e0f
msgid ""
"Note that the default vector model used is text2vec-large-chinese (which "
"is a large model, so if your personal computer configuration is not "

View File

View File

@ -46,6 +46,8 @@ class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
MODEL_REGISTRY = "dbgpt_model_registry"
MODEL_API_SERVER = "dbgpt_model_api_server"
AGENT_HUB = "dbgpt_agent_hub"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
TRACER = "dbgpt_tracer"
@ -69,7 +71,6 @@ class BaseComponent(LifeCycle, ABC):
This method needs to be implemented by every component to define how it integrates
with the main system app.
"""
pass
T = TypeVar("T", bound=BaseComponent)
@ -91,13 +92,28 @@ class SystemApp(LifeCycle):
"""Returns the internal ASGI app."""
return self._asgi_app
def register(self, component: Type[BaseComponent], *args, **kwargs):
"""Register a new component by its type."""
def register(self, component: Type[BaseComponent], *args, **kwargs) -> T:
"""Register a new component by its type.
Args:
component (Type[BaseComponent]): The component class to register
Returns:
T: The instance of registered component
"""
instance = component(self, *args, **kwargs)
self.register_instance(instance)
return instance
def register_instance(self, instance: T):
"""Register an already initialized component."""
def register_instance(self, instance: T) -> T:
"""Register an already initialized component.
Args:
instance (T): The component instance to register
Returns:
T: The instance of registered component
"""
name = instance.name
if isinstance(name, ComponentType):
name = name.value
@ -108,18 +124,34 @@ class SystemApp(LifeCycle):
logger.info(f"Register component with name {name} and instance: {instance}")
self.components[name] = instance
instance.init_app(self)
return instance
def get_component(
self,
name: Union[str, ComponentType],
component_type: Type[T],
default_component=_EMPTY_DEFAULT_COMPONENT,
or_register_component: Type[BaseComponent] = None,
*args,
**kwargs,
) -> T:
"""Retrieve a registered component by its name and type."""
"""Retrieve a registered component by its name and type.
Args:
name (Union[str, ComponentType]): Component name
component_type (Type[T]): The type of current retrieve component
default_component : The default component instance if not retrieve by name
or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
Returns:
T: The instance retrieved by component name
"""
if isinstance(name, ComponentType):
name = name.value
component = self.components.get(name)
if not component:
if or_register_component:
return self.register(or_register_component, *args, **kwargs)
if default_component != _EMPTY_DEFAULT_COMPONENT:
return default_component
raise ValueError(f"No component found with name {name}")

View File

@ -78,6 +78,10 @@ LLM_MODEL_CONFIG = {
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
"codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"),
"codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"),
"codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"),
"codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
# For test now
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
}

View File

@ -320,6 +320,19 @@ class Llama2Adapter(BaseLLMAdaper):
return model, tokenizer
class CodeLlamaAdapter(BaseLLMAdaper):
"""The model adapter for codellama"""
def match(self, model_path: str):
return "codellama" in model_path.lower()
def loader(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
class BaichuanAdapter(BaseLLMAdaper):
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""
@ -420,6 +433,7 @@ register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
register_llm_model_adapters(Llama2Adapter)
register_llm_model_adapters(CodeLlamaAdapter)
register_llm_model_adapters(BaichuanAdapter)
register_llm_model_adapters(WizardLMAdapter)
register_llm_model_adapters(LlamaCppAdapater)

View File

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
from enum import Enum
from typing import TypedDict, Optional, Dict, List
from typing import TypedDict, Optional, Dict, List, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from pilot.utils.parameter_utils import ParameterDescription
@ -52,6 +52,8 @@ class ModelOutput:
text: str
error_code: int
model_context: Dict = None
finish_reason: str = None
usage: Dict[str, Any] = None
def to_dict(self) -> Dict:
return asdict(self)

View File

@ -8,6 +8,7 @@ from pilot.configs.model_config import LOGDIR
from pilot.model.base import WorkerApplyType
from pilot.model.parameter import (
ModelControllerParameters,
ModelAPIServerParameters,
ModelWorkerParameters,
ModelParameters,
BaseParameters,
@ -441,15 +442,27 @@ def stop_model_worker(port: int):
@click.command(name="apiserver")
@EnvArgumentParser.create_click_option(ModelAPIServerParameters)
def start_apiserver(**kwargs):
"""Start apiserver(TODO)"""
raise NotImplementedError
"""Start apiserver"""
if kwargs["daemon"]:
log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log")
_run_current_with_daemon("ModelAPIServer", log_file)
else:
from pilot.model.cluster import run_apiserver
run_apiserver()
@click.command(name="apiserver")
def stop_apiserver(**kwargs):
"""Start apiserver(TODO)"""
raise NotImplementedError
@add_stop_server_options
def stop_apiserver(port: int):
"""Stop apiserver"""
name = "ModelAPIServer"
if port:
name = f"{name}-{port}"
_stop_service("apiserver", name, port=port)
def _stop_all_model_server(**kwargs):

View File

@ -21,6 +21,7 @@ from pilot.model.cluster.controller.controller import (
run_model_controller,
BaseModelController,
)
from pilot.model.cluster.apiserver.api import run_apiserver
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
@ -40,4 +41,5 @@ __all__ = [
"ModelRegistryClient",
"RemoteWorkerManager",
"run_model_controller",
"run_apiserver",
]

View File

@ -0,0 +1,443 @@
"""A server that provides OpenAI-compatible RESTful APIs. It supports:
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
"""
from typing import Optional, List, Dict, Any, Generator
import logging
import asyncio
import shortuuid
import json
from fastapi import APIRouter, FastAPI
from fastapi import Depends, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseSettings
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponse,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
ChatCompletionResponseChoice,
DeltaMessage,
EmbeddingsRequest,
EmbeddingsResponse,
ErrorResponse,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
from fastchat.protocol.api_protocol import (
APIChatCompletionRequest,
APITokenCheckRequest,
APITokenCheckResponse,
APITokenCheckResponseItem,
)
from fastchat.serve.openai_api_server import create_error_response, check_requests
from fastchat.constants import ErrorCode
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.utils.parameter_utils import EnvArgumentParser
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.model.base import ModelInstance, ModelOutput
from pilot.model.parameter import ModelAPIServerParameters, WorkerType
from pilot.model.cluster import ModelRegistry, ModelRegistryClient
from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
from pilot.utils.utils import setup_logging
logger = logging.getLogger(__name__)
class APIServerException(Exception):
def __init__(self, code: int, message: str):
self.code = code
self.message = message
class APISettings(BaseSettings):
api_keys: Optional[List[str]] = None
api_settings = APISettings()
get_bearer_token = HTTPBearer(auto_error=False)
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
) -> str:
if api_settings.api_keys:
if auth is None or (token := auth.credentials) not in api_settings.api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
class APIServer(BaseComponent):
name = ComponentType.MODEL_API_SERVER
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def get_worker_manager(self) -> WorkerManager:
"""Get the worker manager component instance
Raises:
APIServerException: If can't get worker manager component instance
"""
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
if not worker_manager:
raise APIServerException(
ErrorCode.INTERNAL_ERROR,
f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app",
)
return worker_manager
def get_model_registry(self) -> ModelRegistry:
"""Get the model registry component instance
Raises:
APIServerException: If can't get model registry component instance
"""
controller = self.system_app.get_component(
ComponentType.MODEL_REGISTRY, ModelRegistry
)
if not controller:
raise APIServerException(
ErrorCode.INTERNAL_ERROR,
f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app",
)
return controller
async def get_model_instances_or_raise(
self, model_name: str
) -> List[ModelInstance]:
"""Get healthy model instances with request model name
Args:
model_name (str): Model name
Raises:
APIServerException: If can't get healthy model instances with request model name
"""
registry = self.get_model_registry()
registry_model_name = f"{model_name}@llm"
model_instances = await registry.get_all_instances(
registry_model_name, healthy_only=True
)
if not model_instances:
all_instances = await registry.get_all_model_instances(healthy_only=True)
models = [
ins.model_name.split("@llm")[0]
for ins in all_instances
if ins.model_name.endswith("@llm")
]
if models:
models = "&&".join(models)
message = f"Only {models} allowed now, your model {model_name}"
else:
message = f"No models allowed now, your model {model_name}"
raise APIServerException(ErrorCode.INVALID_MODEL, message)
return model_instances
async def get_available_models(self) -> ModelList:
"""Return available models
Just include LLM and embedding models.
Returns:
List[ModelList]: The list of models.
"""
registry = self.get_model_registry()
model_instances = await registry.get_all_model_instances(healthy_only=True)
model_name_set = set()
for inst in model_instances:
name, worker_type = WorkerType.parse_worker_key(inst.model_name)
if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC:
model_name_set.add(name)
models = list(model_name_set)
models.sort()
# TODO: return real model permission details
model_cards = []
for m in models:
model_cards.append(
ModelCard(
id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()]
)
)
return ModelList(data=model_cards)
async def chat_completion_stream_generator(
self, model_name: str, params: Dict[str, Any], n: int
) -> Generator[str, Any, None]:
"""Chat stream completion generator
Args:
model_name (str): Model name
params (Dict[str, Any]): The parameters pass to model worker
n (int): How many completions to generate for each prompt.
"""
worker_manager = self.get_worker_manager()
id = f"chatcmpl-{shortuuid.random()}"
finish_stream_events = []
for i in range(n):
# First chunk with role
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
previous_text = ""
async for model_output in worker_manager.generate_stream(params):
model_output: ModelOutput = model_output
if model_output.error_code != 0:
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
decoded_unicode = model_output.text.replace("\ufffd", "")
delta_text = decoded_unicode[len(previous_text) :]
previous_text = (
decoded_unicode
if len(decoded_unicode) > len(previous_text)
else previous_text
)
if len(delta_text) == 0:
delta_text = None
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
finish_reason=model_output.finish_reason,
)
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model_name
)
if delta_text is None:
if model_output.finish_reason is not None:
finish_stream_events.append(chunk)
continue
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
for finish_chunk in finish_stream_events:
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
async def chat_completion_generate(
self, model_name: str, params: Dict[str, Any], n: int
) -> ChatCompletionResponse:
"""Generate completion
Args:
model_name (str): Model name
params (Dict[str, Any]): The parameters pass to model worker
n (int): How many completions to generate for each prompt.
"""
worker_manager: WorkerManager = self.get_worker_manager()
choices = []
chat_completions = []
for i in range(n):
model_output = asyncio.create_task(worker_manager.generate(params))
chat_completions.append(model_output)
try:
all_tasks = await asyncio.gather(*chat_completions)
except Exception as e:
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
usage = UsageInfo()
for i, model_output in enumerate(all_tasks):
model_output: ModelOutput = model_output
if model_output.error_code != 0:
return create_error_response(model_output.error_code, model_output.text)
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role="assistant", content=model_output.text),
finish_reason=model_output.finish_reason or "stop",
)
)
if model_output.usage:
task_usage = UsageInfo.parse_obj(model_output.usage)
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
def get_api_server() -> APIServer:
api_server = global_system_app.get_component(
ComponentType.MODEL_API_SERVER, APIServer, default_component=None
)
if not api_server:
global_system_app.register(APIServer)
return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer)
router = APIRouter()
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
async def get_available_models(api_server: APIServer = Depends(get_api_server)):
return await api_server.get_available_models()
@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
async def create_chat_completion(
request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server)
):
await api_server.get_model_instances_or_raise(request.model)
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret
params = {
"model": request.model,
"messages": ModelMessage.to_dict_list(
ModelMessage.from_openai_messages(request.messages)
),
"echo": False,
}
if request.temperature:
params["temperature"] = request.temperature
if request.top_p:
params["top_p"] = request.top_p
if request.max_tokens:
params["max_new_tokens"] = request.max_tokens
if request.stop:
params["stop"] = request.stop
if request.user:
params["user"] = request.user
# TODO check token length
if request.stream:
generator = api_server.chat_completion_stream_generator(
request.model, params, request.n
)
return StreamingResponse(generator, media_type="text/event-stream")
return await api_server.chat_completion_generate(request.model, params, request.n)
def _initialize_all(controller_addr: str, system_app: SystemApp):
from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient
from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
if not system_app.get_component(
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
):
# Register model registry if not exist
registry = ModelRegistryClient(controller_addr)
registry.name = ComponentType.MODEL_REGISTRY.value
system_app.register_instance(registry)
registry = system_app.get_component(
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
)
worker_manager = RemoteWorkerManager(registry)
# Register worker manager component if not exist
system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
or_register_component=_DefaultWorkerManagerFactory,
worker_manager=worker_manager,
)
# Register api server component if not exist
system_app.get_component(
ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer
)
def initialize_apiserver(
controller_addr: str,
app=None,
system_app: SystemApp = None,
host: str = None,
port: int = None,
api_keys: List[str] = None,
):
global global_system_app
global api_settings
embedded_mod = True
if not app:
embedded_mod = False
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
if not system_app:
system_app = SystemApp(app)
global_system_app = system_app
if api_keys:
api_settings.api_keys = api_keys
app.include_router(router, prefix="/api", tags=["APIServer"])
@app.exception_handler(APIServerException)
async def validation_apiserver_exception_handler(request, exc: APIServerException):
return create_error_response(exc.code, exc.message)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
_initialize_all(controller_addr, system_app)
if not embedded_mod:
import uvicorn
uvicorn.run(app, host=host, port=port, log_level="info")
def run_apiserver():
parser = EnvArgumentParser()
env_prefix = "apiserver_"
apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass(
ModelAPIServerParameters,
env_prefixes=[env_prefix],
)
setup_logging(
"pilot",
logging_level=apiserver_params.log_level,
logger_filename=apiserver_params.log_file,
)
api_keys = None
if apiserver_params.api_keys:
api_keys = apiserver_params.api_keys.strip().split(",")
initialize_apiserver(
apiserver_params.controller_addr,
host=apiserver_params.host,
port=apiserver_params.port,
api_keys=api_keys,
)
if __name__ == "__main__":
run_apiserver()

View File

@ -0,0 +1,248 @@
import pytest
import pytest_asyncio
from aioresponses import aioresponses
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient, HTTPError
from pilot.component import SystemApp
from pilot.utils.openai_utils import chat_completion_stream, chat_completion
from pilot.model.cluster.apiserver.api import (
api_settings,
initialize_apiserver,
ModelList,
UsageInfo,
ChatCompletionResponse,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
ChatCompletionResponseChoice,
DeltaMessage,
)
from pilot.model.cluster.tests.conftest import _new_cluster
from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
@pytest_asyncio.fixture
async def system_app():
return SystemApp(app)
@pytest_asyncio.fixture
async def client(request, system_app: SystemApp):
param = getattr(request, "param", {})
api_keys = param.get("api_keys", [])
client_api_key = param.get("client_api_key")
if "num_workers" not in param:
param["num_workers"] = 2
if "api_keys" in param:
del param["api_keys"]
headers = {}
if client_api_key:
headers["Authorization"] = "Bearer " + client_api_key
print(f"param: {param}")
if api_settings:
# Clear global api keys
api_settings.api_keys = []
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
async with _new_cluster(**param) as cluster:
worker_manager, model_registry = cluster
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
system_app.register_instance(model_registry)
# print(f"Instances {model_registry.registry}")
initialize_apiserver(None, app, system_app, api_keys=api_keys)
yield client
@pytest.mark.asyncio
async def test_get_all_models(client: AsyncClient):
res = await client.get("/api/v1/models")
res.status_code == 200
model_lists = ModelList.parse_obj(res.json())
print(f"model list json: {res.json()}")
assert model_lists.object == "list"
assert len(model_lists.data) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages",
[
({"stream_messags": ["Hello", " world."]}, "Hello world."),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
],
indirect=["client"],
)
async def test_chat_completions(client: AsyncClient, expected_messages):
chat_data = {
"model": "test-model-name-0",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
}
full_text = ""
async for text in chat_completion_stream(
"/api/v1/chat/completions", chat_data, client
):
full_text += text
assert full_text == expected_messages
assert (
await chat_completion("/api/v1/chat/completions", chat_data, client)
== expected_messages
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages, client_api_key",
[
(
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
],
indirect=["client"],
)
async def test_chat_completions_with_openai_lib_async_no_stream(
client: AsyncClient, expected_messages: str, client_api_key: str
):
import openai
openai.api_key = client_api_key
openai.api_base = "http://test/api/v1"
model_name = "test-model-name-0"
with aioresponses() as mocked:
mock_message = {"text": expected_messages}
one_res = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=expected_messages),
finish_reason="stop",
)
data = ChatCompletionResponse(
model=model_name, choices=[one_res], usage=UsageInfo()
)
mock_message = f"{data.json(exclude_unset=True, ensure_ascii=False)}\n\n"
# Mock http request
mocked.post(
"http://test/api/v1/chat/completions", status=200, body=mock_message
)
completion = await openai.ChatCompletion.acreate(
model=model_name,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
)
assert completion.choices[0].message.content == expected_messages
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages, client_api_key",
[
(
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
],
indirect=["client"],
)
async def test_chat_completions_with_openai_lib_async_stream(
client: AsyncClient, expected_messages: str, client_api_key: str
):
import openai
openai.api_key = client_api_key
openai.api_base = "http://test/api/v1"
model_name = "test-model-name-0"
with aioresponses() as mocked:
mock_message = {"text": expected_messages}
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=expected_messages),
finish_reason="stop",
)
chunk = ChatCompletionStreamResponse(
id=0, choices=[choice_data], model=model_name
)
mock_message = f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
mocked.post(
"http://test/api/v1/chat/completions",
status=200,
body=mock_message,
content_type="text/event-stream",
)
stream_stream_resp = ""
async for stream_resp in await openai.ChatCompletion.acreate(
model=model_name,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
stream=True,
):
stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "")
assert stream_stream_resp == expected_messages
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages, api_key_is_error",
[
(
{
"stream_messags": ["Hello", " world."],
"api_keys": ["abc", "xx"],
"client_api_key": "abc",
},
"Hello world.",
False,
),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。", False),
(
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc", "xx"]},
"你好,我是张三。",
True,
),
(
{
"stream_messags": ["你好,我是", "张三。"],
"api_keys": ["abc", "xx"],
"client_api_key": "error_api_key",
},
"你好,我是张三。",
True,
),
],
indirect=["client"],
)
async def test_chat_completions_with_api_keys(
client: AsyncClient, expected_messages: str, api_key_is_error: bool
):
chat_data = {
"model": "test-model-name-0",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
}
if api_key_is_error:
with pytest.raises(HTTPError):
await chat_completion("/api/v1/chat/completions", chat_data, client)
else:
assert (
await chat_completion("/api/v1/chat/completions", chat_data, client)
== expected_messages
)

View File

@ -66,7 +66,9 @@ class LocalModelController(BaseModelController):
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
)
if not model_name:
return await self.registry.get_all_model_instances()
return await self.registry.get_all_model_instances(
healthy_only=healthy_only
)
else:
return await self.registry.get_all_instances(model_name, healthy_only)
@ -98,8 +100,10 @@ class _RemoteModelController(BaseModelController):
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
async def get_all_model_instances(self) -> List[ModelInstance]:
return await self.get_all_instances()
async def get_all_model_instances(
self, healthy_only: bool = False
) -> List[ModelInstance]:
return await self.get_all_instances(healthy_only=healthy_only)
@sync_api_remote(path="/api/controller/models")
def sync_get_all_instances(

View File

@ -1,22 +1,37 @@
import random
import threading
import time
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import itertools
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import ModelInstance
class ModelRegistry(ABC):
logger = logging.getLogger(__name__)
class ModelRegistry(BaseComponent, ABC):
"""
Abstract base class for a model registry. It provides an interface
for registering, deregistering, fetching instances, and sending heartbeats
for instances.
"""
name = ComponentType.MODEL_REGISTRY
def __init__(self, system_app: SystemApp | None = None):
self.system_app = system_app
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
"""Initialize the component with the main application."""
self.system_app = system_app
@abstractmethod
async def register_instance(self, instance: ModelInstance) -> bool:
"""
@ -65,9 +80,11 @@ class ModelRegistry(ABC):
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
@abstractmethod
async def get_all_model_instances(self) -> List[ModelInstance]:
async def get_all_model_instances(
self, healthy_only: bool = False
) -> List[ModelInstance]:
"""
Fetch all instances of all models
Fetch all instances of all models, Optionally, fetch only the healthy instances.
Returns:
- List[ModelInstance]: A list of instances for the all models.
@ -105,8 +122,12 @@ class ModelRegistry(ABC):
class EmbeddedModelRegistry(ModelRegistry):
def __init__(
self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120
self,
system_app: SystemApp | None = None,
heartbeat_interval_secs: int = 60,
heartbeat_timeout_secs: int = 120,
):
super().__init__(system_app)
self.registry: Dict[str, List[ModelInstance]] = defaultdict(list)
self.heartbeat_interval_secs = heartbeat_interval_secs
self.heartbeat_timeout_secs = heartbeat_timeout_secs
@ -180,9 +201,14 @@ class EmbeddedModelRegistry(ModelRegistry):
instances = [ins for ins in instances if ins.healthy == True]
return instances
async def get_all_model_instances(self) -> List[ModelInstance]:
print(self.registry)
return list(itertools.chain(*self.registry.values()))
async def get_all_model_instances(
self, healthy_only: bool = False
) -> List[ModelInstance]:
logger.debug("Current registry metadata:\n{self.registry}")
instances = list(itertools.chain(*self.registry.values()))
if healthy_only:
instances = [ins for ins in instances if ins.healthy == True]
return instances
async def send_heartbeat(self, instance: ModelInstance) -> bool:
_, exist_ins = self._get_instances(

View File

View File

@ -6,6 +6,7 @@ from pilot.model.parameter import ModelParameters, ModelWorkerParameters, Worker
from pilot.model.base import ModelOutput
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.worker.manager import (
WorkerManager,
LocalWorkerManager,
RegisterFunc,
DeregisterFunc,
@ -13,6 +14,23 @@ from pilot.model.cluster.worker.manager import (
ApplyFunction,
)
from pilot.model.base import ModelInstance
from pilot.model.cluster.registry import ModelRegistry, EmbeddedModelRegistry
@pytest.fixture
def model_registry(request):
return EmbeddedModelRegistry()
@pytest.fixture
def model_instance():
return ModelInstance(
model_name="test_model",
host="192.168.1.1",
port=5000,
)
class MockModelWorker(ModelWorker):
def __init__(
@ -51,8 +69,10 @@ class MockModelWorker(ModelWorker):
raise Exception("Stop worker error for mock")
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
full_text = ""
for msg in self.stream_messags:
yield ModelOutput(text=msg, error_code=0)
full_text += msg
yield ModelOutput(text=full_text, error_code=0)
def generate(self, params: Dict) -> ModelOutput:
output = None
@ -67,6 +87,8 @@ class MockModelWorker(ModelWorker):
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
ClusterType = Tuple[WorkerManager, ModelRegistry]
def _new_worker_params(
model_name: str = _TEST_MODEL_NAME,
@ -85,7 +107,9 @@ def _create_workers(
worker_type: str = WorkerType.LLM.value,
stream_messags: List[str] = None,
embeddings: List[List[float]] = None,
) -> List[Tuple[ModelWorker, ModelWorkerParameters]]:
host: str = "127.0.0.1",
start_port=8001,
) -> List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]:
workers = []
for i in range(num_workers):
model_name = f"test-model-name-{i}"
@ -98,10 +122,16 @@ def _create_workers(
stream_messags=stream_messags,
embeddings=embeddings,
)
model_instance = ModelInstance(
model_name=WorkerType.to_worker_key(model_name, worker_type),
host=host,
port=start_port + i,
healthy=True,
)
worker_params = _new_worker_params(
model_name, model_path, worker_type=worker_type
)
workers.append((worker, worker_params))
workers.append((worker, worker_params, model_instance))
return workers
@ -127,12 +157,12 @@ async def _start_worker_manager(**kwargs):
model_registry=model_registry,
)
for worker, worker_params in _create_workers(
for worker, worker_params, model_instance in _create_workers(
num_workers, error_worker, stop_error, stream_messags, embeddings
):
worker_manager.add_worker(worker, worker_params)
if workers:
for worker, worker_params in workers:
for worker, worker_params, model_instance in workers:
worker_manager.add_worker(worker, worker_params)
if start:
@ -143,6 +173,15 @@ async def _start_worker_manager(**kwargs):
await worker_manager.stop()
async def _create_model_registry(
workers: List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]
) -> ModelRegistry:
registry = EmbeddedModelRegistry()
for _, _, inst in workers:
assert await registry.register_instance(inst) == True
return registry
@pytest_asyncio.fixture
async def manager_2_workers(request):
param = getattr(request, "param", {})
@ -166,3 +205,27 @@ async def manager_2_embedding_workers(request):
)
async with _start_worker_manager(workers=workers, **param) as worker_manager:
yield (worker_manager, workers)
@asynccontextmanager
async def _new_cluster(**kwargs) -> ClusterType:
num_workers = kwargs.get("num_workers", 0)
workers = _create_workers(
num_workers, stream_messags=kwargs.get("stream_messags", [])
)
if "num_workers" in kwargs:
del kwargs["num_workers"]
registry = await _create_model_registry(
workers,
)
async with _start_worker_manager(workers=workers, **kwargs) as worker_manager:
yield (worker_manager, registry)
@pytest_asyncio.fixture
async def cluster_2_workers(request):
param = getattr(request, "param", {})
workers = _create_workers(2)
registry = await _create_model_registry(workers)
async with _start_worker_manager(workers=workers, **param) as worker_manager:
yield (worker_manager, registry)

View File

@ -256,15 +256,22 @@ class DefaultModelWorker(ModelWorker):
return params, model_context, generate_stream_func, model_span
def _handle_output(self, output, previous_response, model_context):
finish_reason = None
usage = None
if isinstance(output, dict):
finish_reason = output.get("finish_reason")
usage = output.get("usage")
output = output["text"]
if finish_reason is not None:
logger.info(f"finish_reason: {finish_reason}")
incremental_output = output[len(previous_response) :]
print(incremental_output, end="", flush=True)
model_output = ModelOutput(
text=output, error_code=0, model_context=model_context
text=output,
error_code=0,
model_context=model_context,
finish_reason=finish_reason,
usage=usage,
)
return model_output, incremental_output, output

View File

@ -99,9 +99,7 @@ class LocalWorkerManager(WorkerManager):
)
def _worker_key(self, worker_type: str, model_name: str) -> str:
if isinstance(worker_type, WorkerType):
worker_type = worker_type.value
return f"{model_name}@{worker_type}"
return WorkerType.to_worker_key(model_name, worker_type)
async def run_blocking_func(self, func, *args):
if asyncio.iscoroutinefunction(func):

View File

@ -3,7 +3,7 @@ import pytest
from typing import List, Iterator, Dict, Tuple
from dataclasses import asdict
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.model.base import ModelOutput, WorkerApplyType
from pilot.model.base import ModelOutput, WorkerApplyType, ModelInstance
from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.manager_base import WorkerRunData
@ -14,7 +14,7 @@ from pilot.model.cluster.worker.manager import (
SendHeartbeatFunc,
ApplyFunction,
)
from pilot.model.cluster.worker.tests.base_tests import (
from pilot.model.cluster.tests.conftest import (
MockModelWorker,
manager_2_workers,
manager_with_2_workers,
@ -216,7 +216,7 @@ async def test__remove_worker():
workers = _create_workers(3)
async with _start_worker_manager(workers=workers, stop=False) as manager:
assert len(manager.workers) == 3
for _, worker_params in workers:
for _, worker_params, _ in workers:
manager._remove_worker(worker_params)
not_exist_parmas = _new_worker_params(
model_name="this is a not exist worker params"
@ -229,7 +229,7 @@ async def test__remove_worker():
async def test_model_startup(mock_build_worker):
async with _start_worker_manager() as manager:
workers = _create_workers(1)
worker, worker_params = workers[0]
worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
@ -245,7 +245,7 @@ async def test_model_startup(mock_build_worker):
async with _start_worker_manager() as manager:
workers = _create_workers(1, error_worker=True)
worker, worker_params = workers[0]
worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
host="127.0.0.1",
@ -263,7 +263,7 @@ async def test_model_startup(mock_build_worker):
async def test_model_shutdown(mock_build_worker):
async with _start_worker_manager(start=False, stop=False) as manager:
workers = _create_workers(1)
worker, worker_params = workers[0]
worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
@ -298,7 +298,7 @@ async def test_get_model_instances(is_async):
workers = _create_workers(3)
async with _start_worker_manager(workers=workers, stop=False) as manager:
assert len(manager.workers) == 3
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
if is_async:
@ -326,7 +326,7 @@ async def test__simple_select(
]
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
instances = await manager.get_model_instances(worker_type, model_name)
@ -351,7 +351,7 @@ async def test_select_one_instance(
],
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
if is_async:
@ -376,7 +376,7 @@ async def test__get_model(
],
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
@ -403,13 +403,13 @@ async def test_generate_stream(
expected_messages: str,
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
text = ""
async for out in manager.generate_stream(params):
text += out.text
text = out.text
assert text == expected_messages
@ -417,8 +417,8 @@ async def test_generate_stream(
@pytest.mark.parametrize(
"manager_with_2_workers, expected_messages",
[
({"stream_messags": ["Hello", " world."]}, " world."),
({"stream_messags": ["你好,我是", "张三。"]}, "张三。"),
({"stream_messags": ["Hello", " world."]}, "Hello world."),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
],
indirect=["manager_with_2_workers"],
)
@ -429,7 +429,7 @@ async def test_generate(
expected_messages: str,
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
@ -454,7 +454,7 @@ async def test_embeddings(
is_async: bool,
):
manager, workers = manager_2_embedding_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name, "input": ["hello", "world"]}
@ -472,7 +472,7 @@ async def test_parameter_descriptions(
]
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = await manager.parameter_descriptions(worker_type, model_name)

View File

@ -339,6 +339,27 @@ register_conv_template(
)
)
# codellama template
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
# reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md
register_conv_template(
Conversation(
name="codellama",
system="<s>[INST] <<SYS>>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request."
"If you don't know the answer to the request, please don't share false information.\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_token_ids=[2],
system_formatter=lambda msg: f"<s>[INST] <<SYS>>\n{msg}\n<</SYS>>\n\n",
)
)
# Alpaca default template
register_conv_template(
Conversation(

View File

@ -45,6 +45,10 @@ _OLD_MODELS = [
"llama-cpp",
"proxyllm",
"gptj-6b",
"codellama-13b-sql-sft",
"codellama-7b",
"codellama-7b-sql-sft",
"codellama-13b",
]
@ -148,8 +152,12 @@ class LLMModelAdaper:
conv.append_message(conv.roles[1], content)
else:
raise ValueError(f"Unknown role: {role}")
if system_messages:
conv.set_system_message("".join(system_messages))
if isinstance(conv, Conversation):
conv.set_system_message("".join(system_messages))
else:
conv.update_system_message("".join(system_messages))
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
@ -459,7 +467,8 @@ register_conv_template(
sep="\n",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
)
),
override=True,
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
register_conv_template(
@ -474,7 +483,8 @@ register_conv_template(
sep="###",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
)
),
override=True,
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
register_conv_template(
@ -487,5 +497,6 @@ register_conv_template(
sep="",
sep2="</s>",
stop_str=["</s>", "<|endoftext|>"],
)
),
override=True,
)

View File

@ -1,9 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional
from typing import Dict, Optional, Union, Tuple
from pilot.model.conversation import conv_templates
from pilot.utils.parameter_utils import BaseParameters
@ -19,6 +20,35 @@ class WorkerType(str, Enum):
def values():
return [item.value for item in WorkerType]
@staticmethod
def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str:
"""Generate worker key from worker name and worker type
Args:
worker_name (str): Worker name(eg., chatglm2-6b)
worker_type (Union[str, "WorkerType"]): Worker type(eg., 'llm', or [`WorkerType.LLM`])
Returns:
str: Generated worker key
"""
if "@" in worker_name:
raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}")
if isinstance(worker_type, WorkerType):
worker_type = worker_type.value
return f"{worker_name}@{worker_type}"
@staticmethod
def parse_worker_key(worker_key: str) -> Tuple[str, str]:
"""Parse worker name and worker type from worker key
Args:
worker_key (str): Worker key generated by [`WorkerType.to_worker_key`]
Returns:
Tuple[str, str]: Worker name and worker type
"""
return tuple(worker_key.split("@"))
@dataclass
class ModelControllerParameters(BaseParameters):
@ -60,6 +90,56 @@ class ModelControllerParameters(BaseParameters):
)
@dataclass
class ModelAPIServerParameters(BaseParameters):
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model API server deploy host"}
)
port: Optional[int] = field(
default=8100, metadata={"help": "Model API server deploy port"}
)
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model API server in background"}
)
controller_addr: Optional[str] = field(
default="http://127.0.0.1:8000",
metadata={"help": "The Model controller address to connect"},
)
api_keys: Optional[str] = field(
default=None,
metadata={"help": "Optional list of comma separated API keys"},
)
log_level: Optional[str] = field(
default=None,
metadata={
"help": "Logging level",
"valid_values": [
"FATAL",
"ERROR",
"WARNING",
"WARNING",
"INFO",
"DEBUG",
"NOTSET",
],
},
)
log_file: Optional[str] = field(
default="dbgpt_model_apiserver.log",
metadata={
"help": "The filename to store log",
},
)
tracer_file: Optional[str] = field(
default="dbgpt_model_apiserver_tracer.jsonl",
metadata={
"help": "The filename to store tracer span records",
},
)
@dataclass
class BaseModelParameters(BaseParameters):
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple, Optional
from typing import Any, Dict, List, Tuple, Optional, Union
from pydantic import BaseModel, Field, root_validator
@ -70,14 +70,6 @@ class SystemMessage(BaseMessage):
return "system"
class ModelMessage(BaseModel):
"""Type of message that interaction between dbgpt-server and llm-server"""
"""Similar to openai's message format"""
role: str
content: str
class ModelMessageRoleType:
""" "Type of ModelMessage role"""
@ -87,6 +79,45 @@ class ModelMessageRoleType:
VIEW = "view"
class ModelMessage(BaseModel):
"""Type of message that interaction between dbgpt-server and llm-server"""
"""Similar to openai's message format"""
role: str
content: str
@staticmethod
def from_openai_messages(
messages: Union[str, List[Dict[str, str]]]
) -> List["ModelMessage"]:
"""Openai message format to current ModelMessage format"""
if isinstance(messages, str):
return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)]
result = []
for message in messages:
msg_role = message["role"]
content = message["content"]
if msg_role == "system":
result.append(
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content)
)
elif msg_role == "user":
result.append(
ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
)
elif msg_role == "assistant":
result.append(
ModelMessage(role=ModelMessageRoleType.AI, content=content)
)
else:
raise ValueError(f"Unknown role: {msg_role}")
return result
@staticmethod
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
return list(map(lambda m: m.dict(), messages))
class Generation(BaseModel):
"""Output of a single generation."""

View File

@ -1,3 +1,4 @@
import json
import os
from typing import Dict, List
@ -110,12 +111,14 @@ class ChatKnowledge(BaseChat):
list(map(lambda doc: doc.metadata, docs)), "source"
)
self.current_message.knowledge_source = self.sources
if not docs:
raise ValueError(
"you have no knowledge space, please add your knowledge space"
)
context = [d.page_content for d in docs]
if not docs or len(docs) == 0:
print("no relevant docs to retrieve")
context = "no relevant docs to retrieve"
# raise ValueError(
# "you have no knowledge space, please add your knowledge space"
# )
else:
context = [d.page_content for d in docs]
context = context[: self.max_token]
relations = list(
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
@ -128,17 +131,26 @@ class ChatKnowledge(BaseChat):
return input_values
def parse_source_view(self, sources: List):
html_title = f"##### **References:**"
lines = ""
"""
build knowledge reference view message to web
{
"title":"References",
"reference":{
"name":"aa.pdf",
"pages":["1","2","3"]
},
}
"""
references = {"title": "References", "reference": {}}
for item in sources:
source = item["source"] if "source" in item else ""
pages = ",".join(item["pages"]) if "pages" in item else ""
lines += f"{source}"
references["reference"]["name"] = source
pages = item["pages"] if "pages" in item else []
if len(pages) > 0:
lines += f", **pages**:{pages}\n\n"
else:
lines += "\n\n"
html = f"""{html_title}\n{lines}"""
references["reference"]["pages"] = pages
html = (
f"""<references>{json.dumps(references, ensure_ascii=False)}</references>"""
)
return html
def merge_by_key(self, data, key):

View File

@ -215,6 +215,16 @@ class Llama2ChatAdapter(BaseChatAdpter):
return get_conv_template("llama-2")
class CodeLlamaChatAdapter(BaseChatAdpter):
"""The model ChatAdapter for codellama ."""
def match(self, model_path: str):
return "codellama" in model_path.lower()
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("codellama")
class BaichuanChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "baichuan" in model_path.lower()
@ -268,6 +278,7 @@ register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
register_llm_model_chat_adapter(Llama2ChatAdapter)
register_llm_model_chat_adapter(CodeLlamaChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter)
register_llm_model_chat_adapter(LlamaCppChatAdapter)

View File

@ -0,0 +1,99 @@
from typing import Dict, Any, Awaitable, Callable, Optional, Iterator
import httpx
import asyncio
import logging
import json
logger = logging.getLogger(__name__)
MessageCaller = Callable[[str], Awaitable[None]]
async def _do_chat_completion(
url: str,
chat_data: Dict[str, Any],
client: httpx.AsyncClient,
headers: Dict[str, Any] = {},
timeout: int = 60,
caller: Optional[MessageCaller] = None,
) -> Iterator[str]:
async with client.stream(
"POST",
url,
headers=headers,
json=chat_data,
timeout=timeout,
) as res:
if res.status_code != 200:
error_message = await res.aread()
if error_message:
error_message = error_message.decode("utf-8")
logger.error(
f"Request failed with status {res.status_code}. Error: {error_message}"
)
raise httpx.RequestError(
f"Request failed with status {res.status_code}",
request=res.request,
)
async for line in res.aiter_lines():
if line:
if not line.startswith("data: "):
if caller:
await caller(line)
yield line
else:
decoded_line = line.split("data: ", 1)[1]
if decoded_line.lower().strip() != "[DONE]".lower():
obj = json.loads(decoded_line)
if obj["choices"][0]["delta"].get("content") is not None:
text = obj["choices"][0]["delta"].get("content")
if caller:
await caller(text)
yield text
await asyncio.sleep(0.02)
async def chat_completion_stream(
url: str,
chat_data: Dict[str, Any],
client: Optional[httpx.AsyncClient] = None,
headers: Dict[str, Any] = {},
timeout: int = 60,
caller: Optional[MessageCaller] = None,
) -> Iterator[str]:
if client:
async for text in _do_chat_completion(
url,
chat_data,
client=client,
headers=headers,
timeout=timeout,
caller=caller,
):
yield text
else:
async with httpx.AsyncClient() as client:
async for text in _do_chat_completion(
url,
chat_data,
client=client,
headers=headers,
timeout=timeout,
caller=caller,
):
yield text
async def chat_completion(
url: str,
chat_data: Dict[str, Any],
client: Optional[httpx.AsyncClient] = None,
headers: Dict[str, Any] = {},
timeout: int = 60,
caller: Optional[MessageCaller] = None,
) -> str:
full_text = ""
async for text in chat_completion_stream(
url, chat_data, client, headers=headers, timeout=timeout, caller=caller
):
full_text += text
return full_text

View File

@ -8,6 +8,7 @@ pytest-integration
pytest-mock
pytest-recording
pytesseract==0.3.10
aioresponses
# python code format
black
# for git hooks