mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-16 23:37:52 +00:00
85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
from typing import Dict
|
|
|
|
from dbgpt._private.config import Config
|
|
from dbgpt.app.scene import BaseChat, ChatScene
|
|
from dbgpt.util.executor_utils import blocking_func_to_async
|
|
from dbgpt.util.tracer import trace
|
|
|
|
CFG = Config()
|
|
|
|
|
|
class ChatWithDbQA(BaseChat):
|
|
chat_scene: str = ChatScene.ChatWithDbQA.value()
|
|
|
|
keep_end_rounds = 5
|
|
|
|
"""As a DBA, Chat DB Module, chat with combine DB meta schema """
|
|
|
|
def __init__(self, chat_param: Dict):
|
|
"""Chat DB Module Initialization
|
|
Args:
|
|
- chat_param: Dict
|
|
- chat_session_id: (str) chat session_id
|
|
- current_user_input: (str) current user input
|
|
- model_name:(str) llm model name
|
|
- select_param:(str) dbname
|
|
"""
|
|
self.db_name = chat_param["select_param"]
|
|
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
|
|
super().__init__(chat_param=chat_param)
|
|
|
|
if self.db_name:
|
|
self.database = CFG.local_db_manager.get_connector(self.db_name)
|
|
self.tables = self.database.get_table_names()
|
|
if self.database.is_graph_type():
|
|
# When the current graph database retrieves source data from ChatDB, the topk uses the sum of node table and edge table.
|
|
self.top_k = len(self.tables["vertex_tables"]) + len(
|
|
self.tables["edge_tables"]
|
|
)
|
|
else:
|
|
print(self.database.db_type)
|
|
self.top_k = (
|
|
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
|
if len(self.tables) > CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
|
else len(self.tables)
|
|
)
|
|
|
|
@trace()
|
|
async def generate_input_values(self) -> Dict:
|
|
table_info = ""
|
|
dialect = "mysql"
|
|
try:
|
|
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
|
|
except ImportError:
|
|
raise ValueError("Could not import DBSummaryClient. ")
|
|
if self.db_name:
|
|
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
|
try:
|
|
# table_infos = client.get_db_summary(
|
|
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
|
# )
|
|
table_infos = await blocking_func_to_async(
|
|
self._executor,
|
|
client.get_db_summary,
|
|
self.db_name,
|
|
self.current_user_input,
|
|
self.top_k,
|
|
)
|
|
except Exception as e:
|
|
print("db summary find error!" + str(e))
|
|
# table_infos = self.database.table_simple_info()
|
|
table_infos = await blocking_func_to_async(
|
|
self._executor, self.database.table_simple_info
|
|
)
|
|
|
|
# table_infos = self.database.table_simple_info()
|
|
dialect = self.database.dialect
|
|
|
|
input_values = {
|
|
"input": self.current_user_input,
|
|
# "top_k": str(self.top_k),
|
|
# "dialect": dialect,
|
|
"table_info": table_infos,
|
|
}
|
|
return input_values
|