Merge branch 'dev' of https://github.com/csunny/DB-GPT into dev
Conflicts: plugins/DB-GPT-Plugin-ByteBase.zip
@ -28,8 +28,12 @@ MAX_POSITION_EMBEDDINGS=4096
|
||||
# FAST_LLM_MODEL=chatglm-6b
|
||||
|
||||
|
||||
### EMBEDDINGS
|
||||
## EMBEDDING_MODEL - Model to use for creating embeddings
|
||||
#*******************************************************************#
|
||||
#** EMBEDDING SETTINGS **#
|
||||
#*******************************************************************#
|
||||
EMBEDDING_MODEL=text2vec
|
||||
KNOWLEDGE_CHUNK_SIZE=500
|
||||
KNOWLEDGE_SEARCH_TOP_SIZE=5
|
||||
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
|
||||
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
|
||||
# EMBEDDING_MODEL=all-MiniLM-L6-v2
|
||||
@ -41,7 +45,7 @@ MAX_POSITION_EMBEDDINGS=4096
|
||||
#** DATABASE SETTINGS **#
|
||||
#*******************************************************************#
|
||||
LOCAL_DB_USER=root
|
||||
LOCAL_DB_PASSWORD=aa12345678
|
||||
LOCAL_DB_PASSWORD=aa123456
|
||||
LOCAL_DB_HOST=127.0.0.1
|
||||
LOCAL_DB_PORT=3306
|
||||
|
||||
@ -108,4 +112,4 @@ PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address
|
||||
#*******************************************************************#
|
||||
# ** SUMMARY_CONFIG
|
||||
#*******************************************************************#
|
||||
SUMMARY_CONFIG=VECTOR
|
||||
SUMMARY_CONFIG=FAST
|
76
README.md
@ -1,8 +1,16 @@
|
||||
# DB-GPT 
|
||||
# DB-GPT: A LLM Tool for Multi Databases
|
||||
<div align="center">
|
||||
<p>
|
||||
<a href="https://github.com/csunny/DB-GPT">
|
||||
<img alt="stars" src="https://img.shields.io/github/stars/csunny/db-gpt?style=social" />
|
||||
</a>
|
||||
<a href="https://github.com/csunny/DB-GPT">
|
||||
<img alt="forks" src="https://img.shields.io/github/forks/csunny/db-gpt?style=social" />
|
||||
</a>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
[简体中文](README.zh.md)
|
||||
[**简体中文**](README.zh.md)|[**Discord**](https://discord.gg/ea6BnZkY)
|
||||
</div>
|
||||
|
||||
[](https://star-history.com/#csunny/DB-GPT)
|
||||
|
||||
@ -12,6 +20,15 @@ As large models are released and iterated upon, they are becoming increasingly i
|
||||
|
||||
DB-GPT is an experimental open-source project that uses localized GPT large models to interact with your data and environment. With this solution, you can be assured that there is no risk of data leakage, and your data is 100% private and secure.
|
||||
|
||||
## News
|
||||
|
||||
- [2023/06/01]🔥 On the basis of the Vicuna-13B basic model, task chain calls are implemented through plugins. For example, the implementation of creating a database with a single sentence.[demo](./assets/auto_plugin.gif)
|
||||
- [2023/06/01]🔥 QLoRA guanaco(7b, 13b, 33b) support.
|
||||
- [2023/05/28]🔥 Learning from crawling data from the Internet [demo](./assets/chaturl_en.gif)
|
||||
- [2023/05/21] Generate SQL and execute it automatically. [demo](./assets/auto_sql_en.gif)
|
||||
- [2023/05/15] Chat with documents. [demo](./assets/new_knownledge_en.gif)
|
||||
- [2023/05/06] SQL generation and diagnosis. [demo](./assets/demo_en.gif)
|
||||
|
||||
## Features
|
||||
|
||||
Currently, we have released multiple key features, which are listed below to demonstrate our current capabilities:
|
||||
@ -30,7 +47,7 @@ Currently, we have released multiple key features, which are listed below to dem
|
||||
- Support for unstructured data such as PDF, Markdown, CSV, and WebURL
|
||||
|
||||
- Milti LLMs Support
|
||||
- Supports multiple large language models, currently supporting Vicuna (7b, 13b), ChatGLM-6b (int4, int8)
|
||||
- Supports multiple large language models, currently supporting Vicuna (7b, 13b), ChatGLM-6b (int4, int8), guanaco(7b,13b,33b)
|
||||
- TODO: codegen2, codet5p
|
||||
|
||||
|
||||
@ -38,53 +55,6 @@ Currently, we have released multiple key features, which are listed below to dem
|
||||
|
||||
Run on an RTX 4090 GPU. [YouTube](https://www.youtube.com/watch?v=1PWI6F89LPo)
|
||||
|
||||
### Run
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/demo_en.gif" width="600px" />
|
||||
</p>
|
||||
|
||||
### Run Plugin
|
||||
<p align="center">
|
||||
<img src="./assets/auto_sql_en.gif" width="600px" />
|
||||
</p>
|
||||
|
||||
### SQL Generation
|
||||
|
||||
1. Generate Create Table SQL
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/SQL_Gen_CreateTable_en.png" width="600px" />
|
||||
</p>
|
||||
|
||||
2. Generating executable SQL:To generate executable SQL, first select the corresponding database and then the model can generate SQL based on the corresponding database schema information. The successful result of running it would be demonstrated as follows:
|
||||
<p align="center">
|
||||
<img src="./assets/exeable_en.png" width="600px" />
|
||||
</p>
|
||||
|
||||
### Q&A
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/DB_QA_en.png" width="600px" />
|
||||
</p>
|
||||
|
||||
1. Based on the default built-in knowledge base, question and answer.
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/Knownledge_based_QA_en.png" width="600px" />
|
||||
</p>
|
||||
|
||||
2. Add your own knowledge base.
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/new_knownledge_en.gif" width="600px" />
|
||||
</p>
|
||||
|
||||
3. Learning from crawling data from the Internet
|
||||
|
||||
- TODO
|
||||
|
||||
|
||||
## Introduction
|
||||
DB-GPT creates a vast model operating system using [FastChat](https://github.com/lm-sys/FastChat) and offers a large language model powered by [Vicuna](https://huggingface.co/Tribbiani/vicuna-7b). In addition, we provide private domain knowledge base question-answering capability through LangChain. Furthermore, we also provide support for additional plugins, and our design natively supports the Auto-GPT plugin.
|
||||
|
||||
@ -153,7 +123,7 @@ As our project has the ability to achieve ChatGPT performance of over 85%, there
|
||||
This project relies on a local MySQL database service, which you need to install locally. We recommend using Docker for installation.
|
||||
|
||||
```bash
|
||||
$ docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest
|
||||
$ docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa123456 -dit mysql:latest
|
||||
```
|
||||
We use [Chroma embedding database](https://github.com/chroma-core/chroma) as the default for our vector database, so there is no need for special installation. If you choose to connect to other databases, you can follow our tutorial for installation and configuration.
|
||||
For the entire installation process of DB-GPT, we use the miniconda3 virtual environment. Create a virtual environment and install the Python dependencies.
|
||||
|
78
README.zh.md
@ -1,14 +1,33 @@
|
||||
# DB-GPT 
|
||||
# DB-GPT: 数据库的 LLM 工具
|
||||
<div align="center">
|
||||
<p>
|
||||
<a href="https://github.com/csunny/DB-GPT">
|
||||
<img alt="stars" src="https://img.shields.io/github/stars/csunny/db-gpt?style=social" />
|
||||
</a>
|
||||
<a href="https://github.com/csunny/DB-GPT">
|
||||
<img alt="forks" src="https://img.shields.io/github/forks/csunny/db-gpt?style=social" />
|
||||
</a>
|
||||
</p>
|
||||
|
||||
[English](README.zh.md)
|
||||
[**English**](README.md)|[**Discord**](https://discord.gg/ea6BnZkY)
|
||||
</div>
|
||||
|
||||
[](https://star-history.com/#csunny/DB-GPT)
|
||||
|
||||
## DB-GPT 是什么?
|
||||
|
||||
随着大模型的发布迭代,大模型变得越来越智能,在使用大模型的过程当中,遇到极大的数据安全与隐私挑战。在利用大模型能力的过程中我们的私密数据跟环境需要掌握自己的手里,完全可控,避免任何的数据隐私泄露以及安全风险。基于此,我们发起了DB-GPT项目,为所有以数据库为基础的场景,构建一套完整的私有大模型解决方案。 此方案因为支持本地部署,所以不仅仅可以应用于独立私有环境,而且还可以根据业务模块独立部署隔离,让大模型的能力绝对私有、安全、可控。
|
||||
|
||||
DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地化的GPT大模型与您的数据和环境进行交互,无数据泄露风险,100% 私密,100% 安全。
|
||||
|
||||
## 最新发布
|
||||
|
||||
- [2023/06/01]🔥 在Vicuna-13B基础模型的基础上,通过插件实现任务链调用。例如单句创建数据库的实现.[演示](./assets/dbgpt_bytebase_plugin.gif)
|
||||
- [2023/06/01]🔥 QLoRA guanaco(原驼)支持, 支持4090运行33B
|
||||
- [2023/05/28]🔥根据URL进行对话 [演示](./assets/chat_url_zh.gif)
|
||||
- [2023/05/21] SQL生成与自动执行. [演示](./assets/auto_sql.gif)
|
||||
- [2023/05/15] 知识库对话 [演示](./assets/new_knownledge.gif)
|
||||
- [2023/05/06] SQL生成与诊断 [演示](./assets/演示.gif)
|
||||
|
||||
## 特性一览
|
||||
|
||||
@ -33,59 +52,6 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
|
||||
## 效果演示
|
||||
|
||||
示例通过 RTX 4090 GPU 演示,[YouTube 地址](https://www.youtube.com/watch?v=1PWI6F89LPo)
|
||||
### 运行环境演示
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/演示.gif" width="600px" />
|
||||
</p>
|
||||
|
||||
### SQL 插件化执行
|
||||
<p align="center">
|
||||
<img src="./assets/auto_sql.gif" width="600px" />
|
||||
</p>
|
||||
|
||||
### SQL 生成
|
||||
|
||||
1. 生成建表语句
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/SQL_Gen_CreateTable.png" width="600px" />
|
||||
</p>
|
||||
|
||||
2. 生成可运行SQL
|
||||
首先选择对应的数据库, 然后模型即可根据对应的数据库 Schema 信息生成 SQL, 运行成功的效果如下面的演示:
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/exeable.png" width="600px" />
|
||||
</p>
|
||||
|
||||
3. 自动分析执行SQL输出运行结果
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/Auto-DB-GPT.png" width="600px" />
|
||||
</p>
|
||||
|
||||
### 数据库问答
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/DB_QA.png" width="600px" />
|
||||
</p>
|
||||
|
||||
|
||||
1. 基于默认内置知识库问答
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/VectorDBQA.png" width="600px" />
|
||||
</p>
|
||||
|
||||
2. 自己新增知识库
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/new_knownledge.gif" width="600px" />
|
||||
</p>
|
||||
|
||||
3. 从网络自己爬取数据学习
|
||||
- TODO
|
||||
|
||||
## 架构方案
|
||||
DB-GPT基于 [FastChat](https://github.com/lm-sys/FastChat) 构建大模型运行环境,并提供 vicuna 作为基础的大语言模型。此外,我们通过LangChain提供私域知识库问答能力。同时我们支持插件模式, 在设计上原生支持Auto-GPT插件。
|
||||
@ -147,7 +113,7 @@ TODO: 在终端展示上,我们将提供多端产品界面。包括PC、手机
|
||||
|
||||
本项目依赖一个本地的 MySQL 数据库服务,你需要本地安装,推荐直接使用 Docker 安装。
|
||||
```
|
||||
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest
|
||||
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa123456 -dit mysql:latest
|
||||
```
|
||||
向量数据库我们默认使用的是Chroma内存数据库,所以无需特殊安装,如果有需要连接其他的同学,可以按照我们的教程进行安装配置。整个DB-GPT的安装过程,我们使用的是miniconda3的虚拟环境。创建虚拟环境,并安装python依赖包
|
||||
|
||||
|
Before Width: | Height: | Size: 2.4 MiB |
Before Width: | Height: | Size: 88 KiB |
BIN
assets/DB_QA.png
Before Width: | Height: | Size: 243 KiB |
Before Width: | Height: | Size: 300 KiB |
Before Width: | Height: | Size: 310 KiB |
Before Width: | Height: | Size: 297 KiB |
Before Width: | Height: | Size: 297 KiB |
Before Width: | Height: | Size: 567 KiB |
Before Width: | Height: | Size: 560 KiB |
BIN
assets/auto_plugin.gif
Normal file
After Width: | Height: | Size: 1.8 MiB |
Before Width: | Height: | Size: 3.1 MiB After Width: | Height: | Size: 2.2 MiB |
BIN
assets/chat_url_zh.gif
Normal file
After Width: | Height: | Size: 1.7 MiB |
BIN
assets/chaturl_en.gif
Normal file
After Width: | Height: | Size: 3.0 MiB |
BIN
assets/dbgpt_bytebase_plugin.gif
Normal file
After Width: | Height: | Size: 15 MiB |
Before Width: | Height: | Size: 277 KiB |
Before Width: | Height: | Size: 947 KiB |
BIN
assets/pilot.png
Before Width: | Height: | Size: 78 KiB |
Before Width: | Height: | Size: 164 KiB |
@ -1 +1,43 @@
|
||||
# Knownledge based qa
|
||||
# Knownledge based qa
|
||||
|
||||
Chat with your own knowledge is a very interesting thing. In the usage scenarios of this chapter, we will introduce how to build your own knowledge base through the knowledge base API. Firstly, building a knowledge store can currently be initialized by executing "python tool/knowledge_init.py" to initialize the content of your own knowledge base, which was introduced in the previous knowledge base module. Of course, you can also call our provided knowledge embedding API to store knowledge.
|
||||
|
||||
|
||||
We currently support four document formats: txt, pdf, url, and md.
|
||||
```
|
||||
vector_store_config = {
|
||||
"vector_store_name": name
|
||||
}
|
||||
|
||||
file_path = "your file path"
|
||||
|
||||
knowledge_embedding_client = KnowledgeEmbedding(file_path=file_path, model_name=LLM_MODEL_CONFIG["text2vec"],local_persist=False, vector_store_config=vector_store_config)
|
||||
|
||||
knowledge_embedding_client.knowledge_embedding()
|
||||
|
||||
```
|
||||
|
||||
Now we currently support vector databases: Chroma (default) and Milvus. You can switch between them by modifying the "VECTOR_STORE_TYPE" field in the .env file.
|
||||
```
|
||||
#*******************************************************************#
|
||||
#** VECTOR STORE SETTINGS **#
|
||||
#*******************************************************************#
|
||||
VECTOR_STORE_TYPE=Chroma
|
||||
#MILVUS_URL=127.0.0.1
|
||||
#MILVUS_PORT=19530
|
||||
```
|
||||
|
||||
|
||||
Below is an example of using the knowledge base API to query knowledge:
|
||||
|
||||
```
|
||||
vector_store_config = {
|
||||
"vector_store_name": name
|
||||
}
|
||||
|
||||
query = "your query"
|
||||
|
||||
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], local_persist=False, vector_store_config=vector_store_config)
|
||||
|
||||
knowledge_embedding_client.similar_search(query, 10)
|
||||
```
|
@ -148,6 +148,8 @@ class Config(metaclass=Singleton):
|
||||
|
||||
### EMBEDDING Configuration
|
||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500))
|
||||
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 10))
|
||||
### SUMMARY_CONFIG Configuration
|
||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR")
|
||||
|
||||
|
@ -0,0 +1 @@
|
||||
LlamaIndex是一个数据框架,旨在帮助您构建LLM应用程序。它包括一个向量存储索引和一个简单的目录阅读器,可以帮助您处理和操作数据。此外,LlamaIndex还提供了一个GPT Index,可以用于数据增强和生成更好的LM模型。
|
@ -97,6 +97,20 @@ class GuanacoAdapter(BaseLLMAdaper):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class GuanacoAdapter(BaseLLMAdaper):
|
||||
"""TODO Support guanaco"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "guanaco" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class CodeGenAdapter(BaseLLMAdaper):
|
||||
pass
|
||||
|
||||
|
54
pilot/model/guanaco_stream_llm.py
Normal file
@ -0,0 +1,54 @@
|
||||
import torch
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
|
||||
|
||||
def guanaco_stream_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
tokenizer.bos_token_id = 1
|
||||
print(params)
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
query = prompt
|
||||
print("Query Message: ", query)
|
||||
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = 1
|
||||
stop_token_ids = [0]
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stop_id in stop_token_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
stop = StopOnTokens()
|
||||
|
||||
generate_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=512,
|
||||
temperature=1.0,
|
||||
do_sample=True,
|
||||
top_k=1,
|
||||
streamer=streamer,
|
||||
repetition_penalty=1.7,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
t.start()
|
||||
|
||||
out = ""
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield new_text
|
||||
return out
|
@ -1,9 +1,9 @@
|
||||
import torch
|
||||
import copy
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
|
||||
|
||||
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
|
||||
@ -16,15 +16,20 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
stop_token_ids = [0]
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stop_id in stop_token_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
stop = StopOnTokens()
|
||||
|
||||
generate_kwargs = dict(
|
||||
@ -32,17 +37,16 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
max_new_tokens=512,
|
||||
temperature=1.0,
|
||||
do_sample=True,
|
||||
top_k=1,
|
||||
top_k=1,
|
||||
streamer=streamer,
|
||||
repetition_penalty=1.7,
|
||||
stopping_criteria=StoppingCriteriaList([stop])
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
|
||||
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
t1.start()
|
||||
|
||||
generator = model.generate(**generate_kwargs)
|
||||
generator = model.generate(**generate_kwargs)
|
||||
for output in generator:
|
||||
# new_tokens = len(output) - len(input_ids[0])
|
||||
decoded_output = tokenizer.decode(output)
|
||||
@ -53,3 +57,54 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
|
||||
yield out
|
||||
|
||||
|
||||
def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
tokenizer.bos_token_id = 1
|
||||
print(params)
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
max_new_tokens = params.get("max_new_tokens", 512)
|
||||
temerature = params.get("temperature", 1.0)
|
||||
|
||||
query = prompt
|
||||
print("Query Message: ", query)
|
||||
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = 1
|
||||
stop_token_ids = [0]
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stop_id in stop_token_ids:
|
||||
if input_ids[-1][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
stop = StopOnTokens()
|
||||
|
||||
generate_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temerature,
|
||||
do_sample=True,
|
||||
top_k=1,
|
||||
streamer=streamer,
|
||||
repetition_penalty=1.7,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
model.generate(**generate_kwargs)
|
||||
|
||||
out = ""
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield out
|
||||
|
@ -68,15 +68,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
}
|
||||
|
||||
print(payloads)
|
||||
print(headers)
|
||||
res = requests.post(
|
||||
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
||||
)
|
||||
|
||||
text = ""
|
||||
print("====================================res================")
|
||||
print(res)
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode("utf-8")
|
||||
|
@ -118,6 +118,8 @@ class ModelLoader(metaclass=Singleton):
|
||||
model.to(self.device)
|
||||
except ValueError:
|
||||
pass
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if debug:
|
||||
print(model)
|
||||
|
@ -53,9 +53,14 @@ class BaseOutputParser(ABC):
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len + 11:].strip()
|
||||
# output = data["text"][skip_echo_len + 11:].strip()
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
elif "guanaco" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
|
||||
# NO stream output
|
||||
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
|
||||
|
||||
# stream out output
|
||||
output = data["text"][11:].replace("<s>", "").strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
@ -66,8 +71,7 @@ class BaseOutputParser(ABC):
|
||||
return output
|
||||
|
||||
# TODO 后续和模型绑定
|
||||
def parse_model_stream_resp(self, response, skip_echo_len):
|
||||
|
||||
def parse_model_stream_resp(self, response, skip_echo_len):
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
@ -75,7 +79,7 @@ class BaseOutputParser(ABC):
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
|
||||
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
@ -113,7 +117,6 @@ class BaseOutputParser(ABC):
|
||||
else:
|
||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
parse model out text to prompt define response
|
||||
@ -129,9 +132,9 @@ class BaseOutputParser(ABC):
|
||||
# if "```" in cleaned_output:
|
||||
# cleaned_output, _ = cleaned_output.split("```")
|
||||
if cleaned_output.startswith("```json"):
|
||||
cleaned_output = cleaned_output[len("```json"):]
|
||||
cleaned_output = cleaned_output[len("```json") :]
|
||||
if cleaned_output.startswith("```"):
|
||||
cleaned_output = cleaned_output[len("```"):]
|
||||
cleaned_output = cleaned_output[len("```") :]
|
||||
if cleaned_output.endswith("```"):
|
||||
cleaned_output = cleaned_output[: -len("```")]
|
||||
cleaned_output = cleaned_output.strip()
|
||||
@ -143,7 +146,13 @@ class BaseOutputParser(ABC):
|
||||
cleaned_output = m.group(0)
|
||||
else:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
|
||||
cleaned_output = (
|
||||
cleaned_output.strip()
|
||||
.replace("\n", "")
|
||||
.replace("\\n", "")
|
||||
.replace("\\", "")
|
||||
.replace("\\", "")
|
||||
)
|
||||
return cleaned_output
|
||||
|
||||
def parse_view_response(self, ai_text, data) -> str:
|
||||
|
@ -57,7 +57,14 @@ class BaseChat(ABC):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input):
|
||||
def __init__(
|
||||
self,
|
||||
temperature,
|
||||
max_new_tokens,
|
||||
chat_mode,
|
||||
chat_session_id,
|
||||
current_user_input,
|
||||
):
|
||||
self.chat_session_id = chat_session_id
|
||||
self.chat_mode = chat_mode
|
||||
self.current_user_input: str = current_user_input
|
||||
@ -68,7 +75,9 @@ class BaseChat(ABC):
|
||||
## TEST
|
||||
self.memory = FileHistoryMemory(chat_session_id)
|
||||
### load prompt template
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value]
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||
self.chat_mode.value
|
||||
]
|
||||
self.history_message: List[OnceConversation] = []
|
||||
self.current_message: OnceConversation = OnceConversation()
|
||||
self.current_tokens_used: int = 0
|
||||
@ -129,7 +138,7 @@ class BaseChat(ABC):
|
||||
def stream_call(self):
|
||||
payload = self.__call_base()
|
||||
|
||||
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
|
||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
@ -175,29 +184,37 @@ class BaseChat(ABC):
|
||||
|
||||
### output parse
|
||||
ai_response_text = (
|
||||
self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep)
|
||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||
response, self.prompt_template.sep
|
||||
)
|
||||
)
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
ai_response_text
|
||||
)
|
||||
)
|
||||
|
||||
result = self.do_with_prompt_response(prompt_define_response)
|
||||
|
||||
if hasattr(prompt_define_response, "thoughts"):
|
||||
if isinstance(prompt_define_response.thoughts, dict):
|
||||
if isinstance(prompt_define_response.thoughts, dict):
|
||||
if "speak" in prompt_define_response.thoughts:
|
||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||
else:
|
||||
speak_to_user = str(prompt_define_response.thoughts)
|
||||
else:
|
||||
if hasattr(prompt_define_response.thoughts, "speak"):
|
||||
if hasattr(prompt_define_response.thoughts, "speak"):
|
||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
||||
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
||||
speak_to_user = prompt_define_response.thoughts.get("reasoning")
|
||||
else:
|
||||
speak_to_user = prompt_define_response.thoughts
|
||||
else:
|
||||
speak_to_user = prompt_define_response
|
||||
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
|
||||
view_message = self.prompt_template.output_parser.parse_view_response(
|
||||
speak_to_user, result
|
||||
)
|
||||
self.current_message.add_view_message(view_message)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
@ -226,20 +243,20 @@ class BaseChat(ABC):
|
||||
for first_message in self.history_message[0].messages:
|
||||
if not isinstance(first_message, ViewMessage):
|
||||
text += (
|
||||
first_message.type
|
||||
+ ":"
|
||||
+ first_message.content
|
||||
+ self.prompt_template.sep
|
||||
first_message.type
|
||||
+ ":"
|
||||
+ first_message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
index = self.chat_retention_rounds - 1
|
||||
for last_message in self.history_message[-index:].messages:
|
||||
if not isinstance(last_message, ViewMessage):
|
||||
text += (
|
||||
last_message.type
|
||||
+ ":"
|
||||
+ last_message.content
|
||||
+ self.prompt_template.sep
|
||||
last_message.type
|
||||
+ ":"
|
||||
+ last_message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
else:
|
||||
@ -248,16 +265,16 @@ class BaseChat(ABC):
|
||||
for message in conversation.messages:
|
||||
if not isinstance(message, ViewMessage):
|
||||
text += (
|
||||
message.type
|
||||
+ ":"
|
||||
+ message.content
|
||||
+ self.prompt_template.sep
|
||||
message.type
|
||||
+ ":"
|
||||
+ message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
### current conversation
|
||||
|
||||
for now_message in self.current_message.messages:
|
||||
text += (
|
||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
)
|
||||
|
||||
return text
|
||||
@ -288,4 +305,3 @@ class BaseChat(ABC):
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -47,12 +47,13 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient()
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"top_k": str(self.top_k),
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_simple_info(self.db_connect)
|
||||
# "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
}
|
||||
return input_values
|
||||
|
||||
|
@ -35,7 +35,13 @@ class ChatWithDbQA(BaseChat):
|
||||
self.database = CFG.local_db
|
||||
# 准备DB信息(拿到指定库的链接)
|
||||
self.db_connect = self.database.get_session(self.db_name)
|
||||
self.top_k: int = 5
|
||||
self.tables = self.database.get_table_names()
|
||||
|
||||
self.top_k = (
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
if len(self.tables) > CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
else len(self.tables)
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
table_info = ""
|
||||
@ -45,7 +51,8 @@ class ChatWithDbQA(BaseChat):
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
if self.db_name:
|
||||
table_info = DBSummaryClient.get_similar_tables(
|
||||
client = DBSummaryClient()
|
||||
table_info = client.get_similar_tables(
|
||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
)
|
||||
# table_info = self.database.table_simple_info(self.db_connect)
|
||||
|
@ -14,14 +14,19 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
class PluginAction(NamedTuple):
|
||||
command: Dict
|
||||
speak: str
|
||||
reasoning:str
|
||||
reasoning: str
|
||||
thoughts: str
|
||||
|
||||
|
||||
class PluginChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||
command, thoughts, speak, reasoning = response["command"], response["thoughts"], response["speak"], response["reasoning"]
|
||||
command, thoughts, speak, reasoning = (
|
||||
response["command"],
|
||||
response["thoughts"],
|
||||
response["speak"],
|
||||
response["reasoning"],
|
||||
)
|
||||
return PluginAction(command, speak, reasoning, thoughts)
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
|
@ -14,7 +14,6 @@ from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_knowledge.custom.prompt import prompt
|
||||
@ -46,15 +45,13 @@ class ChatNewKnowledge(BaseChat):
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[:2000]
|
||||
|
@ -14,13 +14,23 @@ CFG = Config()
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
@ -31,7 +41,7 @@ prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatNewKnowledge.value,
|
||||
input_variables=["context", "question"],
|
||||
response_format=None,
|
||||
template_define=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(
|
||||
|
@ -14,7 +14,6 @@ from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_knowledge.default.prompt import prompt
|
||||
@ -42,15 +41,13 @@ class ChatDefaultKnowledge(BaseChat):
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[:2000]
|
||||
|
@ -11,13 +11,27 @@ from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
@ -28,7 +42,7 @@ prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatKnowledge.value,
|
||||
input_variables=["context", "question"],
|
||||
response_format=None,
|
||||
template_define=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(
|
||||
|
@ -14,7 +14,6 @@ from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_knowledge.url.prompt import prompt
|
||||
@ -40,15 +39,13 @@ class ChatUrlKnowledge(BaseChat):
|
||||
self.url = url
|
||||
vector_store_config = {
|
||||
"vector_store_name": url,
|
||||
"text_field": "content",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=url,
|
||||
file_type="url",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
file_type="url",
|
||||
file_path=url,
|
||||
)
|
||||
|
||||
# url soruce in vector
|
||||
@ -58,7 +55,7 @@ class ChatUrlKnowledge(BaseChat):
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[:2000]
|
||||
|
@ -11,13 +11,27 @@ from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_DEFAULT_TEMPLATE = """ Based on the known information, provide professional and concise answers to the user's questions. If the answer cannot be obtained from the provided content, please say: 'The information provided in the knowledge base is not sufficient to answer this question.' Fabrication is prohibited.。
|
||||
PROMPT_SCENE_DEFINE = """A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
@ -27,7 +41,7 @@ prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatUrlKnowledge.value,
|
||||
input_variables=["context", "question"],
|
||||
response_format=None,
|
||||
template_define=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(
|
||||
|
@ -60,6 +60,20 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
||||
return chatglm_generate_stream
|
||||
|
||||
|
||||
class GuanacoChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for Guanaco"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "guanaco" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.llm_out.guanaco_stream_llm import (
|
||||
guanaco_stream_generate_output,
|
||||
)
|
||||
|
||||
return guanaco_generate_output
|
||||
|
||||
|
||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
||||
|
||||
"""Model chat adapter for CodeT5"""
|
||||
@ -91,9 +105,9 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
||||
return "guanaco" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_output
|
||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream
|
||||
|
||||
return guanaco_generate_output
|
||||
return guanaco_generate_stream
|
||||
|
||||
|
||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
@ -110,6 +124,7 @@ register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
||||
|
||||
|
||||
# Proxy model for test and develop, it's cheap for us now.
|
||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||
|
||||
|
@ -3,12 +3,14 @@
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
|
||||
from pilot.logs import logger
|
||||
from pilot.model.llm_out.vicuna_llm import VicunaLLM
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class KnownLedgeBaseQA:
|
||||
def __init__(self) -> None:
|
||||
@ -22,7 +24,7 @@ class KnownLedgeBaseQA:
|
||||
)
|
||||
|
||||
retriever = self.vector_store.as_retriever(
|
||||
search_kwargs={"k": VECTOR_SEARCH_TOP_K}
|
||||
search_kwargs={"k": CFG.KNOWLEDGE_SEARCH_TOP_SIZE}
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query=query)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import threading
|
||||
import traceback
|
||||
import argparse
|
||||
import datetime
|
||||
@ -414,7 +415,7 @@ def build_single_model_ui():
|
||||
show_label=True,
|
||||
).style(container=False)
|
||||
|
||||
db_selector.change(fn=db_selector_changed, inputs=db_selector)
|
||||
# db_selector.change(fn=db_selector_changed, inputs=db_selector)
|
||||
|
||||
sql_mode = gr.Radio(
|
||||
[
|
||||
@ -618,10 +619,6 @@ def save_vs_name(vs_name):
|
||||
return vs_name
|
||||
|
||||
|
||||
def db_selector_changed(dbname):
|
||||
DBSummaryClient.db_summary_embedding(dbname)
|
||||
|
||||
|
||||
def knowledge_embedding_store(vs_id, files):
|
||||
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
|
||||
@ -634,7 +631,6 @@ def knowledge_embedding_store(vs_id, files):
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config={
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
@ -646,6 +642,12 @@ def knowledge_embedding_store(vs_id, files):
|
||||
return vs_id
|
||||
|
||||
|
||||
def async_db_summery():
|
||||
client = DBSummaryClient()
|
||||
thread = threading.Thread(target=client.init_db_summary)
|
||||
thread.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
@ -662,7 +664,7 @@ if __name__ == "__main__":
|
||||
cfg = Config()
|
||||
|
||||
dbs = cfg.local_db.get_database_list()
|
||||
|
||||
async_db_summery()
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# 加载插件可执行命令
|
||||
|
26
pilot/source_embedding/EncodeTextLoader.py
Normal file
@ -0,0 +1,26 @@
|
||||
from typing import List, Optional
|
||||
import chardet
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class EncodeTextLoader(BaseLoader):
|
||||
"""Load text files."""
|
||||
|
||||
def __init__(self, file_path: str, encoding: Optional[str] = None):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
self.encoding = encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from file path."""
|
||||
with open(self.file_path, "rb") as f:
|
||||
raw_text = f.read()
|
||||
result = chardet.detect(raw_text)
|
||||
if result["encoding"] is None:
|
||||
text = raw_text.decode("utf-8")
|
||||
else:
|
||||
text = raw_text.decode(result["encoding"])
|
||||
metadata = {"source": self.file_path}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
@ -12,14 +12,12 @@ class CSVEmbedding(SourceEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
embedding_args: Optional[Dict] = None,
|
||||
):
|
||||
"""Initialize with csv path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.embedding_args = embedding_args
|
||||
|
||||
|
@ -1,30 +1,34 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders import PyPDFLoader, TextLoader
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||
from pilot.source_embedding.url_embedding import URLEmbedding
|
||||
from pilot.source_embedding.word_embedding import WordEmbedding
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
CFG = Config()
|
||||
|
||||
KnowledgeEmbeddingType = {
|
||||
".txt": (MarkdownEmbedding, {}),
|
||||
".md": (MarkdownEmbedding, {}),
|
||||
".pdf": (PDFEmbedding, {}),
|
||||
".doc": (WordEmbedding, {}),
|
||||
".docx": (WordEmbedding, {}),
|
||||
".csv": (CSVEmbedding, {}),
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
local_persist=True,
|
||||
file_type="default",
|
||||
file_type: Optional[str] = "default",
|
||||
file_path: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
@ -33,11 +37,9 @@ class KnowledgeEmbedding:
|
||||
self.file_type = file_type
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
self.vector_store_config["embeddings"] = self.embeddings
|
||||
self.local_persist = local_persist
|
||||
if not self.local_persist:
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
|
||||
def knowledge_embedding(self):
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
self.knowledge_embedding_client.source_embedding()
|
||||
|
||||
def knowledge_embedding_batch(self):
|
||||
@ -50,95 +52,27 @@ class KnowledgeEmbedding:
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
elif self.file_path.endswith(".pdf"):
|
||||
embedding = PDFEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
return embedding
|
||||
extension = "." + self.file_path.rsplit(".", 1)[-1]
|
||||
if extension in KnowledgeEmbeddingType:
|
||||
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
||||
embedding = knowledge_class(
|
||||
self.file_path,
|
||||
vector_store_config=self.vector_store_config,
|
||||
**knowledge_args,
|
||||
)
|
||||
elif self.file_path.endswith(".md"):
|
||||
embedding = MarkdownEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
elif self.file_path.endswith(".csv"):
|
||||
embedding = CSVEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
elif self.file_type == "default":
|
||||
embedding = MarkdownEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
return embedding
|
||||
raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
||||
return embedding
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
return self.knowledge_embedding_client.similar_search(text, topk)
|
||||
|
||||
def vector_exist(self):
|
||||
return self.knowledge_embedding_client.vector_name_exist()
|
||||
|
||||
def knowledge_persist_initialization(self, append_mode):
|
||||
documents = self._load_knownlege(self.file_path)
|
||||
self.vector_client = VectorStoreConnector(
|
||||
vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
self.vector_client.load_document(documents)
|
||||
return self.vector_client
|
||||
return vector_client.similar_search(text, topk)
|
||||
|
||||
def _load_knownlege(self, path):
|
||||
docments = []
|
||||
for root, _, files in os.walk(path, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
docs = self._load_file(filename)
|
||||
new_docs = []
|
||||
for doc in docs:
|
||||
doc.metadata = {
|
||||
"source": doc.metadata["source"].replace(DATASETS_DIR, "")
|
||||
}
|
||||
print("doc is embedding...", doc.metadata)
|
||||
new_docs.append(doc)
|
||||
docments += new_docs
|
||||
return docments
|
||||
|
||||
def _load_file(self, filename):
|
||||
if filename.lower().endswith(".md"):
|
||||
loader = TextLoader(filename)
|
||||
text_splitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
)
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
i = 0
|
||||
for d in docs:
|
||||
content = markdown.markdown(d.page_content)
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
for tag in soup(["!doctype", "meta", "i.fa"]):
|
||||
tag.extract()
|
||||
docs[i].page_content = soup.get_text()
|
||||
docs[i].page_content = docs[i].page_content.replace("\n", " ")
|
||||
i += 1
|
||||
elif filename.lower().endswith(".pdf"):
|
||||
loader = PyPDFLoader(filename)
|
||||
textsplitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
)
|
||||
docs = loader.load_and_split(textsplitter)
|
||||
i = 0
|
||||
for d in docs:
|
||||
docs[i].page_content = d.page_content.replace("\n", " ").replace(
|
||||
"<EFBFBD>", ""
|
||||
)
|
||||
i += 1
|
||||
else:
|
||||
loader = TextLoader(filename)
|
||||
text_splitor = CHNDocumentSplitter(sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
|
||||
docs = loader.load_and_split(text_splitor)
|
||||
return docs
|
||||
def vector_exist(self):
|
||||
vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
return vector_client.vector_name_exists()
|
||||
|
@ -8,27 +8,30 @@ from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.configs.config import Config
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.EncodeTextLoader import EncodeTextLoader
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class MarkdownEmbedding(SourceEmbedding):
|
||||
"""markdown embedding for read markdown document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
def __init__(self, file_path, vector_store_config):
|
||||
"""Initialize with markdown path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
# self.encoding = encoding
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from markdown path."""
|
||||
loader = TextLoader(self.file_path)
|
||||
loader = EncodeTextLoader(self.file_path)
|
||||
text_splitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
|
@ -5,20 +5,22 @@ from typing import List
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.configs.config import Config
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class PDFEmbedding(SourceEmbedding):
|
||||
"""pdf embedding for read pdf document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
def __init__(self, file_path, vector_store_config, encoding):
|
||||
"""Initialize with pdf path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.encoding = encoding
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
@ -26,7 +28,7 @@ class PDFEmbedding(SourceEmbedding):
|
||||
# loader = UnstructuredPaddlePDFLoader(self.file_path)
|
||||
loader = PyPDFLoader(self.file_path)
|
||||
textsplitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
|
@ -23,13 +23,11 @@ class SourceEmbedding(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
embedding_args: Optional[Dict] = None,
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.embedding_args = embedding_args
|
||||
self.embeddings = vector_store_config["embeddings"]
|
||||
|
@ -8,11 +8,10 @@ from pilot import SourceEmbedding, register
|
||||
class StringEmbedding(SourceEmbedding):
|
||||
"""string embedding for read string document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
def __init__(self, file_path, vector_store_config):
|
||||
"""Initialize with pdf path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
|
||||
@register
|
||||
|
@ -5,27 +5,36 @@ from langchain.document_loaders import WebBaseLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class URLEmbedding(SourceEmbedding):
|
||||
"""url embedding for read url document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
def __init__(self, file_path, vector_store_config):
|
||||
"""Initialize with url path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from url path."""
|
||||
loader = WebBaseLoader(web_path=self.file_path)
|
||||
text_splitor = CharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=20, length_function=len
|
||||
)
|
||||
return loader.load_and_split(text_splitor)
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = CharacterTextSplitter(
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=20,
|
||||
length_function=len,
|
||||
)
|
||||
else:
|
||||
text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
|
39
pilot/source_embedding/word_embedding.py
Normal file
@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class WordEmbedding(SourceEmbedding):
|
||||
"""word embedding for read word document."""
|
||||
|
||||
def __init__(self, file_path, vector_store_config):
|
||||
"""Initialize with word path."""
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from word path."""
|
||||
loader = UnstructuredWordDocumentLoader(self.file_path)
|
||||
textsplitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
documents[i].page_content = d.page_content.replace("\n", "")
|
||||
i += 1
|
||||
return documents
|
@ -21,8 +21,10 @@ class DBSummaryClient:
|
||||
, get_similar_tables method(get user query related tables info)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def db_summary_embedding(dbname):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def db_summary_embedding(self, dbname):
|
||||
"""put db profile and table profile summary into vector store"""
|
||||
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
|
||||
db_summary_client = MysqlSummary(dbname)
|
||||
@ -34,24 +36,21 @@ class DBSummaryClient:
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
embedding = StringEmbedding(
|
||||
db_summary_client.get_summery(),
|
||||
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config,
|
||||
file_path=db_summary_client.get_summery(),
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
if not embedding.vector_name_exist():
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
for vector_table_info in db_summary_client.get_summery():
|
||||
embedding = StringEmbedding(
|
||||
vector_table_info,
|
||||
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config,
|
||||
)
|
||||
embedding.source_embedding()
|
||||
else:
|
||||
embedding = StringEmbedding(
|
||||
db_summary_client.get_summery(),
|
||||
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config,
|
||||
file_path=db_summary_client.get_summery(),
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
embedding.source_embedding()
|
||||
for (
|
||||
@ -59,32 +58,24 @@ class DBSummaryClient:
|
||||
table_summary,
|
||||
) in db_summary_client.get_table_summary().items():
|
||||
table_vector_store_config = {
|
||||
"vector_store_name": table_name + "_ts",
|
||||
"vector_store_name": dbname + "_" + table_name + "_ts",
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
embedding = StringEmbedding(
|
||||
table_summary,
|
||||
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
table_vector_store_config,
|
||||
)
|
||||
embedding.source_embedding()
|
||||
|
||||
logger.info("db summary embedding success")
|
||||
|
||||
@staticmethod
|
||||
def get_similar_tables(dbname, query, topk):
|
||||
def get_similar_tables(self, dbname, query, topk):
|
||||
"""get user query related tables info"""
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_profile",
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
@ -104,19 +95,23 @@ class DBSummaryClient:
|
||||
related_table_summaries = []
|
||||
for table in related_tables:
|
||||
vector_store_config = {
|
||||
"vector_store_name": table + "_ts",
|
||||
"embeddings": embeddings,
|
||||
"vector_store_name": dbname + "_" + table + "_ts",
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
table_summery = knowledge_embedding_client.similar_search(query, 1)
|
||||
related_table_summaries.append(table_summery[0].page_content)
|
||||
return related_table_summaries
|
||||
|
||||
def init_db_summary(self):
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
for dbname in dbs:
|
||||
self.db_summary_embedding(dbname)
|
||||
|
||||
|
||||
def _get_llm_response(query, db_input, dbsummary):
|
||||
chat_param = {
|
||||
@ -132,30 +127,3 @@ def _get_llm_response(query, db_input, dbsummary):
|
||||
)
|
||||
res = chat.nostream_call()
|
||||
return json.loads(res)["table"]
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # summary = DBSummaryClient.get_similar_tables("db_test", "查询在线用户的购物车", 10)
|
||||
#
|
||||
# text= """Based on the input "查询在线聊天的用户好友" and the known database information, the tables involved in the user input are "chat_users" and "friends".
|
||||
# Response:
|
||||
#
|
||||
# {
|
||||
# "table": ["chat_users"]
|
||||
# }"""
|
||||
# text = text.rstrip().replace("\n","")
|
||||
# start = text.find("{")
|
||||
# end = text.find("}") + 1
|
||||
#
|
||||
# # 从字符串中截取出JSON数据
|
||||
# json_str = text[start:end]
|
||||
#
|
||||
# # 将JSON数据转换为Python中的字典类型
|
||||
# data = json.loads(json_str)
|
||||
# # pattern = r'{s*"table"s*:s*[[^]]*]s*}'
|
||||
# # match = re.search(pattern, text)
|
||||
# # if match:
|
||||
# # json_string = match.group(0)
|
||||
# # # 将JSON字符串转换为Python对象
|
||||
# # json_obj = json.loads(json_string)
|
||||
# # print(summary)
|
||||
|
@ -17,7 +17,6 @@ from langchain.vectorstores import Chroma
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
LLM_MODEL_CONFIG,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
VECTORE_PATH,
|
||||
)
|
||||
|
||||
@ -41,7 +40,6 @@ class KnownLedge2Vector:
|
||||
|
||||
embeddings: object = None
|
||||
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
|
||||
def __init__(self, model_name=None) -> None:
|
||||
if not model_name:
|
||||
|
@ -10,7 +10,6 @@ from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
LLM_MODEL_CONFIG,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
|
||||
@ -19,36 +18,30 @@ CFG = Config()
|
||||
|
||||
class LocalKnowledgeInit:
|
||||
embeddings: object = None
|
||||
model_name = LLM_MODEL_CONFIG["text2vec"]
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
|
||||
def __init__(self, vector_store_config) -> None:
|
||||
self.vector_store_config = vector_store_config
|
||||
self.model_name = LLM_MODEL_CONFIG["text2vec"]
|
||||
|
||||
def knowledge_persist(self, file_path, append_mode):
|
||||
"""knowledge persist"""
|
||||
kv = KnowledgeEmbedding(
|
||||
file_path=file_path,
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
vector_store = kv.knowledge_persist_initialization(append_mode)
|
||||
return vector_store
|
||||
|
||||
def query(self, q):
|
||||
"""Query similar doc from Vector"""
|
||||
vector_store = self.init_vector_store()
|
||||
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
|
||||
for doc in docs:
|
||||
dc, s = doc
|
||||
yield s, dc
|
||||
for root, _, files in os.walk(file_path, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
# docs = self._load_file(filename)
|
||||
ke = KnowledgeEmbedding(
|
||||
file_path=filename,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
client = ke.init_knowledge_embedding()
|
||||
client.source_embedding()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vector_name", type=str, default="default")
|
||||
parser.add_argument("--append", type=bool, default=False)
|
||||
parser.add_argument("--store_type", type=str, default="Chroma")
|
||||
args = parser.parse_args()
|
||||
vector_name = args.vector_name
|
||||
append_mode = args.append
|
||||
@ -56,5 +49,5 @@ if __name__ == "__main__":
|
||||
vector_store_config = {"vector_store_name": vector_name}
|
||||
print(vector_store_config)
|
||||
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
||||
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
||||
kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
||||
print("your knowledge embedding success...")
|
||||
|