Fix ChatWithDbQA param validate (#2569)

Co-authored-by: alan.cl <alan.cl@antgroup.com>
This commit is contained in:
alanchen 2025-04-03 09:49:48 +08:00 committed by GitHub
parent e7b23f6425
commit 5dbfb24a86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 4 deletions

View File

@ -76,7 +76,7 @@ async def chat_completions(
"""Chat V2 completions """Chat V2 completions
Args: Args:
request (ChatCompletionRequestBody): The chat request. request (ChatCompletionRequestBody): The chat request.
flow_service (FlowService): The flow service. service (FlowService): The flow service.
Raises: Raises:
HTTPException: If the request is invalid. HTTPException: If the request is invalid.
""" """

View File

@ -28,19 +28,25 @@ class ChatWithDbQA(BaseChat):
- select_param:(str) dbname - select_param:(str) dbname
""" """
self.db_name = chat_param.select_param self.db_name = chat_param.select_param
self.database = None
self.curr_config = chat_param.real_app_config(ChatWithDBQAConfig) self.curr_config = chat_param.real_app_config(ChatWithDBQAConfig)
super().__init__(chat_param=chat_param, system_app=system_app) super().__init__(chat_param=chat_param, system_app=system_app)
if self.db_name is None:
raise Exception(f"Database: {self.db_name} not found")
if self.db_name: if self.db_name:
local_db_manager = ConnectorManager.get_instance(self.system_app) local_db_manager = ConnectorManager.get_instance(self.system_app)
self.database = local_db_manager.get_connector(self.db_name) self.database = local_db_manager.get_connector(self.db_name)
self.tables = self.database.get_table_names() self.tables = self.database.get_table_names()
if self.database.is_graph_type(): if self.database is not None and self.database.is_graph_type():
# When the current graph database retrieves source data from ChatDB, the # When the current graph database retrieves source data from ChatDB, the
# topk uses the sum of node table and edge table. # topk uses the sum of node table and edge table.
self.top_k = len(list(self.tables)) self.top_k = len(list(self.tables))
else: else:
logger.info(f"Dialect: {self.database.db_type}") logger.info(
"Dialect: "
f"{self.database.db_type if self.database is not None else None}"
)
self.top_k = self.curr_config.schema_retrieve_top_k self.top_k = self.curr_config.schema_retrieve_top_k
@trace() @trace()

View File

@ -1391,7 +1391,7 @@ def adapt_native_app_model(dialogue: ConversationVo):
ChatScene.ChatWithDbQA.value(), ChatScene.ChatWithDbQA.value(),
ChatScene.ChatWithDbExecute.value(), ChatScene.ChatWithDbExecute.value(),
ChatScene.ChatDashboard.value(), ChatScene.ChatDashboard.value(),
ChatScene.ChatNormal.value, ChatScene.ChatNormal.value(),
]: ]:
return dialogue return dialogue
gpts_dao = GptsAppDao() gpts_dao = GptsAppDao()