refactor: Refactor datasource module (#1309)

This commit is contained in:
Fangyin Cheng
2024-03-18 18:06:40 +08:00
committed by GitHub
parent 84bedee306
commit 4970c9f813
108 changed files with 1194 additions and 1066 deletions

View File

@@ -70,15 +70,8 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):
def _create_model_start_listener(system_app: SystemApp):
from dbgpt.datasource.manages.connection_manager import ConnectManager
cfg = Config()
def startup_event(wh):
# init connect manage
print("begin run _add_app_startup_event")
conn_manage = ConnectManager(system_app)
cfg.LOCAL_DB_MANAGE = conn_manage
async_db_summary(system_app)
return startup_event

View File

@@ -23,6 +23,7 @@ def initialize_components(
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
from dbgpt.app.initialization.scheduler import DefaultScheduler
from dbgpt.app.initialization.serve_initialization import register_serve_apps
from dbgpt.datasource.manages.connector_manager import ConnectorManager
from dbgpt.model.cluster.controller.controller import controller
# Register global default executor factory first
@@ -31,6 +32,7 @@ def initialize_components(
)
system_app.register(DefaultScheduler)
system_app.register_instance(controller)
system_app.register(ConnectorManager)
from dbgpt.serve.agent.hub.controller import module_plugin

View File

@@ -28,8 +28,8 @@ from dbgpt.app.knowledge.request.response import (
from dbgpt.app.knowledge.space_db import KnowledgeSpaceDao, KnowledgeSpaceEntity
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk
from dbgpt.model import DefaultLLMClient
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType

View File

@@ -67,7 +67,7 @@ def __new_conversation(chat_mode, user_name: str, sys_code: str) -> Conversation
def get_db_list():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
dbs = CFG.local_db_manager.get_db_list()
db_params = []
for item in dbs:
params: dict = {}
@@ -85,7 +85,7 @@ def plugins_select_info():
def get_db_list_info():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
dbs = CFG.local_db_manager.get_db_list()
params: dict = {}
for item in dbs:
comment = item["comment"]
@@ -147,22 +147,22 @@ def get_executor() -> Executor:
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
return Result.succ(CFG.local_db_manager.get_db_list())
@router.post("/v1/chat/db/add", response_model=Result[bool])
async def db_connect_add(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
return Result.succ(CFG.local_db_manager.add_db(db_config))
@router.post("/v1/chat/db/edit", response_model=Result[bool])
async def db_connect_edit(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.edit_db(db_config))
return Result.succ(CFG.local_db_manager.edit_db(db_config))
@router.post("/v1/chat/db/delete", response_model=Result[bool])
async def db_connect_delete(db_name: str = None):
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
return Result.succ(CFG.local_db_manager.delete_db(db_name))
async def async_db_summary_embedding(db_name, db_type):
@@ -174,7 +174,7 @@ async def async_db_summary_embedding(db_name, db_type):
async def test_connect(db_config: DBConfig = Body()):
try:
# TODO Change the synchronous call to the asynchronous call
CFG.LOCAL_DB_MANAGE.test_connect(db_config)
CFG.local_db_manager.test_connect(db_config)
return Result.succ(True)
except Exception as e:
return Result.failed(code="E1001", msg=str(e))
@@ -189,7 +189,7 @@ async def db_summary(db_name: str, db_type: str):
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
async def db_support_types():
support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types()
support_types = CFG.local_db_manager.get_all_completed_types()
db_type_infos = []
for type in support_types:
db_type_infos.append(

View File

@@ -47,7 +47,7 @@ async def get_editor_tables(
db_name: str, page_index: int, page_size: int, search_str: str = ""
):
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
db_conn = CFG.local_db_manager.get_connector(db_name)
tables = db_conn.get_table_names()
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
for table in tables:
@@ -95,7 +95,7 @@ async def editor_sql_run(run_param: dict = Body()):
sql = run_param["sql"]
if not db_name and not sql:
return Result.failed(msg="SQL run param error")
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
conn = CFG.local_db_manager.get_connector(db_name)
try:
start_time = time.time() * 1000
@@ -125,7 +125,7 @@ async def sql_editor_submit(
):
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
conn = CFG.local_db_manager.get_connector(sql_edit_context.db_name)
try:
editor_service.sql_editor_submit_and_save(sql_edit_context, conn)
return Result.succ(None)
@@ -168,7 +168,7 @@ async def editor_chart_run(run_param: dict = Body()):
return Result.failed("SQL run param error")
try:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
db_conn = CFG.local_db_manager.get_connector(db_name)
colunms, sql_result = db_conn.query_ex(sql)
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
colunms, sql_result, sql
@@ -204,7 +204,7 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
history_messages: List[Dict] = history_mem.get_messages()
if history_messages:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
db_conn = CFG.local_db_manager.get_connector(chart_edit_context.db_name)
edit_round = max(history_messages, key=lambda x: x["chat_order"])
if edit_round:

View File

@@ -22,7 +22,7 @@ from dbgpt.core.interface.message import (
from dbgpt.serve.conversation.serve import Serve as ConversationServe
if TYPE_CHECKING:
from dbgpt.datasource.base import BaseConnect
from dbgpt.datasource.base import BaseConnector
logger = logging.getLogger(__name__)
@@ -86,7 +86,7 @@ class EditorService(BaseComponent):
return None
def sql_editor_submit_and_save(
self, sql_edit_context: ChatSqlEditContext, connection: BaseConnect
self, sql_edit_context: ChatSqlEditContext, connection: BaseConnector
):
storage_conv: StorageConversation = self.get_storage_conv(
sql_edit_context.conv_uid
@@ -169,7 +169,7 @@ class EditorService(BaseComponent):
filter(lambda x: x["chart_name"] == chart_title, charts)
)[0]
conn = cfg.LOCAL_DB_MANAGE.get_connect(db_name)
conn = cfg.local_db_manager.get_connector(db_name)
detail: ChartDetail = ChartDetail(
chart_uid=find_chart["chart_uid"],
chart_type=find_chart["chart_type"],

View File

@@ -38,7 +38,7 @@ class ChatDashboard(BaseChat):
self.db_name = self.db_name
self.report_name = chat_param.get("report_name", "report")
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.database = CFG.local_db_manager.get_connector(self.db_name)
self.top_k: int = 5
self.dashboard_template = self.__load_dashboard_template(self.report_name)

View File

@@ -57,5 +57,5 @@ class DashboardDataLoader:
def get_chart_values_by_db(self, db_name: str, chart_sql: str):
logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
db_conn = CFG.local_db_manager.get_connector(db_name)
return self.get_chart_values_by_conn(db_conn, chart_sql)

View File

@@ -37,7 +37,7 @@ class ChatWithDbAutoExecute(BaseChat):
with root_tracer.start_span(
"ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.database = CFG.local_db_manager.get_connector(self.db_name)
self.top_k: int = 50
self.api_call = ApiCall(display_registry=CFG.command_display)

View File

@@ -29,7 +29,7 @@ class ChatWithDbQA(BaseChat):
super().__init__(chat_param=chat_param)
if self.db_name:
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.database = CFG.local_db_manager.get_connector(self.db_name)
self.tables = self.database.get_table_names()
self.top_k = (