fix some error

This commit is contained in:
csunny 2023-05-08 21:04:50 +08:00
commit 8fb917c8ab
17 changed files with 353 additions and 59 deletions

4
.gitignore vendored
View File

@ -131,4 +131,8 @@ dmypy.json
.pyre/ .pyre/
.DS_Store .DS_Store
logs logs
<<<<<<< HEAD
=======
>>>>>>> 2eeb22ccae5b2a5d695ac267044129f3b672c3dc
.vectordb .vectordb

View File

@ -1,8 +1,10 @@
# DB-GPT # DB-GPT
A Open Database-GPT Experiment A Open Database-GPT Experiment, A fully localized project.
![GitHub Repo stars](https://img.shields.io/github/stars/csunny/db-gpt?style=social) ![GitHub Repo stars](https://img.shields.io/github/stars/csunny/db-gpt?style=social)
一个数据库相关的GPT实验项目, 模型与数据全部本地化部署, 绝对保障数据的隐私安全。 同时此GPT项目可以直接本地部署连接到私有数据库, 进行私有数据处理。
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用[vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)作为基础模型。此外,此程序结合了[langchain](https://github.com/hwchase17/langchain)和[llama-index](https://github.com/jerryjliu/llama_index)基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用[vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)作为基础模型。此外,此程序结合了[langchain](https://github.com/hwchase17/langchain)和[llama-index](https://github.com/jerryjliu/llama_index)基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。
@ -30,6 +32,26 @@ Run on an RTX 4090 GPU (The origin mov not sped up!, [YouTube地址](https://www
<img src="https://github.com/csunny/DB-GPT/blob/main/asserts/DB_QA.png" margin-left="auto" margin-right="auto" width="600"> <img src="https://github.com/csunny/DB-GPT/blob/main/asserts/DB_QA.png" margin-left="auto" margin-right="auto" width="600">
# Dependencies
1. First you need to install python requirements.
```
python>=3.9
pip install -r requirements
```
or if you use conda envirenment, you can use this command
```
cd DB-GPT
conda env create -f environment.yml
```
2. MySQL Install
In this project examples, we connect mysql and run SQL-Generate. so you need install mysql local for test. recommand docker
```
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa123456 -dit mysql:latest
```
The password just for test, you can change this if necessary
# Install # Install
1. 基础模型下载 1. 基础模型下载
关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。 关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。
@ -52,3 +74,7 @@ python webserver.py
- SQL-diagnosis - SQL-diagnosis
总的来说它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。 总的来说它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。
# Licence
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)

BIN
asserts/exeable.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 277 KiB

View File

@ -61,3 +61,7 @@ dependencies:
- gradio-client==0.0.8 - gradio-client==0.0.8
- wandb - wandb
- fschat=0.1.10 - fschat=0.1.10
- llama-index=0.5.27
- pymysql
- unstructured==0.6.3
- pytesseract==0.3.10

2
pilot/chain/audio.py Normal file
View File

@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

2
pilot/chain/visual.py Normal file
View File

@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

2
pilot/client/chart.py Normal file
View File

@ -0,0 +1,2 @@
#/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@ -11,6 +11,7 @@ PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store") VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
LOGDIR = os.path.join(ROOT_PATH, "logs") LOGDIR = os.path.join(ROOT_PATH, "logs")
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
DATA_DIR = os.path.join(PILOT_PATH, "data")
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
@ -22,6 +23,7 @@ LLM_MODEL_CONFIG = {
} }
VECTOR_SEARCH_TOP_K = 3
LLM_MODEL = "vicuna-13b" LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5 LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048 MAX_POSITION_EMBEDDINGS = 2048

View File

@ -146,8 +146,24 @@ conv_vicuna_v1 = Conversation(
sep2="</s>", sep2="</s>",
) )
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
已知内容:
{context}
问题:
{question}
"""
default_conversation = conv_one_shot default_conversation = conv_one_shot
conversation_types = {
"native": "LLM原生对话",
"default_knownledge": "默认知识库对话",
"custome": "新增知识库对话",
}
conv_templates = { conv_templates = {
"conv_one_shot": conv_one_shot, "conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1, "vicuna_v1": conv_vicuna_v1,

View File

@ -10,33 +10,29 @@ from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM from langchain.llms.base import LLM
from pilot.configs.model_config import * from pilot.configs.model_config import *
class VicunaRequestLLM(LLM): class VicunaLLM(LLM):
vicuna_generate_path = "generate" vicuna_generate_path = "generate_stream"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str:
if isinstance(stop, list):
stop = stop + ["Observation:"]
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
params = { params = {
"prompt": prompt, "prompt": prompt,
"temperature": 0.7, "temperature": temperature,
"max_new_tokens": 1024, "max_new_tokens": max_new_tokens,
"stop": stop "stop": stop
} }
response = requests.post( response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path), url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
data=json.dumps(params), data=json.dumps(params),
) )
response.raise_for_status()
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
# if chunk: for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
# data = json.loads(chunk.decode()) if chunk:
# if data["error_code"] == 0: data = json.loads(chunk.decode())
# output = data["text"][skip_echo_len:].strip() if data["error_code"] == 0:
# output = self.post_process_code(output) output = data["text"][skip_echo_len:].strip()
# yield output yield output
return response.json()["response"]
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:

View File

@ -0,0 +1,180 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
import transformers
from transformers import LlamaTokenizer, LlamaForCausalLM
from typing import List
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
)
import torch
from datasets import load_dataset
import pandas as pd
from pilot.configs.model_config import DATA_DIR, LLM_MODEL, LLM_MODEL_CONFIG
device = "cuda" if torch.cuda.is_available() else "cpu"
CUTOFF_LEN = 50
df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv"))
def sentiment_score_to_name(score: float):
if score > 0:
return "Positive"
elif score < 0:
return "Negative"
return "Neutral"
dataset_data = [
{
"instruction": "Detect the sentiment of the tweet.",
"input": row_dict["Tweet"],
"output": sentiment_score_to_name(row_dict["New_Sentiment_State"])
}
for row_dict in df.to_dict(orient="records")
]
with open(os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"), "w") as f:
json.dump(dataset_data, f)
data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"))
print(data["train"])
BASE_MODEL = LLM_MODEL_CONFIG[LLM_MODEL]
model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto",
offload_folder=os.path.join(DATA_DIR, "vicuna-lora")
)
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token_id = (0)
tokenizer.padding_side = "left"
def generate_prompt(data_point):
return f"""Blow is an instruction that describes a task, paired with an input that provide future context.
Write a response that appropriately completes the request. #noqa:
### Instruct:
{data_point["instruction"]}
### Input
{data_point["input"]}
### Response
{data_point["output"]}
"""
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN,
padding=False,
return_tensors=None,
)
if (result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < CUTOFF_LEN and add_eos_token):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = generate_prompt(data_point)
tokenized_full_prompt = tokenize(full_prompt)
return tokenized_full_prompt
train_val = data["train"].train_test_split(
test_size=200, shuffle=True, seed=42
)
train_data = (
train_val["train"].map(generate_and_tokenize_prompt)
)
val_data = (
train_val["test"].map(generate_and_tokenize_prompt)
)
# Training
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = [
"q_proj",
"v_proj",
]
BATCH_SIZE = 128
MICRO_BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LEARNING_RATE = 3e-4
TRAIN_STEPS = 300
OUTPUT_DIR = "experiments"
# We can now prepare model for training
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r = LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=LORA_TARGET_MODULES,
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
training_arguments = transformers.TrainingArguments(
per_device_train_batch_size=MICRO_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
warmup_steps=100,
max_steps=TRAIN_STEPS,
no_cuda=True,
learning_rate=LEARNING_RATE,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=50,
save_steps=50,
output_dir=OUTPUT_DIR,
save_total_limit=3,
load_best_model_at_end=True,
report_to="tensorboard"
)
data_collector = transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=training_arguments,
data_collector=data_collector
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
trainer.train()
model.save_pretrained(OUTPUT_DIR)

View File

@ -4,29 +4,44 @@
import requests import requests
import json import json
import time import time
import uuid
from urllib.parse import urljoin from urllib.parse import urljoin
import gradio as gr import gradio as gr
from pilot.configs.model_config import * from pilot.configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/" from pilot.conversation import conv_qa_prompt_template, conv_templates
vicuna_stream_path = "worker_generate_stream" from langchain.prompts import PromptTemplate
vicuna_status_path = "worker_get_status"
def generate(prompt): vicuna_stream_path = "generate_stream"
def generate(query):
template_name = "conv_one_shot"
state = conv_templates[template_name].copy()
pt = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
)
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
question=query)
print(result)
state.append_message(state.roles[0], result)
state.append_message(state.roles[1], None)
prompt = state.get_prompt()
params = { params = {
"model": "vicuna-13b", "model": "vicuna-13b",
"prompt": prompt, "prompt": prompt,
"temperature": 0.7, "temperature": 0.7,
"max_new_tokens": 512, "max_new_tokens": 1024,
"stop": "###" "stop": "###"
} }
sts_response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_status_path)
)
print(sts_response.text)
response = requests.post( response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params) url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
) )
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3 skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
@ -34,11 +49,10 @@ def generate(prompt):
if chunk: if chunk:
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
if data["error_code"] == 0: if data["error_code"] == 0:
output = data["text"] output = data["text"][skip_echo_len:].strip()
state.messages[-1][-1] = output + ""
yield(output) yield(output)
time.sleep(0.02)
if __name__ == "__main__": if __name__ == "__main__":
print(LLM_MODEL) print(LLM_MODEL)
with gr.Blocks() as demo: with gr.Blocks() as demo:

View File

@ -0,0 +1,29 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pilot.vector_store.file_loader import KnownLedge2Vector
from langchain.prompts import PromptTemplate
from pilot.conversation import conv_qa_prompt_template
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.model.vicuna_llm import VicunaLLM
class KnownLedgeBaseQA:
def __init__(self) -> None:
k2v = KnownLedge2Vector()
self.vector_store = k2v.init_vector_store()
self.llm = VicunaLLM()
def get_similar_answer(self, query):
prompt = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
)
retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K})
docs = retriever.get_relevant_documents(query=query)
context = [d.page_content for d in docs]
result = prompt.format(context="\n".join(context), question=query)
return result

View File

@ -11,6 +11,7 @@ import datetime
import requests import requests
from urllib.parse import urljoin from urllib.parse import urljoin
from pilot.configs.model_config import DB_SETTINGS from pilot.configs.model_config import DB_SETTINGS
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql_conn import MySQLOperator from pilot.connections.mysql_conn import MySQLOperator
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
@ -19,6 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
from pilot.conversation import ( from pilot.conversation import (
default_conversation, default_conversation,
conv_templates, conv_templates,
conversation_types,
SeparatorStyle SeparatorStyle
) )
@ -149,7 +151,7 @@ def post_process_code(code):
code = sep.join(blocks) code = sep.join(blocks)
return code return code
def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request): def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
start_tstamp = time.time() start_tstamp = time.time()
model_name = LLM_MODEL model_name = LLM_MODEL
@ -170,7 +172,8 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
query = state.messages[-2][1] query = state.messages[-2][1]
# prompt 中添加上下文提示 # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector: if db_selector:
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
new_state.append_message(new_state.roles[1], None) new_state.append_message(new_state.roles[1], None)
@ -180,13 +183,11 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
new_state.append_message(new_state.roles[1], None) new_state.append_message(new_state.roles[1], None)
state = new_state state = new_state
# try: if mode == conversation_types["default_knownledge"] and not db_selector:
# if not db_selector: query = state.messages[-2][1]
# sim_q = get_simlar(query) knqa = KnownLedgeBaseQA()
# print("********vector similar info*************: ", sim_q) state.messages[-2][1] = knqa.get_similar_answer(query)
# state.append_message(new_state.roles[0], sim_q + query)
# except Exception as e:
# print(e)
prompt = state.get_prompt() prompt = state.get_prompt()
@ -222,7 +223,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
state.messages[-1][-1] = output state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return return
time.sleep(0.02)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
@ -231,6 +232,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
state.messages[-1][-1] = state.messages[-1][-1][:-1] state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# 记录运行日志
finish_tstamp = time.time() finish_tstamp = time.time()
logger.info(f"{output}") logger.info(f"{output}")
@ -266,7 +268,7 @@ def change_tab(tab):
pass pass
def change_mode(mode): def change_mode(mode):
if mode == "默认知识库对话": if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False) return gr.update(visible=False)
else: else:
return gr.update(visible=True) return gr.update(visible=True)
@ -281,7 +283,7 @@ def build_single_model_ui():
""" """
learn_more_markdown = """ learn_more_markdown = """
### Licence ### Licence
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B
""" """
state = gr.State() state = gr.State()
@ -318,7 +320,8 @@ def build_single_model_ui():
show_label=True).style(container=False) show_label=True).style(container=False)
with gr.TabItem("知识问答", elem_id="QA"): with gr.TabItem("知识问答", elem_id="QA"):
mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话")
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
vs_setting = gr.Accordion("配置知识库", open=False) vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting: with vs_setting:
@ -363,7 +366,7 @@ def build_single_model_ui():
btn_list = [regenerate_btn, clear_btn] btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot, http_bot,
[state, db_selector, temperature, max_output_tokens], [state, mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list, [state, chatbot] + btn_list,
) )
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@ -372,7 +375,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then( ).then(
http_bot, http_bot,
[state, db_selector, temperature, max_output_tokens], [state, mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list, [state, chatbot] + btn_list,
) )
@ -380,7 +383,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then( ).then(
http_bot, http_bot,
[state, db_selector, temperature, max_output_tokens], [state, mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list [state, chatbot] + btn_list
) )

View File

@ -10,9 +10,8 @@ from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader
from langchain.chains import VectorDBQA from langchain.chains import VectorDBQA
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
VECTOR_SEARCH_TOP_K = 5
class KnownLedge2Vector: class KnownLedge2Vector:
@ -26,21 +25,21 @@ class KnownLedge2Vector:
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
def init_vector_store(self): def init_vector_store(self):
documents = self.load_knownlege()
persist_dir = os.path.join(VECTORE_PATH, ".vectordb") persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
print("向量数据库持久化地址: ", persist_dir) print("向量数据库持久化地址: ", persist_dir)
if os.path.exists(persist_dir): if os.path.exists(persist_dir):
# 从本地持久化文件中Load # 从本地持久化文件中Load
print("从本地向量加载数据...")
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
vector_store.add_documents(documents=documents) # vector_store.add_documents(documents=documents)
else: else:
documents = self.load_knownlege()
# 重新初始化 # 重新初始化
vector_store = Chroma.from_documents(documents=documents, vector_store = Chroma.from_documents(documents=documents,
embedding=self.embeddings, embedding=self.embeddings,
persist_directory=persist_dir) persist_directory=persist_dir)
vector_store.persist() vector_store.persist()
vector_store = None return vector_store
return persist_dir
def load_knownlege(self): def load_knownlege(self):
docments = [] docments = []
@ -70,10 +69,23 @@ class KnownLedge2Vector:
return docs return docs
def _load_from_url(self, url): def _load_from_url(self, url):
"""Load data from url address"""
pass pass
def query(self, q):
"""Query similar doc from Vector """
vector_store = self.init_vector_store()
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
for doc in docs:
dc, s = doc
yield s, dc
if __name__ == "__main__": if __name__ == "__main__":
k2v = KnownLedge2Vector() k2v = KnownLedge2Vector()
k2v.init_vector_store()
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
print(persist_dir)
for s, dc in k2v.query("什么是OceanBase"):
print(s, dc.page_content, dc.metadata)

View File

@ -53,3 +53,5 @@ wandb
fschat=0.1.10 fschat=0.1.10
llama-index=0.5.27 llama-index=0.5.27
pymysql pymysql
unstructured==0.6.3
pytesseract==0.3.10