diff --git a/.env.template b/.env.template
index e03650033..ba7f752db 100644
--- a/.env.template
+++ b/.env.template
@@ -55,6 +55,17 @@ QUANTIZE_8bit=True
## Model path
# llama_cpp_model_path=/data/models/TheBloke/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q4_K_M.gguf
+### LLM cache
+## Enable Model cache
+# MODEL_CACHE_ENABLE=True
+## The storage type of model cache, now supports: memory, disk
+# MODEL_CACHE_STORAGE_TYPE=disk
+## The max cache data in memory, we always store cache data in memory fist for high speed.
+# MODEL_CACHE_MAX_MEMORY_MB=256
+## The dir to save cache data, this configuration is only valid when MODEL_CACHE_STORAGE_TYPE=disk
+## The default dir is pilot/data/model_cache
+# MODEL_CACHE_STORAGE_DISK_DIR=
+
#*******************************************************************#
#** EMBEDDING SETTINGS **#
#*******************************************************************#
diff --git a/README.md b/README.md
index 61fdf994d..09d67bd13 100644
--- a/README.md
+++ b/README.md
@@ -13,9 +13,6 @@
-
-
-
@@ -34,12 +31,22 @@
-[**简体中文**](README.zh.md) |[**Discord**](https://discord.gg/nASQyBjvY) |[**Documents**](https://db-gpt.readthedocs.io/en/latest/)|[**Wechat**](https://github.com/eosphoros-ai/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**Community**](https://github.com/eosphoros-ai/community)
+[**简体中文**](README.zh.md) | [**Discord**](https://discord.gg/nASQyBjvY) | [**Documents**](https://db-gpt.readthedocs.io/en/latest/) | [**Wechat**](https://github.com/eosphoros-ai/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC) | [**Community**](https://github.com/eosphoros-ai/community)
## What is DB-GPT?
-DB-GPT is an experimental open-source project that uses localized GPT large models to interact with your data and environment. With this solution, you can be assured that there is no risk of data leakage, and your data is 100% private and secure.
+DB-GPT is an open-source framework designed for the realm of large language models (LLMs) within the database field. Its primary purpose is to provide infrastructure that simplifies and streamlines the development of database-related applications. This is accomplished through the development of various technical capabilities, including:
+
+1. **SMMF(Service-oriented Multi-model Management Framework)**
+2. **Text2SQL Fine-tuning**
+3. **RAG(Retrieval Augmented Generation) framework and optimization**
+4. **Data-Driven Agents framework collaboration**
+5. **GBI(Generative Business intelligence)**
+
+DB-GPT simplifies the creation of these applications based on large language models (LLMs) and databases.
+
+In the era of Data 3.0, enterprises and developers can take the ability to create customized applications with minimal coding, which harnesses the power of large language models (LLMs) and databases.
## Contents
@@ -57,16 +64,6 @@ DB-GPT is an experimental open-source project that uses localized GPT large mode
Run on an RTX 4090 GPU.
##### Chat Excel

-##### Chat Plugin
-
-##### LLM Management
-
-##### FastChat && vLLM
-
-##### Trace
-
-##### Chat Knowledge
-
## Install

@@ -96,26 +93,26 @@ Run on an RTX 4090 GPU.
## Features
-Currently, we have released multiple key features, which are listed below to demonstrate our current capabilities:
-- Private KBQA & data processing
+At present, we have introduced several key features to showcase our current capabilities:
+- **Private Domain Q&A & Data Processing**
- The DB-GPT project offers a range of features to enhance knowledge base construction and enable efficient storage and retrieval of both structured and unstructured data. These include built-in support for uploading multiple file formats, the ability to integrate plug-ins for custom data extraction, and unified vector storage and retrieval capabilities for managing large volumes of information.
+ The DB-GPT project offers a range of functionalities designed to improve knowledge base construction and enable efficient storage and retrieval of both structured and unstructured data. These functionalities include built-in support for uploading multiple file formats, the ability to integrate custom data extraction plug-ins, and unified vector storage and retrieval capabilities for effectively managing large volumes of information.
-- Multiple data sources & visualization
-
- The DB-GPT project enables seamless natural language interaction with various data sources, including Excel, databases, and data warehouses. It facilitates effortless querying and retrieval of information from these sources, allowing users to engage in intuitive conversations and obtain insights. Additionally, DB-GPT supports the generation of analysis reports, providing users with valuable summaries and interpretations of the data.
+- **Multi-Data Source & GBI(Generative Business intelligence)**
-- Multi-Agents&Plugins
+ The DB-GPT project facilitates seamless natural language interaction with diverse data sources, including Excel, databases, and data warehouses. It simplifies the process of querying and retrieving information from these sources, empowering users to engage in intuitive conversations and gain insights. Moreover, DB-GPT supports the generation of analytical reports, providing users with valuable data summaries and interpretations.
- It supports custom plug-ins to perform tasks, natively supports the Auto-GPT plug-in model, and the Agents protocol adopts the Agent Protocol standard.
+- **Multi-Agents&Plugins**
-- Fine-tuning text2SQL
+ It offers support for custom plug-ins to perform various tasks and natively integrates the Auto-GPT plug-in model. The Agents protocol adheres to the Agent Protocol standard.
- An automated fine-tuning lightweight framework built around large language models, Text2SQL data sets, LoRA/QLoRA/Pturning, and other fine-tuning methods, making TextSQL fine-tuning as convenient as an assembly line. [DB-GPT-Hub](https://github.com/eosphoros-ai/DB-GPT-Hub)
+- **Automated Fine-tuning text2SQL**
-- Multi LLMs Support, Supports multiple large language models, currently supporting
+ We've also developed an automated fine-tuning lightweight framework centred on large language models (LLMs), Text2SQL datasets, LoRA/QLoRA/Pturning, and other fine-tuning methods. This framework simplifies Text-to-SQL fine-tuning, making it as straightforward as an assembly line process. [DB-GPT-Hub](https://github.com/eosphoros-ai/DB-GPT-Hub)
- Massive model support, including dozens of large language models such as open source and API agents. Such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, etc.
+- **SMMF(Service-oriented Multi-model Management Framework)**
+
+ We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
- [Vicuna](https://huggingface.co/Tribbiani/vicuna-13b)
- [vicuna-13b-v1.5](https://huggingface.co/lmsys/vicuna-13b-v1.5)
- [LLama2](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
@@ -126,22 +123,6 @@ Currently, we have released multiple key features, which are listed below to dem
- [falcon-40b](https://huggingface.co/tiiuae/falcon-40b)
- [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
- [Qwen-7B-Chat/Qwen-14B-Chat](https://huggingface.co/Qwen/)
- - [RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven)
- - [CAMEL-13B-Combined-Data](https://huggingface.co/camel-ai/CAMEL-13B-Combined-Data)
- - [dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b)
- - [h2ogpt-gm-oasst1-en-2048-open-llama-7b](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b)
- - [fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5)
- - [mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat)
- - [gpt4all-13b-snoozy](https://huggingface.co/nomic-ai/gpt4all-13b-snoozy)
- - [Nous-Hermes-13b](https://huggingface.co/NousResearch/Nous-Hermes-13b)
- - [codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b)
- - [guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged)
- - [WizardLM-13B-V1.0](https://huggingface.co/WizardLM/WizardLM-13B-V1.0)
- - [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
- - [Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat)
- - [OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
-
- Etc.
- Support API Proxy LLMs
- [x] [ChatGPT](https://api.openai.com/)
@@ -149,9 +130,9 @@ Currently, we have released multiple key features, which are listed below to dem
- [x] [Wenxin](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
- [x] [ChatGLM](http://open.bigmodel.cn/)
-- Privacy and security
+- **Privacy and Security**
- The privacy and security of data are ensured through various technologies, such as privatized large models and proxy desensitization.
+ We ensure the privacy and security of data through the implementation of various technologies, including privatized large models and proxy desensitization.
- Support Datasources
@@ -177,43 +158,37 @@ Currently, we have released multiple key features, which are listed below to dem
| [StarRocks](https://github.com/StarRocks/starrocks) | No | TODO |
## Introduction
-The architecture of the entire DB-GPT is shown.
+The architecture of DB-GPT is shown in the following figure:
-The core capabilities mainly consist of the following parts:
-1. Multi-Models: Support multi-LLMs, such as LLaMA/LLaMA2、CodeLLaMA、ChatGLM, QWen、Vicuna and proxy model ChatGPT、Baichuan、tongyi、wenxin etc
-2. Knowledge-Based QA: You can perform high-quality intelligent Q&A based on local documents such as PDF, word, excel, and other data.
-3. Embedding: Unified data vector storage and indexing, Embed data as vectors and store them in vector databases, providing content similarity search.
-4. Multi-Datasources: Used to connect different modules and data sources to achieve data flow and interaction.
-5. Multi-Agents: Provides Agent and plugin mechanisms, allowing users to customize and enhance the system's behavior.
-6. Privacy & Secure: You can be assured that there is no risk of data leakage, and your data is 100% private and secure.
-7. Text2SQL: We enhance the Text-to-SQL performance by applying Supervised Fine-Tuning (SFT) on large language models
-
-### RAG-IN-Action
-
-
-
+The core capabilities primarily consist of the following components:
+1. Multi-Models: We support multiple Large Language Models (LLMs) such as LLaMA/LLaMA2, CodeLLaMA, ChatGLM, QWen, Vicuna, and proxy models like ChatGPT, Baichuan, Tongyi, Wenxin, and more.
+2. Knowledge-Based QA: Our system enables high-quality intelligent Q&A based on local documents such as PDFs, Word documents, Excel files, and other data sources.
+3. Embedding: We offer unified data vector storage and indexing. Data is embedded as vectors and stored in vector databases, allowing for content similarity search.
+4. Multi-Datasources: This feature connects different modules and data sources, facilitating data flow and interaction.
+5. Multi-Agents: Our platform provides Agent and plugin mechanisms, empowering users to customize and enhance the system's behaviour.
+6. Privacy & Security: Rest assured that there is no risk of data leakage, and your data is 100% private and secure.
+7. Text2SQL: We enhance Text-to-SQL performance through Supervised Fine-Tuning (SFT) applied to Large Language Models (LLMs).
### SubModule
-- [DB-GPT-Hub](https://github.com/eosphoros-ai/DB-GPT-Hub) Text-to-SQL performance by applying Supervised Fine-Tuning (SFT) on large language models.
-- [DB-GPT-Plugins](https://github.com/eosphoros-ai/DB-GPT-Plugins) DB-GPT Plugins Can run autogpt plugin directly
+- [DB-GPT-Hub](https://github.com/eosphoros-ai/DB-GPT-Hub) Text-to-SQL workflow with high performance by applying Supervised Fine-Tuning (SFT) on Large Language Models (LLMs).
+- [DB-GPT-Plugins](https://github.com/eosphoros-ai/DB-GPT-Plugins) DB-GPT Plugins that can run Auto-GPT plugin directly
- [DB-GPT-Web](https://github.com/eosphoros-ai/DB-GPT-Web) ChatUI for DB-GPT
## Image
🌐 [AutoDL Image](https://www.codewithgpu.com/i/eosphoros-ai/DB-GPT/dbgpt)
-
-
### Language Switching
In the .env configuration file, modify the LANGUAGE parameter to switch to different languages. The default is English (Chinese: zh, English: en, other languages to be added later).
## Contribution
-- Please run `black .` before submitting the code. Contributing guidelines, [how to contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
+- Please run `black .` before submitting the code.
+- To check detailed guidelines for new contributions, please refer [how to contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
## RoadMap
@@ -310,18 +285,7 @@ The core capabilities mainly consist of the following parts:
- [x] ChatGLM2
- SFT Accuracy
-
-As of October 10, 2023, by fine-tuning an open-source model of 13 billion parameters using this project, the execution accuracy on the Spider evaluation dataset has surpassed that of GPT-4!
-
-| name | Execution Accuracy | reference |
-| ----------------------------------| ------------------ | ------------------------------------------------------------------------------------------------------------------------------ |
-| **GPT-4** | **0.762** | [numbersstation-eval-res](https://www.numbersstation.ai/post/nsql-llama-2-7b) |
-| ChatGPT | 0.728 | [numbersstation-eval-res](https://www.numbersstation.ai/post/nsql-llama-2-7b) |
-| **CodeLlama-13b-Instruct-hf_lora**| **0.789** | sft train by our this project,only used spider train dataset ,the same eval way in this project with lora SFT |
-| CodeLlama-13b-Instruct-hf_qlora | 0.774 | sft train by our this project,only used spider train dataset ,the same eval way in this project with qlora and nf4,bit4 SFT |
-| wizardcoder | 0.610 | [text-to-sql-wizardcoder](https://github.com/cuplv/text-to-sql-wizardcoder/tree/main) |
-| CodeLlama-13b-Instruct-hf | 0.556 | eval in this project default param |
-| llama2_13b_hf_lora_best | 0.744 | sft train by our this project,only used spider train dataset ,the same eval way in this project |
+As of October 10, 2023, through the fine-tuning of an open-source model with 13 billion parameters using this project, we have achieved execution accuracy on the Spider dataset that surpasses even GPT-4!
[More Information about Text2SQL finetune](https://github.com/eosphoros-ai/DB-GPT-Hub)
diff --git a/README.zh.md b/README.zh.md
index 115427e9e..0ab05ab4a 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -34,12 +34,9 @@
## DB-GPT 是什么?
+DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模型领域的基础设施,通过开发多模型管理、Text2SQL效果优化、RAG框架以及优化、Multi-Agents框架协作等多种技术能力,让围绕数据库构建大模型应用更简单,更方便。
-随着大模型的发布迭代,大模型变得越来越智能,在使用大模型的过程当中,遇到极大的数据安全与隐私挑战。在利用大模型能力的过程中我们的私密数据跟环境需要掌握自己的手里,完全可控,避免任何的数据隐私泄露以及安全风险。基于此,我们发起了DB-GPT项目,为所有以数据库为基础的场景,构建一套完整的私有大模型解决方案。 此方案因为支持本地部署,所以不仅仅可以应用于独立私有环境,而且还可以根据业务模块独立部署隔离,让大模型的能力绝对私有、安全、可控。我们的愿景是让围绕数据库构建大模型应用更简单,更方便。
-
-DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地化的GPT大模型与您的数据和环境进行交互,无数据泄露风险,100% 私密
-
-
+数据3.0 时代,基于模型、数据库,企业/开发者可以用更少的代码搭建自己的专属应用。
## 目录
@@ -59,19 +56,8 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
##### Chat Excel

-#### Chat Plugin
-
-#### LLM Management
-
-#### FastChat && vLLM
-
-#### Trace
-
-#### Chat Knowledge
-
#### 根据自然语言对话生成分析图表
-
@@ -80,10 +66,6 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
-
-
-
-
## 安装

@@ -111,26 +93,23 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
- [**FAQ**](https://db-gpt.readthedocs.io/en/latest/getting_started/faq/deploy/deploy_faq.html)
## 特性一览
-
-目前我们已经发布了多种关键的特性,这里一一列举展示一下当前发布的能力。
-
-- 私域问答&数据处理
+- **私域问答&数据处理&RAG**
支持内置、多文件格式上传、插件自抓取等方式自定义构建知识库,对海量结构化,非结构化数据做统一向量存储与检索
-
-- 多数据源&可视化
+
+- **多数据源&GBI**
支持自然语言与Excel、数据库、数仓等多种数据源交互,并支持分析报告。
-- 自动化微调
+- **自动化微调**
围绕大语言模型、Text2SQL数据集、LoRA/QLoRA/Pturning等微调方法构建的自动化微调轻量框架, 让TextSQL微调像流水线一样方便。详见: [DB-GPT-Hub](https://github.com/eosphoros-ai/DB-GPT-Hub)
-- Multi-Agents&Plugins
+- **Data-Driven Multi-Agents&Plugins**
支持自定义插件执行任务,原生支持Auto-GPT插件模型,Agents协议采用Agent Protocol标准
-- 多模型支持与管理
+- **多模型支持与管理**
海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。
- 支持多种大语言模型, 当前已支持如下模型:
@@ -141,30 +120,14 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
- [baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
- [chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
- [chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
- - [falcon-40b](https://huggingface.co/tiiuae/falcon-40b)
- - [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
- - [Qwen-7B-Chat/Qwen-14B-Chat](https://huggingface.co/Qwen/)
- - [RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven)
- - [CAMEL-13B-Combined-Data](https://huggingface.co/camel-ai/CAMEL-13B-Combined-Data)
- - [dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b)
- - [h2ogpt-gm-oasst1-en-2048-open-llama-7b](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b)
- - [fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5)
- - [mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat)
- - [gpt4all-13b-snoozy](https://huggingface.co/nomic-ai/gpt4all-13b-snoozy)
- - [Nous-Hermes-13b](https://huggingface.co/NousResearch/Nous-Hermes-13b)
- - [codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b)
- - [guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged)
- - [WizardLM-13B-V1.0](https://huggingface.co/WizardLM/WizardLM-13B-V1.0)
- - [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
- - [Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat)
- - [OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
+
- 支持在线代理模型
- [x] [ChatGPT](https://api.openai.com/)
- [x] [Tongyi](https://www.aliyun.com/product/dashscope)
- [x] [Wenxin](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
- [x] [ChatGLM](http://open.bigmodel.cn/)
-- 隐私安全
+- **隐私安全**
通过私有化大模型、代理脱敏等多种技术保障数据的隐私安全。
@@ -192,22 +155,23 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
| [StarRocks](https://github.com/StarRocks/starrocks) | No | TODO |
## 架构方案
-DB-GPT基于 [FastChat](https://github.com/lm-sys/FastChat) 构建大模型运行环境。此外,我们通过LangChain提供私域知识库问答能力。同时我们支持插件模式, 在设计上原生支持Auto-GPT插件。我们的愿景是让围绕数据库和LLM构建应用程序更加简便和便捷。
-
整个DB-GPT的架构,如下图所示
-
-核心能力主要有以下几个部分。
-1. 多模型:支持多LLM,如LLaMA/LLaMA2、CodeLLaMA、ChatGLM、QWen、Vicuna以及代理模型ChatGPT、Baichuan、tongyi、wenxin等
-2. 私域知识库问答: 可以根据本地文档(如pdf、word、excel等数据)进行高质量的智能问答。
-3. 统一数据向量存储和索引: 将数据嵌入为向量并存储在向量数据库中,提供内容相似性搜索。
-4. 多数据源: 用于连接不同的模块和数据源,实现数据的流动和交互。
-5. Agent与插件: 提供Agent和插件机制,使得用户可以自定义并增强系统的行为。
-6. 隐私和安全: 您可以放心,没有数据泄露的风险,您的数据100%私密和安全。
-7. Text2SQL: 我们通过在大型语言模型监督微调(SFT)来增强文本到SQL的性能
+核心能力主要有以下几个部分:
+- **RAG(Retrieval Augmented Generation)**,RAG是当下落地实践最多,也是最迫切的领域,DB-GPT目前已经实现了一套基于RAG的框架,用户可以基于DB-GPT的RAG能力构建知识类应用。
+
+- **GBI**:生成式BI是DB-GPT项目的核心能力之一,为构建企业报表分析、业务洞察提供基础的数智化技术保障。
+
+- **Fine-tune框架**: 模型微调是任何一个企业在垂直、细分领域落地不可或缺的能力,DB-GPT提供了完整的微调框架,实现与DB-GPT项目的无缝打通,在最近的微调中,基于spider的准确率已经做到了82.5%
+
+- **数据驱动的Multi-Agents框架**: DB-GPT提供了数据驱动的自进化微调框架,目标是可以持续基于数据做决策与执行。
+
+- **数据工厂**: 数据工厂主要是在大模型时代,做可信知识、数据的清洗加工。
+
+- **数据源**: 对接各类数据源,实现生产业务数据无缝对接到DB-GPT核心能力。
### RAG生产落地实践架构
@@ -345,16 +309,6 @@ The MIT License (MIT)
- SFT模型准确率
截止20231010,我们利用本项目基于开源的13B大小的模型微调后,在Spider的评估集上的执行准确率,已经超越GPT-4!
-| 模型名称 | 执行准确率 | 说明 |
-| ----------------------------------| ------------------ | ------------------------------------------------------------------------------------------------------------------------------ |
-| **GPT-4** | **0.762** | [numbersstation-eval-res](https://www.numbersstation.ai/post/nsql-llama-2-7b) |
-| ChatGPT | 0.728 | [numbersstation-eval-res](https://www.numbersstation.ai/post/nsql-llama-2-7b) |
-| **CodeLlama-13b-Instruct-hf_lora**| **0.789** | sft train by our this project,only used spider train dataset ,the same eval way in this project with lora SFT |
-| CodeLlama-13b-Instruct-hf_qlora | 0.774 | sft train by our this project,only used spider train dataset ,the same eval way in this project with qlora and nf4,bit4 SFT |
-| wizardcoder | 0.610 | [text-to-sql-wizardcoder](https://github.com/cuplv/text-to-sql-wizardcoder/tree/main) |
-| CodeLlama-13b-Instruct-hf | 0.556 | eval in this project default param |
-| llama2_13b_hf_lora_best | 0.744 | sft train by our this project,only used spider train dataset ,the same eval way in this project |
-
[More Information about Text2SQL finetune](https://github.com/eosphoros-ai/DB-GPT-Hub)
## 联系我们
diff --git a/docker/examples/benchmarks/benchmarks_llm_11k_prompt.txt b/docker/examples/benchmarks/benchmarks_llm_11k_prompt.txt
new file mode 100644
index 000000000..25072fa18
--- /dev/null
+++ b/docker/examples/benchmarks/benchmarks_llm_11k_prompt.txt
@@ -0,0 +1,214 @@
+A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
+The assistant gives helpful, detailed, professional and polite answers to the user's questions. 基于以下已知的信息, 专业、简要的回答用户的问题,
+ 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造, 回答的时候最好按照1.2.3.点进行总结。
+ 已知内容:
+
+OceanBase 数据库(OceanBase Database)是一款完全自研的企业级原生分布式数据库,在普通硬件上实现金融级高可用,首创“三地五中心”城市级故障自动无损容灾新标准,刷新 TPC-C 标准测试,单集群规模超过 1500 节点,具有云原生、强一致性、高度兼容 Oracle/MySQL 等特性。
+
+核心特性
+高可用
+独创 “三地五中心” 容灾架构方案,建立金融行业无损容灾新标准。支持同城/异地容灾,可实现多地多活,满足金融行业 6 级容灾标准(RPO=0,RTO< 8s),数据零丢失。
+高兼容
+高度兼容 Oracle 和 MySQL,覆盖绝大多数常见功能,支持过程语言、触发器等高级特性,提供自动迁移工具,支持迁移评估和反向同步以保障数据迁移安全,可支撑金融、政府、运营商等关键行业核心场景替代。
+水平扩展
+实现透明水平扩展,支持业务快速的扩容缩容,同时通过准内存处理架构实现高性能。支持集群节点超过数千个,单集群最大数据量超过 3PB,最大单表行数达万亿级。
+低成本
+基于 LSM-Tree 的高压缩引擎,存储成本降低 70% - 90%;原生支持多租户架构,同集群可为多个独立业务提供服务,租户间数据隔离,降低部署和运维成本。
+实时 HTAP
+基于“同一份数据,同一个引擎”,同时支持在线实时交易及实时分析两种场景,“一份数据”的多个副本可以存储成多种形态,用于不同工作负载,从根本上保持数据一致性。
+安全可靠
+12 年完全自主研发,代码级可控,自研单机分布式一体化架构,大规模金融核心场景 9 年可靠性验证;完备的角色权限管理体系,数据存储和通信全链路透明加密,支持国密算法,通过等保三级专项合规检测。
+深入了解 OceanBase 数据库
+您可以通过以下内容更深入地了解 OceanBase 数据库:
+
+OceanBase 使用通用服务器硬件,依赖本地存储,分布式部署使用的多个服务器也是对等的,没有特殊的硬件要求。OceanBase 的分布式数据库处理采用 Shared Nothing 架构,数据库内的 SQL 执行引擎具有分布式执行能力。
+
+OceanBase 在服务器上会运行叫做 observer 的单进程程序作为数据库的运行实例,使用本地的文件存储数据和事务 Redo 日志。
+
+OceanBase 集群部署需要配置可用区(Zone),由若干个服务器组成。可用区是一个逻辑概念,表示集群内具有相似硬件可用性的一组节点,它在不同的部署模式下代表不同的含义。例如,当整个集群部署在同一个数据中心(IDC)内的时候,一个可用区的节点可以属于同一个机架,同一个交换机等。当集群分布在多个数据中心的时候,每个可用区可以对应于一个数据中心。
+
+用户存储的数据在分布式集群内部可以存储多个副本,用于故障容灾,也可以用于分散读取压力。在一个可用区内部数据只有一个副本,不同的可用区可以存储同一个数据的多个副本,副本之间由共识协议保证数据的一致性。
+
+OceanBase 内置多租户特性,每个租户对于使用者是一个独立的数据库,一个租户能够在租户级别设置租户的分布式部署方式。租户之间 CPU、内存和 IO 都是隔离的。
+
+OceanBase的数据库实例内部由不同的组件相互协作,这些组件从底层向上由存储层、复制层、均衡层、事务层、SQL 层、接入层组成。
+
+存储层
+存储层以一张表或者一个分区为粒度提供数据存储与访问,每个分区对应一个用于存储数据的Tablet(分片),用户定义的非分区表也会对应一个 Tablet。
+
+Tablet 的内部是分层存储的结构,总共有 4 层。DML 操作插入、更新、删除等首先写入 MemTable,等到 MemTable 达到一定大小时转储到磁盘成为 L0 SSTable。L0 SSTable 个数达到阈值后会将多个 L0 SSTable 合并成一个 L1 SSTable。在每天配置的业务低峰期,系统会将所有的 MemTable、L0 SSTable 和 L1 SSTable 合并成一个 Major SSTable。
+
+每个 SSTable 内部是以 2MB 定长宏块为基本单位,每个宏块内部由多个不定长微块组成。
+
+Major SSTable 的微块会在合并过程中用编码方式进行格式转换,微块内的数据会按照列维度分别进行列内的编码,编码规则包括字典/游程/常量/差值等,每一列压缩结束后,还会进一步对多列进行列间等值/子串等规则编码。编码能对数据大幅压缩,同时提炼的列内特征信息还能进一步加速后续的查询速度。
+
+在编码压缩之后,还可以根据用户指定的通用压缩算法进行无损压缩,进一步提升数据压缩率。
+
+复制层
+复制层使用日志流(LS、Log Stream)在多副本之间同步状态。每个 Tablet 都会对应一个确定的日志流,DML 操作写入 Tablet 的数据所产生的 Redo 日志会持久化在日志流中。日志流的多个副本会分布在不同的可用区中,多个副本之间维持了共识算法,选择其中一个副本作为主副本,其他的副本皆为从副本。Tablet 的 DML 和强一致性查询只在其对应的日志流的主副本上进行。
+
+通常情况下,每个租户在每台机器上只会有一个日志流的主副本,可能存在多个其他日志流的从副本。租户的总日志流个数取决于 Primary Zone 和 Locality 的配置。
+
+日志流使用自研的 Paxos 协议将 Redo 日志在本服务器持久化,同时通过网络发送给日志流的从副本,从副本在完成各自持久化后应答主副本,主副本在确认有多数派副本都持久化成功后确认对应的 Redo 日志持久化成功。从副本利用 Redo 日志的内容实时回放,保证自己的状态与主副本一致。
+
+日志流的主副本在被选举成为主后会获得租约(Lease),正常工作的主副本在租约有效期内会不停的通过选举协议延长租约期。主副本只会在租约有效时执行主的工作,租约机制保证了数据库异常处理的能力。
+
+复制层能够自动应对服务器故障,保障数据库服务的持续可用。如果出现少于半数的从副本所在服务器出现问题,因为还有多于半数的副本正常工作,数据库的服务不受影响。如果主副本所在服务器出现问题,其租约会得不到延续,待其租约失效后,其他从副本会通过选举协议选举出新的主副本并授予新的租约,之后即可恢复数据库的服务。
+
+均衡层
+新建表和新增分区时,系统会按照均衡原则选择合适的日志流创建 Tablet。当租户的属性发生变更,新增了机器资源,或者经过长时间使用后,Tablet 在各台机器上不再均衡时,均衡层通过日志流的分裂和合并操作,并在这个过程中配合日志流副本的移动,让数据和服务在多个服务器之间再次均衡。
+
+当租户有扩容操作,获得更多服务器资源时,均衡层会将租户内已有的日志流进行分裂,并选择合适数量的 Tablet 一同分裂到新的日志流中,再将新日志流迁移到新增的服务器上,以充分利用扩容后的资源。当租户有缩容操作时,均衡层会把需要缩减的服务器上的日志流迁移到其他服务器上,并和其他服务器上已有的日志流进行合并,以缩减机器的资源占用。
+
+当数据库长期使用后,随着持续创建删除表格,并且写入更多的数据,即使没有服务器资源数量变化,原本均衡的情况可能被破坏。最常见的情况是,当用户删除了一批表格后,删除的表格可能原本聚集在某一些机器上,删除后这些机器上的 Tablet 数量就变少了,应该把其他机器的 Tablet 均衡一些到这些少的机器上。均衡层会定期生成均衡计划,将 Tablet 多的服务器上日志流分裂出临时日志流并携带需要移动的 Tablet,临时日志流迁移到目的服务器后再和目的服务器上的日志流进行合并,以达成均衡的效果。
+
+事务层
+事务层保证了单个日志流和多个日志流DML操作提交的原子性,也保证了并发事务之间的多版本隔离能力。
+
+原子性
+一个日志流上事务的修改,即使涉及多个 Tablet,通过日志流的 write-ahead log 可以保证事务提交的原子性。事务的修改涉及多个日志流时,每个日志流会产生并持久化各自的write-ahead log,事务层通过优化的两阶段提交协议来保证提交的原子性。
+
+事务层会选择一个事务修改的一个日志流产生协调者状态机,协调者会与事务修改的所有日志流通信,判断 write-ahead log 是否持久化,当所有日志流都完成持久化后,事务进入提交状态,协调者会再驱动所有日志流写下这个事务的 Commit 日志,表示事务最终的提交状态。当从副本回放或者数据库重启时,已经完成提交的事务都会通过 Commit 日志确定各自日志流事务的状态。
+
+宕机重启场景下,宕机前还未完成的事务,会出现写完 write-ahead log 但是还没有Commit 日志的情况,每个日志流的 write-ahead log 都会包含事务的所有日志流列表,通过此信息可以重新确定哪个日志流是协调者并恢复协调者的状态,再次推进两阶段状态机,直到事务最终的 Commit 或 Abort 状态。
+
+隔离性
+GTS 服务是一个租户内产生连续增长的时间戳的服务,其通过多副本保证可用性,底层机制与上面复制层所描述的日志流副本同步机制是一样的。
+
+每个事务在提交时会从 GTS 获取一个时间戳作为事务的提交版本号并持久化在日志流的write-ahead log 中,事务内所有修改的数据都以此提交版本号标记。
+
+每个语句开始时(对于 Read Committed 隔离级别)或者每个事务开始时(对于Repeatable Read 和 Serializable 隔离级别)会从 GTS 获取一个时间戳作为语句或事务的读取版本号。在读取数据时,会跳过事务版本号比读取版本号大的数据,通过这种方式为读取操作提供了统一的全局数据快照。
+
+SQL 层
+SQL 层将用户的 SQL 请求转化成对一个或多个 Tablet 的数据访问。
+
+SQL 层组件
+SQL 层处理一个请求的执行流程是:Parser、Resolver、Transformer、Optimizer、Code Generator、Executor。
+
+Parser 负责词法/语法解析,Parser 会将用户的 SQL 分成一个个的 "Token",并根据预先设定好的语法规则解析整个请求,转换成语法树(Syntax Tree)。
+
+Resolver 负责语义解析,将根据数据库元信息将 SQL 请求中的 Token 翻译成对应的对象(例如库、表、列、索引等),生成的数据结构叫做 Statement Tree。
+
+Transformer 负责逻辑改写,根据内部的规则或代价模型,将 SQL 改写为与之等价的其他形式,并将其提供给后续的优化器做进一步的优化。Transformer 的工作方式是在原Statement Tree 上做等价变换,变换的结果仍然是一棵 Statement Tree。
+
+Optimizer(优化器)为 SQL 请求生成最佳的执行计划,需要综合考虑 SQL 请求的语义、对象数据特征、对象物理分布等多方面因素,解决访问路径选择、联接顺序选择、联接算法选择、分布式计划生成等问题,最终生成执行计划。
+
+Code Generator(代码生成器)将执行计划转换为可执行的代码,但是不做任何优化选择。
+
+Executor(执行器)启动 SQL 的执行过程。
+
+在标准的 SQL 流程之外,SQL 层还有 Plan Cache 能力,将历史的执行计划缓存在内存中,后续的执行可以反复执行这个计划,避免了重复查询优化的过程。配合 Fast-parser 模块,仅使用词法分析对文本串直接参数化,获取参数化后的文本及常量参数,让 SQL 直接命中 Plan Cache,加速频繁执行的 SQL。
+
+多种计划
+SQL 层的执行计划分为本地、远程和分布式三种。本地执行计划只访问本服务器的数据。远程执行计划只访问非本地的一台服务器的数据。分布式计划会访问超过一台服务器的数据,执行计划会分成多个子计划在多个服务器上执行。
+
+SQL 层并行化执行能力可以将执行计划分解成多个部分,由多个执行线程执行,通过一定的调度的方式,实现执行计划的并行处理。并行化执行可以充分发挥服务器 CPU 和 IO 处理能力,缩短单个查询的响应时间。并行查询技术可以用于分布式执行计划,也可以用于本地执行计划。
+
+接入层
+obproxy 是 OceanBase 数据库的接入层,负责将用户的请求转发到合适的 OceanBase 实例上进行处理。
+
+obproxy 是独立的进程实例,独立于 OceanBase 的数据库实例部署。obproxy 监听网络端口,兼容 MySQL 网络协议,支持使用 MySQL 驱动的应用直接连接 OceanBase。
+
+obproxy 能够自动发现 OceanBase 集群的数据分布信息,对于代理的每一条 SQL 语句,会尽可能识别出语句将访问的数据,并将语句直接转发到数据所在服务器的 OceanBase 实例。
+
+obproxy 有两种部署方式,一种是部署在每一个需要访问数据库的应用服务器上,另一种是部署在与 OceanBase 相同的机器上。第一种部署方式下,应用程序直接连接部署在同一台服务器上的 obproxy,所有的请求会由 obproxy 发送到合适的 OceanBase 服务器。第二种部署方式下,需要使用网络负载均衡服务将多个 obproxy 聚合成同一个对应用提供服务的入口地址。
+
+OceanBase 数据库采用 Shared-Nothing 架构,各个节点之间完全对等,每个节点都有自己的 SQL 引擎、存储引擎、事务引擎,运行在普通 PC 服务器组成的集群之上,具备高可扩展性、高可用性、高性能、低成本、与主流数据库高兼容等核心特性。
+
+OceanBase 数据库的一个集群由若干个节点组成。这些节点分属于若干个可用区(Zone),每个节点属于一个可用区。可用区是一个逻辑概念,表示集群内具有相似硬件可用性的一组节点,它在不同的部署模式下代表不同的含义。例如,当整个集群部署在同一个数据中心(IDC)内的时候,一个可用区的节点可以属于同一个机架,同一个交换机等。当集群分布在多个数据中心的时候,每个可用区可以对应于一个数据中心。每个可用区具有 IDC 和地域(Region)两个属性,描述该可用区所在的 IDC 及 IDC 所属的地域。一般地,地域指 IDC 所在的城市。可用区的 IDC 和 Region 属性需要反映部署时候的实际情况,以便集群内的自动容灾处理和优化策略能更好地工作。根据业务对数据库系统不同的高可用性需求,OceanBase 集群提供了多种部署模式,参见 高可用架构概述。
+
+在 OceanBase 数据库中,一个表的数据可以按照某种划分规则水平拆分为多个分片,每个分片叫做一个表分区,简称分区(Partition)。某行数据属于且只属于一个分区。分区的规则由用户在建表的时候指定,包括hash、range、list等类型的分区,还支持二级分区。例如,交易库中的订单表,可以先按照用户 ID 划分为若干一级分区,再按照月份把每个一级分区划分为若干二级分区。对于二级分区表,第二级的每个子分区是一个物理分区,而第一级分区只是逻辑概念。一个表的若干个分区可以分布在一个可用区内的多个节点上。每个物理分区有一个用于存储数据的存储层对象,叫做 Tablet ,用于存储有序的数据记录。
+
+当用户对 Tablet 中记录进行修改的时候,为了保证数据持久化,需要记录重做日志(REDO)到 Tablet 对应的日志流(Log Stream)里。每个日志流服务了其所在节点上的多个 Tablet。为了能够保护数据,并在节点发生故障的时候不中断服务,每个日志流及其所属的 Tablet 有多个副本。一般来说,多个副本分散在多个不同的可用区里。多个副本中有且只有一个副本接受修改操作,叫做主副本(Leader),其他副本叫做从副本(Follower)。主从副本之间通过基于 Multi-Paxos 的分布式共识协议实现了副本之间数据的一致性。当主副本所在节点发生故障的时候,一个从副本会被选举为新的主副本并继续提供服务。
+
+在集群的每个节点上会运行一个叫做 observer 的服务进程,它内部包含多个操作系统线程。节点的功能都是对等的。每个服务负责自己所在节点上分区数据的存取,也负责路由到本机的 SQL 语句的解析和执行。这些服务进程之间通过 TCP/IP 协议进行通信。同时,每个服务会监听来自外部应用的连接请求,建立连接和数据库会话,并提供数据库服务。关于 observer 服务进程的更多信息,参见 线程简介。
+
+为了简化大规模部署多个业务数据库的管理并降低资源成本,OceanBase 数据库提供了独特的多租户特性。在一个 OceanBase 集群内,可以创建很多个互相之间隔离的数据库"实例",叫做一个租户。从应用程序的视角来看,每个租户是一个独立的数据库。不仅如此,每个租户可以选择 MySQL 或 Oracle 兼容模式。应用连接到 MySQL 租户后,可以在租户下创建用户、database,与一个独立的 MySQL 库的使用体验是一样的。同样的,应用连接到 Oracle 租户后,可以在租户下创建 schema、管理角色等,与一个独立的 Oracle 库的使用体验是一样的。一个新的集群初始化之后,就会存在一个特殊的名为 sys 的租户,叫做系统租户。系统租户中保存了集群的元数据,是一个 MySQL 兼容模式的租户。
+
+为了隔离租户的资源,每个 observer 进程内可以有多个属于不同租户的虚拟容器,叫做资源单元(UNIT)。每个租户在多个节点上的资源单元组成一个资源池。资源单元包括 CPU 和内存资源。
+
+为了使 OceanBase 数据库对应用程序屏蔽内部分区和副本分布等细节,使应用访问分布式数据库像访问单机数据库一样简单,我们提供了 obproxy 代理服务。应用程序并不会直接与 OBServer 建立连接,而是连接obproxy,然后由 obproxy 转发 SQL 请求到合适的 OBServer 节点。obproxy 是无状态的服务,多个 obproxy 节点通过网络负载均衡(SLB)对应用提供统一的网络地址。
+
+
+OceanBase 数据库是随着阿里巴巴电商业务的发展孕育而生,随着蚂蚁集团移动支付业务的发展而壮大,经过十多年各类业务的使用和打磨才终于破茧成蝶,推向了外部市场。本章节简述 OceanBase 数据库发展过程中一些里程碑意义的事件。
+
+诞生
+
+2010 年,OceanBase 创始人阳振坤博士带领初创团队启动了 OceanBase 项目。第一个应用是淘宝的收藏夹业务。如今收藏夹依然是 OceanBase 的客户。收藏夹单表数据量非常大,OceanBase 用独创的方法解决了其高并发的大表连接小表的需求。
+
+关系数据库
+
+早期的版本中,应用通过定制的 API 库访问 OceanBase 数据库。2012 年,OceanBase 数据库发布了支持 SQL 的版本,初步成为一个功能完整的通用关系数据库。
+
+初试金融业务
+
+OceanBase 进入支付宝(后来的蚂蚁集团),开始应用于金融级的业务场景。2014 年"双 11"大促活动,OceanBase 开始承担交易库部分流量。此后,新成立的网商银行把所有核心交易库都运行在 OceanBase 数据库上。
+
+金融级核心库
+
+2016 年,OceanBase 数据库发布了架构重新设计后的 1.0 版本,支持了分布式事务,提升了高并发写业务中的扩展,同时实现了多租户架构,这个整体架构延续至今。同时,到 2016 年"双 11"时,支付宝全部核心库的业务流量 100% 运行在 OceanBase 数据库上,包括交易、支付、会员和最重要的账务库。
+
+走向外部市场
+
+2017 年,OceanBase 数据库开始试点外部业务,成功应用于南京银行。
+
+商业化加速
+
+2018 年,OceanBase 数据库发布 2.0 版本,开始支持 Oracle 兼容模式。这一特性降低应用改造适配成本,在外部客户中快速推广开来。
+
+勇攀高峰
+
+2019 年,OceanBase 数据库 V2.2 版本参加代表 OLTP 数据库最权威的 TPC-C 评测,以 6000 万 tpmC 的成绩登顶世界第一。随后,在 2020 年,又以 7 亿 tpmC 刷新纪录,截止目前依然稳居第一。这充分证明了 OceanBase 数据库优秀的扩展性和稳定性。OceanBase 数据库是第一个也是截止目前唯一一个上榜 TPC-C 的中国数据库产品。
+
+HTAP 混合负载
+
+2021 年,OceanBase 数据库 V3.0 基于全新的向量化执行引擎,在 TPC-H 30000GB 的评测中以 1526 万 QphH 的成绩刷新了评测榜单。这标志着 OceanBase 数据库一套引擎处理 AP 和 TP 混合负载的能力取得了基础性的突破。
+
+开源开放
+
+2021 年六一儿童节,OceanBase 数据库宣布全面开源,开放合作,共建生态。
+
+OceanBase 数据库采用了单集群多租户设计,天然支持云数据库架构,支持公有云、私有云、混合云等多种部署形式。
+
+架构
+
+OceanBase 数据库通过租户实现资源隔离,让每个数据库服务的实例不感知其他实例的存在,并通过权限控制确保租户数据的安全性,配合 OceanBase 数据库强大的可扩展性,能够提供安全、灵活的 DBaaS 服务。
+
+租户是一个逻辑概念。在 OceanBase 数据库中,租户是资源分配的单位,是数据库对象管理和资源管理的基础,对于系统运维,尤其是对于云数据库的运维有着重要的影响。租户在一定程度上相当于传统数据库的"实例"概念。租户之间是完全隔离的。在数据安全方面,OceanBase 数据库不允许跨租户的数据访问,以确保用户的数据资产没有被其他租户窃取的风险。在资源使用方面,OceanBase 数据库表现为租户"独占"其资源配额。总体上来说,租户(tenant)既是各类数据库对象的容器,又是资源(CPU、Memory、IO 等)的容器。
+
+OceanBase 数据库在一个系统中可同时支持 MySQL 模式和 Oracle 模式两种模式的租户。用户在创建租户时,可选择创建 MySQL 兼容模式的租户或 Oracle 兼容模式的租户,租户的兼容模式一经确定就无法更改,所有数据类型、SQL 功能、视图等相应地与 MySQL 数据库或 Oracle 数据库保持一致。
+
+
+MySQL 模式
+MySQL 模式是为降低 MySQL 数据库迁移至 OceanBase 数据库所引发的业务系统改造成本,同时使业务数据库设计人员、开发人员、数据库管理员等可复用积累的 MySQL 数据库技术知识经验,并能快速上手 OceanBase 数据库而支持的一种租户类型功能。OceanBase 数据库的 MySQL 模式兼容 MySQL 5.7 的绝大部分功能和语法,兼容 MySQL 5.7 版本的全量以及 8.0 版本的部分 JSON 函数,基于 MySQL 的应用能够平滑迁移。
+
+Oracle 模式
+OceanBase 数据库从 V2.x.x 版本开始支持 Oracle 兼容模式。Oracle 模式是为降低 Oracle 数据库迁移 OceanBase 数据库的业务系统改造成本,同时使业务数据库设计开发人员、数据库管理员等可复用积累的 Oracle 数据库技术知识经验,并能快速上手 OceanBase 数据库而支持的一种租户类型功能。Oracle 模式目前能够支持绝大部分的 Oracle 语法和过程性语言功能,可以做到大部分的 Oracle 业务进行少量修改后的自动迁移。
+
+OceanBase 数据库是多租户架构。在 V4.0.0 版本之前,仅支持两种类型的租户:系统租户和用户租户。从 V4.0.0 版本开始,引入了 Meta 租户概念。因此,当前版本对用户可见的租户有三种类型:系统租户、用户租户以及 Meta 租户。
+
+系统租户
+系统租户是集群默认创建的租户,与集群的生命周期一致,负责管理集群和所有租户的生命周期。系统租户仅有一个 1 号日志流,仅支持单点写入,不具备扩展能力。
+
+系统租户可以创建用户表,所有的用户表和系统表数据均由 1 号日志流服务。系统租户的数据是集群私有的,不支持主备集群物理同步和物理备份恢复。
+
+用户租户
+用户租户是由用户创建的租户,对外提供完整的数据库功能,支持 MySQL 和 Oracle 两种兼容模式。用户租户支持服务能力水平扩展到多台机器上,支持动态扩容和缩容,内部会根据用户的配置自动创建和删除日志流。
+
+用户租户的数据有更强的数据保护和可用性要求,支持跨集群物理同步和物理备份恢复,典型数据包括:Schema 数据、用户表数据及事务数据等。
+Meta 租户
+Meta 租户是 OceanBase 数据库内部自管理的租户,每创建一个用户租户系统就会自动创建一个对应的 Meta 租户,其生命期与用户租户保持一致。
+
+Meta 租户用于存储和管理用户租户的集群私有数据,这部分数据不需要进行跨库物理同步以及物理备份恢复,这些数据包括:配置项、位置信息、副本信息、日志流状态、备份恢复相关信息、合并信息等。
+
+租户对比
+从用户角度来看,系统租户、用户租户和 Meta 租户的差异性如下表所示。
+OceanBase 数据库是多租户的数据库系统,一个集群内可包含多个相互独立的租户,每个租户提供独立的数据库服务。在 OceanBase 数据库中,使用资源配置(unit_config)、资源池(Resource Pool)和资源单元(Unit)三个概念,对各租户的可用资源进行管理。
+
+
+创建租户前,需首先确定租户的资源配置、使用资源范围等。租户创建的通用流程如下:
+
+资源配置是描述资源池的配置信息,用来描述资源池中每个资源单元可用的 CPU、内存、存储空间和 IOPS 等的规格。修改资源配置可动态调整资源单元的规格。这里需要注意,资源配置指定的是对应资源单元能够提供的服务能力,而不是资源单元的实时负载。 创建资源配置的示例语句如下:
+
+问题:
+请你基于上述内容对 OceanBase 的介绍进行总结,不少于2000字。
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
index 9215b21a2..56f851227 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -3,48 +3,58 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-Welcome to DB-GPT!
-==================================
-| As large models are released and iterated upon, they are becoming increasingly intelligent. However, in the process of using large models, we face significant challenges in data security and privacy. We need to ensure that our sensitive data and environments remain completely controlled and avoid any data privacy leaks or security risks. Based on this, we have launched the DB-GPT project to build a complete private large model solution for all database-based scenarios. This solution supports local deployment, allowing it to be applied not only in independent private environments but also to be independently deployed and isolated according to business modules, ensuring that the ability of large models is absolutely private, secure, and controllable.
+Overview
+------------------
-| **DB-GPT** is an experimental open-source project that uses localized GPT large models to interact with your data and environment. With this solution, you can be assured that there is no risk of data leakage, and your data is 100% private and secure.
+| DB-GPT is an open-source framework for large models in the databases fields. It's purpose is to build infrastructure for the domain of large models, making it easier and more convenient to develop applications around databases. By developing various technical capabilities such as:
-| **Features**
-Currently, we have released multiple key features, which are listed below to demonstrate our current capabilities:
+1. **SMMF(Service-oriented Multi-model Management Framework)**
+2. **Text2SQL Fine-tuning**
+3. **RAG(Retrieval Augmented Generation) framework and optimization**
+4. **Data-Driven Agents framework collaboration**
+5. **GBI(Generative Business intelligence)**
-- SQL language capabilities
- - SQL generation
- - SQL diagnosis
+etc, DB-GPT simplifies the construction of large model applications based on databases.
-- Private domain Q&A and data processing
- - Database knowledge Q&A
- - Data processing
+| In the era of Data 3.0, enterprises and developers can build their own customized applications with less code, leveraging models and databases.
-- Plugins
- - Support custom plugin execution tasks and natively support the Auto-GPT plugin, such as:
+Features
+^^^^^^^^^^^
-- Unified vector storage/indexing of knowledge base
- - Support for unstructured data such as PDF, Markdown, CSV, and WebURL
+| **1. Private Domain Q&A & Data Processing**
+| Supports custom construction of knowledge bases through methods such as built-in, multi-file format uploads, and plugin-based web scraping. Enables unified vector storage and retrieval of massive structured and unstructured data.
+
+| **2.Multi-Data Source & GBI(Generative Business intelligence)**
+| Supports interaction between natural language and various data sources such as Excel, databases, and data warehouses. Also supports analysis reporting.
+
+| **3.SMMF(Service-oriented Multi-model Management Framework)**
+| Supports a wide range of models, including dozens of large language models such as open-source models and API proxies. Examples include LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, Xinghuo, etc.
+
+| **4.Automated Fine-tuning**
+| A lightweight framework for automated fine-tuning built around large language models, Text2SQL datasets, and methods like LoRA/QLoRA/Pturning. Makes TextSQL fine-tuning as convenient as a production line.
+
+| **5.Data-Driven Multi-Agents & Plugins**
+| Supports executing tasks through custom plugins and natively supports the Auto-GPT plugin model. Agents protocol follows the Agent Protocol standard.
+
+| **6.Privacy and Security**
+| Ensures data privacy and security through techniques such as privatizing large models and proxy de-identification.
-- Multi LLMs Support
- - Supports multiple large language models, currently supporting Vicuna (7b, 13b), ChatGLM-6b (int4, int8)
- - TODO: codegen2, codet5p
Getting Started
------------------
-| How to get started using DB-GPT to interact with your data and environment.
-- `Quickstart Guide <./getting_started/getting_started.html>`_
+^^^^^^^^^^^^^^^^^
+
+| Quickstart
+
+- `Quickstart Guide <./getting_started/getting_started.html>`_
| Concepts and terminology
- `Concepts and Terminology <./getting_started/concepts.html>`_
-| Coming soon...
-
-- `Tutorials <.getting_started/tutorials.html>`_
.. toctree::
:maxdepth: 2
:caption: Getting Started
+ :name: getting_started
:hidden:
getting_started/install.rst
@@ -57,10 +67,9 @@ Getting Started
Modules
----------
+^^^^^^^^^
-| These modules are the core abstractions with which we can interact with data and environment smoothly.
-It's very important for DB-GPT, DB-GPT also provide standard, extendable interfaces.
+| These modules are the core abstractions with which we can interact with data and environment smoothly. It's very important for DB-GPT, DB-GPT also provide standard, extendable interfaces.
| The docs for each module contain quickstart examples, how to guides, reference docs, and conceptual guides.
@@ -78,35 +87,23 @@ It's very important for DB-GPT, DB-GPT also provide standard, extendable interfa
- `Vector <./modules/vector.html>`_: Supported multi vector database.
+-------------
+
.. toctree::
:maxdepth: 2
:caption: Modules
:name: modules
:hidden:
- ./modules/llms.md
- ./modules/prompts.md
- ./modules/plugins.md
- ./modules/connections.rst
- ./modules/knowledge.rst
- ./modules/vector.rst
-
-
-Reference
------------
-| Full documentation on all methods, classes, installation methods, and integration setups for DB-GPT.
-
-.. toctree::
- :maxdepth: 1
- :caption: Reference
- :name: reference
- :hidden:
-
- ./reference.md
-
+ modules/llms.md
+ modules/prompts.md
+ modules/plugins.md
+ modules/connections.rst
+ modules/knowledge.rst
+ modules/vector.rst
Resources
-----------
+-----------------
| Additional resources we think may be useful as you develop your application!
diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po
index addf53bc6..bd4d6fa0b 100644
--- a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po
+++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2023-11-02 21:04+0800\n"
+"POT-Creation-Date: 2023-11-14 16:08+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language: zh_CN\n"
@@ -20,292 +20,287 @@ msgstr ""
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/environment/environment.md:1
-#: a17719d2f4374285a7beb4d1db470146
+#: e4787ab6eacc4362802752528bb786ec
#, fuzzy
msgid "Environment Parameter"
msgstr "环境变量说明"
#: ../../getting_started/install/environment/environment.md:4
-#: 9a62e6fff7914eeaa2d195ddef4fcb61
+#: 4682a0734a034e0e9f2c22fa061b889e
msgid "LLM MODEL Config"
msgstr "模型配置"
#: ../../getting_started/install/environment/environment.md:5
-#: 90e3991538324ecfac8cac7ef2103ac2
+#: c148f178b2964344a570bb2b3713fba3
msgid "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
msgstr "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
#: ../../getting_started/install/environment/environment.md:6
-#: 1f45af01100c4586acbc05469e3006bc
+#: 9ab8d82fb338439a8c0042b92ad2f7c4
msgid "LLM_MODEL=vicuna-13b"
msgstr "LLM_MODEL=vicuna-13b"
#: ../../getting_started/install/environment/environment.md:8
-#: bed14b704f154c2db525f7fafd3aa5a4
+#: 76fb3b1299694730852f120db6fec7f9
msgid "MODEL_SERVER_ADDRESS"
msgstr "MODEL_SERVER_ADDRESS"
-#: ../../getting_started/install/environment/environment.md:9
-#: ea42946cfe4f4ad996bf82c1996e7344
-msgid "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY"
+#: ../../getting_started/install/environment/environment.md:10
+#: 7476a0ee342f4517bbf999abecec029e
+#, fuzzy
+msgid "MODEL_SERVER=http://127.0.0.1:8000"
msgstr "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY"
#: ../../getting_started/install/environment/environment.md:12
-#: 021c261231f342fdba34098b1baa06fd
-msgid "LIMIT_MODEL_CONCURRENCY=5"
+#: fb3c73990a6443e8b63c35d61175e467
+#, fuzzy
+msgid "LIMIT_MODEL_CONCURRENCY"
msgstr "LIMIT_MODEL_CONCURRENCY=5"
#: ../../getting_started/install/environment/environment.md:14
-#: afaf0ba7fd09463d8ff74b514ed7264c
+#: 0eb187fffa3643dbac4bbe7237d2e011
+msgid "LIMIT_MODEL_CONCURRENCY=5"
+msgstr "LIMIT_MODEL_CONCURRENCY=5"
+
+#: ../../getting_started/install/environment/environment.md:16
+#: 1d7b8bf89c1b44e9871d9d0c382db114
msgid "MAX_POSITION_EMBEDDINGS"
msgstr "MAX_POSITION_EMBEDDINGS"
-#: ../../getting_started/install/environment/environment.md:16
-#: e4517a942bca4361a64a00408f993f5b
+#: ../../getting_started/install/environment/environment.md:18
+#: 50d0b3f760fd4ff9829cd1ba0653fd79
msgid "MAX_POSITION_EMBEDDINGS=4096"
msgstr "MAX_POSITION_EMBEDDINGS=4096"
-#: ../../getting_started/install/environment/environment.md:18
-#: 78d2ef04ed4548b9b7b0fb8ae35c9d5c
+#: ../../getting_started/install/environment/environment.md:20
+#: d07c4bbcde214f5993d73ac2bfb1bf9e
msgid "QUANTIZE_QLORA"
msgstr "QUANTIZE_QLORA"
-#: ../../getting_started/install/environment/environment.md:20
-#: bfa65db03c6d46bba293331f03ab15ac
+#: ../../getting_started/install/environment/environment.md:22
+#: 6bceef51780f45d9805270d16847ddc2
msgid "QUANTIZE_QLORA=True"
msgstr "QUANTIZE_QLORA=True"
-#: ../../getting_started/install/environment/environment.md:22
-#: 1947d45a7f184821910b4834ad5f1897
+#: ../../getting_started/install/environment/environment.md:24
+#: df9d560f69334e4aa3f6803e40a7f38d
msgid "QUANTIZE_8bit"
msgstr "QUANTIZE_8bit"
-#: ../../getting_started/install/environment/environment.md:24
-#: 4a2ee2919d0e4bdaa13c9d92eefd2aac
+#: ../../getting_started/install/environment/environment.md:26
+#: ac433b8574574432add7315558b845ea
msgid "QUANTIZE_8bit=True"
msgstr "QUANTIZE_8bit=True"
-#: ../../getting_started/install/environment/environment.md:27
-#: 348dc1e411b54ab09414f40a20e934e4
+#: ../../getting_started/install/environment/environment.md:29
+#: 7b1c407517984bff9f4d509c5f45b92e
msgid "LLM PROXY Settings"
msgstr "LLM PROXY Settings"
-#: ../../getting_started/install/environment/environment.md:28
-#: a692e78425a040f5828ab54ff9a33f77
+#: ../../getting_started/install/environment/environment.md:30
+#: ba7d52c0e95143ebb973e7eda69f0bc1
msgid "OPENAI Key"
msgstr "OPENAI Key"
-#: ../../getting_started/install/environment/environment.md:30
-#: 940d00e25a424acf92951a314a64e5ea
+#: ../../getting_started/install/environment/environment.md:32
+#: 0f0bd20a7a60461e8bcfc91297cc3666
msgid "PROXY_API_KEY={your-openai-sk}"
msgstr "PROXY_API_KEY={your-openai-sk}"
-#: ../../getting_started/install/environment/environment.md:31
-#: 4bd27547ae6041679e91f2a363cd1deb
+#: ../../getting_started/install/environment/environment.md:33
+#: d9c03e0b3316415eb2ca59ad9c419b8c
msgid "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions"
msgstr "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions"
-#: ../../getting_started/install/environment/environment.md:33
-#: cfa3071afb0b47baad6bd729d4a02cb9
+#: ../../getting_started/install/environment/environment.md:35
+#: 45883f99c1fd494ea513f3c0f92562a3
msgid "from https://bard.google.com/ f12-> application-> __Secure-1PSID"
msgstr "from https://bard.google.com/ f12-> application-> __Secure-1PSID"
-#: ../../getting_started/install/environment/environment.md:35
-#: a17efa03b10f47f68afac9e865982a75
+#: ../../getting_started/install/environment/environment.md:37
+#: 70665dbe72c545a3b61c6efe37dfa7d5
msgid "BARD_PROXY_API_KEY={your-bard-token}"
msgstr "BARD_PROXY_API_KEY={your-bard-token}"
-#: ../../getting_started/install/environment/environment.md:38
-#: 6bcfe90574da4d82a459e8e11bf73cba
+#: ../../getting_started/install/environment/environment.md:40
+#: 782f8a9c9cd745a4990542ba8130c66a
msgid "DATABASE SETTINGS"
msgstr "DATABASE SETTINGS"
-#: ../../getting_started/install/environment/environment.md:39
-#: 2b1e62d9bf5d4af5a22f68c8248eaafb
+#: ../../getting_started/install/environment/environment.md:41
+#: 50ad9eae827a407c8c77692f48b9d423
msgid "SQLite database (Current default database)"
msgstr "SQLite database (Current default database)"
-#: ../../getting_started/install/environment/environment.md:40
-#: 8a909ac3b3c943da8dbc4e8dd596c80c
+#: ../../getting_started/install/environment/environment.md:42
+#: 410041683b664cabbe7ce6cb2050c629
msgid "LOCAL_DB_PATH=data/default_sqlite.db"
msgstr "LOCAL_DB_PATH=data/default_sqlite.db"
-#: ../../getting_started/install/environment/environment.md:41
-#: 90ae6507932f4815b6e180051738bb93
+#: ../../getting_started/install/environment/environment.md:43
+#: 0fcf0f9da84d4e4a8a1503a96dd6734b
msgid "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite"
msgstr "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite"
-#: ../../getting_started/install/environment/environment.md:43
-#: d2ce34e0dcf44ccf9e8007d548ba7b0a
+#: ../../getting_started/install/environment/environment.md:45
+#: 15fb9cdc51e44b71a1a375e49fb7bc6d
msgid "MYSQL database"
msgstr "MYSQL database"
-#: ../../getting_started/install/environment/environment.md:44
-#: c07159d63c334f6cbb95fcc30bfb7ea5
+#: ../../getting_started/install/environment/environment.md:46
+#: c8cc4cb61d1c44cd9ef3546455929ef6
msgid "LOCAL_DB_TYPE=mysql"
msgstr "LOCAL_DB_TYPE=mysql"
-#: ../../getting_started/install/environment/environment.md:45
-#: e16700b2ea8d411e91d010c1cde7aecc
+#: ../../getting_started/install/environment/environment.md:47
+#: a6caf3cabc4041b5879ec3af25c85139
msgid "LOCAL_DB_USER=root"
msgstr "LOCAL_DB_USER=root"
-#: ../../getting_started/install/environment/environment.md:46
-#: bfc2dce1bf374121b6861e677b4e1ffa
+#: ../../getting_started/install/environment/environment.md:48
+#: b839bde122374e299086f120fce0144c
msgid "LOCAL_DB_PASSWORD=aa12345678"
msgstr "LOCAL_DB_PASSWORD=aa12345678"
-#: ../../getting_started/install/environment/environment.md:47
-#: bc384739f5b04e21a34d0d2b78e7906c
+#: ../../getting_started/install/environment/environment.md:49
+#: 52cdbfdefda142b4a3b5cb3b060916a8
msgid "LOCAL_DB_HOST=127.0.0.1"
msgstr "LOCAL_DB_HOST=127.0.0.1"
-#: ../../getting_started/install/environment/environment.md:48
-#: e5253d452e0d42b7ac308fe6fbfb5017
+#: ../../getting_started/install/environment/environment.md:50
+#: 492db6e5c13b40898f38063980c5897c
msgid "LOCAL_DB_PORT=3306"
msgstr "LOCAL_DB_PORT=3306"
-#: ../../getting_started/install/environment/environment.md:51
-#: 9ca8f6fe06ed4cbab390f94be252e165
+#: ../../getting_started/install/environment/environment.md:53
+#: 20b101603f054c70af633439abddefec
msgid "EMBEDDING SETTINGS"
msgstr "EMBEDDING SETTINGS"
-#: ../../getting_started/install/environment/environment.md:52
-#: 76c7c260293c4b49bae057143fd48377
+#: ../../getting_started/install/environment/environment.md:54
+#: 3463a5a74cea494c8442100c0069285c
msgid "EMBEDDING MODEL Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG"
msgstr "EMBEDDING模型, 参考see /pilot/configs/model_config.LLM_MODEL_CONFIG"
-#: ../../getting_started/install/environment/environment.md:53
-#: f1d63a0128ce493cae37d34f1976bcca
+#: ../../getting_started/install/environment/environment.md:55
+#: 4c8adbf52110474bbfcd3b63cf2839f6
msgid "EMBEDDING_MODEL=text2vec"
msgstr "EMBEDDING_MODEL=text2vec"
-#: ../../getting_started/install/environment/environment.md:55
-#: b8fbb99109d04781b2dd5bc5d6efa5bd
+#: ../../getting_started/install/environment/environment.md:57
+#: 8a85a75151e64827971b1a367b31ecfa
msgid "Embedding Chunk size, default 500"
msgstr "Embedding 切片大小, 默认500"
-#: ../../getting_started/install/environment/environment.md:57
-#: bf8256576ea34f6a9c5f261ab9aab676
+#: ../../getting_started/install/environment/environment.md:59
+#: 947939b0fa7e46de97d48eadf5c443d2
msgid "KNOWLEDGE_CHUNK_SIZE=500"
msgstr "KNOWLEDGE_CHUNK_SIZE=500"
-#: ../../getting_started/install/environment/environment.md:59
-#: 9b156c6b599b4c02a58ce023b4ff25f2
+#: ../../getting_started/install/environment/environment.md:61
+#: 2785ad6bb0de4534a6523ac420f2c84c
msgid "Embedding Chunk Overlap, default 100"
msgstr "Embedding chunk Overlap, 文本块之间的最大重叠量。保留一些重叠可以保持文本块之间的连续性(例如使用滑动窗口),默认100"
-#: ../../getting_started/install/environment/environment.md:60
-#: dcafd903c36041ac85ac99a14dbee512
+#: ../../getting_started/install/environment/environment.md:62
+#: 40b6a8f57ee14ec1ab73143ba1516e78
msgid "KNOWLEDGE_CHUNK_OVERLAP=100"
msgstr "KNOWLEDGE_CHUNK_OVERLAP=100"
-#: ../../getting_started/install/environment/environment.md:62
-#: 6c3244b7e5e24b0188c7af4bb52e9134
+#: ../../getting_started/install/environment/environment.md:64
+#: e410faa1087c45639ee210be99cf9336
#, fuzzy
msgid "embedding recall top k,5"
msgstr "embedding 召回topk, 默认5"
-#: ../../getting_started/install/environment/environment.md:64
-#: f4a2f30551cf4fe1a7ff3c7c74ec77be
+#: ../../getting_started/install/environment/environment.md:66
+#: abfca38fe2a04161a11259588fa4d205
msgid "KNOWLEDGE_SEARCH_TOP_SIZE=5"
msgstr "KNOWLEDGE_SEARCH_TOP_SIZE=5"
-#: ../../getting_started/install/environment/environment.md:66
-#: 593f2512362f467e92fdaa60dd5903a0
+#: ../../getting_started/install/environment/environment.md:68
+#: 31182c38607b4c3bbc657b5fe5b7a4f6
#, fuzzy
msgid "embedding recall max token ,2000"
msgstr "embedding向量召回最大token, 默认2000"
-#: ../../getting_started/install/environment/environment.md:68
-#: 83d6d28914be4d6282d457272e508ddc
+#: ../../getting_started/install/environment/environment.md:70
+#: 96cd042635bc468e90c792fd9d1a7f4d
msgid "KNOWLEDGE_SEARCH_MAX_TOKEN=5"
msgstr "KNOWLEDGE_SEARCH_MAX_TOKEN=5"
-#: ../../getting_started/install/environment/environment.md:71
-#: ../../getting_started/install/environment/environment.md:87
-#: 6bc1b9d995e74294a1c78e783c550db7 d33c77ded834438e9f4a2df06e7e041a
+#: ../../getting_started/install/environment/environment.md:73
+#: d43b408ad9bc46f2b3c97aa91627f6b3
msgid "Vector Store SETTINGS"
msgstr "Vector Store SETTINGS"
-#: ../../getting_started/install/environment/environment.md:72
-#: ../../getting_started/install/environment/environment.md:88
-#: 9cafa06e2d584f70afd848184e0fa52a f01057251b8b4ffea806192dfe1048ed
+#: ../../getting_started/install/environment/environment.md:74
+#: b1fcbf6049af4eeea91edd3de58c8512
msgid "Chroma"
msgstr "Chroma"
-#: ../../getting_started/install/environment/environment.md:73
-#: ../../getting_started/install/environment/environment.md:89
-#: e6c16fab37484769b819aeecbc13e6db faad299722e5400e95ec6ac3c1e018b8
+#: ../../getting_started/install/environment/environment.md:75
+#: 2fb31575b274448fb945d47ee0eb108c
msgid "VECTOR_STORE_TYPE=Chroma"
msgstr "VECTOR_STORE_TYPE=Chroma"
-#: ../../getting_started/install/environment/environment.md:74
-#: ../../getting_started/install/environment/environment.md:90
-#: 4eca3a51716d406f8ffd49c06550e871 581ee9dd38064b119660c44bdd00cbaa
+#: ../../getting_started/install/environment/environment.md:76
+#: 601b87cc6f1d4732b935747e907cba5a
msgid "MILVUS"
msgstr "MILVUS"
-#: ../../getting_started/install/environment/environment.md:75
-#: ../../getting_started/install/environment/environment.md:91
-#: 814c93048bed46589358a854d6c99683 b72b1269a2224f5f961214e41c019f21
+#: ../../getting_started/install/environment/environment.md:77
+#: fde6cf6982764020aa1174f7fe3a5b3e
msgid "VECTOR_STORE_TYPE=Milvus"
msgstr "VECTOR_STORE_TYPE=Milvus"
-#: ../../getting_started/install/environment/environment.md:76
-#: ../../getting_started/install/environment/environment.md:92
-#: 73ae665f1db9402883662734588fd02c c4da20319c994e83ba5a7706db967178
+#: ../../getting_started/install/environment/environment.md:78
+#: 40c6206c7a614edf9b0af82c2c76f518
msgid "MILVUS_URL=127.0.0.1"
msgstr "MILVUS_URL=127.0.0.1"
-#: ../../getting_started/install/environment/environment.md:77
-#: ../../getting_started/install/environment/environment.md:93
-#: e30c5288516d42aa858a485db50490c1 f843b2e58bcb4e4594e3c28499c341d0
+#: ../../getting_started/install/environment/environment.md:79
+#: abde3c75269442cbb94a59c657d847a9
msgid "MILVUS_PORT=19530"
msgstr "MILVUS_PORT=19530"
-#: ../../getting_started/install/environment/environment.md:78
-#: ../../getting_started/install/environment/environment.md:94
-#: 158669efcc7d4bcaac1c8dd01b499029 24e88ffd32f242f281c56c0ec3ad2639
+#: ../../getting_started/install/environment/environment.md:80
+#: 375a837cbf6d4d65891612a7f073414a
msgid "MILVUS_USERNAME"
msgstr "MILVUS_USERNAME"
-#: ../../getting_started/install/environment/environment.md:79
-#: ../../getting_started/install/environment/environment.md:95
-#: 111a985297184c8aa5a0dd8e14a58445 6602093a6bb24d6792548e2392105c82
+#: ../../getting_started/install/environment/environment.md:81
+#: f785a796c8d3452c802d9a637f34cb57
msgid "MILVUS_PASSWORD"
msgstr "MILVUS_PASSWORD"
-#: ../../getting_started/install/environment/environment.md:80
-#: ../../getting_started/install/environment/environment.md:96
-#: 47bdfcd78fbe4ccdb5f49b717a6d01a6 b96c0545b2044926a8a8190caf94ad25
+#: ../../getting_started/install/environment/environment.md:82
+#: 18cd17a50dc14add9b31f6b4c55069ef
msgid "MILVUS_SECURE="
msgstr "MILVUS_SECURE="
-#: ../../getting_started/install/environment/environment.md:82
-#: ../../getting_started/install/environment/environment.md:98
-#: 755c32b5d6c54607907a138b5474c0ec ff4f2a7ddaa14f089dda7a14e1062c36
+#: ../../getting_started/install/environment/environment.md:84
+#: a4783d775bf2444788b758a71bd5a7e7
msgid "WEAVIATE"
msgstr "WEAVIATE"
-#: ../../getting_started/install/environment/environment.md:83
-#: 23b2ce83385d40a589a004709f9864be
+#: ../../getting_started/install/environment/environment.md:85
+#: 3cc5ca99670947e6868e27db588031e0
msgid "VECTOR_STORE_TYPE=Weaviate"
msgstr "VECTOR_STORE_TYPE=Weaviate"
-#: ../../getting_started/install/environment/environment.md:84
-#: ../../getting_started/install/environment/environment.md:99
-#: 9acef304d89a448a9e734346705ba872 cf5151b6c1594ccd8beb1c3f77769acb
+#: ../../getting_started/install/environment/environment.md:86
+#: 141a3da2e36e40ffaa0fb863081a4c07
msgid "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network"
msgstr "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network"
-#: ../../getting_started/install/environment/environment.md:102
-#: c3003516b2364051bf34f8c3086e348a
+#: ../../getting_started/install/environment/environment.md:89
+#: fde1941617ec4148b33c298bebeb45e4
msgid "Multi-GPU Setting"
msgstr "Multi-GPU Setting"
-#: ../../getting_started/install/environment/environment.md:103
-#: ade8fc381c5e438aa29d159c10041713
+#: ../../getting_started/install/environment/environment.md:90
+#: fe162354e15e42cda54f6c9322409321
msgid ""
"See https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-"
"visibility-cuda_visible_devices/ If CUDA_VISIBLE_DEVICES is not "
@@ -314,50 +309,50 @@ msgstr ""
"参考 https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-visibility-"
"cuda_visible_devices/ 如果 CUDA_VISIBLE_DEVICES没有设置, 会使用所有可用的gpu"
-#: ../../getting_started/install/environment/environment.md:106
-#: e137bd19be5e410ba6709027dbf2923a
+#: ../../getting_started/install/environment/environment.md:93
+#: c8a83b09bfc94dab8226840b275ca034
msgid "CUDA_VISIBLE_DEVICES=0"
msgstr "CUDA_VISIBLE_DEVICES=0"
-#: ../../getting_started/install/environment/environment.md:108
-#: 7669947acbdc4b1d92bcc029a8353a5d
+#: ../../getting_started/install/environment/environment.md:95
+#: a1d33bd2492a4a80bd8b679c1331280a
msgid ""
"Optionally, you can also specify the gpu ID to use before the starting "
"command"
msgstr "你也可以通过启动命令设置gpu ID"
-#: ../../getting_started/install/environment/environment.md:110
-#: 751743d1753b4051beea46371278d793
+#: ../../getting_started/install/environment/environment.md:97
+#: 961087a5cf1b45168c7439e3a2103253
msgid "CUDA_VISIBLE_DEVICES=3,4,5,6"
msgstr "CUDA_VISIBLE_DEVICES=3,4,5,6"
-#: ../../getting_started/install/environment/environment.md:112
-#: 3acc3de0af0d4df2bb575e161e377f85
+#: ../../getting_started/install/environment/environment.md:99
+#: 545b438ecb9d46edacbd8b4cc95886f9
msgid "You can configure the maximum memory used by each GPU."
msgstr "可以设置GPU的最大内存"
-#: ../../getting_started/install/environment/environment.md:114
-#: 67f1d9b172b84294a44ecace5436e6e0
+#: ../../getting_started/install/environment/environment.md:101
+#: a78dc8082fa04e13a7a3e43302830c26
msgid "MAX_GPU_MEMORY=16Gib"
msgstr "MAX_GPU_MEMORY=16Gib"
-#: ../../getting_started/install/environment/environment.md:117
-#: 3c69dfe48bcf46b89b76cac1e7849a66
+#: ../../getting_started/install/environment/environment.md:104
+#: eaebcb1784be4047b739ff1b8a78faa1
msgid "Other Setting"
msgstr "Other Setting"
-#: ../../getting_started/install/environment/environment.md:118
-#: d5015b70f4fe4d20a63de9d87f86957a
+#: ../../getting_started/install/environment/environment.md:105
+#: 21f524662fa34bfa9cfb8855bc191cc7
msgid "Language Settings(influence prompt language)"
msgstr "Language Settings(涉及prompt语言以及知识切片方式)"
-#: ../../getting_started/install/environment/environment.md:119
-#: 5543c28bb8e34c9fb3bb6b063c2b1750
+#: ../../getting_started/install/environment/environment.md:106
+#: bb5ce4a6ee794f0e910363673e54055a
msgid "LANGUAGE=en"
msgstr "LANGUAGE=en"
-#: ../../getting_started/install/environment/environment.md:120
-#: cb4ed5b892ee41068c1ca76cb29aa400
+#: ../../getting_started/install/environment/environment.md:107
+#: 862f113d63b94084b89bfef29f8ab48d
msgid "LANGUAGE=zh"
msgstr "LANGUAGE=zh"
diff --git a/docs/locales/zh_CN/LC_MESSAGES/index.po b/docs/locales/zh_CN/LC_MESSAGES/index.po
index 1badfc0e4..3a4c57dbc 100644
--- a/docs/locales/zh_CN/LC_MESSAGES/index.po
+++ b/docs/locales/zh_CN/LC_MESSAGES/index.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 0.3.0\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2023-11-06 19:00+0800\n"
+"POT-Creation-Date: 2023-11-14 17:55+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language: zh_CN\n"
@@ -19,151 +19,176 @@ msgstr ""
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
-#: ../../index.rst:34 ../../index.rst:45 8bc3a47457a34995816985436034e233
+#: ../../index.rst:44 ../../index.rst:54 fb98707559574eb29f01bd8f6ebfac60
msgid "Getting Started"
msgstr "开始"
-#: ../../index.rst:60 ../../index.rst:81 1a4e8a5dc7754967a0af9fb3d2e53017
+#: ../../index.rst:70 ../../index.rst:92 6d9603b978d44e54a257fb359c871867
msgid "Modules"
msgstr "模块"
-#: ../../index.rst:96 ../../index.rst:99 c815772ae8514f0c9b26911b0dd73f54
-msgid "Reference"
-msgstr "参考"
-
-#: ../../index.rst:109 ../../index.rst:115 dabe4c3409df489f84e4ec588f2b34a5
+#: ../../index.rst:106 ../../index.rst:112 9df06739ca4446bc86ec2ff6907763ce
msgid "Resources"
msgstr "资源"
-#: ../../index.rst:7 7626b01b253546ac83ca0cf130dfa091
-msgid "Welcome to DB-GPT!"
-msgstr "欢迎来到DB-GPT中文文档"
+#: ../../index.rst:7 7de875cfbc764937ab7f8b362d997952
+msgid "Overview"
+msgstr "概览"
-#: ../../index.rst:8 6037e5e0d7f7428ba92315a91ccfd53f
+#: ../../index.rst:9 770a756bd0b640ef863fd72b8d7e882a
msgid ""
-"As large models are released and iterated upon, they are becoming "
-"increasingly intelligent. However, in the process of using large models, "
-"we face significant challenges in data security and privacy. We need to "
-"ensure that our sensitive data and environments remain completely "
-"controlled and avoid any data privacy leaks or security risks. Based on "
-"this, we have launched the DB-GPT project to build a complete private "
-"large model solution for all database-based scenarios. This solution "
-"supports local deployment, allowing it to be applied not only in "
-"independent private environments but also to be independently deployed "
-"and isolated according to business modules, ensuring that the ability of "
-"large models is absolutely private, secure, and controllable."
-msgstr ""
-"随着大型模型的发布和迭代,它们变得越来越智能。然而,在使用大型模型的过程中,我们在数据安全和隐私方面面临着重大挑战。我们需要确保我们的敏感数据和环境得到完全控制,避免任何数据隐私泄露或安全风险。基于此"
-",我们启动了DB-"
-"GPT项目,为所有基于数据库的场景构建一个完整的私有大模型解决方案。该方案“”支持本地部署,既可应用于“独立私有环境”,又可根据业务模块进行“独立部署”和“隔离”,确保“大模型”的能力绝对私有、安全、可控。"
+"DB-GPT is an open-source framework for large models in the database "
+"field. Its purpose is to build infrastructure for the domain of large "
+"models, making it easier and more convenient to develop applications "
+"around databases. By developing various technical capabilities such as:"
+msgstr "DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模型领域的基础设施,通过开发如"
-#: ../../index.rst:10 ab2a181d517047e6992171786c83f8e3
+#: ../../index.rst:11 8774a5ad5ce14baf9eae35fefd62e40b
+msgid "**SMMF(Service-oriented Multi-model Management Framework)**"
+msgstr "**服务化多模型管理**"
+
+#: ../../index.rst:12 b2ba120fc994436db7066486c9acd6ad
+msgid "**Text2SQL Fine-tuning**"
+msgstr "**Text2SQL微调**"
+
+#: ../../index.rst:13 d55efe86dd6b40ebbe63079edb60e421
+msgid "**RAG(Retrieval Augmented Generation) framework and optimization**"
+msgstr "**检索增强**"
+
+#: ../../index.rst:14 3eca943c44464c9cb9bbc5724c27ad1c
+msgid "**Data-Driven Agents framework collaboration**"
+msgstr "**数据驱动的Agents协作框架**"
+
+#: ../../index.rst:15 bf41d57cbc474e2c9829f09d6b983ae1
+msgid "**GBI(Generative Business intelligence)**"
+msgstr "**生成式报表分析**"
+
+#: ../../index.rst:17 36630469cc064317a1c196dd377c3d93
msgid ""
-"**DB-GPT** is an experimental open-source project that uses localized GPT"
-" large models to interact with your data and environment. With this "
-"solution, you can be assured that there is no risk of data leakage, and "
-"your data is 100% private and secure."
-msgstr ""
-"DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地化的GPT大模型与您的数据和环境进行交互,无数据泄露风险100% 私密,100%"
-" 安全。"
+"etc, DB-GPT simplifies the construction of large model applications based"
+" on databases."
+msgstr "等能力, 让围绕数据库构建大模型应用更简单,更方便。"
-#: ../../index.rst:12 9cfb7515430d49af8a1ca47f60264a58
-msgid "**Features**"
+#: ../../index.rst:19 82f03535c6914ebfa8b3adad34eeed2f
+msgid ""
+"In the era of Data 3.0, enterprises and developers can build their own "
+"customized applications with less code, leveraging models and databases."
+msgstr "*数据3.0 时代,基于模型、数据库,企业/开发者可以用更少的代码搭建自己的专属应用*。"
+
+#: ../../index.rst:22 daf64ec39c28458087d542879d106d1b
+msgid "Features"
msgstr "特性"
-#: ../../index.rst:13 2a1f84e455c84d9ca66c65f92e5b0d78
+#: ../../index.rst:24 7ceb41b710f847e683479dc892baa3d5
+msgid "**1. Private Domain Q&A & Data Processing**"
+msgstr "**1. 私域问答&数据处理**"
+
+#: ../../index.rst:25 3f480e259ee9432b934ee6474bc8de79
msgid ""
-"Currently, we have released multiple key features, which are listed below"
-" to demonstrate our current capabilities:"
-msgstr "目前我们已经发布了多种关键的特性,这里一一列举展示一下当前发布的能力。"
+"Supports custom construction of knowledge bases through methods such as "
+"built-in, multi-file format uploads, and plugin-based web scraping. "
+"Enables unified vector storage and retrieval of massive structured and "
+"unstructured data."
+msgstr "支持内置、多文件格式上传、插件自抓取等方式自定义构建知识库,对海量结构化,非结构化数据做统一向量存储与检索"
-#: ../../index.rst:15 43de30ce92da4c3cbe43ae4e4c9f1869
-msgid "SQL language capabilities - SQL generation - SQL diagnosis"
-msgstr "SQL语言能力 - SQL生成 - SQL诊断"
+#: ../../index.rst:27 1f9f12be761a4a6c996788051a3fa4dd
+msgid "**2.Multi-Data Source & GBI(Generative Business intelligence)**"
+msgstr "**2.多数据源与可视化**"
-#: ../../index.rst:19 edfeef5284e7426a9e551e782bc5702c
+#: ../../index.rst:28 e597e6c2d4ad4d1bbcc440b3afb7c0fa
msgid ""
-"Private domain Q&A and data processing - Database knowledge Q&A - Data "
-"processing"
-msgstr "私有领域问答与数据处理 - 数据库知识问答 - 数据处理"
+"Supports interaction between natural language and various data sources "
+"such as Excel, databases, and data warehouses. Also supports analysis "
+"reporting."
+msgstr "支持自然语言与Excel、数据库、数仓等多种数据源交互,并支持分析报告。"
-#: ../../index.rst:23 7a42f17049b943f88dd8f17baa440144
+#: ../../index.rst:30 9c63ecf927874f9ea79f1ef5c1535e67
+msgid "**3.SMMF(Service-oriented Multi-model Management Framework)**"
+msgstr "**3.多模型管理**"
+
+#: ../../index.rst:31 d6cfb9b69f9743d083c4644c90fd6108
msgid ""
-"Plugins - Support custom plugin execution tasks and natively support the "
-"Auto-GPT plugin, such as:"
-msgstr "插件模型 - 支持自定义插件执行任务,并原生支持Auto-GPT插件,例如:* SQL自动执行,获取查询结果 * 自动爬取学习知识"
+"Supports a wide range of models, including dozens of large language "
+"models such as open-source models and API proxies. Examples include "
+"LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, Xinghuo, etc."
+msgstr "海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱、星火等。"
-#: ../../index.rst:26 8b48d7b60bbc439da50a624c4048e6f6
+#: ../../index.rst:33 dda6cec4316e48f2afe77005baa53a06
+msgid "**4.Automated Fine-tuning**"
+msgstr "**4.自动化微调**"
+
+#: ../../index.rst:34 7cf1654a9779444ab3982435887d087b
msgid ""
-"Unified vector storage/indexing of knowledge base - Support for "
-"unstructured data such as PDF, Markdown, CSV, and WebURL"
-msgstr "知识库统一向量存储/索引 - 非结构化数据支持包括PDF、MarkDown、CSV、WebURL"
+"A lightweight framework for automated fine-tuning built around large "
+"language models, Text2SQL datasets, and methods like LoRA/QLoRA/Pturning."
+" Makes TextSQL fine-tuning as convenient as a production line."
+msgstr ""
+"围绕大语言模型、Text2SQL数据集、LoRA/QLoRA/Pturning等微调方法构建的自动化微调轻量框架, "
+"让TextSQL微调像流水线一样方便。"
-#: ../../index.rst:29 97df482893924bd18e9a101922e7c374
-#, fuzzy
+#: ../../index.rst:36 f58f114546f04b658aaa67fd895fba2b
+msgid "**5.Data-Driven Multi-Agents & Plugins**"
+msgstr "**5.数据驱动的插件模型**"
+
+#: ../../index.rst:37 a93fdca3de054cb0812d7f5ca3d12375
msgid ""
-"Multi LLMs Support - Supports multiple large language models, currently "
-"supporting Vicuna (7b, 13b), ChatGLM-6b (int4, int8) - TODO: codegen2, "
-"codet5p"
-msgstr "多模型支持 - 支持多种大语言模型, 当前已支持Vicuna(7b,13b), ChatGLM-6b(int4, int8)"
+"Supports executing tasks through custom plugins and natively supports the"
+" Auto-GPT plugin model. Agents protocol follows the Agent Protocol "
+"standard."
+msgstr "支持自定义插件执行任务,原生支持Auto-GPT插件模型,Agents协议采用Agent Protocol标准"
-#: ../../index.rst:35 1ef26ead30ed4b7fb966c8a17307cdc5
+#: ../../index.rst:39 3a0e89b151694e4b8e87646efe313568
+msgid "**6.Privacy and Security**"
+msgstr "**6.隐私安全**"
+
+#: ../../index.rst:40 aa50fc40f22f4fae8225a0a0a97c17dc
msgid ""
-"How to get started using DB-GPT to interact with your data and "
-"environment."
-msgstr "开始使用DB-GPT与您的数据环境进行交互。"
+"Ensures data privacy and security through techniques such as privatizing "
+"large models and proxy de-identification."
+msgstr "通过私有化大模型、代理脱敏等多种技术保障数据的隐私安全"
-#: ../../index.rst:36 3b44ab3576944bf6aa221f35bc051f4e
+#: ../../index.rst:46 d8bf21a7abd749608cddcdb2e358f3be
+msgid "Quickstart"
+msgstr "快速开始"
+
+#: ../../index.rst:48 d1f117a7cbb94c80afc0660e899d8154
#, fuzzy
msgid "`Quickstart Guide <./getting_started/getting_started.html>`_"
msgstr "`使用指南 <./getting_started/getting_started.html>`_"
-#: ../../index.rst:38 430cb239cdce42a0b62db46aba3f3bdb
+#: ../../index.rst:50 5fd56979f31b4a0b93082004f1cb90c7
msgid "Concepts and terminology"
msgstr "相关概念"
-#: ../../index.rst:40 ded4d9f80066498e90ba6214520013f7
+#: ../../index.rst:52 09c6889d02fa417c9ffde312211726f0
#, fuzzy
msgid "`Concepts and Terminology <./getting_started/concepts.html>`_"
msgstr "`相关概念 <./getting_started/concepts.html>`_"
-#: ../../index.rst:42 cd662e53621e474d901146813c750044
-msgid "Coming soon..."
-msgstr ""
-
-#: ../../index.rst:44 15edba57f1de44af8aff76735a2593de
-msgid "`Tutorials <.getting_started/tutorials.html>`_"
-msgstr "`教程 <.getting_started/tutorials.html>`_"
-
-#: ../../index.rst:62 779454b29d8e4e6eb21497025922d1b8
+#: ../../index.rst:72 5bd727134fc94cfb88abb755ccceac03
msgid ""
"These modules are the core abstractions with which we can interact with "
-"data and environment smoothly."
-msgstr "这些模块是我们可以与数据和环境顺利地进行交互的核心组成。"
+"data and environment smoothly. It's very important for DB-GPT, DB-GPT "
+"also provide standard, extendable interfaces."
+msgstr "这些模块是我们能够与数据和环境顺利交互的核心抽象。这对于DB-GPT来说非常重要,DB-GPT还提供了标准的、可扩展的接口。"
-#: ../../index.rst:63 bcd0e8c88c7b4807a91dd442416bec19
-msgid ""
-"It's very important for DB-GPT, DB-GPT also provide standard, extendable "
-"interfaces."
-msgstr "DB-GPT还提供了标准的、可扩展的接口。"
-
-#: ../../index.rst:65 1e785dc6925045e8ba106cf4a3b17cac
+#: ../../index.rst:74 1a5eb0b7cb884309be3431112c8f38e5
msgid ""
"The docs for each module contain quickstart examples, how to guides, "
"reference docs, and conceptual guides."
msgstr "每个模块的文档都包含快速入门的例子、操作指南、参考文档和相关概念等内容。"
-#: ../../index.rst:67 9c9fddd14bfd40339889f5d1f0b04163
+#: ../../index.rst:76 24aa8c08d1dc460ab23d69a5bb9c8fc3
msgid "The modules are as follows"
msgstr "组成模块如下:"
-#: ../../index.rst:69 4a19083cadd04b8e8b649a622e0ceccd
+#: ../../index.rst:78 9f4280cca1f743cb9b868cc67e3f3ce7
msgid ""
"`LLMs <./modules/llms.html>`_: Supported multi models management and "
"integrations."
msgstr "`LLMs <./modules/llms.html>`_:基于FastChat提供大模型的运行环境。支持多模型管理和集成。 "
-#: ../../index.rst:71 436a139225574aa5b066a1835d38238d
+#: ../../index.rst:80 d357811f110f40e79f0c20ef9cb60d0c
msgid ""
"`Prompts <./modules/prompts.html>`_: Prompt management, optimization, and"
" serialization for multi database."
@@ -171,41 +196,35 @@ msgstr ""
"`Prompt自动生成与优化 <./modules/prompts.html>`_: 自动化生成高质量的Prompt "
",并进行优化,提高系统的响应效率"
-#: ../../index.rst:73 6c53edfb2e494c5fba6efb5ade48c310
+#: ../../index.rst:82 3cb9acc9f11a46638e6687f743d6b7f3
msgid "`Plugins <./modules/plugins.html>`_: Plugins management, scheduler."
msgstr "`Agent与插件: <./modules/plugins.html>`_:提供Agent和插件机制,使得用户可以自定义并增强系统的行为。"
-#: ../../index.rst:75 6328760e8faf4e8296f3e1edd486316c
+#: ../../index.rst:84 b24c462cb5364890a6ca990f09f48cfc
#, fuzzy
msgid ""
"`Knowledge <./modules/knowledge.html>`_: Knowledge management, embedding,"
" and search."
msgstr "`知识库能力: <./modules/knowledge.html>`_: 支持私域知识库问答能力, "
-#: ../../index.rst:77 da272ccf56e3498d92009ac7101b0c45
+#: ../../index.rst:86 7448b231fe8745f1965a1f48ffc5444a
msgid ""
"`Connections <./modules/connections.html>`_: Supported multi databases "
"connection. management connections and interact with this."
msgstr "`连接模块 <./modules/connections.html>`_: 用于连接不同的模块和数据源,实现数据的流转和交互 "
-#: ../../index.rst:79 1a0551f62d9d418a9dec267fbcb49af0
+#: ../../index.rst:88 c677fb24869347ff907f1529ef333b6b
#, fuzzy
msgid "`Vector <./modules/vector.html>`_: Supported multi vector database."
msgstr "`LLMs <./modules/llms.html>`_:基于FastChat提供大模型的运行环境。支持多模型管理和集成。 "
-#: ../../index.rst:97 9aceee0dbe1e4f7da499ac6aab23aea2
-msgid ""
-"Full documentation on all methods, classes, installation methods, and "
-"integration setups for DB-GPT."
-msgstr "关于DB-GPT的所有方法、类、安装方法和集成设置的完整文档。"
-
-#: ../../index.rst:111 c9a729f4e1964894bae215793647ab75
+#: ../../index.rst:108 2e56f2cb1a8b40dda9465c0a1af94196
msgid ""
"Additional resources we think may be useful as you develop your "
"application!"
-msgstr "“我们认为在您开发应用程序时可能有用的其他资源!”"
+msgstr "我们认为在您开发应用程序时可能有用的其他资源!"
-#: ../../index.rst:113 06e6e4b7776c405fa94ae7b59253162d
+#: ../../index.rst:110 590362cb3b7442d49eafa58cb323e127
msgid ""
"`Discord `_: if your have some problem or "
"ideas, you can talk from discord."
@@ -278,3 +297,270 @@ msgstr "`Discord `_:如果您有任何问题,可
#~ "autonomoly."
#~ msgstr "`插件工具 <./use_cases/tool_use_with_plugin>`_: 根据插件使用工具自主管理数据库。"
+#~ msgid "Reference"
+#~ msgstr "参考"
+
+#~ msgid "Welcome to DB-GPT!"
+#~ msgstr "欢迎来到DB-GPT中文文档"
+
+#~ msgid ""
+#~ "As large models are released and "
+#~ "iterated upon, they are becoming "
+#~ "increasingly intelligent. However, in the "
+#~ "process of using large models, we "
+#~ "face significant challenges in data "
+#~ "security and privacy. We need to "
+#~ "ensure that our sensitive data and "
+#~ "environments remain completely controlled and"
+#~ " avoid any data privacy leaks or "
+#~ "security risks. Based on this, we "
+#~ "have launched the DB-GPT project "
+#~ "to build a complete private large "
+#~ "model solution for all database-based"
+#~ " scenarios. This solution supports local"
+#~ " deployment, allowing it to be "
+#~ "applied not only in independent private"
+#~ " environments but also to be "
+#~ "independently deployed and isolated according"
+#~ " to business modules, ensuring that "
+#~ "the ability of large models is "
+#~ "absolutely private, secure, and controllable."
+#~ msgstr ""
+#~ "随着大型模型的发布和迭代,它们变得越来越智能。然而,在使用大型模型的过程中,我们在数据安全和隐私方面面临着重大挑战。我们需要确保我们的敏感数据和环境得到完全控制,避免任何数据隐私泄露或安全风险。基于此"
+#~ ",我们启动了DB-"
+#~ "GPT项目,为所有基于数据库的场景构建一个完整的私有大模型解决方案。该方案“”支持本地部署,既可应用于“独立私有环境”,又可根据业务模块进行“独立部署”和“隔离”,确保“大模型”的能力绝对私有、安全、可控。"
+
+#~ msgid ""
+#~ "**DB-GPT** is an experimental open-"
+#~ "source project that uses localized GPT"
+#~ " large models to interact with your"
+#~ " data and environment. With this "
+#~ "solution, you can be assured that "
+#~ "there is no risk of data leakage,"
+#~ " and your data is 100% private "
+#~ "and secure."
+#~ msgstr ""
+#~ "DB-GPT "
+#~ "是一个开源的以数据库为基础的GPT实验项目,使用本地化的GPT大模型与您的数据和环境进行交互,无数据泄露风险100% "
+#~ "私密,100% 安全。"
+
+#~ msgid ""
+#~ "Currently, we have released multiple key"
+#~ " features, which are listed below to"
+#~ " demonstrate our current capabilities:"
+#~ msgstr "目前我们已经发布了多种关键的特性,这里一一列举展示一下当前发布的能力。"
+
+#~ msgid "SQL language capabilities - SQL generation - SQL diagnosis"
+#~ msgstr "SQL语言能力 - SQL生成 - SQL诊断"
+
+#~ msgid ""
+#~ "Private domain Q&A and data processing"
+#~ " - Database knowledge Q&A - Data "
+#~ "processing"
+#~ msgstr "私有领域问答与数据处理 - 数据库知识问答 - 数据处理"
+
+#~ msgid ""
+#~ "Plugins - Support custom plugin "
+#~ "execution tasks and natively support the"
+#~ " Auto-GPT plugin, such as:"
+#~ msgstr "插件模型 - 支持自定义插件执行任务,并原生支持Auto-GPT插件,例如:* SQL自动执行,获取查询结果 * 自动爬取学习知识"
+
+#~ msgid ""
+#~ "Unified vector storage/indexing of knowledge"
+#~ " base - Support for unstructured data"
+#~ " such as PDF, Markdown, CSV, and "
+#~ "WebURL"
+#~ msgstr "知识库统一向量存储/索引 - 非结构化数据支持包括PDF、MarkDown、CSV、WebURL"
+
+#~ msgid ""
+#~ "Multi LLMs Support - Supports multiple"
+#~ " large language models, currently "
+#~ "supporting Vicuna (7b, 13b), ChatGLM-6b"
+#~ " (int4, int8) - TODO: codegen2, "
+#~ "codet5p"
+#~ msgstr "多模型支持 - 支持多种大语言模型, 当前已支持Vicuna(7b,13b), ChatGLM-6b(int4, int8)"
+
+#~ msgid ""
+#~ "Full documentation on all methods, "
+#~ "classes, installation methods, and integration"
+#~ " setups for DB-GPT."
+#~ msgstr "关于DB-GPT的所有方法、类、安装方法和集成设置的完整文档。"
+
+#~ msgid ""
+#~ "**DB-GPT** is an open-source "
+#~ "framework for large models in the "
+#~ "database field. Its purpose is to "
+#~ "build infrastructure for the domain of"
+#~ " large models, making it easier and"
+#~ " more convenient to develop applications"
+#~ " around databases."
+#~ msgstr ""
+
+#~ msgid "By developing various technical capabilities such as"
+#~ msgstr ""
+
+#~ msgid "SMMF(Service-oriented Multi-model Management Framework)"
+#~ msgstr ""
+
+#~ msgid "Text2SQL Fine-tuning"
+#~ msgstr ""
+
+#~ msgid "RAG(Retrieval Augmented Generation) framework and optimization"
+#~ msgstr ""
+
+#~ msgid "Data-Driven Agents framework collaboration"
+#~ msgstr ""
+
+#~ msgid ""
+#~ "5. GBI(Generative Business intelligence) etc,"
+#~ " DB-GPT simplifies the construction "
+#~ "of large model applications based on "
+#~ "databases."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**1. Private Domain Q&A & Data "
+#~ "Processing** Supports custom construction of"
+#~ " knowledge bases through methods such "
+#~ "as built-in, multi-file format "
+#~ "uploads, and plugin-based web scraping."
+#~ " Enables unified vector storage and "
+#~ "retrieval of massive structured and "
+#~ "unstructured data."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**2.Multi-Data Source & GBI(Generative "
+#~ "Business intelligence)** Supports interaction "
+#~ "between natural language and various "
+#~ "data sources such as Excel, databases,"
+#~ " and data warehouses. Also supports "
+#~ "analysis reporting."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**3.SMMF(Service-oriented Multi-model "
+#~ "Management Framework)** Supports a wide "
+#~ "range of models, including dozens of "
+#~ "large language models such as open-"
+#~ "source models and API proxies. Examples"
+#~ " include LLaMA/LLaMA2, Baichuan, ChatGLM, "
+#~ "Wenxin, Tongyi, Zhipu, Xinghuo, etc."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**4.Automated Fine-tuning** A lightweight "
+#~ "framework for automated fine-tuning "
+#~ "built around large language models, "
+#~ "Text2SQL datasets, and methods like "
+#~ "LoRA/QLoRA/Pturning. Makes TextSQL fine-tuning"
+#~ " as convenient as a production line."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**5.Data-Driven Multi-Agents & Plugins**"
+#~ " Supports executing tasks through custom"
+#~ " plugins and natively supports the "
+#~ "Auto-GPT plugin model. Agents protocol "
+#~ "follows the Agent Protocol standard."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**6.Privacy and Security** Ensures data "
+#~ "privacy and security through techniques "
+#~ "such as privatizing large models and "
+#~ "proxy de-identification."
+#~ msgstr ""
+
+#~ msgid "Coming soon..."
+#~ msgstr ""
+
+#~ msgid "`Tutorials <.getting_started/tutorials.html>`_"
+#~ msgstr "`教程 <.getting_started/tutorials.html>`_"
+
+#~ msgid ""
+#~ "DB-GPT is an open-source framework"
+#~ " for large models in the database "
+#~ "field. Its purpose is to build "
+#~ "infrastructure for the domain of large"
+#~ " models, making it easier and more"
+#~ " convenient to develop applications around"
+#~ " databases. By developing various technical"
+#~ " capabilities such as **1. SMMF(Service-"
+#~ "oriented Multi-model Management Framework)**"
+#~ " **2. Text2SQL Fine-tuning** **3. "
+#~ "RAG(Retrieval Augmented Generation) framework "
+#~ "and optimization** **4. Data-Driven "
+#~ "Agents framework collaboration** **5. "
+#~ "GBI(Generative Business intelligence)** etc, "
+#~ "DB-GPT simplifies the construction of "
+#~ "large model applications based on "
+#~ "databases."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**1. Private Domain Q&A & Data "
+#~ "Processing** ::Supports custom construction of"
+#~ " knowledge bases through methods such "
+#~ "as built-in, multi-file format "
+#~ "uploads, and plugin-based web scraping."
+#~ " Enables unified vector storage and "
+#~ "retrieval of massive structured and "
+#~ "unstructured data."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**2.Multi-Data Source & GBI(Generative "
+#~ "Business intelligence)** ::Supports interaction "
+#~ "between natural language and various "
+#~ "data sources such as Excel, databases,"
+#~ " and data warehouses. Also supports "
+#~ "analysis reporting."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**3.SMMF(Service-oriented Multi-model "
+#~ "Management Framework)** ::Supports a wide "
+#~ "range of models, including dozens of "
+#~ "large language models such as open-"
+#~ "source models and API proxies. Examples"
+#~ " include LLaMA/LLaMA2, Baichuan, ChatGLM, "
+#~ "Wenxin, Tongyi, Zhipu, Xinghuo, etc."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**4.Automated Fine-tuning** ::A lightweight"
+#~ " framework for automated fine-tuning "
+#~ "built around large language models, "
+#~ "Text2SQL datasets, and methods like "
+#~ "LoRA/QLoRA/Pturning. Makes TextSQL fine-tuning"
+#~ " as convenient as a production line."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**5.Data-Driven Multi-Agents & Plugins**"
+#~ " ::Supports executing tasks through custom"
+#~ " plugins and natively supports the "
+#~ "Auto-GPT plugin model. Agents protocol "
+#~ "follows the Agent Protocol standard."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "**6.Privacy and Security** ::Ensures data "
+#~ "privacy and security through techniques "
+#~ "such as privatizing large models and "
+#~ "proxy de-identification."
+#~ msgstr ""
+
+#~ msgid ""
+#~ "How to get started using DB-GPT"
+#~ " to interact with your data and "
+#~ "environment."
+#~ msgstr "开始使用DB-GPT与您的数据环境进行交互。"
+
+#~ msgid ""
+#~ "It's very important for DB-GPT, "
+#~ "DB-GPT also provide standard, extendable"
+#~ " interfaces."
+#~ msgstr "DB-GPT还提供了标准的、可扩展的接口。"
+
diff --git a/docs/reference.md b/docs/reference.md
deleted file mode 100644
index 4a938e09d..000000000
--- a/docs/reference.md
+++ /dev/null
@@ -1 +0,0 @@
-# Reference
\ No newline at end of file
diff --git a/pilot/awel/__init__.py b/pilot/awel/__init__.py
new file mode 100644
index 000000000..6c5313b5d
--- /dev/null
+++ b/pilot/awel/__init__.py
@@ -0,0 +1,60 @@
+"""Agentic Workflow Expression Language (AWEL)"""
+
+from .dag.base import DAGContext, DAG
+
+from .operator.base import BaseOperator, WorkflowRunner, initialize_awel
+from .operator.common_operator import (
+ JoinOperator,
+ ReduceStreamOperator,
+ MapOperator,
+ BranchOperator,
+ InputOperator,
+ BranchFunc,
+)
+
+from .operator.stream_operator import (
+ StreamifyAbsOperator,
+ UnstreamifyAbsOperator,
+ TransformStreamAbsOperator,
+)
+
+from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
+from .task.task_impl import (
+ SimpleInputSource,
+ SimpleCallDataInputSource,
+ DefaultTaskContext,
+ DefaultInputContext,
+ SimpleTaskOutput,
+ SimpleStreamTaskOutput,
+ _is_async_iterator,
+)
+from .runner.local_runner import DefaultWorkflowRunner
+
+__all__ = [
+ "initialize_awel",
+ "DAGContext",
+ "DAG",
+ "BaseOperator",
+ "JoinOperator",
+ "ReduceStreamOperator",
+ "MapOperator",
+ "BranchOperator",
+ "InputOperator",
+ "BranchFunc",
+ "WorkflowRunner",
+ "TaskState",
+ "TaskOutput",
+ "TaskContext",
+ "InputContext",
+ "InputSource",
+ "DefaultWorkflowRunner",
+ "SimpleInputSource",
+ "SimpleCallDataInputSource",
+ "DefaultTaskContext",
+ "DefaultInputContext",
+ "SimpleTaskOutput",
+ "SimpleStreamTaskOutput",
+ "StreamifyAbsOperator",
+ "UnstreamifyAbsOperator",
+ "TransformStreamAbsOperator",
+]
diff --git a/pilot/awel/dag/__init__.py b/pilot/awel/dag/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/awel/dag/base.py b/pilot/awel/dag/base.py
new file mode 100644
index 000000000..a6ad08990
--- /dev/null
+++ b/pilot/awel/dag/base.py
@@ -0,0 +1,270 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Dict, List, Sequence, Union, Any
+import uuid
+import contextvars
+import threading
+import asyncio
+from collections import deque
+
+from ..resource.base import ResourceGroup
+from ..task.base import TaskContext
+
+DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
+
+
+def _is_async_context():
+ try:
+ loop = asyncio.get_running_loop()
+ return asyncio.current_task(loop=loop) is not None
+ except RuntimeError:
+ return False
+
+
+class DependencyMixin(ABC):
+ @abstractmethod
+ def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
+ """Set one or more upstream nodes for this node.
+
+ Args:
+ nodes (DependencyType): Upstream nodes to be set to current node.
+
+ Returns:
+ DependencyMixin: Returns self to allow method chaining.
+
+ Raises:
+ ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
+ """
+
+ @abstractmethod
+ def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
+ """Set one or more downstream nodes for this node.
+
+ Args:
+ nodes (DependencyType): Downstream nodes to be set to current node.
+
+ Returns:
+ DependencyMixin: Returns self to allow method chaining.
+
+ Raises:
+ ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
+ """
+
+ def __lshift__(self, nodes: DependencyType) -> DependencyType:
+ """Implements self << nodes
+
+ Example:
+
+ .. code-block:: python
+
+ # means node.set_upstream(input_node)
+ node << input_node
+
+ # means node2.set_upstream([input_node])
+ node2 << [input_node]
+ """
+ self.set_upstream(nodes)
+ return nodes
+
+ def __rshift__(self, nodes: DependencyType) -> DependencyType:
+ """Implements self >> nodes
+
+ Example:
+
+ .. code-block:: python
+
+ # means node.set_downstream(next_node)
+ node >> next_node
+
+ # means node2.set_downstream([next_node])
+ node2 >> [next_node]
+
+ """
+ self.set_downstream(nodes)
+ return nodes
+
+ def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
+ """Implements [node] >> self"""
+ self.__lshift__(nodes)
+ return self
+
+ def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
+ """Implements [node] << self"""
+ self.__rshift__(nodes)
+ return self
+
+
+class DAGVar:
+ _thread_local = threading.local()
+ _async_local = contextvars.ContextVar("current_dag_stack", default=deque())
+
+ @classmethod
+ def enter_dag(cls, dag) -> None:
+ is_async = _is_async_context()
+ if is_async:
+ stack = cls._async_local.get()
+ stack.append(dag)
+ cls._async_local.set(stack)
+ else:
+ if not hasattr(cls._thread_local, "current_dag_stack"):
+ cls._thread_local.current_dag_stack = deque()
+ cls._thread_local.current_dag_stack.append(dag)
+
+ @classmethod
+ def exit_dag(cls) -> None:
+ is_async = _is_async_context()
+ if is_async:
+ stack = cls._async_local.get()
+ if stack:
+ stack.pop()
+ cls._async_local.set(stack)
+ else:
+ if (
+ hasattr(cls._thread_local, "current_dag_stack")
+ and cls._thread_local.current_dag_stack
+ ):
+ cls._thread_local.current_dag_stack.pop()
+
+ @classmethod
+ def get_current_dag(cls) -> Optional["DAG"]:
+ is_async = _is_async_context()
+ if is_async:
+ stack = cls._async_local.get()
+ return stack[-1] if stack else None
+ else:
+ if (
+ hasattr(cls._thread_local, "current_dag_stack")
+ and cls._thread_local.current_dag_stack
+ ):
+ return cls._thread_local.current_dag_stack[-1]
+ return None
+
+
+class DAGNode(DependencyMixin, ABC):
+ resource_group: Optional[ResourceGroup] = None
+ """The resource group of current DAGNode"""
+
+ def __init__(
+ self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None
+ ) -> None:
+ super().__init__()
+ self._upstream: List["DAGNode"] = []
+ self._downstream: List["DAGNode"] = []
+ self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
+ if not node_id and self._dag:
+ node_id = self._dag._new_node_id()
+ self._node_id: str = node_id
+ self._node_name: str = node_name
+
+ @property
+ def node_id(self) -> str:
+ return self._node_id
+
+ def set_node_id(self, node_id: str) -> None:
+ self._node_id = node_id
+
+ def __hash__(self) -> int:
+ if self.node_id:
+ return hash(self.node_id)
+ else:
+ return super().__hash__()
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, DAGNode):
+ return False
+ return self.node_id == other.node_id
+
+ @property
+ def node_name(self) -> str:
+ return self._node_name
+
+ @property
+ def dag(self) -> "DAGNode":
+ return self._dag
+
+ def set_upstream(self, nodes: DependencyType) -> "DAGNode":
+ self.set_dependency(nodes)
+
+ def set_downstream(self, nodes: DependencyType) -> "DAGNode":
+ self.set_dependency(nodes, is_upstream=False)
+
+ @property
+ def upstream(self) -> List["DAGNode"]:
+ return self._upstream
+
+ @property
+ def downstream(self) -> List["DAGNode"]:
+ return self._downstream
+
+ def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
+ if not isinstance(nodes, Sequence):
+ nodes = [nodes]
+ if not all(isinstance(node, DAGNode) for node in nodes):
+ raise ValueError(
+ "all nodes to set dependency to current node must be instance of 'DAGNode'"
+ )
+ nodes: Sequence[DAGNode] = nodes
+ dags = set([node.dag for node in nodes if node.dag])
+ if self.dag:
+ dags.add(self.dag)
+ if not dags:
+ raise ValueError("set dependency to current node must in a DAG context")
+ if len(dags) != 1:
+ raise ValueError(
+ "set dependency to current node just support in one DAG context"
+ )
+ dag = dags.pop()
+ self._dag = dag
+
+ dag._append_node(self)
+ for node in nodes:
+ if is_upstream and node not in self.upstream:
+ node._dag = dag
+ dag._append_node(node)
+
+ self._upstream.append(node)
+ node._downstream.append(self)
+ elif node not in self._downstream:
+ node._dag = dag
+ dag._append_node(node)
+
+ self._downstream.append(node)
+ node._upstream.append(self)
+
+
+class DAGContext:
+ def __init__(self) -> None:
+ self._curr_task_ctx = None
+ self._share_data: Dict[str, Any] = {}
+
+ @property
+ def current_task_context(self) -> TaskContext:
+ return self._curr_task_ctx
+
+ def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
+ self._curr_task_ctx = _curr_task_ctx
+
+ async def get_share_data(self, key: str) -> Any:
+ return self._share_data.get(key)
+
+ async def save_to_share_data(self, key: str, data: Any) -> None:
+ self._share_data[key] = data
+
+
+class DAG:
+ def __init__(
+ self, dag_id: str, resource_group: Optional[ResourceGroup] = None
+ ) -> None:
+ self.node_map: Dict[str, DAGNode] = {}
+
+ def _append_node(self, node: DAGNode) -> None:
+ self.node_map[node.node_id] = node
+
+ def _new_node_id(self) -> str:
+ return str(uuid.uuid4())
+
+ def __enter__(self):
+ DAGVar.enter_dag(self)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ DAGVar.exit_dag()
diff --git a/pilot/awel/dag/tests/__init__.py b/pilot/awel/dag/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/awel/dag/tests/test_dag.py b/pilot/awel/dag/tests/test_dag.py
new file mode 100644
index 000000000..c30530dc8
--- /dev/null
+++ b/pilot/awel/dag/tests/test_dag.py
@@ -0,0 +1,51 @@
+import pytest
+import threading
+import asyncio
+from ..dag import DAG, DAGContext
+
+
+def test_dag_context_sync():
+ dag1 = DAG("dag1")
+ dag2 = DAG("dag2")
+
+ with dag1:
+ assert DAGContext.get_current_dag() == dag1
+ with dag2:
+ assert DAGContext.get_current_dag() == dag2
+ assert DAGContext.get_current_dag() == dag1
+ assert DAGContext.get_current_dag() is None
+
+
+def test_dag_context_threading():
+ def thread_function(dag):
+ DAGContext.enter_dag(dag)
+ assert DAGContext.get_current_dag() == dag
+ DAGContext.exit_dag()
+
+ dag1 = DAG("dag1")
+ dag2 = DAG("dag2")
+
+ thread1 = threading.Thread(target=thread_function, args=(dag1,))
+ thread2 = threading.Thread(target=thread_function, args=(dag2,))
+
+ thread1.start()
+ thread2.start()
+ thread1.join()
+ thread2.join()
+
+ assert DAGContext.get_current_dag() is None
+
+
+@pytest.mark.asyncio
+async def test_dag_context_async():
+ async def async_function(dag):
+ DAGContext.enter_dag(dag)
+ assert DAGContext.get_current_dag() == dag
+ DAGContext.exit_dag()
+
+ dag1 = DAG("dag1")
+ dag2 = DAG("dag2")
+
+ await asyncio.gather(async_function(dag1), async_function(dag2))
+
+ assert DAGContext.get_current_dag() is None
diff --git a/pilot/awel/operator/__init__.py b/pilot/awel/operator/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/awel/operator/base.py b/pilot/awel/operator/base.py
new file mode 100644
index 000000000..b6d1a4e14
--- /dev/null
+++ b/pilot/awel/operator/base.py
@@ -0,0 +1,177 @@
+from abc import ABC, abstractmethod, ABCMeta
+
+from types import FunctionType
+from typing import (
+ List,
+ Generic,
+ TypeVar,
+ AsyncIterator,
+ Union,
+ Any,
+ Dict,
+ Optional,
+ cast,
+)
+import functools
+from inspect import signature
+
+from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
+from ..task.base import (
+ TaskContext,
+ TaskOutput,
+ TaskState,
+ OUT,
+ T,
+ InputContext,
+ InputSource,
+)
+
+F = TypeVar("F", bound=FunctionType)
+
+CALL_DATA = Union[Dict, Dict[str, Dict]]
+
+
+class WorkflowRunner(ABC, Generic[T]):
+ """Abstract base class representing a runner for executing workflows in a DAG.
+
+ This class defines the interface for executing workflows within the DAG,
+ handling the flow from one DAG node to another.
+ """
+
+ @abstractmethod
+ async def execute_workflow(
+ self, node: "BaseOperator", call_data: Optional[CALL_DATA] = None
+ ) -> DAGContext:
+ """Execute the workflow starting from a given operator.
+
+ Args:
+ node (RunnableDAGNode): The starting node of the workflow to be executed.
+ call_data (CALL_DATA): The data pass to root operator node.
+
+ Returns:
+ DAGContext: The context after executing the workflow, containing the final state and data.
+ """
+
+
+default_runner: WorkflowRunner = None
+
+
+class BaseOperatorMeta(ABCMeta):
+ """Metaclass of BaseOperator."""
+
+ @classmethod
+ def _apply_defaults(cls, func: F) -> F:
+ sig_cache = signature(func)
+
+ @functools.wraps(func)
+ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
+ dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
+ task_id: Optional[str] = kwargs.get("task_id")
+ if not task_id and dag:
+ task_id = dag._new_node_id()
+ runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
+ # print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}")
+ # for arg in sig_cache.parameters:
+ # if arg not in kwargs:
+ # kwargs[arg] = default_args[arg]
+ if not kwargs.get("dag"):
+ kwargs["dag"] = dag
+ if not kwargs.get("task_id"):
+ kwargs["task_id"] = task_id
+ if not kwargs.get("runner"):
+ kwargs["runner"] = runner
+ real_obj = func(self, *args, **kwargs)
+ return real_obj
+
+ return cast(T, apply_defaults)
+
+ def __new__(cls, name, bases, namespace, **kwargs):
+ new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
+ new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
+ return new_cls
+
+
+class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
+ """Abstract base class for operator nodes that can be executed within a workflow.
+
+ This class extends DAGNode by adding execution capabilities.
+ """
+
+ def __init__(
+ self,
+ task_id: Optional[str] = None,
+ task_name: Optional[str] = None,
+ dag: Optional[DAG] = None,
+ runner: WorkflowRunner = None,
+ **kwargs,
+ ) -> None:
+ """Initializes a BaseOperator with an optional workflow runner.
+
+ Args:
+ runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
+ """
+ super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
+ if not runner:
+ from pilot.awel import DefaultWorkflowRunner
+
+ runner = DefaultWorkflowRunner()
+
+ self._runner: WorkflowRunner = runner
+ self._dag_ctx: DAGContext = None
+
+ @property
+ def current_dag_context(self) -> DAGContext:
+ return self._dag_ctx
+
+ async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
+ if not self.node_id:
+ raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
+ self._dag_ctx = dag_ctx
+ return await self._do_run(dag_ctx)
+
+ @abstractmethod
+ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
+ """
+ Abstract method to run the task within the DAG node.
+
+ Args:
+ dag_ctx (DAGContext): The context of the DAG when this node is run.
+
+ Returns:
+ TaskOutput[OUT]: The task output after this node has been run.
+ """
+
+ async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
+ """Execute the node and return the output.
+
+ This method is a high-level wrapper for executing the node.
+
+ Args:
+ call_data (CALL_DATA): The data pass to root operator node.
+
+ Returns:
+ OUT: The output of the node after execution.
+ """
+ out_ctx = await self._runner.execute_workflow(self, call_data)
+ return out_ctx.current_task_context.task_output.output
+
+ async def call_stream(
+ self, call_data: Optional[CALL_DATA] = None
+ ) -> AsyncIterator[OUT]:
+ """Execute the node and return the output as a stream.
+
+ This method is used for nodes where the output is a stream.
+
+ Args:
+ call_data (CALL_DATA): The data pass to root operator node.
+
+ Returns:
+ AsyncIterator[OUT]: An asynchronous iterator over the output stream.
+ """
+ out_ctx = await self._runner.execute_workflow(self, call_data)
+ return out_ctx.current_task_context.task_output.output_stream
+
+
+def initialize_awel(runner: WorkflowRunner):
+ global default_runner
+ default_runner = runner
diff --git a/pilot/awel/operator/common_operator.py b/pilot/awel/operator/common_operator.py
new file mode 100644
index 000000000..6d12565aa
--- /dev/null
+++ b/pilot/awel/operator/common_operator.py
@@ -0,0 +1,239 @@
+from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
+import asyncio
+import logging
+
+from ..dag.base import DAGContext
+from ..task.base import (
+ TaskContext,
+ TaskOutput,
+ IN,
+ OUT,
+ InputContext,
+ InputSource,
+)
+
+from .base import BaseOperator
+
+
+logger = logging.getLogger(__name__)
+
+
+class JoinOperator(BaseOperator, Generic[OUT]):
+ """Operator that joins inputs using a custom combine function.
+
+ This node type is useful for combining the outputs of upstream nodes.
+ """
+
+ def __init__(self, combine_function, **kwargs):
+ super().__init__(**kwargs)
+ if not callable(combine_function):
+ raise ValueError("combine_function must be callable")
+ self.combine_function = combine_function
+
+ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
+ """Run the join operation on the DAG context's inputs.
+ Args:
+ dag_ctx (DAGContext): The current context of the DAG.
+
+ Returns:
+ TaskOutput[OUT]: The task output after this node has been run.
+ """
+ curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
+ input_ctx: InputContext = await curr_task_ctx.task_input.map_all(
+ self.combine_function
+ )
+ # All join result store in the first parent output
+ join_output = input_ctx.parent_outputs[0].task_output
+ curr_task_ctx.set_task_output(join_output)
+ return join_output
+
+
+class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
+ def __init__(self, reduce_function=None, **kwargs):
+ """Initializes a ReduceStreamOperator with a combine function.
+
+ Args:
+ combine_function: A function that defines how to combine inputs.
+
+ Raises:
+ ValueError: If the combine_function is not callable.
+ """
+ super().__init__(**kwargs)
+ if reduce_function and not callable(reduce_function):
+ raise ValueError("reduce_function must be callable")
+ self.reduce_function = reduce_function
+
+ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
+ """Run the join operation on the DAG context's inputs.
+
+ Args:
+ dag_ctx (DAGContext): The current context of the DAG.
+
+ Returns:
+ TaskOutput[OUT]: The task output after this node has been run.
+ """
+ curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
+ task_input = curr_task_ctx.task_input
+ if not task_input.check_stream():
+ raise ValueError("ReduceStreamOperator expects stream data")
+ if not task_input.check_single_parent():
+ raise ValueError("ReduceStreamOperator expects single parent")
+
+ reduce_function = self.reduce_function or self.reduce
+
+ input_ctx: InputContext = await task_input.reduce(reduce_function)
+ # All join result store in the first parent output
+ reduce_output = input_ctx.parent_outputs[0].task_output
+ curr_task_ctx.set_task_output(reduce_output)
+ return reduce_output
+
+ async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
+ raise NotImplementedError
+
+
+class MapOperator(BaseOperator, Generic[IN, OUT]):
+ """Map operator that applies a mapping function to its inputs.
+
+ This operator transforms its input data using a provided mapping function and
+ passes the transformed data downstream.
+ """
+
+ def __init__(self, map_function=None, **kwargs):
+ """Initializes a MapDAGNode with a mapping function.
+
+ Args:
+ map_function: A function that defines how to map the input data.
+
+ Raises:
+ ValueError: If the map_function is not callable.
+ """
+ super().__init__(**kwargs)
+ if map_function and not callable(map_function):
+ raise ValueError("map_function must be callable")
+ self.map_function = map_function
+
+ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
+ """Run the mapping operation on the DAG context's inputs.
+
+ This method applies the mapping function to the input context and updates
+ the DAG context with the new data.
+
+ Args:
+ dag_ctx (DAGContext[IN]): The current context of the DAG.
+
+ Returns:
+ TaskOutput[OUT]: The task output after this node has been run.
+
+ Raises:
+ ValueError: If not a single parent or the map_function is not callable
+ """
+ curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
+ if not curr_task_ctx.task_input.check_single_parent():
+ num_parents = len(curr_task_ctx.task_input.parent_outputs)
+ raise ValueError(
+ f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
+ )
+ map_function = self.map_function or self.map
+
+ input_ctx: InputContext = await curr_task_ctx.task_input.map(map_function)
+ # All join result store in the first parent output
+ reduce_output = input_ctx.parent_outputs[0].task_output
+ curr_task_ctx.set_task_output(reduce_output)
+ return reduce_output
+
+ async def map(self, input_value: IN) -> OUT:
+ raise NotImplementedError
+
+
+BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
+
+
+class BranchOperator(BaseOperator, Generic[IN, OUT]):
+ """Operator node that branches the workflow based on a provided function.
+
+ This node filters its input data using a branching function and
+ allows for conditional paths in the workflow.
+ """
+
+ def __init__(
+ self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
+ ):
+ """
+ Initializes a BranchDAGNode with a branching function.
+
+ Args:
+ branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
+
+ Raises:
+ ValueError: If the branch_function is not callable.
+ """
+ super().__init__(**kwargs)
+ if branches:
+ for branch_function, value in branches.items():
+ if not callable(branch_function):
+ raise ValueError("branch_function must be callable")
+ if isinstance(value, BaseOperator):
+ branches[branch_function] = value.node_name or value.node_name
+ self._branches = branches
+
+ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
+ """Run the branching operation on the DAG context's inputs.
+
+ This method applies the branching function to the input context to determine
+ the path of execution in the workflow.
+
+ Args:
+ dag_ctx (DAGContext[IN]): The current context of the DAG.
+
+ Returns:
+ TaskOutput[OUT]: The task output after this node has been run.
+ """
+ curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
+ task_input = curr_task_ctx.task_input
+ if task_input.check_stream():
+ raise ValueError("BranchDAGNode expects no stream data")
+ if not task_input.check_single_parent():
+ raise ValueError("BranchDAGNode expects single parent")
+
+ branches = self._branches
+ if not branches:
+ branches = await self.branchs()
+
+ branch_func_tasks = []
+ branch_nodes: List[str] = []
+ for func, node_name in branches.items():
+ branch_nodes.append(node_name)
+ branch_func_tasks.append(
+ curr_task_ctx.task_input.predicate_map(func, failed_value=None)
+ )
+
+ branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
+ parent_output = task_input.parent_outputs[0].task_output
+ curr_task_ctx.set_task_output(parent_output)
+ skip_node_names = []
+ for i, ctx in enumerate(branch_input_ctxs):
+ node_name = branch_nodes[i]
+ branch_out = ctx.parent_outputs[0].task_output
+ logger.info(
+ f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
+ )
+ if ctx.parent_outputs[0].task_output.is_empty:
+ logger.info(f"Skip node name {node_name}")
+ skip_node_names.append(node_name)
+ curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
+ return parent_output
+
+ async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
+ raise NotImplementedError
+
+
+class InputOperator(BaseOperator, Generic[OUT]):
+ def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
+ super().__init__(**kwargs)
+ self._input_source = input_source
+
+ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
+ curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
+ task_output = await self._input_source.read(curr_task_ctx)
+ curr_task_ctx.set_task_output(task_output)
+ return task_output
diff --git a/pilot/awel/operator/stream_operator.py b/pilot/awel/operator/stream_operator.py
new file mode 100644
index 000000000..7de916a83
--- /dev/null
+++ b/pilot/awel/operator/stream_operator.py
@@ -0,0 +1,90 @@
+from abc import ABC, abstractmethod
+from typing import Generic, AsyncIterator
+from ..task.base import OUT, IN, TaskOutput, TaskContext
+from ..dag.base import DAGContext
+from .base import BaseOperator
+
+
+class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
+ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
+ curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
+ output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
+ self.streamify
+ )
+ curr_task_ctx.set_task_output(output)
+ return output
+
+ @abstractmethod
+ async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
+ """Convert a value of IN to an AsyncIterator[OUT]
+
+ Args:
+ input_value (IN): The data of parent operator's output
+
+ Example:
+
+ .. code-block:: python
+
+ class MyStreamOperator(StreamifyAbsOperator[int, int]):
+ async def streamify(self, input_value: int) -> AsyncIterator[int]
+ for i in range(input_value):
+ yield i
+ """
+
+
+class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
+ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
+ curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
+ output = await curr_task_ctx.task_input.parent_outputs[
+ 0
+ ].task_output.unstreamify(self.unstreamify)
+ curr_task_ctx.set_task_output(output)
+ return output
+
+ @abstractmethod
+ async def unstreamify(self, input_value: AsyncIterator[IN]) -> OUT:
+ """Convert a value of AsyncIterator[IN] to an OUT.
+
+ Args:
+ input_value (AsyncIterator[IN])): The data of parent operator's output
+
+ Example:
+
+ .. code-block:: python
+
+ class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
+ async def unstreamify(self, input_value: AsyncIterator[int]) -> int
+ value_cnt = 0
+ async for v in input_value:
+ value_cnt += 1
+ return value_cnt
+ """
+
+
+class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
+ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
+ curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
+ output = await curr_task_ctx.task_input.parent_outputs[
+ 0
+ ].task_output.transform_stream(self.transform_stream)
+ curr_task_ctx.set_task_output(output)
+ return output
+
+ @abstractmethod
+ async def transform_stream(
+ self, input_value: AsyncIterator[IN]
+ ) -> AsyncIterator[OUT]:
+ """Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
+
+ Args:
+ input_value (AsyncIterator[IN])): The data of parent operator's output
+
+ Example:
+
+ .. code-block:: python
+
+ class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
+ async def unstreamify(self, input_value: AsyncIterator[int]) -> AsyncIterator[int]
+ async for v in input_value:
+ yield v + 1
+ """
diff --git a/pilot/awel/resource/__init__.py b/pilot/awel/resource/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/awel/resource/base.py b/pilot/awel/resource/base.py
new file mode 100644
index 000000000..97fefbbc3
--- /dev/null
+++ b/pilot/awel/resource/base.py
@@ -0,0 +1,8 @@
+from abc import ABC, abstractmethod
+
+
+class ResourceGroup(ABC):
+ @property
+ @abstractmethod
+ def name(self) -> str:
+ """The name of current resource group"""
diff --git a/pilot/awel/runner/__init__.py b/pilot/awel/runner/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/awel/runner/job_manager.py b/pilot/awel/runner/job_manager.py
new file mode 100644
index 000000000..7a1d12ead
--- /dev/null
+++ b/pilot/awel/runner/job_manager.py
@@ -0,0 +1,82 @@
+from typing import List, Set, Optional, Dict
+import uuid
+import logging
+from ..dag.base import DAG
+
+from ..operator.base import BaseOperator, CALL_DATA
+
+logger = logging.getLogger(__name__)
+
+
+class DAGNodeInstance:
+ def __init__(self, node_instance: DAG) -> None:
+ pass
+
+
+class DAGInstance:
+ def __init__(self, dag: DAG) -> None:
+ self._dag = dag
+
+
+class JobManager:
+ def __init__(
+ self,
+ root_nodes: List[BaseOperator],
+ all_nodes: List[BaseOperator],
+ end_node: BaseOperator,
+ id2call_data: Dict[str, Dict],
+ ) -> None:
+ self._root_nodes = root_nodes
+ self._all_nodes = all_nodes
+ self._end_node = end_node
+ self._id2node_data = id2call_data
+
+ @staticmethod
+ def build_from_end_node(
+ end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
+ ) -> "JobManager":
+ nodes = _build_from_end_node(end_node)
+ root_nodes = _get_root_nodes(nodes)
+ id2call_data = _save_call_data(root_nodes, call_data)
+ return JobManager(root_nodes, nodes, end_node, id2call_data)
+
+ def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
+ return self._id2node_data.get(node_id)
+
+
+def _save_call_data(
+ root_nodes: List[BaseOperator], call_data: CALL_DATA
+) -> Dict[str, Dict]:
+ id2call_data = {}
+ logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
+ if not call_data:
+ return id2call_data
+ if len(root_nodes) == 1:
+ node = root_nodes[0]
+ logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
+ id2call_data[node.node_id] = call_data
+ else:
+ for node in root_nodes:
+ node_id = node.node_id
+ logger.info(
+ f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
+ )
+ id2call_data[node_id] = call_data.get(node_id)
+ return id2call_data
+
+
+def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
+ nodes = []
+ if isinstance(end_node, BaseOperator):
+ task_id = end_node.node_id
+ if not task_id:
+ task_id = str(uuid.uuid4())
+ end_node.set_node_id(task_id)
+ nodes.append(end_node)
+ for node in end_node.upstream:
+ nodes += _build_from_end_node(node)
+ return nodes
+
+
+def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
+ return list(set(filter(lambda x: not x.upstream, nodes)))
diff --git a/pilot/awel/runner/local_runner.py b/pilot/awel/runner/local_runner.py
new file mode 100644
index 000000000..769223212
--- /dev/null
+++ b/pilot/awel/runner/local_runner.py
@@ -0,0 +1,106 @@
+from typing import Dict, Optional, Set, List
+import logging
+
+from ..dag.base import DAGContext
+from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
+from ..operator.common_operator import BranchOperator, JoinOperator
+from ..task.base import TaskContext, TaskState
+from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
+from .job_manager import JobManager
+
+logger = logging.getLogger(__name__)
+
+
+class DefaultWorkflowRunner(WorkflowRunner):
+ async def execute_workflow(
+ self, node: BaseOperator, call_data: Optional[CALL_DATA] = None
+ ) -> DAGContext:
+ # Create DAG context
+ dag_ctx = DAGContext()
+ job_manager = JobManager.build_from_end_node(node, call_data)
+ logger.info(
+ f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
+ )
+ dag = node.dag
+ # Save node output
+ node_outputs: Dict[str, TaskContext] = {}
+ skip_node_ids = set()
+ await self._execute_node(
+ job_manager, node, dag_ctx, node_outputs, skip_node_ids
+ )
+
+ return dag_ctx
+
+ async def _execute_node(
+ self,
+ job_manager: JobManager,
+ node: BaseOperator,
+ dag_ctx: DAGContext,
+ node_outputs: Dict[str, TaskContext],
+ skip_node_ids: Set[str],
+ ):
+ # Skip run node
+ if node.node_id in node_outputs:
+ return
+
+ # Run all upstream node
+ for upstream_node in node.upstream:
+ if isinstance(upstream_node, BaseOperator):
+ await self._execute_node(
+ job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
+ )
+
+ inputs = [
+ node_outputs[upstream_node.node_id] for upstream_node in node.upstream
+ ]
+ input_ctx = DefaultInputContext(inputs)
+ task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
+ task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
+
+ task_ctx.set_task_input(input_ctx)
+ dag_ctx.set_current_task_context(task_ctx)
+ task_ctx.set_current_state(TaskState.RUNNING)
+
+ if node.node_id in skip_node_ids:
+ task_ctx.set_current_state(TaskState.SKIP)
+ task_ctx.set_task_output(SimpleTaskOutput(None))
+ node_outputs[node.node_id] = task_ctx
+ return
+ try:
+ logger.info(
+ f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
+ )
+ await node._run(dag_ctx)
+ node_outputs[node.node_id] = dag_ctx.current_task_context
+ task_ctx.set_current_state(TaskState.SUCCESS)
+
+ if isinstance(node, BranchOperator):
+ skip_nodes = task_ctx.metadata.get("skip_node_names", [])
+ logger.info(
+ f"Current is branch operator, skip node names: {skip_nodes}"
+ )
+ _skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
+ except Exception as e:
+ logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
+ task_ctx.set_current_state(TaskState.FAILED)
+ raise e
+
+
+def _skip_current_downstream_by_node_name(
+ branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
+):
+ if not skip_nodes:
+ return
+ for child in branch_node.downstream:
+ if child.node_name in skip_nodes:
+ logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
+ _skip_downstream_by_id(child, skip_node_ids)
+
+
+def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
+ if isinstance(node, JoinOperator):
+ # Not skip join node
+ return
+ skip_node_ids.add(node.node_id)
+ for child in node.downstream:
+ _skip_downstream_by_id(child, skip_node_ids)
diff --git a/pilot/awel/task/__init__.py b/pilot/awel/task/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/awel/task/base.py b/pilot/awel/task/base.py
new file mode 100644
index 000000000..88b0df343
--- /dev/null
+++ b/pilot/awel/task/base.py
@@ -0,0 +1,367 @@
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import (
+ TypeVar,
+ Generic,
+ Optional,
+ AsyncIterator,
+ Union,
+ Callable,
+ Any,
+ Dict,
+ List,
+)
+
+IN = TypeVar("IN")
+OUT = TypeVar("OUT")
+T = TypeVar("T")
+
+
+class TaskState(str, Enum):
+ """Enumeration representing the state of a task in the workflow.
+
+ This Enum defines various states a task can be in during its lifecycle in the DAG.
+ """
+
+ INIT = "init" # Initial state of the task, not yet started
+ SKIP = "skip" # State indicating the task was skipped
+ RUNNING = "running" # State indicating the task is currently running
+ SUCCESS = "success" # State indicating the task completed successfully
+ FAILED = "failed" # State indicating the task failed during execution
+
+
+class TaskOutput(ABC, Generic[T]):
+ """Abstract base class representing the output of a task.
+
+ This class encapsulates the output of a task and provides methods to access the output data.
+ It can be subclassed to implement specific output behaviors.
+ """
+
+ @property
+ def is_stream(self) -> bool:
+ """Check if the output is a stream.
+
+ Returns:
+ bool: True if the output is a stream, False otherwise.
+ """
+ return False
+
+ @property
+ def is_empty(self) -> bool:
+ """Check if the output is empty.
+
+ Returns:
+ bool: True if the output is empty, False otherwise.
+ """
+ return False
+
+ @property
+ def output(self) -> Optional[T]:
+ """Return the output of the task.
+
+ Returns:
+ T: The output of the task. None if the output is empty.
+ """
+ raise NotImplementedError
+
+ @property
+ def output_stream(self) -> Optional[AsyncIterator[T]]:
+ """Return the output of the task as an asynchronous stream.
+
+ Returns:
+ AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None:
+ """Set the output data to current object.
+
+ Args:
+ output_data (Union[T, AsyncIterator[T]]): Output data.
+ """
+
+ @abstractmethod
+ def new_output(self) -> "TaskOutput[T]":
+ """Create new output object"""
+
+ async def map(self, map_func) -> "TaskOutput[T]":
+ """Apply a mapping function to the task's output.
+
+ Args:
+ map_func: A function to apply to the task's output.
+
+ Returns:
+ TaskOutput[T]: The result of applying the mapping function.
+ """
+ raise NotImplementedError
+
+ async def reduce(self, reduce_func) -> "TaskOutput[T]":
+ """Apply a reducing function to the task's output.
+
+ Stream TaskOutput to Nonstream TaskOutput.
+
+ Args:
+ reduce_func: A reducing function to apply to the task's output.
+
+ Returns:
+ TaskOutput[T]: The result of applying the reducing function.
+ """
+ raise NotImplementedError
+
+ async def streamify(
+ self, transform_func: Callable[[T], AsyncIterator[T]]
+ ) -> "TaskOutput[T]":
+ """Convert a value of type T to an AsyncIterator[T] using a transform function.
+
+ Args:
+ transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
+
+ Returns:
+ TaskOutput[T]: The result of applying the reducing function.
+ """
+ raise NotImplementedError
+
+ async def transform_stream(
+ self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
+ ) -> "TaskOutput[T]":
+ """Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
+
+ Args:
+ transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
+
+ Returns:
+ TaskOutput[T]: The result of applying the reducing function.
+ """
+ raise NotImplementedError
+
+ async def unstreamify(
+ self, transform_func: Callable[[AsyncIterator[T]], T]
+ ) -> "TaskOutput[T]":
+ """Convert an AsyncIterator[T] to a value of type T using a transform function.
+
+ Args:
+ transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
+
+ Returns:
+ TaskOutput[T]: The result of applying the reducing function.
+ """
+ raise NotImplementedError
+
+ async def check_condition(self, condition_func) -> bool:
+ """Check if current output meets a given condition.
+
+ Args:
+ condition_func: A function to determine if the condition is met.
+ Returns:
+ bool: True if current output meet the condition, False otherwise.
+ """
+ raise NotImplementedError
+
+
+class TaskContext(ABC, Generic[T]):
+ """Abstract base class representing the context of a task within a DAG.
+
+ This class provides the interface for accessing task-related information
+ and manipulating task output.
+ """
+
+ @property
+ @abstractmethod
+ def task_id(self) -> str:
+ """Return the unique identifier of the task.
+
+ Returns:
+ str: The unique identifier of the task.
+ """
+
+ @property
+ @abstractmethod
+ def task_input(self) -> "InputContext":
+ """Return the InputContext of current task.
+
+ Returns:
+ InputContext: The InputContext of current task.
+ """
+
+ @abstractmethod
+ def set_task_input(self, input_ctx: "InputContext") -> None:
+ """Set the InputContext object to current task.
+
+ Args:
+ input_ctx (InputContext): The InputContext of current task
+ """
+
+ @property
+ @abstractmethod
+ def task_output(self) -> TaskOutput[T]:
+ """Return the output object of the task.
+
+ Returns:
+ TaskOutput[T]: The output object of the task.
+ """
+
+ @abstractmethod
+ def set_task_output(self, task_output: TaskOutput[T]) -> None:
+ """Set the output object to current task."""
+
+ @property
+ @abstractmethod
+ def current_state(self) -> TaskState:
+ """Get the current state of the task.
+
+ Returns:
+ TaskState: The current state of the task.
+ """
+
+ @abstractmethod
+ def set_current_state(self, task_state: TaskState) -> None:
+ """Set current task state
+
+ Args:
+ task_state (TaskState): The task state to be set.
+ """
+
+ @abstractmethod
+ def new_ctx(self) -> "TaskContext":
+ """Create new task context
+
+ Returns:
+ TaskContext: A new instance of a TaskContext.
+ """
+
+ @property
+ @abstractmethod
+ def metadata(self) -> Dict[str, Any]:
+ """Get the metadata of current task
+
+ Returns:
+ Dict[str, Any]: The metadata
+ """
+
+ def update_metadata(self, key: str, value: Any) -> None:
+ """Update metadata with key and value
+
+ Args:
+ key (str): The key of metadata
+ value (str): The value to be add to metadata
+ """
+ self.metadata[key] = value
+
+ @property
+ def call_data(self) -> Optional[Dict]:
+ """Get the call data for current data"""
+ return self.metadata.get("call_data")
+
+ def set_call_data(self, call_data: Dict) -> None:
+ """Set call data for current task"""
+ self.update_metadata("call_data", call_data)
+
+
+class InputContext(ABC):
+ """Abstract base class representing the context of inputs to a operator node.
+
+ This class defines methods to manipulate and access the inputs for a operator node.
+ """
+
+ @property
+ @abstractmethod
+ def parent_outputs(self) -> List[TaskContext]:
+ """Get the outputs from the parent nodes.
+
+ Returns:
+ List[TaskContext]: A list of contexts of the parent nodes' outputs.
+ """
+
+ @abstractmethod
+ async def map(self, map_func: Callable[[Any], Any]) -> "InputContext":
+ """Apply a mapping function to the inputs.
+
+ Args:
+ map_func (Callable[[Any], Any]): A function to be applied to the inputs.
+
+ Returns:
+ InputContext: A new InputContext instance with the mapped inputs.
+ """
+
+ @abstractmethod
+ async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
+ """Apply a mapping function to all inputs.
+
+ Args:
+ map_func (Callable[..., Any]): A function to be applied to all inputs.
+
+ Returns:
+ InputContext: A new InputContext instance with the mapped inputs.
+ """
+
+ @abstractmethod
+ async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
+ """Apply a reducing function to the inputs.
+
+ Args:
+ reduce_func (Callable[[Any], Any]): A function that reduces the inputs.
+
+ Returns:
+ InputContext: A new InputContext instance with the reduced inputs.
+ """
+
+ @abstractmethod
+ async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext":
+ """Filter the inputs based on a provided function.
+
+ Args:
+ filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
+
+ Returns:
+ InputContext: A new InputContext instance with the filtered inputs.
+ """
+
+ @abstractmethod
+ async def predicate_map(
+ self, predicate_func: Callable[[Any], bool], failed_value: Any = None
+ ) -> "InputContext":
+ """Predicate the inputs based on a provided function.
+
+ Args:
+ predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
+ failed_value (Any): The value to be set if the return value of predicate function is False
+ Returns:
+ InputContext: A new InputContext instance with the predicate inputs.
+ """
+
+ def check_single_parent(self) -> bool:
+ """Check if there is only a single parent output.
+
+ Returns:
+ bool: True if there is only one parent output, False otherwise.
+ """
+ return len(self.parent_outputs) == 1
+
+ def check_stream(self, skip_empty: bool = False) -> bool:
+ """Check if all parent outputs are streams.
+
+ Args:
+ skip_empty (bool): Skip empty output or not.
+
+ Returns:
+ bool: True if all parent outputs are streams, False otherwise.
+ """
+ for out in self.parent_outputs:
+ if out.task_output.is_empty and skip_empty:
+ continue
+ if not (out.task_output and out.task_output.is_stream):
+ return False
+ return True
+
+
+class InputSource(ABC, Generic[T]):
+ """Abstract base class representing the source of inputs to a DAG node."""
+
+ @abstractmethod
+ async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
+ """Read the data from current input source.
+
+ Returns:
+ TaskOutput[T]: The output object read from current source
+ """
diff --git a/pilot/awel/task/task_impl.py b/pilot/awel/task/task_impl.py
new file mode 100644
index 000000000..f969c135c
--- /dev/null
+++ b/pilot/awel/task/task_impl.py
@@ -0,0 +1,339 @@
+from abc import ABC, abstractmethod
+from typing import (
+ Callable,
+ Coroutine,
+ Iterator,
+ AsyncIterator,
+ List,
+ Generic,
+ TypeVar,
+ Any,
+ Tuple,
+ Dict,
+ Union,
+)
+import asyncio
+import logging
+from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
+
+
+logger = logging.getLogger(__name__)
+
+
+async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
+ # Init accumulator
+ try:
+ accumulator = await stream.__anext__()
+ except StopAsyncIteration:
+ raise ValueError("Stream is empty")
+ is_async = asyncio.iscoroutinefunction(reduce_function)
+ async for element in stream:
+ if is_async:
+ accumulator = await reduce_function(accumulator, element)
+ else:
+ accumulator = reduce_function(accumulator, element)
+ return accumulator
+
+
+class SimpleTaskOutput(TaskOutput[T], Generic[T]):
+ def __init__(self, data: T) -> None:
+ super().__init__()
+ self._data = data
+
+ @property
+ def output(self) -> T:
+ return self._data
+
+ def set_output(self, output_data: T | AsyncIterator[T]) -> None:
+ self._data = output_data
+
+ def new_output(self) -> TaskOutput[T]:
+ return SimpleTaskOutput(None)
+
+ @property
+ def is_empty(self) -> bool:
+ return not self._data
+
+ async def _apply_func(self, func) -> Any:
+ if asyncio.iscoroutinefunction(func):
+ out = await func(self._data)
+ else:
+ out = func(self._data)
+ return out
+
+ async def map(self, map_func) -> TaskOutput[T]:
+ out = await self._apply_func(map_func)
+ return SimpleTaskOutput(out)
+
+ async def check_condition(self, condition_func) -> bool:
+ return await self._apply_func(condition_func)
+
+ async def streamify(
+ self, transform_func: Callable[[T], AsyncIterator[T]]
+ ) -> TaskOutput[T]:
+ out = await self._apply_func(transform_func)
+ return SimpleStreamTaskOutput(out)
+
+
+class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
+ def __init__(self, data: AsyncIterator[T]) -> None:
+ super().__init__()
+ self._data = data
+
+ @property
+ def is_stream(self) -> bool:
+ return True
+
+ @property
+ def is_empty(self) -> bool:
+ return not self._data
+
+ @property
+ def output_stream(self) -> AsyncIterator[T]:
+ return self._data
+
+ def set_output(self, output_data: T | AsyncIterator[T]) -> None:
+ self._data = output_data
+
+ def new_output(self) -> TaskOutput[T]:
+ return SimpleStreamTaskOutput(None)
+
+ async def map(self, map_func) -> TaskOutput[T]:
+ is_async = asyncio.iscoroutinefunction(map_func)
+
+ async def new_iter() -> AsyncIterator[T]:
+ async for out in self._data:
+ if is_async:
+ out = await map_func(out)
+ else:
+ out = map_func(out)
+ yield out
+
+ return SimpleStreamTaskOutput(new_iter())
+
+ async def reduce(self, reduce_func) -> TaskOutput[T]:
+ out = await _reduce_stream(self._data, reduce_func)
+ return SimpleTaskOutput(out)
+
+ async def unstreamify(
+ self, transform_func: Callable[[AsyncIterator[T]], T]
+ ) -> TaskOutput[T]:
+ if asyncio.iscoroutinefunction(transform_func):
+ out = await transform_func(self._data)
+ else:
+ out = transform_func(self._data)
+ return SimpleTaskOutput(out)
+
+ async def transform_stream(
+ self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
+ ) -> TaskOutput[T]:
+ if asyncio.iscoroutinefunction(transform_func):
+ out = await transform_func(self._data)
+ else:
+ out = transform_func(self._data)
+ return SimpleStreamTaskOutput(out)
+
+
+def _is_async_iterator(obj):
+ return (
+ hasattr(obj, "__anext__")
+ and callable(getattr(obj, "__anext__", None))
+ and hasattr(obj, "__aiter__")
+ and callable(getattr(obj, "__aiter__", None))
+ )
+
+
+class BaseInputSource(InputSource, ABC):
+ def __init__(self) -> None:
+ super().__init__()
+ self._is_read = False
+
+ @abstractmethod
+ def _read_data(self, task_ctx: TaskContext) -> Any:
+ """Read data with task context"""
+
+ async def read(self, task_ctx: TaskContext) -> Coroutine[Any, Any, TaskOutput]:
+ data = self._read_data(task_ctx)
+ if _is_async_iterator(data):
+ if self._is_read:
+ raise ValueError(f"Input iterator {data} has been read!")
+ output = SimpleStreamTaskOutput(data)
+ else:
+ output = SimpleTaskOutput(data)
+ self._is_read = True
+ return output
+
+
+class SimpleInputSource(BaseInputSource):
+ def __init__(self, data: Any) -> None:
+ super().__init__()
+ self._data = data
+
+ def _read_data(self, task_ctx: TaskContext) -> Any:
+ return self._data
+
+
+class SimpleCallDataInputSource(BaseInputSource):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def _read_data(self, task_ctx: TaskContext) -> Any:
+ call_data = task_ctx.call_data
+ data = call_data.get("data") if call_data else None
+ if not (call_data and data):
+ raise ValueError("No call data for current SimpleCallDataInputSource")
+ return data
+
+
+class DefaultTaskContext(TaskContext, Generic[T]):
+ def __init__(
+ self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
+ ) -> None:
+ super().__init__()
+ self._task_id = task_id
+ self._task_state = task_state
+ self._output = task_output
+ self._task_input = None
+ self._metadata = {}
+
+ @property
+ def task_id(self) -> str:
+ return self._task_id
+
+ @property
+ def task_input(self) -> InputContext:
+ return self._task_input
+
+ def set_task_input(self, input_ctx: "InputContext") -> None:
+ self._task_input = input_ctx
+
+ @property
+ def task_output(self) -> TaskOutput:
+ return self._output
+
+ def set_task_output(self, task_output: TaskOutput) -> None:
+ self._output = task_output
+
+ @property
+ def current_state(self) -> TaskState:
+ return self._task_state
+
+ def set_current_state(self, task_state: TaskState) -> None:
+ self._task_state = task_state
+
+ def new_ctx(self) -> TaskContext:
+ new_output = self._output.new_output()
+ return DefaultTaskContext(self._task_id, self._task_state, new_output)
+
+ @property
+ def metadata(self) -> Dict[str, Any]:
+ return self._metadata
+
+
+class DefaultInputContext(InputContext):
+ def __init__(self, outputs: List[TaskContext]) -> None:
+ super().__init__()
+ self._outputs = outputs
+
+ @property
+ def parent_outputs(self) -> List[TaskContext]:
+ return self._outputs
+
+ async def _apply_func(
+ self, func: Callable[[Any], Any], apply_type: str = "map"
+ ) -> Tuple[List[TaskContext], List[TaskOutput]]:
+ new_outputs: List[TaskContext] = []
+ map_tasks = []
+ for out in self._outputs:
+ new_outputs.append(out.new_ctx())
+ result = None
+ if apply_type == "map":
+ result = out.task_output.map(func)
+ elif apply_type == "reduce":
+ result = out.task_output.reduce(func)
+ elif apply_type == "check_condition":
+ result = out.task_output.check_condition(func)
+ else:
+ raise ValueError(f"Unsupport apply type {apply_type}")
+ map_tasks.append(result)
+ results = await asyncio.gather(*map_tasks)
+ return new_outputs, results
+
+ async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
+ new_outputs, results = await self._apply_func(map_func)
+ for i, task_ctx in enumerate(new_outputs):
+ task_ctx: TaskContext = task_ctx
+ task_ctx.set_task_output(results[i])
+ return DefaultInputContext(new_outputs)
+
+ async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
+ if not self._outputs:
+ return DefaultInputContext([])
+ # Some parent may be empty
+ not_empty_idx = 0
+ for i, p in enumerate(self._outputs):
+ if p.task_output.is_empty:
+ continue
+ not_empty_idx = i
+ break
+ # All output is empty?
+ is_steam = self._outputs[not_empty_idx].task_output.is_stream
+ if is_steam:
+ if not self.check_stream(skip_empty=True):
+ raise ValueError(
+ "The output in all tasks must has same output format to map_all"
+ )
+ outputs = []
+ for out in self._outputs:
+ if out.task_output.is_stream:
+ outputs.append(out.task_output.output_stream)
+ else:
+ outputs.append(out.task_output.output)
+ if asyncio.iscoroutinefunction(map_func):
+ map_res = await map_func(*outputs)
+ else:
+ map_res = map_func(*outputs)
+ single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
+ single_output.task_output.set_output(map_res)
+ logger.debug(
+ f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
+ )
+ return DefaultInputContext([single_output])
+
+ async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
+ if not self.check_stream():
+ raise ValueError(
+ "The output in all tasks must has same output format of stream to apply reduce function"
+ )
+ new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
+ for i, task_ctx in enumerate(new_outputs):
+ task_ctx: TaskContext = task_ctx
+ task_ctx.set_task_output(results[i])
+ return DefaultInputContext(new_outputs)
+
+ async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
+ new_outputs, results = await self._apply_func(
+ filter_func, apply_type="check_condition"
+ )
+ result_outputs = []
+ for i, task_ctx in enumerate(new_outputs):
+ if results[i]:
+ result_outputs.append(task_ctx)
+ return DefaultInputContext(result_outputs)
+
+ async def predicate_map(
+ self, predicate_func: Callable[[Any], bool], failed_value: Any = None
+ ) -> "InputContext":
+ new_outputs, results = await self._apply_func(
+ predicate_func, apply_type="check_condition"
+ )
+ result_outputs = []
+ for i, task_ctx in enumerate(new_outputs):
+ task_ctx: TaskContext = task_ctx
+ if results[i]:
+ task_ctx.task_output.set_output(True)
+ result_outputs.append(task_ctx)
+ else:
+ task_ctx.task_output.set_output(failed_value)
+ result_outputs.append(task_ctx)
+ return DefaultInputContext(result_outputs)
diff --git a/pilot/awel/tests/__init__.py b/pilot/awel/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/awel/tests/conftest.py b/pilot/awel/tests/conftest.py
new file mode 100644
index 000000000..2279cceba
--- /dev/null
+++ b/pilot/awel/tests/conftest.py
@@ -0,0 +1,102 @@
+import pytest
+import pytest_asyncio
+from typing import AsyncIterator, List
+from contextlib import contextmanager, asynccontextmanager
+from .. import (
+ WorkflowRunner,
+ InputOperator,
+ DAGContext,
+ TaskState,
+ DefaultWorkflowRunner,
+ SimpleInputSource,
+)
+from ..task.task_impl import _is_async_iterator
+
+
+@pytest.fixture
+def runner():
+ return DefaultWorkflowRunner()
+
+
+def _create_stream(num_nodes) -> List[AsyncIterator[int]]:
+ iters = []
+ for _ in range(num_nodes):
+
+ async def stream_iter():
+ for i in range(10):
+ yield i
+
+ stream_iter = stream_iter()
+ assert _is_async_iterator(stream_iter)
+ iters.append(stream_iter)
+ return iters
+
+
+def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]:
+ iters = []
+ for single_stream in output_streams:
+
+ async def stream_iter():
+ for i in single_stream:
+ yield i
+
+ stream_iter = stream_iter()
+ assert _is_async_iterator(stream_iter)
+ iters.append(stream_iter)
+ return iters
+
+
+@asynccontextmanager
+async def _create_input_node(**kwargs):
+ num_nodes = kwargs.get("num_nodes")
+ is_stream = kwargs.get("is_stream", False)
+ if is_stream:
+ outputs = kwargs.get("output_streams")
+ if outputs:
+ if num_nodes and num_nodes != len(outputs):
+ raise ValueError(
+ f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}"
+ )
+ outputs = _create_stream_from(outputs)
+ else:
+ num_nodes = num_nodes or 1
+ outputs = _create_stream(num_nodes)
+ else:
+ outputs = kwargs.get("outputs", ["Hello."])
+ nodes = []
+ for output in outputs:
+ print(f"output: {output}")
+ input_source = SimpleInputSource(output)
+ input_node = InputOperator(input_source)
+ nodes.append(input_node)
+ yield nodes
+
+
+@pytest_asyncio.fixture
+async def input_node(request):
+ param = getattr(request, "param", {})
+ async with _create_input_node(**param) as input_nodes:
+ yield input_nodes[0]
+
+
+@pytest_asyncio.fixture
+async def stream_input_node(request):
+ param = getattr(request, "param", {})
+ param["is_stream"] = True
+ async with _create_input_node(**param) as input_nodes:
+ yield input_nodes[0]
+
+
+@pytest_asyncio.fixture
+async def input_nodes(request):
+ param = getattr(request, "param", {})
+ async with _create_input_node(**param) as input_nodes:
+ yield input_nodes
+
+
+@pytest_asyncio.fixture
+async def stream_input_nodes(request):
+ param = getattr(request, "param", {})
+ param["is_stream"] = True
+ async with _create_input_node(**param) as input_nodes:
+ yield input_nodes
diff --git a/pilot/awel/tests/test_http_operator.py b/pilot/awel/tests/test_http_operator.py
new file mode 100644
index 000000000..c57e70fe1
--- /dev/null
+++ b/pilot/awel/tests/test_http_operator.py
@@ -0,0 +1,51 @@
+import pytest
+from typing import List
+from .. import (
+ DAG,
+ WorkflowRunner,
+ DAGContext,
+ TaskState,
+ InputOperator,
+ MapOperator,
+ JoinOperator,
+ BranchOperator,
+ ReduceStreamOperator,
+ SimpleInputSource,
+)
+from .conftest import (
+ runner,
+ input_node,
+ input_nodes,
+ stream_input_node,
+ stream_input_nodes,
+ _is_async_iterator,
+)
+
+
+def _register_dag_to_fastapi_app(dag):
+ # TODO
+ pass
+
+
+@pytest.mark.asyncio
+async def test_http_operator(runner: WorkflowRunner, stream_input_node: InputOperator):
+ with DAG("test_map") as dag:
+ pass
+ # http_req_task = HttpRequestOperator(endpoint="/api/completions")
+ # db_task = DBQueryOperator(table_name="user_info")
+ # prompt_task = PromptTemplateOperator(
+ # system_prompt="You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
+ # )
+ # llm_task = ChatGPTLLMOperator(model="chagpt-3.5")
+ # output_parser_task = CommonOutputParserOperator()
+ # http_res_task = HttpResponseOperator()
+ # (
+ # http_req_task
+ # >> db_task
+ # >> prompt_task
+ # >> llm_task
+ # >> output_parser_task
+ # >> http_res_task
+ # )
+
+ _register_dag_to_fastapi_app(dag)
diff --git a/pilot/awel/tests/test_run_dag.py b/pilot/awel/tests/test_run_dag.py
new file mode 100644
index 000000000..c0ea8e7ad
--- /dev/null
+++ b/pilot/awel/tests/test_run_dag.py
@@ -0,0 +1,141 @@
+import pytest
+from typing import List
+from .. import (
+ DAG,
+ WorkflowRunner,
+ DAGContext,
+ TaskState,
+ InputOperator,
+ MapOperator,
+ JoinOperator,
+ BranchOperator,
+ ReduceStreamOperator,
+ SimpleInputSource,
+)
+from .conftest import (
+ runner,
+ input_node,
+ input_nodes,
+ stream_input_node,
+ stream_input_nodes,
+ _is_async_iterator,
+)
+
+
+@pytest.mark.asyncio
+async def test_input_node(runner: WorkflowRunner):
+ input_node = InputOperator(SimpleInputSource("hello"))
+ res: DAGContext[str] = await runner.execute_workflow(input_node)
+ assert res.current_task_context.current_state == TaskState.SUCCESS
+ assert res.current_task_context.task_output.output == "hello"
+
+ async def new_steam_iter(n: int):
+ for i in range(n):
+ yield i
+
+ num_iter = 10
+ steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
+ res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
+ assert res.current_task_context.current_state == TaskState.SUCCESS
+ output_steam = res.current_task_context.task_output.output_stream
+ assert output_steam
+ assert _is_async_iterator(output_steam)
+ i = 0
+ async for x in output_steam:
+ assert x == i
+ i += 1
+
+
+@pytest.mark.asyncio
+async def test_map_node(runner: WorkflowRunner, stream_input_node: InputOperator):
+ with DAG("test_map") as dag:
+ map_node = MapOperator(lambda x: x * 2)
+ stream_input_node >> map_node
+ res: DAGContext[int] = await runner.execute_workflow(map_node)
+ output_steam = res.current_task_context.task_output.output_stream
+ assert output_steam
+ i = 0
+ async for x in output_steam:
+ assert x == i * 2
+ i += 1
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "stream_input_node, expect_sum",
+ [
+ ({"output_streams": [[0, 1, 2, 3]]}, 6),
+ ({"output_streams": [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]}, 55),
+ ],
+ indirect=["stream_input_node"],
+)
+async def test_reduce_node(
+ runner: WorkflowRunner, stream_input_node: InputOperator, expect_sum: int
+):
+ with DAG("test_reduce_node") as dag:
+ reduce_node = ReduceStreamOperator(lambda x, y: x + y)
+ stream_input_node >> reduce_node
+ res: DAGContext[int] = await runner.execute_workflow(reduce_node)
+ assert res.current_task_context.current_state == TaskState.SUCCESS
+ assert not res.current_task_context.task_output.is_stream
+ assert res.current_task_context.task_output.output == expect_sum
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "input_nodes",
+ [
+ ({"outputs": [0, 1, 2]}),
+ ],
+ indirect=["input_nodes"],
+)
+async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator]):
+ def join_func(p1, p2, p3) -> int:
+ return p1 + p2 + p3
+
+ with DAG("test_join_node") as dag:
+ join_node = JoinOperator(join_func)
+ for input_node in input_nodes:
+ input_node >> join_node
+ res: DAGContext[int] = await runner.execute_workflow(join_node)
+ assert res.current_task_context.current_state == TaskState.SUCCESS
+ assert not res.current_task_context.task_output.is_stream
+ assert res.current_task_context.task_output.output == 3
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "input_node, is_odd",
+ [
+ ({"outputs": [0]}, False),
+ ({"outputs": [1]}, True),
+ ],
+ indirect=["input_node"],
+)
+async def test_branch_node(
+ runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
+):
+ def join_func(o1, o2) -> int:
+ print(f"join func result, o1: {o1}, o2: {o2}")
+ return o1 or o2
+
+ with DAG("test_join_node") as dag:
+ odd_node = MapOperator(
+ lambda x: 999, task_id="odd_node", task_name="odd_node_name"
+ )
+ even_node = MapOperator(
+ lambda x: 888, task_id="even_node", task_name="even_node_name"
+ )
+ join_node = JoinOperator(join_func)
+ branch_node = BranchOperator(
+ {lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
+ )
+ branch_node >> odd_node >> join_node
+ branch_node >> even_node >> join_node
+
+ input_node >> branch_node
+
+ res: DAGContext[int] = await runner.execute_workflow(join_node)
+ assert res.current_task_context.current_state == TaskState.SUCCESS
+ expect_res = 999 if is_odd else 888
+ assert res.current_task_context.task_output.output == expect_res
diff --git a/pilot/cache/__init__.py b/pilot/cache/__init__.py
new file mode 100644
index 000000000..65f768a7e
--- /dev/null
+++ b/pilot/cache/__init__.py
@@ -0,0 +1,10 @@
+from pilot.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
+from pilot.cache.manager import CacheManager, initialize_cache
+
+__all__ = [
+ "LLMCacheKey",
+ "LLMCacheValue",
+ "LLMCacheClient",
+ "CacheManager",
+ "initialize_cache",
+]
diff --git a/pilot/cache/base.py b/pilot/cache/base.py
new file mode 100644
index 000000000..feb135288
--- /dev/null
+++ b/pilot/cache/base.py
@@ -0,0 +1,161 @@
+from abc import ABC, abstractmethod, abstractclassmethod
+from typing import Any, TypeVar, Generic, Optional, Type, Dict
+from dataclasses import dataclass
+from enum import Enum
+
+T = TypeVar("T", bound="Serializable")
+
+K = TypeVar("K")
+V = TypeVar("V")
+
+
+class Serializable(ABC):
+ @abstractmethod
+ def serialize(self) -> bytes:
+ """Convert the object into bytes for storage or transmission.
+
+ Returns:
+ bytes: The byte array after serialization
+ """
+
+ @abstractmethod
+ def to_dict(self) -> Dict:
+ """Convert the object's state to a dictionary."""
+
+ # @staticmethod
+ # @abstractclassmethod
+ # def from_dict(cls: Type["Serializable"], obj_dict: Dict) -> "Serializable":
+ # """Deserialize a dictionary to an Serializable object.
+ # """
+
+
+class RetrievalPolicy(str, Enum):
+ EXACT_MATCH = "exact_match"
+ SIMILARITY_MATCH = "similarity_match"
+
+
+class CachePolicy(str, Enum):
+ LRU = "lru"
+ FIFO = "fifo"
+
+
+@dataclass
+class CacheConfig:
+ retrieval_policy: Optional[RetrievalPolicy] = RetrievalPolicy.EXACT_MATCH
+ cache_policy: Optional[CachePolicy] = CachePolicy.LRU
+
+
+class CacheKey(Serializable, ABC, Generic[K]):
+ """The key of the cache. Must be hashable and comparable.
+
+ Supported cache keys:
+ - The LLM cache key: Include user prompt and the parameters to LLM.
+ - The embedding model cache key: Include the texts to embedding and the parameters to embedding model.
+ """
+
+ @abstractmethod
+ def __hash__(self) -> int:
+ """Return the hash value of the key."""
+
+ @abstractmethod
+ def __eq__(self, other: Any) -> bool:
+ """Check equality with another key."""
+
+ @abstractmethod
+ def get_hash_bytes(self) -> bytes:
+ """Return the byte array of hash value."""
+
+ @abstractmethod
+ def get_value(self) -> K:
+ """Get the underlying value of the cache key.
+
+ Returns:
+ K: The real object of current cache key
+ """
+
+
+class CacheValue(Serializable, ABC, Generic[V]):
+ """Cache value abstract class."""
+
+ @abstractmethod
+ def get_value(self) -> V:
+ """Get the underlying real value."""
+
+
+class Serializer(ABC):
+ """The serializer abstract class for serializing cache keys and values."""
+
+ @abstractmethod
+ def serialize(self, obj: Serializable) -> bytes:
+ """Serialize a cache object.
+
+ Args:
+ obj (Serializable): The object to serialize
+ """
+
+ @abstractmethod
+ def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
+ """Deserialize data back into a cache object of the specified type.
+
+ Args:
+ data (bytes): The byte array to deserialize
+ cls (Type[Serializable]): The type of current object
+
+ Returns:
+ Serializable: The serializable object
+ """
+
+
+class CacheClient(ABC, Generic[K, V]):
+ """The cache client interface."""
+
+ @abstractmethod
+ async def get(
+ self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
+ ) -> Optional[CacheValue[V]]:
+ """Retrieve a value from the cache using the provided key.
+
+ Args:
+ key (CacheKey[K]): The key to get cache
+ cache_config (Optional[CacheConfig]): Cache config
+
+ Returns:
+ Optional[CacheValue[V]]: The value retrieved according to key. If cache key not exist, return None.
+ """
+
+ @abstractmethod
+ async def set(
+ self,
+ key: CacheKey[K],
+ value: CacheValue[V],
+ cache_config: Optional[CacheConfig] = None,
+ ) -> None:
+ """Set a value in the cache for the provided key.
+
+ Args:
+ key (CacheKey[K]): The key to set to cache
+ value (CacheValue[V]): The value to set to cache
+ cache_config (Optional[CacheConfig]): Cache config
+ """
+
+ @abstractmethod
+ async def exists(
+ self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
+ ) -> bool:
+ """Check if a key exists in the cache.
+
+ Args:
+ key (CacheKey[K]): The key to set to cache
+ cache_config (Optional[CacheConfig]): Cache config
+
+ Return:
+ bool: True if the key in the cache, otherwise is False
+ """
+
+ @abstractmethod
+ def new_key(self, **kwargs) -> CacheKey[K]:
+ """Create a cache key with params"""
+
+ @abstractmethod
+ def new_value(self, **kwargs) -> CacheValue[K]:
+ """Create a cache key with params"""
diff --git a/pilot/cache/embedding_cache.py b/pilot/cache/embedding_cache.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/cache/llm_cache.py b/pilot/cache/llm_cache.py
new file mode 100644
index 000000000..ad559df03
--- /dev/null
+++ b/pilot/cache/llm_cache.py
@@ -0,0 +1,148 @@
+from typing import Optional, Dict, Any, Union, List
+from dataclasses import dataclass, asdict
+import json
+import hashlib
+
+from pilot.cache.base import CacheKey, CacheValue, Serializer, CacheClient, CacheConfig
+from pilot.cache.manager import CacheManager
+from pilot.cache.storage.base import CacheStorage
+from pilot.model.base import ModelType, ModelOutput
+
+
+@dataclass
+class LLMCacheKeyData:
+ prompt: str
+ model_name: str
+ temperature: Optional[float] = 0.7
+ max_new_tokens: Optional[int] = None
+ top_p: Optional[float] = 1.0
+ model_type: Optional[str] = ModelType.HF
+
+
+CacheOutputType = Union[ModelOutput, List[ModelOutput]]
+
+
+@dataclass
+class LLMCacheValueData:
+ output: CacheOutputType
+ user: Optional[str] = None
+ _is_list: Optional[bool] = False
+
+ @staticmethod
+ def from_dict(**kwargs) -> "LLMCacheValueData":
+ output = kwargs.get("output")
+ if not output:
+ raise ValueError("Can't new LLMCacheValueData object, output is None")
+ if isinstance(output, dict):
+ output = ModelOutput(**output)
+ elif isinstance(output, list):
+ kwargs["_is_list"] = True
+ output_list = []
+ for out in output:
+ if isinstance(out, dict):
+ out = ModelOutput(**out)
+ output_list.append(out)
+ output = output_list
+ kwargs["output"] = output
+ return LLMCacheValueData(**kwargs)
+
+ def to_dict(self) -> Dict:
+ output = self.output
+ is_list = False
+ if isinstance(output, list):
+ output_list = []
+ is_list = True
+ for out in output:
+ output_list.append(out.to_dict())
+ output = output_list
+ else:
+ output = output.to_dict()
+ return {"output": output, "_is_list": is_list, "user": self.user}
+
+ @property
+ def is_list(self) -> bool:
+ return self._is_list
+
+ def __str__(self) -> str:
+ if not isinstance(self.output, list):
+ return f"user: {self.user}, output: {self.output}"
+ else:
+ return f"user: {self.user}, output(last two item): {self.output[-2:]}"
+
+
+class LLMCacheKey(CacheKey[LLMCacheKeyData]):
+ def __init__(self, serializer: Serializer = None, **kwargs) -> None:
+ super().__init__()
+ self._serializer = serializer
+ self.config = LLMCacheKeyData(**kwargs)
+
+ def __hash__(self) -> int:
+ serialize_bytes = self.serialize()
+ return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, LLMCacheKey):
+ return False
+ return self.config == other.config
+
+ def get_hash_bytes(self) -> bytes:
+ serialize_bytes = self.serialize()
+ return hashlib.sha256(serialize_bytes).digest()
+
+ def to_dict(self) -> Dict:
+ return asdict(self.config)
+
+ def serialize(self) -> bytes:
+ return self._serializer.serialize(self)
+
+ def get_value(self) -> LLMCacheKeyData:
+ return self.config
+
+
+class LLMCacheValue(CacheValue[LLMCacheValueData]):
+ def __init__(self, serializer: Serializer = None, **kwargs) -> None:
+ super().__init__()
+ self._serializer = serializer
+ self.value = LLMCacheValueData.from_dict(**kwargs)
+
+ def to_dict(self) -> Dict:
+ return self.value.to_dict()
+
+ def serialize(self) -> bytes:
+ return self._serializer.serialize(self)
+
+ def get_value(self) -> LLMCacheValueData:
+ return self.value
+
+ def __str__(self) -> str:
+ return f"vaue: {str(self.value)}"
+
+
+class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
+ def __init__(self, cache_manager: CacheManager) -> None:
+ super().__init__()
+ self._cache_manager: CacheManager = cache_manager
+
+ async def get(
+ self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
+ ) -> Optional[LLMCacheValue]:
+ return await self._cache_manager.get(key, LLMCacheValue, cache_config)
+
+ async def set(
+ self,
+ key: LLMCacheKey,
+ value: LLMCacheValue,
+ cache_config: Optional[CacheConfig] = None,
+ ) -> None:
+ return await self._cache_manager.set(key, value, cache_config)
+
+ async def exists(
+ self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
+ ) -> bool:
+ return await self.get(key, cache_config) is not None
+
+ def new_key(self, **kwargs) -> LLMCacheKey:
+ return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs)
+
+ def new_value(self, **kwargs) -> LLMCacheValue:
+ return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs)
diff --git a/pilot/cache/manager.py b/pilot/cache/manager.py
new file mode 100644
index 000000000..0e76df0b3
--- /dev/null
+++ b/pilot/cache/manager.py
@@ -0,0 +1,126 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Type
+import logging
+from concurrent.futures import Executor
+from pilot.cache.storage.base import CacheStorage, StorageItem
+from pilot.cache.base import (
+ K,
+ V,
+ CacheKey,
+ CacheValue,
+ CacheConfig,
+ Serializer,
+ Serializable,
+)
+from pilot.component import BaseComponent, ComponentType, SystemApp
+from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
+
+logger = logging.getLogger(__name__)
+
+
+class CacheManager(BaseComponent, ABC):
+ name = ComponentType.MODEL_CACHE_MANAGER
+
+ def __init__(self, system_app: SystemApp | None = None):
+ super().__init__(system_app)
+
+ def init_app(self, system_app: SystemApp):
+ self.system_app = system_app
+
+ @abstractmethod
+ async def set(
+ self,
+ key: CacheKey[K],
+ value: CacheValue[V],
+ cache_config: Optional[CacheConfig] = None,
+ ):
+ """Set cache"""
+
+ @abstractmethod
+ async def get(
+ self,
+ key: CacheKey[K],
+ cls: Type[Serializable],
+ cache_config: Optional[CacheConfig] = None,
+ ) -> CacheValue[V]:
+ """Get cache with key"""
+
+ @property
+ @abstractmethod
+ def serializer(self) -> Serializer:
+ """Get cache serializer"""
+
+
+class LocalCacheManager(CacheManager):
+ def __init__(
+ self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
+ ) -> None:
+ super().__init__(system_app)
+ self._serializer = serializer
+ self._storage = storage
+
+ @property
+ def executor(self) -> Executor:
+ """Return executor to submit task"""
+ self._executor = self.system_app.get_component(
+ ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
+ ).create()
+
+ async def set(
+ self,
+ key: CacheKey[K],
+ value: CacheValue[V],
+ cache_config: Optional[CacheConfig] = None,
+ ):
+ if self._storage.support_async():
+ await self._storage.aset(key, value, cache_config)
+ else:
+ await blocking_func_to_async(
+ self.executor, self._storage.set, key, value, cache_config
+ )
+
+ async def get(
+ self,
+ key: CacheKey[K],
+ cls: Type[Serializable],
+ cache_config: Optional[CacheConfig] = None,
+ ) -> CacheValue[V]:
+ if self._storage.support_async():
+ item_bytes = await self._storage.aget(key, cache_config)
+ else:
+ item_bytes = await blocking_func_to_async(
+ self.executor, self._storage.get, key, cache_config
+ )
+ if not item_bytes:
+ return None
+ return self._serializer.deserialize(item_bytes.value_data, cls)
+
+ @property
+ def serializer(self) -> Serializer:
+ return self._serializer
+
+
+def initialize_cache(
+ system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
+):
+ from pilot.cache.protocal.json_protocal import JsonSerializer
+ from pilot.cache.storage.base import MemoryCacheStorage
+
+ cache_storage = None
+ if storage_type == "disk":
+ try:
+ from pilot.cache.storage.disk.disk_storage import DiskCacheStorage
+
+ cache_storage = DiskCacheStorage(
+ persist_dir, mem_table_buffer_mb=max_memory_mb
+ )
+ except ImportError as e:
+ logger.warn(
+ f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
+ )
+ cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
+ else:
+ cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
+ system_app.register(
+ LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage
+ )
diff --git a/pilot/cache/protocal/__init__.py b/pilot/cache/protocal/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/cache/protocal/json_protocal.py b/pilot/cache/protocal/json_protocal.py
new file mode 100644
index 000000000..6f73fef3f
--- /dev/null
+++ b/pilot/cache/protocal/json_protocal.py
@@ -0,0 +1,44 @@
+from abc import ABC, abstractmethod
+from typing import Dict, Type
+import json
+
+from pilot.cache.base import Serializable, Serializer
+
+JSON_ENCODING = "utf-8"
+
+
+class JsonSerializable(Serializable, ABC):
+ @abstractmethod
+ def to_dict(self) -> Dict:
+ """Return the dict of current serializable object"""
+
+ def serialize(self) -> bytes:
+ """Convert the object into bytes for storage or transmission."""
+ return json.dumps(self.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
+
+
+class JsonSerializer(Serializer):
+ """The serializer abstract class for serializing cache keys and values."""
+
+ def serialize(self, obj: Serializable) -> bytes:
+ """Serialize a cache object.
+
+ Args:
+ obj (Serializable): The object to serialize
+ """
+ return json.dumps(obj.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
+
+ def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
+ """Deserialize data back into a cache object of the specified type.
+
+ Args:
+ data (bytes): The byte array to deserialize
+ cls (Type[Serializable]): The type of current object
+
+ Returns:
+ Serializable: The serializable object
+ """
+ # Convert bytes back to JSON and then to the specified class
+ json_data = json.loads(data.decode(JSON_ENCODING))
+ # Assume that the cls has an __init__ that accepts a dictionary
+ return cls(**json_data)
diff --git a/pilot/cache/storage/__init__.py b/pilot/cache/storage/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/cache/storage/base.py b/pilot/cache/storage/base.py
new file mode 100644
index 000000000..ea07bfacf
--- /dev/null
+++ b/pilot/cache/storage/base.py
@@ -0,0 +1,252 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+from dataclasses import dataclass
+from collections import OrderedDict
+import msgpack
+import logging
+
+from pilot.cache.base import (
+ K,
+ V,
+ CacheKey,
+ CacheValue,
+ CacheClient,
+ CacheConfig,
+ RetrievalPolicy,
+ CachePolicy,
+)
+from pilot.utils.memory_utils import _get_object_bytes
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class StorageItem:
+ """
+ A class representing a storage item.
+
+ This class encapsulates data related to a storage item, such as its length,
+ the hash of the key, and the data for both the key and value.
+
+ Parameters:
+ length (int): The bytes length of the storage item.
+ key_hash (bytes): The hash value of the storage item's key.
+ key_data (bytes): The data of the storage item's key, represented in bytes.
+ value_data (bytes): The data of the storage item's value, also in bytes.
+ """
+
+ length: int # The bytes length of the storage item
+ key_hash: bytes # The hash value of the storage item's key
+ key_data: bytes # The data of the storage item's key
+ value_data: bytes # The data of the storage item's value
+
+ @staticmethod
+ def build_from(
+ key_hash: bytes, key_data: bytes, value_data: bytes
+ ) -> "StorageItem":
+ length = (
+ 32
+ + _get_object_bytes(key_hash)
+ + _get_object_bytes(key_data)
+ + _get_object_bytes(value_data)
+ )
+ return StorageItem(
+ length=length, key_hash=key_hash, key_data=key_data, value_data=value_data
+ )
+
+ @staticmethod
+ def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
+ key_hash = key.get_hash_bytes()
+ key_data = key.serialize()
+ value_data = value.serialize()
+ return StorageItem.build_from(key_hash, key_data, value_data)
+
+ def serialize(self) -> bytes:
+ """Serialize the StorageItem into a byte stream using MessagePack.
+
+ This method packs the object data into a dictionary, marking the
+ key_data and value_data fields as raw binary data to avoid re-serialization.
+
+ Returns:
+ bytes: The serialized bytes.
+ """
+ obj = {
+ "length": self.length,
+ "key_hash": msgpack.ExtType(1, self.key_hash),
+ "key_data": msgpack.ExtType(2, self.key_data),
+ "value_data": msgpack.ExtType(3, self.value_data),
+ }
+ return msgpack.packb(obj)
+
+ @staticmethod
+ def deserialize(data: bytes) -> "StorageItem":
+ """Deserialize bytes back into a StorageItem using MessagePack.
+
+ This extracts the fields from the MessagePack dict back into
+ a StorageItem object.
+
+ Args:
+ data (bytes): Serialized bytes
+
+ Returns:
+ StorageItem: Deserialized StorageItem object.
+ """
+ obj = msgpack.unpackb(data)
+ key_hash = obj["key_hash"].data
+ key_data = obj["key_data"].data
+ value_data = obj["value_data"].data
+
+ return StorageItem(
+ length=obj["length"],
+ key_hash=key_hash,
+ key_data=key_data,
+ value_data=value_data,
+ )
+
+
+class CacheStorage(ABC):
+ @abstractmethod
+ def check_config(
+ self,
+ cache_config: Optional[CacheConfig] = None,
+ raise_error: Optional[bool] = True,
+ ) -> bool:
+ """Check whether the CacheConfig is legal.
+
+ Args:
+ cache_config (Optional[CacheConfig]): Cache config.
+ raise_error (Optional[bool]): Whether raise error if illegal.
+
+ Returns:
+ ValueError: Error when raise_error is True and config is illegal.
+ """
+
+ def support_async(self) -> bool:
+ return False
+
+ @abstractmethod
+ def get(
+ self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
+ ) -> Optional[StorageItem]:
+ """Retrieve a storage item from the cache using the provided key.
+
+ Args:
+ key (CacheKey[K]): The key to get cache
+ cache_config (Optional[CacheConfig]): Cache config
+
+ Returns:
+ Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
+ """
+
+ async def aget(
+ self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
+ ) -> Optional[StorageItem]:
+ """Retrieve a storage item from the cache using the provided key asynchronously.
+
+ Args:
+ key (CacheKey[K]): The key to get cache
+ cache_config (Optional[CacheConfig]): Cache config
+
+ Returns:
+ Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def set(
+ self,
+ key: CacheKey[K],
+ value: CacheValue[V],
+ cache_config: Optional[CacheConfig] = None,
+ ) -> None:
+ """Set a value in the cache for the provided key asynchronously.
+
+ Args:
+ key (CacheKey[K]): The key to set to cache
+ value (CacheValue[V]): The value to set to cache
+ cache_config (Optional[CacheConfig]): Cache config
+ """
+
+ async def aset(
+ self,
+ key: CacheKey[K],
+ value: CacheValue[V],
+ cache_config: Optional[CacheConfig] = None,
+ ) -> None:
+ """Set a value in the cache for the provided key asynchronously.
+
+ Args:
+ key (CacheKey[K]): The key to set to cache
+ value (CacheValue[V]): The value to set to cache
+ cache_config (Optional[CacheConfig]): Cache config
+ """
+ raise NotImplementedError
+
+
+class MemoryCacheStorage(CacheStorage):
+ def __init__(self, max_memory_mb: int = 256):
+ self.cache = OrderedDict()
+ self.max_memory = max_memory_mb * 1024 * 1024
+ self.current_memory_usage = 0
+
+ def check_config(
+ self,
+ cache_config: Optional[CacheConfig] = None,
+ raise_error: Optional[bool] = True,
+ ) -> bool:
+ if (
+ cache_config
+ and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
+ ):
+ if raise_error:
+ raise ValueError(
+ "MemoryCacheStorage only supports 'EXACT_MATCH' retrieval policy"
+ )
+ return False
+ return True
+
+ def get(
+ self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
+ ) -> Optional[StorageItem]:
+ self.check_config(cache_config, raise_error=True)
+ # Exact match retrieval
+ key_hash = hash(key)
+ item: StorageItem = self.cache.get(key_hash)
+ logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
+
+ if not item:
+ return None
+ # Move the item to the end of the OrderedDict to signify recent use.
+ self.cache.move_to_end(key_hash)
+ return item
+
+ def set(
+ self,
+ key: CacheKey[K],
+ value: CacheValue[V],
+ cache_config: Optional[CacheConfig] = None,
+ ) -> None:
+ key_hash = hash(key)
+ item = StorageItem.build_from_kv(key, value)
+ # Calculate memory size of the new entry
+ new_entry_size = _get_object_bytes(item)
+ # Evict entries if necessary
+ while self.current_memory_usage + new_entry_size > self.max_memory:
+ self._apply_cache_policy(cache_config)
+
+ # Store the item in the cache.
+ self.cache[key_hash] = item
+ self.current_memory_usage += new_entry_size
+ logger.debug(f"MemoryCacheStorage set key {key}, hash {key_hash}, item: {item}")
+
+ def exists(
+ self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
+ ) -> bool:
+ return self.get(key, cache_config) is not None
+
+ def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):
+ # Remove the oldest/newest item based on the cache policy.
+ if cache_config and cache_config.cache_policy == CachePolicy.FIFO:
+ self.cache.popitem(last=False)
+ else: # Default is LRU
+ self.cache.popitem(last=True)
diff --git a/pilot/cache/storage/disk/__init__.py b/pilot/cache/storage/disk/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/cache/storage/disk/disk_storage.py b/pilot/cache/storage/disk/disk_storage.py
new file mode 100644
index 000000000..04fb19c6e
--- /dev/null
+++ b/pilot/cache/storage/disk/disk_storage.py
@@ -0,0 +1,93 @@
+from typing import Optional
+import logging
+from pilot.cache.base import (
+ K,
+ V,
+ CacheKey,
+ CacheValue,
+ CacheConfig,
+ RetrievalPolicy,
+ CachePolicy,
+)
+from pilot.cache.storage.base import StorageItem, CacheStorage
+from rocksdict import Rdict
+from rocksdict import Rdict, Options, SliceTransform, PlainTableFactoryOptions
+
+
+logger = logging.getLogger(__name__)
+
+
+def db_options(
+ mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
+):
+ opt = Options()
+ # create table
+ opt.create_if_missing(True)
+ # config to more jobs, default 2
+ opt.set_max_background_jobs(background_threads)
+ # configure mem-table to a large value
+ opt.set_write_buffer_size(mem_table_buffer_mb * 1024 * 1024)
+ # opt.set_write_buffer_size(1024)
+ # opt.set_level_zero_file_num_compaction_trigger(4)
+ # configure l0 and l1 size, let them have the same size (1 GB)
+ # opt.set_max_bytes_for_level_base(0x40000000)
+ # 256 MB file size
+ # opt.set_target_file_size_base(0x10000000)
+ # use a smaller compaction multiplier
+ # opt.set_max_bytes_for_level_multiplier(4.0)
+ # use 8-byte prefix (2 ^ 64 is far enough for transaction counts)
+ # opt.set_prefix_extractor(SliceTransform.create_max_len_prefix(8))
+ # set to plain-table
+ # opt.set_plain_table_factory(PlainTableFactoryOptions())
+ return opt
+
+
+class DiskCacheStorage(CacheStorage):
+ def __init__(
+ self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
+ ) -> None:
+ super().__init__()
+ self.db: Rdict = Rdict(
+ persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
+ )
+
+ def check_config(
+ self,
+ cache_config: Optional[CacheConfig] = None,
+ raise_error: Optional[bool] = True,
+ ) -> bool:
+ if (
+ cache_config
+ and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
+ ):
+ if raise_error:
+ raise ValueError(
+ "DiskCacheStorage only supports 'EXACT_MATCH' retrieval policy"
+ )
+ return False
+ return True
+
+ def get(
+ self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
+ ) -> Optional[StorageItem]:
+ self.check_config(cache_config, raise_error=True)
+
+ # Exact match retrieval
+ key_hash = key.get_hash_bytes()
+ item_bytes = self.db.get(key_hash)
+ if not item_bytes:
+ return None
+ item = StorageItem.deserialize(item_bytes)
+ logger.debug(f"Read file cache, key: {key}, storage item: {item}")
+ return item
+
+ def set(
+ self,
+ key: CacheKey[K],
+ value: CacheValue[V],
+ cache_config: Optional[CacheConfig] = None,
+ ) -> None:
+ item = StorageItem.build_from_kv(key, value)
+ key_hash = item.key_hash
+ self.db[key_hash] = item.serialize()
+ logger.debug(f"Save file cache, key: {key}, value: {value}")
diff --git a/pilot/cache/storage/tests/__init__.py b/pilot/cache/storage/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/cache/storage/tests/test_storage.py b/pilot/cache/storage/tests/test_storage.py
new file mode 100644
index 000000000..489873d08
--- /dev/null
+++ b/pilot/cache/storage/tests/test_storage.py
@@ -0,0 +1,53 @@
+import pytest
+from ..base import StorageItem
+from pilot.utils.memory_utils import _get_object_bytes
+
+
+def test_build_from():
+ key_hash = b"key_hash"
+ key_data = b"key_data"
+ value_data = b"value_data"
+ item = StorageItem.build_from(key_hash, key_data, value_data)
+
+ assert item.key_hash == key_hash
+ assert item.key_data == key_data
+ assert item.value_data == value_data
+ assert item.length == 32 + _get_object_bytes(key_hash) + _get_object_bytes(
+ key_data
+ ) + _get_object_bytes(value_data)
+
+
+def test_build_from_kv():
+ class MockCacheKey:
+ def get_hash_bytes(self):
+ return b"key_hash"
+
+ def serialize(self):
+ return b"key_data"
+
+ class MockCacheValue:
+ def serialize(self):
+ return b"value_data"
+
+ key = MockCacheKey()
+ value = MockCacheValue()
+ item = StorageItem.build_from_kv(key, value)
+
+ assert item.key_hash == key.get_hash_bytes()
+ assert item.key_data == key.serialize()
+ assert item.value_data == value.serialize()
+
+
+def test_serialize_deserialize():
+ key_hash = b"key_hash"
+ key_data = b"key_data"
+ value_data = b"value_data"
+ item = StorageItem.build_from(key_hash, key_data, value_data)
+
+ serialized = item.serialize()
+ deserialized = StorageItem.deserialize(serialized)
+
+ assert deserialized.key_hash == item.key_hash
+ assert deserialized.key_data == item.key_data
+ assert deserialized.value_data == item.value_data
+ assert deserialized.length == item.length
diff --git a/pilot/component.py b/pilot/component.py
index 16013ee17..d79a8d395 100644
--- a/pilot/component.py
+++ b/pilot/component.py
@@ -48,6 +48,7 @@ class ComponentType(str, Enum):
MODEL_CONTROLLER = "dbgpt_model_controller"
MODEL_REGISTRY = "dbgpt_model_registry"
MODEL_API_SERVER = "dbgpt_model_api_server"
+ MODEL_CACHE_MANAGER = "dbgpt_model_cache_manager"
AGENT_HUB = "dbgpt_agent_hub"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
TRACER = "dbgpt_tracer"
diff --git a/pilot/configs/config.py b/pilot/configs/config.py
index b263b46c4..f93cd7b83 100644
--- a/pilot/configs/config.py
+++ b/pilot/configs/config.py
@@ -253,6 +253,19 @@ class Config(metaclass=Singleton):
### Temporary configuration
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
+ self.MODEL_CACHE_ENABLE: bool = (
+ os.getenv("MODEL_CACHE_ENABLE", "True").lower() == "true"
+ )
+ self.MODEL_CACHE_STORAGE_TYPE: str = os.getenv(
+ "MODEL_CACHE_STORAGE_TYPE", "disk"
+ )
+ self.MODEL_CACHE_MAX_MEMORY_MB: int = int(
+ os.getenv("MODEL_CACHE_MAX_MEMORY_MB", 256)
+ )
+ self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv(
+ "MODEL_CACHE_STORAGE_DISK_DIR"
+ )
+
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value
diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py
index 0e1fb3d40..fec343f2a 100644
--- a/pilot/configs/model_config.py
+++ b/pilot/configs/model_config.py
@@ -2,6 +2,7 @@
# -*- coding:utf-8 -*-
import os
+from functools import cache
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models")
@@ -14,6 +15,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
+MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
current_directory = os.getcwd()
@@ -21,6 +23,7 @@ new_directory = PILOT_PATH
os.chdir(new_directory)
+@cache
def get_device() -> str:
try:
import torch
diff --git a/pilot/model/base.py b/pilot/model/base.py
index 48480b94b..d54ac6d57 100644
--- a/pilot/model/base.py
+++ b/pilot/model/base.py
@@ -3,7 +3,9 @@
from enum import Enum
from typing import TypedDict, Optional, Dict, List, Any
+
from dataclasses import dataclass, asdict
+import time
from datetime import datetime
from pilot.utils.parameter_utils import ParameterDescription
@@ -47,6 +49,79 @@ class WorkerApplyType(str, Enum):
UPDATE_PARAMS = "update_params"
+@dataclass
+class ModelInferenceMetrics:
+ """A class to represent metrics for assessing the inference performance of a LLM."""
+
+ start_time_ms: Optional[int] = None
+ """The timestamp (in milliseconds) when the model inference starts."""
+
+ end_time_ms: Optional[int] = None
+ """The timestamp (in milliseconds) when the model inference ends."""
+
+ current_time_ms: Optional[int] = None
+ """The current timestamp (in milliseconds) when the model inference return partially output(stream)."""
+
+ first_token_time_ms: Optional[int] = None
+ """The timestamp (in milliseconds) when the first token is generated."""
+
+ first_completion_time_ms: Optional[int] = None
+ """The timestamp (in milliseconds) when the first completion is generated."""
+
+ first_completion_tokens: Optional[int] = None
+ """The number of tokens when the first completion is generated."""
+
+ prompt_tokens: Optional[int] = None
+ """The number of tokens in the input prompt."""
+
+ completion_tokens: Optional[int] = None
+ """The number of tokens in the generated completion."""
+
+ total_tokens: Optional[int] = None
+ """The total number of tokens (prompt plus completion)."""
+
+ speed_per_second: Optional[float] = None
+ """The average number of tokens generated per second."""
+
+ @staticmethod
+ def create_metrics(
+ last_metrics: Optional["ModelInferenceMetrics"] = None,
+ ) -> "ModelInferenceMetrics":
+ start_time_ms = last_metrics.start_time_ms if last_metrics else None
+ first_token_time_ms = last_metrics.first_token_time_ms if last_metrics else None
+ first_completion_time_ms = (
+ last_metrics.first_completion_time_ms if last_metrics else None
+ )
+ first_completion_tokens = (
+ last_metrics.first_completion_tokens if last_metrics else None
+ )
+ prompt_tokens = last_metrics.prompt_tokens if last_metrics else None
+ completion_tokens = last_metrics.completion_tokens if last_metrics else None
+ total_tokens = last_metrics.total_tokens if last_metrics else None
+ speed_per_second = last_metrics.speed_per_second if last_metrics else None
+
+ if not start_time_ms:
+ start_time_ms = time.time_ns() // 1_000_000
+ current_time_ms = time.time_ns() // 1_000_000
+ end_time_ms = current_time_ms
+
+ return ModelInferenceMetrics(
+ start_time_ms=start_time_ms,
+ end_time_ms=end_time_ms,
+ current_time_ms=current_time_ms,
+ first_token_time_ms=first_token_time_ms,
+ first_completion_time_ms=first_completion_time_ms,
+ first_completion_tokens=first_completion_tokens,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ speed_per_second=speed_per_second,
+ )
+
+ def to_dict(self) -> Dict:
+ return asdict(self)
+
+
@dataclass
class ModelOutput:
text: str
@@ -54,6 +129,9 @@ class ModelOutput:
model_context: Dict = None
finish_reason: str = None
usage: Dict[str, Any] = None
+ metrics: Optional[ModelInferenceMetrics] = None
+
+ """Some metrics for model inference"""
def to_dict(self) -> Dict:
return asdict(self)
diff --git a/pilot/model/cluster/base.py b/pilot/model/cluster/base.py
index 9d22161b1..45e46ab3e 100644
--- a/pilot/model/cluster/base.py
+++ b/pilot/model/cluster/base.py
@@ -17,8 +17,12 @@ class PromptRequest(BaseModel):
temperature: float = None
max_new_tokens: int = None
stop: str = None
+ stop_token_ids: List[int] = []
+ context_len: int = None
echo: bool = True
span_id: str = None
+ metrics: bool = False
+ """Whether to return metrics of inference"""
class EmbeddingsRequest(BaseModel):
diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py
index 44a476f20..c798e3075 100644
--- a/pilot/model/cluster/worker/default_worker.py
+++ b/pilot/model/cluster/worker/default_worker.py
@@ -1,10 +1,13 @@
import os
import logging
-from typing import Dict, Iterator, List
+
+from typing import Dict, Iterator, List, Optional
+import time
+import traceback
from pilot.configs.model_config import get_device
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
-from pilot.model.base import ModelOutput
+from pilot.model.base import ModelOutput, ModelInferenceMetrics
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
@@ -60,7 +63,7 @@ class DefaultModelWorker(ModelWorker):
self.ml: ModelLoader = ModelLoader(
model_path=self.model_path, model_name=self.model_name
)
- # TODO read context len from model config
+ # Default model context len
self.context_len = 2048
def model_param_class(self) -> ModelParameters:
@@ -111,6 +114,12 @@ class DefaultModelWorker(ModelWorker):
self.model, self.tokenizer = self.ml.loader_with_params(
model_params, self.llm_adapter
)
+ model_max_length = _parse_model_max_length(self.model, self.tokenizer)
+ if model_max_length:
+ logger.info(
+ f"Parse model max length {model_max_length} from model {self.model_name}."
+ )
+ self.context_len = model_max_length
def stop(self) -> None:
if not self.model:
@@ -138,14 +147,29 @@ class DefaultModelWorker(ModelWorker):
)
previous_response = ""
+ last_metrics = ModelInferenceMetrics.create_metrics()
+ is_first_generate = True
+ context_len = params.get("context_len") or self.context_len
for output in generate_stream_func(
- self.model, self.tokenizer, params, get_device(), self.context_len
+ self.model, self.tokenizer, params, get_device(), context_len
):
- model_output, incremental_output, output_str = self._handle_output(
- output, previous_response, model_context
+ (
+ model_output,
+ incremental_output,
+ output_str,
+ current_metrics,
+ ) = self._handle_output(
+ output,
+ previous_response,
+ model_context,
+ last_metrics,
+ is_first_generate,
)
+ if is_first_generate:
+ is_first_generate = False
previous_response = output_str
+ last_metrics = current_metrics
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
@@ -183,14 +207,30 @@ class DefaultModelWorker(ModelWorker):
)
previous_response = ""
+ context_len = params.get("context_len") or self.context_len
+ last_metrics = ModelInferenceMetrics.create_metrics()
+ is_first_generate = True
async for output in generate_stream_func(
- self.model, self.tokenizer, params, get_device(), self.context_len
+ self.model, self.tokenizer, params, get_device(), context_len
):
- model_output, incremental_output, output_str = self._handle_output(
- output, previous_response, model_context
+ (
+ model_output,
+ incremental_output,
+ output_str,
+ current_metrics,
+ ) = self._handle_output(
+ output,
+ previous_response,
+ model_context,
+ last_metrics,
+ is_first_generate,
)
+ if is_first_generate:
+ is_first_generate = False
+
previous_response = output_str
+ last_metrics = current_metrics
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
@@ -255,7 +295,14 @@ class DefaultModelWorker(ModelWorker):
return params, model_context, generate_stream_func, model_span
- def _handle_output(self, output, previous_response, model_context):
+ def _handle_output(
+ self,
+ output,
+ previous_response,
+ model_context,
+ last_metrics: ModelInferenceMetrics,
+ is_first_generate: bool,
+ ):
finish_reason = None
usage = None
if isinstance(output, dict):
@@ -266,24 +313,91 @@ class DefaultModelWorker(ModelWorker):
logger.info(f"finish_reason: {finish_reason}")
incremental_output = output[len(previous_response) :]
print(incremental_output, end="", flush=True)
+
+ metrics = _new_metrics_from_model_output(last_metrics, is_first_generate, usage)
model_output = ModelOutput(
text=output,
error_code=0,
model_context=model_context,
finish_reason=finish_reason,
usage=usage,
+ metrics=metrics,
)
- return model_output, incremental_output, output
+ return model_output, incremental_output, output, metrics
def _handle_exception(self, e):
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
if _torch_imported and isinstance(e, torch.cuda.CudaError):
model_output = ModelOutput(
- text="**GPU OutOfMemory, Please Refresh.**", error_code=0
+ text="**GPU OutOfMemory, Please Refresh.**", error_code=1
)
else:
+ msg = traceback.format_exc()
+ logger.error(f"Model inference error, detail: {msg}")
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
- error_code=0,
+ error_code=1,
)
return model_output
+
+
+def _parse_model_max_length(model, tokenizer) -> Optional[int]:
+ if not (tokenizer or model):
+ return None
+ try:
+ if tokenizer and hasattr(tokenizer, "model_max_length"):
+ return tokenizer.model_max_length
+ if model and hasattr(model, "config"):
+ model_config = model.config
+ if hasattr(model_config, "max_sequence_length"):
+ return model_config.max_sequence_length
+ if hasattr(model_config, "max_position_embeddings"):
+ return model_config.max_position_embeddings
+ except Exception:
+ return None
+
+
+def _new_metrics_from_model_output(
+ last_metric: ModelInferenceMetrics,
+ is_first_generate: bool,
+ usage: Optional[Dict] = None,
+) -> ModelInferenceMetrics:
+ metrics = ModelInferenceMetrics.create_metrics(last_metric)
+ if is_first_generate:
+ logger.info(f"is_first_generate, usage: {usage}")
+ metrics.first_completion_time_ms = time.time_ns() // 1_000_000
+
+ if not usage or not isinstance(usage, dict):
+ return metrics
+ prompt_tokens = usage.get("prompt_tokens")
+ completion_tokens = usage.get("completion_tokens")
+ total_tokens = usage.get("total_tokens")
+
+ if prompt_tokens is None:
+ prompt_tokens = metrics.prompt_tokens
+ if completion_tokens is None:
+ completion_tokens = metrics.completion_tokens
+ if total_tokens is None:
+ total_tokens = metrics.total_tokens
+
+ if is_first_generate and (completion_tokens is not None):
+ # completion_tokens == 0 is prefill
+ metrics.first_completion_tokens = completion_tokens
+ if completion_tokens == 1:
+ metrics.first_token_time_ms = metrics.first_completion_time_ms
+
+ if prompt_tokens:
+ metrics.prompt_tokens = prompt_tokens
+ if completion_tokens:
+ metrics.completion_tokens = completion_tokens
+ if total_tokens:
+ metrics.total_tokens = total_tokens
+ elif prompt_tokens and completion_tokens:
+ total_tokens = prompt_tokens + completion_tokens
+ metrics.total_tokens = total_tokens
+
+ if total_tokens:
+ # time cost(seconds)
+ duration = (metrics.current_time_ms - metrics.start_time_ms) / 1000.0
+ metrics.speed_per_second = total_tokens / duration
+ return metrics
diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py
index d67519f59..9c1c4f7b3 100644
--- a/pilot/model/cluster/worker/manager.py
+++ b/pilot/model/cluster/worker/manager.py
@@ -119,7 +119,10 @@ class LocalWorkerManager(WorkerManager):
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
)
for listener in self.start_listeners:
- listener(self)
+ if asyncio.iscoroutinefunction(listener):
+ await listener(self)
+ else:
+ listener(self)
async def stop(self, ignore_exception: bool = False):
if not self.run_data.stop_event.is_set():
@@ -325,7 +328,7 @@ class LocalWorkerManager(WorkerManager):
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
- error_code=0,
+ error_code=1,
)
return
async with worker_run_data.semaphore:
@@ -355,7 +358,7 @@ class LocalWorkerManager(WorkerManager):
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
- error_code=0,
+ error_code=1,
)
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
@@ -996,11 +999,17 @@ def run_worker_manager(
port: int = None,
embedding_model_name: str = None,
embedding_model_path: str = None,
+ start_listener: Callable[["WorkerManager"], None] = None,
+ **kwargs,
):
global worker_manager
worker_params: ModelWorkerParameters = _parse_worker_params(
- model_name=model_name, model_path=model_path, standalone=standalone, port=port
+ model_name=model_name,
+ model_path=model_path,
+ standalone=standalone,
+ port=port,
+ **kwargs,
)
setup_logging(
@@ -1029,6 +1038,8 @@ def run_worker_manager(
worker_manager, embedding_model_name, embedding_model_path
)
+ worker_manager.after_start(start_listener)
+
if include_router:
app.include_router(router, prefix="/api")
diff --git a/pilot/model/cluster/worker/remote_manager.py b/pilot/model/cluster/worker/remote_manager.py
index 61b608cc7..4047f428e 100644
--- a/pilot/model/cluster/worker/remote_manager.py
+++ b/pilot/model/cluster/worker/remote_manager.py
@@ -15,7 +15,10 @@ class RemoteWorkerManager(LocalWorkerManager):
async def start(self):
for listener in self.start_listeners:
- listener(self)
+ if asyncio.iscoroutinefunction(listener):
+ await listener(self)
+ else:
+ listener(self)
async def stop(self, ignore_exception: bool = False):
pass
diff --git a/pilot/model/llm_out/vllm_llm.py b/pilot/model/llm_out/vllm_llm.py
index 07d43dc74..de108c87c 100644
--- a/pilot/model/llm_out/vllm_llm.py
+++ b/pilot/model/llm_out/vllm_llm.py
@@ -1,9 +1,13 @@
from typing import Dict
+import os
from vllm import AsyncLLMEngine
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
+_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
+
+
async def generate_stream(
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int
):
@@ -37,15 +41,29 @@ async def generate_stream(
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0
+ gen_params = {
+ "stop": list(stop),
+ "ignore_eos": False,
+ }
+ prompt_token_ids = None
+ if _IS_BENCHMARK:
+ gen_params["stop"] = []
+ gen_params["ignore_eos"] = True
+ prompt_len = context_len - max_new_tokens - 2
+ prompt_token_ids = tokenizer([prompt]).input_ids[0]
+ prompt_token_ids = prompt_token_ids[-prompt_len:]
sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=False,
- stop=list(stop),
max_tokens=max_new_tokens,
+ **gen_params
+ )
+
+ results_generator = model.generate(
+ prompt, sampling_params, request_id, prompt_token_ids=prompt_token_ids
)
- results_generator = model.generate(prompt, sampling_params, request_id)
async for request_output in results_generator:
prompt = request_output.prompt
if echo:
@@ -53,4 +71,25 @@ async def generate_stream(
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
- yield {"text": text_outputs, "error_code": 0, "usage": {}}
+
+ # Note: usage is not supported yet
+ prompt_tokens = len(request_output.prompt_token_ids)
+ completion_tokens = sum(
+ len(output.token_ids) for output in request_output.outputs
+ )
+ usage = {
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": completion_tokens,
+ "total_tokens": prompt_tokens + completion_tokens,
+ }
+ finish_reason = (
+ request_output.outputs[0].finish_reason
+ if len(request_output.outputs) == 1
+ else [output.finish_reason for output in request_output.outputs]
+ )
+ yield {
+ "text": text_outputs,
+ "error_code": 0,
+ "usage": usage,
+ "finish_reason": finish_reason,
+ }
diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py
index cc8c07143..6ec0f7774 100644
--- a/pilot/model/model_adapter.py
+++ b/pilot/model/model_adapter.py
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
thread_local = threading.local()
-
+_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
_OLD_MODELS = [
"llama-cpp",
@@ -187,9 +187,12 @@ class LLMModelAdaper:
model_context["has_format_prompt"] = True
params["prompt"] = new_prompt
- # Overwrite model params:
- params["stop"] = conv.stop_str
- params["stop_token_ids"] = conv.stop_token_ids
+ custom_stop = params.get("stop")
+ custom_stop_token_ids = params.get("stop_token_ids")
+
+ # Prefer the value passed in from the input parameter
+ params["stop"] = custom_stop or conv.stop_str
+ params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids
return params, model_context
@@ -242,9 +245,16 @@ class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
return self._adapter.load_model(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
- from fastchat.model.model_adapter import get_generate_stream_function
+ if _IS_BENCHMARK:
+ from pilot.utils.benchmarks.llm.fastchat_benchmarks_inference import (
+ generate_stream,
+ )
- return get_generate_stream_function(model, model_path)
+ return generate_stream
+ else:
+ from fastchat.model.model_adapter import get_generate_stream_function
+
+ return get_generate_stream_function(model, model_path)
def get_default_conv_template(
self, model_name: str, model_path: str
diff --git a/pilot/model/operator/__init__.py b/pilot/model/operator/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/model/operator/model_operator.py b/pilot/model/operator/model_operator.py
new file mode 100644
index 000000000..6486e8373
--- /dev/null
+++ b/pilot/model/operator/model_operator.py
@@ -0,0 +1,300 @@
+from typing import AsyncIterator, Dict, List, Union
+import logging
+from pilot.awel import (
+ BranchFunc,
+ StreamifyAbsOperator,
+ BranchOperator,
+ MapOperator,
+ TransformStreamAbsOperator,
+)
+from pilot.awel.operator.base import BaseOperator
+from pilot.model.base import ModelOutput
+from pilot.model.cluster import WorkerManager
+from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
+
+logger = logging.getLogger(__name__)
+
+_LLM_MODEL_INPUT_VALUE_KEY = "llm_model_input_value"
+_LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
+
+
+class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
+ """Operator for streaming processing of model outputs.
+
+ Args:
+ worker_manager (WorkerManager): The manager that handles worker processes for model inference.
+ **kwargs: Additional keyword arguments.
+
+ Methods:
+ streamify: Asynchronously processes a stream of inputs, yielding model outputs.
+ """
+
+ def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.worker_manager = worker_manager
+
+ async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
+ """Process inputs as a stream and yield model outputs.
+
+ Args:
+ input_value (Dict): The input value for the model.
+
+ Returns:
+ AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
+ """
+ async for out in self.worker_manager.generate_stream(input_value):
+ yield out
+
+
+class ModelOperator(MapOperator[Dict, ModelOutput]):
+ """Operator for map-based processing of model outputs.
+
+ Args:
+ worker_manager (WorkerManager): Manager for handling worker processes.
+ **kwargs: Additional keyword arguments.
+
+ Methods:
+ map: Asynchronously processes a single input and returns the model output.
+ """
+
+ def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
+ self.worker_manager = worker_manager
+ super().__init__(**kwargs)
+
+ async def map(self, input_value: Dict) -> ModelOutput:
+ """Process a single input and return the model output.
+
+ Args:
+ input_value (Dict): The input value for the model.
+
+ Returns:
+ ModelOutput: The output from the model.
+ """
+ return await self.worker_manager.generate(input_value)
+
+
+class CachedModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
+ """Operator for streaming processing of model outputs with caching.
+
+ Args:
+ cache_manager (CacheManager): The cache manager to handle caching operations.
+ **kwargs: Additional keyword arguments.
+
+ Methods:
+ streamify: Processes a stream of inputs with cache support, yielding model outputs.
+ """
+
+ def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self._cache_manager = cache_manager
+ self._client = LLMCacheClient(cache_manager)
+
+ async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
+ """Process inputs as a stream with cache support and yield model outputs.
+
+ Args:
+ input_value (Dict): The input value for the model.
+
+ Returns:
+ AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
+ """
+ cache_dict = _parse_cache_key_dict(input_value)
+ llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
+ llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
+ logger.info(f"llm_cache_value: {llm_cache_value}")
+ for out in llm_cache_value.get_value().output:
+ yield out
+
+
+class CachedModelOperator(MapOperator[Dict, ModelOutput]):
+ """Operator for map-based processing of model outputs with caching.
+
+ Args:
+ cache_manager (CacheManager): Manager for caching operations.
+ **kwargs: Additional keyword arguments.
+
+ Methods:
+ map: Processes a single input with cache support and returns the model output.
+ """
+
+ def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self._cache_manager = cache_manager
+ self._client = LLMCacheClient(cache_manager)
+
+ async def map(self, input_value: Dict) -> ModelOutput:
+ """Process a single input with cache support and return the model output.
+
+ Args:
+ input_value (Dict): The input value for the model.
+
+ Returns:
+ ModelOutput: The output from the model.
+ """
+ cache_dict = _parse_cache_key_dict(input_value)
+ llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
+ llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
+ logger.info(f"llm_cache_value: {llm_cache_value}")
+ return llm_cache_value.get_value().output
+
+
+class ModelCacheBranchOperator(BranchOperator[Dict, Dict]):
+ """
+ A branch operator that decides whether to use cached data or to process data using the model.
+
+ Args:
+ cache_manager (CacheManager): The cache manager for managing cache operations.
+ model_task_name (str): The name of the task to process data using the model.
+ cache_task_name (str): The name of the task to process data using the cache.
+ **kwargs: Additional keyword arguments.
+ """
+
+ def __init__(
+ self,
+ cache_manager: CacheManager,
+ model_task_name: str,
+ cache_task_name: str,
+ **kwargs,
+ ):
+ super().__init__(branches=None, **kwargs)
+ self._cache_manager = cache_manager
+ self._client = LLMCacheClient(cache_manager)
+ self._model_task_name = model_task_name
+ self._cache_task_name = cache_task_name
+
+ async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
+ """Defines branch logic based on cache availability.
+
+ Returns:
+ Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names.
+ """
+
+ async def check_cache_true(input_value: Dict) -> bool:
+ # Check if the cache contains the result for the given input
+ cache_dict = _parse_cache_key_dict(input_value)
+ cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
+ cache_value = await self._client.get(cache_key)
+ logger.debug(
+ f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
+ )
+ await self.current_dag_context.save_to_share_data(
+ _LLM_MODEL_INPUT_VALUE_KEY, cache_key
+ )
+ return True if cache_value else False
+
+ async def check_cache_false(input_value: Dict):
+ # Inverse of check_cache_true
+ return not await check_cache_true(input_value)
+
+ return {
+ check_cache_true: self._cache_task_name,
+ check_cache_false: self._model_task_name,
+ }
+
+
+class ModelStreamSaveCacheOperator(
+ TransformStreamAbsOperator[ModelOutput, ModelOutput]
+):
+ """An operator to save the stream of model outputs to cache.
+
+ Args:
+ cache_manager (CacheManager): The cache manager for handling cache operations.
+ **kwargs: Additional keyword arguments.
+ """
+
+ def __init__(self, cache_manager: CacheManager, **kwargs):
+ self._cache_manager = cache_manager
+ self._client = LLMCacheClient(cache_manager)
+ super().__init__(**kwargs)
+
+ async def transform_stream(
+ self, input_value: AsyncIterator[ModelOutput]
+ ) -> AsyncIterator[ModelOutput]:
+ """Transforms the input stream by saving the outputs to cache.
+
+ Args:
+ input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs.
+
+ Returns:
+ AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache.
+ """
+ llm_cache_key: LLMCacheKey = None
+ outputs = []
+ async for out in input_value:
+ if not llm_cache_key:
+ llm_cache_key = await self.current_dag_context.get_share_data(
+ _LLM_MODEL_INPUT_VALUE_KEY
+ )
+ outputs.append(out)
+ yield out
+ if llm_cache_key and _is_success_model_output(outputs):
+ llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
+ await self._client.set(llm_cache_key, llm_cache_value)
+
+
+class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
+ """An operator to save a single model output to cache.
+
+ Args:
+ cache_manager (CacheManager): The cache manager for handling cache operations.
+ **kwargs: Additional keyword arguments.
+ """
+
+ def __init__(self, cache_manager: CacheManager, **kwargs):
+ self._cache_manager = cache_manager
+ self._client = LLMCacheClient(cache_manager)
+ super().__init__(**kwargs)
+
+ async def map(self, input_value: ModelOutput) -> ModelOutput:
+ """Saves a single model output to cache and returns it.
+
+ Args:
+ input_value (ModelOutput): The output from the model to be cached.
+
+ Returns:
+ ModelOutput: The same input model output.
+ """
+ llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
+ _LLM_MODEL_INPUT_VALUE_KEY
+ )
+ llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
+ if llm_cache_key and _is_success_model_output(input_value):
+ await self._client.set(llm_cache_key, llm_cache_value)
+ return input_value
+
+
+def _parse_cache_key_dict(input_value: Dict) -> Dict:
+ """Parses and extracts relevant fields from input to form a cache key dictionary.
+
+ Args:
+ input_value (Dict): The input dictionary containing model and prompt parameters.
+
+ Returns:
+ Dict: A dictionary used for generating cache keys.
+ """
+ prompt: str = input_value.get("prompt")
+ if prompt:
+ prompt = prompt.strip()
+ return {
+ "prompt": prompt,
+ "model_name": input_value.get("model"),
+ "temperature": input_value.get("temperature"),
+ "max_new_tokens": input_value.get("max_new_tokens"),
+ "top_p": input_value.get("top_p", "1.0"),
+ # TODO pass model_type
+ "model_type": input_value.get("model_type", "huggingface"),
+ }
+
+
+def _is_success_model_output(out: Union[Dict, ModelOutput, List[ModelOutput]]) -> bool:
+ if not out:
+ return False
+ if isinstance(out, list):
+ # check last model output
+ out = out[-1]
+ error_code = 0
+ if isinstance(out, ModelOutput):
+ error_code = out.error_code
+ else:
+ error_code = int(out.get("error_code", 0))
+ return error_code == 0
diff --git a/pilot/model/proxy/llms/chatgpt.py b/pilot/model/proxy/llms/chatgpt.py
index 4557599ee..9e6d1a20a 100644
--- a/pilot/model/proxy/llms/chatgpt.py
+++ b/pilot/model/proxy/llms/chatgpt.py
@@ -4,10 +4,11 @@
import os
from typing import List
import logging
-
+import importlib.metadata as metadata
from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.model.parameter import ProxyModelParameters
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
+import httpx
logger = logging.getLogger(__name__)
@@ -57,40 +58,37 @@ def _initialize_openai(params: ProxyModelParameters):
return openai_params
-def __convert_2_gpt_messages(messages: List[ModelMessage]):
+def _initialize_openai_v1(params: ProxyModelParameters):
+ try:
+ from openai import OpenAI
+ except ImportError as exc:
+ raise ValueError(
+ "Could not import python package: openai "
+ "Please install openai by command `pip install openai"
+ )
- chat_round = 0
- gpt_messages = []
+ api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
- last_usr_message = ""
- system_messages = []
+ base_url = params.proxy_api_base or os.getenv(
+ "OPENAI_API_TYPE",
+ os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
+ )
+ api_key = params.proxy_api_key or os.getenv(
+ "OPENAI_API_KEY",
+ os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
+ )
+ api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")
- for message in messages:
- if message.role == ModelMessageRoleType.HUMAN:
- last_usr_message = message.content
- elif message.role == ModelMessageRoleType.SYSTEM:
- system_messages.append(message.content)
- elif message.role == ModelMessageRoleType.AI:
- last_ai_message = message.content
- gpt_messages.append({"role": "user", "content": last_usr_message})
- gpt_messages.append({"role": "assistant", "content": last_ai_message})
+ if not base_url and params.proxy_server_url:
+ # Adapt previous proxy_server_url configuration
+ base_url = params.proxy_server_url.split("/chat/completions")[0]
- # build last user messge
-
- if len(system_messages) >0:
- if len(system_messages) > 1:
- end_message = system_messages[-1]
- else:
- last_message = messages[-1]
- if last_message.role == ModelMessageRoleType.HUMAN:
- end_message = system_messages[-1] + "\n" + last_message.content
- else:
- end_message = system_messages[-1]
- else:
- last_message = messages[-1]
- end_message = last_message.content
- gpt_messages.append({"role": "user", "content": end_message})
- return gpt_messages, system_messages
+ proxies = params.http_proxy
+ openai_params = {
+ "api_key": api_key,
+ "base_url": base_url,
+ }
+ return openai_params, api_type, api_version, proxies
def _build_request(model: ProxyModel, params):
@@ -99,8 +97,6 @@ def _build_request(model: ProxyModel, params):
model_params = model.get_params()
logger.info(f"Model: {model}, model_params: {model_params}")
- openai_params = _initialize_openai(model_params)
-
messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
@@ -131,13 +127,21 @@ def _build_request(model: ProxyModel, params):
}
proxyllm_backend = model_params.proxyllm_backend
- if openai_params["api_type"] == "azure":
- # engine = "deployment_name".
- proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
- payloads["engine"] = proxyllm_backend
- else:
+ if metadata.version("openai") >= "1.0.0":
+ openai_params, api_type, api_version, proxies = _initialize_openai_v1(
+ model_params
+ )
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend
+ else:
+ openai_params = _initialize_openai(model_params)
+ if openai_params["api_type"] == "azure":
+ # engine = "deployment_name".
+ proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
+ payloads["engine"] = proxyllm_backend
+ else:
+ proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
+ payloads["model"] = proxyllm_backend
logger.info(
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
@@ -148,32 +152,90 @@ def _build_request(model: ProxyModel, params):
def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
- import openai
+ if metadata.version("openai") >= "1.0.0":
+ model_params = model.get_params()
+ openai_params, api_type, api_version, proxies = _initialize_openai_v1(
+ model_params
+ )
+ history, payloads = _build_request(model, params)
+ if api_type == "azure":
+ from openai import AzureOpenAI
- history, payloads = _build_request(model, params)
+ client = AzureOpenAI(
+ api_key=openai_params["api_key"],
+ api_version=api_version,
+ azure_endpoint=openai_params["base_url"],
+ http_client=httpx.Client(proxies=proxies),
+ )
+ else:
+ from openai import OpenAI
- res = openai.ChatCompletion.create(messages=history, **payloads)
+ client = OpenAI(**openai_params, http_client=httpx.Client(proxies=proxies))
+ res = client.chat.completions.create(messages=history, **payloads)
+ text = ""
+ for r in res:
+ if r.choices[0].delta.content is not None:
+ content = r.choices[0].delta.content
+ text += content
+ yield text
- text = ""
- for r in res:
- if r["choices"][0]["delta"].get("content") is not None:
- content = r["choices"][0]["delta"]["content"]
- text += content
- yield text
+ else:
+ import openai
+
+ history, payloads = _build_request(model, params)
+
+ res = openai.ChatCompletion.create(messages=history, **payloads)
+
+ text = ""
+ for r in res:
+ if r["choices"][0]["delta"].get("content") is not None:
+ content = r["choices"][0]["delta"]["content"]
+ text += content
+ yield text
async def async_chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
- import openai
+ if metadata.version("openai") >= "1.0.0":
+ model_params = model.get_params()
+ openai_params, api_type, api_version, proxies = _initialize_openai_v1(
+ model_params
+ )
+ history, payloads = _build_request(model, params)
+ if api_type == "azure":
+ from openai import AsyncAzureOpenAI
- history, payloads = _build_request(model, params)
+ client = AsyncAzureOpenAI(
+ api_key=openai_params["api_key"],
+ api_version=api_version,
+ azure_endpoint=openai_params["base_url"],
+ http_client=httpx.AsyncClient(proxies=proxies),
+ )
+ else:
+ from openai import AsyncOpenAI
- res = await openai.ChatCompletion.acreate(messages=history, **payloads)
+ client = AsyncOpenAI(
+ **openai_params, http_client=httpx.AsyncClient(proxies=proxies)
+ )
- text = ""
- async for r in res:
- if r["choices"][0]["delta"].get("content") is not None:
- content = r["choices"][0]["delta"]["content"]
- text += content
- yield text
+ res = await client.chat.completions.create(messages=history, **payloads)
+ text = ""
+ for r in res:
+ if r.choices[0].delta.content is not None:
+ content = r.choices[0].delta.content
+ text += content
+ yield text
+ else:
+ import openai
+
+ history, payloads = _build_request(model, params)
+
+ res = await openai.ChatCompletion.acreate(messages=history, **payloads)
+
+ text = ""
+ async for r in res:
+ if r["choices"][0]["delta"].get("content") is not None:
+ content = r["choices"][0]["delta"]["content"]
+ text += content
+ yield text
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index 5a23b0fd8..117aac263 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -16,6 +16,8 @@ from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
+from pilot.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG
+from pilot.model.operator.model_operator import ModelOperator, ModelStreamOperator
logger = logging.getLogger(__name__)
headers = {"User-Agent": "dbgpt Client"}
@@ -88,6 +90,11 @@ class BaseChat(ABC):
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
+ self._model_operator: BaseOperator = _build_model_operator()
+ self._model_stream_operator: BaseOperator = _build_model_operator(
+ is_stream=True, dag_name="llm_stream_model_dag"
+ )
+
class Config:
"""Configuration for this pydantic object."""
@@ -170,7 +177,7 @@ class BaseChat(ABC):
"messages": llm_messages,
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
- "stop": self.prompt_template.sep,
+ # "stop": self.prompt_template.sep,
"echo": self.llm_echo,
}
return payload
@@ -208,14 +215,9 @@ class BaseChat(ABC):
)
payload["span_id"] = span.span_id
try:
- from pilot.model.cluster import WorkerManagerFactory
-
- worker_manager = CFG.SYSTEM_APP.get_component(
- ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
- ).create()
- msg =""
- view_msg=""
- async for output in worker_manager.generate_stream(payload):
+ async for output in await self._model_stream_operator.call_stream(
+ call_data={"data": payload}
+ ):
### Plug-in research in result generation
msg = self.prompt_template.output_parser.parse_model_stream_resp_ex(
output, self.skip_echo_len
@@ -246,14 +248,10 @@ class BaseChat(ABC):
)
payload["span_id"] = span.span_id
try:
- from pilot.model.cluster import WorkerManagerFactory
-
- worker_manager = CFG.SYSTEM_APP.get_component(
- ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
- ).create()
-
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
- model_output = await worker_manager.generate(payload)
+ model_output = await self._model_operator.call(
+ call_data={"data": payload}
+ )
### output parse
ai_response_text = (
@@ -317,14 +315,7 @@ class BaseChat(ABC):
logger.info(f"Request: \n{payload}")
ai_response_text = ""
try:
- from pilot.model.cluster import WorkerManagerFactory
-
- worker_manager = CFG.SYSTEM_APP.get_component(
- ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
- ).create()
-
- model_output = await worker_manager.generate(payload)
-
+ model_output = await self._model_operator.call(call_data={"data": payload})
### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
@@ -578,3 +569,88 @@ class BaseChat(ABC):
)
else:
return prompt_define_response
+
+
+def _build_model_operator(
+ is_stream: bool = False, dag_name: str = "llm_model_dag"
+) -> BaseOperator:
+ """Builds and returns a model processing workflow (DAG) operator.
+
+ This function constructs a Directed Acyclic Graph (DAG) for processing data using a model.
+ It includes caching and branching logic to either fetch results from a cache or process
+ data using the model. It supports both streaming and non-streaming modes.
+
+ .. code-block:: python
+ input_node >> cache_check_branch_node
+ cache_check_branch_node >> model_node >> save_cached_node >> join_node
+ cache_check_branch_node >> cached_node >> join_node
+
+ equivalent to::
+
+ -> model_node -> save_cached_node ->
+ / \
+ input_node -> cache_check_branch_node ---> join_node
+ \ /
+ -> cached_node ------------------- ->
+
+ Args:
+ is_stream (bool): Flag to determine if the operator should process data in streaming mode.
+ dag_name (str): Name of the DAG.
+
+ Returns:
+ BaseOperator: The final operator in the constructed DAG, typically a join node.
+ """
+ from pilot.model.cluster import WorkerManagerFactory
+ from pilot.awel import JoinOperator
+ from pilot.model.operator.model_operator import (
+ ModelCacheBranchOperator,
+ CachedModelStreamOperator,
+ CachedModelOperator,
+ ModelSaveCacheOperator,
+ ModelStreamSaveCacheOperator,
+ )
+ from pilot.cache import CacheManager
+
+ # Fetch worker and cache managers from the system configuration
+ worker_manager = CFG.SYSTEM_APP.get_component(
+ ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
+ ).create()
+ cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
+ ComponentType.MODEL_CACHE_MANAGER, CacheManager
+ )
+ # Define task names for the model and cache nodes
+ model_task_name = "llm_model_node"
+ cache_task_name = "llm_model_cache_node"
+
+ with DAG(dag_name):
+ # Create an input node
+ input_node = InputOperator(SimpleCallDataInputSource())
+ # Determine if the workflow should operate in streaming mode
+ if is_stream:
+ model_node = ModelStreamOperator(worker_manager, task_name=model_task_name)
+ cached_node = CachedModelStreamOperator(
+ cache_manager, task_name=cache_task_name
+ )
+ save_cached_node = ModelStreamSaveCacheOperator(cache_manager)
+ else:
+ model_node = ModelOperator(worker_manager, task_name=model_task_name)
+ cached_node = CachedModelOperator(cache_manager, task_name=cache_task_name)
+ save_cached_node = ModelSaveCacheOperator(cache_manager)
+
+ # Create a branch node to decide between fetching from cache or processing with the model
+ cache_check_branch_node = ModelCacheBranchOperator(
+ cache_manager,
+ model_task_name="llm_model_node",
+ cache_task_name="llm_model_cache_node",
+ )
+ # Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
+ join_node = JoinOperator(
+ combine_function=lambda model_out, cache_out: cache_out or model_out
+ )
+
+ # Define the workflow structure using the >> operator
+ input_node >> cache_check_branch_node
+ cache_check_branch_node >> model_node >> save_cached_node >> join_node
+ cache_check_branch_node >> cached_node >> join_node
+
+ return join_node
diff --git a/pilot/server/component_configs.py b/pilot/server/component_configs.py
index d127120cc..58269385b 100644
--- a/pilot/server/component_configs.py
+++ b/pilot/server/component_configs.py
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Type
import os
from pilot.component import ComponentType, SystemApp
+from pilot.configs.config import Config
+from pilot.configs.model_config import MODEL_DISK_CACHE_DIR
from pilot.utils.executor_utils import DefaultExecutorFactory
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters
@@ -15,6 +17,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+CFG = Config()
+
def initialize_components(
param: WebWerverParameters,
@@ -40,6 +44,7 @@ def initialize_components(
_initialize_embedding_model(
param, system_app, embedding_model_name, embedding_model_path
)
+ _initialize_model_cache(system_app)
def _initialize_embedding_model(
@@ -131,3 +136,16 @@ class LocalEmbeddingFactory(EmbeddingFactory):
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)
+
+
+def _initialize_model_cache(system_app: SystemApp):
+ from pilot.cache import initialize_cache
+
+ if not CFG.MODEL_CACHE_ENABLE:
+ logger.info("Model cache is not enable")
+ return
+
+ storage_type = CFG.MODEL_CACHE_STORAGE_TYPE or "disk"
+ max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256
+ persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR
+ initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)
diff --git a/pilot/utils/benchmarks/__init__.py b/pilot/utils/benchmarks/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/utils/benchmarks/llm/__init__.py b/pilot/utils/benchmarks/llm/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py b/pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py
new file mode 100644
index 000000000..cb05ab33f
--- /dev/null
+++ b/pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py
@@ -0,0 +1,296 @@
+"""
+Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py.
+For benchmarks.
+
+"""
+import gc
+from typing import Iterable, Dict
+
+import torch
+from transformers.generation.logits_process import (
+ LogitsProcessorList,
+ RepetitionPenaltyLogitsProcessor,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+)
+
+
+from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
+
+
+def prepare_logits_processor(
+ temperature: float, repetition_penalty: float, top_p: float, top_k: int
+) -> LogitsProcessorList:
+ processor_list = LogitsProcessorList()
+ # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
+ if temperature >= 1e-5 and temperature != 1.0:
+ processor_list.append(TemperatureLogitsWarper(temperature))
+ if repetition_penalty > 1.0:
+ processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
+ if 1e-8 <= top_p < 1.0:
+ processor_list.append(TopPLogitsWarper(top_p))
+ if top_k > 0:
+ processor_list.append(TopKLogitsWarper(top_k))
+ return processor_list
+
+
+@torch.inference_mode()
+def generate_stream(
+ model,
+ tokenizer,
+ params: Dict,
+ device: str,
+ context_len: int,
+ stream_interval: int = 2,
+ judge_sent_end: bool = False,
+):
+ if hasattr(model, "device"):
+ device = model.device
+
+ # Read parameters
+ prompt = params["prompt"]
+ len_prompt = len(prompt)
+ temperature = float(params.get("temperature", 1.0))
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ top_k = int(params.get("top_k", -1)) # -1 means disable
+ max_new_tokens = int(params.get("max_new_tokens", 256))
+ logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
+ echo = bool(params.get("echo", True))
+ stop_str = params.get("stop", None)
+ stop_token_ids = params.get("stop_token_ids", None) or []
+ if tokenizer.eos_token_id not in stop_token_ids:
+ stop_token_ids.append(tokenizer.eos_token_id)
+
+ logits_processor = prepare_logits_processor(
+ temperature, repetition_penalty, top_p, top_k
+ )
+ input_ids = tokenizer(prompt).input_ids
+
+ if model.config.is_encoder_decoder:
+ max_src_len = context_len
+ else: # truncate
+ max_src_len = context_len - max_new_tokens - 1
+
+ input_ids = input_ids[-max_src_len:]
+ output_ids = list(input_ids)
+ input_echo_len = len(input_ids)
+
+ # Don't stop generate until max_new_tokens is reached.
+ stop_token_ids = []
+ stop_str = None
+
+ if model.config.is_encoder_decoder:
+ if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
+ raise NotImplementedError
+ encoder_output = model.encoder(
+ input_ids=torch.as_tensor([input_ids], device=device)
+ )[0]
+ start_ids = torch.as_tensor(
+ [[model.generation_config.decoder_start_token_id]],
+ dtype=torch.int64,
+ device=device,
+ )
+ else:
+ start_ids = torch.as_tensor([input_ids], device=device)
+
+ past_key_values = out = None
+ token_logprobs = [None] # The first token has no logprobs.
+ sent_interrupt = False
+ finish_reason = None
+ for i in range(max_new_tokens):
+ if i == 0: # prefill
+ if model.config.is_encoder_decoder:
+ out = model.decoder(
+ input_ids=start_ids,
+ encoder_hidden_states=encoder_output,
+ use_cache=True,
+ )
+ logits = model.lm_head(out[0])
+ else:
+ out = model(input_ids=start_ids, use_cache=True)
+ logits = out.logits
+ past_key_values = out.past_key_values
+
+ if logprobs is not None:
+ # Prefull logprobs for the prompt.
+ shift_input_ids = start_ids[..., 1:].contiguous()
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
+ for label_id, logit in zip(
+ shift_input_ids[0].tolist(), shift_logits[0]
+ ):
+ token_logprobs.append(logit[label_id])
+ else: # decoding
+ if model.config.is_encoder_decoder:
+ out = model.decoder(
+ input_ids=torch.as_tensor(
+ [[token] if not sent_interrupt else output_ids],
+ device=device,
+ ),
+ encoder_hidden_states=encoder_output,
+ use_cache=True,
+ past_key_values=past_key_values if not sent_interrupt else None,
+ )
+ sent_interrupt = False
+
+ logits = model.lm_head(out[0])
+ else:
+ out = model(
+ input_ids=torch.as_tensor(
+ [[token] if not sent_interrupt else output_ids],
+ device=device,
+ ),
+ use_cache=True,
+ past_key_values=past_key_values if not sent_interrupt else None,
+ )
+ sent_interrupt = False
+ logits = out.logits
+ past_key_values = out.past_key_values
+
+ if logits_processor:
+ if repetition_penalty > 1.0:
+ tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
+ else:
+ tmp_output_ids = None
+ last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
+ else:
+ last_token_logits = logits[0, -1, :]
+
+ if device == "mps":
+ # Switch to CPU by avoiding some bugs in mps backend.
+ last_token_logits = last_token_logits.float().to("cpu")
+
+ if temperature < 1e-5 or top_p < 1e-8: # greedy
+ _, indices = torch.topk(last_token_logits, 2)
+ tokens = [int(index) for index in indices.tolist()]
+ else:
+ probs = torch.softmax(last_token_logits, dim=-1)
+ indices = torch.multinomial(probs, num_samples=2)
+ tokens = [int(token) for token in indices.tolist()]
+ token = tokens[0]
+ output_ids.append(token)
+ if logprobs is not None:
+ # Cannot use last_token_logits because logprobs is based on raw logits.
+ token_logprobs.append(
+ torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
+ )
+
+ if token in stop_token_ids:
+ stopped = True
+ else:
+ stopped = False
+
+ # Yield the output tokens
+ if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
+ if echo:
+ tmp_output_ids = output_ids
+ rfind_start = len_prompt
+ else:
+ tmp_output_ids = output_ids[input_echo_len:]
+ rfind_start = 0
+
+ output = tokenizer.decode(
+ tmp_output_ids,
+ skip_special_tokens=True,
+ spaces_between_special_tokens=False,
+ clean_up_tokenization_spaces=True,
+ )
+ ret_logprobs = None
+ if logprobs is not None:
+ ret_logprobs = {
+ "text_offset": [],
+ "tokens": [
+ tokenizer.decode(token)
+ for token in (
+ output_ids if echo else output_ids[input_echo_len:]
+ )
+ ],
+ "token_logprobs": token_logprobs
+ if echo
+ else token_logprobs[input_echo_len:],
+ "top_logprobs": [{}]
+ * len(token_logprobs if echo else token_logprobs[input_echo_len:]),
+ }
+ # Compute text_offset
+ curr_pos = 0
+ for text in ret_logprobs["tokens"]:
+ ret_logprobs["text_offset"].append(curr_pos)
+ curr_pos += len(text)
+
+ # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
+ if judge_sent_end and stopped and not is_sentence_complete(output):
+ if len(tokens) > 1:
+ token = tokens[1]
+ output_ids[-1] = token
+ else:
+ output_ids.pop()
+ stopped = False
+ sent_interrupt = True
+
+ partially_stopped = False
+ if stop_str:
+ if isinstance(stop_str, str):
+ pos = output.rfind(stop_str, rfind_start)
+ if pos != -1:
+ output = output[:pos]
+ stopped = True
+ else:
+ partially_stopped = is_partial_stop(output, stop_str)
+ elif isinstance(stop_str, Iterable):
+ for each_stop in stop_str:
+ pos = output.rfind(each_stop, rfind_start)
+ if pos != -1:
+ output = output[:pos]
+ stopped = True
+ break
+ else:
+ partially_stopped = is_partial_stop(output, each_stop)
+ if partially_stopped:
+ break
+ else:
+ raise ValueError("Invalid stop field type.")
+
+ # Prevent yielding partial stop sequence
+ if not partially_stopped:
+ yield {
+ "text": output,
+ "logprobs": ret_logprobs,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": None,
+ }
+
+ if stopped:
+ break
+
+ # Finish stream event, which contains finish reason
+ else:
+ finish_reason = "length"
+
+ if stopped:
+ finish_reason = "stop"
+
+ yield {
+ "text": output,
+ "logprobs": ret_logprobs,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": finish_reason,
+ }
+
+ # Clean
+ del past_key_values, out
+ gc.collect()
+ torch.cuda.empty_cache()
+ if device == "xpu":
+ torch.xpu.empty_cache()
+ if device == "npu":
+ torch.npu.empty_cache()
diff --git a/pilot/utils/benchmarks/llm/llm_benchmarks.py b/pilot/utils/benchmarks/llm/llm_benchmarks.py
new file mode 100644
index 000000000..b70742670
--- /dev/null
+++ b/pilot/utils/benchmarks/llm/llm_benchmarks.py
@@ -0,0 +1,243 @@
+from typing import Dict, List
+import asyncio
+import os
+import sys
+import time
+import csv
+import argparse
+import logging
+import traceback
+from pilot.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG
+
+from pilot.model.cluster.worker.manager import (
+ run_worker_manager,
+ initialize_worker_manager_in_client,
+ worker_manager,
+ WorkerManager,
+)
+
+from pilot.model.base import ModelOutput, ModelInferenceMetrics
+from pilot.model.cluster import PromptRequest
+from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
+
+
+model_name = "vicuna-7b-v1.5"
+model_path = LLM_MODEL_CONFIG[model_name]
+# or vllm
+model_type = "huggingface"
+
+controller_addr = "http://127.0.0.1:5000"
+
+result_csv_file = None
+
+parallel_nums = [1, 2, 4, 16, 32]
+# parallel_nums = [1, 2, 4]
+
+
+def get_result_csv_file() -> str:
+ return os.path.join(
+ ROOT_PATH, f"pilot/data/{model_name}_{model_type}_benchmarks_llm.csv"
+ )
+
+
+input_lens = [64, 64]
+output_lens = [256, 512]
+
+
+prompt_file_map = {
+ "11k": os.path.join(
+ ROOT_PATH, "docker/examples/benchmarks/benchmarks_llm_11k_prompt.txt"
+ )
+}
+
+METRICS_HEADERS = [
+ # Params
+ "model_name",
+ "parallel_nums",
+ "input_length",
+ "output_length",
+ # Merge parallel result
+ "test_time_cost_ms",
+ "test_total_tokens",
+ "test_speed_per_second", # (tokens / s)
+ # Detail for each task
+ "start_time_ms",
+ "end_time_ms",
+ "current_time_ms",
+ "first_token_time_ms",
+ "first_completion_time_ms",
+ "first_completion_tokens",
+ "prompt_tokens",
+ "completion_tokens",
+ "total_tokens",
+ "speed_per_second",
+]
+
+
+def read_prompt_from_file(file_key: str) -> str:
+ full_path = prompt_file_map[file_key]
+ with open(full_path, "r+", encoding="utf-8") as f:
+ return f.read()
+
+
+def build_param(
+ input_len: int,
+ output_len: int,
+ user_input: str,
+ system_prompt: str = None,
+) -> Dict:
+ hist = []
+ if system_prompt is not None:
+ hist.append(
+ ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
+ )
+ hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
+ hist = list(h.dict() for h in hist)
+ context_len = input_len + output_len + 2
+ params = {
+ "prompt": user_input,
+ "messages": hist,
+ "model": model_name,
+ "echo": False,
+ "max_new_tokens": output_len,
+ "context_len": context_len,
+ }
+ return params
+
+
+async def run_batch(
+ wh, input_len: int, output_len: int, parallel_num: int, output_file: str
+):
+ tasks = []
+ prompt = read_prompt_from_file("11k")
+ if model_type == "vllm":
+ max_input_str_len = input_len
+ if "baichuan" in model_name:
+ # TODO prompt handle first
+ max_input_str_len *= 2
+ prompt = prompt[-max_input_str_len:]
+
+ for _ in range(parallel_num):
+ params = build_param(input_len, output_len, prompt, system_prompt="")
+ tasks.append(wh.generate(params))
+ print(
+ f"Begin run benchmarks, model name: {model_name}, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, save result to {output_file}"
+ )
+ start_time_ms = time.time_ns() // 1_000_000
+ results: List[ModelOutput] = await asyncio.gather(*tasks)
+ end_time_ms = time.time_ns() // 1_000_000
+
+ test_time_cost_ms = end_time_ms - start_time_ms
+ test_total_tokens = 0
+ rows = []
+ for r in results:
+ metrics = r.metrics
+ if isinstance(metrics, dict):
+ metrics = ModelInferenceMetrics(**metrics)
+ print(r)
+ test_total_tokens += metrics.total_tokens
+ row_data = metrics.to_dict()
+ rows.append(row_data)
+ test_speed_per_second = test_total_tokens / (test_time_cost_ms / 1000.0)
+
+ with open(output_file, "a", newline="", encoding="utf-8") as f:
+ writer = csv.DictWriter(f, fieldnames=METRICS_HEADERS)
+ if f.tell() == 0:
+ # Fist time
+ writer.writeheader()
+ for row in rows:
+ row["model_name"] = model_name
+ row["parallel_nums"] = parallel_num
+ row["input_length"] = input_len
+ row["output_length"] = output_len
+ row["test_time_cost_ms"] = test_time_cost_ms
+ row["test_total_tokens"] = test_total_tokens
+ row["test_speed_per_second"] = test_speed_per_second
+ writer.writerow(row)
+ print(
+ f"input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, save result to {output_file}"
+ )
+
+
+async def run_model(wh: WorkerManager) -> None:
+ global result_csv_file
+ if not result_csv_file:
+ result_csv_file = get_result_csv_file()
+ if os.path.exists(result_csv_file):
+ os.rename(result_csv_file, f"{result_csv_file}.bak.csv")
+ for parallel_num in parallel_nums:
+ for input_len, output_len in zip(input_lens, output_lens):
+ try:
+ await run_batch(
+ wh, input_len, output_len, parallel_num, result_csv_file
+ )
+ except Exception:
+ msg = traceback.format_exc()
+ logging.error(
+ f"Run benchmarks error, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, error message: {msg}"
+ )
+
+ sys.exit(0)
+
+
+def startup_llm_env():
+ from fastapi import FastAPI
+
+ app = FastAPI()
+ initialize_worker_manager_in_client(
+ app=app,
+ model_name=model_name,
+ model_path=model_path,
+ run_locally=False,
+ controller_addr=controller_addr,
+ local_port=6000,
+ start_listener=run_model,
+ )
+
+
+def connect_to_remote_model():
+ startup_llm_env()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_name", type=str, default=model_name)
+ parser.add_argument("--model_path", type=str, default=None)
+ parser.add_argument("--model_type", type=str, default="huggingface")
+ parser.add_argument("--result_csv_file", type=str, default=None)
+ parser.add_argument("--input_lens", type=str, default="8,8,256,1024")
+ parser.add_argument("--output_lens", type=str, default="256,512,1024,1024")
+ parser.add_argument("--parallel_nums", type=str, default="1,2,4,16,32")
+ parser.add_argument(
+ "--remote_model", type=bool, default=False, help="Connect to remote model"
+ )
+ parser.add_argument("--controller_addr", type=str, default="http://127.0.0.1:8000")
+ parser.add_argument("--limit_model_concurrency", type=int, default=200)
+
+ args = parser.parse_args()
+ print(f"args: {args}")
+ model_name = args.model_name
+ model_path = args.model_path or LLM_MODEL_CONFIG[model_name]
+ result_csv_file = args.result_csv_file
+ input_lens = [int(i) for i in args.input_lens.strip().split(",")]
+ output_lens = [int(i) for i in args.output_lens.strip().split(",")]
+ parallel_nums = [int(i) for i in args.parallel_nums.strip().split(",")]
+ remote_model = args.remote_model
+ controller_addr = args.controller_addr
+ limit_model_concurrency = args.limit_model_concurrency
+ model_type = args.model_type
+ if len(input_lens) != len(output_lens):
+ raise ValueError("input_lens size must equal output_lens size")
+
+ if remote_model:
+ # Connect to remote model and run benchmarks
+ connect_to_remote_model()
+ else:
+ # Start worker manager and run benchmarks
+ run_worker_manager(
+ model_name=model_name,
+ model_path=model_path,
+ start_listener=run_model,
+ limit_model_concurrency=limit_model_concurrency,
+ model_type=model_type,
+ )
diff --git a/pilot/utils/memory_utils.py b/pilot/utils/memory_utils.py
new file mode 100644
index 000000000..cb0427c08
--- /dev/null
+++ b/pilot/utils/memory_utils.py
@@ -0,0 +1,11 @@
+from typing import Any
+from pympler import asizeof
+
+
+def _get_object_bytes(obj: Any) -> int:
+ """Get the bytes of a object in memory
+
+ Args:
+ obj (Any): The object to return the bytes
+ """
+ return asizeof.asizeof(obj)
diff --git a/scripts/run_llm_benchmarks.sh b/scripts/run_llm_benchmarks.sh
new file mode 100755
index 000000000..ffa9ac6da
--- /dev/null
+++ b/scripts/run_llm_benchmarks.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+default_input_lens="64,64,64,512,1024,1024,2048"
+default_output_lens="256,512,1024,1024,1024,2048,2048"
+default_parallel_nums="1,2,4,16,32"
+
+input_lens=${1:-$default_input_lens}
+output_lens=${2:-$default_output_lens}
+parallel_nums=${3:-$default_parallel_nums}
+
+run_benchmark() {
+ local model_name=$1
+ local model_type=$2
+ DB_GPT_MODEL_BENCHMARK=true python pilot/utils/benchmarks/llm/llm_benchmarks.py --model_name ${model_name} --model_type ${model_type} --input_lens ${input_lens} --output_lens ${output_lens} --parallel_nums ${parallel_nums}
+}
+
+run_benchmark "vicuna-7b-v1.5" "huggingface"
+run_benchmark "vicuna-7b-v1.5" "vllm"
+run_benchmark "baichuan2-7b" "huggingface"
+run_benchmark "baichuan2-7b" "vllm"
diff --git a/setup.py b/setup.py
index 6043b3840..185a255bd 100644
--- a/setup.py
+++ b/setup.py
@@ -317,9 +317,12 @@ def core_requires():
# TODO move transformers to default
"transformers>=4.31.0",
"alembic==1.12.0",
+ # for excel
"openpyxl==3.1.2",
"chardet==5.1.0",
"xlrd==2.0.1",
+ # for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
+ "pympler",
]
@@ -364,6 +367,8 @@ def quantization_requires():
)
pkgs = [f"bitsandbytes @ {local_pkg}"]
print(pkgs)
+ # For chatglm2-6b-int4
+ pkgs += ["cpm_kernels"]
setup_spec.extras["quantization"] = pkgs
@@ -409,6 +414,13 @@ def vllm_requires():
setup_spec.extras["vllm"] = ["vllm"]
+def cache_requires():
+ """
+ pip install "db-gpt[cache]"
+ """
+ setup_spec.extras["cache"] = ["rocksdict", "msgpack"]
+
+
# def chat_scene():
# setup_spec.extras["chat"] = [
# ""
@@ -459,6 +471,7 @@ all_datasource_requires()
openai_requires()
gpt4all_requires()
vllm_requires()
+cache_requires()
# must be last
default_requires()