mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
refactor: Refactor datasource module (#1309)
This commit is contained in:
@@ -67,7 +67,7 @@ def __new_conversation(chat_mode, user_name: str, sys_code: str) -> Conversation
|
||||
|
||||
|
||||
def get_db_list():
|
||||
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
|
||||
dbs = CFG.local_db_manager.get_db_list()
|
||||
db_params = []
|
||||
for item in dbs:
|
||||
params: dict = {}
|
||||
@@ -85,7 +85,7 @@ def plugins_select_info():
|
||||
|
||||
|
||||
def get_db_list_info():
|
||||
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
|
||||
dbs = CFG.local_db_manager.get_db_list()
|
||||
params: dict = {}
|
||||
for item in dbs:
|
||||
comment = item["comment"]
|
||||
@@ -147,22 +147,22 @@ def get_executor() -> Executor:
|
||||
|
||||
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
||||
async def db_connect_list():
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
|
||||
return Result.succ(CFG.local_db_manager.get_db_list())
|
||||
|
||||
|
||||
@router.post("/v1/chat/db/add", response_model=Result[bool])
|
||||
async def db_connect_add(db_config: DBConfig = Body()):
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
|
||||
return Result.succ(CFG.local_db_manager.add_db(db_config))
|
||||
|
||||
|
||||
@router.post("/v1/chat/db/edit", response_model=Result[bool])
|
||||
async def db_connect_edit(db_config: DBConfig = Body()):
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.edit_db(db_config))
|
||||
return Result.succ(CFG.local_db_manager.edit_db(db_config))
|
||||
|
||||
|
||||
@router.post("/v1/chat/db/delete", response_model=Result[bool])
|
||||
async def db_connect_delete(db_name: str = None):
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
|
||||
return Result.succ(CFG.local_db_manager.delete_db(db_name))
|
||||
|
||||
|
||||
async def async_db_summary_embedding(db_name, db_type):
|
||||
@@ -174,7 +174,7 @@ async def async_db_summary_embedding(db_name, db_type):
|
||||
async def test_connect(db_config: DBConfig = Body()):
|
||||
try:
|
||||
# TODO Change the synchronous call to the asynchronous call
|
||||
CFG.LOCAL_DB_MANAGE.test_connect(db_config)
|
||||
CFG.local_db_manager.test_connect(db_config)
|
||||
return Result.succ(True)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E1001", msg=str(e))
|
||||
@@ -189,7 +189,7 @@ async def db_summary(db_name: str, db_type: str):
|
||||
|
||||
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
|
||||
async def db_support_types():
|
||||
support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types()
|
||||
support_types = CFG.local_db_manager.get_all_completed_types()
|
||||
db_type_infos = []
|
||||
for type in support_types:
|
||||
db_type_infos.append(
|
||||
|
@@ -47,7 +47,7 @@ async def get_editor_tables(
|
||||
db_name: str, page_index: int, page_size: int, search_str: str = ""
|
||||
):
|
||||
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
db_conn = CFG.local_db_manager.get_connector(db_name)
|
||||
tables = db_conn.get_table_names()
|
||||
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
|
||||
for table in tables:
|
||||
@@ -95,7 +95,7 @@ async def editor_sql_run(run_param: dict = Body()):
|
||||
sql = run_param["sql"]
|
||||
if not db_name and not sql:
|
||||
return Result.failed(msg="SQL run param error!")
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
conn = CFG.local_db_manager.get_connector(db_name)
|
||||
|
||||
try:
|
||||
start_time = time.time() * 1000
|
||||
@@ -125,7 +125,7 @@ async def sql_editor_submit(
|
||||
):
|
||||
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
|
||||
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
|
||||
conn = CFG.local_db_manager.get_connector(sql_edit_context.db_name)
|
||||
try:
|
||||
editor_service.sql_editor_submit_and_save(sql_edit_context, conn)
|
||||
return Result.succ(None)
|
||||
@@ -168,7 +168,7 @@ async def editor_chart_run(run_param: dict = Body()):
|
||||
return Result.failed("SQL run param error!")
|
||||
try:
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
db_conn = CFG.local_db_manager.get_connector(db_name)
|
||||
colunms, sql_result = db_conn.query_ex(sql)
|
||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
|
||||
colunms, sql_result, sql
|
||||
@@ -204,7 +204,7 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
|
||||
history_messages: List[Dict] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
|
||||
db_conn = CFG.local_db_manager.get_connector(chart_edit_context.db_name)
|
||||
|
||||
edit_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
if edit_round:
|
||||
|
@@ -22,7 +22,7 @@ from dbgpt.core.interface.message import (
|
||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.datasource.base import BaseConnect
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,7 +86,7 @@ class EditorService(BaseComponent):
|
||||
return None
|
||||
|
||||
def sql_editor_submit_and_save(
|
||||
self, sql_edit_context: ChatSqlEditContext, connection: BaseConnect
|
||||
self, sql_edit_context: ChatSqlEditContext, connection: BaseConnector
|
||||
):
|
||||
storage_conv: StorageConversation = self.get_storage_conv(
|
||||
sql_edit_context.conv_uid
|
||||
@@ -169,7 +169,7 @@ class EditorService(BaseComponent):
|
||||
filter(lambda x: x["chart_name"] == chart_title, charts)
|
||||
)[0]
|
||||
|
||||
conn = cfg.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
conn = cfg.local_db_manager.get_connector(db_name)
|
||||
detail: ChartDetail = ChartDetail(
|
||||
chart_uid=find_chart["chart_uid"],
|
||||
chart_type=find_chart["chart_type"],
|
||||
|
Reference in New Issue
Block a user