diff --git a/packages/dbgpt-app/src/dbgpt_app/openapi/api_v2.py b/packages/dbgpt-app/src/dbgpt_app/openapi/api_v2.py index ec8865eb6..e762f1cc3 100644 --- a/packages/dbgpt-app/src/dbgpt_app/openapi/api_v2.py +++ b/packages/dbgpt-app/src/dbgpt_app/openapi/api_v2.py @@ -128,6 +128,7 @@ async def chat_completions( or request.chat_mode == ChatMode.CHAT_KNOWLEDGE.value or request.chat_mode == ChatMode.CHAT_DATA.value or request.chat_mode == ChatMode.CHAT_DB_QA.value + or request.chat_mode == ChatMode.CHAT_DASHBOARD.value ): with root_tracer.start_span( "get_chat_instance", @@ -157,7 +158,7 @@ async def chat_completions( detail={ "error": { "message": "chat mode now only support chat_normal, chat_app, " - "chat_flow, chat_knowledge, chat_data", + "chat_flow, chat_knowledge, chat_data, chat_dashboard", "type": "invalid_request_error", "param": None, "code": "invalid_chat_mode", diff --git a/packages/dbgpt-client/src/dbgpt_client/schema.py b/packages/dbgpt-client/src/dbgpt_client/schema.py index ed9d86f3e..6c0d0ac18 100644 --- a/packages/dbgpt-client/src/dbgpt_client/schema.py +++ b/packages/dbgpt-client/src/dbgpt_client/schema.py @@ -94,6 +94,7 @@ class ChatMode(Enum): CHAT_KNOWLEDGE = "chat_knowledge" CHAT_DATA = "chat_data" CHAT_DB_QA = "chat_with_db_qa" + CHAT_DASHBOARD = "chat_dashboard" class AWELTeamModel(BaseModel): diff --git a/packages/dbgpt-serve/src/dbgpt_serve/datasource/manages/connect_config_db.py b/packages/dbgpt-serve/src/dbgpt_serve/datasource/manages/connect_config_db.py index 4e8082b3a..8fe5eddec 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/datasource/manages/connect_config_db.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/datasource/manages/connect_config_db.py @@ -202,7 +202,10 @@ class ConnectConfigDao(BaseDao): def get_db_config(self, db_name): """Return db connect info by name.""" session = self.get_raw_session() - if db_name: + try: + if not db_name: + raise ValueError("Database name cannot be empty") + select_statement = text( """ SELECT @@ -216,16 +219,16 @@ class ConnectConfigDao(BaseDao): params = {"db_name": db_name} result = session.execute(select_statement, params) - else: - raise ValueError("Cannot get database by name" + db_name) + fields = [field[0] for field in result.cursor.description] - logger.info(f"Result: {result}") - fields = [field[0] for field in result.cursor.description] - row_dict = {} - row_1 = list(result.cursor.fetchall()[0]) - for i, field in enumerate(fields): - row_dict[field] = row_1[i] - return row_dict + row = result.cursor.fetchone() + if not row: + logger.error(f"No database config found for db_name: {db_name}") + raise ValueError(f"Database config not found for: {db_name}") + + return {fields[i]: row[i] for i in range(len(fields))} + finally: + session.close() def get_db_list(self, db_name: Optional[str] = None, user_id: Optional[str] = None): """Get db list."""