DB-GPT/dbgpt/app/scene/chat_db/professional_qa/chat.py
lipengfei a5666b3120
feat: Added support for TuGraph graph database (#1451)
Co-authored-by: aries_ckt <916701291@qq.com>
2024-04-26 09:48:40 +08:00

85 lines
3.0 KiB
Python

from typing import Dict
from dbgpt._private.config import Config
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.util.tracer import trace
CFG = Config()
class ChatWithDbQA(BaseChat):
chat_scene: str = ChatScene.ChatWithDbQA.value()
keep_end_rounds = 5
"""As a DBA, Chat DB Module, chat with combine DB meta schema """
def __init__(self, chat_param: Dict):
"""Chat DB Module Initialization
Args:
- chat_param: Dict
- chat_session_id: (str) chat session_id
- current_user_input: (str) current user input
- model_name:(str) llm model name
- select_param:(str) dbname
"""
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
super().__init__(chat_param=chat_param)
if self.db_name:
self.database = CFG.local_db_manager.get_connector(self.db_name)
self.tables = self.database.get_table_names()
if self.database.is_graph_type():
# When the current graph database retrieves source data from ChatDB, the topk uses the sum of node table and edge table.
self.top_k = len(self.tables["vertex_tables"]) + len(
self.tables["edge_tables"]
)
else:
print(self.database.db_type)
self.top_k = (
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
if len(self.tables) > CFG.KNOWLEDGE_SEARCH_TOP_SIZE
else len(self.tables)
)
@trace()
async def generate_input_values(self) -> Dict:
table_info = ""
dialect = "mysql"
try:
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
if self.db_name:
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
# table_infos = client.get_db_summary(
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
self.db_name,
self.current_user_input,
self.top_k,
)
except Exception as e:
print("db summary find error!" + str(e))
# table_infos = self.database.table_simple_info()
table_infos = await blocking_func_to_async(
self._executor, self.database.table_simple_info
)
# table_infos = self.database.table_simple_info()
dialect = self.database.dialect
input_values = {
"input": self.current_user_input,
# "top_k": str(self.top_k),
# "dialect": dialect,
"table_info": table_infos,
}
return input_values