DB-GPT/dbgpt/app/scene/chat_db/auto_execute/chat.py

90 lines
3.0 KiB
Python

from typing import Dict
from dbgpt._private.config import Config
from dbgpt.agent.util.api_call import ApiCall
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.util.tracer import root_tracer, trace
CFG = Config()
class ChatWithDbAutoExecute(BaseChat):
chat_scene: str = ChatScene.ChatWithDbExecute.value()
"""Number of results to return from the query"""
def __init__(self, chat_param: Dict):
"""Chat Data Module Initialization
Args:
- chat_param: Dict
- chat_session_id: (str) chat session_id
- current_user_input: (str) current user input
- model_name:(str) llm model name
- select_param:(str) dbname
"""
chat_mode = ChatScene.ChatWithDbExecute
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = chat_mode
""" """
super().__init__(
chat_param=chat_param,
)
if not self.db_name:
raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
)
with root_tracer.start_span(
"ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
):
self.database = CFG.local_db_manager.get_connector(self.db_name)
self.top_k: int = 50
self.api_call = ApiCall()
@trace()
async def generate_input_values(self) -> Dict:
"""
generate input values
"""
try:
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
table_infos = None
try:
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
self.db_name,
self.current_user_input,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
except Exception as e:
print("db summary find error!" + str(e))
if not table_infos:
table_infos = await blocking_func_to_async(
self._executor, self.database.table_simple_info
)
input_values = {
"db_name": self.db_name,
"user_input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": table_infos,
"display_type": self._generate_numbered_list(),
}
return input_values
def stream_plugin_call(self, text):
text = text.replace("\n", " ")
print(f"stream_plugin_call:{text}")
return self.api_call.display_sql_llmvis(text, self.database.run_to_df)
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
return self.database.run_to_df