mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 02:20:08 +00:00
load base knownledge
This commit is contained in:
parent
d9f5130db4
commit
529f077409
@ -8,7 +8,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
|||||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||||
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
|
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
|
||||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||||
|
DATASETS_DIR = os.path.join(ROOT_PATH, "datasets")
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
LLM_MODEL_CONFIG = {
|
LLM_MODEL_CONFIG = {
|
||||||
|
@ -12,7 +12,7 @@ import requests
|
|||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from pilot.configs.model_config import DB_SETTINGS
|
from pilot.configs.model_config import DB_SETTINGS
|
||||||
from pilot.connections.mysql_conn import MySQLOperator
|
from pilot.connections.mysql_conn import MySQLOperator
|
||||||
from pilot.vector_store.extract_tovec import get_vector_storelist
|
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL
|
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL
|
||||||
|
|
||||||
@ -48,6 +48,16 @@ priority = {
|
|||||||
"vicuna-13b": "aaa"
|
"vicuna-13b": "aaa"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_simlar(q):
|
||||||
|
|
||||||
|
docsearch = load_knownledge_from_doc()
|
||||||
|
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||||
|
|
||||||
|
contents = [dc.page_content for dc, _ in docs]
|
||||||
|
return "\n".join(contents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def gen_sqlgen_conversation(dbname):
|
def gen_sqlgen_conversation(dbname):
|
||||||
mo = MySQLOperator(
|
mo = MySQLOperator(
|
||||||
**DB_SETTINGS
|
**DB_SETTINGS
|
||||||
@ -150,6 +160,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
return
|
return
|
||||||
|
|
||||||
|
query = state.messages[-2][1]
|
||||||
if len(state.messages) == state.offset + 2:
|
if len(state.messages) == state.offset + 2:
|
||||||
# 第一轮对话需要加入提示Prompt
|
# 第一轮对话需要加入提示Prompt
|
||||||
|
|
||||||
@ -158,10 +169,14 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
new_state.conv_id = uuid.uuid4().hex
|
new_state.conv_id = uuid.uuid4().hex
|
||||||
|
|
||||||
# prompt 中添加上下文提示
|
# prompt 中添加上下文提示
|
||||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1])
|
if db_selector:
|
||||||
|
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
||||||
|
|
||||||
new_state.append_message(new_state.roles[1], None)
|
new_state.append_message(new_state.roles[1], None)
|
||||||
state = new_state
|
state = new_state
|
||||||
|
|
||||||
|
if not db_selector:
|
||||||
|
state.append_message(new_state.roles[0], get_simlar(query) + query)
|
||||||
|
|
||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
|
|
||||||
@ -237,6 +252,9 @@ pre {
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def change_tab(tab):
|
||||||
|
pass
|
||||||
|
|
||||||
def change_mode(mode):
|
def change_mode(mode):
|
||||||
if mode == "默认知识库对话":
|
if mode == "默认知识库对话":
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
@ -256,7 +274,6 @@ def build_single_model_ui():
|
|||||||
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
|
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vs_path, file_status, vs_list = gr.State(""), gr.State(""), gr.State()
|
|
||||||
state = gr.State()
|
state = gr.State()
|
||||||
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
||||||
|
|
||||||
@ -291,7 +308,6 @@ def build_single_model_ui():
|
|||||||
show_label=True).style(container=False)
|
show_label=True).style(container=False)
|
||||||
|
|
||||||
with gr.TabItem("知识问答", elem_id="QA"):
|
with gr.TabItem("知识问答", elem_id="QA"):
|
||||||
|
|
||||||
mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话")
|
mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话")
|
||||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||||
@ -331,9 +347,6 @@ def build_single_model_ui():
|
|||||||
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
||||||
clear_btn = gr.Button(value="清理", interactive=False)
|
clear_btn = gr.Button(value="清理", interactive=False)
|
||||||
|
|
||||||
# QA 模式下清空数据库选项
|
|
||||||
if tabs.elem_id == "QA":
|
|
||||||
db_selector = ""
|
|
||||||
|
|
||||||
gr.Markdown(learn_more_markdown)
|
gr.Markdown(learn_more_markdown)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import os
|
|||||||
from langchain.text_splitter import CharacterTextSplitter
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
from pilot.model.vicuna_llm import VicunaEmbeddingLLM
|
from pilot.model.vicuna_llm import VicunaEmbeddingLLM
|
||||||
from pilot.configs.model_config import VECTORE_PATH
|
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
embeddings = VicunaEmbeddingLLM()
|
embeddings = VicunaEmbeddingLLM()
|
||||||
@ -38,6 +38,33 @@ def knownledge_tovec_st(filename):
|
|||||||
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))])
|
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))])
|
||||||
return docsearch
|
return docsearch
|
||||||
|
|
||||||
|
|
||||||
|
def load_knownledge_from_doc():
|
||||||
|
"""从数据集当中加载知识
|
||||||
|
# TODO 如果向量存储已经存在, 则无需初始化
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not os.path.exists(DATASETS_DIR):
|
||||||
|
print("Not Exists Local DataSets, We will answers the Question use model default.")
|
||||||
|
|
||||||
|
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||||
|
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
files = os.listdir(DATASETS_DIR)
|
||||||
|
for file in files:
|
||||||
|
if not os.path.isdir(file):
|
||||||
|
with open(file, "r") as f:
|
||||||
|
doc = f.read()
|
||||||
|
docs.append(docs)
|
||||||
|
|
||||||
|
print(doc)
|
||||||
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0)
|
||||||
|
texts = text_splitter.split_text("\n".join(docs))
|
||||||
|
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))],
|
||||||
|
persist_directory=os.path.join(VECTORE_PATH, ".vectore"))
|
||||||
|
return docsearch
|
||||||
|
|
||||||
def get_vector_storelist():
|
def get_vector_storelist():
|
||||||
if not os.path.exists(VECTORE_PATH):
|
if not os.path.exists(VECTORE_PATH):
|
||||||
return []
|
return []
|
||||||
|
Loading…
Reference in New Issue
Block a user