mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 01:27:14 +00:00
Merge remote-tracking branch 'origin/main' into dev
# Conflicts: # pilot/configs/config.py # pilot/connections/mysql.py # pilot/conversation.py # pilot/server/webserver.py
This commit is contained in:
@@ -81,3 +81,14 @@ DENYLISTED_PLUGINS=
|
||||
#*******************************************************************#
|
||||
# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False)
|
||||
# CHAT_MESSAGES_ENABLED=False
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** VECTOR STORE SETTINGS **#
|
||||
#*******************************************************************#
|
||||
VECTOR_STORE_TYPE=Chroma
|
||||
#MILVUS_URL=127.0.0.1
|
||||
#MILVUS_PORT=19530
|
||||
#MILVUS_USERNAME
|
||||
#MILVUS_PASSWORD
|
||||
#MILVUS_SECURE=
|
||||
|
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: "[BUG]: "
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS: [e.g. iOS]
|
||||
- Browser [e.g. chrome, safari]
|
||||
- Version [e.g. 22]
|
||||
|
||||
**Smartphone (please complete the following information):**
|
||||
- Device: [e.g. iPhone6]
|
||||
- OS: [e.g. iOS8.1]
|
||||
- Browser [e.g. stock browser, safari]
|
||||
- Version [e.g. 22]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
10
.github/ISSUE_TEMPLATE/documentation-related.md
vendored
Normal file
10
.github/ISSUE_TEMPLATE/documentation-related.md
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
---
|
||||
name: Documentation Related
|
||||
about: Describe this issue template's purpose here.
|
||||
title: "[Doc]: "
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
|
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: "[Feature]:"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
17
.github/workflows/pylint.yml
vendored
17
.github/workflows/pylint.yml
vendored
@@ -1,6 +1,15 @@
|
||||
name: Pylint
|
||||
|
||||
on: [push]
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.number || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@@ -17,7 +26,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pylint
|
||||
- name: Analysing the code with pylint
|
||||
pip install -U black isort
|
||||
- name: check the code lint
|
||||
run: |
|
||||
pylint $(git ls-files '*.py')
|
||||
black . --check
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@@ -25,6 +25,7 @@ lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
models
|
||||
var/
|
||||
wheels/
|
||||
models/
|
||||
@@ -138,4 +139,5 @@ dmypy.json
|
||||
.DS_Store
|
||||
logs
|
||||
nltk_data
|
||||
.vectordb
|
||||
.vectordb
|
||||
pilot/data/
|
43
README.md
43
README.md
@@ -29,6 +29,10 @@ Currently, we have released multiple key features, which are listed below to dem
|
||||
- Unified vector storage/indexing of knowledge base
|
||||
- Support for unstructured data such as PDF, Markdown, CSV, and WebURL
|
||||
|
||||
- Milti LLMs Support
|
||||
- Supports multiple large language models, currently supporting Vicuna (7b, 13b), ChatGLM-6b (int4, int8)
|
||||
- TODO: codegen2, codet5p
|
||||
|
||||
|
||||
## Demo
|
||||
|
||||
@@ -103,6 +107,7 @@ As the knowledge base is currently the most significant user demand scenario, we
|
||||
2. Custom addition of knowledge bases
|
||||
3. Various usage scenarios such as constructing knowledge bases through plugin capabilities and web crawling. Users only need to organize the knowledge documents, and they can use our existing capabilities to build the knowledge base required for the large model.
|
||||
|
||||
|
||||
### LLMs Management
|
||||
|
||||
In the underlying large model integration, we have designed an open interface that supports integration with various large models. At the same time, we have a very strict control and evaluation mechanism for the effectiveness of the integrated models. In terms of accuracy, the integrated models need to align with the capability of ChatGPT at a level of 85% or higher. We use higher standards to select models, hoping to save users the cumbersome testing and evaluation process in the process of use.
|
||||
@@ -153,16 +158,6 @@ conda create -n dbgpt_env python=3.10
|
||||
conda activate dbgpt_env
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
Alternatively, you can use the following command:
|
||||
```
|
||||
cd DB-GPT
|
||||
conda env create -f environment.yml
|
||||
```
|
||||
It is recommended to set the Python package path to avoid runtime errors due to package not found.
|
||||
```
|
||||
echo "/root/workspace/DB-GPT" > /root/miniconda3/env/dbgpt_env/lib/python3.10/site-packages/dbgpt.pth
|
||||
```
|
||||
Notice: You need replace the path to your owner.
|
||||
|
||||
### 3. Run
|
||||
You can refer to this document to obtain the Vicuna weights: [Vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights) .
|
||||
@@ -182,9 +177,31 @@ $ python pilot/server/webserver.py
|
||||
Notice: the webserver need to connect llmserver, so you need change the .env file. change the MODEL_SERVER = "http://127.0.0.1:8000" to your address. It's very important.
|
||||
|
||||
## Usage Instructions
|
||||
|
||||
We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project.
|
||||
- [LLM Practical In Action Series (1) — Combined Langchain-Vicuna Application Practical](https://medium.com/@cfqcsunny/llm-practical-in-action-series-1-combined-langchain-vicuna-application-practical-701cd0413c9f)
|
||||
|
||||
### Multi LLMs Usage
|
||||
|
||||
To use multiple models, modify the LLM_MODEL parameter in the .env configuration file to switch between the models.
|
||||
|
||||
####Create your own knowledge repository:
|
||||
|
||||
1.Place personal knowledge files or folders in the pilot/datasets directory.
|
||||
|
||||
2.Run the knowledge repository script in the tools directory.
|
||||
|
||||
```
|
||||
python tools/knowledge_init.py
|
||||
|
||||
--vector_name : your vector store name default_value:default
|
||||
--append: append mode, True:append, False: not append default_value:False
|
||||
|
||||
```
|
||||
|
||||
3.Add the knowledge repository in the interface by entering the name of your knowledge repository (if not specified, enter "default") so you can use it for Q&A based on your knowledge base.
|
||||
|
||||
Note that the default vector model used is text2vec-large-chinese (which is a large model, so if your personal computer configuration is not enough, it is recommended to use text2vec-base-chinese). Therefore, ensure that you download the model and place it in the models directory.
|
||||
## Acknowledgement
|
||||
|
||||
The achievements of this project are thanks to the technical community, especially the following projects:
|
||||
@@ -198,6 +215,10 @@ The achievements of this project are thanks to the technical community, especial
|
||||
- [ChatGLM](https://github.com/THUDM/ChatGLM-6B) as the base model
|
||||
- [llama_index](https://github.com/jerryjliu/llama_index) for enhancing database-related knowledge using [in-context learning](https://arxiv.org/abs/2301.00234) based on existing knowledge bases.
|
||||
|
||||
## Contribution
|
||||
|
||||
- Please run `black .` before submitting the code.
|
||||
|
||||
<!-- GITCONTRIBUTOR_START -->
|
||||
|
||||
## Contributors
|
||||
@@ -206,7 +227,7 @@ The achievements of this project are thanks to the technical community, especial
|
||||
| :---: | :---: | :---: | :---: |:---: |
|
||||
|
||||
|
||||
This project follows the git-contributor [spec](https://github.com/xudafeng/git-contributor), auto updated at `Sun May 14 2023 23:02:43 GMT+0800`.
|
||||
This project follows the git-contributor [spec](https://github.com/xudafeng/git-contributor), auto updated at `Fri May 19 2023 00:24:18 GMT+0800`.
|
||||
|
||||
<!-- GITCONTRIBUTOR_END -->
|
||||
|
||||
|
41
README.zh.md
41
README.zh.md
@@ -26,6 +26,10 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
|
||||
- 知识库统一向量存储/索引
|
||||
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
|
||||
|
||||
- 多模型支持
|
||||
- 支持多种大语言模型, 当前已支持Vicuna(7b,13b), ChatGLM-6b(int4, int8)
|
||||
- TODO: codet5p, codegen2
|
||||
|
||||
## 效果演示
|
||||
|
||||
示例通过 RTX 4090 GPU 演示,[YouTube 地址](https://www.youtube.com/watch?v=1PWI6F89LPo)
|
||||
@@ -110,6 +114,8 @@ DB-GPT基于 [FastChat](https://github.com/lm-sys/FastChat) 构建大模型运
|
||||
|
||||
用户只需要整理好知识文档,即可用我们现有的能力构建大模型所需要的知识库能力。
|
||||
|
||||
|
||||
|
||||
### 大模型管理能力
|
||||
在底层大模型接入中,设计了开放的接口,支持对接多种大模型。同时对于接入模型的效果,我们有非常严格的把控与评审机制。对大模型能力上与ChatGPT对比,在准确率上需要满足85%以上的能力对齐。我们用更高的标准筛选模型,是期望在用户使用过程中,可以省去前面繁琐的测试评估环节。
|
||||
|
||||
@@ -151,15 +157,6 @@ conda create -n dbgpt_env python=3.10
|
||||
conda activate dbgpt_env
|
||||
pip install -r requirements.txt
|
||||
|
||||
```
|
||||
或者也可以使用命令:
|
||||
```
|
||||
cd DB-GPT
|
||||
conda env create -f environment.yml
|
||||
```
|
||||
另外需要设置一下python包路径, 避免出现运行时找不到包
|
||||
```
|
||||
echo "/root/workspace/DB-GPT" > /root/miniconda3/env/dbgpt_env/lib/python3.10/site-packages/dbgpt.pth
|
||||
```
|
||||
|
||||
### 3. 运行大模型
|
||||
@@ -187,6 +184,26 @@ $ python webserver.py
|
||||
2. [大模型实战系列(2) —— DB-GPT 阿里云部署指南](https://zhuanlan.zhihu.com/p/629467580)
|
||||
3. [大模型实战系列(3) —— DB-GPT插件模型原理与使用](https://zhuanlan.zhihu.com/p/629623125)
|
||||
|
||||
|
||||
### 多模型使用
|
||||
在.env 配置文件当中, 修改LLM_MODEL参数来切换使用的模型。
|
||||
|
||||
####打造属于你的知识库:
|
||||
|
||||
1、将个人知识文件或者文件夹放入pilot/datasets目录中
|
||||
|
||||
2、在tools目录执行知识入库脚本
|
||||
|
||||
```
|
||||
python tools/knowledge_init.py
|
||||
|
||||
--vector_name : your vector store name default_value:default
|
||||
--append: append mode, True:append, False: not append default_value:False
|
||||
|
||||
```
|
||||
3、在界面上新增知识库输入你的知识库名(如果没指定输入default),就可以根据你的知识库进行问答
|
||||
|
||||
注意,这里默认向量模型是text2vec-large-chinese(模型比较大,如果个人电脑配置不够建议采用text2vec-base-chinese),因此确保需要将模型download下来放到models目录中。
|
||||
## 感谢
|
||||
|
||||
项目取得的成果,需要感谢技术社区,尤其以下项目。
|
||||
@@ -201,14 +218,20 @@ $ python webserver.py
|
||||
- [ChatGLM](https://github.com/THUDM/ChatGLM-6B) 基础模型
|
||||
- [llama-index](https://github.com/jerryjliu/llama_index) 基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。
|
||||
|
||||
# 贡献
|
||||
|
||||
- 提交代码前请先执行 `black .`
|
||||
|
||||
<!-- GITCONTRIBUTOR_START -->
|
||||
|
||||
## 贡献者
|
||||
## Contributors
|
||||
|
||||
|[<img src="https://avatars.githubusercontent.com/u/17919400?v=4" width="100px;"/><br/><sub><b>csunny</b></sub>](https://github.com/csunny)<br/>|[<img src="https://avatars.githubusercontent.com/u/1011681?v=4" width="100px;"/><br/><sub><b>xudafeng</b></sub>](https://github.com/xudafeng)<br/>|[<img src="https://avatars.githubusercontent.com/u/7636723?s=96&v=4" width="100px;"/><br/><sub><b>明天</b></sub>](https://github.com/yhjun1026)<br/> | [<img src="https://avatars.githubusercontent.com/u/13723926?v=4" width="100px;"/><br/><sub><b>Aries-ckt</b></sub>](https://github.com/Aries-ckt)<br/>|[<img src="https://avatars.githubusercontent.com/u/95130644?v=4" width="100px;"/><br/><sub><b>thebigbone</b></sub>](https://github.com/thebigbone)<br/>|
|
||||
| :---: | :---: | :---: | :---: |:---: |
|
||||
|
||||
|
||||
[git-contributor 说明](https://github.com/xudafeng/git-contributor),自动生成时间:`Fri May 19 2023 00:24:18 GMT+0800`。
|
||||
This project follows the git-contributor [spec](https://github.com/xudafeng/git-contributor), auto updated at `Sun May 14 2023 23:02:43 GMT+0800`.
|
||||
|
||||
<!-- GITCONTRIBUTOR_END -->
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 257 KiB After Width: | Height: | Size: 157 KiB |
@@ -1,68 +0,0 @@
|
||||
name: db_pgt
|
||||
channels:
|
||||
- pytorch
|
||||
- defaults
|
||||
- anaconda
|
||||
dependencies:
|
||||
- python=3.10
|
||||
- cudatoolkit
|
||||
- pip
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- pip:
|
||||
- pytorch
|
||||
- accelerate==0.16.0
|
||||
- aiohttp==3.8.4
|
||||
- aiosignal==1.3.1
|
||||
- async-timeout==4.0.2
|
||||
- attrs==22.2.0
|
||||
- bitsandbytes==0.37.0
|
||||
- cchardet==2.1.7
|
||||
- chardet==5.1.0
|
||||
- contourpy==1.0.7
|
||||
- cycler==0.11.0
|
||||
- filelock==3.9.0
|
||||
- fonttools==4.38.0
|
||||
- frozenlist==1.3.3
|
||||
- huggingface-hub==0.13.4
|
||||
- importlib-resources==5.12.0
|
||||
- kiwisolver==1.4.4
|
||||
- matplotlib==3.7.0
|
||||
- multidict==6.0.4
|
||||
- packaging==23.0
|
||||
- psutil==5.9.4
|
||||
- pycocotools==2.0.6
|
||||
- pyparsing==3.0.9
|
||||
- python-dateutil==2.8.2
|
||||
- pyyaml==6.0
|
||||
- regex==2022.10.31
|
||||
- tokenizers==0.13.2
|
||||
- tqdm==4.64.1
|
||||
- transformers==4.28.0
|
||||
- timm==0.6.13
|
||||
- spacy==3.5.1
|
||||
- webdataset==0.2.48
|
||||
- scikit-learn==1.2.2
|
||||
- scipy==1.10.1
|
||||
- yarl==1.8.2
|
||||
- zipp==3.14.0
|
||||
- omegaconf==2.3.0
|
||||
- opencv-python==4.7.0.72
|
||||
- iopath==0.1.10
|
||||
- tenacity==8.2.2
|
||||
- peft
|
||||
- pycocoevalcap
|
||||
- sentence-transformers
|
||||
- umap-learn
|
||||
- notebook
|
||||
- gradio==3.23
|
||||
- gradio-client==0.0.8
|
||||
- wandb
|
||||
- llama-index==0.5.27
|
||||
- pymysql
|
||||
- unstructured==0.6.3
|
||||
- pytesseract==0.3.10
|
||||
- markdown2
|
||||
- chromadb
|
||||
- colorama
|
||||
- playsound
|
||||
- distro
|
@@ -2,24 +2,28 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import gradio as gr
|
||||
from langchain.agents import (
|
||||
load_tools,
|
||||
initialize_agent,
|
||||
AgentType
|
||||
)
|
||||
from pilot.model.vicuna_llm import VicunaRequestLLM, VicunaEmbeddingLLM
|
||||
from llama_index import LLMPredictor, LangchainEmbedding, ServiceContext
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from llama_index import Document, GPTSimpleVectorIndex
|
||||
from llama_index import (
|
||||
Document,
|
||||
GPTSimpleVectorIndex,
|
||||
LangchainEmbedding,
|
||||
LLMPredictor,
|
||||
ServiceContext,
|
||||
)
|
||||
|
||||
from pilot.model.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM
|
||||
|
||||
|
||||
def agent_demo():
|
||||
llm = VicunaRequestLLM()
|
||||
|
||||
tools = load_tools(['python_repl'], llm=llm)
|
||||
agent = initialize_agent(tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
|
||||
agent.run(
|
||||
"Write a SQL script that Query 'select count(1)!'"
|
||||
tools = load_tools(["python_repl"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run("Write a SQL script that Query 'select count(1)!'")
|
||||
|
||||
|
||||
def knowledged_qa_demo(text_list):
|
||||
llm_predictor = LLMPredictor(llm=VicunaRequestLLM())
|
||||
@@ -27,27 +31,34 @@ def knowledged_qa_demo(text_list):
|
||||
embed_model = LangchainEmbedding(hfemb)
|
||||
documents = [Document(t) for t in text_list]
|
||||
|
||||
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model)
|
||||
index = GPTSimpleVectorIndex.from_documents(documents, service_context=service_context)
|
||||
service_context = ServiceContext.from_defaults(
|
||||
llm_predictor=llm_predictor, embed_model=embed_model
|
||||
)
|
||||
index = GPTSimpleVectorIndex.from_documents(
|
||||
documents, service_context=service_context
|
||||
)
|
||||
return index
|
||||
|
||||
|
||||
def get_answer(q):
|
||||
base_knowledge = """ """
|
||||
base_knowledge = """ """
|
||||
text_list = [base_knowledge]
|
||||
index = knowledged_qa_demo(text_list)
|
||||
response = index.query(q)
|
||||
return response.response
|
||||
|
||||
|
||||
def get_similar(q):
|
||||
from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st
|
||||
|
||||
docsearch = knownledge_tovec_st("./datasets/plan.md")
|
||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||
|
||||
for doc in docs:
|
||||
dc, s = doc
|
||||
dc, s = doc
|
||||
print(s)
|
||||
yield dc.page_content
|
||||
yield dc.page_content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# agent_demo()
|
||||
@@ -58,8 +69,7 @@ if __name__ == "__main__":
|
||||
text_input = gr.TextArea()
|
||||
text_output = gr.TextArea()
|
||||
text_button = gr.Button()
|
||||
|
||||
|
||||
text_button.click(get_similar, inputs=text_input, outputs=text_output)
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
||||
|
@@ -1,61 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import os
|
||||
import sys
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import gradio as gr
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import conv_qa_prompt_template, conv_templates
|
||||
import requests
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import conv_qa_prompt_template, conv_templates
|
||||
|
||||
vicuna_stream_path = "generate_stream"
|
||||
llmstream_stream_path = "generate_stream"
|
||||
|
||||
CFG = Config()
|
||||
|
||||
def generate(query):
|
||||
|
||||
def generate(query):
|
||||
template_name = "conv_one_shot"
|
||||
state = conv_templates[template_name].copy()
|
||||
|
||||
pt = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
)
|
||||
# pt = PromptTemplate(
|
||||
# template=conv_qa_prompt_template,
|
||||
# input_variables=["context", "question"]
|
||||
# )
|
||||
|
||||
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
|
||||
question=query)
|
||||
# result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
|
||||
# question=query)
|
||||
|
||||
print(result)
|
||||
# print(result)
|
||||
|
||||
state.append_message(state.roles[0], result)
|
||||
state.append_message(state.roles[0], query)
|
||||
state.append_message(state.roles[1], None)
|
||||
|
||||
prompt = state.get_prompt()
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"model": "chatglm-6b",
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"temperature": 1.0,
|
||||
"max_new_tokens": 1024,
|
||||
"stop": "###"
|
||||
"stop": "###",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
|
||||
url=urljoin(CFG.MODEL_SERVER, llmstream_stream_path), data=json.dumps(params)
|
||||
)
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield(output)
|
||||
|
||||
yield (output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(CFG.LLM_MODEL)
|
||||
with gr.Blocks() as demo:
|
||||
@@ -64,10 +76,7 @@ if __name__ == "__main__":
|
||||
text_input = gr.TextArea()
|
||||
text_output = gr.TextArea()
|
||||
text_button = gr.Button("提交")
|
||||
|
||||
|
||||
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
@@ -1,19 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from llama_index import SimpleDirectoryReader, GPTSimpleVectorIndex
|
||||
from llama_index import GPTSimpleVectorIndex, SimpleDirectoryReader
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
||||
|
||||
# read the document of data dir
|
||||
documents = SimpleDirectoryReader("data").load_data()
|
||||
# split the document to chunk, max token size=500, convert chunk to vector
|
||||
# split the document to chunk, max token size=500, convert chunk to vector
|
||||
|
||||
index = GPTSimpleVectorIndex(documents)
|
||||
|
||||
# save index
|
||||
index.save_to_disk("index.json")
|
||||
index.save_to_disk("index.json")
|
||||
|
@@ -3,17 +3,19 @@
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def change_tab():
|
||||
return gr.Tabs.update(selected=1)
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.TabItem("Train", id=0):
|
||||
t = gr.Textbox()
|
||||
with gr.TabItem("Inference", id=1):
|
||||
i = gr.Image()
|
||||
|
||||
|
||||
btn = gr.Button()
|
||||
btn.click(change_tab, None, tabs)
|
||||
|
||||
demo.launch()
|
||||
demo.launch()
|
||||
|
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||
|
||||
# path = "/Users/chenketing/Downloads/share_ireserve双写数据异常2.xlsx"
|
||||
@@ -8,6 +6,13 @@ model_name = "your_path/all-MiniLM-L6-v2"
|
||||
vector_store_path = "your_path/"
|
||||
|
||||
|
||||
pdf_embedding = CSVEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "url", "vector_store_path": "vector_store_path"})
|
||||
pdf_embedding = CSVEmbedding(
|
||||
file_path=path,
|
||||
model_name=model_name,
|
||||
vector_store_config={
|
||||
"vector_store_name": "url",
|
||||
"vector_store_path": "vector_store_path",
|
||||
},
|
||||
)
|
||||
pdf_embedding.source_embedding()
|
||||
print("success")
|
||||
print("success")
|
||||
|
@@ -6,6 +6,13 @@ model_name = "your_path/all-MiniLM-L6-v2"
|
||||
vector_store_path = "your_path/"
|
||||
|
||||
|
||||
pdf_embedding = PDFEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "ob-pdf", "vector_store_path": vector_store_path})
|
||||
pdf_embedding = PDFEmbedding(
|
||||
file_path=path,
|
||||
model_name=model_name,
|
||||
vector_store_config={
|
||||
"vector_store_name": "ob-pdf",
|
||||
"vector_store_path": vector_store_path,
|
||||
},
|
||||
)
|
||||
pdf_embedding.source_embedding()
|
||||
print("success")
|
||||
print("success")
|
||||
|
@@ -5,6 +5,13 @@ model_name = "your_path/all-MiniLM-L6-v2"
|
||||
vector_store_path = "your_path"
|
||||
|
||||
|
||||
pdf_embedding = URLEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "url", "vector_store_path": "vector_store_path"})
|
||||
pdf_embedding = URLEmbedding(
|
||||
file_path=path,
|
||||
model_name=model_name,
|
||||
vector_store_config={
|
||||
"vector_store_name": "url",
|
||||
"vector_store_path": "vector_store_path",
|
||||
},
|
||||
)
|
||||
pdf_embedding.source_embedding()
|
||||
print("success")
|
||||
print("success")
|
||||
|
@@ -1,19 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from llama_index import SimpleDirectoryReader, LangchainEmbedding, GPTListIndex, GPTSimpleVectorIndex, PromptHelper
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from llama_index import LLMPredictor
|
||||
import torch
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.llms.base import LLM
|
||||
from llama_index import (
|
||||
GPTListIndex,
|
||||
GPTSimpleVectorIndex,
|
||||
LangchainEmbedding,
|
||||
LLMPredictor,
|
||||
PromptHelper,
|
||||
SimpleDirectoryReader,
|
||||
)
|
||||
from transformers import pipeline
|
||||
|
||||
|
||||
class FlanLLM(LLM):
|
||||
model_name = "google/flan-t5-large"
|
||||
pipeline = pipeline("text2text-generation", model=model_name, device=0, model_kwargs={
|
||||
"torch_dtype": torch.bfloat16
|
||||
})
|
||||
pipeline = pipeline(
|
||||
"text2text-generation",
|
||||
model=model_name,
|
||||
device=0,
|
||||
model_kwargs={"torch_dtype": torch.bfloat16},
|
||||
)
|
||||
|
||||
def _call(self, prompt, stop=None):
|
||||
return self.pipeline(prompt, max_length=9999)[0]["generated_text"]
|
||||
@@ -24,6 +33,7 @@ class FlanLLM(LLM):
|
||||
def _llm_type(self):
|
||||
return "custome"
|
||||
|
||||
|
||||
llm_predictor = LLMPredictor(llm=FlanLLM())
|
||||
hfemb = HuggingFaceEmbeddings()
|
||||
embed_model = LangchainEmbedding(hfemb)
|
||||
@@ -214,9 +224,10 @@ OceanBase 数据库 EXPLAIN 命令输出的第一部分是执行计划的树形
|
||||
|
||||
回答: nlj也是左表的表是驱动表,这个要了解下计划执行方面的基本原理,取左表的一行数据,再遍历右表,一旦满足连接条件,就可以返回数据
|
||||
anti/semi只是因为not exists/exist的语义只是返回左表数据,改成anti join是一种计划优化,连接的方式比子查询更优
|
||||
"""
|
||||
"""
|
||||
|
||||
from llama_index import Document
|
||||
|
||||
text_list = [text1]
|
||||
documents = [Document(t) for t in text_list]
|
||||
|
||||
@@ -226,12 +237,18 @@ max_input_size = 512
|
||||
max_chunk_overlap = 20
|
||||
prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)
|
||||
|
||||
index = GPTListIndex(documents, embed_model=embed_model, llm_predictor=llm_predictor, prompt_helper=prompt_helper)
|
||||
index = GPTListIndex(
|
||||
documents,
|
||||
embed_model=embed_model,
|
||||
llm_predictor=llm_predictor,
|
||||
prompt_helper=prompt_helper,
|
||||
)
|
||||
index.save_to_disk("index.json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.getLogger().setLevel(logging.CRITICAL)
|
||||
for d in documents:
|
||||
print(d)
|
||||
|
@@ -1,7 +1,3 @@
|
||||
from pilot.source_embedding import (SourceEmbedding, register)
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SourceEmbedding",
|
||||
"register"
|
||||
]
|
||||
__all__ = ["SourceEmbedding", "register"]
|
||||
|
@@ -3,10 +3,10 @@
|
||||
|
||||
|
||||
class Agent:
|
||||
"""Agent class for interacting with DB-GPT
|
||||
|
||||
Attributes:
|
||||
"""Agent class for interacting with DB-GPT
|
||||
|
||||
Attributes:
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
pass
|
||||
|
@@ -4,10 +4,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.configs.config import Config
|
||||
from typing import List
|
||||
from pilot.model.base import Message
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
|
||||
class AgentManager(metaclass=Singleton):
|
||||
@@ -17,6 +15,7 @@ class AgentManager(metaclass=Singleton):
|
||||
self.next_key = 0
|
||||
self.agents = {} # key, (task, full_message_history, model)
|
||||
self.cfg = Config()
|
||||
|
||||
"""Agent manager for managing DB-GPT agents
|
||||
In order to compatible auto gpt plugins,
|
||||
we use the same template with it.
|
||||
@@ -28,7 +27,7 @@ class AgentManager(metaclass=Singleton):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_key = 0
|
||||
self.agents = {} #TODO need to define
|
||||
self.agents = {} # TODO need to define
|
||||
self.cfg = Config()
|
||||
|
||||
# Create new GPT agent
|
||||
@@ -46,7 +45,6 @@ class AgentManager(metaclass=Singleton):
|
||||
The key of the new agent
|
||||
"""
|
||||
|
||||
|
||||
def message_agent(self, key: str | int, message: str) -> str:
|
||||
"""Send a message to an agent and return its response
|
||||
|
||||
@@ -58,7 +56,6 @@ class AgentManager(metaclass=Singleton):
|
||||
The agent's response
|
||||
"""
|
||||
|
||||
|
||||
def list_agents(self) -> list[tuple[str | int, str]]:
|
||||
"""Return a list of all agents
|
||||
|
||||
|
@@ -1,18 +1,22 @@
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
import contextlib
|
||||
|
||||
from colorama import Fore
|
||||
from regex import regex
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.json_utils.json_fix_general import (
|
||||
add_quotes_to_property_names,
|
||||
balance_braces,
|
||||
fix_invalid_escape,
|
||||
)
|
||||
from pilot.logs import logger
|
||||
from pilot.speech import say_text
|
||||
|
||||
from pilot.json_utils.json_fix_general import fix_invalid_escape,add_quotes_to_property_names,balance_braces
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def fix_and_parse_json(
|
||||
json_to_load: str, try_to_fix_with_gpt: bool = True
|
||||
) -> Dict[Any, Any]:
|
||||
@@ -48,7 +52,7 @@ def fix_and_parse_json(
|
||||
maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1]
|
||||
return json.loads(maybe_fixed_json)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error("参数解析错误", e)
|
||||
logger.error("参数解析错误", e)
|
||||
|
||||
|
||||
def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]:
|
||||
|
@@ -1,2 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# -*- coding:utf-8 -*-
|
||||
|
@@ -1,2 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
|
@@ -1,15 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from typing import Dict, List, NoReturn, Union
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.speech import say_text
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from pilot.agent.json_fix_llm import fix_json_using_multiple_techniques
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
import json
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.speech import say_text
|
||||
|
||||
|
||||
def _resolve_pathlike_command_args(command_args):
|
||||
@@ -25,9 +24,9 @@ def _resolve_pathlike_command_args(command_args):
|
||||
|
||||
|
||||
def execute_ai_response_json(
|
||||
prompt: PromptGenerator,
|
||||
ai_response: str,
|
||||
user_input: str = None,
|
||||
prompt: PromptGenerator,
|
||||
ai_response: str,
|
||||
user_input: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
||||
@@ -52,18 +51,14 @@ def execute_ai_response_json(
|
||||
arguments = _resolve_pathlike_command_args(arguments)
|
||||
# Execute command
|
||||
if command_name is not None and command_name.lower().startswith("error"):
|
||||
result = (
|
||||
f"Command {command_name} threw the following error: {arguments}"
|
||||
)
|
||||
result = f"Command {command_name} threw the following error: {arguments}"
|
||||
elif command_name == "human_feedback":
|
||||
result = f"Human feedback: {user_input}"
|
||||
else:
|
||||
for plugin in cfg.plugins:
|
||||
if not plugin.can_handle_pre_command():
|
||||
continue
|
||||
command_name, arguments = plugin.pre_command(
|
||||
command_name, arguments
|
||||
)
|
||||
command_name, arguments = plugin.pre_command(command_name, arguments)
|
||||
command_result = execute_command(
|
||||
command_name,
|
||||
arguments,
|
||||
@@ -74,9 +69,9 @@ def execute_ai_response_json(
|
||||
|
||||
|
||||
def execute_command(
|
||||
command_name: str,
|
||||
arguments,
|
||||
prompt: PromptGenerator,
|
||||
command_name: str,
|
||||
arguments,
|
||||
prompt: PromptGenerator,
|
||||
):
|
||||
"""Execute the command and return the result
|
||||
|
||||
@@ -102,13 +97,15 @@ def execute_command(
|
||||
else:
|
||||
for command in prompt.commands:
|
||||
if (
|
||||
command_name == command["label"].lower()
|
||||
or command_name == command["name"].lower()
|
||||
command_name == command["label"].lower()
|
||||
or command_name == command["name"].lower()
|
||||
):
|
||||
try:
|
||||
# 删除非定义参数
|
||||
diff_ags = list(set(arguments.keys()).difference(set(command['args'].keys())))
|
||||
for arg_name in diff_ags:
|
||||
diff_ags = list(
|
||||
set(arguments.keys()).difference(set(command["args"].keys()))
|
||||
)
|
||||
for arg_name in diff_ags:
|
||||
del arguments[arg_name]
|
||||
print(str(arguments))
|
||||
return command["function"](**arguments)
|
||||
|
@@ -1,19 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from typing import Any, Optional, Type
|
||||
from pilot.prompts.prompt import build_default_prompt_generator
|
||||
|
||||
|
||||
class CommandsLoad:
|
||||
"""
|
||||
Load Plugins Commands Info , help build system prompt!
|
||||
Load Plugins Commands Info , help build system prompt!
|
||||
"""
|
||||
|
||||
def __init__(self)->None:
|
||||
def __init__(self) -> None:
|
||||
self.command_registry = None
|
||||
|
||||
|
||||
def getCommandInfos(self, prompt_generator: Optional[PromptGenerator] = None)-> str:
|
||||
def getCommandInfos(
|
||||
self, prompt_generator: Optional[PromptGenerator] = None
|
||||
) -> str:
|
||||
cfg = Config()
|
||||
if prompt_generator is None:
|
||||
prompt_generator = build_default_prompt_generator()
|
||||
@@ -24,4 +26,4 @@ class CommandsLoad:
|
||||
self.prompt_generator = prompt_generator
|
||||
command_infos = ""
|
||||
command_infos += f"\n\n{prompt_generator.commands()}"
|
||||
return command_infos
|
||||
return command_infos
|
||||
|
@@ -1,5 +1,4 @@
|
||||
|
||||
class NotCommands(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.message = message
|
||||
|
@@ -25,7 +25,7 @@ def generate_image(prompt: str, size: int = 256) -> str:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg"
|
||||
|
||||
|
||||
# HuggingFace
|
||||
if CFG.image_provider == "huggingface":
|
||||
return generate_image_with_hf(prompt, filename)
|
||||
@@ -72,6 +72,7 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||
|
||||
return f"Saved to disk:{filename}"
|
||||
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
prompt: str,
|
||||
filename: str,
|
||||
|
14
pilot/configs/__init__.py
Normal file
14
pilot/configs/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||
print("Setting random seed to 42")
|
||||
random.seed(42)
|
||||
|
||||
# Load the users .env file into environment variables
|
||||
load_dotenv(verbose=True, override=True)
|
||||
|
||||
del load_dotenv
|
@@ -7,13 +7,13 @@ from __future__ import annotations
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
import distro
|
||||
import yaml
|
||||
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.prompts.prompt import build_default_prompt_generator
|
||||
|
||||
# Soon this will go in a folder where it remembers more stuff about the run(s)
|
||||
@@ -88,7 +88,7 @@ class AIConfig:
|
||||
for goal in config_params.get("ai_goals", [])
|
||||
]
|
||||
api_budget = config_params.get("api_budget", 0.0)
|
||||
# type: Type[AIConfig]
|
||||
# type is Type[AIConfig]
|
||||
return AIConfig(ai_name, ai_role, ai_goals, api_budget)
|
||||
|
||||
def save(self, config_file: str = SAVE_FILE) -> None:
|
||||
@@ -133,8 +133,6 @@ class AIConfig:
|
||||
""
|
||||
)
|
||||
|
||||
|
||||
|
||||
cfg = Config()
|
||||
if prompt_generator is None:
|
||||
prompt_generator = build_default_prompt_generator()
|
||||
|
@@ -2,16 +2,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import nltk
|
||||
from typing import List
|
||||
|
||||
import nltk
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.common.sql_database import Database
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
"""Configuration class to store the state of bools for different scripts access"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Config class"""
|
||||
|
||||
@@ -19,7 +21,6 @@ class Config(metaclass=Singleton):
|
||||
self.skip_reprompt = False
|
||||
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||
|
||||
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||
)
|
||||
@@ -96,30 +97,39 @@ class Config(metaclass=Singleton):
|
||||
else:
|
||||
self.plugins_denylist = []
|
||||
|
||||
|
||||
### Local database connection configuration
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||
|
||||
### TODO Adapt to multiple types of libraries
|
||||
self.local_db = Database.from_uri("mysql+pymysql://" + self.LOCAL_DB_USER +":"+ self.LOCAL_DB_PASSWORD +"@" +self.LOCAL_DB_HOST + ":" + str(self.LOCAL_DB_PORT) ,
|
||||
engine_args ={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
|
||||
### LLM Model Service Configuration
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
|
||||
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
|
||||
self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://121.41.167.183:8000")
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
|
||||
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
|
||||
self.MODEL_PORT = os.getenv("MODEL_PORT", 8000)
|
||||
self.MODEL_SERVER = os.getenv(
|
||||
"MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT)
|
||||
)
|
||||
self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True"
|
||||
|
||||
### Vector Store Configuration
|
||||
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
|
||||
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
|
||||
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
|
||||
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
||||
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value"""
|
||||
self.debug_mode = value
|
||||
|
||||
def set_plugins(self, value: list) -> None:
|
||||
"""Set the plugins value. """
|
||||
"""Set the plugins value."""
|
||||
self.plugins = value
|
||||
|
||||
def set_templature(self, value: int) -> None:
|
||||
@@ -132,4 +142,4 @@ class Config(metaclass=Singleton):
|
||||
|
||||
def set_last_plugin_return(self, value: bool) -> None:
|
||||
"""Set the speak mode value."""
|
||||
self.last_plugin_return = value
|
||||
self.last_plugin_return = value
|
||||
|
@@ -1,10 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import torch
|
||||
import os
|
||||
import nltk
|
||||
|
||||
import nltk
|
||||
import torch
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||
@@ -16,24 +16,43 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
|
||||
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
DEVICE = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
LLM_MODEL_CONFIG = {
|
||||
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
|
||||
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
|
||||
"vicuna-7b": os.path.join(MODEL_PATH, "vicuna-7b"),
|
||||
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
|
||||
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
|
||||
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
|
||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
}
|
||||
|
||||
|
||||
VECTOR_SEARCH_TOP_K = 20
|
||||
LLM_MODEL = "vicuna-13b"
|
||||
LIMIT_MODEL_CONCURRENCY = 5
|
||||
MAX_POSITION_EMBEDDINGS = 4096
|
||||
# VICUNA_MODEL_SERVER = "http://121.41.227.141:8000"
|
||||
VICUNA_MODEL_SERVER = "http://120.79.27.110:8000"
|
||||
|
||||
# Load model config
|
||||
ISLOAD_8BIT = True
|
||||
ISDEBUG = False
|
||||
|
||||
|
||||
VECTOR_SEARCH_TOP_K = 3
|
||||
# LLM_MODEL = "vicuna-13b"
|
||||
# LIMIT_MODEL_CONCURRENCY = 5
|
||||
# MAX_POSITION_EMBEDDINGS = 4096
|
||||
# VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
|
||||
|
||||
|
||||
|
||||
VECTOR_SEARCH_TOP_K = 10
|
||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge")
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "data"
|
||||
)
|
||||
KNOWLEDGE_CHUNK_SPLIT_SIZE = 100
|
||||
|
@@ -3,6 +3,6 @@
|
||||
|
||||
"""We need to design a base class. That other connector can Write with this"""
|
||||
|
||||
|
||||
class BaseConnection:
|
||||
pass
|
||||
|
||||
|
@@ -4,4 +4,5 @@
|
||||
|
||||
class ClickHouseConnector:
|
||||
"""ClickHouseConnector"""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
@@ -4,4 +4,5 @@
|
||||
|
||||
class ElasticSearchConnector:
|
||||
"""ElasticSearchConnector"""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
@@ -1,6 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
class MongoConnector:
|
||||
"""MongoConnector is a class which connect to mongo and chat with LLM"""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
@@ -4,26 +4,25 @@
|
||||
import pymysql
|
||||
|
||||
class MySQLOperator:
|
||||
"""Connect MySQL Database fetch MetaData For LLM Prompt
|
||||
Args:
|
||||
"""Connect MySQL Database fetch MetaData For LLM Prompt
|
||||
Args:
|
||||
|
||||
Usage:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
def __init__(self, user, password, host="localhost", port=3306) -> None:
|
||||
|
||||
self.conn = pymysql.connect(
|
||||
host=host,
|
||||
user=user,
|
||||
port=port,
|
||||
passwd=password,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
)
|
||||
|
||||
def get_schema(self, schema_name):
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
_sql = f"""
|
||||
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{schema_name}" group by TABLE_NAME;
|
||||
@@ -32,6 +31,7 @@ class MySQLOperator:
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
|
||||
def run_sql(self, db_name:str, sql:str, fetch: str = "all"):
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("USE " + db_name)
|
||||
@@ -55,10 +55,10 @@ class MySQLOperator:
|
||||
cursor.execute(_sql)
|
||||
results = cursor.fetchall()
|
||||
|
||||
dbs = [d["Database"] for d in results if d["Database"] not in self.default_db]
|
||||
dbs = [
|
||||
d["Database"] for d in results if d["Database"] not in self.default_db
|
||||
]
|
||||
return dbs
|
||||
|
||||
def get_meta(self, schema_name):
|
||||
pass
|
||||
|
||||
|
||||
|
@@ -1,6 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
|
||||
class OracleConnector:
|
||||
"""OracleConnector"""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
@@ -2,7 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
|
||||
class PostgresConnector:
|
||||
"""PostgresConnector is a class which Connector to chat with LLM"""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
@@ -4,4 +4,5 @@
|
||||
|
||||
class RedisConnector:
|
||||
"""RedisConnector"""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
@@ -5,26 +5,32 @@ import dataclasses
|
||||
import uuid
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
DB_SETTINGS = {
|
||||
"user": CFG.LOCAL_DB_USER,
|
||||
"password": CFG.LOCAL_DB_PASSWORD,
|
||||
"password": CFG.LOCAL_DB_PASSWORD,
|
||||
"host": CFG.LOCAL_DB_HOST,
|
||||
"port": CFG.LOCAL_DB_PORT
|
||||
"port": CFG.LOCAL_DB_PORT,
|
||||
}
|
||||
|
||||
ROLE_USER = "USER"
|
||||
ROLE_ASSISTANT = "Assistant"
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
THREE = auto()
|
||||
FOUR = auto()
|
||||
|
||||
@ dataclasses.dataclass
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""This class keeps all conversation history. """
|
||||
"""This class keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
@@ -65,7 +71,7 @@ class Conversation:
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
@@ -93,15 +99,14 @@ class Conversation:
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
"conv_id": self.conv_id
|
||||
"conv_id": self.conv_id,
|
||||
}
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
from pilot.connections.mysql import MySQLOperator
|
||||
mo = MySQLOperator(
|
||||
**(DB_SETTINGS)
|
||||
)
|
||||
|
||||
mo = MySQLOperator(**(DB_SETTINGS))
|
||||
|
||||
message = ""
|
||||
|
||||
@@ -113,7 +118,7 @@ def gen_sqlgen_conversation(dbname):
|
||||
|
||||
conv_one_shot = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||
roles=("USER", "Assistant"),
|
||||
messages=(
|
||||
(
|
||||
@@ -134,20 +139,19 @@ conv_one_shot = Conversation(
|
||||
"whereas PostgreSQL is known for its robustness and reliability.\n"
|
||||
"5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, "
|
||||
"whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n"
|
||||
|
||||
"Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. "
|
||||
"Both are excellent database management systems, and choosing the right one "
|
||||
"for your project requires careful consideration of your application's requirements, performance needs, and scalability."
|
||||
"for your project requires careful consideration of your application's requirements, performance needs, and scalability.",
|
||||
),
|
||||
),
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###"
|
||||
sep="###",
|
||||
)
|
||||
|
||||
conv_vicuna_v1 = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
@@ -158,7 +162,7 @@ conv_vicuna_v1 = Conversation(
|
||||
|
||||
auto_dbgpt_one_shot = Conversation(
|
||||
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
|
||||
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(
|
||||
(
|
||||
@@ -201,7 +205,7 @@ auto_dbgpt_one_shot = Conversation(
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
""",
|
||||
),
|
||||
(
|
||||
"ASSISTANT",
|
||||
@@ -221,8 +225,8 @@ auto_dbgpt_one_shot = Conversation(
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
""",
|
||||
),
|
||||
),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
@@ -231,7 +235,7 @@ auto_dbgpt_one_shot = Conversation(
|
||||
|
||||
auto_dbgpt_without_shot = Conversation(
|
||||
system="You are DB-GPT, an AI designed to answer questions about users by query `users` database in MySQL. "
|
||||
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
@@ -248,11 +252,18 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回
|
||||
{question}
|
||||
"""
|
||||
|
||||
# conv_qa_prompt_template = """ Please provide the known information so that I can professionally and briefly answer the user's question. If the answer cannot be obtained from the provided content,
|
||||
# please say: "The information provided in the knowledge base is insufficient to answer this question." Fabrication is prohibited.。
|
||||
# known information:
|
||||
# {context}
|
||||
# question:
|
||||
# {question}
|
||||
# """
|
||||
default_conversation = conv_one_shot
|
||||
|
||||
conversation_sql_mode ={
|
||||
conversation_sql_mode = {
|
||||
"auto_execute_ai_response": "直接执行结果",
|
||||
"dont_execute_ai_response": "不直接执行结果"
|
||||
"dont_execute_ai_response": "不直接执行结果",
|
||||
}
|
||||
|
||||
conversation_types = {
|
||||
@@ -265,7 +276,7 @@ conversation_types = {
|
||||
conv_templates = {
|
||||
"conv_one_shot": conv_one_shot,
|
||||
"vicuna_v1": conv_vicuna_v1,
|
||||
"auto_dbgpt_one_shot": auto_dbgpt_one_shot
|
||||
"auto_dbgpt_one_shot": auto_dbgpt_one_shot,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -8,8 +8,8 @@ import re
|
||||
from typing import Optional
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
from pilot.json_utils.utilities import extract_char_position
|
||||
from pilot.logs import logger
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
@@ -84,7 +84,7 @@ class Logger(metaclass=Singleton):
|
||||
self.chat_plugins = []
|
||||
|
||||
def typewriter_log(
|
||||
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
|
||||
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
|
||||
):
|
||||
if speak_text and self.speak_mode:
|
||||
say_text(f"{title}. {content}")
|
||||
@@ -103,26 +103,26 @@ class Logger(metaclass=Singleton):
|
||||
)
|
||||
|
||||
def debug(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
self._log(title, title_color, message, logging.DEBUG)
|
||||
|
||||
def info(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
self._log(title, title_color, message, logging.INFO)
|
||||
|
||||
def warn(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
self._log(title, title_color, message, logging.WARN)
|
||||
|
||||
@@ -130,11 +130,11 @@ class Logger(metaclass=Singleton):
|
||||
self._log(title, Fore.RED, message, logging.ERROR)
|
||||
|
||||
def _log(
|
||||
self,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
message: str = "",
|
||||
level=logging.INFO,
|
||||
self,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
message: str = "",
|
||||
level=logging.INFO,
|
||||
):
|
||||
if message:
|
||||
if isinstance(message, list):
|
||||
@@ -178,10 +178,12 @@ class Logger(metaclass=Singleton):
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
return os.path.abspath(log_dir)
|
||||
|
||||
|
||||
"""
|
||||
Output stream to console using simulated typing
|
||||
"""
|
||||
|
||||
|
||||
class TypingConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record):
|
||||
min_typing_speed = 0.05
|
||||
@@ -203,6 +205,7 @@ class TypingConsoleHandler(logging.StreamHandler):
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class ConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record) -> None:
|
||||
msg = self.format(record)
|
||||
@@ -221,10 +224,10 @@ class DbGptFormatter(logging.Formatter):
|
||||
def format(self, record: LogRecord) -> str:
|
||||
if hasattr(record, "color"):
|
||||
record.title_color = (
|
||||
getattr(record, "color")
|
||||
+ getattr(record, "title", "")
|
||||
+ " "
|
||||
+ Style.RESET_ALL
|
||||
getattr(record, "color")
|
||||
+ getattr(record, "title", "")
|
||||
+ " "
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
else:
|
||||
record.title_color = getattr(record, "title", "")
|
||||
@@ -248,9 +251,9 @@ logger = Logger()
|
||||
|
||||
|
||||
def print_assistant_thoughts(
|
||||
ai_name: object,
|
||||
assistant_reply_json_valid: object,
|
||||
speak_mode: bool = False,
|
||||
ai_name: object,
|
||||
assistant_reply_json_valid: object,
|
||||
speak_mode: bool = False,
|
||||
) -> None:
|
||||
assistant_thoughts_reasoning = None
|
||||
assistant_thoughts_plan = None
|
||||
|
@@ -1,17 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List
|
||||
from functools import cache
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from pilot.configs.model_config import DEVICE
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoModel
|
||||
)
|
||||
|
||||
class BaseLLMAdaper:
|
||||
"""The Base class for multi model, in our project.
|
||||
We will support those model, which performance resemble ChatGPT """
|
||||
We will support those model, which performance resemble ChatGPT"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return True
|
||||
@@ -24,7 +23,8 @@ class BaseLLMAdaper:
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
llm_model_adapters = List[BaseLLMAdaper] = []
|
||||
llm_model_adapters: List[BaseLLMAdaper] = []
|
||||
|
||||
|
||||
# Register llm models to adapters, by this we can use multi models.
|
||||
def register_llm_model_adapters(cls):
|
||||
@@ -37,60 +37,91 @@ def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper:
|
||||
for adapter in llm_model_adapters:
|
||||
if adapter.match(model_path):
|
||||
return adapter
|
||||
|
||||
|
||||
raise ValueError(f"Invalid model adapter for {model_path}")
|
||||
|
||||
|
||||
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
||||
|
||||
|
||||
class VicunaLLMAdapater(BaseLLMAdaper):
|
||||
"""Vicuna Adapter """
|
||||
"""Vicuna Adapter"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "vicuna" in model_path
|
||||
return "vicuna" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
low_cpu_mem_usage=True,
|
||||
**from_pretrained_kwagrs
|
||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwagrs
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class ChatGLMAdapater(BaseLLMAdaper):
|
||||
"""LLM Adatpter for THUDM/chatglm-6b"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "chatglm" in model_path
|
||||
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(
|
||||
model_path, trust_remote_code=True, **from_pretrained_kwargs
|
||||
).half().cuda()
|
||||
return model, tokenizer
|
||||
|
||||
if DEVICE != "cuda":
|
||||
model = AutoModel.from_pretrained(
|
||||
model_path, trust_remote_code=True, **from_pretrained_kwargs
|
||||
).float()
|
||||
return model, tokenizer
|
||||
else:
|
||||
model = (
|
||||
AutoModel.from_pretrained(
|
||||
model_path, trust_remote_code=True, **from_pretrained_kwargs
|
||||
)
|
||||
.half()
|
||||
.cuda()
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class CodeGenAdapter(BaseLLMAdaper):
|
||||
pass
|
||||
|
||||
|
||||
class StarCoderAdapter(BaseLLMAdaper):
|
||||
pass
|
||||
|
||||
|
||||
class T5CodeAdapter(BaseLLMAdaper):
|
||||
pass
|
||||
|
||||
|
||||
class KoalaLLMAdapter(BaseLLMAdaper):
|
||||
"""Koala LLM Adapter which Based LLaMA """
|
||||
"""Koala LLM Adapter which Based LLaMA"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "koala" in model_path
|
||||
|
||||
|
||||
|
||||
class RWKV4LLMAdapter(BaseLLMAdaper):
|
||||
"""LLM Adapter for RwKv4 """
|
||||
"""LLM Adapter for RwKv4"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "RWKV-4" in model_path
|
||||
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class GPT4AllAdapter(BaseLLMAdaper):
|
||||
"""A light version for someone who want practise LLM use laptop."""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "gpt4all" in model_path
|
||||
|
||||
|
||||
register_llm_model_adapters(VicunaLLMAdapater)
|
||||
register_llm_model_adapters(ChatGLMAdapater)
|
||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||
|
||||
register_llm_model_adapters(BaseLLMAdaper)
|
||||
register_llm_model_adapters(BaseLLMAdaper)
|
||||
|
@@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import List, TypedDict
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
"""LLM Message object containing usually like (role: content) """
|
||||
"""LLM Message object containing usually like (role: content)"""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
@@ -1,3 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
58
pilot/model/chatglm_llm.py
Normal file
58
pilot/model/chatglm_llm.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def chatglm_generate_stream(
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||
):
|
||||
"""Generate text using chatglm model's chat api"""
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
stop = params.get("stop", "###")
|
||||
echo = params.get("echo", False)
|
||||
|
||||
generate_kwargs = {
|
||||
"do_sample": True if temperature > 1e-5 else False,
|
||||
"top_p": top_p,
|
||||
"repetition_penalty": 1.0,
|
||||
"logits_processor": None,
|
||||
}
|
||||
|
||||
if temperature > 1e-5:
|
||||
generate_kwargs["temperature"] = temperature
|
||||
|
||||
# TODO, Fix this
|
||||
hist = []
|
||||
|
||||
messages = prompt.split(stop)
|
||||
|
||||
# Add history chat to hist for model.
|
||||
for i in range(1, len(messages) - 2, 2):
|
||||
hist.append(
|
||||
(
|
||||
messages[i].split(ROLE_USER + ":")[1],
|
||||
messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
|
||||
)
|
||||
)
|
||||
|
||||
query = messages[-2].split(ROLE_USER + ":")[1]
|
||||
print("Query Message: ", query)
|
||||
output = ""
|
||||
i = 0
|
||||
for i, (response, new_hist) in enumerate(
|
||||
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
||||
):
|
||||
if echo:
|
||||
output = query + " " + response
|
||||
else:
|
||||
output = response
|
||||
|
||||
yield output
|
||||
|
||||
yield output
|
@@ -3,14 +3,15 @@
|
||||
import dataclasses
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompressionConfig:
|
||||
"""Group-wise quantization."""
|
||||
|
||||
num_bits: int
|
||||
group_size: int
|
||||
group_dim: int
|
||||
@@ -19,7 +20,8 @@ class CompressionConfig:
|
||||
|
||||
|
||||
default_compression_config = CompressionConfig(
|
||||
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True)
|
||||
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True
|
||||
)
|
||||
|
||||
|
||||
class CLinear(nn.Module):
|
||||
@@ -40,8 +42,11 @@ def compress_module(module, target_device):
|
||||
for attr_str in dir(module):
|
||||
target_attr = getattr(module, attr_str)
|
||||
if type(target_attr) == torch.nn.Linear:
|
||||
setattr(module, attr_str,
|
||||
CLinear(target_attr.weight, target_attr.bias, target_device))
|
||||
setattr(
|
||||
module,
|
||||
attr_str,
|
||||
CLinear(target_attr.weight, target_attr.bias, target_device),
|
||||
)
|
||||
for name, child in module.named_children():
|
||||
compress_module(child, target_device)
|
||||
|
||||
@@ -52,22 +57,31 @@ def compress(tensor, config):
|
||||
return tensor
|
||||
|
||||
group_size, num_bits, group_dim, symmetric = (
|
||||
config.group_size, config.num_bits, config.group_dim, config.symmetric)
|
||||
config.group_size,
|
||||
config.num_bits,
|
||||
config.group_dim,
|
||||
config.symmetric,
|
||||
)
|
||||
assert num_bits <= 8
|
||||
|
||||
original_shape = tensor.shape
|
||||
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
|
||||
new_shape = (original_shape[:group_dim] + (num_groups, group_size) +
|
||||
original_shape[group_dim+1:])
|
||||
new_shape = (
|
||||
original_shape[:group_dim]
|
||||
+ (num_groups, group_size)
|
||||
+ original_shape[group_dim + 1 :]
|
||||
)
|
||||
|
||||
# Pad
|
||||
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
||||
if pad_len != 0:
|
||||
pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:]
|
||||
tensor = torch.cat([
|
||||
tensor,
|
||||
torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
|
||||
dim=group_dim)
|
||||
pad_shape = (
|
||||
original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :]
|
||||
)
|
||||
tensor = torch.cat(
|
||||
[tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
|
||||
dim=group_dim,
|
||||
)
|
||||
data = tensor.view(new_shape)
|
||||
|
||||
# Quantize
|
||||
@@ -78,7 +92,7 @@ def compress(tensor, config):
|
||||
data = data.clamp_(-B, B).round_().to(torch.int8)
|
||||
return data, scale, original_shape
|
||||
else:
|
||||
B = 2 ** num_bits - 1
|
||||
B = 2**num_bits - 1
|
||||
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
|
||||
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
|
||||
|
||||
@@ -96,7 +110,11 @@ def decompress(packed_data, config):
|
||||
return packed_data
|
||||
|
||||
group_size, num_bits, group_dim, symmetric = (
|
||||
config.group_size, config.num_bits, config.group_dim, config.symmetric)
|
||||
config.group_size,
|
||||
config.num_bits,
|
||||
config.group_dim,
|
||||
config.symmetric,
|
||||
)
|
||||
|
||||
# Dequantize
|
||||
if symmetric:
|
||||
@@ -111,9 +129,10 @@ def decompress(packed_data, config):
|
||||
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
||||
if pad_len:
|
||||
padded_original_shape = (
|
||||
original_shape[:group_dim] +
|
||||
(original_shape[group_dim] + pad_len,) +
|
||||
original_shape[group_dim+1:])
|
||||
original_shape[:group_dim]
|
||||
+ (original_shape[group_dim] + pad_len,)
|
||||
+ original_shape[group_dim + 1 :]
|
||||
)
|
||||
data = data.reshape(padded_original_shape)
|
||||
indices = [slice(0, x) for x in original_shape]
|
||||
return data[indices].contiguous()
|
||||
|
@@ -3,11 +3,12 @@
|
||||
|
||||
import torch
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(model, tokenizer, params, device,
|
||||
context_len=4096, stream_interval=2):
|
||||
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||
@torch.inference_mode()
|
||||
def generate_stream(
|
||||
model, tokenizer, params, device, context_len=4096, stream_interval=2
|
||||
):
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
@@ -22,17 +23,19 @@ def generate_stream(model, tokenizer, params, device,
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0:
|
||||
out = model(
|
||||
torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else:
|
||||
attention_mask = torch.ones(
|
||||
1, past_key_values[0][0].shape[-2] + 1, device=device)
|
||||
out = model(input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values)
|
||||
1, past_key_values[0][0].shape[-2] + 1, device=device
|
||||
)
|
||||
out = model(
|
||||
input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
@@ -68,9 +71,12 @@ def generate_stream(model, tokenizer, params, device,
|
||||
|
||||
del past_key_values
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||
def generate_output(
|
||||
model, tokenizer, params, device, context_len=4096, stream_interval=2
|
||||
):
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
|
||||
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
@@ -78,7 +84,6 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_str = params.get("stop", None)
|
||||
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
output_ids = list(input_ids)
|
||||
|
||||
@@ -87,17 +92,19 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0:
|
||||
out = model(
|
||||
torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else:
|
||||
attention_mask = torch.ones(
|
||||
1, past_key_values[0][0].shape[-2] + 1, device=device)
|
||||
out = model(input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values)
|
||||
1, past_key_values[0][0].shape[-2] + 1, device=device
|
||||
)
|
||||
out = model(
|
||||
input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
@@ -120,7 +127,6 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
|
||||
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
pos = output.rfind(stop_str, l_prompt)
|
||||
@@ -133,8 +139,11 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
|
||||
break
|
||||
del past_key_values
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||
def generate_output_ex(
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||
):
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
@@ -161,20 +170,20 @@ def generate_output_ex(model, tokenizer, params, device, context_len=2048, strea
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0:
|
||||
out = model(
|
||||
torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else:
|
||||
out = model(input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values)
|
||||
out = model(
|
||||
input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
last_token_logits = logits[0][-1]
|
||||
|
||||
|
||||
if temperature < 1e-4:
|
||||
token = int(torch.argmax(last_token_logits))
|
||||
else:
|
||||
@@ -188,7 +197,6 @@ def generate_output_ex(model, tokenizer, params, device, context_len=2048, strea
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
|
||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
# print("Partial output:", output)
|
||||
for stop_str in stop_strings:
|
||||
@@ -211,7 +219,7 @@ def generate_output_ex(model, tokenizer, params, device, context_len=2048, strea
|
||||
del past_key_values
|
||||
if pos != -1:
|
||||
return output[:pos]
|
||||
return output
|
||||
return output
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
@@ -1,12 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, TypedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
"""Vicuna Message object containing a role and the message content """
|
||||
"""Vicuna Message object containing a role and the message content"""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
@@ -18,12 +18,15 @@ class ModelInfo:
|
||||
|
||||
Would be lovely to eventually get this directly from APIs
|
||||
"""
|
||||
|
||||
name: str
|
||||
max_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Standard response struct for a response from a LLM model."""
|
||||
|
||||
model_info = ModelInfo
|
||||
|
||||
|
||||
@@ -31,4 +34,4 @@ class LLMResponse:
|
||||
class ChatModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from an LLM model."""
|
||||
|
||||
content: str = None
|
||||
content: str = None
|
||||
|
@@ -2,35 +2,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import abc
|
||||
import time
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
from pilot.model.llm.base import Message
|
||||
from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import (
|
||||
Conversation,
|
||||
auto_dbgpt_one_shot,
|
||||
conv_one_shot,
|
||||
conv_templates,
|
||||
)
|
||||
from pilot.model.llm.base import Message
|
||||
|
||||
|
||||
# TODO Rewrite this
|
||||
def retry_stream_api(
|
||||
num_retries: int = 10,
|
||||
backoff_base: float = 2.0,
|
||||
warn_user: bool = True
|
||||
):
|
||||
num_retries: int = 10, backoff_base: float = 2.0, warn_user: bool = True
|
||||
):
|
||||
"""Retry an Vicuna Server call.
|
||||
|
||||
Args:
|
||||
num_retries int: Number of retries. Defaults to 10.
|
||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
Args:
|
||||
num_retries int: Number of retries. Defaults to 10.
|
||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
retry_limit_msg = f"Error: Reached rate limit, passing..."
|
||||
backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...")
|
||||
backoff_msg = f"Error: API Bad gateway. Waiting {{backoff}} seconds..."
|
||||
|
||||
def _wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def _wrapped(*args, **kwargs):
|
||||
user_warned = not warn_user
|
||||
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, num_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -39,10 +43,13 @@ def retry_stream_api(
|
||||
raise
|
||||
|
||||
backoff = backoff_base ** (attempt + 2)
|
||||
time.sleep(backoff)
|
||||
time.sleep(backoff)
|
||||
|
||||
return _wrapped
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
# Overly simple abstraction util we create something better
|
||||
# simple retry mechanism when getting a rate error or a bad gateway
|
||||
def create_chat_competion(
|
||||
@@ -52,15 +59,15 @@ def create_chat_competion(
|
||||
max_new_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a chat completion using the Vicuna-13b
|
||||
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Default to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The response from the chat completion
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Default to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The response from the chat completion
|
||||
"""
|
||||
cfg = Config()
|
||||
if temperature is None:
|
||||
@@ -77,7 +84,7 @@ class ChatIO(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
"""Prompt for input from a role."""
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
"""Prompt for output from a role."""
|
||||
@@ -105,4 +112,3 @@ class SimpleChatIO(ChatIO):
|
||||
|
||||
print(" ".join(outputs[pre:]), flush=True)
|
||||
return " ".join(outputs)
|
||||
|
||||
|
125
pilot/model/llm/monkey_patch.py
Normal file
125
pilot/model/llm/monkey_patch.py
Normal file
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from torch import nn
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2].clone()
|
||||
x2 = x[..., x.shape[-1] // 2 :].clone()
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
query_states.dtype
|
||||
)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def replace_llama_attn_with_non_inplace_operations():
|
||||
"""Avoid bugs in mps backend by not using in-place operations."""
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
def replace_llama_attn_with_non_inplace_operations():
|
||||
"""Avoid bugs in mps backend by not using in-place operations."""
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
@@ -2,31 +2,33 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from typing import List, Optional
|
||||
from pilot.model.base import Message
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.model.base import Message
|
||||
from pilot.server.llmserver import generate_output
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
messages: List[Message], # type: ignore
|
||||
messages: List[Message], # type: ignore
|
||||
model: Optional[str] = None,
|
||||
temperature: float = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a chat completion using the vicuna local model
|
||||
"""Create a chat completion using the vicuna local model
|
||||
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Defaults to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None
|
||||
|
||||
Returns:
|
||||
str: The response from chat completion
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Defaults to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None
|
||||
|
||||
Returns:
|
||||
str: The response from chat completion
|
||||
"""
|
||||
cfg = Config()
|
||||
if temperature is None:
|
||||
temperature = cfg.temperature
|
||||
|
||||
|
||||
for plugin in cfg.plugins:
|
||||
if plugin.can_handle_chat_completion(
|
||||
messages=messages,
|
||||
|
@@ -1,55 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import sys
|
||||
import warnings
|
||||
from pilot.singleton import Singleton
|
||||
from typing import Optional
|
||||
|
||||
from pilot.model.compression import compress_module
|
||||
import torch
|
||||
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.model.adapter import get_llm_model_adapter
|
||||
from pilot.model.compression import compress_module
|
||||
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.utils import get_gpu_memory
|
||||
|
||||
|
||||
def raise_warning_for_incompatible_cpu_offloading_configuration(
|
||||
device: str, load_8bit: bool, cpu_offloading: bool
|
||||
):
|
||||
if cpu_offloading:
|
||||
if not load_8bit:
|
||||
warnings.warn(
|
||||
"The cpu-offloading feature can only be used while also using 8-bit-quantization.\n"
|
||||
"Use '--load-8bit' to enable 8-bit-quantization\n"
|
||||
"Continuing without cpu-offloading enabled\n"
|
||||
)
|
||||
return False
|
||||
if not "linux" in sys.platform:
|
||||
warnings.warn(
|
||||
"CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n"
|
||||
"Continuing without cpu-offloading enabled\n"
|
||||
)
|
||||
return False
|
||||
if device != "cuda":
|
||||
warnings.warn(
|
||||
"CPU-offloading is only enabled when using CUDA-devices\n"
|
||||
"Continuing without cpu-offloading enabled\n"
|
||||
)
|
||||
return False
|
||||
return cpu_offloading
|
||||
|
||||
|
||||
class ModelLoader(metaclass=Singleton):
|
||||
"""Model loader is a class for model load
|
||||
|
||||
|
||||
Args: model_path
|
||||
|
||||
TODO: multi model support.
|
||||
TODO: multi model support.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
|
||||
def __init__(self,
|
||||
model_path) -> None:
|
||||
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model_path = model_path
|
||||
def __init__(self, model_path) -> None:
|
||||
self.device = DEVICE
|
||||
self.model_path = model_path
|
||||
self.kwargs = {
|
||||
"torch_dtype": torch.float16,
|
||||
"device_map": "auto",
|
||||
}
|
||||
|
||||
# TODO multi gpu support
|
||||
def loader(self, num_gpus, load_8bit=False, debug=False):
|
||||
def loader(
|
||||
self,
|
||||
num_gpus,
|
||||
load_8bit=False,
|
||||
debug=False,
|
||||
cpu_offloading=False,
|
||||
max_gpu_memory: Optional[str] = None,
|
||||
):
|
||||
if self.device == "cpu":
|
||||
kwargs = {}
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
|
||||
elif self.device == "cuda":
|
||||
kwargs = {"torch_dtype": torch.float16}
|
||||
if num_gpus == "auto":
|
||||
num_gpus = int(num_gpus)
|
||||
|
||||
if num_gpus != 1:
|
||||
kwargs["device_map"] = "auto"
|
||||
if max_gpu_memory is None:
|
||||
kwargs["device_map"] = "sequential"
|
||||
|
||||
available_gpu_memory = get_gpu_memory(num_gpus)
|
||||
kwargs["max_memory"] = {
|
||||
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
||||
for i in range(num_gpus)
|
||||
}
|
||||
|
||||
else:
|
||||
num_gpus = int(num_gpus)
|
||||
if num_gpus != 1:
|
||||
kwargs.update({
|
||||
"device_map": "auto",
|
||||
"max_memory": {i: "13GiB" for i in range(num_gpus)},
|
||||
})
|
||||
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
|
||||
|
||||
elif self.device == "mps":
|
||||
kwargs = kwargs = {"torch_dtype": torch.float16}
|
||||
replace_llama_attn_with_non_inplace_operations()
|
||||
else:
|
||||
# Todo Support mps for practise
|
||||
raise ValueError(f"Invalid device: {self.device}")
|
||||
|
||||
|
||||
# TODO when cpu loading, need use quantization config
|
||||
|
||||
llm_adapter = get_llm_model_adapter(self.model_path)
|
||||
model, tokenizer = llm_adapter.loader(self.model_path, kwargs)
|
||||
|
||||
@@ -59,13 +106,14 @@ class ModelLoader(metaclass=Singleton):
|
||||
"8-bit quantization is not supported for multi-gpu inference"
|
||||
)
|
||||
else:
|
||||
compress_module(model, self.device)
|
||||
compress_module(model, self.device)
|
||||
|
||||
if (self.device == "cuda" and num_gpus == 1):
|
||||
if (
|
||||
self.device == "cuda" and num_gpus == 1 and not cpu_offloading
|
||||
) or self.device == "mps":
|
||||
model.to(self.device)
|
||||
|
||||
if debug:
|
||||
print(model)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
@@ -2,25 +2,34 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import json
|
||||
import requests
|
||||
from typing import Any, List, Mapping, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Mapping, Optional, List
|
||||
from langchain.llms.base import LLM
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class VicunaLLM(LLM):
|
||||
|
||||
vicuna_generate_path = "generate_stream"
|
||||
def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str:
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
max_new_tokens: int,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"stop": stop
|
||||
"stop": stop,
|
||||
}
|
||||
response = requests.post(
|
||||
url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path),
|
||||
@@ -41,10 +50,9 @@ class VicunaLLM(LLM):
|
||||
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
class VicunaEmbeddingLLM(BaseModel, Embeddings):
|
||||
|
||||
vicuna_embedding_path = "embedding"
|
||||
|
||||
def _call(self, prompt: str) -> str:
|
||||
@@ -53,15 +61,13 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
|
||||
|
||||
response = requests.post(
|
||||
url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path),
|
||||
json={
|
||||
"prompt": p
|
||||
}
|
||||
json={"prompt": p},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["response"]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
""" Call out to Vicuna's server embedding endpoint for embedding search docs.
|
||||
"""Call out to Vicuna's server embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of text to embed
|
||||
@@ -73,17 +79,15 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
|
||||
for text in texts:
|
||||
response = self.embed_query(text)
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
return results
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
""" Call out to Vicuna's server embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
"""Call out to Vicuna's server embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
Returns:
|
||||
Embedding for the text
|
||||
"""
|
||||
embedding = self._call(text)
|
||||
return embedding
|
||||
|
||||
|
@@ -1,11 +1,10 @@
|
||||
"""加载组件"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
from zipimport import zipimporter
|
||||
|
||||
@@ -15,6 +14,7 @@ from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
|
||||
|
||||
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||
"""
|
||||
Loader zip plugin file. Native support Auto_gpt_plugin
|
||||
@@ -36,6 +36,7 @@ def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||
logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
|
||||
return result
|
||||
|
||||
|
||||
def write_dict_to_json_file(data: dict, file_path: str) -> None:
|
||||
"""
|
||||
Write a dictionary to a JSON file.
|
||||
@@ -46,6 +47,7 @@ def write_dict_to_json_file(data: dict, file_path: str) -> None:
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(data, file, indent=4)
|
||||
|
||||
|
||||
def create_directory_if_not_exists(directory_path: str) -> bool:
|
||||
"""
|
||||
Create a directory if it does not exist.
|
||||
@@ -66,6 +68,7 @@ def create_directory_if_not_exists(directory_path: str) -> bool:
|
||||
logger.info(f"Directory {directory_path} already exists")
|
||||
return True
|
||||
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
|
@@ -1,19 +1,21 @@
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from typing import Any, Optional, Type
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import distro
|
||||
import yaml
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.prompt import build_default_prompt_generator, DEFAULT_PROMPT_OHTER, DEFAULT_TRIGGERING_PROMPT
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.prompts.prompt import (
|
||||
DEFAULT_PROMPT_OHTER,
|
||||
DEFAULT_TRIGGERING_PROMPT,
|
||||
build_default_prompt_generator,
|
||||
)
|
||||
|
||||
|
||||
class AutoModePrompt:
|
||||
"""
|
||||
""" """
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
ai_goals: list | None = None,
|
||||
@@ -36,23 +38,21 @@ class AutoModePrompt:
|
||||
self.command_registry = None
|
||||
|
||||
def construct_follow_up_prompt(
|
||||
self,
|
||||
user_input:[str],
|
||||
last_auto_return: str = None,
|
||||
prompt_generator: Optional[PromptGenerator] = None
|
||||
)-> str:
|
||||
self,
|
||||
user_input: [str],
|
||||
last_auto_return: str = None,
|
||||
prompt_generator: Optional[PromptGenerator] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build complete prompt information based on subsequent dialogue information entered by the user
|
||||
Args:
|
||||
self:
|
||||
prompt_generator:
|
||||
Build complete prompt information based on subsequent dialogue information entered by the user
|
||||
Args:
|
||||
self:
|
||||
prompt_generator:
|
||||
|
||||
Returns:
|
||||
Returns:
|
||||
|
||||
"""
|
||||
prompt_start = (
|
||||
DEFAULT_PROMPT_OHTER
|
||||
)
|
||||
"""
|
||||
prompt_start = DEFAULT_PROMPT_OHTER
|
||||
if prompt_generator is None:
|
||||
prompt_generator = build_default_prompt_generator()
|
||||
prompt_generator.goals = user_input
|
||||
@@ -64,12 +64,13 @@ class AutoModePrompt:
|
||||
continue
|
||||
prompt_generator = plugin.post_prompt(prompt_generator)
|
||||
|
||||
|
||||
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
|
||||
if not self.ai_goals :
|
||||
if not self.ai_goals:
|
||||
self.ai_goals = user_input
|
||||
for i, goal in enumerate(self.ai_goals):
|
||||
full_prompt += f"{i+1}.According to the provided Schema information, {goal}\n"
|
||||
full_prompt += (
|
||||
f"{i+1}.According to the provided Schema information, {goal}\n"
|
||||
)
|
||||
# if last_auto_return == None:
|
||||
# full_prompt += f"{cfg.last_plugin_return}\n\n"
|
||||
# else:
|
||||
@@ -82,10 +83,10 @@ class AutoModePrompt:
|
||||
return full_prompt
|
||||
|
||||
def construct_first_prompt(
|
||||
self,
|
||||
fisrt_message: [str]=[],
|
||||
db_schemes: str=None,
|
||||
prompt_generator: Optional[PromptGenerator] = None
|
||||
self,
|
||||
fisrt_message: [str] = [],
|
||||
db_schemes: str = None,
|
||||
prompt_generator: Optional[PromptGenerator] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build complete prompt information based on the initial dialogue information entered by the user
|
||||
@@ -125,16 +126,18 @@ class AutoModePrompt:
|
||||
|
||||
# Construct full prompt
|
||||
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
|
||||
if not self.ai_goals :
|
||||
if not self.ai_goals:
|
||||
self.ai_goals = fisrt_message
|
||||
for i, goal in enumerate(self.ai_goals):
|
||||
full_prompt += f"{i+1}.According to the provided Schema information,{goal}\n"
|
||||
if db_schemes:
|
||||
full_prompt += f"\nSchema:\n\n"
|
||||
full_prompt += (
|
||||
f"{i+1}.According to the provided Schema information,{goal}\n"
|
||||
)
|
||||
if db_schemes:
|
||||
full_prompt += f"\nSchema:\n\n"
|
||||
full_prompt += f"{db_schemes}"
|
||||
|
||||
# if self.api_budget > 0.0:
|
||||
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
|
||||
self.prompt_generator = prompt_generator
|
||||
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
|
||||
return full_prompt
|
||||
return full_prompt
|
||||
|
@@ -149,7 +149,7 @@ class PromptGenerator:
|
||||
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
|
||||
"Performance Evaluation:\n"
|
||||
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
|
||||
"You should only respond in JSON format as described below and ensure the"
|
||||
"You should only respond in JSON format as described below and ensure the"
|
||||
"response can be parsed by Python json.loads \nResponse"
|
||||
f" Format: \n{formatted_response_format}"
|
||||
)
|
||||
|
@@ -1,17 +1,14 @@
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
DEFAULT_TRIGGERING_PROMPT = (
|
||||
"Determine which next command to use, and respond using the format specified above"
|
||||
)
|
||||
|
||||
DEFAULT_PROMPT_OHTER = (
|
||||
"Previous response was excellent. Please response according to the requirements based on the new goal"
|
||||
)
|
||||
DEFAULT_PROMPT_OHTER = "Previous response was excellent. Please response according to the requirements based on the new goal"
|
||||
|
||||
|
||||
def build_default_prompt_generator() -> PromptGenerator:
|
||||
"""
|
||||
@@ -36,17 +33,15 @@ def build_default_prompt_generator() -> PromptGenerator:
|
||||
)
|
||||
# prompt_generator.add_constraint("No user assistance")
|
||||
|
||||
prompt_generator.add_constraint(
|
||||
'Only output one correct JSON response at a time'
|
||||
)
|
||||
prompt_generator.add_constraint("Only output one correct JSON response at a time")
|
||||
prompt_generator.add_constraint(
|
||||
'Exclusively use the commands listed in double quotes e.g. "command name"'
|
||||
)
|
||||
prompt_generator.add_constraint(
|
||||
'If there is SQL in the args parameter, ensure to use the database and table definitions in Schema, and ensure that the fields and table names are in the definition'
|
||||
"If there is SQL in the args parameter, ensure to use the database and table definitions in Schema, and ensure that the fields and table names are in the definition"
|
||||
)
|
||||
prompt_generator.add_constraint(
|
||||
'The generated command args need to comply with the definition of the command'
|
||||
"The generated command args need to comply with the definition of the command"
|
||||
)
|
||||
|
||||
# Add resources to the PromptGenerator object
|
||||
|
@@ -1,26 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import json
|
||||
import transformers
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
import os
|
||||
|
||||
from typing import List
|
||||
import pandas as pd
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training,
|
||||
)
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
import pandas as pd
|
||||
from pilot.configs.config import Config
|
||||
|
||||
|
||||
from pilot.configs.model_config import DATA_DIR, LLM_MODEL_CONFIG
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
CUTOFF_LEN = 50
|
||||
|
||||
@@ -28,6 +26,7 @@ df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv"))
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def sentiment_score_to_name(score: float):
|
||||
if score > 0:
|
||||
return "Positive"
|
||||
@@ -40,16 +39,18 @@ dataset_data = [
|
||||
{
|
||||
"instruction": "Detect the sentiment of the tweet.",
|
||||
"input": row_dict["Tweet"],
|
||||
"output": sentiment_score_to_name(row_dict["New_Sentiment_State"])
|
||||
}
|
||||
"output": sentiment_score_to_name(row_dict["New_Sentiment_State"]),
|
||||
}
|
||||
for row_dict in df.to_dict(orient="records")
|
||||
]
|
||||
|
||||
with open(os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"), "w") as f:
|
||||
json.dump(dataset_data, f)
|
||||
json.dump(dataset_data, f)
|
||||
|
||||
|
||||
data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"))
|
||||
data = load_dataset(
|
||||
"json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json")
|
||||
)
|
||||
print(data["train"])
|
||||
|
||||
BASE_MODEL = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
@@ -57,13 +58,14 @@ model = LlamaForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
offload_folder=os.path.join(DATA_DIR, "vicuna-lora")
|
||||
)
|
||||
offload_folder=os.path.join(DATA_DIR, "vicuna-lora"),
|
||||
)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
||||
tokenizer.pad_token_id = (0)
|
||||
tokenizer.pad_token_id = 0
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
|
||||
def generate_prompt(data_point):
|
||||
return f"""Blow is an instruction that describes a task, paired with an input that provide future context.
|
||||
Write a response that appropriately completes the request. #noqa:
|
||||
@@ -76,6 +78,7 @@ def generate_prompt(data_point):
|
||||
{data_point["output"]}
|
||||
"""
|
||||
|
||||
|
||||
def tokenize(prompt, add_eos_token=True):
|
||||
result = tokenizer(
|
||||
prompt,
|
||||
@@ -85,30 +88,29 @@ def tokenize(prompt, add_eos_token=True):
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
if (result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < CUTOFF_LEN and add_eos_token):
|
||||
if (
|
||||
result["input_ids"][-1] != tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < CUTOFF_LEN
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
full_prompt = generate_prompt(data_point)
|
||||
tokenized_full_prompt = tokenize(full_prompt)
|
||||
return tokenized_full_prompt
|
||||
|
||||
|
||||
train_val = data["train"].train_test_split(
|
||||
test_size=200, shuffle=True, seed=42
|
||||
)
|
||||
train_val = data["train"].train_test_split(test_size=200, shuffle=True, seed=42)
|
||||
|
||||
train_data = (
|
||||
train_val["train"].map(generate_and_tokenize_prompt)
|
||||
)
|
||||
train_data = train_val["train"].map(generate_and_tokenize_prompt)
|
||||
|
||||
val_data = (
|
||||
train_val["test"].map(generate_and_tokenize_prompt)
|
||||
)
|
||||
val_data = train_val["test"].map(generate_and_tokenize_prompt)
|
||||
|
||||
# Training
|
||||
LORA_R = 8
|
||||
@@ -129,7 +131,7 @@ OUTPUT_DIR = "experiments"
|
||||
# We can now prepare model for training
|
||||
model = prepare_model_for_int8_training(model)
|
||||
config = LoraConfig(
|
||||
r = LORA_R,
|
||||
r=LORA_R,
|
||||
lora_alpha=LORA_ALPHA,
|
||||
target_modules=LORA_TARGET_MODULES,
|
||||
lora_dropout=LORA_DROPOUT,
|
||||
@@ -156,7 +158,7 @@ training_arguments = transformers.TrainingArguments(
|
||||
output_dir=OUTPUT_DIR,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True,
|
||||
report_to="tensorboard"
|
||||
report_to="tensorboard",
|
||||
)
|
||||
|
||||
data_collector = transformers.DataCollatorForSeq2Seq(
|
||||
@@ -168,15 +170,13 @@ trainer = transformers.Trainer(
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args=training_arguments,
|
||||
data_collector=data_collector
|
||||
data_collector=data_collector,
|
||||
)
|
||||
|
||||
model.config.use_cache = False
|
||||
old_state_dict = model.state_dict
|
||||
model.state_dict = (
|
||||
lambda self, *_, **__: get_peft_model_state_dict(
|
||||
self, old_state_dict()
|
||||
)
|
||||
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
||||
).__get__(model, type(model))
|
||||
|
||||
trainer.train()
|
||||
|
90
pilot/server/chat_adapter.py
Normal file
90
pilot/server/chat_adapter.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from functools import cache
|
||||
from typing import List
|
||||
|
||||
from pilot.model.inference import generate_stream
|
||||
|
||||
|
||||
class BaseChatAdpter:
|
||||
"""The Base class for chat with llm models. it will match the model,
|
||||
and fetch output from model"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return True
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
"""Return the generate stream handler func"""
|
||||
pass
|
||||
|
||||
|
||||
llm_model_chat_adapters: List[BaseChatAdpter] = []
|
||||
|
||||
|
||||
def register_llm_model_chat_adapter(cls):
|
||||
"""Register a chat adapter"""
|
||||
llm_model_chat_adapters.append(cls())
|
||||
|
||||
|
||||
@cache
|
||||
def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
|
||||
"""Get a chat generate func for a model"""
|
||||
for adapter in llm_model_chat_adapters:
|
||||
if adapter.match(model_path):
|
||||
return adapter
|
||||
|
||||
raise ValueError(f"Invalid model for chat adapter {model_path}")
|
||||
|
||||
|
||||
class VicunaChatAdapter(BaseChatAdpter):
|
||||
|
||||
"""Model chat Adapter for vicuna"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "vicuna" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
return generate_stream
|
||||
|
||||
|
||||
class ChatGLMChatAdapter(BaseChatAdpter):
|
||||
"""Model chat Adapter for ChatGLM"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "chatglm" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.chatglm_llm import chatglm_generate_stream
|
||||
|
||||
return chatglm_generate_stream
|
||||
|
||||
|
||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
||||
|
||||
"""Model chat adapter for CodeT5"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "codet5" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class CodeGenChatAdapter(BaseChatAdpter):
|
||||
|
||||
"""Model chat adapter for CodeGen"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "codegen" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
|
||||
register_llm_model_chat_adapter(BaseChatAdpter)
|
@@ -1,8 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
code_highlight_css = (
|
||||
"""
|
||||
code_highlight_css = """
|
||||
#chatbot .hll { background-color: #ffffcc }
|
||||
#chatbot .c { color: #408080; font-style: italic }
|
||||
#chatbot .err { border: 1px solid #FF0000 }
|
||||
@@ -71,6 +70,5 @@ code_highlight_css = (
|
||||
#chatbot .vi { color: #19177C }
|
||||
#chatbot .vm { color: #19177C }
|
||||
#chatbot .il { color: #666666 }
|
||||
""")
|
||||
#.highlight { background: #f8f8f8; }
|
||||
|
||||
"""
|
||||
# .highlight { background: #f8f8f8; }
|
||||
|
@@ -49,7 +49,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
warnings.warn(
|
||||
"The 'color_map' parameter has been deprecated.",
|
||||
)
|
||||
#self.md = utils.get_markdown_parser()
|
||||
# self.md = utils.get_markdown_parser()
|
||||
self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
|
||||
self.select: EventListenerMethod
|
||||
"""
|
||||
@@ -112,7 +112,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
): # This happens for previously processed messages
|
||||
return chat_message
|
||||
elif isinstance(chat_message, str):
|
||||
#return self.md.render(chat_message)
|
||||
# return self.md.render(chat_message)
|
||||
return str(self.md.convert(chat_message))
|
||||
else:
|
||||
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
||||
@@ -141,9 +141,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
||||
processed_messages.append(
|
||||
(
|
||||
#self._process_chat_messages(message_pair[0]),
|
||||
'<pre style="font-family: var(--font)">' +
|
||||
message_pair[0] + "</pre>",
|
||||
# self._process_chat_messages(message_pair[0]),
|
||||
'<pre style="font-family: var(--font)">'
|
||||
+ message_pair[0]
|
||||
+ "</pre>",
|
||||
self._process_chat_messages(message_pair[1]),
|
||||
)
|
||||
)
|
||||
@@ -163,5 +164,3 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
**kwargs,
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
|
@@ -1,40 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import uvicorn
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI, Request, BackgroundTasks
|
||||
import os
|
||||
import sys
|
||||
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pilot.model.inference import generate_stream
|
||||
from pydantic import BaseModel
|
||||
from pilot.model.inference import generate_output, get_embeddings
|
||||
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.configs.model_config import *
|
||||
from pilot.configs.config import Config
|
||||
|
||||
|
||||
CFG = Config()
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
|
||||
|
||||
global_counter = 0
|
||||
model_semaphore = None
|
||||
|
||||
ml = ModelLoader(model_path=model_path)
|
||||
model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
|
||||
#model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import *
|
||||
from pilot.model.inference import generate_output, generate_stream, get_embeddings
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, model_path, model_name, device, num_gpus=1):
|
||||
if model_path.endswith("/"):
|
||||
model_path = model_path[:-1]
|
||||
self.model_name = model_name or model_path.split("/")[-1]
|
||||
self.device = device
|
||||
|
||||
self.ml = ModelLoader(model_path=model_path)
|
||||
self.model, self.tokenizer = self.ml.loader(
|
||||
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
|
||||
)
|
||||
|
||||
if hasattr(self.model.config, "max_sequence_length"):
|
||||
self.context_len = self.model.config.max_sequence_length
|
||||
elif hasattr(self.model.config, "max_position_embeddings"):
|
||||
self.context_len = self.model.config.max_position_embeddings
|
||||
|
||||
else:
|
||||
self.context_len = 2048
|
||||
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
|
||||
|
||||
def get_queue_length(self):
|
||||
if (
|
||||
model_semaphore is None
|
||||
or model_semaphore._value is None
|
||||
or model_semaphore._waiters is None
|
||||
):
|
||||
return 0
|
||||
else:
|
||||
(
|
||||
CFG.LIMIT_MODEL_CONCURRENCY
|
||||
- model_semaphore._value
|
||||
+ len(model_semaphore._waiters)
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
try:
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
):
|
||||
print("output: ", output)
|
||||
ret = {
|
||||
"text": output,
|
||||
"error_code": 0,
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
except torch.cuda.CudaError:
|
||||
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
def get_embeddings(self, prompt):
|
||||
return get_embeddings(self.model, self.tokenizer, prompt)
|
||||
|
||||
# TODO
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
temperature: float
|
||||
@@ -42,6 +93,7 @@ class PromptRequest(BaseModel):
|
||||
model: str
|
||||
stop: str = None
|
||||
|
||||
|
||||
class StreamRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
@@ -49,64 +101,43 @@ class StreamRequest(BaseModel):
|
||||
max_new_tokens: int
|
||||
stop: str
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
def release_model_semaphore():
|
||||
model_semaphore.release()
|
||||
|
||||
|
||||
def generate_stream_gate(params):
|
||||
try:
|
||||
for output in generate_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
params,
|
||||
DEVICE,
|
||||
CFG.MAX_POSITION_EMBEDDINGS,
|
||||
):
|
||||
print("output: ", output)
|
||||
ret = {
|
||||
"text": output,
|
||||
"error_code": 0,
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
except torch.cuda.CudaError:
|
||||
ret = {
|
||||
"text": "**GPU OutOfMemory, Please Refresh.**",
|
||||
"error_code": 0
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
|
||||
@app.post("/generate_stream")
|
||||
async def api_generate_stream(request: Request):
|
||||
global model_semaphore, global_counter
|
||||
global_counter += 1
|
||||
params = await request.json()
|
||||
print(model, tokenizer, params, DEVICE)
|
||||
|
||||
if model_semaphore is None:
|
||||
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
|
||||
await model_semaphore.acquire()
|
||||
await model_semaphore.acquire()
|
||||
|
||||
generator = generate_stream_gate(params)
|
||||
generator = worker.generate_stream_gate(params)
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_model_semaphore)
|
||||
return StreamingResponse(generator, background=background_tasks)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest):
|
||||
params = {
|
||||
"prompt": prompt_request.prompt,
|
||||
"temperature": prompt_request.temperature,
|
||||
"max_new_tokens": prompt_request.max_new_tokens,
|
||||
"stop": prompt_request.stop
|
||||
"stop": prompt_request.stop,
|
||||
}
|
||||
|
||||
response = []
|
||||
response = []
|
||||
rsp_str = ""
|
||||
output = generate_stream_gate(params)
|
||||
output = worker.generate_stream_gate(params)
|
||||
for rsp in output:
|
||||
# rsp = rsp.decode("utf-8")
|
||||
rsp_str = str(rsp, "utf-8")
|
||||
@@ -114,15 +145,22 @@ def generate(prompt_request: PromptRequest):
|
||||
response.append(rsp_str)
|
||||
|
||||
return {"response": rsp_str}
|
||||
|
||||
|
||||
|
||||
@app.post("/embedding")
|
||||
def embeddings(prompt_request: EmbeddingRequest):
|
||||
params = {"prompt": prompt_request.prompt}
|
||||
print("Received prompt: ", params["prompt"])
|
||||
output = get_embeddings(model, tokenizer, params["prompt"])
|
||||
output = worker.get_embeddings(params["prompt"])
|
||||
return {"response": [float(x) for x in output]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", log_level="info")
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
print(model_path, DEVICE)
|
||||
|
||||
worker = ModelWorker(
|
||||
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
|
||||
)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")
|
||||
|
@@ -1,29 +1,30 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pilot.conversation import conv_qa_prompt_template
|
||||
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from pilot.conversation import conv_qa_prompt_template
|
||||
from pilot.model.vicuna_llm import VicunaLLM
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
|
||||
|
||||
class KnownLedgeBaseQA:
|
||||
|
||||
def __init__(self) -> None:
|
||||
k2v = KnownLedge2Vector()
|
||||
self.vector_store = k2v.init_vector_store()
|
||||
self.llm = VicunaLLM()
|
||||
|
||||
|
||||
def get_similar_answer(self, query):
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K})
|
||||
retriever = self.vector_store.as_retriever(
|
||||
search_kwargs={"k": VECTOR_SEARCH_TOP_K}
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query=query)
|
||||
|
||||
context = [d.page_content for d in docs]
|
||||
context = [d.page_content for d in docs]
|
||||
result = prompt.format(context="\n".join(context), question=query)
|
||||
return result
|
||||
|
@@ -2,50 +2,63 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import gradio as gr
|
||||
import datetime
|
||||
import requests
|
||||
import uuid
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
from langchain import PromptTemplate
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
from pilot.connections.mysql import MySQLOperator
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command import execute_ai_response_json
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
from pilot.connections.mysql import MySQLOperator
|
||||
from pilot.conversation import (
|
||||
default_conversation,
|
||||
SeparatorStyle,
|
||||
conv_qa_prompt_template,
|
||||
conv_templates,
|
||||
conversation_types,
|
||||
conversation_sql_mode,
|
||||
SeparatorStyle, conv_qa_prompt_template
|
||||
conversation_types,
|
||||
default_conversation,
|
||||
)
|
||||
|
||||
from pilot.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
)
|
||||
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.utils import build_logger, server_error_msg
|
||||
from pilot.vector_store.extract_tovec import (
|
||||
get_vector_storelist,
|
||||
knownledge_tovec_st,
|
||||
load_knownledge_from_doc,
|
||||
)
|
||||
|
||||
from pilot.commands.command import execute_ai_response_json
|
||||
from pilot.scene.base import ChatScene
|
||||
@@ -66,9 +79,7 @@ autogpt = False
|
||||
vector_store_client = None
|
||||
vector_store_name = {"vs_name": ""}
|
||||
|
||||
priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
}
|
||||
priority = {"vicuna-13b": "aaa"}
|
||||
|
||||
# 加载插件
|
||||
CFG = Config()
|
||||
@@ -76,10 +87,12 @@ CHAT_FACTORY = ChatFactory()
|
||||
|
||||
DB_SETTINGS = {
|
||||
"user": CFG.LOCAL_DB_USER,
|
||||
"password": CFG.LOCAL_DB_PASSWORD,
|
||||
"password": CFG.LOCAL_DB_PASSWORD,
|
||||
"host": CFG.LOCAL_DB_HOST,
|
||||
"port": CFG.LOCAL_DB_PORT
|
||||
"port": CFG.LOCAL_DB_PORT,
|
||||
}
|
||||
|
||||
|
||||
def get_simlar(q):
|
||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||
@@ -89,9 +102,7 @@ def get_simlar(q):
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
mo = MySQLOperator(
|
||||
**DB_SETTINGS
|
||||
)
|
||||
mo = MySQLOperator(**DB_SETTINGS)
|
||||
|
||||
message = ""
|
||||
|
||||
@@ -334,8 +345,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
code_highlight_css
|
||||
+ """
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
@@ -372,7 +383,7 @@ def build_single_model_ui():
|
||||
notice_markdown = """
|
||||
# DB-GPT
|
||||
|
||||
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna-13b作为基础模型。此外,此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。
|
||||
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个开源的以数据库为基础的GPT实验项目,使用本地化的GPT大模型与您的数据和环境进行交互,无数据泄露风险,100% 私密,100% 安全。
|
||||
"""
|
||||
learn_more_markdown = """
|
||||
### Licence
|
||||
@@ -396,7 +407,7 @@ def build_single_model_ui():
|
||||
max_output_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1024,
|
||||
value=1024,
|
||||
value=512,
|
||||
step=64,
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
@@ -412,7 +423,8 @@ def build_single_model_ui():
|
||||
choices=dbs,
|
||||
value=dbs[0] if len(models) > 0 else "",
|
||||
interactive=True,
|
||||
show_label=True).style(container=False)
|
||||
show_label=True,
|
||||
).style(container=False)
|
||||
|
||||
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
|
||||
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
@@ -420,7 +432,9 @@ def build_single_model_ui():
|
||||
|
||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||
with tab_qa:
|
||||
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
||||
mode = gr.Radio(
|
||||
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
|
||||
)
|
||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||
with vs_setting:
|
||||
@@ -429,18 +443,22 @@ def build_single_model_ui():
|
||||
with gr.Column() as doc2vec:
|
||||
gr.Markdown("向知识库中添加文件")
|
||||
with gr.Tab("上传文件"):
|
||||
files = gr.File(label="添加文件",
|
||||
file_types=[".txt", ".md", ".docx", ".pdf"],
|
||||
file_count="multiple",
|
||||
show_label=False
|
||||
)
|
||||
files = gr.File(
|
||||
label="添加文件",
|
||||
file_types=[".txt", ".md", ".docx", ".pdf"],
|
||||
file_count="multiple",
|
||||
allow_flagged_uploads=True,
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
load_file_button = gr.Button("上传并加载到知识库")
|
||||
with gr.Tab("上传文件夹"):
|
||||
folder_files = gr.File(label="添加文件夹",
|
||||
accept_multiple_files=True,
|
||||
file_count="directory",
|
||||
show_label=False)
|
||||
folder_files = gr.File(
|
||||
label="添加文件夹",
|
||||
accept_multiple_files=True,
|
||||
file_count="directory",
|
||||
show_label=False,
|
||||
)
|
||||
load_folder_button = gr.Button("上传并加载到知识库")
|
||||
|
||||
with gr.Blocks():
|
||||
@@ -481,28 +499,32 @@ def build_single_model_ui():
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
vs_add.click(
|
||||
fn=save_vs_name, show_progress=True, inputs=[vs_name], outputs=[vs_name]
|
||||
)
|
||||
load_file_button.click(
|
||||
fn=knowledge_embedding_store,
|
||||
show_progress=True,
|
||||
inputs=[vs_name, files],
|
||||
outputs=[vs_name],
|
||||
)
|
||||
load_folder_button.click(
|
||||
fn=knowledge_embedding_store,
|
||||
show_progress=True,
|
||||
inputs=[vs_name, folder_files],
|
||||
outputs=[vs_name],
|
||||
)
|
||||
vs_add.click(fn=save_vs_name, show_progress=True,
|
||||
inputs=[vs_name],
|
||||
outputs=[vs_name])
|
||||
load_file_button.click(fn=knowledge_embedding_store,
|
||||
show_progress=True,
|
||||
inputs=[vs_name, files],
|
||||
outputs=[vs_name])
|
||||
load_folder_button.click(fn=knowledge_embedding_store,
|
||||
show_progress=True,
|
||||
inputs=[vs_name, folder_files],
|
||||
outputs=[vs_name])
|
||||
return state, chatbot, textbox, send_btn, button_row, parameter_row
|
||||
|
||||
|
||||
def build_webdemo():
|
||||
with gr.Blocks(
|
||||
title="数据库智能助手",
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
title="数据库智能助手",
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
) as demo:
|
||||
url_params = gr.JSON(visible=False)
|
||||
(
|
||||
@@ -544,15 +566,21 @@ def knowledge_embedding_store(vs_id, files):
|
||||
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id))
|
||||
for file in files:
|
||||
filename = os.path.split(file.name)[-1]
|
||||
shutil.move(file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename))
|
||||
shutil.move(
|
||||
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
|
||||
)
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config={
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
},
|
||||
)
|
||||
knowledge_embedding_client.knowledge_embedding()
|
||||
|
||||
|
||||
logger.info("knowledge embedding success")
|
||||
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
|
||||
|
||||
@@ -596,5 +624,8 @@ if __name__ == "__main__":
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||
).launch(
|
||||
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
|
||||
server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=args.share,
|
||||
max_threads=200,
|
||||
)
|
||||
|
@@ -5,10 +5,12 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Singleton(abc.ABCMeta, type):
|
||||
""" Singleton metaclass for ensuring only one instance of a class"""
|
||||
"""Singleton metaclass for ensuring only one instance of a class"""
|
||||
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call method for the singleton metaclass"""
|
||||
if cls not in cls._instances:
|
||||
@@ -18,4 +20,5 @@ class Singleton(abc.ABCMeta, type):
|
||||
|
||||
class AbstractSingleton(abc.ABC, metaclass=Singleton):
|
||||
"""Abstract singleton class for ensuring only one instance of a class"""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
@@ -1,8 +1,3 @@
|
||||
from pilot.source_embedding.source_embedding import SourceEmbedding
|
||||
from pilot.source_embedding.source_embedding import register
|
||||
from pilot.source_embedding.source_embedding import SourceEmbedding, register
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SourceEmbedding",
|
||||
"register"
|
||||
]
|
||||
__all__ = ["SourceEmbedding", "register"]
|
||||
|
55
pilot/source_embedding/chn_document_splitter.py
Normal file
55
pilot/source_embedding/chn_document_splitter.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
|
||||
|
||||
class CHNDocumentSplitter(CharacterTextSplitter):
|
||||
def __init__(self, pdf: bool = False, sentence_size: int = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pdf = pdf
|
||||
self.sentence_size = sentence_size
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
if self.pdf:
|
||||
text = re.sub(r"\n{3,}", r"\n", text)
|
||||
text = re.sub("\s", " ", text)
|
||||
text = re.sub("\n\n", "", text)
|
||||
|
||||
text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text)
|
||||
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)
|
||||
text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text)
|
||||
text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text)
|
||||
text = text.rstrip()
|
||||
ls = [i for i in text.split("\n") if i]
|
||||
for ele in ls:
|
||||
if len(ele) > self.sentence_size:
|
||||
ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele)
|
||||
ele1_ls = ele1.split("\n")
|
||||
for ele_ele1 in ele1_ls:
|
||||
if len(ele_ele1) > self.sentence_size:
|
||||
ele_ele2 = re.sub(
|
||||
r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r"\1\n\2", ele_ele1
|
||||
)
|
||||
ele2_ls = ele_ele2.split("\n")
|
||||
for ele_ele2 in ele2_ls:
|
||||
if len(ele_ele2) > self.sentence_size:
|
||||
ele_ele3 = re.sub(
|
||||
'( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2
|
||||
)
|
||||
ele2_id = ele2_ls.index(ele_ele2)
|
||||
ele2_ls = (
|
||||
ele2_ls[:ele2_id]
|
||||
+ [i for i in ele_ele3.split("\n") if i]
|
||||
+ ele2_ls[ele2_id + 1 :]
|
||||
)
|
||||
ele_id = ele1_ls.index(ele_ele1)
|
||||
ele1_ls = (
|
||||
ele1_ls[:ele_id]
|
||||
+ [i for i in ele2_ls if i]
|
||||
+ ele1_ls[ele_id + 1 :]
|
||||
)
|
||||
|
||||
id = ls.index(ele)
|
||||
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :]
|
||||
return ls
|
@@ -1,14 +1,21 @@
|
||||
from typing import List, Optional, Dict
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain.document_loaders import CSVLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
|
||||
|
||||
class CSVEmbedding(SourceEmbedding):
|
||||
"""csv embedding for read csv document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
embedding_args: Optional[Dict] = None,
|
||||
):
|
||||
"""Initialize with csv path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
self.file_path = file_path
|
||||
@@ -29,6 +36,3 @@ class CSVEmbedding(SourceEmbedding):
|
||||
documents[i].page_content = d.page_content.replace("\n", "")
|
||||
i += 1
|
||||
return documents
|
||||
|
||||
|
||||
|
||||
|
@@ -1,34 +1,126 @@
|
||||
import os
|
||||
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders import PyPDFLoader, TextLoader, markdown
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class KnowledgeEmbedding:
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
def __init__(self, file_path, model_name, vector_store_config, local_persist=True):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.vector_store_type = "default"
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
self.file_type = "default"
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
self.vector_store_config["embeddings"] = self.embeddings
|
||||
self.local_persist = local_persist
|
||||
if not self.local_persist:
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
|
||||
def knowledge_embedding(self):
|
||||
self.knowledge_embedding_client.source_embedding()
|
||||
|
||||
def knowledge_embedding_batch(self):
|
||||
self.knowledge_embedding_client.batch_embedding()
|
||||
|
||||
def init_knowledge_embedding(self):
|
||||
if self.file_path.endswith(".pdf"):
|
||||
embedding = PDFEmbedding(file_path=self.file_path, model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config)
|
||||
embedding = PDFEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
elif self.file_path.endswith(".md"):
|
||||
embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config)
|
||||
embedding = MarkdownEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
elif self.file_path.endswith(".csv"):
|
||||
embedding = CSVEmbedding(file_path=self.file_path, model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config)
|
||||
elif self.vector_store_type == "default":
|
||||
embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config)
|
||||
embedding = CSVEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
elif self.file_type == "default":
|
||||
embedding = MarkdownEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
return self.knowledge_embedding_client.similar_search(text, topk)
|
||||
return self.knowledge_embedding_client.similar_search(text, topk)
|
||||
|
||||
def knowledge_persist_initialization(self, append_mode):
|
||||
documents = self._load_knownlege(self.file_path)
|
||||
self.vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
self.vector_client.load_document(documents)
|
||||
return self.vector_client
|
||||
|
||||
def _load_knownlege(self, path):
|
||||
docments = []
|
||||
for root, _, files in os.walk(path, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
docs = self._load_file(filename)
|
||||
new_docs = []
|
||||
for doc in docs:
|
||||
doc.metadata = {
|
||||
"source": doc.metadata["source"].replace(DATASETS_DIR, "")
|
||||
}
|
||||
print("doc is embedding...", doc.metadata)
|
||||
new_docs.append(doc)
|
||||
docments += new_docs
|
||||
return docments
|
||||
|
||||
def _load_file(self, filename):
|
||||
if filename.lower().endswith(".md"):
|
||||
loader = TextLoader(filename)
|
||||
text_splitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
)
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
i = 0
|
||||
for d in docs:
|
||||
content = markdown.markdown(d.page_content)
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
for tag in soup(["!doctype", "meta", "i.fa"]):
|
||||
tag.extract()
|
||||
docs[i].page_content = soup.get_text()
|
||||
docs[i].page_content = docs[i].page_content.replace("\n", " ")
|
||||
i += 1
|
||||
elif filename.lower().endswith(".pdf"):
|
||||
loader = PyPDFLoader(filename)
|
||||
textsplitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
)
|
||||
docs = loader.load_and_split(textsplitter)
|
||||
i = 0
|
||||
for d in docs:
|
||||
docs[i].page_content = d.page_content.replace("\n", " ").replace(
|
||||
"<EFBFBD>", ""
|
||||
)
|
||||
i += 1
|
||||
else:
|
||||
loader = TextLoader(filename)
|
||||
text_splitor = CHNDocumentSplitter(sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
|
||||
docs = loader.load_and_split(text_splitor)
|
||||
return docs
|
||||
|
@@ -1,13 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.schema import Document
|
||||
import markdown
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
|
||||
class MarkdownEmbedding(SourceEmbedding):
|
||||
@@ -24,20 +27,42 @@ class MarkdownEmbedding(SourceEmbedding):
|
||||
def read(self):
|
||||
"""Load from markdown path."""
|
||||
loader = TextLoader(self.file_path)
|
||||
return loader.load()
|
||||
text_splitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
@register
|
||||
def read_batch(self):
|
||||
"""Load from markdown path."""
|
||||
docments = []
|
||||
for root, _, files in os.walk(self.file_path, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
loader = TextLoader(filename)
|
||||
# text_splitor = CHNDocumentSplitter(chunk_size=1000, chunk_overlap=20, length_function=len)
|
||||
# docs = loader.load_and_split()
|
||||
docs = loader.load()
|
||||
# 更新metadata数据
|
||||
new_docs = []
|
||||
for doc in docs:
|
||||
doc.metadata = {
|
||||
"source": doc.metadata["source"].replace(self.file_path, "")
|
||||
}
|
||||
print("doc is embedding ... ", doc.metadata)
|
||||
new_docs.append(doc)
|
||||
docments += new_docs
|
||||
return docments
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
content = markdown.markdown(d.page_content)
|
||||
soup = BeautifulSoup(content, 'html.parser')
|
||||
for tag in soup(['!doctype', 'meta', 'i.fa']):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
for tag in soup(["!doctype", "meta", "i.fa"]):
|
||||
tag.extract()
|
||||
documents[i].page_content = soup.get_text()
|
||||
documents[i].page_content = documents[i].page_content.replace(" ", "").replace("\n", " ")
|
||||
documents[i].page_content = documents[i].page_content.replace("\n", " ")
|
||||
i += 1
|
||||
return documents
|
||||
|
||||
|
||||
|
||||
|
@@ -5,7 +5,9 @@ from typing import List
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
|
||||
class PDFEmbedding(SourceEmbedding):
|
||||
@@ -17,22 +19,21 @@ class PDFEmbedding(SourceEmbedding):
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
# SourceEmbedding(file_path =file_path, );
|
||||
SourceEmbedding(file_path, model_name, vector_store_config)
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from pdf path."""
|
||||
# loader = UnstructuredPaddlePDFLoader(self.file_path)
|
||||
loader = PyPDFLoader(self.file_path)
|
||||
return loader.load()
|
||||
textsplitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
documents[i].page_content = d.page_content.replace(" ", "").replace("\n", "")
|
||||
documents[i].page_content = d.page_content.replace("\n", "")
|
||||
i += 1
|
||||
return documents
|
||||
|
||||
|
||||
|
||||
|
55
pilot/source_embedding/pdf_loader.py
Normal file
55
pilot/source_embedding/pdf_loader.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Loader that loads image files."""
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import fitz
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
|
||||
class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
|
||||
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
|
||||
|
||||
def _get_elements(self) -> List:
|
||||
def pdf_ocr_txt(filepath, dir_path="tmp_files"):
|
||||
full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
|
||||
if not os.path.exists(full_dir_path):
|
||||
os.makedirs(full_dir_path)
|
||||
filename = os.path.split(filepath)[-1]
|
||||
ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False)
|
||||
doc = fitz.open(filepath)
|
||||
txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename))
|
||||
img_name = os.path.join(full_dir_path, ".tmp.png")
|
||||
with open(txt_file_path, "w", encoding="utf-8") as fout:
|
||||
for i in range(doc.page_count):
|
||||
page = doc[i]
|
||||
text = page.get_text("")
|
||||
fout.write(text)
|
||||
fout.write("\n")
|
||||
|
||||
img_list = page.get_images()
|
||||
for img in img_list:
|
||||
pix = fitz.Pixmap(doc, img[0])
|
||||
|
||||
pix.save(img_name)
|
||||
|
||||
result = ocr.ocr(img_name)
|
||||
ocr_result = [i[1][0] for line in result for i in line]
|
||||
fout.write("\n".join(ocr_result))
|
||||
os.remove(img_name)
|
||||
return txt_file_path
|
||||
|
||||
txt_file_path = pdf_ocr_txt(self.file_path)
|
||||
from unstructured.partition.text import partition_text
|
||||
|
||||
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf"
|
||||
)
|
||||
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
|
||||
docs = loader.load()
|
||||
for doc in docs:
|
||||
print(doc)
|
@@ -50,7 +50,7 @@
|
||||
#
|
||||
# # text_embeddings = Text2Vectors()
|
||||
# mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_k"})
|
||||
#
|
||||
#
|
||||
# mivuls.insert(["textc","tezt2"])
|
||||
# print("success")
|
||||
# ct
|
||||
@@ -58,4 +58,4 @@
|
||||
# # docs,
|
||||
# # embedding=embeddings,
|
||||
# # connection_args={"host": "127.0.0.1", "port": "19530", "alias": "default"}
|
||||
# # )
|
||||
# # )
|
||||
|
@@ -1,14 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from pilot.configs.config import Config
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
registered_methods = []
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def register(method):
|
||||
@@ -22,21 +23,30 @@ class SourceEmbedding(ABC):
|
||||
Implementations should implement the method
|
||||
"""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
embedding_args: Optional[Dict] = None,
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.embedding_args = embedding_args
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
||||
self.vector_store_config["vector_store_name"] + ".vectordb")
|
||||
self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
||||
|
||||
vector_store_config["embeddings"] = self.embeddings
|
||||
self.vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, vector_store_config
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@register
|
||||
def read(self) -> List[ABC]:
|
||||
"""read datasource into document objects."""
|
||||
|
||||
@register
|
||||
def data_process(self, text):
|
||||
"""pre process data."""
|
||||
@@ -54,25 +64,33 @@ class SourceEmbedding(ABC):
|
||||
@register
|
||||
def index_to_store(self, docs):
|
||||
"""index to vector store"""
|
||||
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
||||
self.vector_store_config["vector_store_name"] + ".vectordb")
|
||||
self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
|
||||
self.vector_store.persist()
|
||||
self.vector_client.load_document(docs)
|
||||
|
||||
@register
|
||||
def similar_search(self, doc, topk):
|
||||
"""vector store similarity_search"""
|
||||
|
||||
return self.vector_store_client.similarity_search(doc, topk)
|
||||
return self.vector_client.similar_search(doc, topk)
|
||||
|
||||
def source_embedding(self):
|
||||
if 'read' in registered_methods:
|
||||
if "read" in registered_methods:
|
||||
text = self.read()
|
||||
if 'data_process' in registered_methods:
|
||||
if "data_process" in registered_methods:
|
||||
text = self.data_process(text)
|
||||
if 'text_split' in registered_methods:
|
||||
if "text_split" in registered_methods:
|
||||
self.text_split(text)
|
||||
if 'text_to_vector' in registered_methods:
|
||||
if "text_to_vector" in registered_methods:
|
||||
self.text_to_vector(text)
|
||||
if 'index_to_store' in registered_methods:
|
||||
if "index_to_store" in registered_methods:
|
||||
self.index_to_store(text)
|
||||
|
||||
def batch_embedding(self):
|
||||
if "read_batch" in registered_methods:
|
||||
text = self.read_batch()
|
||||
if "data_process" in registered_methods:
|
||||
text = self.data_process(text)
|
||||
if "text_split" in registered_methods:
|
||||
self.text_split(text)
|
||||
if "text_to_vector" in registered_methods:
|
||||
self.text_to_vector(text)
|
||||
if "index_to_store" in registered_methods:
|
||||
self.index_to_store(text)
|
||||
|
@@ -1,10 +1,11 @@
|
||||
from typing import List
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
|
||||
|
||||
class URLEmbedding(SourceEmbedding):
|
||||
@@ -20,19 +21,19 @@ class URLEmbedding(SourceEmbedding):
|
||||
def read(self):
|
||||
"""Load from url path."""
|
||||
loader = WebBaseLoader(web_path=self.file_path)
|
||||
return loader.load()
|
||||
text_splitor = CharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=20, length_function=len
|
||||
)
|
||||
return loader.load_and_split(text_splitor)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
content = d.page_content.replace("\n", "")
|
||||
soup = BeautifulSoup(content, 'html.parser')
|
||||
for tag in soup(['!doctype', 'meta']):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
for tag in soup(["!doctype", "meta"]):
|
||||
tag.extract()
|
||||
documents[i].page_content = soup.get_text()
|
||||
i += 1
|
||||
return documents
|
||||
|
||||
|
||||
|
||||
|
@@ -1,27 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||
server_error_msg = (
|
||||
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||
)
|
||||
|
||||
handler = None
|
||||
|
||||
|
||||
def get_gpu_memory(max_gpus=None):
|
||||
gpu_memory = []
|
||||
num_gpus = (
|
||||
torch.cuda.device_count()
|
||||
if max_gpus is None
|
||||
if max_gpus is None
|
||||
else min(max_gpus, torch.cuda.device_count())
|
||||
)
|
||||
|
||||
@@ -29,12 +30,11 @@ def get_gpu_memory(max_gpus=None):
|
||||
with torch.cuda.device(gpu_id):
|
||||
device = torch.cuda.current_device()
|
||||
gpu_properties = torch.cuda.get_device_properties(device)
|
||||
total_memory = gpu_properties.total_memory / (1024 ** 3)
|
||||
allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
total_memory = gpu_properties.total_memory / (1024**3)
|
||||
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
|
||||
available_memory = total_memory - allocated_memory
|
||||
gpu_memory.append(available_memory)
|
||||
return gpu_memory
|
||||
|
||||
return gpu_memory
|
||||
|
||||
|
||||
def build_logger(logger_name, logger_filename):
|
||||
@@ -47,7 +47,7 @@ def build_logger(logger_name, logger_filename):
|
||||
|
||||
# Set the format of root handlers
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO, encoding='utf-8')
|
||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
@@ -70,7 +70,8 @@ def build_logger(logger_name, logger_filename):
|
||||
os.makedirs(LOGDIR, exist_ok=True)
|
||||
filename = os.path.join(LOGDIR, logger_filename)
|
||||
handler = logging.handlers.TimedRotatingFileHandler(
|
||||
filename, when='D', utc=True)
|
||||
filename, when="D", utc=True
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
@@ -84,35 +85,36 @@ class StreamToLogger(object):
|
||||
"""
|
||||
Fake file-like stream object that redirects writes to a logger instance.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, log_level=logging.INFO):
|
||||
self.terminal = sys.stdout
|
||||
self.logger = logger
|
||||
self.log_level = log_level
|
||||
self.linebuf = ''
|
||||
self.linebuf = ""
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.terminal, attr)
|
||||
|
||||
def write(self, buf):
|
||||
temp_linebuf = self.linebuf + buf
|
||||
self.linebuf = ''
|
||||
self.linebuf = ""
|
||||
for line in temp_linebuf.splitlines(True):
|
||||
# From the io.TextIOWrapper docs:
|
||||
# On output, if newline is None, any '\n' characters written
|
||||
# are translated to the system default line separator.
|
||||
# By default sys.stdout.write() expects '\n' newlines and then
|
||||
# translates them so this is still cross platform.
|
||||
if line[-1] == '\n':
|
||||
encoded_message = line.encode('utf-8', 'ignore').decode('utf-8')
|
||||
if line[-1] == "\n":
|
||||
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
else:
|
||||
self.linebuf += line
|
||||
|
||||
def flush(self):
|
||||
if self.linebuf != '':
|
||||
encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8')
|
||||
if self.linebuf != "":
|
||||
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
self.linebuf = ''
|
||||
self.linebuf = ""
|
||||
|
||||
|
||||
def disable_torch_init():
|
||||
@@ -120,6 +122,7 @@ def disable_torch_init():
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
"""
|
||||
import torch
|
||||
|
||||
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
@@ -128,4 +131,3 @@ def pretty_print_semaphore(semaphore):
|
||||
if semaphore is None:
|
||||
return "None"
|
||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
||||
|
||||
|
32
pilot/vector_store/chroma_store.py
Normal file
32
pilot/vector_store/chroma_store.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from pilot.logs import logger
|
||||
from pilot.vector_store.vector_store_base import VectorStoreBase
|
||||
|
||||
|
||||
class ChromaStore(VectorStoreBase):
|
||||
"""chroma database"""
|
||||
|
||||
def __init__(self, ctx: {}) -> None:
|
||||
self.ctx = ctx
|
||||
self.embeddings = ctx["embeddings"]
|
||||
self.persist_dir = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, ctx["vector_store_name"] + ".vectordb"
|
||||
)
|
||||
self.vector_store_client = Chroma(
|
||||
persist_directory=self.persist_dir, embedding_function=self.embeddings
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk) -> None:
|
||||
logger.info("ChromaStore similar search")
|
||||
return self.vector_store_client.similarity_search(text, topk)
|
||||
|
||||
def load_document(self, documents):
|
||||
logger.info("ChromaStore load document")
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
|
||||
self.vector_store_client.persist()
|
19
pilot/vector_store/connector.py
Normal file
19
pilot/vector_store/connector.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pilot.vector_store.chroma_store import ChromaStore
|
||||
from pilot.vector_store.milvus_store import MilvusStore
|
||||
|
||||
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore}
|
||||
|
||||
|
||||
class VectorStoreConnector:
|
||||
"""vector store connector, can connect different vector db provided load document api and similar search api"""
|
||||
|
||||
def __init__(self, vector_store_type, ctx: {}) -> None:
|
||||
self.ctx = ctx
|
||||
self.connector_class = connector[vector_store_type]
|
||||
self.client = self.connector_class(ctx)
|
||||
|
||||
def load_document(self, docs):
|
||||
self.client.load_document(docs)
|
||||
|
||||
def similar_search(self, docs, topk):
|
||||
return self.client.similar_search(docs, topk)
|
@@ -3,14 +3,16 @@
|
||||
|
||||
import os
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
from pilot.configs.model_config import DATASETS_DIR, VECTORE_PATH
|
||||
from pilot.model.vicuna_llm import VicunaEmbeddingLLM
|
||||
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
embeddings = VicunaEmbeddingLLM()
|
||||
|
||||
|
||||
def knownledge_tovec(filename):
|
||||
with open(filename, "r") as f:
|
||||
knownledge = f.read()
|
||||
@@ -22,48 +24,64 @@ def knownledge_tovec(filename):
|
||||
)
|
||||
return docsearch
|
||||
|
||||
|
||||
def knownledge_tovec_st(filename):
|
||||
""" Use sentence transformers to embedding the document.
|
||||
https://github.com/UKPLab/sentence-transformers
|
||||
"""Use sentence transformers to embedding the document.
|
||||
https://github.com/UKPLab/sentence-transformers
|
||||
"""
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"]
|
||||
)
|
||||
|
||||
with open(filename, "r") as f:
|
||||
knownledge = f.read()
|
||||
|
||||
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
|
||||
texts = text_splitter.split_text(knownledge)
|
||||
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))])
|
||||
docsearch = Chroma.from_texts(
|
||||
texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]
|
||||
)
|
||||
return docsearch
|
||||
|
||||
|
||||
def load_knownledge_from_doc():
|
||||
"""Loader Knownledge from current datasets
|
||||
# TODO if the vector store is exists, just use it.
|
||||
# TODO if the vector store is exists, just use it.
|
||||
"""
|
||||
|
||||
if not os.path.exists(DATASETS_DIR):
|
||||
print("Not Exists Local DataSets, We will answers the Question use model default.")
|
||||
print(
|
||||
"Not Exists Local DataSets, We will answers the Question use model default."
|
||||
)
|
||||
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"]
|
||||
)
|
||||
|
||||
files = os.listdir(DATASETS_DIR)
|
||||
for file in files:
|
||||
if not os.path.isdir(file):
|
||||
if not os.path.isdir(file):
|
||||
filename = os.path.join(DATASETS_DIR, file)
|
||||
with open(filename, "r") as f:
|
||||
knownledge = f.read()
|
||||
knownledge = f.read()
|
||||
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0)
|
||||
texts = text_splitter.split_text(knownledge)
|
||||
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))],
|
||||
persist_directory=os.path.join(VECTORE_PATH, ".vectore"))
|
||||
docsearch = Chroma.from_texts(
|
||||
texts,
|
||||
embeddings,
|
||||
metadatas=[{"source": str(i)} for i in range(len(texts))],
|
||||
persist_directory=os.path.join(VECTORE_PATH, ".vectore"),
|
||||
)
|
||||
return docsearch
|
||||
|
||||
|
||||
def get_vector_storelist():
|
||||
if not os.path.exists(VECTORE_PATH):
|
||||
return []
|
||||
return os.listdir(VECTORE_PATH)
|
||||
return os.listdir(VECTORE_PATH)
|
||||
|
@@ -2,58 +2,72 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import copy
|
||||
from typing import Optional, List, Dict
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader
|
||||
|
||||
from langchain.chains import VectorDBQA
|
||||
from langchain.document_loaders import (
|
||||
TextLoader,
|
||||
UnstructuredFileLoader,
|
||||
UnstructuredPDFLoader,
|
||||
)
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
LLM_MODEL_CONFIG,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
VECTORE_PATH,
|
||||
)
|
||||
|
||||
|
||||
class KnownLedge2Vector:
|
||||
|
||||
"""KnownLedge2Vector class is order to load document to vector
|
||||
"""KnownLedge2Vector class is order to load document to vector
|
||||
and persist to vector store.
|
||||
|
||||
Args:
|
||||
|
||||
Args:
|
||||
- model_name
|
||||
|
||||
Usage:
|
||||
k2v = KnownLedge2Vector()
|
||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||
print(persist_dir)
|
||||
for s, dc in k2v.query("what is oceanbase?"):
|
||||
print(s, dc.page_content, dc.metadata)
|
||||
|
||||
"""
|
||||
embeddings: object = None
|
||||
|
||||
embeddings: object = None
|
||||
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
|
||||
def __init__(self, model_name=None) -> None:
|
||||
if not model_name:
|
||||
# use default embedding model
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
|
||||
def init_vector_store(self):
|
||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||
print("Vector store Persist address is: ", persist_dir)
|
||||
if os.path.exists(persist_dir):
|
||||
# Loader from local file.
|
||||
print("Loader data from local persist vector file...")
|
||||
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
||||
vector_store = Chroma(
|
||||
persist_directory=persist_dir, embedding_function=self.embeddings
|
||||
)
|
||||
# vector_store.add_documents(documents=documents)
|
||||
else:
|
||||
documents = self.load_knownlege()
|
||||
# reinit
|
||||
vector_store = Chroma.from_documents(documents=documents,
|
||||
embedding=self.embeddings,
|
||||
persist_directory=persist_dir)
|
||||
# reinit
|
||||
vector_store = Chroma.from_documents(
|
||||
documents=documents,
|
||||
embedding=self.embeddings,
|
||||
persist_directory=persist_dir,
|
||||
)
|
||||
vector_store.persist()
|
||||
return vector_store
|
||||
return vector_store
|
||||
|
||||
def load_knownlege(self):
|
||||
docments = []
|
||||
@@ -61,10 +75,12 @@ class KnownLedge2Vector:
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
docs = self._load_file(filename)
|
||||
# update metadata.
|
||||
new_docs = []
|
||||
# update metadata.
|
||||
new_docs = []
|
||||
for doc in docs:
|
||||
doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")}
|
||||
doc.metadata = {
|
||||
"source": doc.metadata["source"].replace(DATASETS_DIR, "")
|
||||
}
|
||||
print("Documents to vector running, please wait...", doc.metadata)
|
||||
new_docs.append(doc)
|
||||
docments += new_docs
|
||||
@@ -73,7 +89,7 @@ class KnownLedge2Vector:
|
||||
def _load_file(self, filename):
|
||||
# Loader file
|
||||
if filename.lower().endswith(".pdf"):
|
||||
loader = UnstructuredFileLoader(filename)
|
||||
loader = UnstructuredFileLoader(filename)
|
||||
text_splitor = CharacterTextSplitter()
|
||||
docs = loader.load_and_split(text_splitor)
|
||||
else:
|
||||
@@ -86,13 +102,10 @@ class KnownLedge2Vector:
|
||||
"""Load data from url address"""
|
||||
pass
|
||||
|
||||
|
||||
def query(self, q):
|
||||
"""Query similar doc from Vector """
|
||||
"""Query similar doc from Vector"""
|
||||
vector_store = self.init_vector_store()
|
||||
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
|
||||
for doc in docs:
|
||||
dc, s = doc
|
||||
yield s, dc
|
||||
|
||||
|
@@ -1,33 +1,55 @@
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
from pymilvus import DataType, FieldSchema, CollectionSchema, connections, Collection
|
||||
from langchain.docstore.document import Document
|
||||
from pymilvus import Collection, DataType, connections
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.vector_store.vector_store_base import VectorStoreBase
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class MilvusStore(VectorStoreBase):
|
||||
def __init__(self, cfg: {}) -> None:
|
||||
"""Construct a milvus memory storage connection.
|
||||
"""Milvus database"""
|
||||
|
||||
def __init__(self, ctx: {}) -> None:
|
||||
"""init a milvus storage connection.
|
||||
|
||||
Args:
|
||||
cfg (Config): Auto-GPT global config.
|
||||
ctx ({}): MilvusStore global config.
|
||||
"""
|
||||
# self.configure(cfg)
|
||||
|
||||
connect_kwargs = {}
|
||||
self.uri = None
|
||||
self.uri = cfg["url"]
|
||||
self.port = cfg["port"]
|
||||
self.username = cfg.get("username", None)
|
||||
self.password = cfg.get("password", None)
|
||||
self.collection_name = cfg["table_name"]
|
||||
self.password = cfg.get("secure", None)
|
||||
self.uri = CFG.MILVUS_URL
|
||||
self.port = CFG.MILVUS_PORT
|
||||
self.username = CFG.MILVUS_USERNAME
|
||||
self.password = CFG.MILVUS_PASSWORD
|
||||
self.collection_name = ctx.get("vector_store_name", None)
|
||||
self.secure = ctx.get("secure", None)
|
||||
self.embedding = ctx.get("embeddings", None)
|
||||
self.fields = []
|
||||
|
||||
# use HNSW by default.
|
||||
self.index_params = {
|
||||
"metric_type": "IP",
|
||||
"metric_type": "L2",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
# use HNSW by default.
|
||||
self.index_params_map = {
|
||||
"IVF_FLAT": {"params": {"nprobe": 10}},
|
||||
"IVF_SQ8": {"params": {"nprobe": 10}},
|
||||
"IVF_PQ": {"params": {"nprobe": 10}},
|
||||
"HNSW": {"params": {"ef": 10}},
|
||||
"RHNSW_FLAT": {"params": {"ef": 10}},
|
||||
"RHNSW_SQ": {"params": {"ef": 10}},
|
||||
"RHNSW_PQ": {"params": {"ef": 10}},
|
||||
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
|
||||
"ANNOY": {"params": {"search_k": 10}},
|
||||
}
|
||||
|
||||
self.text_field = "content"
|
||||
|
||||
if (self.username is None) != (self.password is None):
|
||||
raise ValueError(
|
||||
@@ -38,54 +60,263 @@ class MilvusStore(VectorStoreBase):
|
||||
connect_kwargs["password"] = self.password
|
||||
|
||||
connections.connect(
|
||||
**connect_kwargs,
|
||||
host=self.uri or "127.0.0.1",
|
||||
port=self.port or "19530",
|
||||
alias="default"
|
||||
# secure=self.secure,
|
||||
)
|
||||
|
||||
self.init_schema()
|
||||
|
||||
def init_schema(self) -> None:
|
||||
"""Initialize collection in milvus database."""
|
||||
fields = [
|
||||
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=384),
|
||||
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
]
|
||||
|
||||
# create collection if not exist and load it.
|
||||
self.schema = CollectionSchema(fields, "db-gpt memory storage")
|
||||
self.collection = Collection(self.collection_name, self.schema)
|
||||
self.index_params = {
|
||||
"metric_type": "IP",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
# create index if not exist.
|
||||
if not self.collection.has_index():
|
||||
self.collection.release()
|
||||
self.collection.create_index(
|
||||
"vector",
|
||||
self.index_params,
|
||||
index_name="vector",
|
||||
def init_schema_and_load(self, vector_name, documents):
|
||||
"""Create a Milvus collection, indexes it with HNSW, load document.
|
||||
Args:
|
||||
vector_name (Embeddings): your collection name.
|
||||
documents (List[str]): Text to insert.
|
||||
Returns:
|
||||
VectorStore: The MilvusStore vector store.
|
||||
"""
|
||||
try:
|
||||
from pymilvus import (
|
||||
Collection,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
FieldSchema,
|
||||
connections,
|
||||
)
|
||||
self.collection.load()
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
if not connections.has_connection("default"):
|
||||
connections.connect(
|
||||
host=self.uri or "127.0.0.1",
|
||||
port=self.port or "19530",
|
||||
alias="default"
|
||||
# secure=self.secure,
|
||||
)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
embeddings = self.embedding.embed_query(texts[0])
|
||||
dim = len(embeddings)
|
||||
# Generate unique names
|
||||
primary_field = "pk_id"
|
||||
vector_field = "vector"
|
||||
text_field = "content"
|
||||
self.text_field = text_field
|
||||
collection_name = vector_name
|
||||
fields = []
|
||||
# Determine metadata schema
|
||||
# if metadatas:
|
||||
# # Check if all metadata keys line up
|
||||
# key = metadatas[0].keys()
|
||||
# for x in metadatas:
|
||||
# if key != x.keys():
|
||||
# raise ValueError(
|
||||
# "Mismatched metadata. "
|
||||
# "Make sure all metadata has the same keys and datatype."
|
||||
# )
|
||||
# # Create FieldSchema for each entry in singular metadata.
|
||||
# for key, value in metadatas[0].items():
|
||||
# # Infer the corresponding datatype of the metadata
|
||||
# dtype = infer_dtype_bydata(value)
|
||||
# if dtype == DataType.UNKNOWN:
|
||||
# raise ValueError(f"Unrecognized datatype for {key}.")
|
||||
# elif dtype == DataType.VARCHAR:
|
||||
# # Find out max length text based metadata
|
||||
# max_length = 0
|
||||
# for subvalues in metadatas:
|
||||
# max_length = max(max_length, len(subvalues[key]))
|
||||
# fields.append(
|
||||
# FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1)
|
||||
# )
|
||||
# else:
|
||||
# fields.append(FieldSchema(key, dtype))
|
||||
|
||||
# def add(self, data) -> str:
|
||||
# Find out max length of texts
|
||||
max_length = 0
|
||||
for y in texts:
|
||||
max_length = max(max_length, len(y))
|
||||
# Create the text field
|
||||
fields.append(
|
||||
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1)
|
||||
)
|
||||
# primary key field
|
||||
fields.append(
|
||||
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
|
||||
)
|
||||
# vector field
|
||||
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
||||
# milvus the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
# Create the collection
|
||||
collection = Collection(collection_name, schema)
|
||||
self.col = collection
|
||||
# index parameters for the collection
|
||||
index = self.index_params
|
||||
# milvus index
|
||||
collection.create_index(vector_field, index)
|
||||
schema = collection.schema
|
||||
for x in schema.fields:
|
||||
self.fields.append(x.name)
|
||||
if x.auto_id:
|
||||
self.fields.remove(x.name)
|
||||
if x.is_primary:
|
||||
self.primary_field = x.name
|
||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||
self.vector_field = x.name
|
||||
self._add_texts(texts, metadatas)
|
||||
|
||||
return self.collection_name
|
||||
|
||||
# def init_schema(self) -> None:
|
||||
# """Initialize collection in milvus database."""
|
||||
# fields = [
|
||||
# FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||
# FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]),
|
||||
# FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
# ]
|
||||
#
|
||||
# # create collection if not exist and load it.
|
||||
# self.schema = CollectionSchema(fields, "db-gpt memory storage")
|
||||
# self.collection = Collection(self.collection_name, self.schema)
|
||||
# self.index_params_map = {
|
||||
# "IVF_FLAT": {"params": {"nprobe": 10}},
|
||||
# "IVF_SQ8": {"params": {"nprobe": 10}},
|
||||
# "IVF_PQ": {"params": {"nprobe": 10}},
|
||||
# "HNSW": {"params": {"ef": 10}},
|
||||
# "RHNSW_FLAT": {"params": {"ef": 10}},
|
||||
# "RHNSW_SQ": {"params": {"ef": 10}},
|
||||
# "RHNSW_PQ": {"params": {"ef": 10}},
|
||||
# "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
|
||||
# "ANNOY": {"params": {"search_k": 10}},
|
||||
# }
|
||||
#
|
||||
# self.index_params = {
|
||||
# "metric_type": "IP",
|
||||
# "index_type": "HNSW",
|
||||
# "params": {"M": 8, "efConstruction": 64},
|
||||
# }
|
||||
# # create index if not exist.
|
||||
# if not self.collection.has_index():
|
||||
# self.collection.release()
|
||||
# self.collection.create_index(
|
||||
# "vector",
|
||||
# self.index_params,
|
||||
# index_name="vector",
|
||||
# )
|
||||
# info = self.collection.describe()
|
||||
# self.collection.load()
|
||||
|
||||
# def insert(self, text, model_config) -> str:
|
||||
# """Add an embedding of data into milvus.
|
||||
#
|
||||
# Args:
|
||||
# data (str): The raw text to construct embedding index.
|
||||
#
|
||||
# text (str): The raw text to construct embedding index.
|
||||
# Returns:
|
||||
# str: log.
|
||||
# """
|
||||
# embedding = get_ada_embedding(data)
|
||||
# result = self.collection.insert([[embedding], [data]])
|
||||
# # embedding = get_ada_embedding(data)
|
||||
# embeddings = HuggingFaceEmbeddings(model_name=self.model_config["model_name"])
|
||||
# result = self.collection.insert([embeddings.embed_documents(text), text])
|
||||
# _text = (
|
||||
# "Inserting data into memory at primary key: "
|
||||
# f"{result.primary_keys[0]}:\n data: {data}"
|
||||
# f"{result.primary_keys[0]}:\n data: {text}"
|
||||
# )
|
||||
# return _text
|
||||
# return _text
|
||||
|
||||
def _add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
partition_name: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""add text data into Milvus."""
|
||||
insert_dict: Any = {self.text_field: list(texts)}
|
||||
try:
|
||||
insert_dict[self.vector_field] = self.embedding.embed_documents(list(texts))
|
||||
except NotImplementedError:
|
||||
insert_dict[self.vector_field] = [
|
||||
self.embedding.embed_query(x) for x in texts
|
||||
]
|
||||
# Collect the metadata into the insert dict.
|
||||
if len(self.fields) > 2 and metadatas is not None:
|
||||
for d in metadatas:
|
||||
for key, value in d.items():
|
||||
if key in self.fields:
|
||||
insert_dict.setdefault(key, []).append(value)
|
||||
# Convert dict to list of lists for insertion
|
||||
insert_list = [insert_dict[x] for x in self.fields]
|
||||
# Insert into the collection.
|
||||
res = self.col.insert(
|
||||
insert_list, partition_name=partition_name, timeout=timeout
|
||||
)
|
||||
# make sure data is searchable.
|
||||
self.col.flush()
|
||||
return res.primary_keys
|
||||
|
||||
def load_document(self, documents) -> None:
|
||||
"""load document in vector database."""
|
||||
self.init_schema_and_load(self.collection_name, documents)
|
||||
|
||||
def similar_search(self, text, topk) -> None:
|
||||
"""similar_search in vector database."""
|
||||
self.col = Collection(self.collection_name)
|
||||
schema = self.col.schema
|
||||
for x in schema.fields:
|
||||
self.fields.append(x.name)
|
||||
if x.auto_id:
|
||||
self.fields.remove(x.name)
|
||||
if x.is_primary:
|
||||
self.primary_field = x.name
|
||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||
self.vector_field = x.name
|
||||
_, docs_and_scores = self._search(text, topk)
|
||||
return [doc for doc, _, _ in docs_and_scores]
|
||||
|
||||
def _search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
partition_names: Optional[List[str]] = None,
|
||||
round_decimal: int = -1,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]:
|
||||
self.col.load()
|
||||
# use default index params.
|
||||
if param is None:
|
||||
index_type = self.col.indexes[0].params["index_type"]
|
||||
param = self.index_params_map[index_type]
|
||||
# query text embedding.
|
||||
data = [self.embedding.embed_query(query)]
|
||||
# Determine result metadata fields.
|
||||
output_fields = self.fields[:]
|
||||
output_fields.remove(self.vector_field)
|
||||
# milvus search.
|
||||
res = self.col.search(
|
||||
data,
|
||||
self.vector_field,
|
||||
param,
|
||||
k,
|
||||
expr=expr,
|
||||
output_fields=output_fields,
|
||||
partition_names=partition_names,
|
||||
round_decimal=round_decimal,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
ret = []
|
||||
for result in res[0]:
|
||||
meta = {x: result.entity.get(x) for x in output_fields}
|
||||
ret.append(
|
||||
(
|
||||
Document(page_content=meta.pop(self.text_field), metadata=meta),
|
||||
result.distance,
|
||||
result.id,
|
||||
)
|
||||
)
|
||||
|
||||
return data[0], ret
|
||||
|
@@ -2,8 +2,14 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class VectorStoreBase(ABC):
|
||||
"""base class for vector store database"""
|
||||
|
||||
@abstractmethod
|
||||
def init_schema(self) -> None:
|
||||
def load_document(self, documents) -> None:
|
||||
"""load document in vector database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def similar_search(self, text, topk) -> None:
|
||||
"""Initialize schema in vector database."""
|
||||
pass
|
||||
pass
|
||||
|
@@ -42,6 +42,7 @@ tenacity==8.2.2
|
||||
peft
|
||||
pycocoevalcap
|
||||
sentence-transformers
|
||||
cpm_kernels
|
||||
umap-learn
|
||||
notebook
|
||||
gradio==3.23
|
||||
@@ -60,6 +61,15 @@ gTTS==2.3.1
|
||||
langchain
|
||||
nltk
|
||||
python-dotenv==1.0.0
|
||||
pymilvus==2.2.1
|
||||
vcrpy
|
||||
chromadb
|
||||
markdown2
|
||||
colorama
|
||||
playsound
|
||||
distro
|
||||
pypdf
|
||||
milvus-cli==0.3.2
|
||||
|
||||
# Testing dependencies
|
||||
pytest
|
||||
@@ -69,10 +79,4 @@ pytest-benchmark
|
||||
pytest-cov
|
||||
pytest-integration
|
||||
pytest-mock
|
||||
vcrpy
|
||||
pytest-recording
|
||||
chromadb
|
||||
markdown2
|
||||
colorama
|
||||
playsound
|
||||
distro
|
||||
pytest-recording
|
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.plugins import (
|
||||
@@ -15,10 +15,13 @@ PLUGIN_TEST_ZIP_FILE = "Auto-GPT-Plugin-Test-master.zip"
|
||||
PLUGIN_TEST_INIT_PY = "Auto-GPT-Plugin-Test-master/src/auto_gpt_vicuna/__init__.py"
|
||||
PLUGIN_TEST_OPENAI = "https://weathergpt.vercel.app/"
|
||||
|
||||
|
||||
def test_inspect_zip_for_modules():
|
||||
current_dir = os.getcwd()
|
||||
print(current_dir)
|
||||
result = inspect_zip_for_modules(str(f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/{PLUGIN_TEST_ZIP_FILE}"))
|
||||
result = inspect_zip_for_modules(
|
||||
str(f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/{PLUGIN_TEST_ZIP_FILE}")
|
||||
)
|
||||
assert result == [PLUGIN_TEST_INIT_PY]
|
||||
|
||||
|
||||
@@ -99,6 +102,7 @@ def mock_config_openai_plugin():
|
||||
|
||||
class MockConfig:
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
current_dir = os.getcwd()
|
||||
plugins_dir = f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/"
|
||||
plugins_openai = [PLUGIN_TEST_OPENAI]
|
||||
|
60
tools/knowlege_init.py
Normal file
60
tools/knowlege_init.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
LLM_MODEL_CONFIG,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class LocalKnowledgeInit:
|
||||
embeddings: object = None
|
||||
model_name = LLM_MODEL_CONFIG["text2vec"]
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
|
||||
def __init__(self, vector_store_config) -> None:
|
||||
self.vector_store_config = vector_store_config
|
||||
|
||||
def knowledge_persist(self, file_path, append_mode):
|
||||
"""knowledge persist"""
|
||||
kv = KnowledgeEmbedding(
|
||||
file_path=file_path,
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
vector_store = kv.knowledge_persist_initialization(append_mode)
|
||||
return vector_store
|
||||
|
||||
def query(self, q):
|
||||
"""Query similar doc from Vector"""
|
||||
vector_store = self.init_vector_store()
|
||||
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
|
||||
for doc in docs:
|
||||
dc, s = doc
|
||||
yield s, dc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vector_name", type=str, default="default")
|
||||
parser.add_argument("--append", type=bool, default=False)
|
||||
parser.add_argument("--store_type", type=str, default="Chroma")
|
||||
args = parser.parse_args()
|
||||
vector_name = args.vector_name
|
||||
append_mode = args.append
|
||||
store_type = CFG.VECTOR_STORE_TYPE
|
||||
vector_store_config = {"vector_store_name": vector_name}
|
||||
print(vector_store_config)
|
||||
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
||||
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
||||
print("your knowledge embedding success...")
|
Reference in New Issue
Block a user