mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
fix some error
This commit is contained in:
commit
8fb917c8ab
4
.gitignore
vendored
4
.gitignore
vendored
@ -131,4 +131,8 @@ dmypy.json
|
||||
.pyre/
|
||||
.DS_Store
|
||||
logs
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
|
||||
>>>>>>> 2eeb22ccae5b2a5d695ac267044129f3b672c3dc
|
||||
.vectordb
|
28
README.md
28
README.md
@ -1,8 +1,10 @@
|
||||
# DB-GPT
|
||||
A Open Database-GPT Experiment
|
||||
A Open Database-GPT Experiment, A fully localized project.
|
||||
|
||||

|
||||
|
||||
一个数据库相关的GPT实验项目, 模型与数据全部本地化部署, 绝对保障数据的隐私安全。 同时此GPT项目可以直接本地部署连接到私有数据库, 进行私有数据处理。
|
||||
|
||||
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用[vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)作为基础模型。此外,此程序结合了[langchain](https://github.com/hwchase17/langchain)和[llama-index](https://github.com/jerryjliu/llama_index)基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。
|
||||
|
||||
|
||||
@ -30,6 +32,26 @@ Run on an RTX 4090 GPU (The origin mov not sped up!, [YouTube地址](https://www
|
||||
|
||||
<img src="https://github.com/csunny/DB-GPT/blob/main/asserts/DB_QA.png" margin-left="auto" margin-right="auto" width="600">
|
||||
|
||||
# Dependencies
|
||||
1. First you need to install python requirements.
|
||||
```
|
||||
python>=3.9
|
||||
pip install -r requirements
|
||||
```
|
||||
or if you use conda envirenment, you can use this command
|
||||
```
|
||||
cd DB-GPT
|
||||
conda env create -f environment.yml
|
||||
```
|
||||
|
||||
2. MySQL Install
|
||||
|
||||
In this project examples, we connect mysql and run SQL-Generate. so you need install mysql local for test. recommand docker
|
||||
```
|
||||
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa123456 -dit mysql:latest
|
||||
```
|
||||
The password just for test, you can change this if necessary
|
||||
|
||||
# Install
|
||||
1. 基础模型下载
|
||||
关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。
|
||||
@ -52,3 +74,7 @@ python webserver.py
|
||||
- SQL-diagnosis
|
||||
|
||||
总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。
|
||||
|
||||
|
||||
# Licence
|
||||
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)
|
BIN
asserts/exeable.png
Normal file
BIN
asserts/exeable.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 277 KiB |
@ -61,3 +61,7 @@ dependencies:
|
||||
- gradio-client==0.0.8
|
||||
- wandb
|
||||
- fschat=0.1.10
|
||||
- llama-index=0.5.27
|
||||
- pymysql
|
||||
- unstructured==0.6.3
|
||||
- pytesseract==0.3.10
|
||||
|
2
pilot/chain/audio.py
Normal file
2
pilot/chain/audio.py
Normal file
@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
2
pilot/chain/visual.py
Normal file
2
pilot/chain/visual.py
Normal file
@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
@ -1,3 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
|
2
pilot/client/chart.py
Normal file
2
pilot/client/chart.py
Normal file
@ -0,0 +1,2 @@
|
||||
#/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
@ -11,6 +11,7 @@ PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
||||
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
|
||||
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
|
||||
@ -22,6 +23,7 @@ LLM_MODEL_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
VECTOR_SEARCH_TOP_K = 3
|
||||
LLM_MODEL = "vicuna-13b"
|
||||
LIMIT_MODEL_CONCURRENCY = 5
|
||||
MAX_POSITION_EMBEDDINGS = 2048
|
||||
|
@ -146,8 +146,24 @@ conv_vicuna_v1 = Conversation(
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
|
||||
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
|
||||
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
|
||||
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
|
||||
default_conversation = conv_one_shot
|
||||
|
||||
conversation_types = {
|
||||
"native": "LLM原生对话",
|
||||
"default_knownledge": "默认知识库对话",
|
||||
"custome": "新增知识库对话",
|
||||
}
|
||||
|
||||
conv_templates = {
|
||||
"conv_one_shot": conv_one_shot,
|
||||
"vicuna_v1": conv_vicuna_v1,
|
||||
|
@ -10,33 +10,29 @@ from typing import Any, Mapping, Optional, List
|
||||
from langchain.llms.base import LLM
|
||||
from pilot.configs.model_config import *
|
||||
|
||||
class VicunaRequestLLM(LLM):
|
||||
class VicunaLLM(LLM):
|
||||
|
||||
vicuna_generate_path = "generate_stream"
|
||||
def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str:
|
||||
|
||||
vicuna_generate_path = "generate"
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
if isinstance(stop, list):
|
||||
stop = stop + ["Observation:"]
|
||||
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"stop": stop
|
||||
}
|
||||
response = requests.post(
|
||||
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
|
||||
data=json.dumps(params),
|
||||
)
|
||||
response.raise_for_status()
|
||||
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
# if chunk:
|
||||
# data = json.loads(chunk.decode())
|
||||
# if data["error_code"] == 0:
|
||||
# output = data["text"][skip_echo_len:].strip()
|
||||
# output = self.post_process_code(output)
|
||||
# yield output
|
||||
return response.json()["response"]
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
yield output
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
180
pilot/pturning/lora/finetune.py
Normal file
180
pilot/pturning/lora/finetune.py
Normal file
@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import json
|
||||
import transformers
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
|
||||
from typing import List
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training,
|
||||
)
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
import pandas as pd
|
||||
|
||||
|
||||
from pilot.configs.model_config import DATA_DIR, LLM_MODEL, LLM_MODEL_CONFIG
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
CUTOFF_LEN = 50
|
||||
|
||||
df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv"))
|
||||
|
||||
def sentiment_score_to_name(score: float):
|
||||
if score > 0:
|
||||
return "Positive"
|
||||
elif score < 0:
|
||||
return "Negative"
|
||||
return "Neutral"
|
||||
|
||||
|
||||
dataset_data = [
|
||||
{
|
||||
"instruction": "Detect the sentiment of the tweet.",
|
||||
"input": row_dict["Tweet"],
|
||||
"output": sentiment_score_to_name(row_dict["New_Sentiment_State"])
|
||||
}
|
||||
for row_dict in df.to_dict(orient="records")
|
||||
]
|
||||
|
||||
with open(os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"), "w") as f:
|
||||
json.dump(dataset_data, f)
|
||||
|
||||
|
||||
data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"))
|
||||
print(data["train"])
|
||||
|
||||
BASE_MODEL = LLM_MODEL_CONFIG[LLM_MODEL]
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
offload_folder=os.path.join(DATA_DIR, "vicuna-lora")
|
||||
)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
||||
tokenizer.pad_token_id = (0)
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
def generate_prompt(data_point):
|
||||
return f"""Blow is an instruction that describes a task, paired with an input that provide future context.
|
||||
Write a response that appropriately completes the request. #noqa:
|
||||
|
||||
### Instruct:
|
||||
{data_point["instruction"]}
|
||||
### Input
|
||||
{data_point["input"]}
|
||||
### Response
|
||||
{data_point["output"]}
|
||||
"""
|
||||
|
||||
def tokenize(prompt, add_eos_token=True):
|
||||
result = tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=CUTOFF_LEN,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
if (result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < CUTOFF_LEN and add_eos_token):
|
||||
result["input_ids"].append(tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
full_prompt = generate_prompt(data_point)
|
||||
tokenized_full_prompt = tokenize(full_prompt)
|
||||
return tokenized_full_prompt
|
||||
|
||||
|
||||
train_val = data["train"].train_test_split(
|
||||
test_size=200, shuffle=True, seed=42
|
||||
)
|
||||
|
||||
train_data = (
|
||||
train_val["train"].map(generate_and_tokenize_prompt)
|
||||
)
|
||||
|
||||
val_data = (
|
||||
train_val["test"].map(generate_and_tokenize_prompt)
|
||||
)
|
||||
|
||||
# Training
|
||||
LORA_R = 8
|
||||
LORA_ALPHA = 16
|
||||
LORA_DROPOUT = 0.05
|
||||
LORA_TARGET_MODULES = [
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
]
|
||||
|
||||
BATCH_SIZE = 128
|
||||
MICRO_BATCH_SIZE = 4
|
||||
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
|
||||
LEARNING_RATE = 3e-4
|
||||
TRAIN_STEPS = 300
|
||||
OUTPUT_DIR = "experiments"
|
||||
|
||||
# We can now prepare model for training
|
||||
model = prepare_model_for_int8_training(model)
|
||||
config = LoraConfig(
|
||||
r = LORA_R,
|
||||
lora_alpha=LORA_ALPHA,
|
||||
target_modules=LORA_TARGET_MODULES,
|
||||
lora_dropout=LORA_DROPOUT,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = get_peft_model(model, config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
training_arguments = transformers.TrainingArguments(
|
||||
per_device_train_batch_size=MICRO_BATCH_SIZE,
|
||||
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
|
||||
warmup_steps=100,
|
||||
max_steps=TRAIN_STEPS,
|
||||
no_cuda=True,
|
||||
learning_rate=LEARNING_RATE,
|
||||
logging_steps=10,
|
||||
optim="adamw_torch",
|
||||
evaluation_strategy="steps",
|
||||
save_strategy="steps",
|
||||
eval_steps=50,
|
||||
save_steps=50,
|
||||
output_dir=OUTPUT_DIR,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True,
|
||||
report_to="tensorboard"
|
||||
)
|
||||
|
||||
data_collector = transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args=training_arguments,
|
||||
data_collector=data_collector
|
||||
)
|
||||
|
||||
model.config.use_cache = False
|
||||
old_state_dict = model.state_dict
|
||||
model.state_dict = (
|
||||
lambda self, *_, **__: get_peft_model_state_dict(
|
||||
self, old_state_dict()
|
||||
)
|
||||
).__get__(model, type(model))
|
||||
|
||||
trainer.train()
|
||||
model.save_pretrained(OUTPUT_DIR)
|
@ -4,29 +4,44 @@
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import urljoin
|
||||
import gradio as gr
|
||||
from pilot.configs.model_config import *
|
||||
vicuna_base_uri = "http://192.168.31.114:21002/"
|
||||
vicuna_stream_path = "worker_generate_stream"
|
||||
vicuna_status_path = "worker_get_status"
|
||||
from pilot.conversation import conv_qa_prompt_template, conv_templates
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
def generate(prompt):
|
||||
vicuna_stream_path = "generate_stream"
|
||||
|
||||
def generate(query):
|
||||
|
||||
template_name = "conv_one_shot"
|
||||
state = conv_templates[template_name].copy()
|
||||
|
||||
pt = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
|
||||
question=query)
|
||||
|
||||
print(result)
|
||||
|
||||
state.append_message(state.roles[0], result)
|
||||
state.append_message(state.roles[1], None)
|
||||
|
||||
prompt = state.get_prompt()
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 512,
|
||||
"max_new_tokens": 1024,
|
||||
"stop": "###"
|
||||
}
|
||||
|
||||
sts_response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_status_path)
|
||||
)
|
||||
print(sts_response.text)
|
||||
|
||||
response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
|
||||
url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
|
||||
)
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
@ -34,11 +49,10 @@ def generate(prompt):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"]
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield(output)
|
||||
|
||||
time.sleep(0.02)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(LLM_MODEL)
|
||||
with gr.Blocks() as demo:
|
||||
|
29
pilot/server/vectordb_qa.py
Normal file
29
pilot/server/vectordb_qa.py
Normal file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pilot.conversation import conv_qa_prompt_template
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from pilot.model.vicuna_llm import VicunaLLM
|
||||
|
||||
class KnownLedgeBaseQA:
|
||||
|
||||
def __init__(self) -> None:
|
||||
k2v = KnownLedge2Vector()
|
||||
self.vector_store = k2v.init_vector_store()
|
||||
self.llm = VicunaLLM()
|
||||
|
||||
def get_similar_answer(self, query):
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K})
|
||||
docs = retriever.get_relevant_documents(query=query)
|
||||
|
||||
context = [d.page_content for d in docs]
|
||||
result = prompt.format(context="\n".join(context), question=query)
|
||||
return result
|
@ -11,6 +11,7 @@ import datetime
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
from pilot.configs.model_config import DB_SETTINGS
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
from pilot.connections.mysql_conn import MySQLOperator
|
||||
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
|
||||
|
||||
@ -19,6 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
|
||||
from pilot.conversation import (
|
||||
default_conversation,
|
||||
conv_templates,
|
||||
conversation_types,
|
||||
SeparatorStyle
|
||||
)
|
||||
|
||||
@ -149,7 +151,7 @@ def post_process_code(code):
|
||||
code = sep.join(blocks)
|
||||
return code
|
||||
|
||||
def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
start_tstamp = time.time()
|
||||
model_name = LLM_MODEL
|
||||
|
||||
@ -170,7 +172,8 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
||||
|
||||
query = state.messages[-2][1]
|
||||
|
||||
# prompt 中添加上下文提示
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
@ -180,13 +183,11 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
state = new_state
|
||||
|
||||
# try:
|
||||
# if not db_selector:
|
||||
# sim_q = get_simlar(query)
|
||||
# print("********vector similar info*************: ", sim_q)
|
||||
# state.append_message(new_state.roles[0], sim_q + query)
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
query = state.messages[-2][1]
|
||||
knqa = KnownLedgeBaseQA()
|
||||
state.messages[-2][1] = knqa.get_similar_answer(query)
|
||||
|
||||
|
||||
prompt = state.get_prompt()
|
||||
|
||||
@ -222,7 +223,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
return
|
||||
time.sleep(0.02)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
@ -231,6 +232,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
# 记录运行日志
|
||||
finish_tstamp = time.time()
|
||||
logger.info(f"{output}")
|
||||
|
||||
@ -266,7 +268,7 @@ def change_tab(tab):
|
||||
pass
|
||||
|
||||
def change_mode(mode):
|
||||
if mode == "默认知识库对话":
|
||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||
return gr.update(visible=False)
|
||||
else:
|
||||
return gr.update(visible=True)
|
||||
@ -281,7 +283,7 @@ def build_single_model_ui():
|
||||
"""
|
||||
learn_more_markdown = """
|
||||
### Licence
|
||||
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
|
||||
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B
|
||||
"""
|
||||
|
||||
state = gr.State()
|
||||
@ -318,7 +320,8 @@ def build_single_model_ui():
|
||||
show_label=True).style(container=False)
|
||||
|
||||
with gr.TabItem("知识问答", elem_id="QA"):
|
||||
mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话")
|
||||
|
||||
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||
with vs_setting:
|
||||
@ -363,7 +366,7 @@ def build_single_model_ui():
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
[state, db_selector, temperature, max_output_tokens],
|
||||
[state, mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
@ -372,7 +375,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, db_selector, temperature, max_output_tokens],
|
||||
[state, mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
@ -380,7 +383,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, db_selector, temperature, max_output_tokens],
|
||||
[state, mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list
|
||||
)
|
||||
|
||||
|
@ -10,9 +10,8 @@ from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader
|
||||
from langchain.chains import VectorDBQA
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG
|
||||
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
|
||||
|
||||
VECTOR_SEARCH_TOP_K = 5
|
||||
|
||||
class KnownLedge2Vector:
|
||||
|
||||
@ -26,21 +25,21 @@ class KnownLedge2Vector:
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
|
||||
def init_vector_store(self):
|
||||
documents = self.load_knownlege()
|
||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||
print("向量数据库持久化地址: ", persist_dir)
|
||||
if os.path.exists(persist_dir):
|
||||
# 从本地持久化文件中Load
|
||||
print("从本地向量加载数据...")
|
||||
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
||||
vector_store.add_documents(documents=documents)
|
||||
# vector_store.add_documents(documents=documents)
|
||||
else:
|
||||
documents = self.load_knownlege()
|
||||
# 重新初始化
|
||||
vector_store = Chroma.from_documents(documents=documents,
|
||||
embedding=self.embeddings,
|
||||
persist_directory=persist_dir)
|
||||
vector_store.persist()
|
||||
vector_store = None
|
||||
return persist_dir
|
||||
return vector_store
|
||||
|
||||
def load_knownlege(self):
|
||||
docments = []
|
||||
@ -70,10 +69,23 @@ class KnownLedge2Vector:
|
||||
return docs
|
||||
|
||||
def _load_from_url(self, url):
|
||||
"""Load data from url address"""
|
||||
pass
|
||||
|
||||
|
||||
def query(self, q):
|
||||
"""Query similar doc from Vector """
|
||||
vector_store = self.init_vector_store()
|
||||
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
|
||||
for doc in docs:
|
||||
dc, s = doc
|
||||
yield s, dc
|
||||
|
||||
if __name__ == "__main__":
|
||||
k2v = KnownLedge2Vector()
|
||||
k2v.init_vector_store()
|
||||
|
||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||
print(persist_dir)
|
||||
for s, dc in k2v.query("什么是OceanBase"):
|
||||
print(s, dc.page_content, dc.metadata)
|
||||
|
@ -52,4 +52,6 @@ gradio-client==0.0.8
|
||||
wandb
|
||||
fschat=0.1.10
|
||||
llama-index=0.5.27
|
||||
pymysql
|
||||
pymysql
|
||||
unstructured==0.6.3
|
||||
pytesseract==0.3.10
|
Loading…
Reference in New Issue
Block a user