Merge remote-tracking branch 'origin/main' into dev

# Conflicts:
#	pilot/configs/config.py
#	pilot/connections/mysql.py
#	pilot/conversation.py
#	pilot/server/webserver.py
This commit is contained in:
yhjun1026
2023-05-25 10:22:38 +08:00
87 changed files with 2152 additions and 874 deletions

View File

@@ -81,3 +81,14 @@ DENYLISTED_PLUGINS=
#*******************************************************************#
# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False)
# CHAT_MESSAGES_ENABLED=False
#*******************************************************************#
#** VECTOR STORE SETTINGS **#
#*******************************************************************#
VECTOR_STORE_TYPE=Chroma
#MILVUS_URL=127.0.0.1
#MILVUS_PORT=19530
#MILVUS_USERNAME
#MILVUS_PASSWORD
#MILVUS_SECURE=

38
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,38 @@
---
name: Bug report
about: Create a report to help us improve
title: "[BUG]: "
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):**
- OS: [e.g. iOS]
- Browser [e.g. chrome, safari]
- Version [e.g. 22]
**Smartphone (please complete the following information):**
- Device: [e.g. iPhone6]
- OS: [e.g. iOS8.1]
- Browser [e.g. stock browser, safari]
- Version [e.g. 22]
**Additional context**
Add any other context about the problem here.

View File

@@ -0,0 +1,10 @@
---
name: Documentation Related
about: Describe this issue template's purpose here.
title: "[Doc]: "
labels: ''
assignees: ''
---

View File

@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: "[Feature]:"
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@@ -1,6 +1,15 @@
name: Pylint
on: [push]
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
concurrency:
group: ${{ github.event.number || github.run_id }}
cancel-in-progress: true
jobs:
build:
@@ -17,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Analysing the code with pylint
pip install -U black isort
- name: check the code lint
run: |
pylint $(git ls-files '*.py')
black . --check

4
.gitignore vendored
View File

@@ -25,6 +25,7 @@ lib/
lib64/
parts/
sdist/
models
var/
wheels/
models/
@@ -138,4 +139,5 @@ dmypy.json
.DS_Store
logs
nltk_data
.vectordb
.vectordb
pilot/data/

View File

@@ -29,6 +29,10 @@ Currently, we have released multiple key features, which are listed below to dem
- Unified vector storage/indexing of knowledge base
- Support for unstructured data such as PDF, Markdown, CSV, and WebURL
- Milti LLMs Support
- Supports multiple large language models, currently supporting Vicuna (7b, 13b), ChatGLM-6b (int4, int8)
- TODO: codegen2, codet5p
## Demo
@@ -103,6 +107,7 @@ As the knowledge base is currently the most significant user demand scenario, we
2. Custom addition of knowledge bases
3. Various usage scenarios such as constructing knowledge bases through plugin capabilities and web crawling. Users only need to organize the knowledge documents, and they can use our existing capabilities to build the knowledge base required for the large model.
### LLMs Management
In the underlying large model integration, we have designed an open interface that supports integration with various large models. At the same time, we have a very strict control and evaluation mechanism for the effectiveness of the integrated models. In terms of accuracy, the integrated models need to align with the capability of ChatGPT at a level of 85% or higher. We use higher standards to select models, hoping to save users the cumbersome testing and evaluation process in the process of use.
@@ -153,16 +158,6 @@ conda create -n dbgpt_env python=3.10
conda activate dbgpt_env
pip install -r requirements.txt
```
Alternatively, you can use the following command:
```
cd DB-GPT
conda env create -f environment.yml
```
It is recommended to set the Python package path to avoid runtime errors due to package not found.
```
echo "/root/workspace/DB-GPT" > /root/miniconda3/env/dbgpt_env/lib/python3.10/site-packages/dbgpt.pth
```
Notice: You need replace the path to your owner.
### 3. Run
You can refer to this document to obtain the Vicuna weights: [Vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights) .
@@ -182,9 +177,31 @@ $ python pilot/server/webserver.py
Notice: the webserver need to connect llmserver, so you need change the .env file. change the MODEL_SERVER = "http://127.0.0.1:8000" to your address. It's very important.
## Usage Instructions
We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project.
- [LLM Practical In Action Series (1) — Combined Langchain-Vicuna Application Practical](https://medium.com/@cfqcsunny/llm-practical-in-action-series-1-combined-langchain-vicuna-application-practical-701cd0413c9f)
### Multi LLMs Usage
To use multiple models, modify the LLM_MODEL parameter in the .env configuration file to switch between the models.
####Create your own knowledge repository:
1.Place personal knowledge files or folders in the pilot/datasets directory.
2.Run the knowledge repository script in the tools directory.
```
python tools/knowledge_init.py
--vector_name : your vector store name default_value:default
--append: append mode, True:append, False: not append default_value:False
```
3.Add the knowledge repository in the interface by entering the name of your knowledge repository (if not specified, enter "default") so you can use it for Q&A based on your knowledge base.
Note that the default vector model used is text2vec-large-chinese (which is a large model, so if your personal computer configuration is not enough, it is recommended to use text2vec-base-chinese). Therefore, ensure that you download the model and place it in the models directory.
## Acknowledgement
The achievements of this project are thanks to the technical community, especially the following projects:
@@ -198,6 +215,10 @@ The achievements of this project are thanks to the technical community, especial
- [ChatGLM](https://github.com/THUDM/ChatGLM-6B) as the base model
- [llama_index](https://github.com/jerryjliu/llama_index) for enhancing database-related knowledge using [in-context learning](https://arxiv.org/abs/2301.00234) based on existing knowledge bases.
## Contribution
- Please run `black .` before submitting the code.
<!-- GITCONTRIBUTOR_START -->
## Contributors
@@ -206,7 +227,7 @@ The achievements of this project are thanks to the technical community, especial
| :---: | :---: | :---: | :---: |:---: |
This project follows the git-contributor [spec](https://github.com/xudafeng/git-contributor), auto updated at `Sun May 14 2023 23:02:43 GMT+0800`.
This project follows the git-contributor [spec](https://github.com/xudafeng/git-contributor), auto updated at `Fri May 19 2023 00:24:18 GMT+0800`.
<!-- GITCONTRIBUTOR_END -->

View File

@@ -26,6 +26,10 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目使用本地
- 知识库统一向量存储/索引
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
- 多模型支持
- 支持多种大语言模型, 当前已支持Vicuna(7b,13b), ChatGLM-6b(int4, int8)
- TODO: codet5p, codegen2
## 效果演示
示例通过 RTX 4090 GPU 演示,[YouTube 地址](https://www.youtube.com/watch?v=1PWI6F89LPo)
@@ -110,6 +114,8 @@ DB-GPT基于 [FastChat](https://github.com/lm-sys/FastChat) 构建大模型运
用户只需要整理好知识文档,即可用我们现有的能力构建大模型所需要的知识库能力。
### 大模型管理能力
在底层大模型接入中设计了开放的接口支持对接多种大模型。同时对于接入模型的效果我们有非常严格的把控与评审机制。对大模型能力上与ChatGPT对比在准确率上需要满足85%以上的能力对齐。我们用更高的标准筛选模型,是期望在用户使用过程中,可以省去前面繁琐的测试评估环节。
@@ -151,15 +157,6 @@ conda create -n dbgpt_env python=3.10
conda activate dbgpt_env
pip install -r requirements.txt
```
或者也可以使用命令:
```
cd DB-GPT
conda env create -f environment.yml
```
另外需要设置一下python包路径, 避免出现运行时找不到包
```
echo "/root/workspace/DB-GPT" > /root/miniconda3/env/dbgpt_env/lib/python3.10/site-packages/dbgpt.pth
```
### 3. 运行大模型
@@ -187,6 +184,26 @@ $ python webserver.py
2. [大模型实战系列(2) —— DB-GPT 阿里云部署指南](https://zhuanlan.zhihu.com/p/629467580)
3. [大模型实战系列(3) —— DB-GPT插件模型原理与使用](https://zhuanlan.zhihu.com/p/629623125)
### 多模型使用
在.env 配置文件当中, 修改LLM_MODEL参数来切换使用的模型。
####打造属于你的知识库:
1、将个人知识文件或者文件夹放入pilot/datasets目录中
2、在tools目录执行知识入库脚本
```
python tools/knowledge_init.py
--vector_name : your vector store name default_value:default
--append: append mode, True:append, False: not append default_value:False
```
3、在界面上新增知识库输入你的知识库名如果没指定输入default,就可以根据你的知识库进行问答
注意这里默认向量模型是text2vec-large-chinese(模型比较大如果个人电脑配置不够建议采用text2vec-base-chinese),因此确保需要将模型download下来放到models目录中。
## 感谢
项目取得的成果,需要感谢技术社区,尤其以下项目。
@@ -201,14 +218,20 @@ $ python webserver.py
- [ChatGLM](https://github.com/THUDM/ChatGLM-6B) 基础模型
- [llama-index](https://github.com/jerryjliu/llama_index) 基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。
# 贡献
- 提交代码前请先执行 `black .`
<!-- GITCONTRIBUTOR_START -->
## 贡献者
## Contributors
|[<img src="https://avatars.githubusercontent.com/u/17919400?v=4" width="100px;"/><br/><sub><b>csunny</b></sub>](https://github.com/csunny)<br/>|[<img src="https://avatars.githubusercontent.com/u/1011681?v=4" width="100px;"/><br/><sub><b>xudafeng</b></sub>](https://github.com/xudafeng)<br/>|[<img src="https://avatars.githubusercontent.com/u/7636723?s=96&v=4" width="100px;"/><br/><sub><b>明天</b></sub>](https://github.com/yhjun1026)<br/> | [<img src="https://avatars.githubusercontent.com/u/13723926?v=4" width="100px;"/><br/><sub><b>Aries-ckt</b></sub>](https://github.com/Aries-ckt)<br/>|[<img src="https://avatars.githubusercontent.com/u/95130644?v=4" width="100px;"/><br/><sub><b>thebigbone</b></sub>](https://github.com/thebigbone)<br/>|
| :---: | :---: | :---: | :---: |:---: |
[git-contributor 说明](https://github.com/xudafeng/git-contributor),自动生成时间:`Fri May 19 2023 00:24:18 GMT+0800`
This project follows the git-contributor [spec](https://github.com/xudafeng/git-contributor), auto updated at `Sun May 14 2023 23:02:43 GMT+0800`.
<!-- GITCONTRIBUTOR_END -->

Binary file not shown.

Before

Width:  |  Height:  |  Size: 257 KiB

After

Width:  |  Height:  |  Size: 157 KiB

View File

@@ -1,68 +0,0 @@
name: db_pgt
channels:
- pytorch
- defaults
- anaconda
dependencies:
- python=3.10
- cudatoolkit
- pip
- pytorch-mutex=1.0=cuda
- pip:
- pytorch
- accelerate==0.16.0
- aiohttp==3.8.4
- aiosignal==1.3.1
- async-timeout==4.0.2
- attrs==22.2.0
- bitsandbytes==0.37.0
- cchardet==2.1.7
- chardet==5.1.0
- contourpy==1.0.7
- cycler==0.11.0
- filelock==3.9.0
- fonttools==4.38.0
- frozenlist==1.3.3
- huggingface-hub==0.13.4
- importlib-resources==5.12.0
- kiwisolver==1.4.4
- matplotlib==3.7.0
- multidict==6.0.4
- packaging==23.0
- psutil==5.9.4
- pycocotools==2.0.6
- pyparsing==3.0.9
- python-dateutil==2.8.2
- pyyaml==6.0
- regex==2022.10.31
- tokenizers==0.13.2
- tqdm==4.64.1
- transformers==4.28.0
- timm==0.6.13
- spacy==3.5.1
- webdataset==0.2.48
- scikit-learn==1.2.2
- scipy==1.10.1
- yarl==1.8.2
- zipp==3.14.0
- omegaconf==2.3.0
- opencv-python==4.7.0.72
- iopath==0.1.10
- tenacity==8.2.2
- peft
- pycocoevalcap
- sentence-transformers
- umap-learn
- notebook
- gradio==3.23
- gradio-client==0.0.8
- wandb
- llama-index==0.5.27
- pymysql
- unstructured==0.6.3
- pytesseract==0.3.10
- markdown2
- chromadb
- colorama
- playsound
- distro

View File

@@ -2,24 +2,28 @@
# -*- coding:utf-8 -*-
import gradio as gr
from langchain.agents import (
load_tools,
initialize_agent,
AgentType
)
from pilot.model.vicuna_llm import VicunaRequestLLM, VicunaEmbeddingLLM
from llama_index import LLMPredictor, LangchainEmbedding, ServiceContext
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import Document, GPTSimpleVectorIndex
from llama_index import (
Document,
GPTSimpleVectorIndex,
LangchainEmbedding,
LLMPredictor,
ServiceContext,
)
from pilot.model.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM
def agent_demo():
llm = VicunaRequestLLM()
tools = load_tools(['python_repl'], llm=llm)
agent = initialize_agent(tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
agent.run(
"Write a SQL script that Query 'select count(1)!'"
tools = load_tools(["python_repl"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
agent.run("Write a SQL script that Query 'select count(1)!'")
def knowledged_qa_demo(text_list):
llm_predictor = LLMPredictor(llm=VicunaRequestLLM())
@@ -27,27 +31,34 @@ def knowledged_qa_demo(text_list):
embed_model = LangchainEmbedding(hfemb)
documents = [Document(t) for t in text_list]
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model)
index = GPTSimpleVectorIndex.from_documents(documents, service_context=service_context)
service_context = ServiceContext.from_defaults(
llm_predictor=llm_predictor, embed_model=embed_model
)
index = GPTSimpleVectorIndex.from_documents(
documents, service_context=service_context
)
return index
def get_answer(q):
base_knowledge = """ """
base_knowledge = """ """
text_list = [base_knowledge]
index = knowledged_qa_demo(text_list)
response = index.query(q)
return response.response
def get_similar(q):
from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st
docsearch = knownledge_tovec_st("./datasets/plan.md")
docs = docsearch.similarity_search_with_score(q, k=1)
for doc in docs:
dc, s = doc
dc, s = doc
print(s)
yield dc.page_content
yield dc.page_content
if __name__ == "__main__":
# agent_demo()
@@ -58,8 +69,7 @@ if __name__ == "__main__":
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button()
text_button.click(get_similar, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@@ -1,61 +1,73 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import requests
import json
import time
import uuid
import os
import sys
from urllib.parse import urljoin
import gradio as gr
from pilot.configs.config import Config
from pilot.conversation import conv_qa_prompt_template, conv_templates
import requests
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_PATH)
from langchain.prompts import PromptTemplate
from pilot.configs.config import Config
from pilot.conversation import conv_qa_prompt_template, conv_templates
vicuna_stream_path = "generate_stream"
llmstream_stream_path = "generate_stream"
CFG = Config()
def generate(query):
def generate(query):
template_name = "conv_one_shot"
state = conv_templates[template_name].copy()
pt = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
)
# pt = PromptTemplate(
# template=conv_qa_prompt_template,
# input_variables=["context", "question"]
# )
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
question=query)
# result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
# question=query)
print(result)
# print(result)
state.append_message(state.roles[0], result)
state.append_message(state.roles[0], query)
state.append_message(state.roles[1], None)
prompt = state.get_prompt()
params = {
"model": "vicuna-13b",
"model": "chatglm-6b",
"prompt": prompt,
"temperature": 0.7,
"temperature": 1.0,
"max_new_tokens": 1024,
"stop": "###"
"stop": "###",
}
response = requests.post(
url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
url=urljoin(CFG.MODEL_SERVER, llmstream_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
state.messages[-1][-1] = output + ""
yield(output)
yield (output)
if __name__ == "__main__":
print(CFG.LLM_MODEL)
with gr.Blocks() as demo:
@@ -64,10 +76,7 @@ if __name__ == "__main__":
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button("提交")
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@@ -1,19 +1,19 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import logging
import sys
from llama_index import SimpleDirectoryReader, GPTSimpleVectorIndex
from llama_index import GPTSimpleVectorIndex, SimpleDirectoryReader
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
# read the document of data dir
documents = SimpleDirectoryReader("data").load_data()
# split the document to chunk, max token size=500, convert chunk to vector
# split the document to chunk, max token size=500, convert chunk to vector
index = GPTSimpleVectorIndex(documents)
# save index
index.save_to_disk("index.json")
index.save_to_disk("index.json")

View File

@@ -3,17 +3,19 @@
import gradio as gr
def change_tab():
return gr.Tabs.update(selected=1)
with gr.Blocks() as demo:
with gr.Tabs() as tabs:
with gr.TabItem("Train", id=0):
t = gr.Textbox()
with gr.TabItem("Inference", id=1):
i = gr.Image()
btn = gr.Button()
btn.click(change_tab, None, tabs)
demo.launch()
demo.launch()

View File

@@ -1,5 +1,3 @@
from pilot.source_embedding.csv_embedding import CSVEmbedding
# path = "/Users/chenketing/Downloads/share_ireserve双写数据异常2.xlsx"
@@ -8,6 +6,13 @@ model_name = "your_path/all-MiniLM-L6-v2"
vector_store_path = "your_path/"
pdf_embedding = CSVEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "url", "vector_store_path": "vector_store_path"})
pdf_embedding = CSVEmbedding(
file_path=path,
model_name=model_name,
vector_store_config={
"vector_store_name": "url",
"vector_store_path": "vector_store_path",
},
)
pdf_embedding.source_embedding()
print("success")
print("success")

View File

@@ -6,6 +6,13 @@ model_name = "your_path/all-MiniLM-L6-v2"
vector_store_path = "your_path/"
pdf_embedding = PDFEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "ob-pdf", "vector_store_path": vector_store_path})
pdf_embedding = PDFEmbedding(
file_path=path,
model_name=model_name,
vector_store_config={
"vector_store_name": "ob-pdf",
"vector_store_path": vector_store_path,
},
)
pdf_embedding.source_embedding()
print("success")
print("success")

View File

@@ -5,6 +5,13 @@ model_name = "your_path/all-MiniLM-L6-v2"
vector_store_path = "your_path"
pdf_embedding = URLEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "url", "vector_store_path": "vector_store_path"})
pdf_embedding = URLEmbedding(
file_path=path,
model_name=model_name,
vector_store_config={
"vector_store_name": "url",
"vector_store_path": "vector_store_path",
},
)
pdf_embedding.source_embedding()
print("success")
print("success")

View File

@@ -1,19 +1,28 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from llama_index import SimpleDirectoryReader, LangchainEmbedding, GPTListIndex, GPTSimpleVectorIndex, PromptHelper
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import LLMPredictor
import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.llms.base import LLM
from llama_index import (
GPTListIndex,
GPTSimpleVectorIndex,
LangchainEmbedding,
LLMPredictor,
PromptHelper,
SimpleDirectoryReader,
)
from transformers import pipeline
class FlanLLM(LLM):
model_name = "google/flan-t5-large"
pipeline = pipeline("text2text-generation", model=model_name, device=0, model_kwargs={
"torch_dtype": torch.bfloat16
})
pipeline = pipeline(
"text2text-generation",
model=model_name,
device=0,
model_kwargs={"torch_dtype": torch.bfloat16},
)
def _call(self, prompt, stop=None):
return self.pipeline(prompt, max_length=9999)[0]["generated_text"]
@@ -24,6 +33,7 @@ class FlanLLM(LLM):
def _llm_type(self):
return "custome"
llm_predictor = LLMPredictor(llm=FlanLLM())
hfemb = HuggingFaceEmbeddings()
embed_model = LangchainEmbedding(hfemb)
@@ -214,9 +224,10 @@ OceanBase 数据库 EXPLAIN 命令输出的第一部分是执行计划的树形
回答: nlj也是左表的表是驱动表这个要了解下计划执行方面的基本原理取左表的一行数据再遍历右表一旦满足连接条件就可以返回数据
anti/semi只是因为not exists/exist的语义只是返回左表数据改成anti join是一种计划优化连接的方式比子查询更优
"""
"""
from llama_index import Document
text_list = [text1]
documents = [Document(t) for t in text_list]
@@ -226,12 +237,18 @@ max_input_size = 512
max_chunk_overlap = 20
prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)
index = GPTListIndex(documents, embed_model=embed_model, llm_predictor=llm_predictor, prompt_helper=prompt_helper)
index = GPTListIndex(
documents,
embed_model=embed_model,
llm_predictor=llm_predictor,
prompt_helper=prompt_helper,
)
index.save_to_disk("index.json")
if __name__ == "__main__":
import logging
logging.getLogger().setLevel(logging.CRITICAL)
for d in documents:
print(d)

View File

@@ -1,7 +1,3 @@
from pilot.source_embedding import (SourceEmbedding, register)
from pilot.source_embedding import SourceEmbedding, register
__all__ = [
"SourceEmbedding",
"register"
]
__all__ = ["SourceEmbedding", "register"]

View File

@@ -3,10 +3,10 @@
class Agent:
"""Agent class for interacting with DB-GPT
Attributes:
"""Agent class for interacting with DB-GPT
Attributes:
"""
def __init__(self) -> None:
pass
pass

View File

@@ -4,10 +4,8 @@
from __future__ import annotations
from pilot.configs.config import Config
from pilot.singleton import Singleton
from pilot.configs.config import Config
from typing import List
from pilot.model.base import Message
from pilot.singleton import Singleton
class AgentManager(metaclass=Singleton):
@@ -17,6 +15,7 @@ class AgentManager(metaclass=Singleton):
self.next_key = 0
self.agents = {} # key, (task, full_message_history, model)
self.cfg = Config()
"""Agent manager for managing DB-GPT agents
In order to compatible auto gpt plugins,
we use the same template with it.
@@ -28,7 +27,7 @@ class AgentManager(metaclass=Singleton):
def __init__(self) -> None:
self.next_key = 0
self.agents = {} #TODO need to define
self.agents = {} # TODO need to define
self.cfg = Config()
# Create new GPT agent
@@ -46,7 +45,6 @@ class AgentManager(metaclass=Singleton):
The key of the new agent
"""
def message_agent(self, key: str | int, message: str) -> str:
"""Send a message to an agent and return its response
@@ -58,7 +56,6 @@ class AgentManager(metaclass=Singleton):
The agent's response
"""
def list_agents(self) -> list[tuple[str | int, str]]:
"""Return a list of all agents

View File

@@ -1,18 +1,22 @@
import contextlib
import json
from typing import Any, Dict
import contextlib
from colorama import Fore
from regex import regex
from pilot.configs.config import Config
from pilot.json_utils.json_fix_general import (
add_quotes_to_property_names,
balance_braces,
fix_invalid_escape,
)
from pilot.logs import logger
from pilot.speech import say_text
from pilot.json_utils.json_fix_general import fix_invalid_escape,add_quotes_to_property_names,balance_braces
CFG = Config()
def fix_and_parse_json(
json_to_load: str, try_to_fix_with_gpt: bool = True
) -> Dict[Any, Any]:
@@ -48,7 +52,7 @@ def fix_and_parse_json(
maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1]
return json.loads(maybe_fixed_json)
except (json.JSONDecodeError, ValueError) as e:
logger.error("参数解析错误", e)
logger.error("参数解析错误", e)
def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]:

View File

@@ -1,2 +1,2 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# -*- coding:utf-8 -*-

View File

@@ -1,2 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-

View File

@@ -1,15 +1,14 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pilot.prompts.generator import PromptGenerator
from typing import Dict, List, NoReturn, Union
from pilot.configs.config import Config
from pilot.speech import say_text
import json
from typing import Dict
from pilot.agent.json_fix_llm import fix_json_using_multiple_techniques
from pilot.commands.exception_not_commands import NotCommands
import json
from pilot.configs.config import Config
from pilot.prompts.generator import PromptGenerator
from pilot.speech import say_text
def _resolve_pathlike_command_args(command_args):
@@ -25,9 +24,9 @@ def _resolve_pathlike_command_args(command_args):
def execute_ai_response_json(
prompt: PromptGenerator,
ai_response: str,
user_input: str = None,
prompt: PromptGenerator,
ai_response: str,
user_input: str = None,
) -> str:
"""
@@ -52,18 +51,14 @@ def execute_ai_response_json(
arguments = _resolve_pathlike_command_args(arguments)
# Execute command
if command_name is not None and command_name.lower().startswith("error"):
result = (
f"Command {command_name} threw the following error: {arguments}"
)
result = f"Command {command_name} threw the following error: {arguments}"
elif command_name == "human_feedback":
result = f"Human feedback: {user_input}"
else:
for plugin in cfg.plugins:
if not plugin.can_handle_pre_command():
continue
command_name, arguments = plugin.pre_command(
command_name, arguments
)
command_name, arguments = plugin.pre_command(command_name, arguments)
command_result = execute_command(
command_name,
arguments,
@@ -74,9 +69,9 @@ def execute_ai_response_json(
def execute_command(
command_name: str,
arguments,
prompt: PromptGenerator,
command_name: str,
arguments,
prompt: PromptGenerator,
):
"""Execute the command and return the result
@@ -102,13 +97,15 @@ def execute_command(
else:
for command in prompt.commands:
if (
command_name == command["label"].lower()
or command_name == command["name"].lower()
command_name == command["label"].lower()
or command_name == command["name"].lower()
):
try:
# 删除非定义参数
diff_ags = list(set(arguments.keys()).difference(set(command['args'].keys())))
for arg_name in diff_ags:
diff_ags = list(
set(arguments.keys()).difference(set(command["args"].keys()))
)
for arg_name in diff_ags:
del arguments[arg_name]
print(str(arguments))
return command["function"](**arguments)

View File

@@ -1,19 +1,21 @@
from typing import Optional
from pilot.configs.config import Config
from pilot.prompts.generator import PromptGenerator
from typing import Any, Optional, Type
from pilot.prompts.prompt import build_default_prompt_generator
class CommandsLoad:
"""
Load Plugins Commands Info , help build system prompt!
Load Plugins Commands Info , help build system prompt!
"""
def __init__(self)->None:
def __init__(self) -> None:
self.command_registry = None
def getCommandInfos(self, prompt_generator: Optional[PromptGenerator] = None)-> str:
def getCommandInfos(
self, prompt_generator: Optional[PromptGenerator] = None
) -> str:
cfg = Config()
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()
@@ -24,4 +26,4 @@ class CommandsLoad:
self.prompt_generator = prompt_generator
command_infos = ""
command_infos += f"\n\n{prompt_generator.commands()}"
return command_infos
return command_infos

View File

@@ -1,5 +1,4 @@
class NotCommands(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message
self.message = message

View File

@@ -25,7 +25,7 @@ def generate_image(prompt: str, size: int = 256) -> str:
str: The filename of the image
"""
filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg"
# HuggingFace
if CFG.image_provider == "huggingface":
return generate_image_with_hf(prompt, filename)
@@ -72,6 +72,7 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
return f"Saved to disk:{filename}"
def generate_image_with_sd_webui(
prompt: str,
filename: str,

14
pilot/configs/__init__.py Normal file
View File

@@ -0,0 +1,14 @@
import os
import random
import sys
from dotenv import load_dotenv
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
print("Setting random seed to 42")
random.seed(42)
# Load the users .env file into environment variables
load_dotenv(verbose=True, override=True)
del load_dotenv

View File

@@ -7,13 +7,13 @@ from __future__ import annotations
import os
import platform
from pathlib import Path
from typing import Any, Optional, Type
from typing import Optional
import distro
import yaml
from pilot.prompts.generator import PromptGenerator
from pilot.configs.config import Config
from pilot.prompts.generator import PromptGenerator
from pilot.prompts.prompt import build_default_prompt_generator
# Soon this will go in a folder where it remembers more stuff about the run(s)
@@ -88,7 +88,7 @@ class AIConfig:
for goal in config_params.get("ai_goals", [])
]
api_budget = config_params.get("api_budget", 0.0)
# type: Type[AIConfig]
# type is Type[AIConfig]
return AIConfig(ai_name, ai_role, ai_goals, api_budget)
def save(self, config_file: str = SAVE_FILE) -> None:
@@ -133,8 +133,6 @@ class AIConfig:
""
)
cfg = Config()
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()

View File

@@ -2,16 +2,18 @@
# -*- coding: utf-8 -*-
import os
import nltk
from typing import List
import nltk
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.singleton import Singleton
from pilot.common.sql_database import Database
class Config(metaclass=Singleton):
"""Configuration class to store the state of bools for different scripts access"""
def __init__(self) -> None:
"""Initialize the Config class"""
@@ -19,7 +21,6 @@ class Config(metaclass=Singleton):
self.skip_reprompt = False
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
)
@@ -96,30 +97,39 @@ class Config(metaclass=Singleton):
else:
self.plugins_denylist = []
### Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
### TODO Adapt to multiple types of libraries
self.local_db = Database.from_uri("mysql+pymysql://" + self.LOCAL_DB_USER +":"+ self.LOCAL_DB_PASSWORD +"@" +self.LOCAL_DB_HOST + ":" + str(self.LOCAL_DB_PORT) ,
engine_args ={"pool_size": 10, "pool_recycle": 3600, "echo": True})
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://121.41.167.183:8000")
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
self.MODEL_PORT = os.getenv("MODEL_PORT", 8000)
self.MODEL_SERVER = os.getenv(
"MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT)
)
self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True"
### Vector Store Configuration
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value
def set_plugins(self, value: list) -> None:
"""Set the plugins value. """
"""Set the plugins value."""
self.plugins = value
def set_templature(self, value: int) -> None:
@@ -132,4 +142,4 @@ class Config(metaclass=Singleton):
def set_last_plugin_return(self, value: bool) -> None:
"""Set the speak mode value."""
self.last_plugin_return = value
self.last_plugin_return = value

View File

@@ -1,10 +1,10 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import torch
import os
import nltk
import nltk
import torch
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models")
@@ -16,24 +16,43 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
LLM_MODEL_CONFIG = {
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
"vicuna-7b": os.path.join(MODEL_PATH, "vicuna-7b"),
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
}
VECTOR_SEARCH_TOP_K = 20
LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 4096
# VICUNA_MODEL_SERVER = "http://121.41.227.141:8000"
VICUNA_MODEL_SERVER = "http://120.79.27.110:8000"
# Load model config
ISLOAD_8BIT = True
ISDEBUG = False
VECTOR_SEARCH_TOP_K = 3
# LLM_MODEL = "vicuna-13b"
# LIMIT_MODEL_CONCURRENCY = 5
# MAX_POSITION_EMBEDDINGS = 4096
# VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
VECTOR_SEARCH_TOP_K = 10
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge")
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "data"
)
KNOWLEDGE_CHUNK_SPLIT_SIZE = 100

View File

@@ -3,6 +3,6 @@
"""We need to design a base class. That other connector can Write with this"""
class BaseConnection:
pass

View File

@@ -4,4 +4,5 @@
class ClickHouseConnector:
"""ClickHouseConnector"""
pass
pass

View File

@@ -4,4 +4,5 @@
class ElasticSearchConnector:
"""ElasticSearchConnector"""
pass
pass

View File

@@ -1,6 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
class MongoConnector:
"""MongoConnector is a class which connect to mongo and chat with LLM"""
pass
pass

View File

@@ -4,26 +4,25 @@
import pymysql
class MySQLOperator:
"""Connect MySQL Database fetch MetaData For LLM Prompt
Args:
"""Connect MySQL Database fetch MetaData For LLM Prompt
Args:
Usage:
Usage:
"""
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
def __init__(self, user, password, host="localhost", port=3306) -> None:
self.conn = pymysql.connect(
host=host,
user=user,
port=port,
passwd=password,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor
cursorclass=pymysql.cursors.DictCursor,
)
def get_schema(self, schema_name):
with self.conn.cursor() as cursor:
_sql = f"""
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{schema_name}" group by TABLE_NAME;
@@ -32,6 +31,7 @@ class MySQLOperator:
results = cursor.fetchall()
return results
def run_sql(self, db_name:str, sql:str, fetch: str = "all"):
with self.conn.cursor() as cursor:
cursor.execute("USE " + db_name)
@@ -55,10 +55,10 @@ class MySQLOperator:
cursor.execute(_sql)
results = cursor.fetchall()
dbs = [d["Database"] for d in results if d["Database"] not in self.default_db]
dbs = [
d["Database"] for d in results if d["Database"] not in self.default_db
]
return dbs
def get_meta(self, schema_name):
pass

View File

@@ -1,6 +1,8 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
class OracleConnector:
"""OracleConnector"""
pass
pass

View File

@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
class PostgresConnector:
"""PostgresConnector is a class which Connector to chat with LLM"""
pass
pass

View File

@@ -4,4 +4,5 @@
class RedisConnector:
"""RedisConnector"""
pass
pass

View File

@@ -5,26 +5,32 @@ import dataclasses
import uuid
from enum import auto, Enum
from typing import List, Any
from pilot.configs.config import Config
CFG = Config()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT
"port": CFG.LOCAL_DB_PORT,
}
ROLE_USER = "USER"
ROLE_ASSISTANT = "Assistant"
class SeparatorStyle(Enum):
SINGLE = auto()
TWO = auto()
THREE = auto()
FOUR = auto()
@ dataclasses.dataclass
@dataclasses.dataclass
class Conversation:
"""This class keeps all conversation history. """
"""This class keeps all conversation history."""
system: str
roles: List[str]
@@ -65,7 +71,7 @@ class Conversation:
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
@@ -93,15 +99,14 @@ class Conversation:
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id
"conv_id": self.conv_id,
}
def gen_sqlgen_conversation(dbname):
from pilot.connections.mysql import MySQLOperator
mo = MySQLOperator(
**(DB_SETTINGS)
)
mo = MySQLOperator(**(DB_SETTINGS))
message = ""
@@ -113,7 +118,7 @@ def gen_sqlgen_conversation(dbname):
conv_one_shot = Conversation(
system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
roles=("USER", "Assistant"),
messages=(
(
@@ -134,20 +139,19 @@ conv_one_shot = Conversation(
"whereas PostgreSQL is known for its robustness and reliability.\n"
"5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, "
"whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n"
"Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. "
"Both are excellent database management systems, and choosing the right one "
"for your project requires careful consideration of your application's requirements, performance needs, and scalability."
"for your project requires careful consideration of your application's requirements, performance needs, and scalability.",
),
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###"
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
@@ -158,7 +162,7 @@ conv_vicuna_v1 = Conversation(
auto_dbgpt_one_shot = Conversation(
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
roles=("USER", "ASSISTANT"),
messages=(
(
@@ -201,7 +205,7 @@ auto_dbgpt_one_shot = Conversation(
}
}
}
"""
""",
),
(
"ASSISTANT",
@@ -221,8 +225,8 @@ auto_dbgpt_one_shot = Conversation(
}
}
}
"""
)
""",
),
),
offset=0,
sep_style=SeparatorStyle.SINGLE,
@@ -231,7 +235,7 @@ auto_dbgpt_one_shot = Conversation(
auto_dbgpt_without_shot = Conversation(
system="You are DB-GPT, an AI designed to answer questions about users by query `users` database in MySQL. "
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
@@ -248,11 +252,18 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回
{question}
"""
# conv_qa_prompt_template = """ Please provide the known information so that I can professionally and briefly answer the user's question. If the answer cannot be obtained from the provided content,
# please say: "The information provided in the knowledge base is insufficient to answer this question." Fabrication is prohibited.。
# known information:
# {context}
# question:
# {question}
# """
default_conversation = conv_one_shot
conversation_sql_mode ={
conversation_sql_mode = {
"auto_execute_ai_response": "直接执行结果",
"dont_execute_ai_response": "不直接执行结果"
"dont_execute_ai_response": "不直接执行结果",
}
conversation_types = {
@@ -265,7 +276,7 @@ conversation_types = {
conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1,
"auto_dbgpt_one_shot": auto_dbgpt_one_shot
"auto_dbgpt_one_shot": auto_dbgpt_one_shot,
}
if __name__ == "__main__":

View File

@@ -8,8 +8,8 @@ import re
from typing import Optional
from pilot.configs.config import Config
from pilot.logs import logger
from pilot.json_utils.utilities import extract_char_position
from pilot.logs import logger
CFG = Config()

View File

@@ -84,7 +84,7 @@ class Logger(metaclass=Singleton):
self.chat_plugins = []
def typewriter_log(
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
):
if speak_text and self.speak_mode:
say_text(f"{title}. {content}")
@@ -103,26 +103,26 @@ class Logger(metaclass=Singleton):
)
def debug(
self,
message,
title="",
title_color="",
self,
message,
title="",
title_color="",
):
self._log(title, title_color, message, logging.DEBUG)
def info(
self,
message,
title="",
title_color="",
self,
message,
title="",
title_color="",
):
self._log(title, title_color, message, logging.INFO)
def warn(
self,
message,
title="",
title_color="",
self,
message,
title="",
title_color="",
):
self._log(title, title_color, message, logging.WARN)
@@ -130,11 +130,11 @@ class Logger(metaclass=Singleton):
self._log(title, Fore.RED, message, logging.ERROR)
def _log(
self,
title: str = "",
title_color: str = "",
message: str = "",
level=logging.INFO,
self,
title: str = "",
title_color: str = "",
message: str = "",
level=logging.INFO,
):
if message:
if isinstance(message, list):
@@ -178,10 +178,12 @@ class Logger(metaclass=Singleton):
log_dir = os.path.join(this_files_dir_path, "../logs")
return os.path.abspath(log_dir)
"""
Output stream to console using simulated typing
"""
class TypingConsoleHandler(logging.StreamHandler):
def emit(self, record):
min_typing_speed = 0.05
@@ -203,6 +205,7 @@ class TypingConsoleHandler(logging.StreamHandler):
except Exception:
self.handleError(record)
class ConsoleHandler(logging.StreamHandler):
def emit(self, record) -> None:
msg = self.format(record)
@@ -221,10 +224,10 @@ class DbGptFormatter(logging.Formatter):
def format(self, record: LogRecord) -> str:
if hasattr(record, "color"):
record.title_color = (
getattr(record, "color")
+ getattr(record, "title", "")
+ " "
+ Style.RESET_ALL
getattr(record, "color")
+ getattr(record, "title", "")
+ " "
+ Style.RESET_ALL
)
else:
record.title_color = getattr(record, "title", "")
@@ -248,9 +251,9 @@ logger = Logger()
def print_assistant_thoughts(
ai_name: object,
assistant_reply_json_valid: object,
speak_mode: bool = False,
ai_name: object,
assistant_reply_json_valid: object,
speak_mode: bool = False,
) -> None:
assistant_thoughts_reasoning = None
assistant_thoughts_plan = None

View File

@@ -1,17 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from functools import cache
from typing import List
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from pilot.configs.model_config import DEVICE
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModel
)
class BaseLLMAdaper:
"""The Base class for multi model, in our project.
We will support those model, which performance resemble ChatGPT """
We will support those model, which performance resemble ChatGPT"""
def match(self, model_path: str):
return True
@@ -24,7 +23,8 @@ class BaseLLMAdaper:
return model, tokenizer
llm_model_adapters = List[BaseLLMAdaper] = []
llm_model_adapters: List[BaseLLMAdaper] = []
# Register llm models to adapters, by this we can use multi models.
def register_llm_model_adapters(cls):
@@ -37,60 +37,91 @@ def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper:
for adapter in llm_model_adapters:
if adapter.match(model_path):
return adapter
raise ValueError(f"Invalid model adapter for {model_path}")
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
class VicunaLLMAdapater(BaseLLMAdaper):
"""Vicuna Adapter """
"""Vicuna Adapter"""
def match(self, model_path: str):
return "vicuna" in model_path
return "vicuna" in model_path
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwagrs
model_path, low_cpu_mem_usage=True, **from_pretrained_kwagrs
)
return model, tokenizer
class ChatGLMAdapater(BaseLLMAdaper):
"""LLM Adatpter for THUDM/chatglm-6b"""
def match(self, model_path: str):
return "chatglm" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs
).half().cuda()
return model, tokenizer
if DEVICE != "cuda":
model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs
).float()
return model, tokenizer
else:
model = (
AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs
)
.half()
.cuda()
)
return model, tokenizer
class CodeGenAdapter(BaseLLMAdaper):
pass
class StarCoderAdapter(BaseLLMAdaper):
pass
class T5CodeAdapter(BaseLLMAdaper):
pass
class KoalaLLMAdapter(BaseLLMAdaper):
"""Koala LLM Adapter which Based LLaMA """
"""Koala LLM Adapter which Based LLaMA"""
def match(self, model_path: str):
return "koala" in model_path
class RWKV4LLMAdapter(BaseLLMAdaper):
"""LLM Adapter for RwKv4 """
"""LLM Adapter for RwKv4"""
def match(self, model_path: str):
return "RWKV-4" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
# TODO
pass
class GPT4AllAdapter(BaseLLMAdaper):
"""A light version for someone who want practise LLM use laptop."""
def match(self, model_path: str):
return "gpt4all" in model_path
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
# TODO Default support vicuna, other model need to tests and Evaluate
register_llm_model_adapters(BaseLLMAdaper)
register_llm_model_adapters(BaseLLMAdaper)

View File

@@ -1,11 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List, TypedDict
from typing import TypedDict
class Message(TypedDict):
"""LLM Message object containing usually like (role: content) """
"""LLM Message object containing usually like (role: content)"""
role: str
content: str

View File

@@ -1,3 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

View File

@@ -0,0 +1,58 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import torch
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
@torch.inference_mode()
def chatglm_generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2
):
"""Generate text using chatglm model's chat api"""
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
stop = params.get("stop", "###")
echo = params.get("echo", False)
generate_kwargs = {
"do_sample": True if temperature > 1e-5 else False,
"top_p": top_p,
"repetition_penalty": 1.0,
"logits_processor": None,
}
if temperature > 1e-5:
generate_kwargs["temperature"] = temperature
# TODO, Fix this
hist = []
messages = prompt.split(stop)
# Add history chat to hist for model.
for i in range(1, len(messages) - 2, 2):
hist.append(
(
messages[i].split(ROLE_USER + ":")[1],
messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
)
)
query = messages[-2].split(ROLE_USER + ":")[1]
print("Query Message: ", query)
output = ""
i = 0
for i, (response, new_hist) in enumerate(
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
):
if echo:
output = query + " " + response
else:
output = response
yield output
yield output

View File

@@ -3,14 +3,15 @@
import dataclasses
import torch
from torch import Tensor
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
@dataclasses.dataclass
class CompressionConfig:
"""Group-wise quantization."""
num_bits: int
group_size: int
group_dim: int
@@ -19,7 +20,8 @@ class CompressionConfig:
default_compression_config = CompressionConfig(
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True)
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True
)
class CLinear(nn.Module):
@@ -40,8 +42,11 @@ def compress_module(module, target_device):
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if type(target_attr) == torch.nn.Linear:
setattr(module, attr_str,
CLinear(target_attr.weight, target_attr.bias, target_device))
setattr(
module,
attr_str,
CLinear(target_attr.weight, target_attr.bias, target_device),
)
for name, child in module.named_children():
compress_module(child, target_device)
@@ -52,22 +57,31 @@ def compress(tensor, config):
return tensor
group_size, num_bits, group_dim, symmetric = (
config.group_size, config.num_bits, config.group_dim, config.symmetric)
config.group_size,
config.num_bits,
config.group_dim,
config.symmetric,
)
assert num_bits <= 8
original_shape = tensor.shape
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
new_shape = (original_shape[:group_dim] + (num_groups, group_size) +
original_shape[group_dim+1:])
new_shape = (
original_shape[:group_dim]
+ (num_groups, group_size)
+ original_shape[group_dim + 1 :]
)
# Pad
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
if pad_len != 0:
pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:]
tensor = torch.cat([
tensor,
torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
dim=group_dim)
pad_shape = (
original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :]
)
tensor = torch.cat(
[tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
dim=group_dim,
)
data = tensor.view(new_shape)
# Quantize
@@ -78,7 +92,7 @@ def compress(tensor, config):
data = data.clamp_(-B, B).round_().to(torch.int8)
return data, scale, original_shape
else:
B = 2 ** num_bits - 1
B = 2**num_bits - 1
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
@@ -96,7 +110,11 @@ def decompress(packed_data, config):
return packed_data
group_size, num_bits, group_dim, symmetric = (
config.group_size, config.num_bits, config.group_dim, config.symmetric)
config.group_size,
config.num_bits,
config.group_dim,
config.symmetric,
)
# Dequantize
if symmetric:
@@ -111,9 +129,10 @@ def decompress(packed_data, config):
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
if pad_len:
padded_original_shape = (
original_shape[:group_dim] +
(original_shape[group_dim] + pad_len,) +
original_shape[group_dim+1:])
original_shape[:group_dim]
+ (original_shape[group_dim] + pad_len,)
+ original_shape[group_dim + 1 :]
)
data = data.reshape(padded_original_shape)
indices = [slice(0, x) for x in original_shape]
return data[indices].contiguous()

View File

@@ -3,11 +3,12 @@
import torch
@torch.inference_mode()
def generate_stream(model, tokenizer, params, device,
context_len=4096, stream_interval=2):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
@torch.inference_mode()
def generate_stream(
model, tokenizer, params, device, context_len=4096, stream_interval=2
):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
prompt = params["prompt"]
l_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
@@ -22,17 +23,19 @@ def generate_stream(model, tokenizer, params, device,
for i in range(max_new_tokens):
if i == 0:
out = model(
torch.as_tensor([input_ids], device=device), use_cache=True)
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device=device)
out = model(input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
1, past_key_values[0][0].shape[-2] + 1, device=device
)
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
@@ -68,9 +71,12 @@ def generate_stream(model, tokenizer, params, device,
del past_key_values
@torch.inference_mode()
def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
def generate_output(
model, tokenizer, params, device, context_len=4096, stream_interval=2
):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
prompt = params["prompt"]
l_prompt = len(prompt)
@@ -78,7 +84,6 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)
@@ -87,17 +92,19 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
for i in range(max_new_tokens):
if i == 0:
out = model(
torch.as_tensor([input_ids], device=device), use_cache=True)
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device=device)
out = model(input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
1, past_key_values[0][0].shape[-2] + 1, device=device
)
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
@@ -120,7 +127,6 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
else:
stopped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
output = tokenizer.decode(output_ids, skip_special_tokens=True)
pos = output.rfind(stop_str, l_prompt)
@@ -133,8 +139,11 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
break
del past_key_values
@torch.inference_mode()
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):
def generate_output_ex(
model, tokenizer, params, device, context_len=2048, stream_interval=2
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
@@ -161,20 +170,20 @@ def generate_output_ex(model, tokenizer, params, device, context_len=2048, strea
for i in range(max_new_tokens):
if i == 0:
out = model(
torch.as_tensor([input_ids], device=device), use_cache=True)
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
out = model(input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values)
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
@@ -188,7 +197,6 @@ def generate_output_ex(model, tokenizer, params, device, context_len=2048, strea
else:
stopped = False
output = tokenizer.decode(output_ids, skip_special_tokens=True)
# print("Partial output:", output)
for stop_str in stop_strings:
@@ -211,7 +219,7 @@ def generate_output_ex(model, tokenizer, params, device, context_len=2048, strea
del past_key_values
if pos != -1:
return output[:pos]
return output
return output
@torch.inference_mode()

View File

@@ -1,12 +1,12 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from dataclasses import dataclass, field
from typing import List, TypedDict
from dataclasses import dataclass
from typing import TypedDict
class Message(TypedDict):
"""Vicuna Message object containing a role and the message content """
"""Vicuna Message object containing a role and the message content"""
role: str
content: str
@@ -18,12 +18,15 @@ class ModelInfo:
Would be lovely to eventually get this directly from APIs
"""
name: str
max_tokens: int
@dataclass
class LLMResponse:
"""Standard response struct for a response from a LLM model."""
model_info = ModelInfo
@@ -31,4 +34,4 @@ class LLMResponse:
class ChatModelResponse(LLMResponse):
"""Standard response struct for a response from an LLM model."""
content: str = None
content: str = None

View File

@@ -2,35 +2,39 @@
# -*- coding: utf-8 -*-
import abc
import time
import functools
from typing import List, Optional
from pilot.model.llm.base import Message
from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot
import time
from typing import Optional
from pilot.configs.config import Config
from pilot.conversation import (
Conversation,
auto_dbgpt_one_shot,
conv_one_shot,
conv_templates,
)
from pilot.model.llm.base import Message
# TODO Rewrite this
def retry_stream_api(
num_retries: int = 10,
backoff_base: float = 2.0,
warn_user: bool = True
):
num_retries: int = 10, backoff_base: float = 2.0, warn_user: bool = True
):
"""Retry an Vicuna Server call.
Args:
num_retries int: Number of retries. Defaults to 10.
backoff_base float: Base for exponential backoff. Defaults to 2.
warn_user bool: Whether to warn the user. Defaults to True.
Args:
num_retries int: Number of retries. Defaults to 10.
backoff_base float: Base for exponential backoff. Defaults to 2.
warn_user bool: Whether to warn the user. Defaults to True.
"""
retry_limit_msg = f"Error: Reached rate limit, passing..."
backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...")
backoff_msg = f"Error: API Bad gateway. Waiting {{backoff}} seconds..."
def _wrapper(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
num_attempts = num_retries + 1 # +1 for the first attempt
num_attempts = num_retries + 1 # +1 for the first attempt
for attempt in range(1, num_attempts + 1):
try:
return func(*args, **kwargs)
@@ -39,10 +43,13 @@ def retry_stream_api(
raise
backoff = backoff_base ** (attempt + 2)
time.sleep(backoff)
time.sleep(backoff)
return _wrapped
return _wrapper
# Overly simple abstraction util we create something better
# simple retry mechanism when getting a rate error or a bad gateway
def create_chat_competion(
@@ -52,15 +59,15 @@ def create_chat_competion(
max_new_tokens: Optional[int] = None,
) -> str:
"""Create a chat completion using the Vicuna-13b
Args:
messages(List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Default to None.
temperature (float, optional): The temperature to use. Defaults to 0.7.
max_tokens (int, optional): The max tokens to use. Defaults to None.
Returns:
str: The response from the chat completion
Args:
messages(List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Default to None.
temperature (float, optional): The temperature to use. Defaults to 0.7.
max_tokens (int, optional): The max tokens to use. Defaults to None.
Returns:
str: The response from the chat completion
"""
cfg = Config()
if temperature is None:
@@ -77,7 +84,7 @@ class ChatIO(abc.ABC):
@abc.abstractmethod
def prompt_for_input(self, role: str) -> str:
"""Prompt for input from a role."""
@abc.abstractmethod
def prompt_for_output(self, role: str) -> str:
"""Prompt for output from a role."""
@@ -105,4 +112,3 @@ class SimpleChatIO(ChatIO):
print(" ".join(outputs[pre:]), flush=True)
return " ".join(outputs)

View File

@@ -0,0 +1,125 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import math
from typing import Optional, Tuple
import torch
import transformers
from torch import nn
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2].clone()
x2 = x[..., x.shape[-1] // 2 :].clone()
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
self.head_dim
)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def replace_llama_attn_with_non_inplace_operations():
"""Avoid bugs in mps backend by not using in-place operations."""
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
import transformers
def replace_llama_attn_with_non_inplace_operations():
"""Avoid bugs in mps backend by not using in-place operations."""
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@@ -2,31 +2,33 @@
# -*- coding:utf-8 -*-
from typing import List, Optional
from pilot.model.base import Message
from pilot.configs.config import Config
from pilot.model.base import Message
from pilot.server.llmserver import generate_output
def create_chat_completion(
messages: List[Message], # type: ignore
messages: List[Message], # type: ignore
model: Optional[str] = None,
temperature: float = None,
max_tokens: Optional[int] = None,
) -> str:
"""Create a chat completion using the vicuna local model
"""Create a chat completion using the vicuna local model
Args:
messages(List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Defaults to None.
temperature (float, optional): The temperature to use. Defaults to 0.7.
max_tokens (int, optional): The max tokens to use. Defaults to None
Returns:
str: The response from chat completion
Args:
messages(List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Defaults to None.
temperature (float, optional): The temperature to use. Defaults to 0.7.
max_tokens (int, optional): The max tokens to use. Defaults to None
Returns:
str: The response from chat completion
"""
cfg = Config()
if temperature is None:
temperature = cfg.temperature
for plugin in cfg.plugins:
if plugin.can_handle_chat_completion(
messages=messages,

View File

@@ -1,55 +1,102 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import sys
import warnings
from pilot.singleton import Singleton
from typing import Optional
from pilot.model.compression import compress_module
import torch
from pilot.configs.model_config import DEVICE
from pilot.model.adapter import get_llm_model_adapter
from pilot.model.compression import compress_module
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
from pilot.singleton import Singleton
from pilot.utils import get_gpu_memory
def raise_warning_for_incompatible_cpu_offloading_configuration(
device: str, load_8bit: bool, cpu_offloading: bool
):
if cpu_offloading:
if not load_8bit:
warnings.warn(
"The cpu-offloading feature can only be used while also using 8-bit-quantization.\n"
"Use '--load-8bit' to enable 8-bit-quantization\n"
"Continuing without cpu-offloading enabled\n"
)
return False
if not "linux" in sys.platform:
warnings.warn(
"CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n"
"Continuing without cpu-offloading enabled\n"
)
return False
if device != "cuda":
warnings.warn(
"CPU-offloading is only enabled when using CUDA-devices\n"
"Continuing without cpu-offloading enabled\n"
)
return False
return cpu_offloading
class ModelLoader(metaclass=Singleton):
"""Model loader is a class for model load
Args: model_path
TODO: multi model support.
TODO: multi model support.
"""
kwargs = {}
def __init__(self,
model_path) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_path = model_path
def __init__(self, model_path) -> None:
self.device = DEVICE
self.model_path = model_path
self.kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto",
}
# TODO multi gpu support
def loader(self, num_gpus, load_8bit=False, debug=False):
def loader(
self,
num_gpus,
load_8bit=False,
debug=False,
cpu_offloading=False,
max_gpu_memory: Optional[str] = None,
):
if self.device == "cpu":
kwargs = {}
kwargs = {"torch_dtype": torch.float32}
elif self.device == "cuda":
kwargs = {"torch_dtype": torch.float16}
if num_gpus == "auto":
num_gpus = int(num_gpus)
if num_gpus != 1:
kwargs["device_map"] = "auto"
if max_gpu_memory is None:
kwargs["device_map"] = "sequential"
available_gpu_memory = get_gpu_memory(num_gpus)
kwargs["max_memory"] = {
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
for i in range(num_gpus)
}
else:
num_gpus = int(num_gpus)
if num_gpus != 1:
kwargs.update({
"device_map": "auto",
"max_memory": {i: "13GiB" for i in range(num_gpus)},
})
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
elif self.device == "mps":
kwargs = kwargs = {"torch_dtype": torch.float16}
replace_llama_attn_with_non_inplace_operations()
else:
# Todo Support mps for practise
raise ValueError(f"Invalid device: {self.device}")
# TODO when cpu loading, need use quantization config
llm_adapter = get_llm_model_adapter(self.model_path)
model, tokenizer = llm_adapter.loader(self.model_path, kwargs)
@@ -59,13 +106,14 @@ class ModelLoader(metaclass=Singleton):
"8-bit quantization is not supported for multi-gpu inference"
)
else:
compress_module(model, self.device)
compress_module(model, self.device)
if (self.device == "cuda" and num_gpus == 1):
if (
self.device == "cuda" and num_gpus == 1 and not cpu_offloading
) or self.device == "mps":
model.to(self.device)
if debug:
print(model)
return model, tokenizer

View File

@@ -2,25 +2,34 @@
# -*- coding:utf-8 -*-
import json
import requests
from typing import Any, List, Mapping, Optional
from urllib.parse import urljoin
import requests
from langchain.embeddings.base import Embeddings
from pydantic import BaseModel
from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM
from pydantic import BaseModel
from pilot.configs.config import Config
CFG = Config()
class VicunaLLM(LLM):
vicuna_generate_path = "generate_stream"
def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
temperature: float,
max_new_tokens: int,
stop: Optional[List[str]] = None,
) -> str:
params = {
"prompt": prompt,
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop": stop
"stop": stop,
}
response = requests.post(
url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path),
@@ -41,10 +50,9 @@ class VicunaLLM(LLM):
def _identifying_params(self) -> Mapping[str, Any]:
return {}
class VicunaEmbeddingLLM(BaseModel, Embeddings):
vicuna_embedding_path = "embedding"
def _call(self, prompt: str) -> str:
@@ -53,15 +61,13 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
response = requests.post(
url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path),
json={
"prompt": p
}
json={"prompt": p},
)
response.raise_for_status()
return response.json()["response"]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
""" Call out to Vicuna's server embedding endpoint for embedding search docs.
"""Call out to Vicuna's server embedding endpoint for embedding search docs.
Args:
texts: The list of text to embed
@@ -73,17 +79,15 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
for text in texts:
response = self.embed_query(text)
results.append(response)
return results
return results
def embed_query(self, text: str) -> List[float]:
""" Call out to Vicuna's server embedding endpoint for embedding query text.
Args:
"""Call out to Vicuna's server embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text
"""
embedding = self._call(text)
return embedding

View File

@@ -1,11 +1,10 @@
"""加载组件"""
import importlib
import json
import os
import zipfile
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List
from urllib.parse import urlparse
from zipimport import zipimporter
@@ -15,6 +14,7 @@ from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.configs.config import Config
from pilot.logs import logger
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
"""
Loader zip plugin file. Native support Auto_gpt_plugin
@@ -36,6 +36,7 @@ def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
return result
def write_dict_to_json_file(data: dict, file_path: str) -> None:
"""
Write a dictionary to a JSON file.
@@ -46,6 +47,7 @@ def write_dict_to_json_file(data: dict, file_path: str) -> None:
with open(file_path, "w") as file:
json.dump(data, file, indent=4)
def create_directory_if_not_exists(directory_path: str) -> bool:
"""
Create a directory if it does not exist.
@@ -66,6 +68,7 @@ def create_directory_if_not_exists(directory_path: str) -> bool:
logger.info(f"Directory {directory_path} already exists")
return True
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.

View File

@@ -1,19 +1,21 @@
from pilot.prompts.generator import PromptGenerator
from typing import Any, Optional, Type
import os
import platform
from pathlib import Path
from typing import Optional
import distro
import yaml
from pilot.configs.config import Config
from pilot.prompts.prompt import build_default_prompt_generator, DEFAULT_PROMPT_OHTER, DEFAULT_TRIGGERING_PROMPT
from pilot.prompts.generator import PromptGenerator
from pilot.prompts.prompt import (
DEFAULT_PROMPT_OHTER,
DEFAULT_TRIGGERING_PROMPT,
build_default_prompt_generator,
)
class AutoModePrompt:
"""
""" """
"""
def __init__(
self,
ai_goals: list | None = None,
@@ -36,23 +38,21 @@ class AutoModePrompt:
self.command_registry = None
def construct_follow_up_prompt(
self,
user_input:[str],
last_auto_return: str = None,
prompt_generator: Optional[PromptGenerator] = None
)-> str:
self,
user_input: [str],
last_auto_return: str = None,
prompt_generator: Optional[PromptGenerator] = None,
) -> str:
"""
Build complete prompt information based on subsequent dialogue information entered by the user
Args:
self:
prompt_generator:
Build complete prompt information based on subsequent dialogue information entered by the user
Args:
self:
prompt_generator:
Returns:
Returns:
"""
prompt_start = (
DEFAULT_PROMPT_OHTER
)
"""
prompt_start = DEFAULT_PROMPT_OHTER
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()
prompt_generator.goals = user_input
@@ -64,12 +64,13 @@ class AutoModePrompt:
continue
prompt_generator = plugin.post_prompt(prompt_generator)
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
if not self.ai_goals :
if not self.ai_goals:
self.ai_goals = user_input
for i, goal in enumerate(self.ai_goals):
full_prompt += f"{i+1}.According to the provided Schema information, {goal}\n"
full_prompt += (
f"{i+1}.According to the provided Schema information, {goal}\n"
)
# if last_auto_return == None:
# full_prompt += f"{cfg.last_plugin_return}\n\n"
# else:
@@ -82,10 +83,10 @@ class AutoModePrompt:
return full_prompt
def construct_first_prompt(
self,
fisrt_message: [str]=[],
db_schemes: str=None,
prompt_generator: Optional[PromptGenerator] = None
self,
fisrt_message: [str] = [],
db_schemes: str = None,
prompt_generator: Optional[PromptGenerator] = None,
) -> str:
"""
Build complete prompt information based on the initial dialogue information entered by the user
@@ -125,16 +126,18 @@ class AutoModePrompt:
# Construct full prompt
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
if not self.ai_goals :
if not self.ai_goals:
self.ai_goals = fisrt_message
for i, goal in enumerate(self.ai_goals):
full_prompt += f"{i+1}.According to the provided Schema information,{goal}\n"
if db_schemes:
full_prompt += f"\nSchema:\n\n"
full_prompt += (
f"{i+1}.According to the provided Schema information,{goal}\n"
)
if db_schemes:
full_prompt += f"\nSchema:\n\n"
full_prompt += f"{db_schemes}"
# if self.api_budget > 0.0:
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
self.prompt_generator = prompt_generator
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
return full_prompt
return full_prompt

View File

@@ -149,7 +149,7 @@ class PromptGenerator:
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
"Performance Evaluation:\n"
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
"You should only respond in JSON format as described below and ensure the"
"You should only respond in JSON format as described below and ensure the"
"response can be parsed by Python json.loads \nResponse"
f" Format: \n{formatted_response_format}"
)

View File

@@ -1,17 +1,14 @@
from pilot.configs.config import Config
from pilot.prompts.generator import PromptGenerator
CFG = Config()
DEFAULT_TRIGGERING_PROMPT = (
"Determine which next command to use, and respond using the format specified above"
)
DEFAULT_PROMPT_OHTER = (
"Previous response was excellent. Please response according to the requirements based on the new goal"
)
DEFAULT_PROMPT_OHTER = "Previous response was excellent. Please response according to the requirements based on the new goal"
def build_default_prompt_generator() -> PromptGenerator:
"""
@@ -36,17 +33,15 @@ def build_default_prompt_generator() -> PromptGenerator:
)
# prompt_generator.add_constraint("No user assistance")
prompt_generator.add_constraint(
'Only output one correct JSON response at a time'
)
prompt_generator.add_constraint("Only output one correct JSON response at a time")
prompt_generator.add_constraint(
'Exclusively use the commands listed in double quotes e.g. "command name"'
)
prompt_generator.add_constraint(
'If there is SQL in the args parameter, ensure to use the database and table definitions in Schema, and ensure that the fields and table names are in the definition'
"If there is SQL in the args parameter, ensure to use the database and table definitions in Schema, and ensure that the fields and table names are in the definition"
)
prompt_generator.add_constraint(
'The generated command args need to comply with the definition of the command'
"The generated command args need to comply with the definition of the command"
)
# Add resources to the PromptGenerator object

View File

@@ -1,26 +1,24 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
import transformers
from transformers import LlamaTokenizer, LlamaForCausalLM
import os
from typing import List
import pandas as pd
import torch
import transformers
from datasets import load_dataset
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
)
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch
from datasets import load_dataset
import pandas as pd
from pilot.configs.config import Config
from pilot.configs.model_config import DATA_DIR, LLM_MODEL_CONFIG
device = "cuda" if torch.cuda.is_available() else "cpu"
CUTOFF_LEN = 50
@@ -28,6 +26,7 @@ df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv"))
CFG = Config()
def sentiment_score_to_name(score: float):
if score > 0:
return "Positive"
@@ -40,16 +39,18 @@ dataset_data = [
{
"instruction": "Detect the sentiment of the tweet.",
"input": row_dict["Tweet"],
"output": sentiment_score_to_name(row_dict["New_Sentiment_State"])
}
"output": sentiment_score_to_name(row_dict["New_Sentiment_State"]),
}
for row_dict in df.to_dict(orient="records")
]
with open(os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"), "w") as f:
json.dump(dataset_data, f)
json.dump(dataset_data, f)
data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"))
data = load_dataset(
"json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json")
)
print(data["train"])
BASE_MODEL = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
@@ -57,13 +58,14 @@ model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto",
offload_folder=os.path.join(DATA_DIR, "vicuna-lora")
)
offload_folder=os.path.join(DATA_DIR, "vicuna-lora"),
)
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token_id = (0)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
def generate_prompt(data_point):
return f"""Blow is an instruction that describes a task, paired with an input that provide future context.
Write a response that appropriately completes the request. #noqa:
@@ -76,6 +78,7 @@ def generate_prompt(data_point):
{data_point["output"]}
"""
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
@@ -85,30 +88,29 @@ def tokenize(prompt, add_eos_token=True):
return_tensors=None,
)
if (result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < CUTOFF_LEN and add_eos_token):
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < CUTOFF_LEN
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = generate_prompt(data_point)
tokenized_full_prompt = tokenize(full_prompt)
return tokenized_full_prompt
train_val = data["train"].train_test_split(
test_size=200, shuffle=True, seed=42
)
train_val = data["train"].train_test_split(test_size=200, shuffle=True, seed=42)
train_data = (
train_val["train"].map(generate_and_tokenize_prompt)
)
train_data = train_val["train"].map(generate_and_tokenize_prompt)
val_data = (
train_val["test"].map(generate_and_tokenize_prompt)
)
val_data = train_val["test"].map(generate_and_tokenize_prompt)
# Training
LORA_R = 8
@@ -129,7 +131,7 @@ OUTPUT_DIR = "experiments"
# We can now prepare model for training
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r = LORA_R,
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=LORA_TARGET_MODULES,
lora_dropout=LORA_DROPOUT,
@@ -156,7 +158,7 @@ training_arguments = transformers.TrainingArguments(
output_dir=OUTPUT_DIR,
save_total_limit=3,
load_best_model_at_end=True,
report_to="tensorboard"
report_to="tensorboard",
)
data_collector = transformers.DataCollatorForSeq2Seq(
@@ -168,15 +170,13 @@ trainer = transformers.Trainer(
train_dataset=train_data,
eval_dataset=val_data,
args=training_arguments,
data_collector=data_collector
data_collector=data_collector,
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
trainer.train()

View File

@@ -0,0 +1,90 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from functools import cache
from typing import List
from pilot.model.inference import generate_stream
class BaseChatAdpter:
"""The Base class for chat with llm models. it will match the model,
and fetch output from model"""
def match(self, model_path: str):
return True
def get_generate_stream_func(self):
"""Return the generate stream handler func"""
pass
llm_model_chat_adapters: List[BaseChatAdpter] = []
def register_llm_model_chat_adapter(cls):
"""Register a chat adapter"""
llm_model_chat_adapters.append(cls())
@cache
def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
"""Get a chat generate func for a model"""
for adapter in llm_model_chat_adapters:
if adapter.match(model_path):
return adapter
raise ValueError(f"Invalid model for chat adapter {model_path}")
class VicunaChatAdapter(BaseChatAdpter):
"""Model chat Adapter for vicuna"""
def match(self, model_path: str):
return "vicuna" in model_path
def get_generate_stream_func(self):
return generate_stream
class ChatGLMChatAdapter(BaseChatAdpter):
"""Model chat Adapter for ChatGLM"""
def match(self, model_path: str):
return "chatglm" in model_path
def get_generate_stream_func(self):
from pilot.model.chatglm_llm import chatglm_generate_stream
return chatglm_generate_stream
class CodeT5ChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeT5"""
def match(self, model_path: str):
return "codet5" in model_path
def get_generate_stream_func(self):
# TODO
pass
class CodeGenChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeGen"""
def match(self, model_path: str):
return "codegen" in model_path
def get_generate_stream_func(self):
# TODO
pass
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(BaseChatAdpter)

View File

@@ -1,8 +1,7 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
code_highlight_css = (
"""
code_highlight_css = """
#chatbot .hll { background-color: #ffffcc }
#chatbot .c { color: #408080; font-style: italic }
#chatbot .err { border: 1px solid #FF0000 }
@@ -71,6 +70,5 @@ code_highlight_css = (
#chatbot .vi { color: #19177C }
#chatbot .vm { color: #19177C }
#chatbot .il { color: #666666 }
""")
#.highlight { background: #f8f8f8; }
"""
# .highlight { background: #f8f8f8; }

View File

@@ -49,7 +49,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
warnings.warn(
"The 'color_map' parameter has been deprecated.",
)
#self.md = utils.get_markdown_parser()
# self.md = utils.get_markdown_parser()
self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
self.select: EventListenerMethod
"""
@@ -112,7 +112,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
): # This happens for previously processed messages
return chat_message
elif isinstance(chat_message, str):
#return self.md.render(chat_message)
# return self.md.render(chat_message)
return str(self.md.convert(chat_message))
else:
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
@@ -141,9 +141,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
processed_messages.append(
(
#self._process_chat_messages(message_pair[0]),
'<pre style="font-family: var(--font)">' +
message_pair[0] + "</pre>",
# self._process_chat_messages(message_pair[0]),
'<pre style="font-family: var(--font)">'
+ message_pair[0]
+ "</pre>",
self._process_chat_messages(message_pair[1]),
)
)
@@ -163,5 +164,3 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
**kwargs,
)
return self

View File

@@ -1,40 +1,91 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import uvicorn
import asyncio
import json
from typing import Optional, List
from fastapi import FastAPI, Request, BackgroundTasks
import os
import sys
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse
from pilot.model.inference import generate_stream
from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModelLoader
from pilot.configs.model_config import *
from pilot.configs.config import Config
CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
global_counter = 0
model_semaphore = None
ml = ModelLoader(model_path=model_path)
model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
#model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import *
from pilot.model.inference import generate_output, generate_stream, get_embeddings
from pilot.model.loader import ModelLoader
from pilot.server.chat_adapter import get_llm_chat_adapter
CFG = Config()
class ModelWorker:
def __init__(self):
pass
def __init__(self, model_path, model_name, device, num_gpus=1):
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_name = model_name or model_path.split("/")[-1]
self.device = device
self.ml = ModelLoader(model_path=model_path)
self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
)
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
else:
self.context_len = 2048
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
def get_queue_length(self):
if (
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
):
return 0
else:
(
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
)
def generate_stream_gate(self, params):
try:
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):
print("output: ", output)
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt):
return get_embeddings(self.model, self.tokenizer, prompt)
# TODO
app = FastAPI()
class PromptRequest(BaseModel):
prompt: str
temperature: float
@@ -42,6 +93,7 @@ class PromptRequest(BaseModel):
model: str
stop: str = None
class StreamRequest(BaseModel):
model: str
prompt: str
@@ -49,64 +101,43 @@ class StreamRequest(BaseModel):
max_new_tokens: int
stop: str
class EmbeddingRequest(BaseModel):
prompt: str
def release_model_semaphore():
model_semaphore.release()
def generate_stream_gate(params):
try:
for output in generate_stream(
model,
tokenizer,
params,
DEVICE,
CFG.MAX_POSITION_EMBEDDINGS,
):
print("output: ", output)
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {
"text": "**GPU OutOfMemory, Please Refresh.**",
"error_code": 0
}
yield json.dumps(ret).encode() + b"\0"
@app.post("/generate_stream")
async def api_generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
print(model, tokenizer, params, DEVICE)
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
await model_semaphore.acquire()
generator = generate_stream_gate(params)
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate")
def generate(prompt_request: PromptRequest):
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
"stop": prompt_request.stop,
}
response = []
response = []
rsp_str = ""
output = generate_stream_gate(params)
output = worker.generate_stream_gate(params)
for rsp in output:
# rsp = rsp.decode("utf-8")
rsp_str = str(rsp, "utf-8")
@@ -114,15 +145,22 @@ def generate(prompt_request: PromptRequest):
response.append(rsp_str)
return {"response": rsp_str}
@app.post("/embedding")
def embeddings(prompt_request: EmbeddingRequest):
params = {"prompt": prompt_request.prompt}
print("Received prompt: ", params["prompt"])
output = get_embeddings(model, tokenizer, params["prompt"])
output = worker.get_embeddings(params["prompt"])
return {"response": [float(x) for x in output]}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", log_level="info")
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
print(model_path, DEVICE)
worker = ModelWorker(
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
)
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")

View File

@@ -1,29 +1,30 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pilot.vector_store.file_loader import KnownLedge2Vector
from langchain.prompts import PromptTemplate
from pilot.conversation import conv_qa_prompt_template
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.conversation import conv_qa_prompt_template
from pilot.model.vicuna_llm import VicunaLLM
from pilot.vector_store.file_loader import KnownLedge2Vector
class KnownLedgeBaseQA:
def __init__(self) -> None:
k2v = KnownLedge2Vector()
self.vector_store = k2v.init_vector_store()
self.llm = VicunaLLM()
def get_similar_answer(self, query):
prompt = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
template=conv_qa_prompt_template, input_variables=["context", "question"]
)
retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K})
retriever = self.vector_store.as_retriever(
search_kwargs={"k": VECTOR_SEARCH_TOP_K}
)
docs = retriever.get_relevant_documents(query=query)
context = [d.page_content for d in docs]
context = [d.page_content for d in docs]
result = prompt.format(context="\n".join(context), question=query)
return result

View File

@@ -2,50 +2,63 @@
# -*- coding: utf-8 -*-
import argparse
import datetime
import json
import os
import shutil
import uuid
import json
import sys
import time
import gradio as gr
import datetime
import requests
import uuid
from urllib.parse import urljoin
import gradio as gr
import requests
from langchain import PromptTemplate
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql import MySQLOperator
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.plugins import scan_plugins
from pilot.configs.config import Config
from pilot.commands.command import execute_ai_response_json
from pilot.commands.command_mange import CommandRegistry
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.scene.base_chat import BaseChat
from pilot.commands.exception_not_commands import NotCommands
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.connections.mysql import MySQLOperator
from pilot.conversation import (
default_conversation,
SeparatorStyle,
conv_qa_prompt_template,
conv_templates,
conversation_types,
conversation_sql_mode,
SeparatorStyle, conv_qa_prompt_template
conversation_types,
default_conversation,
)
from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.plugins import scan_plugins
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.utils import build_logger, server_error_msg
from pilot.vector_store.extract_tovec import (
get_vector_storelist,
knownledge_tovec_st,
load_knownledge_from_doc,
)
from pilot.commands.command import execute_ai_response_json
from pilot.scene.base import ChatScene
@@ -66,9 +79,7 @@ autogpt = False
vector_store_client = None
vector_store_name = {"vs_name": ""}
priority = {
"vicuna-13b": "aaa"
}
priority = {"vicuna-13b": "aaa"}
# 加载插件
CFG = Config()
@@ -76,10 +87,12 @@ CHAT_FACTORY = ChatFactory()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT
"port": CFG.LOCAL_DB_PORT,
}
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
@@ -89,9 +102,7 @@ def get_simlar(q):
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(
**DB_SETTINGS
)
mo = MySQLOperator(**DB_SETTINGS)
message = ""
@@ -334,8 +345,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
block_css = (
code_highlight_css
+ """
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@@ -372,7 +383,7 @@ def build_single_model_ui():
notice_markdown = """
# DB-GPT
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat)并使用vicuna-13b作为基础模型。此外此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 总的来说它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个开源的以数据库为基础的GPT实验项目使用本地化的GPT大模型与您的数据和环境进行交互无数据泄露风险100% 私密100% 安全
"""
learn_more_markdown = """
### Licence
@@ -396,7 +407,7 @@ def build_single_model_ui():
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=1024,
value=512,
step=64,
interactive=True,
label="最大输出Token数",
@@ -412,7 +423,8 @@ def build_single_model_ui():
choices=dbs,
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True).style(container=False)
show_label=True,
).style(container=False)
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
@@ -420,7 +432,9 @@ def build_single_model_ui():
tab_qa = gr.TabItem("知识问答", elem_id="QA")
with tab_qa:
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
mode = gr.Radio(
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
)
vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting:
@@ -429,18 +443,22 @@ def build_single_model_ui():
with gr.Column() as doc2vec:
gr.Markdown("向知识库中添加文件")
with gr.Tab("上传文件"):
files = gr.File(label="添加文件",
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
show_label=False
)
files = gr.File(
label="添加文件",
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
allow_flagged_uploads=True,
show_label=False,
)
load_file_button = gr.Button("上传并加载到知识库")
with gr.Tab("上传文件夹"):
folder_files = gr.File(label="添加文件夹",
accept_multiple_files=True,
file_count="directory",
show_label=False)
folder_files = gr.File(
label="添加文件夹",
accept_multiple_files=True,
file_count="directory",
show_label=False,
)
load_folder_button = gr.Button("上传并加载到知识库")
with gr.Blocks():
@@ -481,28 +499,32 @@ def build_single_model_ui():
).then(
http_bot,
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list
[state, chatbot] + btn_list,
)
vs_add.click(
fn=save_vs_name, show_progress=True, inputs=[vs_name], outputs=[vs_name]
)
load_file_button.click(
fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, files],
outputs=[vs_name],
)
load_folder_button.click(
fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, folder_files],
outputs=[vs_name],
)
vs_add.click(fn=save_vs_name, show_progress=True,
inputs=[vs_name],
outputs=[vs_name])
load_file_button.click(fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, files],
outputs=[vs_name])
load_folder_button.click(fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, folder_files],
outputs=[vs_name])
return state, chatbot, textbox, send_btn, button_row, parameter_row
def build_webdemo():
with gr.Blocks(
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(
@@ -544,15 +566,21 @@ def knowledge_embedding_store(vs_id, files):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id))
for file in files:
filename = os.path.split(file.name)[-1]
shutil.move(file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename))
shutil.move(
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
)
knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
knowledge_embedding_client.knowledge_embedding()
logger.info("knowledge embedding success")
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
@@ -596,5 +624,8 @@ if __name__ == "__main__":
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)

View File

@@ -5,10 +5,12 @@
import abc
from typing import Any
class Singleton(abc.ABCMeta, type):
""" Singleton metaclass for ensuring only one instance of a class"""
"""Singleton metaclass for ensuring only one instance of a class"""
_instances = {}
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
"""Call method for the singleton metaclass"""
if cls not in cls._instances:
@@ -18,4 +20,5 @@ class Singleton(abc.ABCMeta, type):
class AbstractSingleton(abc.ABC, metaclass=Singleton):
"""Abstract singleton class for ensuring only one instance of a class"""
pass
pass

View File

@@ -1,8 +1,3 @@
from pilot.source_embedding.source_embedding import SourceEmbedding
from pilot.source_embedding.source_embedding import register
from pilot.source_embedding.source_embedding import SourceEmbedding, register
__all__ = [
"SourceEmbedding",
"register"
]
__all__ = ["SourceEmbedding", "register"]

View File

@@ -0,0 +1,55 @@
import re
from typing import List
from langchain.text_splitter import CharacterTextSplitter
class CHNDocumentSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, sentence_size: int = None, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
self.sentence_size = sentence_size
def split_text(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub("\s", " ", text)
text = re.sub("\n\n", "", text)
text = re.sub(r"([;.!?。!?\?])([^”’])", r"\1\n\2", text)
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)
text = re.sub(r'(\{2})([^"’”」』])', r"\1\n\2", text)
text = re.sub(r'([;!?。!?\?]["’”」』]{0,2})([^;!?,。!?\?])', r"\1\n\2", text)
text = text.rstrip()
ls = [i for i in text.split("\n") if i]
for ele in ls:
if len(ele) > self.sentence_size:
ele1 = re.sub(r'([,.]["’”」』]{0,2})([^,.])', r"\1\n\2", ele)
ele1_ls = ele1.split("\n")
for ele_ele1 in ele1_ls:
if len(ele_ele1) > self.sentence_size:
ele_ele2 = re.sub(
r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r"\1\n\2", ele_ele1
)
ele2_ls = ele_ele2.split("\n")
for ele_ele2 in ele2_ls:
if len(ele_ele2) > self.sentence_size:
ele_ele3 = re.sub(
'( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2
)
ele2_id = ele2_ls.index(ele_ele2)
ele2_ls = (
ele2_ls[:ele2_id]
+ [i for i in ele_ele3.split("\n") if i]
+ ele2_ls[ele2_id + 1 :]
)
ele_id = ele1_ls.index(ele_ele1)
ele1_ls = (
ele1_ls[:ele_id]
+ [i for i in ele2_ls if i]
+ ele1_ls[ele_id + 1 :]
)
id = ls.index(ele)
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :]
return ls

View File

@@ -1,14 +1,21 @@
from typing import List, Optional, Dict
from pilot.source_embedding import SourceEmbedding, register
from typing import Dict, List, Optional
from langchain.document_loaders import CSVLoader
from langchain.schema import Document
from pilot.source_embedding import SourceEmbedding, register
class CSVEmbedding(SourceEmbedding):
"""csv embedding for read csv document."""
def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
def __init__(
self,
file_path,
model_name,
vector_store_config,
embedding_args: Optional[Dict] = None,
):
"""Initialize with csv path."""
super().__init__(file_path, model_name, vector_store_config)
self.file_path = file_path
@@ -29,6 +36,3 @@ class CSVEmbedding(SourceEmbedding):
documents[i].page_content = d.page_content.replace("\n", "")
i += 1
return documents

View File

@@ -1,34 +1,126 @@
import os
import markdown
from bs4 import BeautifulSoup
from langchain.document_loaders import PyPDFLoader, TextLoader, markdown
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
from pilot.source_embedding.csv_embedding import CSVEmbedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.vector_store.connector import VectorStoreConnector
CFG = Config()
class KnowledgeEmbedding:
def __init__(self, file_path, model_name, vector_store_config):
def __init__(self, file_path, model_name, vector_store_config, local_persist=True):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config
self.vector_store_type = "default"
self.knowledge_embedding_client = self.init_knowledge_embedding()
self.file_type = "default"
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings
self.local_persist = local_persist
if not self.local_persist:
self.knowledge_embedding_client = self.init_knowledge_embedding()
def knowledge_embedding(self):
self.knowledge_embedding_client.source_embedding()
def knowledge_embedding_batch(self):
self.knowledge_embedding_client.batch_embedding()
def init_knowledge_embedding(self):
if self.file_path.endswith(".pdf"):
embedding = PDFEmbedding(file_path=self.file_path, model_name=self.model_name,
vector_store_config=self.vector_store_config)
embedding = PDFEmbedding(
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
elif self.file_path.endswith(".md"):
embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config)
embedding = MarkdownEmbedding(
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
elif self.file_path.endswith(".csv"):
embedding = CSVEmbedding(file_path=self.file_path, model_name=self.model_name,
vector_store_config=self.vector_store_config)
elif self.vector_store_type == "default":
embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config)
embedding = CSVEmbedding(
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
elif self.file_type == "default":
embedding = MarkdownEmbedding(
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
return embedding
def similar_search(self, text, topk):
return self.knowledge_embedding_client.similar_search(text, topk)
return self.knowledge_embedding_client.similar_search(text, topk)
def knowledge_persist_initialization(self, append_mode):
documents = self._load_knownlege(self.file_path)
self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
self.vector_client.load_document(documents)
return self.vector_client
def _load_knownlege(self, path):
docments = []
for root, _, files in os.walk(path, topdown=False):
for file in files:
filename = os.path.join(root, file)
docs = self._load_file(filename)
new_docs = []
for doc in docs:
doc.metadata = {
"source": doc.metadata["source"].replace(DATASETS_DIR, "")
}
print("doc is embedding...", doc.metadata)
new_docs.append(doc)
docments += new_docs
return docments
def _load_file(self, filename):
if filename.lower().endswith(".md"):
loader = TextLoader(filename)
text_splitter = CHNDocumentSplitter(
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
)
docs = loader.load_and_split(text_splitter)
i = 0
for d in docs:
content = markdown.markdown(d.page_content)
soup = BeautifulSoup(content, "html.parser")
for tag in soup(["!doctype", "meta", "i.fa"]):
tag.extract()
docs[i].page_content = soup.get_text()
docs[i].page_content = docs[i].page_content.replace("\n", " ")
i += 1
elif filename.lower().endswith(".pdf"):
loader = PyPDFLoader(filename)
textsplitter = CHNDocumentSplitter(
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
)
docs = loader.load_and_split(textsplitter)
i = 0
for d in docs:
docs[i].page_content = d.page_content.replace("\n", " ").replace(
"<EFBFBD>", ""
)
i += 1
else:
loader = TextLoader(filename)
text_splitor = CHNDocumentSplitter(sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
docs = loader.load_and_split(text_splitor)
return docs

View File

@@ -1,13 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from typing import List
import markdown
from bs4 import BeautifulSoup
from langchain.document_loaders import TextLoader
from langchain.schema import Document
import markdown
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
class MarkdownEmbedding(SourceEmbedding):
@@ -24,20 +27,42 @@ class MarkdownEmbedding(SourceEmbedding):
def read(self):
"""Load from markdown path."""
loader = TextLoader(self.file_path)
return loader.load()
text_splitter = CHNDocumentSplitter(
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
)
return loader.load_and_split(text_splitter)
@register
def read_batch(self):
"""Load from markdown path."""
docments = []
for root, _, files in os.walk(self.file_path, topdown=False):
for file in files:
filename = os.path.join(root, file)
loader = TextLoader(filename)
# text_splitor = CHNDocumentSplitter(chunk_size=1000, chunk_overlap=20, length_function=len)
# docs = loader.load_and_split()
docs = loader.load()
# 更新metadata数据
new_docs = []
for doc in docs:
doc.metadata = {
"source": doc.metadata["source"].replace(self.file_path, "")
}
print("doc is embedding ... ", doc.metadata)
new_docs.append(doc)
docments += new_docs
return docments
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
content = markdown.markdown(d.page_content)
soup = BeautifulSoup(content, 'html.parser')
for tag in soup(['!doctype', 'meta', 'i.fa']):
soup = BeautifulSoup(content, "html.parser")
for tag in soup(["!doctype", "meta", "i.fa"]):
tag.extract()
documents[i].page_content = soup.get_text()
documents[i].page_content = documents[i].page_content.replace(" ", "").replace("\n", " ")
documents[i].page_content = documents[i].page_content.replace("\n", " ")
i += 1
return documents

View File

@@ -5,7 +5,9 @@ from typing import List
from langchain.document_loaders import PyPDFLoader
from langchain.schema import Document
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
class PDFEmbedding(SourceEmbedding):
@@ -17,22 +19,21 @@ class PDFEmbedding(SourceEmbedding):
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config
# SourceEmbedding(file_path =file_path, );
SourceEmbedding(file_path, model_name, vector_store_config)
@register
def read(self):
"""Load from pdf path."""
# loader = UnstructuredPaddlePDFLoader(self.file_path)
loader = PyPDFLoader(self.file_path)
return loader.load()
textsplitter = CHNDocumentSplitter(
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
)
return loader.load_and_split(textsplitter)
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
documents[i].page_content = d.page_content.replace(" ", "").replace("\n", "")
documents[i].page_content = d.page_content.replace("\n", "")
i += 1
return documents

View File

@@ -0,0 +1,55 @@
"""Loader that loads image files."""
import os
from typing import List
import fitz
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from paddleocr import PaddleOCR
class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
def _get_elements(self) -> List:
def pdf_ocr_txt(filepath, dir_path="tmp_files"):
full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
if not os.path.exists(full_dir_path):
os.makedirs(full_dir_path)
filename = os.path.split(filepath)[-1]
ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False)
doc = fitz.open(filepath)
txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename))
img_name = os.path.join(full_dir_path, ".tmp.png")
with open(txt_file_path, "w", encoding="utf-8") as fout:
for i in range(doc.page_count):
page = doc[i]
text = page.get_text("")
fout.write(text)
fout.write("\n")
img_list = page.get_images()
for img in img_list:
pix = fitz.Pixmap(doc, img[0])
pix.save(img_name)
result = ocr.ocr(img_name)
ocr_result = [i[1][0] for line in result for i in line]
fout.write("\n".join(ocr_result))
os.remove(img_name)
return txt_file_path
txt_file_path = pdf_ocr_txt(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
if __name__ == "__main__":
filepath = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf"
)
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
docs = loader.load()
for doc in docs:
print(doc)

View File

@@ -50,7 +50,7 @@
#
# # text_embeddings = Text2Vectors()
# mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_k"})
#
#
# mivuls.insert(["textc","tezt2"])
# print("success")
# ct
@@ -58,4 +58,4 @@
# # docs,
# # embedding=embeddings,
# # connection_args={"host": "127.0.0.1", "port": "19530", "alias": "default"}
# # )
# # )

View File

@@ -1,14 +1,15 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from typing import List, Optional, Dict
from pilot.configs.config import Config
from pilot.vector_store.connector import VectorStoreConnector
registered_methods = []
CFG = Config()
def register(method):
@@ -22,21 +23,30 @@ class SourceEmbedding(ABC):
Implementations should implement the method
"""
def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
def __init__(
self,
file_path,
model_name,
vector_store_config,
embedding_args: Optional[Dict] = None,
):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config
self.embedding_args = embedding_args
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
self.vector_store_config["vector_store_name"] + ".vectordb")
self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
vector_store_config["embeddings"] = self.embeddings
self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, vector_store_config
)
@abstractmethod
@register
def read(self) -> List[ABC]:
"""read datasource into document objects."""
@register
def data_process(self, text):
"""pre process data."""
@@ -54,25 +64,33 @@ class SourceEmbedding(ABC):
@register
def index_to_store(self, docs):
"""index to vector store"""
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
self.vector_store_config["vector_store_name"] + ".vectordb")
self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
self.vector_store.persist()
self.vector_client.load_document(docs)
@register
def similar_search(self, doc, topk):
"""vector store similarity_search"""
return self.vector_store_client.similarity_search(doc, topk)
return self.vector_client.similar_search(doc, topk)
def source_embedding(self):
if 'read' in registered_methods:
if "read" in registered_methods:
text = self.read()
if 'data_process' in registered_methods:
if "data_process" in registered_methods:
text = self.data_process(text)
if 'text_split' in registered_methods:
if "text_split" in registered_methods:
self.text_split(text)
if 'text_to_vector' in registered_methods:
if "text_to_vector" in registered_methods:
self.text_to_vector(text)
if 'index_to_store' in registered_methods:
if "index_to_store" in registered_methods:
self.index_to_store(text)
def batch_embedding(self):
if "read_batch" in registered_methods:
text = self.read_batch()
if "data_process" in registered_methods:
text = self.data_process(text)
if "text_split" in registered_methods:
self.text_split(text)
if "text_to_vector" in registered_methods:
self.text_to_vector(text)
if "index_to_store" in registered_methods:
self.index_to_store(text)

View File

@@ -1,10 +1,11 @@
from typing import List
from pilot.source_embedding import SourceEmbedding, register
from bs4 import BeautifulSoup
from langchain.document_loaders import WebBaseLoader
from langchain.schema import Document
from langchain.text_splitter import CharacterTextSplitter
from pilot.source_embedding import SourceEmbedding, register
class URLEmbedding(SourceEmbedding):
@@ -20,19 +21,19 @@ class URLEmbedding(SourceEmbedding):
def read(self):
"""Load from url path."""
loader = WebBaseLoader(web_path=self.file_path)
return loader.load()
text_splitor = CharacterTextSplitter(
chunk_size=1000, chunk_overlap=20, length_function=len
)
return loader.load_and_split(text_splitor)
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
content = d.page_content.replace("\n", "")
soup = BeautifulSoup(content, 'html.parser')
for tag in soup(['!doctype', 'meta']):
soup = BeautifulSoup(content, "html.parser")
for tag in soup(["!doctype", "meta"]):
tag.extract()
documents[i].page_content = soup.get_text()
i += 1
return documents

View File

@@ -1,27 +1,28 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import torch
import datetime
import logging
import logging.handlers
import os
import sys
import requests
import torch
from pilot.configs.model_config import LOGDIR
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
server_error_msg = (
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)
handler = None
def get_gpu_memory(max_gpus=None):
gpu_memory = []
num_gpus = (
torch.cuda.device_count()
if max_gpus is None
if max_gpus is None
else min(max_gpus, torch.cuda.device_count())
)
@@ -29,12 +30,11 @@ def get_gpu_memory(max_gpus=None):
with torch.cuda.device(gpu_id):
device = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device)
total_memory = gpu_properties.total_memory / (1024 ** 3)
allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)
total_memory = gpu_properties.total_memory / (1024**3)
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
available_memory = total_memory - allocated_memory
gpu_memory.append(available_memory)
return gpu_memory
return gpu_memory
def build_logger(logger_name, logger_filename):
@@ -47,7 +47,7 @@ def build_logger(logger_name, logger_filename):
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO, encoding='utf-8')
logging.basicConfig(level=logging.INFO, encoding="utf-8")
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
@@ -70,7 +70,8 @@ def build_logger(logger_name, logger_filename):
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when='D', utc=True)
filename, when="D", utc=True
)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
@@ -84,35 +85,36 @@ class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ''
self.linebuf = ""
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ''
self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == '\n':
encoded_message = line.encode('utf-8', 'ignore').decode('utf-8')
if line[-1] == "\n":
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != '':
encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8')
if self.linebuf != "":
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
self.linebuf = ''
self.linebuf = ""
def disable_torch_init():
@@ -120,6 +122,7 @@ def disable_torch_init():
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
@@ -128,4 +131,3 @@ def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"

View File

@@ -0,0 +1,32 @@
import os
from langchain.vectorstores import Chroma
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.logs import logger
from pilot.vector_store.vector_store_base import VectorStoreBase
class ChromaStore(VectorStoreBase):
"""chroma database"""
def __init__(self, ctx: {}) -> None:
self.ctx = ctx
self.embeddings = ctx["embeddings"]
self.persist_dir = os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, ctx["vector_store_name"] + ".vectordb"
)
self.vector_store_client = Chroma(
persist_directory=self.persist_dir, embedding_function=self.embeddings
)
def similar_search(self, text, topk) -> None:
logger.info("ChromaStore similar search")
return self.vector_store_client.similarity_search(text, topk)
def load_document(self, documents):
logger.info("ChromaStore load document")
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
self.vector_store_client.persist()

View File

@@ -0,0 +1,19 @@
from pilot.vector_store.chroma_store import ChromaStore
from pilot.vector_store.milvus_store import MilvusStore
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore}
class VectorStoreConnector:
"""vector store connector, can connect different vector db provided load document api and similar search api"""
def __init__(self, vector_store_type, ctx: {}) -> None:
self.ctx = ctx
self.connector_class = connector[vector_store_type]
self.client = self.connector_class(ctx)
def load_document(self, docs):
self.client.load_document(docs)
def similar_search(self, docs, topk):
return self.client.similar_search(docs, topk)

View File

@@ -3,14 +3,16 @@
import os
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from pilot.configs.model_config import DATASETS_DIR, VECTORE_PATH
from pilot.model.vicuna_llm import VicunaEmbeddingLLM
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR
from langchain.embeddings import HuggingFaceEmbeddings
embeddings = VicunaEmbeddingLLM()
def knownledge_tovec(filename):
with open(filename, "r") as f:
knownledge = f.read()
@@ -22,48 +24,64 @@ def knownledge_tovec(filename):
)
return docsearch
def knownledge_tovec_st(filename):
""" Use sentence transformers to embedding the document.
https://github.com/UKPLab/sentence-transformers
"""Use sentence transformers to embedding the document.
https://github.com/UKPLab/sentence-transformers
"""
from pilot.configs.model_config import LLM_MODEL_CONFIG
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG["sentence-transforms"]
)
with open(filename, "r") as f:
knownledge = f.read()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_text(knownledge)
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))])
docsearch = Chroma.from_texts(
texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]
)
return docsearch
def load_knownledge_from_doc():
"""Loader Knownledge from current datasets
# TODO if the vector store is exists, just use it.
# TODO if the vector store is exists, just use it.
"""
if not os.path.exists(DATASETS_DIR):
print("Not Exists Local DataSets, We will answers the Question use model default.")
print(
"Not Exists Local DataSets, We will answers the Question use model default."
)
from pilot.configs.model_config import LLM_MODEL_CONFIG
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG["sentence-transforms"]
)
files = os.listdir(DATASETS_DIR)
for file in files:
if not os.path.isdir(file):
if not os.path.isdir(file):
filename = os.path.join(DATASETS_DIR, file)
with open(filename, "r") as f:
knownledge = f.read()
knownledge = f.read()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0)
texts = text_splitter.split_text(knownledge)
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))],
persist_directory=os.path.join(VECTORE_PATH, ".vectore"))
docsearch = Chroma.from_texts(
texts,
embeddings,
metadatas=[{"source": str(i)} for i in range(len(texts))],
persist_directory=os.path.join(VECTORE_PATH, ".vectore"),
)
return docsearch
def get_vector_storelist():
if not os.path.exists(VECTORE_PATH):
return []
return os.listdir(VECTORE_PATH)
return os.listdir(VECTORE_PATH)

View File

@@ -2,58 +2,72 @@
# -*- coding: utf-8 -*-
import os
import copy
from typing import Optional, List, Dict
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader
from langchain.chains import VectorDBQA
from langchain.document_loaders import (
TextLoader,
UnstructuredFileLoader,
UnstructuredPDFLoader,
)
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
from langchain.prompts import PromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from pilot.configs.model_config import (
DATASETS_DIR,
LLM_MODEL_CONFIG,
VECTOR_SEARCH_TOP_K,
VECTORE_PATH,
)
class KnownLedge2Vector:
"""KnownLedge2Vector class is order to load document to vector
"""KnownLedge2Vector class is order to load document to vector
and persist to vector store.
Args:
Args:
- model_name
Usage:
k2v = KnownLedge2Vector()
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
print(persist_dir)
for s, dc in k2v.query("what is oceanbase?"):
print(s, dc.page_content, dc.metadata)
"""
embeddings: object = None
embeddings: object = None
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
top_k: int = VECTOR_SEARCH_TOP_K
def __init__(self, model_name=None) -> None:
if not model_name:
# use default embedding model
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
def init_vector_store(self):
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
print("Vector store Persist address is: ", persist_dir)
if os.path.exists(persist_dir):
# Loader from local file.
print("Loader data from local persist vector file...")
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
vector_store = Chroma(
persist_directory=persist_dir, embedding_function=self.embeddings
)
# vector_store.add_documents(documents=documents)
else:
documents = self.load_knownlege()
# reinit
vector_store = Chroma.from_documents(documents=documents,
embedding=self.embeddings,
persist_directory=persist_dir)
# reinit
vector_store = Chroma.from_documents(
documents=documents,
embedding=self.embeddings,
persist_directory=persist_dir,
)
vector_store.persist()
return vector_store
return vector_store
def load_knownlege(self):
docments = []
@@ -61,10 +75,12 @@ class KnownLedge2Vector:
for file in files:
filename = os.path.join(root, file)
docs = self._load_file(filename)
# update metadata.
new_docs = []
# update metadata.
new_docs = []
for doc in docs:
doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")}
doc.metadata = {
"source": doc.metadata["source"].replace(DATASETS_DIR, "")
}
print("Documents to vector running, please wait...", doc.metadata)
new_docs.append(doc)
docments += new_docs
@@ -73,7 +89,7 @@ class KnownLedge2Vector:
def _load_file(self, filename):
# Loader file
if filename.lower().endswith(".pdf"):
loader = UnstructuredFileLoader(filename)
loader = UnstructuredFileLoader(filename)
text_splitor = CharacterTextSplitter()
docs = loader.load_and_split(text_splitor)
else:
@@ -86,13 +102,10 @@ class KnownLedge2Vector:
"""Load data from url address"""
pass
def query(self, q):
"""Query similar doc from Vector """
"""Query similar doc from Vector"""
vector_store = self.init_vector_store()
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
for doc in docs:
dc, s = doc
yield s, dc

View File

@@ -1,33 +1,55 @@
from typing import Any, Iterable, List, Optional, Tuple
from pymilvus import DataType, FieldSchema, CollectionSchema, connections, Collection
from langchain.docstore.document import Document
from pymilvus import Collection, DataType, connections
from pilot.configs.config import Config
from pilot.vector_store.vector_store_base import VectorStoreBase
CFG = Config()
class MilvusStore(VectorStoreBase):
def __init__(self, cfg: {}) -> None:
"""Construct a milvus memory storage connection.
"""Milvus database"""
def __init__(self, ctx: {}) -> None:
"""init a milvus storage connection.
Args:
cfg (Config): Auto-GPT global config.
ctx ({}): MilvusStore global config.
"""
# self.configure(cfg)
connect_kwargs = {}
self.uri = None
self.uri = cfg["url"]
self.port = cfg["port"]
self.username = cfg.get("username", None)
self.password = cfg.get("password", None)
self.collection_name = cfg["table_name"]
self.password = cfg.get("secure", None)
self.uri = CFG.MILVUS_URL
self.port = CFG.MILVUS_PORT
self.username = CFG.MILVUS_USERNAME
self.password = CFG.MILVUS_PASSWORD
self.collection_name = ctx.get("vector_store_name", None)
self.secure = ctx.get("secure", None)
self.embedding = ctx.get("embeddings", None)
self.fields = []
# use HNSW by default.
self.index_params = {
"metric_type": "IP",
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
# use HNSW by default.
self.index_params_map = {
"IVF_FLAT": {"params": {"nprobe": 10}},
"IVF_SQ8": {"params": {"nprobe": 10}},
"IVF_PQ": {"params": {"nprobe": 10}},
"HNSW": {"params": {"ef": 10}},
"RHNSW_FLAT": {"params": {"ef": 10}},
"RHNSW_SQ": {"params": {"ef": 10}},
"RHNSW_PQ": {"params": {"ef": 10}},
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"params": {"search_k": 10}},
}
self.text_field = "content"
if (self.username is None) != (self.password is None):
raise ValueError(
@@ -38,54 +60,263 @@ class MilvusStore(VectorStoreBase):
connect_kwargs["password"] = self.password
connections.connect(
**connect_kwargs,
host=self.uri or "127.0.0.1",
port=self.port or "19530",
alias="default"
# secure=self.secure,
)
self.init_schema()
def init_schema(self) -> None:
"""Initialize collection in milvus database."""
fields = [
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=384),
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
]
# create collection if not exist and load it.
self.schema = CollectionSchema(fields, "db-gpt memory storage")
self.collection = Collection(self.collection_name, self.schema)
self.index_params = {
"metric_type": "IP",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
# create index if not exist.
if not self.collection.has_index():
self.collection.release()
self.collection.create_index(
"vector",
self.index_params,
index_name="vector",
def init_schema_and_load(self, vector_name, documents):
"""Create a Milvus collection, indexes it with HNSW, load document.
Args:
vector_name (Embeddings): your collection name.
documents (List[str]): Text to insert.
Returns:
VectorStore: The MilvusStore vector store.
"""
try:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
connections,
)
self.collection.load()
from pymilvus.orm.types import infer_dtype_bydata
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
if not connections.has_connection("default"):
connections.connect(
host=self.uri or "127.0.0.1",
port=self.port or "19530",
alias="default"
# secure=self.secure,
)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
embeddings = self.embedding.embed_query(texts[0])
dim = len(embeddings)
# Generate unique names
primary_field = "pk_id"
vector_field = "vector"
text_field = "content"
self.text_field = text_field
collection_name = vector_name
fields = []
# Determine metadata schema
# if metadatas:
# # Check if all metadata keys line up
# key = metadatas[0].keys()
# for x in metadatas:
# if key != x.keys():
# raise ValueError(
# "Mismatched metadata. "
# "Make sure all metadata has the same keys and datatype."
# )
# # Create FieldSchema for each entry in singular metadata.
# for key, value in metadatas[0].items():
# # Infer the corresponding datatype of the metadata
# dtype = infer_dtype_bydata(value)
# if dtype == DataType.UNKNOWN:
# raise ValueError(f"Unrecognized datatype for {key}.")
# elif dtype == DataType.VARCHAR:
# # Find out max length text based metadata
# max_length = 0
# for subvalues in metadatas:
# max_length = max(max_length, len(subvalues[key]))
# fields.append(
# FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1)
# )
# else:
# fields.append(FieldSchema(key, dtype))
# def add(self, data) -> str:
# Find out max length of texts
max_length = 0
for y in texts:
max_length = max(max_length, len(y))
# Create the text field
fields.append(
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1)
)
# primary key field
fields.append(
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
)
# vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
# milvus the schema for the collection
schema = CollectionSchema(fields)
# Create the collection
collection = Collection(collection_name, schema)
self.col = collection
# index parameters for the collection
index = self.index_params
# milvus index
collection.create_index(vector_field, index)
schema = collection.schema
for x in schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
self._add_texts(texts, metadatas)
return self.collection_name
# def init_schema(self) -> None:
# """Initialize collection in milvus database."""
# fields = [
# FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
# FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]),
# FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
# ]
#
# # create collection if not exist and load it.
# self.schema = CollectionSchema(fields, "db-gpt memory storage")
# self.collection = Collection(self.collection_name, self.schema)
# self.index_params_map = {
# "IVF_FLAT": {"params": {"nprobe": 10}},
# "IVF_SQ8": {"params": {"nprobe": 10}},
# "IVF_PQ": {"params": {"nprobe": 10}},
# "HNSW": {"params": {"ef": 10}},
# "RHNSW_FLAT": {"params": {"ef": 10}},
# "RHNSW_SQ": {"params": {"ef": 10}},
# "RHNSW_PQ": {"params": {"ef": 10}},
# "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
# "ANNOY": {"params": {"search_k": 10}},
# }
#
# self.index_params = {
# "metric_type": "IP",
# "index_type": "HNSW",
# "params": {"M": 8, "efConstruction": 64},
# }
# # create index if not exist.
# if not self.collection.has_index():
# self.collection.release()
# self.collection.create_index(
# "vector",
# self.index_params,
# index_name="vector",
# )
# info = self.collection.describe()
# self.collection.load()
# def insert(self, text, model_config) -> str:
# """Add an embedding of data into milvus.
#
# Args:
# data (str): The raw text to construct embedding index.
#
# text (str): The raw text to construct embedding index.
# Returns:
# str: log.
# """
# embedding = get_ada_embedding(data)
# result = self.collection.insert([[embedding], [data]])
# # embedding = get_ada_embedding(data)
# embeddings = HuggingFaceEmbeddings(model_name=self.model_config["model_name"])
# result = self.collection.insert([embeddings.embed_documents(text), text])
# _text = (
# "Inserting data into memory at primary key: "
# f"{result.primary_keys[0]}:\n data: {data}"
# f"{result.primary_keys[0]}:\n data: {text}"
# )
# return _text
# return _text
def _add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
partition_name: Optional[str] = None,
timeout: Optional[int] = None,
) -> List[str]:
"""add text data into Milvus."""
insert_dict: Any = {self.text_field: list(texts)}
try:
insert_dict[self.vector_field] = self.embedding.embed_documents(list(texts))
except NotImplementedError:
insert_dict[self.vector_field] = [
self.embedding.embed_query(x) for x in texts
]
# Collect the metadata into the insert dict.
if len(self.fields) > 2 and metadatas is not None:
for d in metadatas:
for key, value in d.items():
if key in self.fields:
insert_dict.setdefault(key, []).append(value)
# Convert dict to list of lists for insertion
insert_list = [insert_dict[x] for x in self.fields]
# Insert into the collection.
res = self.col.insert(
insert_list, partition_name=partition_name, timeout=timeout
)
# make sure data is searchable.
self.col.flush()
return res.primary_keys
def load_document(self, documents) -> None:
"""load document in vector database."""
self.init_schema_and_load(self.collection_name, documents)
def similar_search(self, text, topk) -> None:
"""similar_search in vector database."""
self.col = Collection(self.collection_name)
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
_, docs_and_scores = self._search(text, topk)
return [doc for doc, _, _ in docs_and_scores]
def _search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
partition_names: Optional[List[str]] = None,
round_decimal: int = -1,
timeout: Optional[int] = None,
**kwargs: Any,
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]:
self.col.load()
# use default index params.
if param is None:
index_type = self.col.indexes[0].params["index_type"]
param = self.index_params_map[index_type]
# query text embedding.
data = [self.embedding.embed_query(query)]
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self.vector_field)
# milvus search.
res = self.col.search(
data,
self.vector_field,
param,
k,
expr=expr,
output_fields=output_fields,
partition_names=partition_names,
round_decimal=round_decimal,
timeout=timeout,
**kwargs,
)
ret = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
ret.append(
(
Document(page_content=meta.pop(self.text_field), metadata=meta),
result.distance,
result.id,
)
)
return data[0], ret

View File

@@ -2,8 +2,14 @@ from abc import ABC, abstractmethod
class VectorStoreBase(ABC):
"""base class for vector store database"""
@abstractmethod
def init_schema(self) -> None:
def load_document(self, documents) -> None:
"""load document in vector database."""
pass
@abstractmethod
def similar_search(self, text, topk) -> None:
"""Initialize schema in vector database."""
pass
pass

View File

@@ -42,6 +42,7 @@ tenacity==8.2.2
peft
pycocoevalcap
sentence-transformers
cpm_kernels
umap-learn
notebook
gradio==3.23
@@ -60,6 +61,15 @@ gTTS==2.3.1
langchain
nltk
python-dotenv==1.0.0
pymilvus==2.2.1
vcrpy
chromadb
markdown2
colorama
playsound
distro
pypdf
milvus-cli==0.3.2
# Testing dependencies
pytest
@@ -69,10 +79,4 @@ pytest-benchmark
pytest-cov
pytest-integration
pytest-mock
vcrpy
pytest-recording
chromadb
markdown2
colorama
playsound
distro
pytest-recording

View File

@@ -1,6 +1,6 @@
import pytest
import os
import pytest
from pilot.configs.config import Config
from pilot.plugins import (
@@ -15,10 +15,13 @@ PLUGIN_TEST_ZIP_FILE = "Auto-GPT-Plugin-Test-master.zip"
PLUGIN_TEST_INIT_PY = "Auto-GPT-Plugin-Test-master/src/auto_gpt_vicuna/__init__.py"
PLUGIN_TEST_OPENAI = "https://weathergpt.vercel.app/"
def test_inspect_zip_for_modules():
current_dir = os.getcwd()
print(current_dir)
result = inspect_zip_for_modules(str(f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/{PLUGIN_TEST_ZIP_FILE}"))
result = inspect_zip_for_modules(
str(f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/{PLUGIN_TEST_ZIP_FILE}")
)
assert result == [PLUGIN_TEST_INIT_PY]
@@ -99,6 +102,7 @@ def mock_config_openai_plugin():
class MockConfig:
"""Mock config object for testing the scan_plugins function"""
current_dir = os.getcwd()
plugins_dir = f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/"
plugins_openai = [PLUGIN_TEST_OPENAI]

60
tools/knowlege_init.py Normal file
View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
LLM_MODEL_CONFIG,
VECTOR_SEARCH_TOP_K,
)
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
class LocalKnowledgeInit:
embeddings: object = None
model_name = LLM_MODEL_CONFIG["text2vec"]
top_k: int = VECTOR_SEARCH_TOP_K
def __init__(self, vector_store_config) -> None:
self.vector_store_config = vector_store_config
def knowledge_persist(self, file_path, append_mode):
"""knowledge persist"""
kv = KnowledgeEmbedding(
file_path=file_path,
model_name=LLM_MODEL_CONFIG["text2vec"],
vector_store_config=self.vector_store_config,
)
vector_store = kv.knowledge_persist_initialization(append_mode)
return vector_store
def query(self, q):
"""Query similar doc from Vector"""
vector_store = self.init_vector_store()
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
for doc in docs:
dc, s = doc
yield s, dc
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--vector_name", type=str, default="default")
parser.add_argument("--append", type=bool, default=False)
parser.add_argument("--store_type", type=str, default="Chroma")
args = parser.parse_args()
vector_name = args.vector_name
append_mode = args.append
store_type = CFG.VECTOR_STORE_TYPE
vector_store_config = {"vector_store_name": vector_name}
print(vector_store_config)
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
print("your knowledge embedding success...")