diff --git a/CODE_OF_CONDUCT b/CODE_OF_CONDUCT new file mode 100644 index 000000000..b7efcc0b3 --- /dev/null +++ b/CODE_OF_CONDUCT @@ -0,0 +1,126 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +[INSERT CONTACT METHOD]. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +*Community Impact*: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +*Consequence*: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +*Community Impact*: A violation through a single incident or series of +actions. + +*Consequence*: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +*Community Impact*: A serious violation of community standards, including +sustained inappropriate behavior. + +*Consequence*: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +*Community Impact*: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +*Consequence*: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/translations][translations]. diff --git a/README.md b/README.md index 7d7af852a..3f1e8109f 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ DB-GPT is an experimental open-source project that uses localized GPT large mode ## Contents -- [install](#install) -- [demo](#demo) +- [Install](#install) +- [Demo](#demo) - [introduction](#introduction) - [features](#features) - [contribution](#contribution) @@ -177,7 +177,7 @@ Currently, we have released multiple key features, which are listed below to dem | [StarRocks](https://github.com/StarRocks/starrocks) | No | TODO | ## Introduction -Is the architecture of the entire DB-GPT shown in the following figure: +The architecture of the entire DB-GPT is shown.

@@ -213,7 +213,7 @@ The core capabilities mainly consist of the following parts: ## Contribution -- Please run `black .` before submitting the code. Contributing guidelines, [how to contribution](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md) +- Please run `black .` before submitting the code. Contributing guidelines, [how to contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md) ## RoadMap @@ -330,7 +330,7 @@ As of October 10, 2023, by fine-tuning an open-source model of 13 billion parame The MIT License (MIT) ## Contact Information -We are working on building a community, if you have any ideas about building the community, feel free to contact us. +We are working on building a community, if you have any ideas for building the community, feel free to contact us. [![](https://dcbadge.vercel.app/api/server/nASQyBjvY?compact=true&style=flat)](https://discord.gg/nASQyBjvY)

diff --git a/assets/wechat.jpg b/assets/wechat.jpg index c7de6223b..ec465785c 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/docker/compose_examples/cluster-docker-compose.yml b/docker/compose_examples/cluster-docker-compose.yml index b41033458..0ad6be9ae 100644 --- a/docker/compose_examples/cluster-docker-compose.yml +++ b/docker/compose_examples/cluster-docker-compose.yml @@ -7,6 +7,16 @@ services: restart: unless-stopped networks: - dbgptnet + api-server: + image: eosphorosai/dbgpt:latest + command: dbgpt start apiserver --controller_addr http://controller:8000 + restart: unless-stopped + depends_on: + - controller + networks: + - dbgptnet + ports: + - 8100:8100/tcp llm-worker: image: eosphorosai/dbgpt:latest command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000 diff --git a/docs/conf.py b/docs/conf.py index 8601627ca..437aeab9b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,7 +32,7 @@ extensions = [ "sphinx_panels", "sphinx_tabs.tabs", "IPython.sphinxext.ipython_console_highlighting", - 'sphinx.ext.autosectionlabel' + "sphinx.ext.autosectionlabel", ] source_suffix = [".ipynb", ".html", ".md", ".rst"] diff --git a/docs/getting_started/install/cluster/cluster.rst b/docs/getting_started/install/cluster/cluster.rst index 93660d0a4..17895e7bc 100644 --- a/docs/getting_started/install/cluster/cluster.rst +++ b/docs/getting_started/install/cluster/cluster.rst @@ -77,3 +77,4 @@ By analyzing this information, we can identify performance bottlenecks in model ./vms/standalone.md ./vms/index.md + ./openai.md diff --git a/docs/getting_started/install/cluster/openai.md b/docs/getting_started/install/cluster/openai.md new file mode 100644 index 000000000..8f23ba0fa --- /dev/null +++ b/docs/getting_started/install/cluster/openai.md @@ -0,0 +1,51 @@ +OpenAI-Compatible RESTful APIs +================================== +(openai-apis-index)= + +### Install Prepare + +You must [deploy DB-GPT cluster](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html) first. + +### Launch Model API Server + +```bash +dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY +``` +By default, the Model API Server starts on port 8100. + +### Validate with cURL + +#### List models + +```bash +curl http://127.0.0.1:8100/api/v1/models \ +-H "Authorization: Bearer EMPTY" \ +-H "Content-Type: application/json" +``` + +#### Chat completions + +```bash +curl http://127.0.0.1:8100/api/v1/chat/completions \ +-H "Authorization: Bearer EMPTY" \ +-H "Content-Type: application/json" \ +-d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}' +``` + +### Validate with OpenAI Official SDK + +#### Chat completions + +```python +import openai +openai.api_key = "EMPTY" +openai.api_base = "http://127.0.0.1:8100/api/v1" +model = "vicuna-13b-v1.5" + +completion = openai.ChatCompletion.create( + model=model, + messages=[{"role": "user", "content": "hello"}] +) +# print the completion +print(completion.choices[0].message.content) +``` \ No newline at end of file diff --git a/docs/getting_started/install/llm/proxyllm/proxyllm.md b/docs/getting_started/install/llm/proxyllm/proxyllm.md index a04252d2f..fae549dd3 100644 --- a/docs/getting_started/install/llm/proxyllm/proxyllm.md +++ b/docs/getting_started/install/llm/proxyllm/proxyllm.md @@ -24,9 +24,12 @@ PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions #Azure LLM_MODEL=chatgpt_proxyllm -OPENAI_API_TYPE=azure -PROXY_API_KEY={your-openai-sk} -PROXY_SERVER_URL=https://xx.openai.azure.com/v1/chat/completions +PROXY_API_KEY={your-azure-sk} +PROXY_API_BASE=https://{your domain}.openai.azure.com/ +PROXY_API_TYPE=azure +PROXY_SERVER_URL=xxxx +PROXY_API_VERSION=2023-05-15 +PROXYLLM_BACKEND=gpt-35-turbo #Aliyun tongyi LLM_MODEL=tongyi_proxyllm diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/openai.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/openai.po new file mode 100644 index 000000000..0ef41aa6d --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/openai.po @@ -0,0 +1,71 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023, csunny +# This file is distributed under the same license as the DB-GPT package. +# FIRST AUTHOR , 2023. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: DB-GPT 👏👏 0.4.0\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-11-02 21:09+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language: zh_CN\n" +"Language-Team: zh_CN \n" +"Plural-Forms: nplurals=1; plural=0;\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.12.1\n" + +#: ../../getting_started/install/cluster/openai.md:1 +#: 01f4e2bf853341198633b367efec1522 +msgid "OpenAI-Compatible RESTful APIs" +msgstr "OpenAI RESTful 兼容接口" + +#: ../../getting_started/install/cluster/openai.md:5 +#: d8717e42335e4027bf4e76b3d28768ee +msgid "Install Prepare" +msgstr "安装准备" + +#: ../../getting_started/install/cluster/openai.md:7 +#: 9a48d8ee116942468de4c6faf9a64758 +msgid "" +"You must [deploy DB-GPT cluster](https://db-" +"gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html)" +" first." +msgstr "你必须先部署 [DB-GPT 集群]" +"(https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh-cn/latest/getting_started/install/cluster/vms/index.html)。" + +#: ../../getting_started/install/cluster/openai.md:9 +#: 7673a7121f004f7ca6b1a94a7e238fa3 +msgid "Launch Model API Server" +msgstr "启动模型 API Server" + +#: ../../getting_started/install/cluster/openai.md:14 +#: 84a925c2cbcd4e4895a1d2d2fe8f720f +msgid "By default, the Model API Server starts on port 8100." +msgstr "默认情况下,模型 API Server 使用 8100 端口启动。" + +#: ../../getting_started/install/cluster/openai.md:16 +#: e53ed41977cd4721becd51eba05c6609 +msgid "Validate with cURL" +msgstr "通过 cURL 验证" + +#: ../../getting_started/install/cluster/openai.md:18 +#: 7c883b410b5c4e53a256bf17c1ded80d +msgid "List models" +msgstr "列出模型" + +#: ../../getting_started/install/cluster/openai.md:26 +#: ../../getting_started/install/cluster/openai.md:37 +#: 7cf0ed13f0754f149ec085cd6cf7a45a 990d5d5ed5d64ab49550e68495b9e7a0 +msgid "Chat completions" +msgstr "" + +#: ../../getting_started/install/cluster/openai.md:35 +#: 81583edd22df44e091d18a0832278131 +msgid "Validate with OpenAI Official SDK" +msgstr "通过 OpenAI 官方 SDK 验证" + diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po index 1f448ee4e..addf53bc6 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: DB-GPT 👏👏 0.3.5\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-11-02 10:10+0800\n" +"POT-Creation-Date: 2023-11-02 21:04+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: zh_CN\n" @@ -20,292 +20,292 @@ msgstr "" "Generated-By: Babel 2.12.1\n" #: ../../getting_started/install/environment/environment.md:1 -#: 28d2f84fc8884e78afad8118cd59c654 +#: a17719d2f4374285a7beb4d1db470146 #, fuzzy msgid "Environment Parameter" msgstr "环境变量说明" #: ../../getting_started/install/environment/environment.md:4 -#: c83fbb5e1aa643cdb09fffe7f3d1a3c5 +#: 9a62e6fff7914eeaa2d195ddef4fcb61 msgid "LLM MODEL Config" msgstr "模型配置" #: ../../getting_started/install/environment/environment.md:5 -#: eb675965ae57407e8d8bf90fed8e9e2a +#: 90e3991538324ecfac8cac7ef2103ac2 msgid "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG" msgstr "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG" #: ../../getting_started/install/environment/environment.md:6 -#: 5d28d35126d849ea9b0d963fd1ba8699 +#: 1f45af01100c4586acbc05469e3006bc msgid "LLM_MODEL=vicuna-13b" msgstr "LLM_MODEL=vicuna-13b" #: ../../getting_started/install/environment/environment.md:8 -#: 01955b2d0fbe4d94939ebf2cbb380bdd +#: bed14b704f154c2db525f7fafd3aa5a4 msgid "MODEL_SERVER_ADDRESS" msgstr "MODEL_SERVER_ADDRESS" #: ../../getting_started/install/environment/environment.md:9 -#: 4eaaa9ab59854c0386b28b3111c82784 +#: ea42946cfe4f4ad996bf82c1996e7344 msgid "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY" msgstr "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY" #: ../../getting_started/install/environment/environment.md:12 -#: 5c2dd05e16834443b7451c2541b59757 +#: 021c261231f342fdba34098b1baa06fd msgid "LIMIT_MODEL_CONCURRENCY=5" msgstr "LIMIT_MODEL_CONCURRENCY=5" #: ../../getting_started/install/environment/environment.md:14 -#: 7707836c2fb04e7da13d2d59b5f9566f +#: afaf0ba7fd09463d8ff74b514ed7264c msgid "MAX_POSITION_EMBEDDINGS" msgstr "MAX_POSITION_EMBEDDINGS" #: ../../getting_started/install/environment/environment.md:16 -#: ee24a7d3d8384e61b715ef3bd362b965 +#: e4517a942bca4361a64a00408f993f5b msgid "MAX_POSITION_EMBEDDINGS=4096" msgstr "MAX_POSITION_EMBEDDINGS=4096" #: ../../getting_started/install/environment/environment.md:18 -#: 90b51aa4e46b4d1298c672e0052c2f68 +#: 78d2ef04ed4548b9b7b0fb8ae35c9d5c msgid "QUANTIZE_QLORA" msgstr "QUANTIZE_QLORA" #: ../../getting_started/install/environment/environment.md:20 -#: 7de7a8eb431e4973ae00f68ca0686281 +#: bfa65db03c6d46bba293331f03ab15ac msgid "QUANTIZE_QLORA=True" msgstr "QUANTIZE_QLORA=True" #: ../../getting_started/install/environment/environment.md:22 -#: e331ca016a474f4aa4e9182165a2693a +#: 1947d45a7f184821910b4834ad5f1897 msgid "QUANTIZE_8bit" msgstr "QUANTIZE_8bit" #: ../../getting_started/install/environment/environment.md:24 -#: 519ccce5a0884778be2719c437a17bd4 +#: 4a2ee2919d0e4bdaa13c9d92eefd2aac msgid "QUANTIZE_8bit=True" msgstr "QUANTIZE_8bit=True" #: ../../getting_started/install/environment/environment.md:27 -#: 1c0586d070f046de8d0f9f94a6b508b4 +#: 348dc1e411b54ab09414f40a20e934e4 msgid "LLM PROXY Settings" msgstr "LLM PROXY Settings" #: ../../getting_started/install/environment/environment.md:28 -#: c208c3f4b13f4b39962de814e5be6ab9 +#: a692e78425a040f5828ab54ff9a33f77 msgid "OPENAI Key" msgstr "OPENAI Key" #: ../../getting_started/install/environment/environment.md:30 -#: 9228bbee2faa4467b1d24f1125faaac8 +#: 940d00e25a424acf92951a314a64e5ea msgid "PROXY_API_KEY={your-openai-sk}" msgstr "PROXY_API_KEY={your-openai-sk}" #: ../../getting_started/install/environment/environment.md:31 -#: 759ae581883348019c1ba79e8954728a +#: 4bd27547ae6041679e91f2a363cd1deb msgid "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions" msgstr "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions" #: ../../getting_started/install/environment/environment.md:33 -#: 83f3952917d34aab80bd34119f7d1e20 +#: cfa3071afb0b47baad6bd729d4a02cb9 msgid "from https://bard.google.com/ f12-> application-> __Secure-1PSID" msgstr "from https://bard.google.com/ f12-> application-> __Secure-1PSID" #: ../../getting_started/install/environment/environment.md:35 -#: 1d70707ca82749bb90b2bed1aee44d62 +#: a17efa03b10f47f68afac9e865982a75 msgid "BARD_PROXY_API_KEY={your-bard-token}" msgstr "BARD_PROXY_API_KEY={your-bard-token}" #: ../../getting_started/install/environment/environment.md:38 -#: 38a2091fa223493ea23cb9bbb33cf58e +#: 6bcfe90574da4d82a459e8e11bf73cba msgid "DATABASE SETTINGS" msgstr "DATABASE SETTINGS" #: ../../getting_started/install/environment/environment.md:39 -#: 5134180d7a5945b48b072a1eb92b27ba +#: 2b1e62d9bf5d4af5a22f68c8248eaafb msgid "SQLite database (Current default database)" msgstr "SQLite database (Current default database)" #: ../../getting_started/install/environment/environment.md:40 -#: 6875e2300e094668a45fa4f2551e0d30 +#: 8a909ac3b3c943da8dbc4e8dd596c80c msgid "LOCAL_DB_PATH=data/default_sqlite.db" msgstr "LOCAL_DB_PATH=data/default_sqlite.db" #: ../../getting_started/install/environment/environment.md:41 -#: 034e8f06f24f44af9d8184563f99b4b3 +#: 90ae6507932f4815b6e180051738bb93 msgid "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite" msgstr "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite" #: ../../getting_started/install/environment/environment.md:43 -#: f688149a97f740269f80b79775236ce9 +#: d2ce34e0dcf44ccf9e8007d548ba7b0a msgid "MYSQL database" msgstr "MYSQL database" #: ../../getting_started/install/environment/environment.md:44 -#: 6db0b305137d45a3aa036e4f2262f460 +#: c07159d63c334f6cbb95fcc30bfb7ea5 msgid "LOCAL_DB_TYPE=mysql" msgstr "LOCAL_DB_TYPE=mysql" #: ../../getting_started/install/environment/environment.md:45 -#: b6d662ce8d5f44f0b54a7f6e7c66f5a5 +#: e16700b2ea8d411e91d010c1cde7aecc msgid "LOCAL_DB_USER=root" msgstr "LOCAL_DB_USER=root" #: ../../getting_started/install/environment/environment.md:46 -#: cd7493d61ac9415283640dc6c018d2f4 +#: bfc2dce1bf374121b6861e677b4e1ffa msgid "LOCAL_DB_PASSWORD=aa12345678" msgstr "LOCAL_DB_PASSWORD=aa12345678" #: ../../getting_started/install/environment/environment.md:47 -#: 4ea2a622b23f4342a4c2ab7f8d9c4e8d +#: bc384739f5b04e21a34d0d2b78e7906c msgid "LOCAL_DB_HOST=127.0.0.1" msgstr "LOCAL_DB_HOST=127.0.0.1" #: ../../getting_started/install/environment/environment.md:48 -#: 936db95a0ab246098028f4dbb596cd17 +#: e5253d452e0d42b7ac308fe6fbfb5017 msgid "LOCAL_DB_PORT=3306" msgstr "LOCAL_DB_PORT=3306" #: ../../getting_started/install/environment/environment.md:51 -#: d9255f25989840ea9c9e7b34f3947c87 +#: 9ca8f6fe06ed4cbab390f94be252e165 msgid "EMBEDDING SETTINGS" msgstr "EMBEDDING SETTINGS" #: ../../getting_started/install/environment/environment.md:52 -#: b09291d32aca43928a981e873476a985 +#: 76c7c260293c4b49bae057143fd48377 msgid "EMBEDDING MODEL Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG" msgstr "EMBEDDING模型, 参考see /pilot/configs/model_config.LLM_MODEL_CONFIG" #: ../../getting_started/install/environment/environment.md:53 -#: 63de573b03a54413b997f18a1ccee279 +#: f1d63a0128ce493cae37d34f1976bcca msgid "EMBEDDING_MODEL=text2vec" msgstr "EMBEDDING_MODEL=text2vec" #: ../../getting_started/install/environment/environment.md:55 -#: 0ef8defbab544bd0b9475a036f278489 +#: b8fbb99109d04781b2dd5bc5d6efa5bd msgid "Embedding Chunk size, default 500" msgstr "Embedding 切片大小, 默认500" #: ../../getting_started/install/environment/environment.md:57 -#: 33dbc7941d054baa8c6ecfc0bf1ce271 +#: bf8256576ea34f6a9c5f261ab9aab676 msgid "KNOWLEDGE_CHUNK_SIZE=500" msgstr "KNOWLEDGE_CHUNK_SIZE=500" #: ../../getting_started/install/environment/environment.md:59 -#: e6ee9f2620ab45ecbc8e9c0642f5ca42 +#: 9b156c6b599b4c02a58ce023b4ff25f2 msgid "Embedding Chunk Overlap, default 100" msgstr "Embedding chunk Overlap, 文本块之间的最大重叠量。保留一些重叠可以保持文本块之间的连续性(例如使用滑动窗口),默认100" #: ../../getting_started/install/environment/environment.md:60 -#: fcddf64340a04df4ab95176fc2fc67a6 +#: dcafd903c36041ac85ac99a14dbee512 msgid "KNOWLEDGE_CHUNK_OVERLAP=100" msgstr "KNOWLEDGE_CHUNK_OVERLAP=100" #: ../../getting_started/install/environment/environment.md:62 -#: 61272200194b4461a921581feb1273da +#: 6c3244b7e5e24b0188c7af4bb52e9134 #, fuzzy msgid "embedding recall top k,5" msgstr "embedding 召回topk, 默认5" #: ../../getting_started/install/environment/environment.md:64 -#: b433091f055542b1b89ff2d525ac99e4 +#: f4a2f30551cf4fe1a7ff3c7c74ec77be msgid "KNOWLEDGE_SEARCH_TOP_SIZE=5" msgstr "KNOWLEDGE_SEARCH_TOP_SIZE=5" #: ../../getting_started/install/environment/environment.md:66 -#: 1db0de41aebd4caa8cc2eaecb4cacd6a +#: 593f2512362f467e92fdaa60dd5903a0 #, fuzzy msgid "embedding recall max token ,2000" msgstr "embedding向量召回最大token, 默认2000" #: ../../getting_started/install/environment/environment.md:68 -#: 81b9d862e58941a4b09680a7520cdabe +#: 83d6d28914be4d6282d457272e508ddc msgid "KNOWLEDGE_SEARCH_MAX_TOKEN=5" msgstr "KNOWLEDGE_SEARCH_MAX_TOKEN=5" #: ../../getting_started/install/environment/environment.md:71 #: ../../getting_started/install/environment/environment.md:87 -#: cac73575d54544778bdee09b18532fd9 f78a509949a64f03aa330f31901e2e7a +#: 6bc1b9d995e74294a1c78e783c550db7 d33c77ded834438e9f4a2df06e7e041a msgid "Vector Store SETTINGS" msgstr "Vector Store SETTINGS" #: ../../getting_started/install/environment/environment.md:72 #: ../../getting_started/install/environment/environment.md:88 -#: 5ebba1cb047b4b09849000244237dfbb 7e9285e91bcb4b2d9413909c0d0a06a7 +#: 9cafa06e2d584f70afd848184e0fa52a f01057251b8b4ffea806192dfe1048ed msgid "Chroma" msgstr "Chroma" #: ../../getting_started/install/environment/environment.md:73 #: ../../getting_started/install/environment/environment.md:89 -#: 05625cfcc23c4745ae1fa0d94ce5450c 3a8615f1507f4fc49d1adda5100a4edf +#: e6c16fab37484769b819aeecbc13e6db faad299722e5400e95ec6ac3c1e018b8 msgid "VECTOR_STORE_TYPE=Chroma" msgstr "VECTOR_STORE_TYPE=Chroma" #: ../../getting_started/install/environment/environment.md:74 #: ../../getting_started/install/environment/environment.md:90 -#: 5b559376aea44f159262e6d4b75c7ec1 e954782861404b10b4e893e01cf74452 +#: 4eca3a51716d406f8ffd49c06550e871 581ee9dd38064b119660c44bdd00cbaa msgid "MILVUS" msgstr "MILVUS" #: ../../getting_started/install/environment/environment.md:75 #: ../../getting_started/install/environment/environment.md:91 -#: 55ee8199c97a4929aeefd32370f2b92d 8f40c02543ea4a2ca9632dd9e8a08c2e +#: 814c93048bed46589358a854d6c99683 b72b1269a2224f5f961214e41c019f21 msgid "VECTOR_STORE_TYPE=Milvus" msgstr "VECTOR_STORE_TYPE=Milvus" #: ../../getting_started/install/environment/environment.md:76 #: ../../getting_started/install/environment/environment.md:92 -#: 528a01d25720491c8e086bf43a62ad92 ba1386d551d7494a85681a2803081a6f +#: 73ae665f1db9402883662734588fd02c c4da20319c994e83ba5a7706db967178 msgid "MILVUS_URL=127.0.0.1" msgstr "MILVUS_URL=127.0.0.1" #: ../../getting_started/install/environment/environment.md:77 #: ../../getting_started/install/environment/environment.md:93 -#: b031950dafcd4d4783c120dc933c4178 c2e9c8cdd41741e3aba01e59a6ef245d +#: e30c5288516d42aa858a485db50490c1 f843b2e58bcb4e4594e3c28499c341d0 msgid "MILVUS_PORT=19530" msgstr "MILVUS_PORT=19530" #: ../../getting_started/install/environment/environment.md:78 #: ../../getting_started/install/environment/environment.md:94 -#: 27b0a64af6434cb2840373e2b38c9bd5 d0e4d79af7954b129ffff7303a1ec3ce +#: 158669efcc7d4bcaac1c8dd01b499029 24e88ffd32f242f281c56c0ec3ad2639 msgid "MILVUS_USERNAME" msgstr "MILVUS_USERNAME" #: ../../getting_started/install/environment/environment.md:79 #: ../../getting_started/install/environment/environment.md:95 -#: 27aa1a5b61e64dd6bfe29124e274809e 5c58892498ce4f46a59f54b2887822d4 +#: 111a985297184c8aa5a0dd8e14a58445 6602093a6bb24d6792548e2392105c82 msgid "MILVUS_PASSWORD" msgstr "MILVUS_PASSWORD" #: ../../getting_started/install/environment/environment.md:80 #: ../../getting_started/install/environment/environment.md:96 -#: 009e57d4acc5434da2146f0545911c85 bac8888dcbff47fbb0ea8ae685445aac +#: 47bdfcd78fbe4ccdb5f49b717a6d01a6 b96c0545b2044926a8a8190caf94ad25 msgid "MILVUS_SECURE=" msgstr "MILVUS_SECURE=" #: ../../getting_started/install/environment/environment.md:82 #: ../../getting_started/install/environment/environment.md:98 -#: a6eeb16ab5274045bee88ecc3d93e09e eb341774403d47658b9b7e94c4c16d5c +#: 755c32b5d6c54607907a138b5474c0ec ff4f2a7ddaa14f089dda7a14e1062c36 msgid "WEAVIATE" msgstr "WEAVIATE" #: ../../getting_started/install/environment/environment.md:83 -#: fbd97522d8da4824b41b99298fd41069 +#: 23b2ce83385d40a589a004709f9864be msgid "VECTOR_STORE_TYPE=Weaviate" msgstr "VECTOR_STORE_TYPE=Weaviate" #: ../../getting_started/install/environment/environment.md:84 #: ../../getting_started/install/environment/environment.md:99 -#: 341785b4abfe42b5af1c2e04497261f4 a81cc2aabc8240f3ac1f674d9350bff4 +#: 9acef304d89a448a9e734346705ba872 cf5151b6c1594ccd8beb1c3f77769acb msgid "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network" msgstr "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network" #: ../../getting_started/install/environment/environment.md:102 -#: 5bb9e5daa36241d499089c1b1910f729 +#: c3003516b2364051bf34f8c3086e348a msgid "Multi-GPU Setting" msgstr "Multi-GPU Setting" #: ../../getting_started/install/environment/environment.md:103 -#: 30df45b7f1f7423c9f18c6360f0b7600 +#: ade8fc381c5e438aa29d159c10041713 msgid "" "See https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-" "visibility-cuda_visible_devices/ If CUDA_VISIBLE_DEVICES is not " @@ -315,49 +315,49 @@ msgstr "" "cuda_visible_devices/ 如果 CUDA_VISIBLE_DEVICES没有设置, 会使用所有可用的gpu" #: ../../getting_started/install/environment/environment.md:106 -#: 8631ea968dfb4d90a7ae6bdb2acdfdce +#: e137bd19be5e410ba6709027dbf2923a msgid "CUDA_VISIBLE_DEVICES=0" msgstr "CUDA_VISIBLE_DEVICES=0" #: ../../getting_started/install/environment/environment.md:108 -#: 0010422280dd4fe79326ebceb2a66f0e +#: 7669947acbdc4b1d92bcc029a8353a5d msgid "" "Optionally, you can also specify the gpu ID to use before the starting " "command" msgstr "你也可以通过启动命令设置gpu ID" #: ../../getting_started/install/environment/environment.md:110 -#: 00106f7341304fbd9425721ea8e6a261 +#: 751743d1753b4051beea46371278d793 msgid "CUDA_VISIBLE_DEVICES=3,4,5,6" msgstr "CUDA_VISIBLE_DEVICES=3,4,5,6" #: ../../getting_started/install/environment/environment.md:112 -#: 720aa8b3478744d78e4b10dfeccb50b4 +#: 3acc3de0af0d4df2bb575e161e377f85 msgid "You can configure the maximum memory used by each GPU." msgstr "可以设置GPU的最大内存" #: ../../getting_started/install/environment/environment.md:114 -#: f9639ac96a244296832c75bbcbdae2af +#: 67f1d9b172b84294a44ecace5436e6e0 msgid "MAX_GPU_MEMORY=16Gib" msgstr "MAX_GPU_MEMORY=16Gib" #: ../../getting_started/install/environment/environment.md:117 -#: fc4d955fdb3e4256af5c8f29b042dcd6 +#: 3c69dfe48bcf46b89b76cac1e7849a66 msgid "Other Setting" msgstr "Other Setting" #: ../../getting_started/install/environment/environment.md:118 -#: 66b14a834e884339be2d48392e884933 +#: d5015b70f4fe4d20a63de9d87f86957a msgid "Language Settings(influence prompt language)" msgstr "Language Settings(涉及prompt语言以及知识切片方式)" #: ../../getting_started/install/environment/environment.md:119 -#: 5c9f05174eb84edd9e1316cc0721a840 +#: 5543c28bb8e34c9fb3bb6b063c2b1750 msgid "LANGUAGE=en" msgstr "LANGUAGE=en" #: ../../getting_started/install/environment/environment.md:120 -#: 7f6d62117d024c51bba9255fa4fcf151 +#: cb4ed5b892ee41068c1ca76cb29aa400 msgid "LANGUAGE=zh" msgstr "LANGUAGE=zh" diff --git a/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po b/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po index 032e81609..bb2dae7af 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po +++ b/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: DB-GPT 0.3.0\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-11-02 10:10+0800\n" +"POT-Creation-Date: 2023-11-02 21:04+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: zh_CN\n" @@ -19,11 +19,11 @@ msgstr "" "Content-Transfer-Encoding: 8bit\n" "Generated-By: Babel 2.12.1\n" -#: ../../modules/knowledge.md:1 436b94d3a8374ed18feb5c14893a84e6 +#: ../../modules/knowledge.md:1 b94b3b15cb2441ed9d78abd222a717b7 msgid "Knowledge" msgstr "知识" -#: ../../modules/knowledge.md:3 918a3747cbed42d18b8c9c4547e67b14 +#: ../../modules/knowledge.md:3 c6d6e308a6ce42948d29e928136ef561 #, fuzzy msgid "" "As the knowledge base is currently the most significant user demand " @@ -34,15 +34,15 @@ msgstr "" "由于知识库是当前用户需求最显著的场景,我们原生支持知识库的构建和处理。同时,我们还在本项目中提供了多种知识库管理策略,如:pdf,md , " "txt, word, ppt" -#: ../../modules/knowledge.md:4 d4d4b5d57918485aafa457bb9fdcf626 +#: ../../modules/knowledge.md:4 268abc408d40410ba90cf5f121dc5270 msgid "Default built-in knowledge base" msgstr "" -#: ../../modules/knowledge.md:5 d4d4b5d57918485aafa457bb9fdcf626 +#: ../../modules/knowledge.md:5 558c3364c38b458a8ebf81030efc2a48 msgid "Custom addition of knowledge bases" msgstr "" -#: ../../modules/knowledge.md:6 984361ce835c4c3492e29e1fb897348a +#: ../../modules/knowledge.md:6 9cb3ce62da1440579c095848c7aef88c msgid "" "Various usage scenarios such as constructing knowledge bases through " "plugin capabilities and web crawling. Users only need to organize the " @@ -50,53 +50,53 @@ msgid "" "the knowledge base required for the large model." msgstr "" -#: ../../modules/knowledge.md:9 746e4fbd3212460198be51b90caee2c8 +#: ../../modules/knowledge.md:9 b8ca6bc4dd9845baa56e36eea7fac2a2 #, fuzzy msgid "Create your own knowledge repository" msgstr "创建你自己的知识库" -#: ../../modules/knowledge.md:11 1c46b33b0532417c824efbaa3e687c3f +#: ../../modules/knowledge.md:11 17d7178a67924f43aa5b6293707ef041 msgid "" "1.Place personal knowledge files or folders in the pilot/datasets " "directory." msgstr "" -#: ../../modules/knowledge.md:13 3b16f387b5354947a89d6df77bd65bdb +#: ../../modules/knowledge.md:13 31c31f14bf444981939689f9a9fb038a #, fuzzy msgid "" "We currently support many document formats: txt, pdf, md, html, doc, ppt," " and url." msgstr "当前支持txt, pdf, md, html, doc, ppt, url文档格式" -#: ../../modules/knowledge.md:15 09ec337d7da4418db854e58afb6c0980 +#: ../../modules/knowledge.md:15 9ad2f2e05f8842a9b9d8469a3704df23 msgid "before execution:" msgstr "开始前" -#: ../../modules/knowledge.md:22 c09b3decb018485f8e56830ddc156194 +#: ../../modules/knowledge.md:22 6fd2775914b641c4b8e486417b558ea6 msgid "" "2.Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma " "(now only support Chroma and Milvus, if you set Milvus, please set " "MILVUS_URL and MILVUS_PORT)" msgstr "" -#: ../../modules/knowledge.md:25 74460ec7709441d5945ce9f745a26d20 +#: ../../modules/knowledge.md:25 131c5f58898a4682940910980edb2043 msgid "2.Run the knowledge repository initialization command" msgstr "" -#: ../../modules/knowledge.md:31 4498ec4e46ff4e24b45dd855e829bd32 +#: ../../modules/knowledge.md:31 2cf550f17881497bb881b19efcc18c23 msgid "" "Optionally, you can run `dbgpt knowledge load --help` command to see more" " usage." msgstr "" -#: ../../modules/knowledge.md:33 5048ac3289e540f2a2b5fd0e5ed043f5 +#: ../../modules/knowledge.md:33 c8a2ea571b944bdfbcad48fa8b54fcc9 msgid "" "3.Add the knowledge repository in the interface by entering the name of " "your knowledge repository (if not specified, enter \"default\") so you " "can use it for Q&A based on your knowledge base." msgstr "" -#: ../../modules/knowledge.md:35 deeccff20f7f453dad0881b63dae2a18 +#: ../../modules/knowledge.md:35 b701170ad75e49dea7d7734c15681e0f msgid "" "Note that the default vector model used is text2vec-large-chinese (which " "is a large model, so if your personal computer configuration is not " diff --git a/pilot/base_modules/agent/db/__init__.py b/pilot/base_modules/agent/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/component.py b/pilot/component.py index 3826a5e48..16013ee17 100644 --- a/pilot/component.py +++ b/pilot/component.py @@ -46,6 +46,8 @@ class ComponentType(str, Enum): WORKER_MANAGER = "dbgpt_worker_manager" WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory" MODEL_CONTROLLER = "dbgpt_model_controller" + MODEL_REGISTRY = "dbgpt_model_registry" + MODEL_API_SERVER = "dbgpt_model_api_server" AGENT_HUB = "dbgpt_agent_hub" EXECUTOR_DEFAULT = "dbgpt_thread_pool_default" TRACER = "dbgpt_tracer" @@ -69,7 +71,6 @@ class BaseComponent(LifeCycle, ABC): This method needs to be implemented by every component to define how it integrates with the main system app. """ - pass T = TypeVar("T", bound=BaseComponent) @@ -91,13 +92,28 @@ class SystemApp(LifeCycle): """Returns the internal ASGI app.""" return self._asgi_app - def register(self, component: Type[BaseComponent], *args, **kwargs): - """Register a new component by its type.""" + def register(self, component: Type[BaseComponent], *args, **kwargs) -> T: + """Register a new component by its type. + + Args: + component (Type[BaseComponent]): The component class to register + + Returns: + T: The instance of registered component + """ instance = component(self, *args, **kwargs) self.register_instance(instance) + return instance - def register_instance(self, instance: T): - """Register an already initialized component.""" + def register_instance(self, instance: T) -> T: + """Register an already initialized component. + + Args: + instance (T): The component instance to register + + Returns: + T: The instance of registered component + """ name = instance.name if isinstance(name, ComponentType): name = name.value @@ -108,18 +124,34 @@ class SystemApp(LifeCycle): logger.info(f"Register component with name {name} and instance: {instance}") self.components[name] = instance instance.init_app(self) + return instance def get_component( self, name: Union[str, ComponentType], component_type: Type[T], default_component=_EMPTY_DEFAULT_COMPONENT, + or_register_component: Type[BaseComponent] = None, + *args, + **kwargs, ) -> T: - """Retrieve a registered component by its name and type.""" + """Retrieve a registered component by its name and type. + + Args: + name (Union[str, ComponentType]): Component name + component_type (Type[T]): The type of current retrieve component + default_component : The default component instance if not retrieve by name + or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name + + Returns: + T: The instance retrieved by component name + """ if isinstance(name, ComponentType): name = name.value component = self.components.get(name) if not component: + if or_register_component: + return self.register(or_register_component, *args, **kwargs) if default_component != _EMPTY_DEFAULT_COMPONENT: return default_component raise ValueError(f"No component found with name {name}") diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index e1575ea03..0e1fb3d40 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -78,6 +78,10 @@ LLM_MODEL_CONFIG = { "internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"), "internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"), "internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"), + "codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"), + "codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"), + "codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"), + "codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"), # For test now "opt-125m": os.path.join(MODEL_PATH, "opt-125m"), } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 69b159a13..5ce5b2173 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -320,6 +320,19 @@ class Llama2Adapter(BaseLLMAdaper): return model, tokenizer +class CodeLlamaAdapter(BaseLLMAdaper): + """The model adapter for codellama""" + + def match(self, model_path: str): + return "codellama" in model_path.lower() + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().loader(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + class BaichuanAdapter(BaseLLMAdaper): """The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)""" @@ -420,6 +433,7 @@ register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(GorillaAdapter) register_llm_model_adapters(GPT4AllAdapter) register_llm_model_adapters(Llama2Adapter) +register_llm_model_adapters(CodeLlamaAdapter) register_llm_model_adapters(BaichuanAdapter) register_llm_model_adapters(WizardLMAdapter) register_llm_model_adapters(LlamaCppAdapater) diff --git a/pilot/model/base.py b/pilot/model/base.py index e89b243c9..48480b94b 100644 --- a/pilot/model/base.py +++ b/pilot/model/base.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from enum import Enum -from typing import TypedDict, Optional, Dict, List +from typing import TypedDict, Optional, Dict, List, Any from dataclasses import dataclass, asdict from datetime import datetime from pilot.utils.parameter_utils import ParameterDescription @@ -52,6 +52,8 @@ class ModelOutput: text: str error_code: int model_context: Dict = None + finish_reason: str = None + usage: Dict[str, Any] = None def to_dict(self) -> Dict: return asdict(self) diff --git a/pilot/model/cli.py b/pilot/model/cli.py index 1030adfc2..79b47db82 100644 --- a/pilot/model/cli.py +++ b/pilot/model/cli.py @@ -8,6 +8,7 @@ from pilot.configs.model_config import LOGDIR from pilot.model.base import WorkerApplyType from pilot.model.parameter import ( ModelControllerParameters, + ModelAPIServerParameters, ModelWorkerParameters, ModelParameters, BaseParameters, @@ -441,15 +442,27 @@ def stop_model_worker(port: int): @click.command(name="apiserver") +@EnvArgumentParser.create_click_option(ModelAPIServerParameters) def start_apiserver(**kwargs): - """Start apiserver(TODO)""" - raise NotImplementedError + """Start apiserver""" + + if kwargs["daemon"]: + log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log") + _run_current_with_daemon("ModelAPIServer", log_file) + else: + from pilot.model.cluster import run_apiserver + + run_apiserver() @click.command(name="apiserver") -def stop_apiserver(**kwargs): - """Start apiserver(TODO)""" - raise NotImplementedError +@add_stop_server_options +def stop_apiserver(port: int): + """Stop apiserver""" + name = "ModelAPIServer" + if port: + name = f"{name}-{port}" + _stop_service("apiserver", name, port=port) def _stop_all_model_server(**kwargs): diff --git a/pilot/model/cluster/__init__.py b/pilot/model/cluster/__init__.py index 9937ffa0b..a777a8d4b 100644 --- a/pilot/model/cluster/__init__.py +++ b/pilot/model/cluster/__init__.py @@ -21,6 +21,7 @@ from pilot.model.cluster.controller.controller import ( run_model_controller, BaseModelController, ) +from pilot.model.cluster.apiserver.api import run_apiserver from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager @@ -40,4 +41,5 @@ __all__ = [ "ModelRegistryClient", "RemoteWorkerManager", "run_model_controller", + "run_apiserver", ] diff --git a/pilot/model/cluster/apiserver/__init__.py b/pilot/model/cluster/apiserver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/apiserver/api.py b/pilot/model/cluster/apiserver/api.py new file mode 100644 index 000000000..148a51eed --- /dev/null +++ b/pilot/model/cluster/apiserver/api.py @@ -0,0 +1,443 @@ +"""A server that provides OpenAI-compatible RESTful APIs. It supports: +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) + +Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py +""" +from typing import Optional, List, Dict, Any, Generator + +import logging +import asyncio +import shortuuid +import json +from fastapi import APIRouter, FastAPI +from fastapi import Depends, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + +from pydantic import BaseSettings + +from fastchat.protocol.openai_api_protocol import ( + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + DeltaMessage, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorResponse, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) +from fastchat.protocol.api_protocol import ( + APIChatCompletionRequest, + APITokenCheckRequest, + APITokenCheckResponse, + APITokenCheckResponseItem, +) +from fastchat.serve.openai_api_server import create_error_response, check_requests +from fastchat.constants import ErrorCode + +from pilot.component import BaseComponent, ComponentType, SystemApp +from pilot.utils.parameter_utils import EnvArgumentParser +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType +from pilot.model.base import ModelInstance, ModelOutput +from pilot.model.parameter import ModelAPIServerParameters, WorkerType +from pilot.model.cluster import ModelRegistry, ModelRegistryClient +from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory +from pilot.utils.utils import setup_logging + +logger = logging.getLogger(__name__) + + +class APIServerException(Exception): + def __init__(self, code: int, message: str): + self.code = code + self.message = message + + +class APISettings(BaseSettings): + api_keys: Optional[List[str]] = None + + +api_settings = APISettings() +get_bearer_token = HTTPBearer(auto_error=False) + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +) -> str: + if api_settings.api_keys: + if auth is None or (token := auth.credentials) not in api_settings.api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +class APIServer(BaseComponent): + name = ComponentType.MODEL_API_SERVER + + def init_app(self, system_app: SystemApp): + self.system_app = system_app + + def get_worker_manager(self) -> WorkerManager: + """Get the worker manager component instance + + Raises: + APIServerException: If can't get worker manager component instance + """ + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + if not worker_manager: + raise APIServerException( + ErrorCode.INTERNAL_ERROR, + f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app", + ) + return worker_manager + + def get_model_registry(self) -> ModelRegistry: + """Get the model registry component instance + + Raises: + APIServerException: If can't get model registry component instance + """ + + controller = self.system_app.get_component( + ComponentType.MODEL_REGISTRY, ModelRegistry + ) + if not controller: + raise APIServerException( + ErrorCode.INTERNAL_ERROR, + f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app", + ) + return controller + + async def get_model_instances_or_raise( + self, model_name: str + ) -> List[ModelInstance]: + """Get healthy model instances with request model name + + Args: + model_name (str): Model name + + Raises: + APIServerException: If can't get healthy model instances with request model name + """ + registry = self.get_model_registry() + registry_model_name = f"{model_name}@llm" + model_instances = await registry.get_all_instances( + registry_model_name, healthy_only=True + ) + if not model_instances: + all_instances = await registry.get_all_model_instances(healthy_only=True) + models = [ + ins.model_name.split("@llm")[0] + for ins in all_instances + if ins.model_name.endswith("@llm") + ] + if models: + models = "&&".join(models) + message = f"Only {models} allowed now, your model {model_name}" + else: + message = f"No models allowed now, your model {model_name}" + raise APIServerException(ErrorCode.INVALID_MODEL, message) + return model_instances + + async def get_available_models(self) -> ModelList: + """Return available models + + Just include LLM and embedding models. + + Returns: + List[ModelList]: The list of models. + """ + registry = self.get_model_registry() + model_instances = await registry.get_all_model_instances(healthy_only=True) + model_name_set = set() + for inst in model_instances: + name, worker_type = WorkerType.parse_worker_key(inst.model_name) + if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC: + model_name_set.add(name) + models = list(model_name_set) + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append( + ModelCard( + id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()] + ) + ) + return ModelList(data=model_cards) + + async def chat_completion_stream_generator( + self, model_name: str, params: Dict[str, Any], n: int + ) -> Generator[str, Any, None]: + """Chat stream completion generator + + Args: + model_name (str): Model name + params (Dict[str, Any]): The parameters pass to model worker + n (int): How many completions to generate for each prompt. + """ + worker_manager = self.get_worker_manager() + id = f"chatcmpl-{shortuuid.random()}" + finish_stream_events = [] + for i in range(n): + # First chunk with role + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + previous_text = "" + async for model_output in worker_manager.generate_stream(params): + model_output: ModelOutput = model_output + if model_output.error_code != 0: + yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = model_output.text.replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=model_output.finish_reason, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + if delta_text is None: + if model_output.finish_reason is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + async def chat_completion_generate( + self, model_name: str, params: Dict[str, Any], n: int + ) -> ChatCompletionResponse: + """Generate completion + Args: + model_name (str): Model name + params (Dict[str, Any]): The parameters pass to model worker + n (int): How many completions to generate for each prompt. + """ + worker_manager: WorkerManager = self.get_worker_manager() + choices = [] + chat_completions = [] + for i in range(n): + model_output = asyncio.create_task(worker_manager.generate(params)) + chat_completions.append(model_output) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, model_output in enumerate(all_tasks): + model_output: ModelOutput = model_output + if model_output.error_code != 0: + return create_error_response(model_output.error_code, model_output.text) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=model_output.text), + finish_reason=model_output.finish_reason or "stop", + ) + ) + if model_output.usage: + task_usage = UsageInfo.parse_obj(model_output.usage) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=model_name, choices=choices, usage=usage) + + +def get_api_server() -> APIServer: + api_server = global_system_app.get_component( + ComponentType.MODEL_API_SERVER, APIServer, default_component=None + ) + if not api_server: + global_system_app.register(APIServer) + return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer) + + +router = APIRouter() + + +@router.get("/v1/models", dependencies=[Depends(check_api_key)]) +async def get_available_models(api_server: APIServer = Depends(get_api_server)): + return await api_server.get_available_models() + + +@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) +async def create_chat_completion( + request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server) +): + await api_server.get_model_instances_or_raise(request.model) + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + params = { + "model": request.model, + "messages": ModelMessage.to_dict_list( + ModelMessage.from_openai_messages(request.messages) + ), + "echo": False, + } + if request.temperature: + params["temperature"] = request.temperature + if request.top_p: + params["top_p"] = request.top_p + if request.max_tokens: + params["max_new_tokens"] = request.max_tokens + if request.stop: + params["stop"] = request.stop + if request.user: + params["user"] = request.user + + # TODO check token length + if request.stream: + generator = api_server.chat_completion_stream_generator( + request.model, params, request.n + ) + return StreamingResponse(generator, media_type="text/event-stream") + return await api_server.chat_completion_generate(request.model, params, request.n) + + +def _initialize_all(controller_addr: str, system_app: SystemApp): + from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient + from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory + + if not system_app.get_component( + ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None + ): + # Register model registry if not exist + registry = ModelRegistryClient(controller_addr) + registry.name = ComponentType.MODEL_REGISTRY.value + system_app.register_instance(registry) + + registry = system_app.get_component( + ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None + ) + worker_manager = RemoteWorkerManager(registry) + + # Register worker manager component if not exist + system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, + WorkerManagerFactory, + or_register_component=_DefaultWorkerManagerFactory, + worker_manager=worker_manager, + ) + # Register api server component if not exist + system_app.get_component( + ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer + ) + + +def initialize_apiserver( + controller_addr: str, + app=None, + system_app: SystemApp = None, + host: str = None, + port: int = None, + api_keys: List[str] = None, +): + global global_system_app + global api_settings + embedded_mod = True + if not app: + embedded_mod = False + app = FastAPI() + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], + ) + + if not system_app: + system_app = SystemApp(app) + global_system_app = system_app + + if api_keys: + api_settings.api_keys = api_keys + + app.include_router(router, prefix="/api", tags=["APIServer"]) + + @app.exception_handler(APIServerException) + async def validation_apiserver_exception_handler(request, exc: APIServerException): + return create_error_response(exc.code, exc.message) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(request, exc): + return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) + + _initialize_all(controller_addr, system_app) + + if not embedded_mod: + import uvicorn + + uvicorn.run(app, host=host, port=port, log_level="info") + + +def run_apiserver(): + parser = EnvArgumentParser() + env_prefix = "apiserver_" + apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass( + ModelAPIServerParameters, + env_prefixes=[env_prefix], + ) + setup_logging( + "pilot", + logging_level=apiserver_params.log_level, + logger_filename=apiserver_params.log_file, + ) + api_keys = None + if apiserver_params.api_keys: + api_keys = apiserver_params.api_keys.strip().split(",") + + initialize_apiserver( + apiserver_params.controller_addr, + host=apiserver_params.host, + port=apiserver_params.port, + api_keys=api_keys, + ) + + +if __name__ == "__main__": + run_apiserver() diff --git a/pilot/model/cluster/apiserver/tests/__init__.py b/pilot/model/cluster/apiserver/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/apiserver/tests/test_api.py b/pilot/model/cluster/apiserver/tests/test_api.py new file mode 100644 index 000000000..281a8aff6 --- /dev/null +++ b/pilot/model/cluster/apiserver/tests/test_api.py @@ -0,0 +1,248 @@ +import pytest +import pytest_asyncio +from aioresponses import aioresponses +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from httpx import AsyncClient, HTTPError + +from pilot.component import SystemApp +from pilot.utils.openai_utils import chat_completion_stream, chat_completion + +from pilot.model.cluster.apiserver.api import ( + api_settings, + initialize_apiserver, + ModelList, + UsageInfo, + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + DeltaMessage, +) +from pilot.model.cluster.tests.conftest import _new_cluster + +from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + + +@pytest_asyncio.fixture +async def system_app(): + return SystemApp(app) + + +@pytest_asyncio.fixture +async def client(request, system_app: SystemApp): + param = getattr(request, "param", {}) + api_keys = param.get("api_keys", []) + client_api_key = param.get("client_api_key") + if "num_workers" not in param: + param["num_workers"] = 2 + if "api_keys" in param: + del param["api_keys"] + headers = {} + if client_api_key: + headers["Authorization"] = "Bearer " + client_api_key + print(f"param: {param}") + if api_settings: + # Clear global api keys + api_settings.api_keys = [] + async with AsyncClient(app=app, base_url="http://test", headers=headers) as client: + async with _new_cluster(**param) as cluster: + worker_manager, model_registry = cluster + system_app.register(_DefaultWorkerManagerFactory, worker_manager) + system_app.register_instance(model_registry) + # print(f"Instances {model_registry.registry}") + initialize_apiserver(None, app, system_app, api_keys=api_keys) + yield client + + +@pytest.mark.asyncio +async def test_get_all_models(client: AsyncClient): + res = await client.get("/api/v1/models") + res.status_code == 200 + model_lists = ModelList.parse_obj(res.json()) + print(f"model list json: {res.json()}") + assert model_lists.object == "list" + assert len(model_lists.data) == 2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages", + [ + ({"stream_messags": ["Hello", " world."]}, "Hello world."), + ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"), + ], + indirect=["client"], +) +async def test_chat_completions(client: AsyncClient, expected_messages): + chat_data = { + "model": "test-model-name-0", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + } + full_text = "" + async for text in chat_completion_stream( + "/api/v1/chat/completions", chat_data, client + ): + full_text += text + assert full_text == expected_messages + + assert ( + await chat_completion("/api/v1/chat/completions", chat_data, client) + == expected_messages + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages, client_api_key", + [ + ( + {"stream_messags": ["Hello", " world."], "api_keys": ["abc"]}, + "Hello world.", + "abc", + ), + ({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"), + ], + indirect=["client"], +) +async def test_chat_completions_with_openai_lib_async_no_stream( + client: AsyncClient, expected_messages: str, client_api_key: str +): + import openai + + openai.api_key = client_api_key + openai.api_base = "http://test/api/v1" + + model_name = "test-model-name-0" + + with aioresponses() as mocked: + mock_message = {"text": expected_messages} + one_res = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=expected_messages), + finish_reason="stop", + ) + data = ChatCompletionResponse( + model=model_name, choices=[one_res], usage=UsageInfo() + ) + mock_message = f"{data.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # Mock http request + mocked.post( + "http://test/api/v1/chat/completions", status=200, body=mock_message + ) + completion = await openai.ChatCompletion.acreate( + model=model_name, + messages=[{"role": "user", "content": "Hello! What is your name?"}], + ) + assert completion.choices[0].message.content == expected_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages, client_api_key", + [ + ( + {"stream_messags": ["Hello", " world."], "api_keys": ["abc"]}, + "Hello world.", + "abc", + ), + ({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"), + ], + indirect=["client"], +) +async def test_chat_completions_with_openai_lib_async_stream( + client: AsyncClient, expected_messages: str, client_api_key: str +): + import openai + + openai.api_key = client_api_key + openai.api_base = "http://test/api/v1" + + model_name = "test-model-name-0" + + with aioresponses() as mocked: + mock_message = {"text": expected_messages} + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=expected_messages), + finish_reason="stop", + ) + chunk = ChatCompletionStreamResponse( + id=0, choices=[choice_data], model=model_name + ) + mock_message = f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + mocked.post( + "http://test/api/v1/chat/completions", + status=200, + body=mock_message, + content_type="text/event-stream", + ) + + stream_stream_resp = "" + async for stream_resp in await openai.ChatCompletion.acreate( + model=model_name, + messages=[{"role": "user", "content": "Hello! What is your name?"}], + stream=True, + ): + stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "") + assert stream_stream_resp == expected_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages, api_key_is_error", + [ + ( + { + "stream_messags": ["Hello", " world."], + "api_keys": ["abc", "xx"], + "client_api_key": "abc", + }, + "Hello world.", + False, + ), + ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。", False), + ( + {"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc", "xx"]}, + "你好,我是张三。", + True, + ), + ( + { + "stream_messags": ["你好,我是", "张三。"], + "api_keys": ["abc", "xx"], + "client_api_key": "error_api_key", + }, + "你好,我是张三。", + True, + ), + ], + indirect=["client"], +) +async def test_chat_completions_with_api_keys( + client: AsyncClient, expected_messages: str, api_key_is_error: bool +): + chat_data = { + "model": "test-model-name-0", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + } + if api_key_is_error: + with pytest.raises(HTTPError): + await chat_completion("/api/v1/chat/completions", chat_data, client) + else: + assert ( + await chat_completion("/api/v1/chat/completions", chat_data, client) + == expected_messages + ) diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index 173c8c019..0006d91a0 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -66,7 +66,9 @@ class LocalModelController(BaseModelController): f"Get all instances with {model_name}, healthy_only: {healthy_only}" ) if not model_name: - return await self.registry.get_all_model_instances() + return await self.registry.get_all_model_instances( + healthy_only=healthy_only + ) else: return await self.registry.get_all_instances(model_name, healthy_only) @@ -98,8 +100,10 @@ class _RemoteModelController(BaseModelController): class ModelRegistryClient(_RemoteModelController, ModelRegistry): - async def get_all_model_instances(self) -> List[ModelInstance]: - return await self.get_all_instances() + async def get_all_model_instances( + self, healthy_only: bool = False + ) -> List[ModelInstance]: + return await self.get_all_instances(healthy_only=healthy_only) @sync_api_remote(path="/api/controller/models") def sync_get_all_instances( diff --git a/pilot/model/cluster/registry.py b/pilot/model/cluster/registry.py index 398882eb9..eb5f1e415 100644 --- a/pilot/model/cluster/registry.py +++ b/pilot/model/cluster/registry.py @@ -1,22 +1,37 @@ import random import threading import time +import logging from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime, timedelta -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import itertools +from pilot.component import BaseComponent, ComponentType, SystemApp from pilot.model.base import ModelInstance -class ModelRegistry(ABC): +logger = logging.getLogger(__name__) + + +class ModelRegistry(BaseComponent, ABC): """ Abstract base class for a model registry. It provides an interface for registering, deregistering, fetching instances, and sending heartbeats for instances. """ + name = ComponentType.MODEL_REGISTRY + + def __init__(self, system_app: SystemApp | None = None): + self.system_app = system_app + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Initialize the component with the main application.""" + self.system_app = system_app + @abstractmethod async def register_instance(self, instance: ModelInstance) -> bool: """ @@ -65,9 +80,11 @@ class ModelRegistry(ABC): """Fetch all instances of a given model. Optionally, fetch only the healthy instances.""" @abstractmethod - async def get_all_model_instances(self) -> List[ModelInstance]: + async def get_all_model_instances( + self, healthy_only: bool = False + ) -> List[ModelInstance]: """ - Fetch all instances of all models + Fetch all instances of all models, Optionally, fetch only the healthy instances. Returns: - List[ModelInstance]: A list of instances for the all models. @@ -105,8 +122,12 @@ class ModelRegistry(ABC): class EmbeddedModelRegistry(ModelRegistry): def __init__( - self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120 + self, + system_app: SystemApp | None = None, + heartbeat_interval_secs: int = 60, + heartbeat_timeout_secs: int = 120, ): + super().__init__(system_app) self.registry: Dict[str, List[ModelInstance]] = defaultdict(list) self.heartbeat_interval_secs = heartbeat_interval_secs self.heartbeat_timeout_secs = heartbeat_timeout_secs @@ -180,9 +201,14 @@ class EmbeddedModelRegistry(ModelRegistry): instances = [ins for ins in instances if ins.healthy == True] return instances - async def get_all_model_instances(self) -> List[ModelInstance]: - print(self.registry) - return list(itertools.chain(*self.registry.values())) + async def get_all_model_instances( + self, healthy_only: bool = False + ) -> List[ModelInstance]: + logger.debug("Current registry metadata:\n{self.registry}") + instances = list(itertools.chain(*self.registry.values())) + if healthy_only: + instances = [ins for ins in instances if ins.healthy == True] + return instances async def send_heartbeat(self, instance: ModelInstance) -> bool: _, exist_ins = self._get_instances( diff --git a/pilot/model/cluster/tests/__init__.py b/pilot/model/cluster/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/worker/tests/base_tests.py b/pilot/model/cluster/tests/conftest.py similarity index 71% rename from pilot/model/cluster/worker/tests/base_tests.py rename to pilot/model/cluster/tests/conftest.py index 21821d9f9..f614387ac 100644 --- a/pilot/model/cluster/worker/tests/base_tests.py +++ b/pilot/model/cluster/tests/conftest.py @@ -6,6 +6,7 @@ from pilot.model.parameter import ModelParameters, ModelWorkerParameters, Worker from pilot.model.base import ModelOutput from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker.manager import ( + WorkerManager, LocalWorkerManager, RegisterFunc, DeregisterFunc, @@ -13,6 +14,23 @@ from pilot.model.cluster.worker.manager import ( ApplyFunction, ) +from pilot.model.base import ModelInstance +from pilot.model.cluster.registry import ModelRegistry, EmbeddedModelRegistry + + +@pytest.fixture +def model_registry(request): + return EmbeddedModelRegistry() + + +@pytest.fixture +def model_instance(): + return ModelInstance( + model_name="test_model", + host="192.168.1.1", + port=5000, + ) + class MockModelWorker(ModelWorker): def __init__( @@ -51,8 +69,10 @@ class MockModelWorker(ModelWorker): raise Exception("Stop worker error for mock") def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: + full_text = "" for msg in self.stream_messags: - yield ModelOutput(text=msg, error_code=0) + full_text += msg + yield ModelOutput(text=full_text, error_code=0) def generate(self, params: Dict) -> ModelOutput: output = None @@ -67,6 +87,8 @@ class MockModelWorker(ModelWorker): _TEST_MODEL_NAME = "vicuna-13b-v1.5" _TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5" +ClusterType = Tuple[WorkerManager, ModelRegistry] + def _new_worker_params( model_name: str = _TEST_MODEL_NAME, @@ -85,7 +107,9 @@ def _create_workers( worker_type: str = WorkerType.LLM.value, stream_messags: List[str] = None, embeddings: List[List[float]] = None, -) -> List[Tuple[ModelWorker, ModelWorkerParameters]]: + host: str = "127.0.0.1", + start_port=8001, +) -> List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]: workers = [] for i in range(num_workers): model_name = f"test-model-name-{i}" @@ -98,10 +122,16 @@ def _create_workers( stream_messags=stream_messags, embeddings=embeddings, ) + model_instance = ModelInstance( + model_name=WorkerType.to_worker_key(model_name, worker_type), + host=host, + port=start_port + i, + healthy=True, + ) worker_params = _new_worker_params( model_name, model_path, worker_type=worker_type ) - workers.append((worker, worker_params)) + workers.append((worker, worker_params, model_instance)) return workers @@ -127,12 +157,12 @@ async def _start_worker_manager(**kwargs): model_registry=model_registry, ) - for worker, worker_params in _create_workers( + for worker, worker_params, model_instance in _create_workers( num_workers, error_worker, stop_error, stream_messags, embeddings ): worker_manager.add_worker(worker, worker_params) if workers: - for worker, worker_params in workers: + for worker, worker_params, model_instance in workers: worker_manager.add_worker(worker, worker_params) if start: @@ -143,6 +173,15 @@ async def _start_worker_manager(**kwargs): await worker_manager.stop() +async def _create_model_registry( + workers: List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]] +) -> ModelRegistry: + registry = EmbeddedModelRegistry() + for _, _, inst in workers: + assert await registry.register_instance(inst) == True + return registry + + @pytest_asyncio.fixture async def manager_2_workers(request): param = getattr(request, "param", {}) @@ -166,3 +205,27 @@ async def manager_2_embedding_workers(request): ) async with _start_worker_manager(workers=workers, **param) as worker_manager: yield (worker_manager, workers) + + +@asynccontextmanager +async def _new_cluster(**kwargs) -> ClusterType: + num_workers = kwargs.get("num_workers", 0) + workers = _create_workers( + num_workers, stream_messags=kwargs.get("stream_messags", []) + ) + if "num_workers" in kwargs: + del kwargs["num_workers"] + registry = await _create_model_registry( + workers, + ) + async with _start_worker_manager(workers=workers, **kwargs) as worker_manager: + yield (worker_manager, registry) + + +@pytest_asyncio.fixture +async def cluster_2_workers(request): + param = getattr(request, "param", {}) + workers = _create_workers(2) + registry = await _create_model_registry(workers) + async with _start_worker_manager(workers=workers, **param) as worker_manager: + yield (worker_manager, registry) diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 5caa2ee7e..44a476f20 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -256,15 +256,22 @@ class DefaultModelWorker(ModelWorker): return params, model_context, generate_stream_func, model_span def _handle_output(self, output, previous_response, model_context): + finish_reason = None + usage = None if isinstance(output, dict): finish_reason = output.get("finish_reason") + usage = output.get("usage") output = output["text"] if finish_reason is not None: logger.info(f"finish_reason: {finish_reason}") incremental_output = output[len(previous_response) :] print(incremental_output, end="", flush=True) model_output = ModelOutput( - text=output, error_code=0, model_context=model_context + text=output, + error_code=0, + model_context=model_context, + finish_reason=finish_reason, + usage=usage, ) return model_output, incremental_output, output diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index a76fa6685..2dcfb086e 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -99,9 +99,7 @@ class LocalWorkerManager(WorkerManager): ) def _worker_key(self, worker_type: str, model_name: str) -> str: - if isinstance(worker_type, WorkerType): - worker_type = worker_type.value - return f"{model_name}@{worker_type}" + return WorkerType.to_worker_key(model_name, worker_type) async def run_blocking_func(self, func, *args): if asyncio.iscoroutinefunction(func): diff --git a/pilot/model/cluster/worker/tests/test_manager.py b/pilot/model/cluster/worker/tests/test_manager.py index 919e64f99..681fb49a3 100644 --- a/pilot/model/cluster/worker/tests/test_manager.py +++ b/pilot/model/cluster/worker/tests/test_manager.py @@ -3,7 +3,7 @@ import pytest from typing import List, Iterator, Dict, Tuple from dataclasses import asdict from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType -from pilot.model.base import ModelOutput, WorkerApplyType +from pilot.model.base import ModelOutput, WorkerApplyType, ModelInstance from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.manager_base import WorkerRunData @@ -14,7 +14,7 @@ from pilot.model.cluster.worker.manager import ( SendHeartbeatFunc, ApplyFunction, ) -from pilot.model.cluster.worker.tests.base_tests import ( +from pilot.model.cluster.tests.conftest import ( MockModelWorker, manager_2_workers, manager_with_2_workers, @@ -216,7 +216,7 @@ async def test__remove_worker(): workers = _create_workers(3) async with _start_worker_manager(workers=workers, stop=False) as manager: assert len(manager.workers) == 3 - for _, worker_params in workers: + for _, worker_params, _ in workers: manager._remove_worker(worker_params) not_exist_parmas = _new_worker_params( model_name="this is a not exist worker params" @@ -229,7 +229,7 @@ async def test__remove_worker(): async def test_model_startup(mock_build_worker): async with _start_worker_manager() as manager: workers = _create_workers(1) - worker, worker_params = workers[0] + worker, worker_params, model_instance = workers[0] mock_build_worker.return_value = worker req = WorkerStartupRequest( @@ -245,7 +245,7 @@ async def test_model_startup(mock_build_worker): async with _start_worker_manager() as manager: workers = _create_workers(1, error_worker=True) - worker, worker_params = workers[0] + worker, worker_params, model_instance = workers[0] mock_build_worker.return_value = worker req = WorkerStartupRequest( host="127.0.0.1", @@ -263,7 +263,7 @@ async def test_model_startup(mock_build_worker): async def test_model_shutdown(mock_build_worker): async with _start_worker_manager(start=False, stop=False) as manager: workers = _create_workers(1) - worker, worker_params = workers[0] + worker, worker_params, model_instance = workers[0] mock_build_worker.return_value = worker req = WorkerStartupRequest( @@ -298,7 +298,7 @@ async def test_get_model_instances(is_async): workers = _create_workers(3) async with _start_worker_manager(workers=workers, stop=False) as manager: assert len(manager.workers) == 3 - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type if is_async: @@ -326,7 +326,7 @@ async def test__simple_select( ] ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type instances = await manager.get_model_instances(worker_type, model_name) @@ -351,7 +351,7 @@ async def test_select_one_instance( ], ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type if is_async: @@ -376,7 +376,7 @@ async def test__get_model( ], ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name} @@ -403,13 +403,13 @@ async def test_generate_stream( expected_messages: str, ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name} text = "" async for out in manager.generate_stream(params): - text += out.text + text = out.text assert text == expected_messages @@ -417,8 +417,8 @@ async def test_generate_stream( @pytest.mark.parametrize( "manager_with_2_workers, expected_messages", [ - ({"stream_messags": ["Hello", " world."]}, " world."), - ({"stream_messags": ["你好,我是", "张三。"]}, "张三。"), + ({"stream_messags": ["Hello", " world."]}, "Hello world."), + ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"), ], indirect=["manager_with_2_workers"], ) @@ -429,7 +429,7 @@ async def test_generate( expected_messages: str, ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name} @@ -454,7 +454,7 @@ async def test_embeddings( is_async: bool, ): manager, workers = manager_2_embedding_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name, "input": ["hello", "world"]} @@ -472,7 +472,7 @@ async def test_parameter_descriptions( ] ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = await manager.parameter_descriptions(worker_type, model_name) diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py index b3674e946..5d4309d9f 100644 --- a/pilot/model/conversation.py +++ b/pilot/model/conversation.py @@ -339,6 +339,27 @@ register_conv_template( ) ) + +# codellama template +# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 +# reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md +register_conv_template( + Conversation( + name="codellama", + system="[INST] <>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request." + "If you don't know the answer to the request, please don't share false information.\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + system_formatter=lambda msg: f"[INST] <>\n{msg}\n<>\n\n", + ) +) + + # Alpaca default template register_conv_template( Conversation( diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index 1580e8863..e2deeaa02 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -45,6 +45,10 @@ _OLD_MODELS = [ "llama-cpp", "proxyllm", "gptj-6b", + "codellama-13b-sql-sft", + "codellama-7b", + "codellama-7b-sql-sft", + "codellama-13b", ] @@ -148,8 +152,12 @@ class LLMModelAdaper: conv.append_message(conv.roles[1], content) else: raise ValueError(f"Unknown role: {role}") + if system_messages: - conv.set_system_message("".join(system_messages)) + if isinstance(conv, Conversation): + conv.set_system_message("".join(system_messages)) + else: + conv.update_system_message("".join(system_messages)) # Add a blank message for the assistant. conv.append_message(conv.roles[1], None) @@ -459,7 +467,8 @@ register_conv_template( sep="\n", sep2="", stop_str=["", "[UNK]"], - ) + ), + override=True, ) # source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227 register_conv_template( @@ -474,7 +483,8 @@ register_conv_template( sep="###", sep2="", stop_str=["", "[UNK]"], - ) + ), + override=True, ) # source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242 register_conv_template( @@ -487,5 +497,6 @@ register_conv_template( sep="", sep2="", stop_str=["", "<|endoftext|>"], - ) + ), + override=True, ) diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index ea81ec091..e21de1c42 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- + import os from dataclasses import dataclass, field from enum import Enum -from typing import Dict, Optional +from typing import Dict, Optional, Union, Tuple from pilot.model.conversation import conv_templates from pilot.utils.parameter_utils import BaseParameters @@ -19,6 +20,35 @@ class WorkerType(str, Enum): def values(): return [item.value for item in WorkerType] + @staticmethod + def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str: + """Generate worker key from worker name and worker type + + Args: + worker_name (str): Worker name(eg., chatglm2-6b) + worker_type (Union[str, "WorkerType"]): Worker type(eg., 'llm', or [`WorkerType.LLM`]) + + Returns: + str: Generated worker key + """ + if "@" in worker_name: + raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}") + if isinstance(worker_type, WorkerType): + worker_type = worker_type.value + return f"{worker_name}@{worker_type}" + + @staticmethod + def parse_worker_key(worker_key: str) -> Tuple[str, str]: + """Parse worker name and worker type from worker key + + Args: + worker_key (str): Worker key generated by [`WorkerType.to_worker_key`] + + Returns: + Tuple[str, str]: Worker name and worker type + """ + return tuple(worker_key.split("@")) + @dataclass class ModelControllerParameters(BaseParameters): @@ -60,6 +90,56 @@ class ModelControllerParameters(BaseParameters): ) +@dataclass +class ModelAPIServerParameters(BaseParameters): + host: Optional[str] = field( + default="0.0.0.0", metadata={"help": "Model API server deploy host"} + ) + port: Optional[int] = field( + default=8100, metadata={"help": "Model API server deploy port"} + ) + daemon: Optional[bool] = field( + default=False, metadata={"help": "Run Model API server in background"} + ) + controller_addr: Optional[str] = field( + default="http://127.0.0.1:8000", + metadata={"help": "The Model controller address to connect"}, + ) + + api_keys: Optional[str] = field( + default=None, + metadata={"help": "Optional list of comma separated API keys"}, + ) + + log_level: Optional[str] = field( + default=None, + metadata={ + "help": "Logging level", + "valid_values": [ + "FATAL", + "ERROR", + "WARNING", + "WARNING", + "INFO", + "DEBUG", + "NOTSET", + ], + }, + ) + log_file: Optional[str] = field( + default="dbgpt_model_apiserver.log", + metadata={ + "help": "The filename to store log", + }, + ) + tracer_file: Optional[str] = field( + default="dbgpt_model_apiserver_tracer.jsonl", + metadata={ + "help": "The filename to store tracer span records", + }, + ) + + @dataclass class BaseModelParameters(BaseParameters): model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index eeb42a285..12a72e909 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Tuple, Optional, Union from pydantic import BaseModel, Field, root_validator @@ -70,14 +70,6 @@ class SystemMessage(BaseMessage): return "system" -class ModelMessage(BaseModel): - """Type of message that interaction between dbgpt-server and llm-server""" - - """Similar to openai's message format""" - role: str - content: str - - class ModelMessageRoleType: """ "Type of ModelMessage role""" @@ -87,6 +79,45 @@ class ModelMessageRoleType: VIEW = "view" +class ModelMessage(BaseModel): + """Type of message that interaction between dbgpt-server and llm-server""" + + """Similar to openai's message format""" + role: str + content: str + + @staticmethod + def from_openai_messages( + messages: Union[str, List[Dict[str, str]]] + ) -> List["ModelMessage"]: + """Openai message format to current ModelMessage format""" + if isinstance(messages, str): + return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)] + result = [] + for message in messages: + msg_role = message["role"] + content = message["content"] + if msg_role == "system": + result.append( + ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content) + ) + elif msg_role == "user": + result.append( + ModelMessage(role=ModelMessageRoleType.HUMAN, content=content) + ) + elif msg_role == "assistant": + result.append( + ModelMessage(role=ModelMessageRoleType.AI, content=content) + ) + else: + raise ValueError(f"Unknown role: {msg_role}") + return result + + @staticmethod + def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]: + return list(map(lambda m: m.dict(), messages)) + + class Generation(BaseModel): """Output of a single generation.""" diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index e6c3d9056..4d40a8eef 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -1,3 +1,4 @@ +import json import os from typing import Dict, List @@ -110,12 +111,14 @@ class ChatKnowledge(BaseChat): list(map(lambda doc: doc.metadata, docs)), "source" ) - self.current_message.knowledge_source = self.sources - if not docs: - raise ValueError( - "you have no knowledge space, please add your knowledge space" - ) - context = [d.page_content for d in docs] + if not docs or len(docs) == 0: + print("no relevant docs to retrieve") + context = "no relevant docs to retrieve" + # raise ValueError( + # "you have no knowledge space, please add your knowledge space" + # ) + else: + context = [d.page_content for d in docs] context = context[: self.max_token] relations = list( set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs]) @@ -128,17 +131,26 @@ class ChatKnowledge(BaseChat): return input_values def parse_source_view(self, sources: List): - html_title = f"##### **References:**" - lines = "" + """ + build knowledge reference view message to web + { + "title":"References", + "reference":{ + "name":"aa.pdf", + "pages":["1","2","3"] + }, + } + """ + references = {"title": "References", "reference": {}} for item in sources: source = item["source"] if "source" in item else "" - pages = ",".join(item["pages"]) if "pages" in item else "" - lines += f"{source}" + references["reference"]["name"] = source + pages = item["pages"] if "pages" in item else [] if len(pages) > 0: - lines += f", **pages**:{pages}\n\n" - else: - lines += "\n\n" - html = f"""{html_title}\n{lines}""" + references["reference"]["pages"] = pages + html = ( + f"""{json.dumps(references, ensure_ascii=False)}""" + ) return html def merge_by_key(self, data, key): diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index cb486021b..64b72739b 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -215,6 +215,16 @@ class Llama2ChatAdapter(BaseChatAdpter): return get_conv_template("llama-2") +class CodeLlamaChatAdapter(BaseChatAdpter): + """The model ChatAdapter for codellama .""" + + def match(self, model_path: str): + return "codellama" in model_path.lower() + + def get_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("codellama") + + class BaichuanChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "baichuan" in model_path.lower() @@ -268,6 +278,7 @@ register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GPT4AllChatAdapter) register_llm_model_chat_adapter(Llama2ChatAdapter) +register_llm_model_chat_adapter(CodeLlamaChatAdapter) register_llm_model_chat_adapter(BaichuanChatAdapter) register_llm_model_chat_adapter(WizardLMChatAdapter) register_llm_model_chat_adapter(LlamaCppChatAdapter) diff --git a/pilot/utils/openai_utils.py b/pilot/utils/openai_utils.py new file mode 100644 index 000000000..6577d3abf --- /dev/null +++ b/pilot/utils/openai_utils.py @@ -0,0 +1,99 @@ +from typing import Dict, Any, Awaitable, Callable, Optional, Iterator +import httpx +import asyncio +import logging +import json + +logger = logging.getLogger(__name__) +MessageCaller = Callable[[str], Awaitable[None]] + + +async def _do_chat_completion( + url: str, + chat_data: Dict[str, Any], + client: httpx.AsyncClient, + headers: Dict[str, Any] = {}, + timeout: int = 60, + caller: Optional[MessageCaller] = None, +) -> Iterator[str]: + async with client.stream( + "POST", + url, + headers=headers, + json=chat_data, + timeout=timeout, + ) as res: + if res.status_code != 200: + error_message = await res.aread() + if error_message: + error_message = error_message.decode("utf-8") + logger.error( + f"Request failed with status {res.status_code}. Error: {error_message}" + ) + raise httpx.RequestError( + f"Request failed with status {res.status_code}", + request=res.request, + ) + async for line in res.aiter_lines(): + if line: + if not line.startswith("data: "): + if caller: + await caller(line) + yield line + else: + decoded_line = line.split("data: ", 1)[1] + if decoded_line.lower().strip() != "[DONE]".lower(): + obj = json.loads(decoded_line) + if obj["choices"][0]["delta"].get("content") is not None: + text = obj["choices"][0]["delta"].get("content") + if caller: + await caller(text) + yield text + await asyncio.sleep(0.02) + + +async def chat_completion_stream( + url: str, + chat_data: Dict[str, Any], + client: Optional[httpx.AsyncClient] = None, + headers: Dict[str, Any] = {}, + timeout: int = 60, + caller: Optional[MessageCaller] = None, +) -> Iterator[str]: + if client: + async for text in _do_chat_completion( + url, + chat_data, + client=client, + headers=headers, + timeout=timeout, + caller=caller, + ): + yield text + else: + async with httpx.AsyncClient() as client: + async for text in _do_chat_completion( + url, + chat_data, + client=client, + headers=headers, + timeout=timeout, + caller=caller, + ): + yield text + + +async def chat_completion( + url: str, + chat_data: Dict[str, Any], + client: Optional[httpx.AsyncClient] = None, + headers: Dict[str, Any] = {}, + timeout: int = 60, + caller: Optional[MessageCaller] = None, +) -> str: + full_text = "" + async for text in chat_completion_stream( + url, chat_data, client, headers=headers, timeout=timeout, caller=caller + ): + full_text += text + return full_text diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index 072b527f1..d1a98ed49 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -8,6 +8,7 @@ pytest-integration pytest-mock pytest-recording pytesseract==0.3.10 +aioresponses # python code format black # for git hooks