diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index f7184858e..711f1d6db 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -1,4 +1,5 @@ import threading +import asyncio from pilot.configs.config import Config from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig @@ -138,6 +139,27 @@ class ConnectManager: host=db_host, port=db_port, user=db_user, pwd=db_pwd, db_name=db_name ) + def test_connect(self, db_info: DBConfig): + try: + db_type = DBType.of_db_type(db_info.db_type) + connect_instance = self.get_cls_by_dbtype(db_type.value()) + if db_type.is_file_db(): + db_path = db_info.db_path + return connect_instance.from_file_path(db_path) + else: + db_name = db_info.db_name + db_host = db_info.db_host + db_port = db_info.db_port + db_user = db_info.db_user + db_pwd = db_info.db_pwd + return connect_instance.from_uri_db( + host=db_host, port=db_port, user=db_user, pwd=db_pwd, db_name=db_name + ) + except Exception as e: + print(f'{db_info.db_name} Test connect Failure!{str(e)}') + raise ValueError(f'{db_info.db_name} Test connect Failure!{str(e)}') + + def get_db_list(self): return self.storage.get_db_list() @@ -159,6 +181,10 @@ class ConnectManager: db_info.comment, ) + async def async_db_summary_embedding(self, db_name, db_type): + # 在这里执行需要异步运行的代码 + self.db_summary_client.db_summary_embedding(db_name, db_type) + def add_db(self, db_info: DBConfig): print(f"add_db:{db_info.__dict__}") try: diff --git a/pilot/connections/rdbms/base.py b/pilot/connections/rdbms/base.py index 51d70d386..4ed276283 100644 --- a/pilot/connections/rdbms/base.py +++ b/pilot/connections/rdbms/base.py @@ -1,5 +1,5 @@ from __future__ import annotations - +from urllib.parse import quote import warnings import sqlparse import regex as re @@ -95,9 +95,9 @@ class RDBMSDatabase(BaseConnect): db_url: str = ( cls.driver + "://" - + user + + quote(user) + ":" - + pwd + + quote(pwd) + "@" + host + ":" @@ -493,9 +493,13 @@ class RDBMSDatabase(BaseConnect): def get_users(self): """Get user info.""" - cursor = self.session.execute(text(f"SELECT user, host FROM mysql.user")) - users = cursor.fetchall() - return [(user[0], user[1]) for user in users] + try: + cursor = self.session.execute(text(f"SELECT user, host FROM mysql.user")) + users = cursor.fetchall() + return [(user[0], user[1]) for user in users] + except Exception as e: + return [] + def get_table_comments(self, db_name): cursor = self.session.execute( diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 33e226468..f4b555e68 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -38,6 +38,7 @@ from pilot.common.schema import DBType from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.scene.message import OnceConversation from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH +from pilot.summary.db_summary_client import DBSummaryClient router = APIRouter() CFG = Config() @@ -108,9 +109,29 @@ async def db_connect_delete(db_name: str = None): return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name)) +async def async_db_summary_embedding(db_name, db_type): + # 在这里执行需要异步运行的代码 + db_summary_client = DBSummaryClient() + db_summary_client.db_summary_embedding(db_name, db_type) + + +@router.post("/v1/chat/db/test/connect", response_model=Result[bool]) +async def test_connect(db_config: DBConfig = Body()): + try: + CFG.LOCAL_DB_MANAGE.test_connect(db_config) + return Result.succ(True) + except Exception as e: + return Result.faild(code="E1001", msg=str(e)) + + +@router.post("/v1/chat/db/summary", response_model=Result[bool]) +async def db_summary(db_name: str, db_type: str): + async_db_summary_embedding(db_name, db_type) + return Result.succ(True) + + @router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo]) async def db_support_types(): - support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types() db_type_infos = [] for type in support_types: @@ -229,7 +250,8 @@ async def dialogue_delete(con_uid: str): history_mem.delete() return Result.succ(None) -def get_hist_messages(conv_uid:str): + +def get_hist_messages(conv_uid: str): message_vos: List[MessageVo] = [] history_mem = DuckdbHistoryMemory(conv_uid) history_messages: List[OnceConversation] = history_mem.get_messages()