mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-12 05:32:32 +00:00
feature:db_summary bootstrap load
This commit is contained in:
parent
e29fa37cde
commit
4b41842277
@ -47,12 +47,13 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("Could not import DBSummaryClient. ")
|
raise ValueError("Could not import DBSummaryClient. ")
|
||||||
|
client = DBSummaryClient()
|
||||||
input_values = {
|
input_values = {
|
||||||
"input": self.current_user_input,
|
"input": self.current_user_input,
|
||||||
"top_k": str(self.top_k),
|
"top_k": str(self.top_k),
|
||||||
"dialect": self.database.dialect,
|
"dialect": self.database.dialect,
|
||||||
"table_info": self.database.table_simple_info(self.db_connect)
|
"table_info": self.database.table_simple_info(self.db_connect)
|
||||||
# "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
|
@ -45,7 +45,8 @@ class ChatWithDbQA(BaseChat):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("Could not import DBSummaryClient. ")
|
raise ValueError("Could not import DBSummaryClient. ")
|
||||||
if self.db_name:
|
if self.db_name:
|
||||||
table_info = DBSummaryClient.get_similar_tables(
|
client = DBSummaryClient()
|
||||||
|
table_info = client.get_similar_tables(
|
||||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||||
)
|
)
|
||||||
# table_info = self.database.table_simple_info(self.db_connect)
|
# table_info = self.database.table_simple_info(self.db_connect)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
import datetime
|
||||||
@ -414,7 +415,7 @@ def build_single_model_ui():
|
|||||||
show_label=True,
|
show_label=True,
|
||||||
).style(container=False)
|
).style(container=False)
|
||||||
|
|
||||||
db_selector.change(fn=db_selector_changed, inputs=db_selector)
|
# db_selector.change(fn=db_selector_changed, inputs=db_selector)
|
||||||
|
|
||||||
sql_mode = gr.Radio(
|
sql_mode = gr.Radio(
|
||||||
[
|
[
|
||||||
@ -618,10 +619,6 @@ def save_vs_name(vs_name):
|
|||||||
return vs_name
|
return vs_name
|
||||||
|
|
||||||
|
|
||||||
def db_selector_changed(dbname):
|
|
||||||
DBSummaryClient.db_summary_embedding(dbname)
|
|
||||||
|
|
||||||
|
|
||||||
def knowledge_embedding_store(vs_id, files):
|
def knowledge_embedding_store(vs_id, files):
|
||||||
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
|
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
|
||||||
@ -645,6 +642,12 @@ def knowledge_embedding_store(vs_id, files):
|
|||||||
return vs_id
|
return vs_id
|
||||||
|
|
||||||
|
|
||||||
|
def async_db_summery():
|
||||||
|
client = DBSummaryClient()
|
||||||
|
thread = threading.Thread(target=client.init_db_summary)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
@ -661,7 +664,7 @@ if __name__ == "__main__":
|
|||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
dbs = cfg.local_db.get_database_list()
|
dbs = cfg.local_db.get_database_list()
|
||||||
|
async_db_summery()
|
||||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||||
|
|
||||||
# 加载插件可执行命令
|
# 加载插件可执行命令
|
||||||
|
@ -1,14 +1,8 @@
|
|||||||
import os
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import markdown
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
from langchain.document_loaders import PyPDFLoader, TextLoader
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
|
|
||||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
|
||||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||||
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
||||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||||
@ -82,61 +76,3 @@ class KnowledgeEmbedding:
|
|||||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||||
)
|
)
|
||||||
return vector_client.vector_name_exists()
|
return vector_client.vector_name_exists()
|
||||||
|
|
||||||
def knowledge_persist_initialization(self, append_mode):
|
|
||||||
documents = self._load_knownlege(self.file_path)
|
|
||||||
self.vector_client = VectorStoreConnector(
|
|
||||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
|
||||||
)
|
|
||||||
self.vector_client.load_document(documents)
|
|
||||||
return self.vector_client
|
|
||||||
|
|
||||||
def _load_knownlege(self, path):
|
|
||||||
docments = []
|
|
||||||
for root, _, files in os.walk(path, topdown=False):
|
|
||||||
for file in files:
|
|
||||||
filename = os.path.join(root, file)
|
|
||||||
docs = self._load_file(filename)
|
|
||||||
new_docs = []
|
|
||||||
for doc in docs:
|
|
||||||
doc.metadata = {
|
|
||||||
"source": doc.metadata["source"].replace(DATASETS_DIR, "")
|
|
||||||
}
|
|
||||||
print("doc is embedding...", doc.metadata)
|
|
||||||
new_docs.append(doc)
|
|
||||||
docments += new_docs
|
|
||||||
return docments
|
|
||||||
|
|
||||||
def _load_file(self, filename):
|
|
||||||
if filename.lower().endswith(".md"):
|
|
||||||
loader = TextLoader(filename)
|
|
||||||
text_splitter = CHNDocumentSplitter(
|
|
||||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
|
||||||
)
|
|
||||||
docs = loader.load_and_split(text_splitter)
|
|
||||||
i = 0
|
|
||||||
for d in docs:
|
|
||||||
content = markdown.markdown(d.page_content)
|
|
||||||
soup = BeautifulSoup(content, "html.parser")
|
|
||||||
for tag in soup(["!doctype", "meta", "i.fa"]):
|
|
||||||
tag.extract()
|
|
||||||
docs[i].page_content = soup.get_text()
|
|
||||||
docs[i].page_content = docs[i].page_content.replace("\n", " ")
|
|
||||||
i += 1
|
|
||||||
elif filename.lower().endswith(".pdf"):
|
|
||||||
loader = PyPDFLoader(filename)
|
|
||||||
textsplitter = CHNDocumentSplitter(
|
|
||||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
|
||||||
)
|
|
||||||
docs = loader.load_and_split(textsplitter)
|
|
||||||
i = 0
|
|
||||||
for d in docs:
|
|
||||||
docs[i].page_content = d.page_content.replace("\n", " ").replace(
|
|
||||||
"<EFBFBD>", ""
|
|
||||||
)
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
loader = TextLoader(filename)
|
|
||||||
text_splitor = CHNDocumentSplitter(sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
|
|
||||||
docs = loader.load_and_split(text_splitor)
|
|
||||||
return docs
|
|
||||||
|
@ -21,8 +21,10 @@ class DBSummaryClient:
|
|||||||
, get_similar_tables method(get user query related tables info)
|
, get_similar_tables method(get user query related tables info)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self):
|
||||||
def db_summary_embedding(dbname):
|
pass
|
||||||
|
|
||||||
|
def db_summary_embedding(self, dbname):
|
||||||
"""put db profile and table profile summary into vector store"""
|
"""put db profile and table profile summary into vector store"""
|
||||||
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
|
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
|
||||||
db_summary_client = MysqlSummary(dbname)
|
db_summary_client = MysqlSummary(dbname)
|
||||||
@ -56,7 +58,7 @@ class DBSummaryClient:
|
|||||||
table_summary,
|
table_summary,
|
||||||
) in db_summary_client.get_table_summary().items():
|
) in db_summary_client.get_table_summary().items():
|
||||||
table_vector_store_config = {
|
table_vector_store_config = {
|
||||||
"vector_store_name": table_name + "_ts",
|
"vector_store_name": dbname + "_" + table_name + "_ts",
|
||||||
"embeddings": embeddings,
|
"embeddings": embeddings,
|
||||||
}
|
}
|
||||||
embedding = StringEmbedding(
|
embedding = StringEmbedding(
|
||||||
@ -67,8 +69,7 @@ class DBSummaryClient:
|
|||||||
|
|
||||||
logger.info("db summary embedding success")
|
logger.info("db summary embedding success")
|
||||||
|
|
||||||
@staticmethod
|
def get_similar_tables(self, dbname, query, topk):
|
||||||
def get_similar_tables(dbname, query, topk):
|
|
||||||
"""get user query related tables info"""
|
"""get user query related tables info"""
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": dbname + "_profile",
|
"vector_store_name": dbname + "_profile",
|
||||||
@ -94,7 +95,7 @@ class DBSummaryClient:
|
|||||||
related_table_summaries = []
|
related_table_summaries = []
|
||||||
for table in related_tables:
|
for table in related_tables:
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": table + "_ts",
|
"vector_store_name": dbname + "_" + table + "_ts",
|
||||||
}
|
}
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
file_path="",
|
file_path="",
|
||||||
@ -105,6 +106,12 @@ class DBSummaryClient:
|
|||||||
related_table_summaries.append(table_summery[0].page_content)
|
related_table_summaries.append(table_summery[0].page_content)
|
||||||
return related_table_summaries
|
return related_table_summaries
|
||||||
|
|
||||||
|
def init_db_summary(self):
|
||||||
|
db = CFG.local_db
|
||||||
|
dbs = db.get_database_list()
|
||||||
|
for dbname in dbs:
|
||||||
|
self.db_summary_embedding(dbname)
|
||||||
|
|
||||||
|
|
||||||
def _get_llm_response(query, db_input, dbsummary):
|
def _get_llm_response(query, db_input, dbsummary):
|
||||||
chat_param = {
|
chat_param = {
|
||||||
|
Loading…
Reference in New Issue
Block a user