mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-31 16:39:48 +00:00
feat:Db summary merge multi db connect
This commit is contained in:
@@ -131,8 +131,6 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
|
||||
# Generic plugins
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
|
||||
logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}")
|
||||
logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}")
|
||||
|
||||
for plugin in plugins_path_path.glob("*.zip"):
|
||||
if moduleList := inspect_zip_for_modules(str(plugin), debug):
|
||||
|
@@ -7,7 +7,7 @@ from pilot.connections.base import BaseConnect
|
||||
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
||||
from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
|
||||
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.connections.db_conn_info import DBConfig
|
||||
|
@@ -1,38 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ClickHouseConnector(RDBMSDatabase):
|
||||
"""ClickHouseConnector"""
|
||||
|
||||
db_type: str = "duckdb"
|
||||
|
||||
driver: str = "duckdb"
|
||||
|
||||
file_path: str
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(
|
||||
cls,
|
||||
CFG.LOCAL_DB_PATH,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + db_path
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
@@ -9,7 +9,7 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class DuckDbConnect(RDBMSDatabase):
|
||||
@@ -29,6 +29,33 @@ class DuckDbConnect(RDBMSDatabase):
|
||||
_engine_args = engine_args or {}
|
||||
return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs)
|
||||
|
||||
def get_users(self):
|
||||
cursor = self.session.execute(text(f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"))
|
||||
users = cursor.fetchall()
|
||||
return [(user[0], user[1]) for user in users]
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT name, sql FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
_tables_sql = f"""
|
||||
SELECT name FROM sqlite_master WHERE type='table'
|
||||
|
@@ -10,7 +10,7 @@ from sqlalchemy import (
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class MSSQLConnect(RDBMSDatabase):
|
||||
|
@@ -2,7 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, Any
|
||||
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class MySQLConnect(RDBMSDatabase):
|
||||
|
@@ -1,13 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
class OracleConnector(RDBMSDatabase):
|
||||
"""OracleConnector"""
|
||||
|
||||
db_type: str = "oracle"
|
||||
|
||||
driver: str = "oracle"
|
||||
|
||||
default_db = ["SYS", "SYSTEM", "OUTLN", "ORDDATA", "XDB"]
|
@@ -1,12 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
class PostgresConnector(RDBMSDatabase):
|
||||
"""PostgresConnector is a class which Connector"""
|
||||
|
||||
db_type: str = "postgresql"
|
||||
driver: str = "postgresql"
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
@@ -4,17 +4,17 @@ import os
|
||||
import shutil
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import logging
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
import signal
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
)
|
||||
# from pilot.configs.model_config import (
|
||||
# DATASETS_DIR,
|
||||
# KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
# LLM_MODEL_CONFIG,
|
||||
# LOGDIR,
|
||||
# )
|
||||
from pilot.utils import build_logger
|
||||
|
||||
from pilot.server.webserver_base import server_init
|
||||
@@ -30,11 +30,13 @@ from pilot.server.knowledge.api import router as knowledge_router
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
# logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
|
||||
|
||||
def signal_handler():
|
||||
@@ -113,5 +115,6 @@ if __name__ == "__main__":
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=0)
|
||||
signal.signal(signal.SIGINT, signal_handler())
|
||||
|
@@ -7,12 +7,12 @@ import sys
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
)
|
||||
# from pilot.configs.model_config import (
|
||||
# DATASETS_DIR,
|
||||
# KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
# LLM_MODEL_CONFIG,
|
||||
# LOGDIR,
|
||||
# )
|
||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
||||
from pilot.utils import build_logger
|
||||
from pilot.connections.manages.connection_manager import ConnectManager
|
||||
@@ -20,7 +20,7 @@ from pilot.connections.manages.connection_manager import ConnectManager
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
# logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
@@ -35,7 +35,7 @@ def async_db_summery():
|
||||
|
||||
|
||||
def server_init(args):
|
||||
logger.info(f"args: {args}")
|
||||
# logger.info(f"args: {args}")
|
||||
|
||||
# init config
|
||||
cfg = Config()
|
||||
@@ -43,7 +43,7 @@ def server_init(args):
|
||||
conn_manage = ConnectManager()
|
||||
cfg.LOCAL_DB_MANAGE = conn_manage
|
||||
|
||||
load_native_plugins(cfg)
|
||||
# load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
async_db_summery()
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
@@ -9,11 +9,10 @@ from pilot.scene.base import ChatScene
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||
from pilot.summary.mysql_db_summary import MysqlSummary
|
||||
from pilot.summary.rdbms_db_summary import RdbmsSummary
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.common.schema import DBType
|
||||
|
||||
|
||||
CFG = Config()
|
||||
chat_factory = ChatFactory()
|
||||
|
||||
@@ -28,11 +27,8 @@ class DBSummaryClient:
|
||||
|
||||
def db_summary_embedding(self, dbname, db_type):
|
||||
"""put db profile and table profile summary into vector store"""
|
||||
if DBType.Mysql.value() == db_type:
|
||||
db_summary_client = MysqlSummary(dbname)
|
||||
else:
|
||||
raise ValueError("Unsupport summary DbType!" + db_type)
|
||||
|
||||
db_summary_client = RdbmsSummary(dbname, db_type)
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
@@ -62,8 +58,8 @@ class DBSummaryClient:
|
||||
)
|
||||
embedding.source_embedding()
|
||||
for (
|
||||
table_name,
|
||||
table_summary,
|
||||
table_name,
|
||||
table_summary,
|
||||
) in db_summary_client.get_table_summary().items():
|
||||
table_vector_store_config = {
|
||||
"vector_store_name": dbname + "_" + table_name + "_ts",
|
||||
|
180
pilot/summary/rdbms_db_summary.py
Normal file
180
pilot/summary/rdbms_db_summary.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import json
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, IndexSummary
|
||||
|
||||
CFG = Config()
|
||||
|
||||
class RdbmsSummary(DBSummary):
|
||||
"""Get mysql summary template."""
|
||||
|
||||
def __init__(self, name, type):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
|
||||
self.tables = {}
|
||||
self.tables_info = []
|
||||
self.vector_tables_info = []
|
||||
# self.tables_summary = {}
|
||||
|
||||
self.db = CFG.LOCAL_DB_MANAGE.get_connect(name)
|
||||
|
||||
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
|
||||
users=self.db.get_users(),
|
||||
grant=self.db.get_grants(),
|
||||
charset=self.db.get_charset(),
|
||||
collation=self.db.get_collation(),
|
||||
)
|
||||
tables = self.db.get_table_names()
|
||||
self.table_comments = self.db.get_table_comments(name)
|
||||
comment_map = {}
|
||||
for table_comment in self.table_comments:
|
||||
self.tables_info.append(
|
||||
"table name:{table_name},table description:{table_comment}".format(
|
||||
table_name=table_comment[0], table_comment=table_comment[1]
|
||||
)
|
||||
)
|
||||
comment_map[table_comment[0]] = table_comment[1]
|
||||
|
||||
vector_table = json.dumps(
|
||||
{"table_name": table_comment[0], "table_description": table_comment[1]}
|
||||
)
|
||||
self.vector_tables_info.append(
|
||||
vector_table.encode("utf-8").decode("unicode_escape")
|
||||
)
|
||||
self.table_columns_info = []
|
||||
self.table_columns_json = []
|
||||
|
||||
for table_name in tables:
|
||||
table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map)
|
||||
# self.tables[table_name] = table_summary.get_summery()
|
||||
self.tables[table_name] = table_summary.get_columns()
|
||||
self.table_columns_info.append(table_summary.get_columns())
|
||||
# self.table_columns_json.append(table_summary.get_summary_json())
|
||||
table_profile = (
|
||||
"table name:{table_name},table description:{table_comment}".format(
|
||||
table_name=table_name,
|
||||
table_comment=self.db.get_show_create_table(table_name),
|
||||
)
|
||||
)
|
||||
self.table_columns_json.append(table_profile)
|
||||
# self.tables_info.append(table_summary.get_summery())
|
||||
|
||||
def get_summery(self):
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
return self.vector_tables_info
|
||||
else:
|
||||
return self.summery.format(
|
||||
name=self.name, type=self.type, table_info=";".join(self.tables_info)
|
||||
)
|
||||
|
||||
def get_db_summery(self):
|
||||
return self.summery.format(
|
||||
name=self.name,
|
||||
type=self.type,
|
||||
tables=";".join(self.vector_tables_info),
|
||||
qps=1000,
|
||||
tps=1000,
|
||||
)
|
||||
|
||||
def get_table_summary(self):
|
||||
return self.tables
|
||||
|
||||
def get_table_comments(self):
|
||||
return self.table_comments
|
||||
|
||||
def table_info_json(self):
|
||||
return self.table_columns_json
|
||||
|
||||
|
||||
class RdbmsTableSummary(TableSummary):
|
||||
"""Get mysql table summary template."""
|
||||
|
||||
def __init__(self, instance, dbname, name, comment_map):
|
||||
self.name = name
|
||||
self.dbname = dbname
|
||||
self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
|
||||
self.json_summery_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
|
||||
self.fields = []
|
||||
self.fields_info = []
|
||||
self.indexes = []
|
||||
self.indexes_info = []
|
||||
self.db = instance
|
||||
fields = self.db.get_fields(name)
|
||||
indexes = self.db.get_indexes(name)
|
||||
field_names = []
|
||||
for field in fields:
|
||||
field_summary = RdbmsFieldsSummary(field)
|
||||
self.fields.append(field_summary)
|
||||
self.fields_info.append(field_summary.get_summery())
|
||||
field_names.append(field[0])
|
||||
|
||||
self.column_summery = """{name}({columns_info})""".format(
|
||||
name=name, columns_info=",".join(field_names)
|
||||
)
|
||||
|
||||
for index in indexes:
|
||||
index_summary = RdbmsIndexSummary(index)
|
||||
self.indexes.append(index_summary)
|
||||
self.indexes_info.append(index_summary.get_summery())
|
||||
|
||||
self.json_summery = self.json_summery_template.format(
|
||||
name=name,
|
||||
comment=comment_map[name],
|
||||
fields=self.fields_info,
|
||||
indexes=self.indexes_info,
|
||||
size_in_bytes=1000,
|
||||
rows=1000,
|
||||
)
|
||||
|
||||
def get_summery(self):
|
||||
return self.summery.format(
|
||||
name=self.name,
|
||||
dbname=self.dbname,
|
||||
fields=";".join(self.fields_info),
|
||||
indexes=";".join(self.indexes_info),
|
||||
)
|
||||
|
||||
def get_columns(self):
|
||||
return self.column_summery
|
||||
|
||||
def get_summary_json(self):
|
||||
return self.json_summery
|
||||
|
||||
|
||||
class RdbmsFieldsSummary(FieldSummary):
|
||||
"""Get mysql field summary template."""
|
||||
|
||||
def __init__(self, field):
|
||||
self.name = field[0]
|
||||
# self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
|
||||
# self.summery = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
|
||||
self.data_type = field[1]
|
||||
self.default_value = field[2]
|
||||
self.is_nullable = field[3]
|
||||
self.comment = field[4]
|
||||
|
||||
def get_summery(self):
|
||||
return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format(
|
||||
name=self.name,
|
||||
data_type=self.data_type,
|
||||
is_nullable=self.is_nullable,
|
||||
default_value=self.default_value,
|
||||
comment=self.comment,
|
||||
)
|
||||
|
||||
|
||||
class RdbmsIndexSummary(IndexSummary):
|
||||
"""Get mysql index summary template."""
|
||||
|
||||
def __init__(self, index):
|
||||
self.name = index[0]
|
||||
# self.summery = """index name:{name}, index bind columns:{bind_fields}"""
|
||||
self.summery_template = '{{"name": "{name}", "columns": {bind_fields}}}'
|
||||
self.bind_fields = index[1]
|
||||
|
||||
def get_summery(self):
|
||||
return self.summery_template.format(
|
||||
name=self.name, bind_fields=self.bind_fields
|
||||
)
|
||||
|
@@ -77,7 +77,7 @@ def build_logger(logger_name, logger_filename):
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
if isinstance(item, logging.Logger):
|
||||
item.addHandler(handler)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
return logger
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user