From 77f6b2f4585032ee7765866cb119e96a90f4f359 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Tue, 1 Aug 2023 15:35:59 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9ADb=20summary=20merge=20multi=20db?= =?UTF-8?q?=20connect?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pilot/common/plugins.py | 2 - .../connections/manages/connection_manager.py | 2 +- .../rdbms/{rdbms_connect.py => base.py} | 0 pilot/connections/rdbms/clickhouse.py | 38 ---- pilot/connections/rdbms/conn_duckdb.py | 29 ++- pilot/connections/rdbms/conn_mssql.py | 2 +- pilot/connections/rdbms/conn_mysql.py | 2 +- pilot/connections/rdbms/oracle.py | 13 -- pilot/connections/rdbms/postgres.py | 12 -- pilot/server/dbgpt_server.py | 21 +- pilot/server/webserver_base.py | 18 +- pilot/summary/db_summary_client.py | 12 +- pilot/summary/rdbms_db_summary.py | 180 ++++++++++++++++++ pilot/utils.py | 2 +- 14 files changed, 237 insertions(+), 96 deletions(-) rename pilot/connections/rdbms/{rdbms_connect.py => base.py} (100%) delete mode 100644 pilot/connections/rdbms/clickhouse.py delete mode 100644 pilot/connections/rdbms/oracle.py delete mode 100644 pilot/connections/rdbms/postgres.py create mode 100644 pilot/summary/rdbms_db_summary.py diff --git a/pilot/common/plugins.py b/pilot/common/plugins.py index e22224399..50ec56606 100644 --- a/pilot/common/plugins.py +++ b/pilot/common/plugins.py @@ -131,8 +131,6 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate # Generic plugins plugins_path_path = Path(PLUGINS_DIR) - logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}") - logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}") for plugin in plugins_path_path.glob("*.zip"): if moduleList := inspect_zip_for_modules(str(plugin), debug): diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index 291127127..43a4dce2c 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -7,7 +7,7 @@ from pilot.connections.base import BaseConnect from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.rdbms.conn_duckdb import DuckDbConnect from pilot.connections.rdbms.conn_mssql import MSSQLConnect -from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase +from pilot.connections.rdbms.base import RDBMSDatabase from pilot.singleton import Singleton from pilot.common.sql_database import Database from pilot.connections.db_conn_info import DBConfig diff --git a/pilot/connections/rdbms/rdbms_connect.py b/pilot/connections/rdbms/base.py similarity index 100% rename from pilot/connections/rdbms/rdbms_connect.py rename to pilot/connections/rdbms/base.py diff --git a/pilot/connections/rdbms/clickhouse.py b/pilot/connections/rdbms/clickhouse.py deleted file mode 100644 index fdf8e803b..000000000 --- a/pilot/connections/rdbms/clickhouse.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase - -from pilot.configs.config import Config - -CFG = Config() - - -class ClickHouseConnector(RDBMSDatabase): - """ClickHouseConnector""" - - db_type: str = "duckdb" - - driver: str = "duckdb" - - file_path: str - - default_db = ["information_schema", "performance_schema", "sys", "mysql"] - - @classmethod - def from_config(cls) -> RDBMSDatabase: - """ - Todo password encryption - Returns: - """ - return cls.from_uri_db( - cls, - CFG.LOCAL_DB_PATH, - engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}, - ) - - @classmethod - def from_uri_db( - cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any - ) -> RDBMSDatabase: - db_url: str = cls.connect_driver + "://" + db_path - return cls.from_uri(db_url, engine_args, **kwargs) diff --git a/pilot/connections/rdbms/conn_duckdb.py b/pilot/connections/rdbms/conn_duckdb.py index 90d2fc6e9..928d3e7c0 100644 --- a/pilot/connections/rdbms/conn_duckdb.py +++ b/pilot/connections/rdbms/conn_duckdb.py @@ -9,7 +9,7 @@ from sqlalchemy import ( ) from sqlalchemy.ext.declarative import declarative_base -from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase +from pilot.connections.rdbms.base import RDBMSDatabase class DuckDbConnect(RDBMSDatabase): @@ -29,6 +29,33 @@ class DuckDbConnect(RDBMSDatabase): _engine_args = engine_args or {} return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs) + def get_users(self): + cursor = self.session.execute(text(f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';")) + users = cursor.fetchall() + return [(user[0], user[1]) for user in users] + + def get_grants(self): + return [] + + def get_collation(self): + """Get collation.""" + return "UTF-8" + def get_charset(self): + return "UTF-8" + + def get_table_comments(self, db_name): + cursor = self.session.execute( + text( + f""" + SELECT name, sql FROM sqlite_master WHERE type='table' + """ + ) + ) + table_comments = cursor.fetchall() + return [ + (table_comment[0], table_comment[1]) for table_comment in table_comments + ] + def table_simple_info(self) -> Iterable[str]: _tables_sql = f""" SELECT name FROM sqlite_master WHERE type='table' diff --git a/pilot/connections/rdbms/conn_mssql.py b/pilot/connections/rdbms/conn_mssql.py index a248f8f66..c2eb947aa 100644 --- a/pilot/connections/rdbms/conn_mssql.py +++ b/pilot/connections/rdbms/conn_mssql.py @@ -10,7 +10,7 @@ from sqlalchemy import ( select, text, ) -from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase +from pilot.connections.rdbms.base import RDBMSDatabase class MSSQLConnect(RDBMSDatabase): diff --git a/pilot/connections/rdbms/conn_mysql.py b/pilot/connections/rdbms/conn_mysql.py index 196919dbe..3bdd1ee44 100644 --- a/pilot/connections/rdbms/conn_mysql.py +++ b/pilot/connections/rdbms/conn_mysql.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from typing import Optional, Any -from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase +from pilot.connections.rdbms.base import RDBMSDatabase class MySQLConnect(RDBMSDatabase): diff --git a/pilot/connections/rdbms/oracle.py b/pilot/connections/rdbms/oracle.py deleted file mode 100644 index e28c17fe5..000000000 --- a/pilot/connections/rdbms/oracle.py +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase - - -class OracleConnector(RDBMSDatabase): - """OracleConnector""" - - db_type: str = "oracle" - - driver: str = "oracle" - - default_db = ["SYS", "SYSTEM", "OUTLN", "ORDDATA", "XDB"] diff --git a/pilot/connections/rdbms/postgres.py b/pilot/connections/rdbms/postgres.py deleted file mode 100644 index 7873eaaf4..000000000 --- a/pilot/connections/rdbms/postgres.py +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase - - -class PostgresConnector(RDBMSDatabase): - """PostgresConnector is a class which Connector""" - - db_type: str = "postgresql" - driver: str = "postgresql" - - default_db = ["information_schema", "performance_schema", "sys", "mysql"] diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index def585069..045e0981c 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -4,17 +4,17 @@ import os import shutil import argparse import sys - +import logging ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) import signal from pilot.configs.config import Config -from pilot.configs.model_config import ( - DATASETS_DIR, - KNOWLEDGE_UPLOAD_ROOT_PATH, - LLM_MODEL_CONFIG, - LOGDIR, -) +# from pilot.configs.model_config import ( +# DATASETS_DIR, +# KNOWLEDGE_UPLOAD_ROOT_PATH, +# LLM_MODEL_CONFIG, +# LOGDIR, +# ) from pilot.utils import build_logger from pilot.server.webserver_base import server_init @@ -30,11 +30,13 @@ from pilot.server.knowledge.api import router as knowledge_router from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler +logging.basicConfig(level=logging.INFO) + static_file_path = os.path.join(os.getcwd(), "server/static") CFG = Config() -logger = build_logger("webserver", LOGDIR + "webserver.log") +# logger = build_logger("webserver", LOGDIR + "webserver.log") def signal_handler(): @@ -113,5 +115,6 @@ if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=args.port) + logging.basicConfig(level=logging.INFO) + uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=0) signal.signal(signal.SIGINT, signal_handler()) diff --git a/pilot/server/webserver_base.py b/pilot/server/webserver_base.py index 243b81bd4..1d39d7bfa 100644 --- a/pilot/server/webserver_base.py +++ b/pilot/server/webserver_base.py @@ -7,12 +7,12 @@ import sys from pilot.summary.db_summary_client import DBSummaryClient from pilot.commands.command_mange import CommandRegistry from pilot.configs.config import Config -from pilot.configs.model_config import ( - DATASETS_DIR, - KNOWLEDGE_UPLOAD_ROOT_PATH, - LLM_MODEL_CONFIG, - LOGDIR, -) +# from pilot.configs.model_config import ( +# DATASETS_DIR, +# KNOWLEDGE_UPLOAD_ROOT_PATH, +# LLM_MODEL_CONFIG, +# LOGDIR, +# ) from pilot.common.plugins import scan_plugins, load_native_plugins from pilot.utils import build_logger from pilot.connections.manages.connection_manager import ConnectManager @@ -20,7 +20,7 @@ from pilot.connections.manages.connection_manager import ConnectManager ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -logger = build_logger("webserver", LOGDIR + "webserver.log") +# logger = build_logger("webserver", LOGDIR + "webserver.log") def signal_handler(sig, frame): @@ -35,7 +35,7 @@ def async_db_summery(): def server_init(args): - logger.info(f"args: {args}") + # logger.info(f"args: {args}") # init config cfg = Config() @@ -43,7 +43,7 @@ def server_init(args): conn_manage = ConnectManager() cfg.LOCAL_DB_MANAGE = conn_manage - load_native_plugins(cfg) + # load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler) async_db_summery() cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index dfb02237b..2aa1b11db 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -9,11 +9,10 @@ from pilot.scene.base import ChatScene from pilot.scene.base_chat import BaseChat from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.string_embedding import StringEmbedding -from pilot.summary.mysql_db_summary import MysqlSummary +from pilot.summary.rdbms_db_summary import RdbmsSummary from pilot.scene.chat_factory import ChatFactory from pilot.common.schema import DBType - CFG = Config() chat_factory = ChatFactory() @@ -28,11 +27,8 @@ class DBSummaryClient: def db_summary_embedding(self, dbname, db_type): """put db profile and table profile summary into vector store""" - if DBType.Mysql.value() == db_type: - db_summary_client = MysqlSummary(dbname) - else: - raise ValueError("Unsupport summary DbType!" + db_type) + db_summary_client = RdbmsSummary(dbname, db_type) embeddings = HuggingFaceEmbeddings( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL] ) @@ -62,8 +58,8 @@ class DBSummaryClient: ) embedding.source_embedding() for ( - table_name, - table_summary, + table_name, + table_summary, ) in db_summary_client.get_table_summary().items(): table_vector_store_config = { "vector_store_name": dbname + "_" + table_name + "_ts", diff --git a/pilot/summary/rdbms_db_summary.py b/pilot/summary/rdbms_db_summary.py new file mode 100644 index 000000000..8a3b77506 --- /dev/null +++ b/pilot/summary/rdbms_db_summary.py @@ -0,0 +1,180 @@ +import json + +from pilot.configs.config import Config +from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, IndexSummary + +CFG = Config() + +class RdbmsSummary(DBSummary): + """Get mysql summary template.""" + + def __init__(self, name, type): + self.name = name + self.type = type + self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" + self.tables = {} + self.tables_info = [] + self.vector_tables_info = [] + # self.tables_summary = {} + + self.db = CFG.LOCAL_DB_MANAGE.get_connect(name) + + self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format( + users=self.db.get_users(), + grant=self.db.get_grants(), + charset=self.db.get_charset(), + collation=self.db.get_collation(), + ) + tables = self.db.get_table_names() + self.table_comments = self.db.get_table_comments(name) + comment_map = {} + for table_comment in self.table_comments: + self.tables_info.append( + "table name:{table_name},table description:{table_comment}".format( + table_name=table_comment[0], table_comment=table_comment[1] + ) + ) + comment_map[table_comment[0]] = table_comment[1] + + vector_table = json.dumps( + {"table_name": table_comment[0], "table_description": table_comment[1]} + ) + self.vector_tables_info.append( + vector_table.encode("utf-8").decode("unicode_escape") + ) + self.table_columns_info = [] + self.table_columns_json = [] + + for table_name in tables: + table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map) + # self.tables[table_name] = table_summary.get_summery() + self.tables[table_name] = table_summary.get_columns() + self.table_columns_info.append(table_summary.get_columns()) + # self.table_columns_json.append(table_summary.get_summary_json()) + table_profile = ( + "table name:{table_name},table description:{table_comment}".format( + table_name=table_name, + table_comment=self.db.get_show_create_table(table_name), + ) + ) + self.table_columns_json.append(table_profile) + # self.tables_info.append(table_summary.get_summery()) + + def get_summery(self): + if CFG.SUMMARY_CONFIG == "FAST": + return self.vector_tables_info + else: + return self.summery.format( + name=self.name, type=self.type, table_info=";".join(self.tables_info) + ) + + def get_db_summery(self): + return self.summery.format( + name=self.name, + type=self.type, + tables=";".join(self.vector_tables_info), + qps=1000, + tps=1000, + ) + + def get_table_summary(self): + return self.tables + + def get_table_comments(self): + return self.table_comments + + def table_info_json(self): + return self.table_columns_json + + +class RdbmsTableSummary(TableSummary): + """Get mysql table summary template.""" + + def __init__(self, instance, dbname, name, comment_map): + self.name = name + self.dbname = dbname + self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}""" + self.json_summery_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}""" + self.fields = [] + self.fields_info = [] + self.indexes = [] + self.indexes_info = [] + self.db = instance + fields = self.db.get_fields(name) + indexes = self.db.get_indexes(name) + field_names = [] + for field in fields: + field_summary = RdbmsFieldsSummary(field) + self.fields.append(field_summary) + self.fields_info.append(field_summary.get_summery()) + field_names.append(field[0]) + + self.column_summery = """{name}({columns_info})""".format( + name=name, columns_info=",".join(field_names) + ) + + for index in indexes: + index_summary = RdbmsIndexSummary(index) + self.indexes.append(index_summary) + self.indexes_info.append(index_summary.get_summery()) + + self.json_summery = self.json_summery_template.format( + name=name, + comment=comment_map[name], + fields=self.fields_info, + indexes=self.indexes_info, + size_in_bytes=1000, + rows=1000, + ) + + def get_summery(self): + return self.summery.format( + name=self.name, + dbname=self.dbname, + fields=";".join(self.fields_info), + indexes=";".join(self.indexes_info), + ) + + def get_columns(self): + return self.column_summery + + def get_summary_json(self): + return self.json_summery + + +class RdbmsFieldsSummary(FieldSummary): + """Get mysql field summary template.""" + + def __init__(self, field): + self.name = field[0] + # self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """ + # self.summery = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}""" + self.data_type = field[1] + self.default_value = field[2] + self.is_nullable = field[3] + self.comment = field[4] + + def get_summery(self): + return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format( + name=self.name, + data_type=self.data_type, + is_nullable=self.is_nullable, + default_value=self.default_value, + comment=self.comment, + ) + + +class RdbmsIndexSummary(IndexSummary): + """Get mysql index summary template.""" + + def __init__(self, index): + self.name = index[0] + # self.summery = """index name:{name}, index bind columns:{bind_fields}""" + self.summery_template = '{{"name": "{name}", "columns": {bind_fields}}}' + self.bind_fields = index[1] + + def get_summery(self): + return self.summery_template.format( + name=self.name, bind_fields=self.bind_fields + ) + diff --git a/pilot/utils.py b/pilot/utils.py index 41e42fd55..b15d3af21 100644 --- a/pilot/utils.py +++ b/pilot/utils.py @@ -77,7 +77,7 @@ def build_logger(logger_name, logger_filename): for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) - + logging.basicConfig(level=logging.INFO) return logger