diff --git a/pilot/connections/manages/connect_storage_duckdb.py b/pilot/connections/manages/connect_storage_duckdb.py index 7946e7cee..3285168d4 100644 --- a/pilot/connections/manages/connect_storage_duckdb.py +++ b/pilot/connections/manages/connect_storage_duckdb.py @@ -46,12 +46,12 @@ class DuckdbConnectConfig: except Exception as e: raise "Unusable duckdb database path:" + path - def add_file_db(self, db_name, db_type, db_path: str): + def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""): try: cursor = self.connect.cursor() cursor.execute( "INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)", - [db_name, db_type, db_path, "", "", "", "", ""], + [db_name, db_type, db_path, "", "", "", "", comment], ) cursor.commit() self.connect.commit() diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index 3b7755b75..122db1eb3 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -6,16 +6,17 @@ from pilot.connections.base import BaseConnect from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.rdbms.conn_duckdb import DuckDbConnect +from pilot.connections.rdbms.conn_mssql import MSSQLConnect from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.singleton import Singleton from pilot.common.sql_database import Database +from pilot.connections.db_conn_info import DBConfig CFG = Config() class ConnectManager: - def get_all_subclasses(self, cls): subclasses = cls.__subclasses__() for subclass in subclasses: @@ -93,5 +94,17 @@ class ConnectManager: def get_db_list(self): return self.storage.get_db_list() + def get_db_names(self): return self.storage.get_db_names() + + def delete_db(self, db_name: str): + return self.storage.delete_db(db_name) + + def add_db(self, db_info: DBConfig): + db_type = DBType.of_db_type(db_info.db_type) + if db_type.is_file_db(): + self.storage.add_file_db(db_info.db_name, db_info.db_type, db_info.file_path) + else: + self.storage.add_url_db(db_info.db_name, db_info.db_type, db_info.db_host, db_info.db_port, db_info.db_user, db_info.db_pwd, db_info.comment) + return True \ No newline at end of file diff --git a/pilot/connections/rdbms/conn_duckdb.py b/pilot/connections/rdbms/conn_duckdb.py index b238d468a..0da04234e 100644 --- a/pilot/connections/rdbms/conn_duckdb.py +++ b/pilot/connections/rdbms/conn_duckdb.py @@ -10,10 +10,6 @@ from sqlalchemy import ( from sqlalchemy.ext.declarative import declarative_base from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase -from pilot.configs.config import Config - -CFG = Config() -Base = declarative_base() class DuckDbConnect(RDBMSDatabase): """Connect Duckdb Database fetch MetaData diff --git a/pilot/connections/rdbms/conn_mssql.py b/pilot/connections/rdbms/conn_mssql.py new file mode 100644 index 000000000..8fb6843f1 --- /dev/null +++ b/pilot/connections/rdbms/conn_mssql.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import Optional, Any, Iterable + +from sqlalchemy import ( + MetaData, + Table, + create_engine, + inspect, + select, + text, +) +from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase + + +class MSSQLConnect(RDBMSDatabase): + """Connect MSSQL Database fetch MetaData + Args: + Usage: + """ + + db_type: str = "mssql" + db_dialect: str = "mssql" + driver: str = "mssql+pymssql" + + default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource", "sys"] + + + def table_simple_info(self) -> Iterable[str]: + _tables_sql = f""" + SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' + """ + cursor = self.session.execute(text(_tables_sql)) + tables_results = cursor.fetchall() + results =[] + for row in tables_results: + table_name = row[0] + _sql = f""" + SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='{table_name}' + """ + cursor_colums = self.session.execute(text(_sql)) + colum_results = cursor_colums.fetchall() + table_colums = [] + for row_col in colum_results: + field_info = list(row_col) + table_colums.append(field_info[0]) + results.append(f"{table_name}({','.join(table_colums)});") + return results diff --git a/pilot/connections/rdbms/mssql.py b/pilot/connections/rdbms/mssql.py deleted file mode 100644 index f70f859c4..000000000 --- a/pilot/connections/rdbms/mssql.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -from typing import Optional, Any - -from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase - - -class MSSQLConnect(RDBMSDatabase): - """Connect MSSQL Database fetch MetaData - Args: - Usage: - """ - - db_type: str = "mssql" - db_dialect: str = "mssql" - driver: str = "pyodbc" - - default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"] diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index f2fff4cef..b2e5a34fb 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -23,8 +23,9 @@ from pilot.openapi.api_v1.api_view_model import ( Result, ConversationVo, MessageVo, - ChatSceneVo, + ChatSceneVo ) +from pilot.connections.db_conn_info import DBConfig from pilot.configs.config import Config from pilot.server.knowledge.service import KnowledgeService from pilot.server.knowledge.request.request import KnowledgeSpaceRequest @@ -95,6 +96,25 @@ def knowledge_list(): return params + + +@router.get("/v1/chat/db/list", response_model=Result[DBConfig]) +async def dialogue_list(): + return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) + +@router.post("/v1/chat/db/add", response_model=Result[bool]) +async def dialogue_list(db_config: DBConfig = Body() ): + return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config)) + +@router.post("/v1/chat/db/delete", response_model=Result[bool]) +async def dialogue_list(db_name: str = None): + return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name)) + + +@router.get("/v1/chat/db/support/type", response_model=Result[str]) +async def db_support_types(): + return Result[str].succ(["mysql", "mssql", "duckdb"]) + @router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo]) async def dialogue_list(user_id: str = None): dialogues: List = [] diff --git a/requirements.txt b/requirements.txt index 1190a755d..961b5c16a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -73,6 +73,7 @@ weaviate-client pymysql duckdb duckdb-engine +pymssql # Testing dependencies pytest