diff --git a/pilot/common/schema.py b/pilot/common/schema.py index 86f936c4b..f0334c553 100644 --- a/pilot/common/schema.py +++ b/pilot/common/schema.py @@ -36,7 +36,7 @@ class DBType(Enum): @staticmethod def of_db_type(db_type: str): - for item in DBType.__members__: - if item.value().name == db_type: + for item in DBType: + if item.value() == db_type: return item return None diff --git a/pilot/connections/manages/connect_storage_duckdb.py b/pilot/connections/manages/connect_storage_duckdb.py index e6612751a..7946e7cee 100644 --- a/pilot/connections/manages/connect_storage_duckdb.py +++ b/pilot/connections/manages/connect_storage_duckdb.py @@ -78,9 +78,9 @@ class DuckdbConnectConfig: fields = [field[0] for field in cursor.description] row_dict = {} - for row in cursor.fetchall()[0]: - for i, field in enumerate(fields): - row_dict[field] = row[i] + row_1 = list(cursor.fetchall()[0]) + for i, field in enumerate(fields): + row_dict[field] = row_1[i] return row_dict return {} diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index e03b0a496..3b7755b75 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -1,8 +1,11 @@ from pilot.configs.config import Config from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig from pilot.common.schema import DBType -from pilot.connections.rdbms.mysql import MySQLConnect +from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.base import BaseConnect + +from pilot.connections.rdbms.conn_mysql import MySQLConnect +from pilot.connections.rdbms.conn_duckdb import DuckDbConnect from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.singleton import Singleton from pilot.common.sql_database import Database @@ -12,15 +15,22 @@ CFG = Config() class ConnectManager: - def get_instance_by_dbtype(db_type, **kwargs): - chat_classes = BaseConnect.__subclasses__() - implementation = None + + def get_all_subclasses(self, cls): + subclasses = cls.__subclasses__() + for subclass in subclasses: + subclasses += self.get_all_subclasses(subclass) + return subclasses + + def get_cls_by_dbtype(self, db_type): + chat_classes = self.get_all_subclasses(BaseConnect) + result = None for cls in chat_classes: if cls.db_type == db_type: - implementation = cls(**kwargs) - if implementation == None: - raise Exception(f"Invalid db connect implementation!DbType:{db_type}") - return implementation + result = cls + if not result: + raise ValueError("Unsupport Db Type!" + db_type) + return result def __init__(self): self.storage = DuckdbConnectConfig() @@ -69,10 +79,10 @@ class ConnectManager: def get_connect(self, db_name): db_config = self.storage.get_db_config(db_name) db_type = DBType.of_db_type(db_config.get('db_type')) - connect_instance = self.get_instance_by_dbtype(db_type) + connect_instance = self.get_cls_by_dbtype(db_type.value()) if db_type.is_file_db(): db_path = db_config.get('db_path') - return connect_instance.from_file(db_path) + return connect_instance.from_file_path(db_path) else: db_host = db_config.get('db_host') db_port = db_config.get('db_port') diff --git a/pilot/connections/rdbms/conn_duckdb.py b/pilot/connections/rdbms/conn_duckdb.py new file mode 100644 index 000000000..b238d468a --- /dev/null +++ b/pilot/connections/rdbms/conn_duckdb.py @@ -0,0 +1,69 @@ +from typing import Optional, Any, Iterable +from sqlalchemy import ( + MetaData, + Table, + create_engine, + inspect, + select, + text, +) +from sqlalchemy.ext.declarative import declarative_base + +from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase +from pilot.configs.config import Config + +CFG = Config() +Base = declarative_base() + +class DuckDbConnect(RDBMSDatabase): + """Connect Duckdb Database fetch MetaData + Args: + Usage: + """ + db_type: str = "duckdb" + db_dialect: str = "duckdb" + + @classmethod + def from_file_path( + cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any + ) -> RDBMSDatabase: + """Construct a SQLAlchemy engine from URI.""" + _engine_args = engine_args or {} + return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs) + + def table_simple_info(self) -> Iterable[str]: + _tables_sql = f""" + SELECT name FROM sqlite_master WHERE type='table' + """ + cursor = self.session.execute(text(_tables_sql)) + tables_results = cursor.fetchall() + results =[] + for row in tables_results: + table_name = row[0] + _sql = f""" + PRAGMA table_info({table_name}) + """ + cursor_colums = self.session.execute(text(_sql)) + colum_results = cursor_colums.fetchall() + table_colums = [] + for row_col in colum_results: + field_info = list(row_col) + table_colums.append(field_info[1]) + + results.append(f"{table_name}({','.join(table_colums)});") + return results + +if __name__ == "__main__": + engine = create_engine('duckdb:////Users/tuyang.yhj/Code/PycharmProjects/DB-GPT/pilot/mock_datas/db-gpt-test.db') + metadata = MetaData(engine) + + results = engine.connect().execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + + print(str(results)) + + fields = [] + results2 = engine.connect().execute(f"""PRAGMA table_info(user)""").fetchall() + for row_col in results2: + field_info = list(row_col) + fields.append(field_info[1]) + print(str(fields)) diff --git a/pilot/connections/rdbms/mysql.py b/pilot/connections/rdbms/conn_mysql.py similarity index 92% rename from pilot/connections/rdbms/mysql.py rename to pilot/connections/rdbms/conn_mysql.py index ecd369c3d..196919dbe 100644 --- a/pilot/connections/rdbms/mysql.py +++ b/pilot/connections/rdbms/conn_mysql.py @@ -13,6 +13,6 @@ class MySQLConnect(RDBMSDatabase): db_type: str = "mysql" db_dialect: str = "mysql" - driver: str = "pymysql" + driver: str = "mysql+pymysql" default_db = ["information_schema", "performance_schema", "sys", "mysql"] diff --git a/pilot/connections/rdbms/duckdb.py b/pilot/connections/rdbms/duckdb.py deleted file mode 100644 index ad4752ef4..000000000 --- a/pilot/connections/rdbms/duckdb.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Optional, Any, Iterable -from sqlalchemy import ( - MetaData, - Table, - create_engine, - inspect, - select, - text, -) - -from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase -from pilot.configs.config import Config - -CFG = Config() - - -class DuckDbConnect(RDBMSDatabase): - """Connect Duckdb Database fetch MetaData - Args: - Usage: - """ - - def table_simple_info(self) -> Iterable[str]: - return super().get_table_names() - - db_type: str = "duckdb" - - @classmethod - def from_file_path( - cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any - ) -> RDBMSDatabase: - """Construct a SQLAlchemy engine from URI.""" - _engine_args = engine_args or {} - return cls(create_engine("duckdb://" + file_path, **_engine_args), **kwargs) diff --git a/pilot/connections/rdbms/rdbms_connect.py b/pilot/connections/rdbms/rdbms_connect.py index 3164733d2..b9d6ed2df 100644 --- a/pilot/connections/rdbms/rdbms_connect.py +++ b/pilot/connections/rdbms/rdbms_connect.py @@ -36,6 +36,7 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: class RDBMSDatabase(BaseConnect): """SQLAlchemy wrapper around a database.""" + db_type: str = None def __init__( self, @@ -69,7 +70,7 @@ class RDBMSDatabase(BaseConnect): **kwargs: Any, ) -> RDBMSDatabase: db_url: str = ( - cls.connect_driver + cls.driver + "://" + CFG.LOCAL_DB_USER + ":" @@ -114,17 +115,6 @@ class RDBMSDatabase(BaseConnect): self._metadata = MetaData() self._metadata.reflect(bind=self._engine) - # including view support by adding the views as well as tables to the all - # tables list if view_support is True - self._all_tables = set( - self._inspector.get_table_names() - + ( - self._inspector.get_view_names() - if self.view_support - else [] - ) - ) - return session def get_current_db_name(self) -> str: diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index f2e5d9462..93f09e32e 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -74,7 +74,7 @@ def get_db_list(): dbs = CFG.LOCAL_DB_MANAGE.get_db_list() params: dict = {} for item in dbs: - params.update({item["db_name"]: item["comment"]}) + params.update({item["db_name"]: item["db_name"]}) return params diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 4bb689d3e..6b17d33f2 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -43,9 +43,10 @@ class ChatDashboard(BaseChat): ) self.db_name = db_name self.report_name = report_name + self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name) - # 准备DB信息(拿到指定库的链接) - self.db_connect = self.database.get_session(self.db_name) + self.db_connect = self.database.session + self.top_k: int = 5 self.dashboard_template = self.__load_dashboard_template(report_name) @@ -64,13 +65,19 @@ class ChatDashboard(BaseChat): from pilot.summary.db_summary_client import DBSummaryClient except ImportError: raise ValueError("Could not import DBSummaryClient. ") + client = DBSummaryClient() + try: + table_infos = client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) + print("dashboard vector find tables:{}", table_infos) + except Exception as e: + print("db summary find error!" + str(e)) + input_values = { "input": self.current_user_input, "dialect": self.database.dialect, - "table_info": self.database.table_simple_info(self.db_connect), + "table_info": self.database.table_simple_info(), "supported_chat_type": self.dashboard_template['supported_chart_type'] - # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) } return input_values @@ -91,14 +98,16 @@ class ChatDashboard(BaseChat): if not data_map[field_name]: field_map.update({f"{field_name}": False}) else: - field_map.update({f"{field_name}": all(isinstance(item, (int, float, Decimal)) for item in data_map[field_name])}) + field_map.update({f"{field_name}": all( + isinstance(item, (int, float, Decimal)) for item in data_map[field_name])}) for field_name in field_names[1:]: if not field_map[field_name]: print("more than 2 non-numeric column") else: for data in datas: - value_item = ValueItem(name=data[0], type=field_name, value=data[field_names.index(field_name)]) + value_item = ValueItem(name=data[0], type=field_name, + value=data[field_names.index(field_name)]) values.append(value_item) chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()), diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index b339d6334..22c585cb8 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -23,12 +23,12 @@ According to the characteristics of the analyzed data, choose the most suitable {supported_chat_type} Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens -Do not use unprovided fields and field value in data analysis SQL, Do not use column pay_status as a query condition in SQL. +Do not use unprovided fields and field value in analysis SQL, Do not use column pay_status as a query condition in SQL. According to the characteristics of the analyzed data, choose the best one from the charts provided below to display, use different types of charts as much as possible,chart types: {supported_chat_type} -Give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format: +Give {dialect} data analysis SQL(Do not use data not provided as field value), analysis title, display method and analytical thinking,respond in the following json format: {response} Ensure the response is correct json and can be parsed by Python json.loads diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 2477cbde6..588cff855 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -43,12 +43,17 @@ class ChatWithDbAutoExecute(BaseChat): except ImportError: raise ValueError("Could not import DBSummaryClient. ") client = DBSummaryClient() + try: + table_infos = client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) + except Exception as e: + print("db summary find error!" + str(e)) + table_infos = self.database.table_simple_info() + input_values = { "input": self.current_user_input, "top_k": str(self.top_k), "dialect": self.database.dialect, - # "table_info": self.database.table_simple_info(self.db_connect) - "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) + "table_info": table_infos } return input_values diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index 1424d7cef..a0a571988 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -47,9 +47,13 @@ class ChatWithDbQA(BaseChat): raise ValueError("Could not import DBSummaryClient. ") if self.db_name: client = DBSummaryClient() - table_info = client.get_db_summary( - dbname=self.db_name, query=self.current_user_input, topk=self.top_k - ) + try: + table_infos = client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, + topk=self.top_k) + except Exception as e: + print("db summary find error!" + str(e)) + table_infos = self.database.table_simple_info() + # table_info = self.database.table_simple_info(self.db_connect) dialect = self.database.dialect @@ -57,6 +61,6 @@ class ChatWithDbQA(BaseChat): "input": self.current_user_input, # "top_k": str(self.top_k), # "dialect": dialect, - "table_info": table_info, + "table_info": table_infos, } return input_values diff --git a/pilot/server/static/404.html b/pilot/server/static/404.html index cc9696c0e..e666deaac 100644 --- a/pilot/server/static/404.html +++ b/pilot/server/static/404.html @@ -1 +1 @@ -