mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-15 23:13:15 +00:00
feat:merge main branch
This commit is contained in:
commit
6fe7bfd63d
126
CODE_OF_CONDUCT
Normal file
126
CODE_OF_CONDUCT
Normal file
@ -0,0 +1,126 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||
identity and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the overall
|
||||
community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or advances of
|
||||
any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email address,
|
||||
without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
[INSERT CONTACT METHOD].
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
*Community Impact*: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
*Consequence*: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
*Community Impact*: A violation through a single incident or series of
|
||||
actions.
|
||||
|
||||
*Consequence*: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or permanent
|
||||
ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
*Community Impact*: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
*Consequence*: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
*Community Impact*: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
*Consequence*: A permanent ban from any sort of public interaction within the
|
||||
community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.1, available at
|
||||
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
||||
|
||||
Community Impact Guidelines were inspired by
|
||||
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
||||
[https://www.contributor-covenant.org/translations][translations].
|
10
README.md
10
README.md
@ -43,8 +43,8 @@ DB-GPT is an experimental open-source project that uses localized GPT large mode
|
||||
|
||||
|
||||
## Contents
|
||||
- [install](#install)
|
||||
- [demo](#demo)
|
||||
- [Install](#install)
|
||||
- [Demo](#demo)
|
||||
- [introduction](#introduction)
|
||||
- [features](#features)
|
||||
- [contribution](#contribution)
|
||||
@ -177,7 +177,7 @@ Currently, we have released multiple key features, which are listed below to dem
|
||||
| [StarRocks](https://github.com/StarRocks/starrocks) | No | TODO |
|
||||
|
||||
## Introduction
|
||||
Is the architecture of the entire DB-GPT shown in the following figure:
|
||||
The architecture of the entire DB-GPT is shown.
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/DB-GPT.png" width="800" />
|
||||
@ -213,7 +213,7 @@ The core capabilities mainly consist of the following parts:
|
||||
|
||||
## Contribution
|
||||
|
||||
- Please run `black .` before submitting the code. Contributing guidelines, [how to contribution](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
|
||||
- Please run `black .` before submitting the code. Contributing guidelines, [how to contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
|
||||
|
||||
## RoadMap
|
||||
|
||||
@ -330,7 +330,7 @@ As of October 10, 2023, by fine-tuning an open-source model of 13 billion parame
|
||||
The MIT License (MIT)
|
||||
|
||||
## Contact Information
|
||||
We are working on building a community, if you have any ideas about building the community, feel free to contact us.
|
||||
We are working on building a community, if you have any ideas for building the community, feel free to contact us.
|
||||
[](https://discord.gg/nASQyBjvY)
|
||||
|
||||
<p align="center">
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 229 KiB After Width: | Height: | Size: 202 KiB |
@ -7,6 +7,16 @@ services:
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- dbgptnet
|
||||
api-server:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start apiserver --controller_addr http://controller:8000
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- controller
|
||||
networks:
|
||||
- dbgptnet
|
||||
ports:
|
||||
- 8100:8100/tcp
|
||||
llm-worker:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000
|
||||
|
@ -32,7 +32,7 @@ extensions = [
|
||||
"sphinx_panels",
|
||||
"sphinx_tabs.tabs",
|
||||
"IPython.sphinxext.ipython_console_highlighting",
|
||||
'sphinx.ext.autosectionlabel'
|
||||
"sphinx.ext.autosectionlabel",
|
||||
]
|
||||
source_suffix = [".ipynb", ".html", ".md", ".rst"]
|
||||
|
||||
|
@ -77,3 +77,4 @@ By analyzing this information, we can identify performance bottlenecks in model
|
||||
|
||||
./vms/standalone.md
|
||||
./vms/index.md
|
||||
./openai.md
|
||||
|
51
docs/getting_started/install/cluster/openai.md
Normal file
51
docs/getting_started/install/cluster/openai.md
Normal file
@ -0,0 +1,51 @@
|
||||
OpenAI-Compatible RESTful APIs
|
||||
==================================
|
||||
(openai-apis-index)=
|
||||
|
||||
### Install Prepare
|
||||
|
||||
You must [deploy DB-GPT cluster](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html) first.
|
||||
|
||||
### Launch Model API Server
|
||||
|
||||
```bash
|
||||
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY
|
||||
```
|
||||
By default, the Model API Server starts on port 8100.
|
||||
|
||||
### Validate with cURL
|
||||
|
||||
#### List models
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8100/api/v1/models \
|
||||
-H "Authorization: Bearer EMPTY" \
|
||||
-H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
#### Chat completions
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8100/api/v1/chat/completions \
|
||||
-H "Authorization: Bearer EMPTY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}'
|
||||
```
|
||||
|
||||
### Validate with OpenAI Official SDK
|
||||
|
||||
#### Chat completions
|
||||
|
||||
```python
|
||||
import openai
|
||||
openai.api_key = "EMPTY"
|
||||
openai.api_base = "http://127.0.0.1:8100/api/v1"
|
||||
model = "vicuna-13b-v1.5"
|
||||
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hello"}]
|
||||
)
|
||||
# print the completion
|
||||
print(completion.choices[0].message.content)
|
||||
```
|
@ -24,9 +24,12 @@ PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
|
||||
|
||||
#Azure
|
||||
LLM_MODEL=chatgpt_proxyllm
|
||||
OPENAI_API_TYPE=azure
|
||||
PROXY_API_KEY={your-openai-sk}
|
||||
PROXY_SERVER_URL=https://xx.openai.azure.com/v1/chat/completions
|
||||
PROXY_API_KEY={your-azure-sk}
|
||||
PROXY_API_BASE=https://{your domain}.openai.azure.com/
|
||||
PROXY_API_TYPE=azure
|
||||
PROXY_SERVER_URL=xxxx
|
||||
PROXY_API_VERSION=2023-05-15
|
||||
PROXYLLM_BACKEND=gpt-35-turbo
|
||||
|
||||
#Aliyun tongyi
|
||||
LLM_MODEL=tongyi_proxyllm
|
||||
|
@ -0,0 +1,71 @@
|
||||
# SOME DESCRIPTIVE TITLE.
|
||||
# Copyright (C) 2023, csunny
|
||||
# This file is distributed under the same license as the DB-GPT package.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
|
||||
#
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 👏👏 0.4.0\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-11-02 21:09+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
"Language-Team: zh_CN <LL@li.org>\n"
|
||||
"Plural-Forms: nplurals=1; plural=0;\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../getting_started/install/cluster/openai.md:1
|
||||
#: 01f4e2bf853341198633b367efec1522
|
||||
msgid "OpenAI-Compatible RESTful APIs"
|
||||
msgstr "OpenAI RESTful 兼容接口"
|
||||
|
||||
#: ../../getting_started/install/cluster/openai.md:5
|
||||
#: d8717e42335e4027bf4e76b3d28768ee
|
||||
msgid "Install Prepare"
|
||||
msgstr "安装准备"
|
||||
|
||||
#: ../../getting_started/install/cluster/openai.md:7
|
||||
#: 9a48d8ee116942468de4c6faf9a64758
|
||||
msgid ""
|
||||
"You must [deploy DB-GPT cluster](https://db-"
|
||||
"gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html)"
|
||||
" first."
|
||||
msgstr "你必须先部署 [DB-GPT 集群]"
|
||||
"(https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh-cn/latest/getting_started/install/cluster/vms/index.html)。"
|
||||
|
||||
#: ../../getting_started/install/cluster/openai.md:9
|
||||
#: 7673a7121f004f7ca6b1a94a7e238fa3
|
||||
msgid "Launch Model API Server"
|
||||
msgstr "启动模型 API Server"
|
||||
|
||||
#: ../../getting_started/install/cluster/openai.md:14
|
||||
#: 84a925c2cbcd4e4895a1d2d2fe8f720f
|
||||
msgid "By default, the Model API Server starts on port 8100."
|
||||
msgstr "默认情况下,模型 API Server 使用 8100 端口启动。"
|
||||
|
||||
#: ../../getting_started/install/cluster/openai.md:16
|
||||
#: e53ed41977cd4721becd51eba05c6609
|
||||
msgid "Validate with cURL"
|
||||
msgstr "通过 cURL 验证"
|
||||
|
||||
#: ../../getting_started/install/cluster/openai.md:18
|
||||
#: 7c883b410b5c4e53a256bf17c1ded80d
|
||||
msgid "List models"
|
||||
msgstr "列出模型"
|
||||
|
||||
#: ../../getting_started/install/cluster/openai.md:26
|
||||
#: ../../getting_started/install/cluster/openai.md:37
|
||||
#: 7cf0ed13f0754f149ec085cd6cf7a45a 990d5d5ed5d64ab49550e68495b9e7a0
|
||||
msgid "Chat completions"
|
||||
msgstr ""
|
||||
|
||||
#: ../../getting_started/install/cluster/openai.md:35
|
||||
#: 81583edd22df44e091d18a0832278131
|
||||
msgid "Validate with OpenAI Official SDK"
|
||||
msgstr "通过 OpenAI 官方 SDK 验证"
|
||||
|
@ -8,7 +8,7 @@ msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-11-02 10:10+0800\n"
|
||||
"POT-Creation-Date: 2023-11-02 21:04+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
@ -20,292 +20,292 @@ msgstr ""
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:1
|
||||
#: 28d2f84fc8884e78afad8118cd59c654
|
||||
#: a17719d2f4374285a7beb4d1db470146
|
||||
#, fuzzy
|
||||
msgid "Environment Parameter"
|
||||
msgstr "环境变量说明"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:4
|
||||
#: c83fbb5e1aa643cdb09fffe7f3d1a3c5
|
||||
#: 9a62e6fff7914eeaa2d195ddef4fcb61
|
||||
msgid "LLM MODEL Config"
|
||||
msgstr "模型配置"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:5
|
||||
#: eb675965ae57407e8d8bf90fed8e9e2a
|
||||
#: 90e3991538324ecfac8cac7ef2103ac2
|
||||
msgid "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
|
||||
msgstr "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:6
|
||||
#: 5d28d35126d849ea9b0d963fd1ba8699
|
||||
#: 1f45af01100c4586acbc05469e3006bc
|
||||
msgid "LLM_MODEL=vicuna-13b"
|
||||
msgstr "LLM_MODEL=vicuna-13b"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:8
|
||||
#: 01955b2d0fbe4d94939ebf2cbb380bdd
|
||||
#: bed14b704f154c2db525f7fafd3aa5a4
|
||||
msgid "MODEL_SERVER_ADDRESS"
|
||||
msgstr "MODEL_SERVER_ADDRESS"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:9
|
||||
#: 4eaaa9ab59854c0386b28b3111c82784
|
||||
#: ea42946cfe4f4ad996bf82c1996e7344
|
||||
msgid "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY"
|
||||
msgstr "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:12
|
||||
#: 5c2dd05e16834443b7451c2541b59757
|
||||
#: 021c261231f342fdba34098b1baa06fd
|
||||
msgid "LIMIT_MODEL_CONCURRENCY=5"
|
||||
msgstr "LIMIT_MODEL_CONCURRENCY=5"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:14
|
||||
#: 7707836c2fb04e7da13d2d59b5f9566f
|
||||
#: afaf0ba7fd09463d8ff74b514ed7264c
|
||||
msgid "MAX_POSITION_EMBEDDINGS"
|
||||
msgstr "MAX_POSITION_EMBEDDINGS"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:16
|
||||
#: ee24a7d3d8384e61b715ef3bd362b965
|
||||
#: e4517a942bca4361a64a00408f993f5b
|
||||
msgid "MAX_POSITION_EMBEDDINGS=4096"
|
||||
msgstr "MAX_POSITION_EMBEDDINGS=4096"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:18
|
||||
#: 90b51aa4e46b4d1298c672e0052c2f68
|
||||
#: 78d2ef04ed4548b9b7b0fb8ae35c9d5c
|
||||
msgid "QUANTIZE_QLORA"
|
||||
msgstr "QUANTIZE_QLORA"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:20
|
||||
#: 7de7a8eb431e4973ae00f68ca0686281
|
||||
#: bfa65db03c6d46bba293331f03ab15ac
|
||||
msgid "QUANTIZE_QLORA=True"
|
||||
msgstr "QUANTIZE_QLORA=True"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:22
|
||||
#: e331ca016a474f4aa4e9182165a2693a
|
||||
#: 1947d45a7f184821910b4834ad5f1897
|
||||
msgid "QUANTIZE_8bit"
|
||||
msgstr "QUANTIZE_8bit"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:24
|
||||
#: 519ccce5a0884778be2719c437a17bd4
|
||||
#: 4a2ee2919d0e4bdaa13c9d92eefd2aac
|
||||
msgid "QUANTIZE_8bit=True"
|
||||
msgstr "QUANTIZE_8bit=True"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:27
|
||||
#: 1c0586d070f046de8d0f9f94a6b508b4
|
||||
#: 348dc1e411b54ab09414f40a20e934e4
|
||||
msgid "LLM PROXY Settings"
|
||||
msgstr "LLM PROXY Settings"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:28
|
||||
#: c208c3f4b13f4b39962de814e5be6ab9
|
||||
#: a692e78425a040f5828ab54ff9a33f77
|
||||
msgid "OPENAI Key"
|
||||
msgstr "OPENAI Key"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:30
|
||||
#: 9228bbee2faa4467b1d24f1125faaac8
|
||||
#: 940d00e25a424acf92951a314a64e5ea
|
||||
msgid "PROXY_API_KEY={your-openai-sk}"
|
||||
msgstr "PROXY_API_KEY={your-openai-sk}"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:31
|
||||
#: 759ae581883348019c1ba79e8954728a
|
||||
#: 4bd27547ae6041679e91f2a363cd1deb
|
||||
msgid "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions"
|
||||
msgstr "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:33
|
||||
#: 83f3952917d34aab80bd34119f7d1e20
|
||||
#: cfa3071afb0b47baad6bd729d4a02cb9
|
||||
msgid "from https://bard.google.com/ f12-> application-> __Secure-1PSID"
|
||||
msgstr "from https://bard.google.com/ f12-> application-> __Secure-1PSID"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:35
|
||||
#: 1d70707ca82749bb90b2bed1aee44d62
|
||||
#: a17efa03b10f47f68afac9e865982a75
|
||||
msgid "BARD_PROXY_API_KEY={your-bard-token}"
|
||||
msgstr "BARD_PROXY_API_KEY={your-bard-token}"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:38
|
||||
#: 38a2091fa223493ea23cb9bbb33cf58e
|
||||
#: 6bcfe90574da4d82a459e8e11bf73cba
|
||||
msgid "DATABASE SETTINGS"
|
||||
msgstr "DATABASE SETTINGS"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:39
|
||||
#: 5134180d7a5945b48b072a1eb92b27ba
|
||||
#: 2b1e62d9bf5d4af5a22f68c8248eaafb
|
||||
msgid "SQLite database (Current default database)"
|
||||
msgstr "SQLite database (Current default database)"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:40
|
||||
#: 6875e2300e094668a45fa4f2551e0d30
|
||||
#: 8a909ac3b3c943da8dbc4e8dd596c80c
|
||||
msgid "LOCAL_DB_PATH=data/default_sqlite.db"
|
||||
msgstr "LOCAL_DB_PATH=data/default_sqlite.db"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:41
|
||||
#: 034e8f06f24f44af9d8184563f99b4b3
|
||||
#: 90ae6507932f4815b6e180051738bb93
|
||||
msgid "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite"
|
||||
msgstr "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:43
|
||||
#: f688149a97f740269f80b79775236ce9
|
||||
#: d2ce34e0dcf44ccf9e8007d548ba7b0a
|
||||
msgid "MYSQL database"
|
||||
msgstr "MYSQL database"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:44
|
||||
#: 6db0b305137d45a3aa036e4f2262f460
|
||||
#: c07159d63c334f6cbb95fcc30bfb7ea5
|
||||
msgid "LOCAL_DB_TYPE=mysql"
|
||||
msgstr "LOCAL_DB_TYPE=mysql"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:45
|
||||
#: b6d662ce8d5f44f0b54a7f6e7c66f5a5
|
||||
#: e16700b2ea8d411e91d010c1cde7aecc
|
||||
msgid "LOCAL_DB_USER=root"
|
||||
msgstr "LOCAL_DB_USER=root"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:46
|
||||
#: cd7493d61ac9415283640dc6c018d2f4
|
||||
#: bfc2dce1bf374121b6861e677b4e1ffa
|
||||
msgid "LOCAL_DB_PASSWORD=aa12345678"
|
||||
msgstr "LOCAL_DB_PASSWORD=aa12345678"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:47
|
||||
#: 4ea2a622b23f4342a4c2ab7f8d9c4e8d
|
||||
#: bc384739f5b04e21a34d0d2b78e7906c
|
||||
msgid "LOCAL_DB_HOST=127.0.0.1"
|
||||
msgstr "LOCAL_DB_HOST=127.0.0.1"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:48
|
||||
#: 936db95a0ab246098028f4dbb596cd17
|
||||
#: e5253d452e0d42b7ac308fe6fbfb5017
|
||||
msgid "LOCAL_DB_PORT=3306"
|
||||
msgstr "LOCAL_DB_PORT=3306"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:51
|
||||
#: d9255f25989840ea9c9e7b34f3947c87
|
||||
#: 9ca8f6fe06ed4cbab390f94be252e165
|
||||
msgid "EMBEDDING SETTINGS"
|
||||
msgstr "EMBEDDING SETTINGS"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:52
|
||||
#: b09291d32aca43928a981e873476a985
|
||||
#: 76c7c260293c4b49bae057143fd48377
|
||||
msgid "EMBEDDING MODEL Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
|
||||
msgstr "EMBEDDING模型, 参考see /pilot/configs/model_config.LLM_MODEL_CONFIG"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:53
|
||||
#: 63de573b03a54413b997f18a1ccee279
|
||||
#: f1d63a0128ce493cae37d34f1976bcca
|
||||
msgid "EMBEDDING_MODEL=text2vec"
|
||||
msgstr "EMBEDDING_MODEL=text2vec"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:55
|
||||
#: 0ef8defbab544bd0b9475a036f278489
|
||||
#: b8fbb99109d04781b2dd5bc5d6efa5bd
|
||||
msgid "Embedding Chunk size, default 500"
|
||||
msgstr "Embedding 切片大小, 默认500"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:57
|
||||
#: 33dbc7941d054baa8c6ecfc0bf1ce271
|
||||
#: bf8256576ea34f6a9c5f261ab9aab676
|
||||
msgid "KNOWLEDGE_CHUNK_SIZE=500"
|
||||
msgstr "KNOWLEDGE_CHUNK_SIZE=500"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:59
|
||||
#: e6ee9f2620ab45ecbc8e9c0642f5ca42
|
||||
#: 9b156c6b599b4c02a58ce023b4ff25f2
|
||||
msgid "Embedding Chunk Overlap, default 100"
|
||||
msgstr "Embedding chunk Overlap, 文本块之间的最大重叠量。保留一些重叠可以保持文本块之间的连续性(例如使用滑动窗口),默认100"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:60
|
||||
#: fcddf64340a04df4ab95176fc2fc67a6
|
||||
#: dcafd903c36041ac85ac99a14dbee512
|
||||
msgid "KNOWLEDGE_CHUNK_OVERLAP=100"
|
||||
msgstr "KNOWLEDGE_CHUNK_OVERLAP=100"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:62
|
||||
#: 61272200194b4461a921581feb1273da
|
||||
#: 6c3244b7e5e24b0188c7af4bb52e9134
|
||||
#, fuzzy
|
||||
msgid "embedding recall top k,5"
|
||||
msgstr "embedding 召回topk, 默认5"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:64
|
||||
#: b433091f055542b1b89ff2d525ac99e4
|
||||
#: f4a2f30551cf4fe1a7ff3c7c74ec77be
|
||||
msgid "KNOWLEDGE_SEARCH_TOP_SIZE=5"
|
||||
msgstr "KNOWLEDGE_SEARCH_TOP_SIZE=5"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:66
|
||||
#: 1db0de41aebd4caa8cc2eaecb4cacd6a
|
||||
#: 593f2512362f467e92fdaa60dd5903a0
|
||||
#, fuzzy
|
||||
msgid "embedding recall max token ,2000"
|
||||
msgstr "embedding向量召回最大token, 默认2000"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:68
|
||||
#: 81b9d862e58941a4b09680a7520cdabe
|
||||
#: 83d6d28914be4d6282d457272e508ddc
|
||||
msgid "KNOWLEDGE_SEARCH_MAX_TOKEN=5"
|
||||
msgstr "KNOWLEDGE_SEARCH_MAX_TOKEN=5"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:71
|
||||
#: ../../getting_started/install/environment/environment.md:87
|
||||
#: cac73575d54544778bdee09b18532fd9 f78a509949a64f03aa330f31901e2e7a
|
||||
#: 6bc1b9d995e74294a1c78e783c550db7 d33c77ded834438e9f4a2df06e7e041a
|
||||
msgid "Vector Store SETTINGS"
|
||||
msgstr "Vector Store SETTINGS"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:72
|
||||
#: ../../getting_started/install/environment/environment.md:88
|
||||
#: 5ebba1cb047b4b09849000244237dfbb 7e9285e91bcb4b2d9413909c0d0a06a7
|
||||
#: 9cafa06e2d584f70afd848184e0fa52a f01057251b8b4ffea806192dfe1048ed
|
||||
msgid "Chroma"
|
||||
msgstr "Chroma"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:73
|
||||
#: ../../getting_started/install/environment/environment.md:89
|
||||
#: 05625cfcc23c4745ae1fa0d94ce5450c 3a8615f1507f4fc49d1adda5100a4edf
|
||||
#: e6c16fab37484769b819aeecbc13e6db faad299722e5400e95ec6ac3c1e018b8
|
||||
msgid "VECTOR_STORE_TYPE=Chroma"
|
||||
msgstr "VECTOR_STORE_TYPE=Chroma"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:74
|
||||
#: ../../getting_started/install/environment/environment.md:90
|
||||
#: 5b559376aea44f159262e6d4b75c7ec1 e954782861404b10b4e893e01cf74452
|
||||
#: 4eca3a51716d406f8ffd49c06550e871 581ee9dd38064b119660c44bdd00cbaa
|
||||
msgid "MILVUS"
|
||||
msgstr "MILVUS"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:75
|
||||
#: ../../getting_started/install/environment/environment.md:91
|
||||
#: 55ee8199c97a4929aeefd32370f2b92d 8f40c02543ea4a2ca9632dd9e8a08c2e
|
||||
#: 814c93048bed46589358a854d6c99683 b72b1269a2224f5f961214e41c019f21
|
||||
msgid "VECTOR_STORE_TYPE=Milvus"
|
||||
msgstr "VECTOR_STORE_TYPE=Milvus"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:76
|
||||
#: ../../getting_started/install/environment/environment.md:92
|
||||
#: 528a01d25720491c8e086bf43a62ad92 ba1386d551d7494a85681a2803081a6f
|
||||
#: 73ae665f1db9402883662734588fd02c c4da20319c994e83ba5a7706db967178
|
||||
msgid "MILVUS_URL=127.0.0.1"
|
||||
msgstr "MILVUS_URL=127.0.0.1"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:77
|
||||
#: ../../getting_started/install/environment/environment.md:93
|
||||
#: b031950dafcd4d4783c120dc933c4178 c2e9c8cdd41741e3aba01e59a6ef245d
|
||||
#: e30c5288516d42aa858a485db50490c1 f843b2e58bcb4e4594e3c28499c341d0
|
||||
msgid "MILVUS_PORT=19530"
|
||||
msgstr "MILVUS_PORT=19530"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:78
|
||||
#: ../../getting_started/install/environment/environment.md:94
|
||||
#: 27b0a64af6434cb2840373e2b38c9bd5 d0e4d79af7954b129ffff7303a1ec3ce
|
||||
#: 158669efcc7d4bcaac1c8dd01b499029 24e88ffd32f242f281c56c0ec3ad2639
|
||||
msgid "MILVUS_USERNAME"
|
||||
msgstr "MILVUS_USERNAME"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:79
|
||||
#: ../../getting_started/install/environment/environment.md:95
|
||||
#: 27aa1a5b61e64dd6bfe29124e274809e 5c58892498ce4f46a59f54b2887822d4
|
||||
#: 111a985297184c8aa5a0dd8e14a58445 6602093a6bb24d6792548e2392105c82
|
||||
msgid "MILVUS_PASSWORD"
|
||||
msgstr "MILVUS_PASSWORD"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:80
|
||||
#: ../../getting_started/install/environment/environment.md:96
|
||||
#: 009e57d4acc5434da2146f0545911c85 bac8888dcbff47fbb0ea8ae685445aac
|
||||
#: 47bdfcd78fbe4ccdb5f49b717a6d01a6 b96c0545b2044926a8a8190caf94ad25
|
||||
msgid "MILVUS_SECURE="
|
||||
msgstr "MILVUS_SECURE="
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:82
|
||||
#: ../../getting_started/install/environment/environment.md:98
|
||||
#: a6eeb16ab5274045bee88ecc3d93e09e eb341774403d47658b9b7e94c4c16d5c
|
||||
#: 755c32b5d6c54607907a138b5474c0ec ff4f2a7ddaa14f089dda7a14e1062c36
|
||||
msgid "WEAVIATE"
|
||||
msgstr "WEAVIATE"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:83
|
||||
#: fbd97522d8da4824b41b99298fd41069
|
||||
#: 23b2ce83385d40a589a004709f9864be
|
||||
msgid "VECTOR_STORE_TYPE=Weaviate"
|
||||
msgstr "VECTOR_STORE_TYPE=Weaviate"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:84
|
||||
#: ../../getting_started/install/environment/environment.md:99
|
||||
#: 341785b4abfe42b5af1c2e04497261f4 a81cc2aabc8240f3ac1f674d9350bff4
|
||||
#: 9acef304d89a448a9e734346705ba872 cf5151b6c1594ccd8beb1c3f77769acb
|
||||
msgid "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network"
|
||||
msgstr "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:102
|
||||
#: 5bb9e5daa36241d499089c1b1910f729
|
||||
#: c3003516b2364051bf34f8c3086e348a
|
||||
msgid "Multi-GPU Setting"
|
||||
msgstr "Multi-GPU Setting"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:103
|
||||
#: 30df45b7f1f7423c9f18c6360f0b7600
|
||||
#: ade8fc381c5e438aa29d159c10041713
|
||||
msgid ""
|
||||
"See https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-"
|
||||
"visibility-cuda_visible_devices/ If CUDA_VISIBLE_DEVICES is not "
|
||||
@ -315,49 +315,49 @@ msgstr ""
|
||||
"cuda_visible_devices/ 如果 CUDA_VISIBLE_DEVICES没有设置, 会使用所有可用的gpu"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:106
|
||||
#: 8631ea968dfb4d90a7ae6bdb2acdfdce
|
||||
#: e137bd19be5e410ba6709027dbf2923a
|
||||
msgid "CUDA_VISIBLE_DEVICES=0"
|
||||
msgstr "CUDA_VISIBLE_DEVICES=0"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:108
|
||||
#: 0010422280dd4fe79326ebceb2a66f0e
|
||||
#: 7669947acbdc4b1d92bcc029a8353a5d
|
||||
msgid ""
|
||||
"Optionally, you can also specify the gpu ID to use before the starting "
|
||||
"command"
|
||||
msgstr "你也可以通过启动命令设置gpu ID"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:110
|
||||
#: 00106f7341304fbd9425721ea8e6a261
|
||||
#: 751743d1753b4051beea46371278d793
|
||||
msgid "CUDA_VISIBLE_DEVICES=3,4,5,6"
|
||||
msgstr "CUDA_VISIBLE_DEVICES=3,4,5,6"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:112
|
||||
#: 720aa8b3478744d78e4b10dfeccb50b4
|
||||
#: 3acc3de0af0d4df2bb575e161e377f85
|
||||
msgid "You can configure the maximum memory used by each GPU."
|
||||
msgstr "可以设置GPU的最大内存"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:114
|
||||
#: f9639ac96a244296832c75bbcbdae2af
|
||||
#: 67f1d9b172b84294a44ecace5436e6e0
|
||||
msgid "MAX_GPU_MEMORY=16Gib"
|
||||
msgstr "MAX_GPU_MEMORY=16Gib"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:117
|
||||
#: fc4d955fdb3e4256af5c8f29b042dcd6
|
||||
#: 3c69dfe48bcf46b89b76cac1e7849a66
|
||||
msgid "Other Setting"
|
||||
msgstr "Other Setting"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:118
|
||||
#: 66b14a834e884339be2d48392e884933
|
||||
#: d5015b70f4fe4d20a63de9d87f86957a
|
||||
msgid "Language Settings(influence prompt language)"
|
||||
msgstr "Language Settings(涉及prompt语言以及知识切片方式)"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:119
|
||||
#: 5c9f05174eb84edd9e1316cc0721a840
|
||||
#: 5543c28bb8e34c9fb3bb6b063c2b1750
|
||||
msgid "LANGUAGE=en"
|
||||
msgstr "LANGUAGE=en"
|
||||
|
||||
#: ../../getting_started/install/environment/environment.md:120
|
||||
#: 7f6d62117d024c51bba9255fa4fcf151
|
||||
#: cb4ed5b892ee41068c1ca76cb29aa400
|
||||
msgid "LANGUAGE=zh"
|
||||
msgstr "LANGUAGE=zh"
|
||||
|
||||
|
@ -8,7 +8,7 @@ msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 0.3.0\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-11-02 10:10+0800\n"
|
||||
"POT-Creation-Date: 2023-11-02 21:04+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
@ -19,11 +19,11 @@ msgstr ""
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../modules/knowledge.md:1 436b94d3a8374ed18feb5c14893a84e6
|
||||
#: ../../modules/knowledge.md:1 b94b3b15cb2441ed9d78abd222a717b7
|
||||
msgid "Knowledge"
|
||||
msgstr "知识"
|
||||
|
||||
#: ../../modules/knowledge.md:3 918a3747cbed42d18b8c9c4547e67b14
|
||||
#: ../../modules/knowledge.md:3 c6d6e308a6ce42948d29e928136ef561
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
"As the knowledge base is currently the most significant user demand "
|
||||
@ -34,15 +34,15 @@ msgstr ""
|
||||
"由于知识库是当前用户需求最显著的场景,我们原生支持知识库的构建和处理。同时,我们还在本项目中提供了多种知识库管理策略,如:pdf,md , "
|
||||
"txt, word, ppt"
|
||||
|
||||
#: ../../modules/knowledge.md:4 d4d4b5d57918485aafa457bb9fdcf626
|
||||
#: ../../modules/knowledge.md:4 268abc408d40410ba90cf5f121dc5270
|
||||
msgid "Default built-in knowledge base"
|
||||
msgstr ""
|
||||
|
||||
#: ../../modules/knowledge.md:5 d4d4b5d57918485aafa457bb9fdcf626
|
||||
#: ../../modules/knowledge.md:5 558c3364c38b458a8ebf81030efc2a48
|
||||
msgid "Custom addition of knowledge bases"
|
||||
msgstr ""
|
||||
|
||||
#: ../../modules/knowledge.md:6 984361ce835c4c3492e29e1fb897348a
|
||||
#: ../../modules/knowledge.md:6 9cb3ce62da1440579c095848c7aef88c
|
||||
msgid ""
|
||||
"Various usage scenarios such as constructing knowledge bases through "
|
||||
"plugin capabilities and web crawling. Users only need to organize the "
|
||||
@ -50,53 +50,53 @@ msgid ""
|
||||
"the knowledge base required for the large model."
|
||||
msgstr ""
|
||||
|
||||
#: ../../modules/knowledge.md:9 746e4fbd3212460198be51b90caee2c8
|
||||
#: ../../modules/knowledge.md:9 b8ca6bc4dd9845baa56e36eea7fac2a2
|
||||
#, fuzzy
|
||||
msgid "Create your own knowledge repository"
|
||||
msgstr "创建你自己的知识库"
|
||||
|
||||
#: ../../modules/knowledge.md:11 1c46b33b0532417c824efbaa3e687c3f
|
||||
#: ../../modules/knowledge.md:11 17d7178a67924f43aa5b6293707ef041
|
||||
msgid ""
|
||||
"1.Place personal knowledge files or folders in the pilot/datasets "
|
||||
"directory."
|
||||
msgstr ""
|
||||
|
||||
#: ../../modules/knowledge.md:13 3b16f387b5354947a89d6df77bd65bdb
|
||||
#: ../../modules/knowledge.md:13 31c31f14bf444981939689f9a9fb038a
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
"We currently support many document formats: txt, pdf, md, html, doc, ppt,"
|
||||
" and url."
|
||||
msgstr "当前支持txt, pdf, md, html, doc, ppt, url文档格式"
|
||||
|
||||
#: ../../modules/knowledge.md:15 09ec337d7da4418db854e58afb6c0980
|
||||
#: ../../modules/knowledge.md:15 9ad2f2e05f8842a9b9d8469a3704df23
|
||||
msgid "before execution:"
|
||||
msgstr "开始前"
|
||||
|
||||
#: ../../modules/knowledge.md:22 c09b3decb018485f8e56830ddc156194
|
||||
#: ../../modules/knowledge.md:22 6fd2775914b641c4b8e486417b558ea6
|
||||
msgid ""
|
||||
"2.Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma "
|
||||
"(now only support Chroma and Milvus, if you set Milvus, please set "
|
||||
"MILVUS_URL and MILVUS_PORT)"
|
||||
msgstr ""
|
||||
|
||||
#: ../../modules/knowledge.md:25 74460ec7709441d5945ce9f745a26d20
|
||||
#: ../../modules/knowledge.md:25 131c5f58898a4682940910980edb2043
|
||||
msgid "2.Run the knowledge repository initialization command"
|
||||
msgstr ""
|
||||
|
||||
#: ../../modules/knowledge.md:31 4498ec4e46ff4e24b45dd855e829bd32
|
||||
#: ../../modules/knowledge.md:31 2cf550f17881497bb881b19efcc18c23
|
||||
msgid ""
|
||||
"Optionally, you can run `dbgpt knowledge load --help` command to see more"
|
||||
" usage."
|
||||
msgstr ""
|
||||
|
||||
#: ../../modules/knowledge.md:33 5048ac3289e540f2a2b5fd0e5ed043f5
|
||||
#: ../../modules/knowledge.md:33 c8a2ea571b944bdfbcad48fa8b54fcc9
|
||||
msgid ""
|
||||
"3.Add the knowledge repository in the interface by entering the name of "
|
||||
"your knowledge repository (if not specified, enter \"default\") so you "
|
||||
"can use it for Q&A based on your knowledge base."
|
||||
msgstr ""
|
||||
|
||||
#: ../../modules/knowledge.md:35 deeccff20f7f453dad0881b63dae2a18
|
||||
#: ../../modules/knowledge.md:35 b701170ad75e49dea7d7734c15681e0f
|
||||
msgid ""
|
||||
"Note that the default vector model used is text2vec-large-chinese (which "
|
||||
"is a large model, so if your personal computer configuration is not "
|
||||
|
0
pilot/base_modules/agent/db/__init__.py
Normal file
0
pilot/base_modules/agent/db/__init__.py
Normal file
@ -46,6 +46,8 @@ class ComponentType(str, Enum):
|
||||
WORKER_MANAGER = "dbgpt_worker_manager"
|
||||
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
|
||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||
MODEL_REGISTRY = "dbgpt_model_registry"
|
||||
MODEL_API_SERVER = "dbgpt_model_api_server"
|
||||
AGENT_HUB = "dbgpt_agent_hub"
|
||||
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
|
||||
TRACER = "dbgpt_tracer"
|
||||
@ -69,7 +71,6 @@ class BaseComponent(LifeCycle, ABC):
|
||||
This method needs to be implemented by every component to define how it integrates
|
||||
with the main system app.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseComponent)
|
||||
@ -91,13 +92,28 @@ class SystemApp(LifeCycle):
|
||||
"""Returns the internal ASGI app."""
|
||||
return self._asgi_app
|
||||
|
||||
def register(self, component: Type[BaseComponent], *args, **kwargs):
|
||||
"""Register a new component by its type."""
|
||||
def register(self, component: Type[BaseComponent], *args, **kwargs) -> T:
|
||||
"""Register a new component by its type.
|
||||
|
||||
Args:
|
||||
component (Type[BaseComponent]): The component class to register
|
||||
|
||||
Returns:
|
||||
T: The instance of registered component
|
||||
"""
|
||||
instance = component(self, *args, **kwargs)
|
||||
self.register_instance(instance)
|
||||
return instance
|
||||
|
||||
def register_instance(self, instance: T):
|
||||
"""Register an already initialized component."""
|
||||
def register_instance(self, instance: T) -> T:
|
||||
"""Register an already initialized component.
|
||||
|
||||
Args:
|
||||
instance (T): The component instance to register
|
||||
|
||||
Returns:
|
||||
T: The instance of registered component
|
||||
"""
|
||||
name = instance.name
|
||||
if isinstance(name, ComponentType):
|
||||
name = name.value
|
||||
@ -108,18 +124,34 @@ class SystemApp(LifeCycle):
|
||||
logger.info(f"Register component with name {name} and instance: {instance}")
|
||||
self.components[name] = instance
|
||||
instance.init_app(self)
|
||||
return instance
|
||||
|
||||
def get_component(
|
||||
self,
|
||||
name: Union[str, ComponentType],
|
||||
component_type: Type[T],
|
||||
default_component=_EMPTY_DEFAULT_COMPONENT,
|
||||
or_register_component: Type[BaseComponent] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Retrieve a registered component by its name and type."""
|
||||
"""Retrieve a registered component by its name and type.
|
||||
|
||||
Args:
|
||||
name (Union[str, ComponentType]): Component name
|
||||
component_type (Type[T]): The type of current retrieve component
|
||||
default_component : The default component instance if not retrieve by name
|
||||
or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
|
||||
|
||||
Returns:
|
||||
T: The instance retrieved by component name
|
||||
"""
|
||||
if isinstance(name, ComponentType):
|
||||
name = name.value
|
||||
component = self.components.get(name)
|
||||
if not component:
|
||||
if or_register_component:
|
||||
return self.register(or_register_component, *args, **kwargs)
|
||||
if default_component != _EMPTY_DEFAULT_COMPONENT:
|
||||
return default_component
|
||||
raise ValueError(f"No component found with name {name}")
|
||||
|
@ -78,6 +78,10 @@ LLM_MODEL_CONFIG = {
|
||||
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
|
||||
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
|
||||
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
|
||||
"codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"),
|
||||
"codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"),
|
||||
"codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"),
|
||||
"codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
|
||||
# For test now
|
||||
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
|
||||
}
|
||||
|
@ -320,6 +320,19 @@ class Llama2Adapter(BaseLLMAdaper):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class CodeLlamaAdapter(BaseLLMAdaper):
|
||||
"""The model adapter for codellama"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "codellama" in model_path.lower()
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class BaichuanAdapter(BaseLLMAdaper):
|
||||
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""
|
||||
|
||||
@ -420,6 +433,7 @@ register_llm_model_adapters(FalconAdapater)
|
||||
register_llm_model_adapters(GorillaAdapter)
|
||||
register_llm_model_adapters(GPT4AllAdapter)
|
||||
register_llm_model_adapters(Llama2Adapter)
|
||||
register_llm_model_adapters(CodeLlamaAdapter)
|
||||
register_llm_model_adapters(BaichuanAdapter)
|
||||
register_llm_model_adapters(WizardLMAdapter)
|
||||
register_llm_model_adapters(LlamaCppAdapater)
|
||||
|
@ -2,7 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from enum import Enum
|
||||
from typing import TypedDict, Optional, Dict, List
|
||||
from typing import TypedDict, Optional, Dict, List, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from pilot.utils.parameter_utils import ParameterDescription
|
||||
@ -52,6 +52,8 @@ class ModelOutput:
|
||||
text: str
|
||||
error_code: int
|
||||
model_context: Dict = None
|
||||
finish_reason: str = None
|
||||
usage: Dict[str, Any] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return asdict(self)
|
||||
|
@ -8,6 +8,7 @@ from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import WorkerApplyType
|
||||
from pilot.model.parameter import (
|
||||
ModelControllerParameters,
|
||||
ModelAPIServerParameters,
|
||||
ModelWorkerParameters,
|
||||
ModelParameters,
|
||||
BaseParameters,
|
||||
@ -441,15 +442,27 @@ def stop_model_worker(port: int):
|
||||
|
||||
|
||||
@click.command(name="apiserver")
|
||||
@EnvArgumentParser.create_click_option(ModelAPIServerParameters)
|
||||
def start_apiserver(**kwargs):
|
||||
"""Start apiserver(TODO)"""
|
||||
raise NotImplementedError
|
||||
"""Start apiserver"""
|
||||
|
||||
if kwargs["daemon"]:
|
||||
log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log")
|
||||
_run_current_with_daemon("ModelAPIServer", log_file)
|
||||
else:
|
||||
from pilot.model.cluster import run_apiserver
|
||||
|
||||
run_apiserver()
|
||||
|
||||
|
||||
@click.command(name="apiserver")
|
||||
def stop_apiserver(**kwargs):
|
||||
"""Start apiserver(TODO)"""
|
||||
raise NotImplementedError
|
||||
@add_stop_server_options
|
||||
def stop_apiserver(port: int):
|
||||
"""Stop apiserver"""
|
||||
name = "ModelAPIServer"
|
||||
if port:
|
||||
name = f"{name}-{port}"
|
||||
_stop_service("apiserver", name, port=port)
|
||||
|
||||
|
||||
def _stop_all_model_server(**kwargs):
|
||||
|
@ -21,6 +21,7 @@ from pilot.model.cluster.controller.controller import (
|
||||
run_model_controller,
|
||||
BaseModelController,
|
||||
)
|
||||
from pilot.model.cluster.apiserver.api import run_apiserver
|
||||
|
||||
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
|
||||
|
||||
@ -40,4 +41,5 @@ __all__ = [
|
||||
"ModelRegistryClient",
|
||||
"RemoteWorkerManager",
|
||||
"run_model_controller",
|
||||
"run_apiserver",
|
||||
]
|
||||
|
0
pilot/model/cluster/apiserver/__init__.py
Normal file
0
pilot/model/cluster/apiserver/__init__.py
Normal file
443
pilot/model/cluster/apiserver/api.py
Normal file
443
pilot/model/cluster/apiserver/api.py
Normal file
@ -0,0 +1,443 @@
|
||||
"""A server that provides OpenAI-compatible RESTful APIs. It supports:
|
||||
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
|
||||
|
||||
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any, Generator
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import shortuuid
|
||||
import json
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from pydantic import BaseSettings
|
||||
|
||||
from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
ChatCompletionResponseChoice,
|
||||
DeltaMessage,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
ErrorResponse,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
UsageInfo,
|
||||
)
|
||||
from fastchat.protocol.api_protocol import (
|
||||
APIChatCompletionRequest,
|
||||
APITokenCheckRequest,
|
||||
APITokenCheckResponse,
|
||||
APITokenCheckResponseItem,
|
||||
)
|
||||
from fastchat.serve.openai_api_server import create_error_response, check_requests
|
||||
from fastchat.constants import ErrorCode
|
||||
|
||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
from pilot.model.base import ModelInstance, ModelOutput
|
||||
from pilot.model.parameter import ModelAPIServerParameters, WorkerType
|
||||
from pilot.model.cluster import ModelRegistry, ModelRegistryClient
|
||||
from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
|
||||
from pilot.utils.utils import setup_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIServerException(Exception):
|
||||
def __init__(self, code: int, message: str):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class APISettings(BaseSettings):
|
||||
api_keys: Optional[List[str]] = None
|
||||
|
||||
|
||||
api_settings = APISettings()
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
) -> str:
|
||||
if api_settings.api_keys:
|
||||
if auth is None or (token := auth.credentials) not in api_settings.api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
class APIServer(BaseComponent):
|
||||
name = ComponentType.MODEL_API_SERVER
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def get_worker_manager(self) -> WorkerManager:
|
||||
"""Get the worker manager component instance
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get worker manager component instance
|
||||
"""
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
if not worker_manager:
|
||||
raise APIServerException(
|
||||
ErrorCode.INTERNAL_ERROR,
|
||||
f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app",
|
||||
)
|
||||
return worker_manager
|
||||
|
||||
def get_model_registry(self) -> ModelRegistry:
|
||||
"""Get the model registry component instance
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get model registry component instance
|
||||
"""
|
||||
|
||||
controller = self.system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry
|
||||
)
|
||||
if not controller:
|
||||
raise APIServerException(
|
||||
ErrorCode.INTERNAL_ERROR,
|
||||
f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app",
|
||||
)
|
||||
return controller
|
||||
|
||||
async def get_model_instances_or_raise(
|
||||
self, model_name: str
|
||||
) -> List[ModelInstance]:
|
||||
"""Get healthy model instances with request model name
|
||||
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get healthy model instances with request model name
|
||||
"""
|
||||
registry = self.get_model_registry()
|
||||
registry_model_name = f"{model_name}@llm"
|
||||
model_instances = await registry.get_all_instances(
|
||||
registry_model_name, healthy_only=True
|
||||
)
|
||||
if not model_instances:
|
||||
all_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||
models = [
|
||||
ins.model_name.split("@llm")[0]
|
||||
for ins in all_instances
|
||||
if ins.model_name.endswith("@llm")
|
||||
]
|
||||
if models:
|
||||
models = "&&".join(models)
|
||||
message = f"Only {models} allowed now, your model {model_name}"
|
||||
else:
|
||||
message = f"No models allowed now, your model {model_name}"
|
||||
raise APIServerException(ErrorCode.INVALID_MODEL, message)
|
||||
return model_instances
|
||||
|
||||
async def get_available_models(self) -> ModelList:
|
||||
"""Return available models
|
||||
|
||||
Just include LLM and embedding models.
|
||||
|
||||
Returns:
|
||||
List[ModelList]: The list of models.
|
||||
"""
|
||||
registry = self.get_model_registry()
|
||||
model_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||
model_name_set = set()
|
||||
for inst in model_instances:
|
||||
name, worker_type = WorkerType.parse_worker_key(inst.model_name)
|
||||
if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC:
|
||||
model_name_set.add(name)
|
||||
models = list(model_name_set)
|
||||
models.sort()
|
||||
# TODO: return real model permission details
|
||||
model_cards = []
|
||||
for m in models:
|
||||
model_cards.append(
|
||||
ModelCard(
|
||||
id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()]
|
||||
)
|
||||
)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self, model_name: str, params: Dict[str, Any], n: int
|
||||
) -> Generator[str, Any, None]:
|
||||
"""Chat stream completion generator
|
||||
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
params (Dict[str, Any]): The parameters pass to model worker
|
||||
n (int): How many completions to generate for each prompt.
|
||||
"""
|
||||
worker_manager = self.get_worker_manager()
|
||||
id = f"chatcmpl-{shortuuid.random()}"
|
||||
finish_stream_events = []
|
||||
for i in range(n):
|
||||
# First chunk with role
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
|
||||
previous_text = ""
|
||||
async for model_output in worker_manager.generate_stream(params):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
delta_text = decoded_unicode[len(previous_text) :]
|
||||
previous_text = (
|
||||
decoded_unicode
|
||||
if len(decoded_unicode) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
|
||||
if len(delta_text) == 0:
|
||||
delta_text = None
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
)
|
||||
if delta_text is None:
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
continue
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
||||
for finish_chunk in finish_stream_events:
|
||||
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_generate(
|
||||
self, model_name: str, params: Dict[str, Any], n: int
|
||||
) -> ChatCompletionResponse:
|
||||
"""Generate completion
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
params (Dict[str, Any]): The parameters pass to model worker
|
||||
n (int): How many completions to generate for each prompt.
|
||||
"""
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
choices = []
|
||||
chat_completions = []
|
||||
for i in range(n):
|
||||
model_output = asyncio.create_task(worker_manager.generate(params))
|
||||
chat_completions.append(model_output)
|
||||
try:
|
||||
all_tasks = await asyncio.gather(*chat_completions)
|
||||
except Exception as e:
|
||||
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
||||
usage = UsageInfo()
|
||||
for i, model_output in enumerate(all_tasks):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
return create_error_response(model_output.error_code, model_output.text)
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role="assistant", content=model_output.text),
|
||||
finish_reason=model_output.finish_reason or "stop",
|
||||
)
|
||||
)
|
||||
if model_output.usage:
|
||||
task_usage = UsageInfo.parse_obj(model_output.usage)
|
||||
for usage_key, usage_value in task_usage.dict().items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
|
||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||
|
||||
|
||||
def get_api_server() -> APIServer:
|
||||
api_server = global_system_app.get_component(
|
||||
ComponentType.MODEL_API_SERVER, APIServer, default_component=None
|
||||
)
|
||||
if not api_server:
|
||||
global_system_app.register(APIServer)
|
||||
return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
async def get_available_models(api_server: APIServer = Depends(get_api_server)):
|
||||
return await api_server.get_available_models()
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
||||
async def create_chat_completion(
|
||||
request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
await api_server.get_model_instances_or_raise(request.model)
|
||||
error_check_ret = check_requests(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
params = {
|
||||
"model": request.model,
|
||||
"messages": ModelMessage.to_dict_list(
|
||||
ModelMessage.from_openai_messages(request.messages)
|
||||
),
|
||||
"echo": False,
|
||||
}
|
||||
if request.temperature:
|
||||
params["temperature"] = request.temperature
|
||||
if request.top_p:
|
||||
params["top_p"] = request.top_p
|
||||
if request.max_tokens:
|
||||
params["max_new_tokens"] = request.max_tokens
|
||||
if request.stop:
|
||||
params["stop"] = request.stop
|
||||
if request.user:
|
||||
params["user"] = request.user
|
||||
|
||||
# TODO check token length
|
||||
if request.stream:
|
||||
generator = api_server.chat_completion_stream_generator(
|
||||
request.model, params, request.n
|
||||
)
|
||||
return StreamingResponse(generator, media_type="text/event-stream")
|
||||
return await api_server.chat_completion_generate(request.model, params, request.n)
|
||||
|
||||
|
||||
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||
from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient
|
||||
from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
|
||||
if not system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
|
||||
):
|
||||
# Register model registry if not exist
|
||||
registry = ModelRegistryClient(controller_addr)
|
||||
registry.name = ComponentType.MODEL_REGISTRY.value
|
||||
system_app.register_instance(registry)
|
||||
|
||||
registry = system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
|
||||
)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
|
||||
# Register worker manager component if not exist
|
||||
system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY,
|
||||
WorkerManagerFactory,
|
||||
or_register_component=_DefaultWorkerManagerFactory,
|
||||
worker_manager=worker_manager,
|
||||
)
|
||||
# Register api server component if not exist
|
||||
system_app.get_component(
|
||||
ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer
|
||||
)
|
||||
|
||||
|
||||
def initialize_apiserver(
|
||||
controller_addr: str,
|
||||
app=None,
|
||||
system_app: SystemApp = None,
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
api_keys: List[str] = None,
|
||||
):
|
||||
global global_system_app
|
||||
global api_settings
|
||||
embedded_mod = True
|
||||
if not app:
|
||||
embedded_mod = False
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if not system_app:
|
||||
system_app = SystemApp(app)
|
||||
global_system_app = system_app
|
||||
|
||||
if api_keys:
|
||||
api_settings.api_keys = api_keys
|
||||
|
||||
app.include_router(router, prefix="/api", tags=["APIServer"])
|
||||
|
||||
@app.exception_handler(APIServerException)
|
||||
async def validation_apiserver_exception_handler(request, exc: APIServerException):
|
||||
return create_error_response(exc.code, exc.message)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
|
||||
|
||||
_initialize_all(controller_addr, system_app)
|
||||
|
||||
if not embedded_mod:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
def run_apiserver():
|
||||
parser = EnvArgumentParser()
|
||||
env_prefix = "apiserver_"
|
||||
apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass(
|
||||
ModelAPIServerParameters,
|
||||
env_prefixes=[env_prefix],
|
||||
)
|
||||
setup_logging(
|
||||
"pilot",
|
||||
logging_level=apiserver_params.log_level,
|
||||
logger_filename=apiserver_params.log_file,
|
||||
)
|
||||
api_keys = None
|
||||
if apiserver_params.api_keys:
|
||||
api_keys = apiserver_params.api_keys.strip().split(",")
|
||||
|
||||
initialize_apiserver(
|
||||
apiserver_params.controller_addr,
|
||||
host=apiserver_params.host,
|
||||
port=apiserver_params.port,
|
||||
api_keys=api_keys,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_apiserver()
|
0
pilot/model/cluster/apiserver/tests/__init__.py
Normal file
0
pilot/model/cluster/apiserver/tests/__init__.py
Normal file
248
pilot/model/cluster/apiserver/tests/test_api.py
Normal file
248
pilot/model/cluster/apiserver/tests/test_api.py
Normal file
@ -0,0 +1,248 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from aioresponses import aioresponses
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from httpx import AsyncClient, HTTPError
|
||||
|
||||
from pilot.component import SystemApp
|
||||
from pilot.utils.openai_utils import chat_completion_stream, chat_completion
|
||||
|
||||
from pilot.model.cluster.apiserver.api import (
|
||||
api_settings,
|
||||
initialize_apiserver,
|
||||
ModelList,
|
||||
UsageInfo,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
ChatCompletionResponseChoice,
|
||||
DeltaMessage,
|
||||
)
|
||||
from pilot.model.cluster.tests.conftest import _new_cluster
|
||||
|
||||
from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def system_app():
|
||||
return SystemApp(app)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(request, system_app: SystemApp):
|
||||
param = getattr(request, "param", {})
|
||||
api_keys = param.get("api_keys", [])
|
||||
client_api_key = param.get("client_api_key")
|
||||
if "num_workers" not in param:
|
||||
param["num_workers"] = 2
|
||||
if "api_keys" in param:
|
||||
del param["api_keys"]
|
||||
headers = {}
|
||||
if client_api_key:
|
||||
headers["Authorization"] = "Bearer " + client_api_key
|
||||
print(f"param: {param}")
|
||||
if api_settings:
|
||||
# Clear global api keys
|
||||
api_settings.api_keys = []
|
||||
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
|
||||
async with _new_cluster(**param) as cluster:
|
||||
worker_manager, model_registry = cluster
|
||||
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
|
||||
system_app.register_instance(model_registry)
|
||||
# print(f"Instances {model_registry.registry}")
|
||||
initialize_apiserver(None, app, system_app, api_keys=api_keys)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_models(client: AsyncClient):
|
||||
res = await client.get("/api/v1/models")
|
||||
res.status_code == 200
|
||||
model_lists = ModelList.parse_obj(res.json())
|
||||
print(f"model list json: {res.json()}")
|
||||
assert model_lists.object == "list"
|
||||
assert len(model_lists.data) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages",
|
||||
[
|
||||
({"stream_messags": ["Hello", " world."]}, "Hello world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions(client: AsyncClient, expected_messages):
|
||||
chat_data = {
|
||||
"model": "test-model-name-0",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": True,
|
||||
}
|
||||
full_text = ""
|
||||
async for text in chat_completion_stream(
|
||||
"/api/v1/chat/completions", chat_data, client
|
||||
):
|
||||
full_text += text
|
||||
assert full_text == expected_messages
|
||||
|
||||
assert (
|
||||
await chat_completion("/api/v1/chat/completions", chat_data, client)
|
||||
== expected_messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages, client_api_key",
|
||||
[
|
||||
(
|
||||
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
|
||||
"Hello world.",
|
||||
"abc",
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions_with_openai_lib_async_no_stream(
|
||||
client: AsyncClient, expected_messages: str, client_api_key: str
|
||||
):
|
||||
import openai
|
||||
|
||||
openai.api_key = client_api_key
|
||||
openai.api_base = "http://test/api/v1"
|
||||
|
||||
model_name = "test-model-name-0"
|
||||
|
||||
with aioresponses() as mocked:
|
||||
mock_message = {"text": expected_messages}
|
||||
one_res = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=expected_messages),
|
||||
finish_reason="stop",
|
||||
)
|
||||
data = ChatCompletionResponse(
|
||||
model=model_name, choices=[one_res], usage=UsageInfo()
|
||||
)
|
||||
mock_message = f"{data.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
# Mock http request
|
||||
mocked.post(
|
||||
"http://test/api/v1/chat/completions", status=200, body=mock_message
|
||||
)
|
||||
completion = await openai.ChatCompletion.acreate(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": "Hello! What is your name?"}],
|
||||
)
|
||||
assert completion.choices[0].message.content == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages, client_api_key",
|
||||
[
|
||||
(
|
||||
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
|
||||
"Hello world.",
|
||||
"abc",
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions_with_openai_lib_async_stream(
|
||||
client: AsyncClient, expected_messages: str, client_api_key: str
|
||||
):
|
||||
import openai
|
||||
|
||||
openai.api_key = client_api_key
|
||||
openai.api_base = "http://test/api/v1"
|
||||
|
||||
model_name = "test-model-name-0"
|
||||
|
||||
with aioresponses() as mocked:
|
||||
mock_message = {"text": expected_messages}
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(content=expected_messages),
|
||||
finish_reason="stop",
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=0, choices=[choice_data], model=model_name
|
||||
)
|
||||
mock_message = f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
mocked.post(
|
||||
"http://test/api/v1/chat/completions",
|
||||
status=200,
|
||||
body=mock_message,
|
||||
content_type="text/event-stream",
|
||||
)
|
||||
|
||||
stream_stream_resp = ""
|
||||
async for stream_resp in await openai.ChatCompletion.acreate(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": "Hello! What is your name?"}],
|
||||
stream=True,
|
||||
):
|
||||
stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "")
|
||||
assert stream_stream_resp == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages, api_key_is_error",
|
||||
[
|
||||
(
|
||||
{
|
||||
"stream_messags": ["Hello", " world."],
|
||||
"api_keys": ["abc", "xx"],
|
||||
"client_api_key": "abc",
|
||||
},
|
||||
"Hello world.",
|
||||
False,
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。", False),
|
||||
(
|
||||
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc", "xx"]},
|
||||
"你好,我是张三。",
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"stream_messags": ["你好,我是", "张三。"],
|
||||
"api_keys": ["abc", "xx"],
|
||||
"client_api_key": "error_api_key",
|
||||
},
|
||||
"你好,我是张三。",
|
||||
True,
|
||||
),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions_with_api_keys(
|
||||
client: AsyncClient, expected_messages: str, api_key_is_error: bool
|
||||
):
|
||||
chat_data = {
|
||||
"model": "test-model-name-0",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": True,
|
||||
}
|
||||
if api_key_is_error:
|
||||
with pytest.raises(HTTPError):
|
||||
await chat_completion("/api/v1/chat/completions", chat_data, client)
|
||||
else:
|
||||
assert (
|
||||
await chat_completion("/api/v1/chat/completions", chat_data, client)
|
||||
== expected_messages
|
||||
)
|
@ -66,7 +66,9 @@ class LocalModelController(BaseModelController):
|
||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||
)
|
||||
if not model_name:
|
||||
return await self.registry.get_all_model_instances()
|
||||
return await self.registry.get_all_model_instances(
|
||||
healthy_only=healthy_only
|
||||
)
|
||||
else:
|
||||
return await self.registry.get_all_instances(model_name, healthy_only)
|
||||
|
||||
@ -98,8 +100,10 @@ class _RemoteModelController(BaseModelController):
|
||||
|
||||
|
||||
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
return await self.get_all_instances()
|
||||
async def get_all_model_instances(
|
||||
self, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
return await self.get_all_instances(healthy_only=healthy_only)
|
||||
|
||||
@sync_api_remote(path="/api/controller/models")
|
||||
def sync_get_all_instances(
|
||||
|
@ -1,22 +1,37 @@
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import itertools
|
||||
|
||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||
from pilot.model.base import ModelInstance
|
||||
|
||||
|
||||
class ModelRegistry(ABC):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRegistry(BaseComponent, ABC):
|
||||
"""
|
||||
Abstract base class for a model registry. It provides an interface
|
||||
for registering, deregistering, fetching instances, and sending heartbeats
|
||||
for instances.
|
||||
"""
|
||||
|
||||
name = ComponentType.MODEL_REGISTRY
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
self.system_app = system_app
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the component with the main application."""
|
||||
self.system_app = system_app
|
||||
|
||||
@abstractmethod
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
"""
|
||||
@ -65,9 +80,11 @@ class ModelRegistry(ABC):
|
||||
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
async def get_all_model_instances(
|
||||
self, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""
|
||||
Fetch all instances of all models
|
||||
Fetch all instances of all models, Optionally, fetch only the healthy instances.
|
||||
|
||||
Returns:
|
||||
- List[ModelInstance]: A list of instances for the all models.
|
||||
@ -105,8 +122,12 @@ class ModelRegistry(ABC):
|
||||
|
||||
class EmbeddedModelRegistry(ModelRegistry):
|
||||
def __init__(
|
||||
self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120
|
||||
self,
|
||||
system_app: SystemApp | None = None,
|
||||
heartbeat_interval_secs: int = 60,
|
||||
heartbeat_timeout_secs: int = 120,
|
||||
):
|
||||
super().__init__(system_app)
|
||||
self.registry: Dict[str, List[ModelInstance]] = defaultdict(list)
|
||||
self.heartbeat_interval_secs = heartbeat_interval_secs
|
||||
self.heartbeat_timeout_secs = heartbeat_timeout_secs
|
||||
@ -180,9 +201,14 @@ class EmbeddedModelRegistry(ModelRegistry):
|
||||
instances = [ins for ins in instances if ins.healthy == True]
|
||||
return instances
|
||||
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
print(self.registry)
|
||||
return list(itertools.chain(*self.registry.values()))
|
||||
async def get_all_model_instances(
|
||||
self, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
logger.debug("Current registry metadata:\n{self.registry}")
|
||||
instances = list(itertools.chain(*self.registry.values()))
|
||||
if healthy_only:
|
||||
instances = [ins for ins in instances if ins.healthy == True]
|
||||
return instances
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
_, exist_ins = self._get_instances(
|
||||
|
0
pilot/model/cluster/tests/__init__.py
Normal file
0
pilot/model/cluster/tests/__init__.py
Normal file
@ -6,6 +6,7 @@ from pilot.model.parameter import ModelParameters, ModelWorkerParameters, Worker
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.worker.manager import (
|
||||
WorkerManager,
|
||||
LocalWorkerManager,
|
||||
RegisterFunc,
|
||||
DeregisterFunc,
|
||||
@ -13,6 +14,23 @@ from pilot.model.cluster.worker.manager import (
|
||||
ApplyFunction,
|
||||
)
|
||||
|
||||
from pilot.model.base import ModelInstance
|
||||
from pilot.model.cluster.registry import ModelRegistry, EmbeddedModelRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_registry(request):
|
||||
return EmbeddedModelRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_instance():
|
||||
return ModelInstance(
|
||||
model_name="test_model",
|
||||
host="192.168.1.1",
|
||||
port=5000,
|
||||
)
|
||||
|
||||
|
||||
class MockModelWorker(ModelWorker):
|
||||
def __init__(
|
||||
@ -51,8 +69,10 @@ class MockModelWorker(ModelWorker):
|
||||
raise Exception("Stop worker error for mock")
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
full_text = ""
|
||||
for msg in self.stream_messags:
|
||||
yield ModelOutput(text=msg, error_code=0)
|
||||
full_text += msg
|
||||
yield ModelOutput(text=full_text, error_code=0)
|
||||
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
output = None
|
||||
@ -67,6 +87,8 @@ class MockModelWorker(ModelWorker):
|
||||
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
|
||||
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
|
||||
|
||||
ClusterType = Tuple[WorkerManager, ModelRegistry]
|
||||
|
||||
|
||||
def _new_worker_params(
|
||||
model_name: str = _TEST_MODEL_NAME,
|
||||
@ -85,7 +107,9 @@ def _create_workers(
|
||||
worker_type: str = WorkerType.LLM.value,
|
||||
stream_messags: List[str] = None,
|
||||
embeddings: List[List[float]] = None,
|
||||
) -> List[Tuple[ModelWorker, ModelWorkerParameters]]:
|
||||
host: str = "127.0.0.1",
|
||||
start_port=8001,
|
||||
) -> List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]:
|
||||
workers = []
|
||||
for i in range(num_workers):
|
||||
model_name = f"test-model-name-{i}"
|
||||
@ -98,10 +122,16 @@ def _create_workers(
|
||||
stream_messags=stream_messags,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
model_instance = ModelInstance(
|
||||
model_name=WorkerType.to_worker_key(model_name, worker_type),
|
||||
host=host,
|
||||
port=start_port + i,
|
||||
healthy=True,
|
||||
)
|
||||
worker_params = _new_worker_params(
|
||||
model_name, model_path, worker_type=worker_type
|
||||
)
|
||||
workers.append((worker, worker_params))
|
||||
workers.append((worker, worker_params, model_instance))
|
||||
return workers
|
||||
|
||||
|
||||
@ -127,12 +157,12 @@ async def _start_worker_manager(**kwargs):
|
||||
model_registry=model_registry,
|
||||
)
|
||||
|
||||
for worker, worker_params in _create_workers(
|
||||
for worker, worker_params, model_instance in _create_workers(
|
||||
num_workers, error_worker, stop_error, stream_messags, embeddings
|
||||
):
|
||||
worker_manager.add_worker(worker, worker_params)
|
||||
if workers:
|
||||
for worker, worker_params in workers:
|
||||
for worker, worker_params, model_instance in workers:
|
||||
worker_manager.add_worker(worker, worker_params)
|
||||
|
||||
if start:
|
||||
@ -143,6 +173,15 @@ async def _start_worker_manager(**kwargs):
|
||||
await worker_manager.stop()
|
||||
|
||||
|
||||
async def _create_model_registry(
|
||||
workers: List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]
|
||||
) -> ModelRegistry:
|
||||
registry = EmbeddedModelRegistry()
|
||||
for _, _, inst in workers:
|
||||
assert await registry.register_instance(inst) == True
|
||||
return registry
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def manager_2_workers(request):
|
||||
param = getattr(request, "param", {})
|
||||
@ -166,3 +205,27 @@ async def manager_2_embedding_workers(request):
|
||||
)
|
||||
async with _start_worker_manager(workers=workers, **param) as worker_manager:
|
||||
yield (worker_manager, workers)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _new_cluster(**kwargs) -> ClusterType:
|
||||
num_workers = kwargs.get("num_workers", 0)
|
||||
workers = _create_workers(
|
||||
num_workers, stream_messags=kwargs.get("stream_messags", [])
|
||||
)
|
||||
if "num_workers" in kwargs:
|
||||
del kwargs["num_workers"]
|
||||
registry = await _create_model_registry(
|
||||
workers,
|
||||
)
|
||||
async with _start_worker_manager(workers=workers, **kwargs) as worker_manager:
|
||||
yield (worker_manager, registry)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def cluster_2_workers(request):
|
||||
param = getattr(request, "param", {})
|
||||
workers = _create_workers(2)
|
||||
registry = await _create_model_registry(workers)
|
||||
async with _start_worker_manager(workers=workers, **param) as worker_manager:
|
||||
yield (worker_manager, registry)
|
@ -256,15 +256,22 @@ class DefaultModelWorker(ModelWorker):
|
||||
return params, model_context, generate_stream_func, model_span
|
||||
|
||||
def _handle_output(self, output, previous_response, model_context):
|
||||
finish_reason = None
|
||||
usage = None
|
||||
if isinstance(output, dict):
|
||||
finish_reason = output.get("finish_reason")
|
||||
usage = output.get("usage")
|
||||
output = output["text"]
|
||||
if finish_reason is not None:
|
||||
logger.info(f"finish_reason: {finish_reason}")
|
||||
incremental_output = output[len(previous_response) :]
|
||||
print(incremental_output, end="", flush=True)
|
||||
model_output = ModelOutput(
|
||||
text=output, error_code=0, model_context=model_context
|
||||
text=output,
|
||||
error_code=0,
|
||||
model_context=model_context,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
return model_output, incremental_output, output
|
||||
|
||||
|
@ -99,9 +99,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
)
|
||||
|
||||
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
||||
if isinstance(worker_type, WorkerType):
|
||||
worker_type = worker_type.value
|
||||
return f"{model_name}@{worker_type}"
|
||||
return WorkerType.to_worker_key(model_name, worker_type)
|
||||
|
||||
async def run_blocking_func(self, func, *args):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
@ -3,7 +3,7 @@ import pytest
|
||||
from typing import List, Iterator, Dict, Tuple
|
||||
from dataclasses import asdict
|
||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from pilot.model.base import ModelOutput, WorkerApplyType
|
||||
from pilot.model.base import ModelOutput, WorkerApplyType, ModelInstance
|
||||
from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.manager_base import WorkerRunData
|
||||
@ -14,7 +14,7 @@ from pilot.model.cluster.worker.manager import (
|
||||
SendHeartbeatFunc,
|
||||
ApplyFunction,
|
||||
)
|
||||
from pilot.model.cluster.worker.tests.base_tests import (
|
||||
from pilot.model.cluster.tests.conftest import (
|
||||
MockModelWorker,
|
||||
manager_2_workers,
|
||||
manager_with_2_workers,
|
||||
@ -216,7 +216,7 @@ async def test__remove_worker():
|
||||
workers = _create_workers(3)
|
||||
async with _start_worker_manager(workers=workers, stop=False) as manager:
|
||||
assert len(manager.workers) == 3
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
manager._remove_worker(worker_params)
|
||||
not_exist_parmas = _new_worker_params(
|
||||
model_name="this is a not exist worker params"
|
||||
@ -229,7 +229,7 @@ async def test__remove_worker():
|
||||
async def test_model_startup(mock_build_worker):
|
||||
async with _start_worker_manager() as manager:
|
||||
workers = _create_workers(1)
|
||||
worker, worker_params = workers[0]
|
||||
worker, worker_params, model_instance = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
|
||||
req = WorkerStartupRequest(
|
||||
@ -245,7 +245,7 @@ async def test_model_startup(mock_build_worker):
|
||||
|
||||
async with _start_worker_manager() as manager:
|
||||
workers = _create_workers(1, error_worker=True)
|
||||
worker, worker_params = workers[0]
|
||||
worker, worker_params, model_instance = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
req = WorkerStartupRequest(
|
||||
host="127.0.0.1",
|
||||
@ -263,7 +263,7 @@ async def test_model_startup(mock_build_worker):
|
||||
async def test_model_shutdown(mock_build_worker):
|
||||
async with _start_worker_manager(start=False, stop=False) as manager:
|
||||
workers = _create_workers(1)
|
||||
worker, worker_params = workers[0]
|
||||
worker, worker_params, model_instance = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
|
||||
req = WorkerStartupRequest(
|
||||
@ -298,7 +298,7 @@ async def test_get_model_instances(is_async):
|
||||
workers = _create_workers(3)
|
||||
async with _start_worker_manager(workers=workers, stop=False) as manager:
|
||||
assert len(manager.workers) == 3
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
if is_async:
|
||||
@ -326,7 +326,7 @@ async def test__simple_select(
|
||||
]
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
instances = await manager.get_model_instances(worker_type, model_name)
|
||||
@ -351,7 +351,7 @@ async def test_select_one_instance(
|
||||
],
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
if is_async:
|
||||
@ -376,7 +376,7 @@ async def test__get_model(
|
||||
],
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
@ -403,13 +403,13 @@ async def test_generate_stream(
|
||||
expected_messages: str,
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
text = ""
|
||||
async for out in manager.generate_stream(params):
|
||||
text += out.text
|
||||
text = out.text
|
||||
assert text == expected_messages
|
||||
|
||||
|
||||
@ -417,8 +417,8 @@ async def test_generate_stream(
|
||||
@pytest.mark.parametrize(
|
||||
"manager_with_2_workers, expected_messages",
|
||||
[
|
||||
({"stream_messags": ["Hello", " world."]}, " world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "张三。"),
|
||||
({"stream_messags": ["Hello", " world."]}, "Hello world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
|
||||
],
|
||||
indirect=["manager_with_2_workers"],
|
||||
)
|
||||
@ -429,7 +429,7 @@ async def test_generate(
|
||||
expected_messages: str,
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
@ -454,7 +454,7 @@ async def test_embeddings(
|
||||
is_async: bool,
|
||||
):
|
||||
manager, workers = manager_2_embedding_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name, "input": ["hello", "world"]}
|
||||
@ -472,7 +472,7 @@ async def test_parameter_descriptions(
|
||||
]
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = await manager.parameter_descriptions(worker_type, model_name)
|
||||
|
@ -339,6 +339,27 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# codellama template
|
||||
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
||||
# reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="codellama",
|
||||
system="<s>[INST] <<SYS>>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request."
|
||||
"If you don't know the answer to the request, please don't share false information.\n<</SYS>>\n\n",
|
||||
roles=("[INST]", "[/INST]"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.LLAMA2,
|
||||
sep=" ",
|
||||
sep2=" </s><s>",
|
||||
stop_token_ids=[2],
|
||||
system_formatter=lambda msg: f"<s>[INST] <<SYS>>\n{msg}\n<</SYS>>\n\n",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Alpaca default template
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
|
@ -45,6 +45,10 @@ _OLD_MODELS = [
|
||||
"llama-cpp",
|
||||
"proxyllm",
|
||||
"gptj-6b",
|
||||
"codellama-13b-sql-sft",
|
||||
"codellama-7b",
|
||||
"codellama-7b-sql-sft",
|
||||
"codellama-13b",
|
||||
]
|
||||
|
||||
|
||||
@ -148,8 +152,12 @@ class LLMModelAdaper:
|
||||
conv.append_message(conv.roles[1], content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
|
||||
if system_messages:
|
||||
conv.set_system_message("".join(system_messages))
|
||||
if isinstance(conv, Conversation):
|
||||
conv.set_system_message("".join(system_messages))
|
||||
else:
|
||||
conv.update_system_message("".join(system_messages))
|
||||
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
@ -459,7 +467,8 @@ register_conv_template(
|
||||
sep="\n",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "[UNK]"],
|
||||
)
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
|
||||
register_conv_template(
|
||||
@ -474,7 +483,8 @@ register_conv_template(
|
||||
sep="###",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "[UNK]"],
|
||||
)
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
|
||||
register_conv_template(
|
||||
@ -487,5 +497,6 @@ register_conv_template(
|
||||
sep="",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "<|endoftext|>"],
|
||||
)
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
|
@ -1,9 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union, Tuple
|
||||
|
||||
from pilot.model.conversation import conv_templates
|
||||
from pilot.utils.parameter_utils import BaseParameters
|
||||
@ -19,6 +20,35 @@ class WorkerType(str, Enum):
|
||||
def values():
|
||||
return [item.value for item in WorkerType]
|
||||
|
||||
@staticmethod
|
||||
def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str:
|
||||
"""Generate worker key from worker name and worker type
|
||||
|
||||
Args:
|
||||
worker_name (str): Worker name(eg., chatglm2-6b)
|
||||
worker_type (Union[str, "WorkerType"]): Worker type(eg., 'llm', or [`WorkerType.LLM`])
|
||||
|
||||
Returns:
|
||||
str: Generated worker key
|
||||
"""
|
||||
if "@" in worker_name:
|
||||
raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}")
|
||||
if isinstance(worker_type, WorkerType):
|
||||
worker_type = worker_type.value
|
||||
return f"{worker_name}@{worker_type}"
|
||||
|
||||
@staticmethod
|
||||
def parse_worker_key(worker_key: str) -> Tuple[str, str]:
|
||||
"""Parse worker name and worker type from worker key
|
||||
|
||||
Args:
|
||||
worker_key (str): Worker key generated by [`WorkerType.to_worker_key`]
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: Worker name and worker type
|
||||
"""
|
||||
return tuple(worker_key.split("@"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelControllerParameters(BaseParameters):
|
||||
@ -60,6 +90,56 @@ class ModelControllerParameters(BaseParameters):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelAPIServerParameters(BaseParameters):
|
||||
host: Optional[str] = field(
|
||||
default="0.0.0.0", metadata={"help": "Model API server deploy host"}
|
||||
)
|
||||
port: Optional[int] = field(
|
||||
default=8100, metadata={"help": "Model API server deploy port"}
|
||||
)
|
||||
daemon: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Run Model API server in background"}
|
||||
)
|
||||
controller_addr: Optional[str] = field(
|
||||
default="http://127.0.0.1:8000",
|
||||
metadata={"help": "The Model controller address to connect"},
|
||||
)
|
||||
|
||||
api_keys: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optional list of comma separated API keys"},
|
||||
)
|
||||
|
||||
log_level: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Logging level",
|
||||
"valid_values": [
|
||||
"FATAL",
|
||||
"ERROR",
|
||||
"WARNING",
|
||||
"WARNING",
|
||||
"INFO",
|
||||
"DEBUG",
|
||||
"NOTSET",
|
||||
],
|
||||
},
|
||||
)
|
||||
log_file: Optional[str] = field(
|
||||
default="dbgpt_model_apiserver.log",
|
||||
metadata={
|
||||
"help": "The filename to store log",
|
||||
},
|
||||
)
|
||||
tracer_file: Optional[str] = field(
|
||||
default="dbgpt_model_apiserver_tracer.jsonl",
|
||||
metadata={
|
||||
"help": "The filename to store tracer span records",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelParameters(BaseParameters):
|
||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from typing import Any, Dict, List, Tuple, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
@ -70,14 +70,6 @@ class SystemMessage(BaseMessage):
|
||||
return "system"
|
||||
|
||||
|
||||
class ModelMessage(BaseModel):
|
||||
"""Type of message that interaction between dbgpt-server and llm-server"""
|
||||
|
||||
"""Similar to openai's message format"""
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ModelMessageRoleType:
|
||||
""" "Type of ModelMessage role"""
|
||||
|
||||
@ -87,6 +79,45 @@ class ModelMessageRoleType:
|
||||
VIEW = "view"
|
||||
|
||||
|
||||
class ModelMessage(BaseModel):
|
||||
"""Type of message that interaction between dbgpt-server and llm-server"""
|
||||
|
||||
"""Similar to openai's message format"""
|
||||
role: str
|
||||
content: str
|
||||
|
||||
@staticmethod
|
||||
def from_openai_messages(
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List["ModelMessage"]:
|
||||
"""Openai message format to current ModelMessage format"""
|
||||
if isinstance(messages, str):
|
||||
return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)]
|
||||
result = []
|
||||
for message in messages:
|
||||
msg_role = message["role"]
|
||||
content = message["content"]
|
||||
if msg_role == "system":
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content)
|
||||
)
|
||||
elif msg_role == "user":
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||
)
|
||||
elif msg_role == "assistant":
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content=content)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
||||
return list(map(lambda m: m.dict(), messages))
|
||||
|
||||
|
||||
class Generation(BaseModel):
|
||||
"""Output of a single generation."""
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
@ -110,12 +111,14 @@ class ChatKnowledge(BaseChat):
|
||||
list(map(lambda doc: doc.metadata, docs)), "source"
|
||||
)
|
||||
|
||||
self.current_message.knowledge_source = self.sources
|
||||
if not docs:
|
||||
raise ValueError(
|
||||
"you have no knowledge space, please add your knowledge space"
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
if not docs or len(docs) == 0:
|
||||
print("no relevant docs to retrieve")
|
||||
context = "no relevant docs to retrieve"
|
||||
# raise ValueError(
|
||||
# "you have no knowledge space, please add your knowledge space"
|
||||
# )
|
||||
else:
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[: self.max_token]
|
||||
relations = list(
|
||||
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
|
||||
@ -128,17 +131,26 @@ class ChatKnowledge(BaseChat):
|
||||
return input_values
|
||||
|
||||
def parse_source_view(self, sources: List):
|
||||
html_title = f"##### **References:**"
|
||||
lines = ""
|
||||
"""
|
||||
build knowledge reference view message to web
|
||||
{
|
||||
"title":"References",
|
||||
"reference":{
|
||||
"name":"aa.pdf",
|
||||
"pages":["1","2","3"]
|
||||
},
|
||||
}
|
||||
"""
|
||||
references = {"title": "References", "reference": {}}
|
||||
for item in sources:
|
||||
source = item["source"] if "source" in item else ""
|
||||
pages = ",".join(item["pages"]) if "pages" in item else ""
|
||||
lines += f"{source}"
|
||||
references["reference"]["name"] = source
|
||||
pages = item["pages"] if "pages" in item else []
|
||||
if len(pages) > 0:
|
||||
lines += f", **pages**:{pages}\n\n"
|
||||
else:
|
||||
lines += "\n\n"
|
||||
html = f"""{html_title}\n{lines}"""
|
||||
references["reference"]["pages"] = pages
|
||||
html = (
|
||||
f"""<references>{json.dumps(references, ensure_ascii=False)}</references>"""
|
||||
)
|
||||
return html
|
||||
|
||||
def merge_by_key(self, data, key):
|
||||
|
@ -215,6 +215,16 @@ class Llama2ChatAdapter(BaseChatAdpter):
|
||||
return get_conv_template("llama-2")
|
||||
|
||||
|
||||
class CodeLlamaChatAdapter(BaseChatAdpter):
|
||||
"""The model ChatAdapter for codellama ."""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "codellama" in model_path.lower()
|
||||
|
||||
def get_conv_template(self, model_path: str) -> Conversation:
|
||||
return get_conv_template("codellama")
|
||||
|
||||
|
||||
class BaichuanChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "baichuan" in model_path.lower()
|
||||
@ -268,6 +278,7 @@ register_llm_model_chat_adapter(FalconChatAdapter)
|
||||
register_llm_model_chat_adapter(GorillaChatAdapter)
|
||||
register_llm_model_chat_adapter(GPT4AllChatAdapter)
|
||||
register_llm_model_chat_adapter(Llama2ChatAdapter)
|
||||
register_llm_model_chat_adapter(CodeLlamaChatAdapter)
|
||||
register_llm_model_chat_adapter(BaichuanChatAdapter)
|
||||
register_llm_model_chat_adapter(WizardLMChatAdapter)
|
||||
register_llm_model_chat_adapter(LlamaCppChatAdapter)
|
||||
|
99
pilot/utils/openai_utils.py
Normal file
99
pilot/utils/openai_utils.py
Normal file
@ -0,0 +1,99 @@
|
||||
from typing import Dict, Any, Awaitable, Callable, Optional, Iterator
|
||||
import httpx
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MessageCaller = Callable[[str], Awaitable[None]]
|
||||
|
||||
|
||||
async def _do_chat_completion(
|
||||
url: str,
|
||||
chat_data: Dict[str, Any],
|
||||
client: httpx.AsyncClient,
|
||||
headers: Dict[str, Any] = {},
|
||||
timeout: int = 60,
|
||||
caller: Optional[MessageCaller] = None,
|
||||
) -> Iterator[str]:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
json=chat_data,
|
||||
timeout=timeout,
|
||||
) as res:
|
||||
if res.status_code != 200:
|
||||
error_message = await res.aread()
|
||||
if error_message:
|
||||
error_message = error_message.decode("utf-8")
|
||||
logger.error(
|
||||
f"Request failed with status {res.status_code}. Error: {error_message}"
|
||||
)
|
||||
raise httpx.RequestError(
|
||||
f"Request failed with status {res.status_code}",
|
||||
request=res.request,
|
||||
)
|
||||
async for line in res.aiter_lines():
|
||||
if line:
|
||||
if not line.startswith("data: "):
|
||||
if caller:
|
||||
await caller(line)
|
||||
yield line
|
||||
else:
|
||||
decoded_line = line.split("data: ", 1)[1]
|
||||
if decoded_line.lower().strip() != "[DONE]".lower():
|
||||
obj = json.loads(decoded_line)
|
||||
if obj["choices"][0]["delta"].get("content") is not None:
|
||||
text = obj["choices"][0]["delta"].get("content")
|
||||
if caller:
|
||||
await caller(text)
|
||||
yield text
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
|
||||
async def chat_completion_stream(
|
||||
url: str,
|
||||
chat_data: Dict[str, Any],
|
||||
client: Optional[httpx.AsyncClient] = None,
|
||||
headers: Dict[str, Any] = {},
|
||||
timeout: int = 60,
|
||||
caller: Optional[MessageCaller] = None,
|
||||
) -> Iterator[str]:
|
||||
if client:
|
||||
async for text in _do_chat_completion(
|
||||
url,
|
||||
chat_data,
|
||||
client=client,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
caller=caller,
|
||||
):
|
||||
yield text
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async for text in _do_chat_completion(
|
||||
url,
|
||||
chat_data,
|
||||
client=client,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
caller=caller,
|
||||
):
|
||||
yield text
|
||||
|
||||
|
||||
async def chat_completion(
|
||||
url: str,
|
||||
chat_data: Dict[str, Any],
|
||||
client: Optional[httpx.AsyncClient] = None,
|
||||
headers: Dict[str, Any] = {},
|
||||
timeout: int = 60,
|
||||
caller: Optional[MessageCaller] = None,
|
||||
) -> str:
|
||||
full_text = ""
|
||||
async for text in chat_completion_stream(
|
||||
url, chat_data, client, headers=headers, timeout=timeout, caller=caller
|
||||
):
|
||||
full_text += text
|
||||
return full_text
|
@ -8,6 +8,7 @@ pytest-integration
|
||||
pytest-mock
|
||||
pytest-recording
|
||||
pytesseract==0.3.10
|
||||
aioresponses
|
||||
# python code format
|
||||
black
|
||||
# for git hooks
|
||||
|
Loading…
Reference in New Issue
Block a user