diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 1f4597789..73c732713 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -47,12 +47,13 @@ class ChatWithDbAutoExecute(BaseChat): from pilot.summary.db_summary_client import DBSummaryClient except ImportError: raise ValueError("Could not import DBSummaryClient. ") + client = DBSummaryClient() input_values = { "input": self.current_user_input, "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": self.database.table_simple_info(self.db_connect) - # "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) + # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) } return input_values diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index 66b751533..faffcc146 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -45,7 +45,8 @@ class ChatWithDbQA(BaseChat): except ImportError: raise ValueError("Could not import DBSummaryClient. ") if self.db_name: - table_info = DBSummaryClient.get_similar_tables( + client = DBSummaryClient() + table_info = client.get_similar_tables( dbname=self.db_name, query=self.current_user_input, topk=self.top_k ) # table_info = self.database.table_simple_info(self.db_connect) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 239fc5d9e..f8626f7b4 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import threading import traceback import argparse import datetime @@ -414,7 +415,7 @@ def build_single_model_ui(): show_label=True, ).style(container=False) - db_selector.change(fn=db_selector_changed, inputs=db_selector) + # db_selector.change(fn=db_selector_changed, inputs=db_selector) sql_mode = gr.Radio( [ @@ -618,10 +619,6 @@ def save_vs_name(vs_name): return vs_name -def db_selector_changed(dbname): - DBSummaryClient.db_summary_embedding(dbname) - - def knowledge_embedding_store(vs_id, files): # vs_path = os.path.join(VS_ROOT_PATH, vs_id) if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)): @@ -645,6 +642,12 @@ def knowledge_embedding_store(vs_id, files): return vs_id +def async_db_summery(): + client = DBSummaryClient() + thread = threading.Thread(target=client.init_db_summary) + thread.start() + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") @@ -661,7 +664,7 @@ if __name__ == "__main__": cfg = Config() dbs = cfg.local_db.get_database_list() - + async_db_summery() cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) # 加载插件可执行命令 diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 1e072c861..27297111b 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -1,14 +1,8 @@ -import os from typing import Optional -import markdown -from bs4 import BeautifulSoup -from langchain.document_loaders import PyPDFLoader, TextLoader from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config -from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE -from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding @@ -82,61 +76,3 @@ class KnowledgeEmbedding: CFG.VECTOR_STORE_TYPE, self.vector_store_config ) return vector_client.vector_name_exists() - - def knowledge_persist_initialization(self, append_mode): - documents = self._load_knownlege(self.file_path) - self.vector_client = VectorStoreConnector( - CFG.VECTOR_STORE_TYPE, self.vector_store_config - ) - self.vector_client.load_document(documents) - return self.vector_client - - def _load_knownlege(self, path): - docments = [] - for root, _, files in os.walk(path, topdown=False): - for file in files: - filename = os.path.join(root, file) - docs = self._load_file(filename) - new_docs = [] - for doc in docs: - doc.metadata = { - "source": doc.metadata["source"].replace(DATASETS_DIR, "") - } - print("doc is embedding...", doc.metadata) - new_docs.append(doc) - docments += new_docs - return docments - - def _load_file(self, filename): - if filename.lower().endswith(".md"): - loader = TextLoader(filename) - text_splitter = CHNDocumentSplitter( - pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE - ) - docs = loader.load_and_split(text_splitter) - i = 0 - for d in docs: - content = markdown.markdown(d.page_content) - soup = BeautifulSoup(content, "html.parser") - for tag in soup(["!doctype", "meta", "i.fa"]): - tag.extract() - docs[i].page_content = soup.get_text() - docs[i].page_content = docs[i].page_content.replace("\n", " ") - i += 1 - elif filename.lower().endswith(".pdf"): - loader = PyPDFLoader(filename) - textsplitter = CHNDocumentSplitter( - pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE - ) - docs = loader.load_and_split(textsplitter) - i = 0 - for d in docs: - docs[i].page_content = d.page_content.replace("\n", " ").replace( - "�", "" - ) - i += 1 - else: - loader = TextLoader(filename) - text_splitor = CHNDocumentSplitter(sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) - docs = loader.load_and_split(text_splitor) - return docs diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 3dfbede72..51f124f62 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -21,8 +21,10 @@ class DBSummaryClient: , get_similar_tables method(get user query related tables info) """ - @staticmethod - def db_summary_embedding(dbname): + def __init__(self): + pass + + def db_summary_embedding(self, dbname): """put db profile and table profile summary into vector store""" if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None: db_summary_client = MysqlSummary(dbname) @@ -56,7 +58,7 @@ class DBSummaryClient: table_summary, ) in db_summary_client.get_table_summary().items(): table_vector_store_config = { - "vector_store_name": table_name + "_ts", + "vector_store_name": dbname + "_" + table_name + "_ts", "embeddings": embeddings, } embedding = StringEmbedding( @@ -67,8 +69,7 @@ class DBSummaryClient: logger.info("db summary embedding success") - @staticmethod - def get_similar_tables(dbname, query, topk): + def get_similar_tables(self, dbname, query, topk): """get user query related tables info""" vector_store_config = { "vector_store_name": dbname + "_profile", @@ -94,7 +95,7 @@ class DBSummaryClient: related_table_summaries = [] for table in related_tables: vector_store_config = { - "vector_store_name": table + "_ts", + "vector_store_name": dbname + "_" + table + "_ts", } knowledge_embedding_client = KnowledgeEmbedding( file_path="", @@ -105,6 +106,12 @@ class DBSummaryClient: related_table_summaries.append(table_summery[0].page_content) return related_table_summaries + def init_db_summary(self): + db = CFG.local_db + dbs = db.get_database_list() + for dbname in dbs: + self.db_summary_embedding(dbname) + def _get_llm_response(query, db_input, dbsummary): chat_param = {