feat:Db summary merge multi db connect

This commit is contained in:
yhjun1026
2023-08-01 15:35:59 +08:00
parent 941237ac89
commit 77f6b2f458
14 changed files with 237 additions and 96 deletions

View File

@@ -131,8 +131,6 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
# Generic plugins
plugins_path_path = Path(PLUGINS_DIR)
logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}")
logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}")
for plugin in plugins_path_path.glob("*.zip"):
if moduleList := inspect_zip_for_modules(str(plugin), debug):

View File

@@ -7,7 +7,7 @@ from pilot.connections.base import BaseConnect
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.singleton import Singleton
from pilot.common.sql_database import Database
from pilot.connections.db_conn_info import DBConfig

View File

@@ -1,38 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.configs.config import Config
CFG = Config()
class ClickHouseConnector(RDBMSDatabase):
"""ClickHouseConnector"""
db_type: str = "duckdb"
driver: str = "duckdb"
file_path: str
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod
def from_config(cls) -> RDBMSDatabase:
"""
Todo password encryption
Returns:
"""
return cls.from_uri_db(
cls,
CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod
def from_uri_db(
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs)

View File

@@ -9,7 +9,7 @@ from sqlalchemy import (
)
from sqlalchemy.ext.declarative import declarative_base
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.connections.rdbms.base import RDBMSDatabase
class DuckDbConnect(RDBMSDatabase):
@@ -29,6 +29,33 @@ class DuckDbConnect(RDBMSDatabase):
_engine_args = engine_args or {}
return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs)
def get_users(self):
cursor = self.session.execute(text(f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"))
users = cursor.fetchall()
return [(user[0], user[1]) for user in users]
def get_grants(self):
return []
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"
def get_table_comments(self, db_name):
cursor = self.session.execute(
text(
f"""
SELECT name, sql FROM sqlite_master WHERE type='table'
"""
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]
def table_simple_info(self) -> Iterable[str]:
_tables_sql = f"""
SELECT name FROM sqlite_master WHERE type='table'

View File

@@ -10,7 +10,7 @@ from sqlalchemy import (
select,
text,
)
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.connections.rdbms.base import RDBMSDatabase
class MSSQLConnect(RDBMSDatabase):

View File

@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.connections.rdbms.base import RDBMSDatabase
class MySQLConnect(RDBMSDatabase):

View File

@@ -1,13 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class OracleConnector(RDBMSDatabase):
"""OracleConnector"""
db_type: str = "oracle"
driver: str = "oracle"
default_db = ["SYS", "SYSTEM", "OUTLN", "ORDDATA", "XDB"]

View File

@@ -1,12 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class PostgresConnector(RDBMSDatabase):
"""PostgresConnector is a class which Connector"""
db_type: str = "postgresql"
driver: str = "postgresql"
default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@@ -4,17 +4,17 @@ import os
import shutil
import argparse
import sys
import logging
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
# from pilot.configs.model_config import (
# DATASETS_DIR,
# KNOWLEDGE_UPLOAD_ROOT_PATH,
# LLM_MODEL_CONFIG,
# LOGDIR,
# )
from pilot.utils import build_logger
from pilot.server.webserver_base import server_init
@@ -30,11 +30,13 @@ from pilot.server.knowledge.api import router as knowledge_router
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
logging.basicConfig(level=logging.INFO)
static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
# logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler():
@@ -113,5 +115,6 @@ if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=args.port)
logging.basicConfig(level=logging.INFO)
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=0)
signal.signal(signal.SIGINT, signal_handler())

View File

@@ -7,12 +7,12 @@ import sys
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
# from pilot.configs.model_config import (
# DATASETS_DIR,
# KNOWLEDGE_UPLOAD_ROOT_PATH,
# LLM_MODEL_CONFIG,
# LOGDIR,
# )
from pilot.common.plugins import scan_plugins, load_native_plugins
from pilot.utils import build_logger
from pilot.connections.manages.connection_manager import ConnectManager
@@ -20,7 +20,7 @@ from pilot.connections.manages.connection_manager import ConnectManager
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
logger = build_logger("webserver", LOGDIR + "webserver.log")
# logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler(sig, frame):
@@ -35,7 +35,7 @@ def async_db_summery():
def server_init(args):
logger.info(f"args: {args}")
# logger.info(f"args: {args}")
# init config
cfg = Config()
@@ -43,7 +43,7 @@ def server_init(args):
conn_manage = ConnectManager()
cfg.LOCAL_DB_MANAGE = conn_manage
load_native_plugins(cfg)
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
async_db_summery()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))

View File

@@ -9,11 +9,10 @@ from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.summary.mysql_db_summary import MysqlSummary
from pilot.summary.rdbms_db_summary import RdbmsSummary
from pilot.scene.chat_factory import ChatFactory
from pilot.common.schema import DBType
CFG = Config()
chat_factory = ChatFactory()
@@ -28,11 +27,8 @@ class DBSummaryClient:
def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store"""
if DBType.Mysql.value() == db_type:
db_summary_client = MysqlSummary(dbname)
else:
raise ValueError("Unsupport summary DbType" + db_type)
db_summary_client = RdbmsSummary(dbname, db_type)
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
@@ -62,8 +58,8 @@ class DBSummaryClient:
)
embedding.source_embedding()
for (
table_name,
table_summary,
table_name,
table_summary,
) in db_summary_client.get_table_summary().items():
table_vector_store_config = {
"vector_store_name": dbname + "_" + table_name + "_ts",

View File

@@ -0,0 +1,180 @@
import json
from pilot.configs.config import Config
from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, IndexSummary
CFG = Config()
class RdbmsSummary(DBSummary):
"""Get mysql summary template."""
def __init__(self, name, type):
self.name = name
self.type = type
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
self.tables = {}
self.tables_info = []
self.vector_tables_info = []
# self.tables_summary = {}
self.db = CFG.LOCAL_DB_MANAGE.get_connect(name)
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
users=self.db.get_users(),
grant=self.db.get_grants(),
charset=self.db.get_charset(),
collation=self.db.get_collation(),
)
tables = self.db.get_table_names()
self.table_comments = self.db.get_table_comments(name)
comment_map = {}
for table_comment in self.table_comments:
self.tables_info.append(
"table name:{table_name},table description:{table_comment}".format(
table_name=table_comment[0], table_comment=table_comment[1]
)
)
comment_map[table_comment[0]] = table_comment[1]
vector_table = json.dumps(
{"table_name": table_comment[0], "table_description": table_comment[1]}
)
self.vector_tables_info.append(
vector_table.encode("utf-8").decode("unicode_escape")
)
self.table_columns_info = []
self.table_columns_json = []
for table_name in tables:
table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map)
# self.tables[table_name] = table_summary.get_summery()
self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json())
table_profile = (
"table name:{table_name},table description:{table_comment}".format(
table_name=table_name,
table_comment=self.db.get_show_create_table(table_name),
)
)
self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summery())
def get_summery(self):
if CFG.SUMMARY_CONFIG == "FAST":
return self.vector_tables_info
else:
return self.summery.format(
name=self.name, type=self.type, table_info=";".join(self.tables_info)
)
def get_db_summery(self):
return self.summery.format(
name=self.name,
type=self.type,
tables=";".join(self.vector_tables_info),
qps=1000,
tps=1000,
)
def get_table_summary(self):
return self.tables
def get_table_comments(self):
return self.table_comments
def table_info_json(self):
return self.table_columns_json
class RdbmsTableSummary(TableSummary):
"""Get mysql table summary template."""
def __init__(self, instance, dbname, name, comment_map):
self.name = name
self.dbname = dbname
self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
self.json_summery_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
self.fields = []
self.fields_info = []
self.indexes = []
self.indexes_info = []
self.db = instance
fields = self.db.get_fields(name)
indexes = self.db.get_indexes(name)
field_names = []
for field in fields:
field_summary = RdbmsFieldsSummary(field)
self.fields.append(field_summary)
self.fields_info.append(field_summary.get_summery())
field_names.append(field[0])
self.column_summery = """{name}({columns_info})""".format(
name=name, columns_info=",".join(field_names)
)
for index in indexes:
index_summary = RdbmsIndexSummary(index)
self.indexes.append(index_summary)
self.indexes_info.append(index_summary.get_summery())
self.json_summery = self.json_summery_template.format(
name=name,
comment=comment_map[name],
fields=self.fields_info,
indexes=self.indexes_info,
size_in_bytes=1000,
rows=1000,
)
def get_summery(self):
return self.summery.format(
name=self.name,
dbname=self.dbname,
fields=";".join(self.fields_info),
indexes=";".join(self.indexes_info),
)
def get_columns(self):
return self.column_summery
def get_summary_json(self):
return self.json_summery
class RdbmsFieldsSummary(FieldSummary):
"""Get mysql field summary template."""
def __init__(self, field):
self.name = field[0]
# self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
# self.summery = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
self.data_type = field[1]
self.default_value = field[2]
self.is_nullable = field[3]
self.comment = field[4]
def get_summery(self):
return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format(
name=self.name,
data_type=self.data_type,
is_nullable=self.is_nullable,
default_value=self.default_value,
comment=self.comment,
)
class RdbmsIndexSummary(IndexSummary):
"""Get mysql index summary template."""
def __init__(self, index):
self.name = index[0]
# self.summery = """index name:{name}, index bind columns:{bind_fields}"""
self.summery_template = '{{"name": "{name}", "columns": {bind_fields}}}'
self.bind_fields = index[1]
def get_summery(self):
return self.summery_template.format(
name=self.name, bind_fields=self.bind_fields
)

View File

@@ -77,7 +77,7 @@ def build_logger(logger_name, logger_filename):
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
logging.basicConfig(level=logging.INFO)
return logger