diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index faffcc146..cb2425ea9 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -35,7 +35,13 @@ class ChatWithDbQA(BaseChat): self.database = CFG.local_db # 准备DB信息(拿到指定库的链接) self.db_connect = self.database.get_session(self.db_name) - self.top_k: int = 5 + self.tables = self.database.get_table_names() + + self.top_k = ( + CFG.KNOWLEDGE_SEARCH_TOP_SIZE + if len(self.tables) > CFG.KNOWLEDGE_SEARCH_TOP_SIZE + else len(self.tables) + ) def generate_input_values(self): table_info = ""