From a2f087c429c22b59640f1d6d3f8ad53a3b316880 Mon Sep 17 00:00:00 2001 From: Aries-ckt <916701291@qq.com> Date: Mon, 4 Dec 2023 22:16:28 +0800 Subject: [PATCH] refactor(ChatData):update rdbms db summary (#885) --- .../connections/manages/connection_manager.py | 67 ------ pilot/summary/db_summary_client.py | 118 +---------- pilot/summary/rdbms_db_summary.py | 193 ++++-------------- 3 files changed, 48 insertions(+), 330 deletions(-) diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index 261f9b5cf..b7f4c749d 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -60,73 +60,6 @@ class ConnectManager: # self.storage = DuckdbConnectConfig() self.storage = ConnectConfigDao() self.db_summary_client = DBSummaryClient(system_app) - # self.__load_config_db() - - # def __load_config_db(self): - # if CFG.LOCAL_DB_HOST: - # # default mysql - # if CFG.LOCAL_DB_NAME: - # self.storage.add_url_db( - # CFG.LOCAL_DB_NAME, - # DBType.Mysql.value(), - # CFG.LOCAL_DB_HOST, - # CFG.LOCAL_DB_PORT, - # CFG.LOCAL_DB_USER, - # CFG.LOCAL_DB_PASSWORD, - # "", - # ) - # else: - # # get all default mysql database - # default_mysql = Database.from_uri( - # "mysql+pymysql://" - # + CFG.LOCAL_DB_USER - # + ":" - # + CFG.LOCAL_DB_PASSWORD - # + "@" - # + CFG.LOCAL_DB_HOST - # + ":" - # + str(CFG.LOCAL_DB_PORT), - # engine_args={ - # "pool_size": CFG.LOCAL_DB_POOL_SIZE, - # "pool_recycle": 3600, - # "echo": True, - # }, - # ) - # dbs = default_mysql.get_database_list() - # for name in dbs: - # self.storage.add_url_db( - # name, - # DBType.Mysql.value(), - # CFG.LOCAL_DB_HOST, - # CFG.LOCAL_DB_PORT, - # CFG.LOCAL_DB_USER, - # CFG.LOCAL_DB_PASSWORD, - # "", - # ) - # db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE) - # if db_type.is_file_db(): - # db_name = CFG.LOCAL_DB_NAME - # db_type = CFG.LOCAL_DB_TYPE - # db_path = CFG.LOCAL_DB_PATH - # if not db_type: - # # Default file database type - # db_type = DBType.DuckDb.value() - # if not db_name: - # db_type, db_name = self._parse_file_db_info(db_type, db_path) - # if db_name: - # print( - # f"Add file db, db_name: {db_name}, db_type: {db_type}, db_path: {db_path}" - # ) - # self.storage.add_file_db(db_name, db_type, db_path) - - # def _parse_file_db_info(self, db_type: str, db_path: str): - # if db_type is None or db_type == DBType.DuckDb.value(): - # # file db is duckdb - # db_name = self.storage.get_file_db_name(db_path) - # db_type = DBType.DuckDb.value() - # else: - # db_name = DBType.parse_file_db_name_from_path(db_type, db_path) - # return db_type, db_name def get_connect(self, db_name): db_config = self.storage.get_db_config(db_name) diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 76c72ec1b..579ff9545 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -2,16 +2,12 @@ import json import uuid import logging -from pilot.common.schema import DBType from pilot.component import SystemApp from pilot.configs.config import Config from pilot.configs.model_config import ( - KNOWLEDGE_UPLOAD_ROOT_PATH, EMBEDDING_MODEL_CONFIG, ) -from pilot.scene.base import ChatScene -from pilot.scene.base_chat import BaseChat from pilot.scene.chat_factory import ChatFactory from pilot.summary.rdbms_db_summary import RdbmsSummary @@ -33,7 +29,6 @@ class DBSummaryClient: def db_summary_embedding(self, dbname, db_type): """put db profile and table profile summary into vector store""" - from pilot.embedding_engine.string_embedding import StringEmbedding from pilot.embedding_engine.embedding_factory import EmbeddingFactory db_summary_client = RdbmsSummary(dbname, db_type) @@ -43,48 +38,12 @@ class DBSummaryClient: embeddings = embedding_factory.create( model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] ) - vector_store_config = { - "vector_store_name": dbname + "_summary", - "vector_store_type": CFG.VECTOR_STORE_TYPE, - "embeddings": embeddings, - } - embedding = StringEmbedding( - file_path=db_summary_client.get_summary(), - vector_store_config=vector_store_config, - ) self.init_db_profile(db_summary_client, dbname, embeddings) - if not embedding.vector_name_exist(): - if CFG.SUMMARY_CONFIG == "FAST": - for vector_table_info in db_summary_client.get_summary(): - embedding = StringEmbedding( - vector_table_info, - vector_store_config, - ) - embedding.source_embedding() - else: - embedding = StringEmbedding( - file_path=db_summary_client.get_summary(), - vector_store_config=vector_store_config, - ) - embedding.source_embedding() - for ( - table_name, - table_summary, - ) in db_summary_client.get_table_summary().items(): - table_vector_store_config = { - "vector_store_name": dbname + "_" + table_name + "_ts", - "vector_store_type": CFG.VECTOR_STORE_TYPE, - "embeddings": embeddings, - } - embedding = StringEmbedding( - table_summary, - table_vector_store_config, - ) - embedding.source_embedding() logger.info("db summary embedding success") def get_db_summary(self, dbname, query, topk): + """get user query related tables info""" from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.embedding_factory import EmbeddingFactory @@ -104,53 +63,8 @@ class DBSummaryClient: ans = [d.page_content for d in table_docs] return ans - def get_similar_tables(self, dbname, query, topk): - """get user query related tables info""" - from pilot.embedding_engine.embedding_engine import EmbeddingEngine - from pilot.embedding_engine.embedding_factory import EmbeddingFactory - - vector_store_config = { - "vector_store_name": dbname + "_summary", - "vector_store_type": CFG.VECTOR_STORE_TYPE, - } - embedding_factory = CFG.SYSTEM_APP.get_component( - "embedding_factory", EmbeddingFactory - ) - knowledge_embedding_client = EmbeddingEngine( - model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], - vector_store_config=vector_store_config, - embedding_factory=embedding_factory, - ) - if CFG.SUMMARY_CONFIG == "FAST": - table_docs = knowledge_embedding_client.similar_search(query, topk) - related_tables = [ - json.loads(table_doc.page_content)["table_name"] - for table_doc in table_docs - ] - else: - table_docs = knowledge_embedding_client.similar_search(query, 1) - # prompt = KnownLedgeBaseQA.build_db_summary_prompt( - # query, table_docs[0].page_content - # ) - related_tables = _get_llm_response( - query, dbname, table_docs[0].page_content - ) - related_table_summaries = [] - for table in related_tables: - vector_store_config = { - "vector_store_name": dbname + "_" + table + "_ts", - "vector_store_type": CFG.VECTOR_STORE_TYPE, - } - knowledge_embedding_client = EmbeddingEngine( - model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], - vector_store_config=vector_store_config, - embedding_factory=embedding_factory, - ) - table_summary = knowledge_embedding_client.similar_search(query, 1) - related_table_summaries.append(table_summary[0].page_content) - return related_table_summaries - def init_db_summary(self): + """init db summary""" db_mange = CFG.LOCAL_DB_MANAGE dbs = db_mange.get_db_list() for item in dbs: @@ -177,17 +91,16 @@ class DBSummaryClient: "embeddings": embeddings, } embedding = StringEmbedding( - file_path=db_summary_client.get_db_summary(), + file_path=None, vector_store_config=profile_store_config, ) if not embedding.vector_name_exist(): docs = [] - docs.extend(embedding.read_batch()) - for table_summary in db_summary_client.table_info_json(): + for table_summary in db_summary_client.table_summaries(): from langchain.text_splitter import RecursiveCharacterTextSplitter text_splitter = RecursiveCharacterTextSplitter( - chunk_size=len(table_summary), chunk_overlap=100 + chunk_size=len(table_summary), chunk_overlap=0 ) embedding = StringEmbedding( file_path=table_summary, @@ -195,23 +108,8 @@ class DBSummaryClient: text_splitter=text_splitter, ) docs.extend(embedding.read_batch()) - embedding.index_to_store(docs) + if len(docs) > 0: + embedding.index_to_store(docs) else: logger.info(f"Vector store name {vector_store_name} exist") - logger.info("init db profile success...") - - -def _get_llm_response(query, db_input, dbsummary): - chat_param = { - "temperature": 0.7, - "max_new_tokens": 512, - "chat_session_id": uuid.uuid1(), - "user_input": query, - "db_select": db_input, - "db_summary": dbsummary, - } - chat: BaseChat = chat_factory.get_implementation( - ChatScene.InnerChatDBSummary.value, **chat_param - ) - res = chat._blocking_nostream_call() - return json.loads(res)["table"] + logger.info("initialize db summary profile success...") diff --git a/pilot/summary/rdbms_db_summary.py b/pilot/summary/rdbms_db_summary.py index 95c603df8..e5c9c0414 100644 --- a/pilot/summary/rdbms_db_summary.py +++ b/pilot/summary/rdbms_db_summary.py @@ -1,22 +1,22 @@ -import json - from pilot.configs.config import Config -from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, IndexSummary +from pilot.summary.db_summary import DBSummary CFG = Config() class RdbmsSummary(DBSummary): - """Get mysql summary template.""" + """Get rdbms db table summary template. + summary example: + table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment is {table_comment}) + """ def __init__(self, name, type): self.name = name self.type = type - self.summary = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" + self.summary_template = "{table_name}({columns})" self.tables = {} self.tables_info = [] self.vector_tables_info = [] - # self.tables_summary = {} self.db = CFG.LOCAL_DB_MANAGE.get_connect(name) @@ -27,154 +27,41 @@ class RdbmsSummary(DBSummary): collation=self.db.get_collation(), ) tables = self.db.get_table_names() - self.table_comments = self.db.get_table_comments(name) - comment_map = {} - for table_comment in self.table_comments: - self.tables_info.append( - "table name:{table_name},table description:{table_comment}".format( - table_name=table_comment[0], table_comment=table_comment[1] - ) - ) - comment_map[table_comment[0]] = table_comment[1] + self.table_info_summaries = [ + self.get_table_summary(table_name) for table_name in tables + ] - vector_table = json.dumps( - {"table_name": table_comment[0], "table_description": table_comment[1]} - ) - self.vector_tables_info.append( - vector_table.encode("utf-8").decode("unicode_escape") - ) - self.table_columns_info = [] - self.table_columns_json = [] + def get_table_summary(self, table_name): + """Get table summary for table. + example: + table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment}) + """ + columns = [] + for column in self.db._inspector.get_columns(table_name): + if column.get("comment"): + columns.append((f"{column['name']} ({column.get('comment')})")) + else: + columns.append(f"{column['name']}") - for table_name in tables: - table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map) - # self.tables[table_name] = table_summary.get_summary() - self.tables[table_name] = table_summary.get_columns() - self.table_columns_info.append(table_summary.get_columns()) - # self.table_columns_json.append(table_summary.get_summary_json()) - table_profile = ( - "table name:{table_name},table description:{table_comment}".format( - table_name=table_name, - table_comment=self.db.get_show_create_table(table_name), - ) - ) - self.table_columns_json.append(table_profile) - # self.tables_info.append(table_summary.get_summary()) - - def get_summary(self): - if CFG.SUMMARY_CONFIG == "FAST": - return self.vector_tables_info - else: - return self.summary.format( - name=self.name, type=self.type, table_info=";".join(self.tables_info) - ) - - def get_db_summary(self): - return self.summary.format( - name=self.name, - type=self.type, - tables=";".join(self.vector_tables_info), - qps=1000, - tps=1000, + column_str = ", ".join(columns) + index_keys = [] + for index_key in self.db._inspector.get_indexes(table_name): + key_str = ", ".join(index_key["column_names"]) + index_keys.append(f"{index_key['name']}(`{key_str}`) ") + table_str = self.summary_template.format( + table_name=table_name, columns=column_str ) + if len(index_keys) > 0: + index_key_str = ", ".join(index_keys) + table_str += f", and index keys: {index_key_str}" + try: + comment = self.db._inspector.get_table_comment(table_name) + except Exception: + comment = dict(text=None) + if comment.get("text"): + table_str += f", and table comment: {comment.get('text')}" + return table_str - def get_table_summary(self): - return self.tables - - def get_table_comments(self): - return self.table_comments - - def table_info_json(self): - return self.table_columns_json - - -class RdbmsTableSummary(TableSummary): - """Get mysql table summary template.""" - - def __init__(self, instance, dbname, name, comment_map): - self.name = name - self.dbname = dbname - self.summary = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}""" - self.json_summary_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}""" - self.fields = [] - self.fields_info = [] - self.indexes = [] - self.indexes_info = [] - self.db = instance - fields = self.db.get_fields(name) - indexes = self.db.get_indexes(name) - field_names = [] - for field in fields: - field_summary = RdbmsFieldsSummary(field) - self.fields.append(field_summary) - self.fields_info.append(field_summary.get_summary()) - field_names.append(field[0]) - - self.column_summary = """{name}({columns_info})""".format( - name=name, columns_info=",".join(field_names) - ) - - for index in indexes: - index_summary = RdbmsIndexSummary(index) - self.indexes.append(index_summary) - self.indexes_info.append(index_summary.get_summary()) - - self.json_summary = self.json_summary_template.format( - name=name, - comment=comment_map[name], - fields=self.fields_info, - indexes=self.indexes_info, - size_in_bytes=1000, - rows=1000, - ) - - def get_summary(self): - return self.summary.format( - name=self.name, - dbname=self.dbname, - fields=";".join(self.fields_info), - indexes=";".join(self.indexes_info), - ) - - def get_columns(self): - return self.column_summary - - def get_summary_json(self): - return self.json_summary - - -class RdbmsFieldsSummary(FieldSummary): - """Get mysql field summary template.""" - - def __init__(self, field): - self.name = field[0] - # self.summary = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """ - # self.summary = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}""" - self.data_type = field[1] - self.default_value = field[2] - self.is_nullable = field[3] - self.comment = field[4] - - def get_summary(self): - return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format( - name=self.name, - data_type=self.data_type, - is_nullable=self.is_nullable, - default_value=self.default_value, - comment=self.comment, - ) - - -class RdbmsIndexSummary(IndexSummary): - """Get mysql index summary template.""" - - def __init__(self, index): - self.name = index[0] - # self.summary = """index name:{name}, index bind columns:{bind_fields}""" - self.summary_template = '{{"name": "{name}", "columns": {bind_fields}}}' - self.bind_fields = index[1] - - def get_summary(self): - return self.summary_template.format( - name=self.name, bind_fields=self.bind_fields - ) + def table_summaries(self): + """Get table summaries.""" + return self.table_info_summaries