mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-07 19:34:04 +00:00
90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
from typing import Dict
|
|
|
|
from dbgpt._private.config import Config
|
|
from dbgpt.agent.util.api_call import ApiCall
|
|
from dbgpt.app.scene import BaseChat, ChatScene
|
|
from dbgpt.util.executor_utils import blocking_func_to_async
|
|
from dbgpt.util.tracer import root_tracer, trace
|
|
|
|
CFG = Config()
|
|
|
|
|
|
class ChatWithDbAutoExecute(BaseChat):
|
|
chat_scene: str = ChatScene.ChatWithDbExecute.value()
|
|
|
|
"""Number of results to return from the query"""
|
|
|
|
def __init__(self, chat_param: Dict):
|
|
"""Chat Data Module Initialization
|
|
Args:
|
|
- chat_param: Dict
|
|
- chat_session_id: (str) chat session_id
|
|
- current_user_input: (str) current user input
|
|
- model_name:(str) llm model name
|
|
- select_param:(str) dbname
|
|
"""
|
|
chat_mode = ChatScene.ChatWithDbExecute
|
|
self.db_name = chat_param["select_param"]
|
|
chat_param["chat_mode"] = chat_mode
|
|
""" """
|
|
super().__init__(
|
|
chat_param=chat_param,
|
|
)
|
|
if not self.db_name:
|
|
raise ValueError(
|
|
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
|
)
|
|
with root_tracer.start_span(
|
|
"ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
|
|
):
|
|
self.database = CFG.local_db_manager.get_connector(self.db_name)
|
|
|
|
self.top_k: int = 50
|
|
self.api_call = ApiCall()
|
|
|
|
@trace()
|
|
async def generate_input_values(self) -> Dict:
|
|
"""
|
|
generate input values
|
|
"""
|
|
try:
|
|
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
|
|
except ImportError:
|
|
raise ValueError("Could not import DBSummaryClient. ")
|
|
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
|
table_infos = None
|
|
try:
|
|
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
|
|
table_infos = await blocking_func_to_async(
|
|
self._executor,
|
|
client.get_db_summary,
|
|
self.db_name,
|
|
self.current_user_input,
|
|
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
|
)
|
|
except Exception as e:
|
|
print("db summary find error!" + str(e))
|
|
if not table_infos:
|
|
table_infos = await blocking_func_to_async(
|
|
self._executor, self.database.table_simple_info
|
|
)
|
|
|
|
input_values = {
|
|
"db_name": self.db_name,
|
|
"user_input": self.current_user_input,
|
|
"top_k": str(self.top_k),
|
|
"dialect": self.database.dialect,
|
|
"table_info": table_infos,
|
|
"display_type": self._generate_numbered_list(),
|
|
}
|
|
return input_values
|
|
|
|
def stream_plugin_call(self, text):
|
|
text = text.replace("\n", " ")
|
|
print(f"stream_plugin_call:{text}")
|
|
return self.api_call.display_sql_llmvis(text, self.database.run_to_df)
|
|
|
|
def do_action(self, prompt_response):
|
|
print(f"do_action:{prompt_response}")
|
|
return self.database.run_to_df
|