DB-GPT/dbgpt/datasource/manages/connector_manager.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

243 lines
9.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Connection manager."""
import logging
from typing import TYPE_CHECKING, List, Optional, Type
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.storage.schema import DBType
from dbgpt.util.executor_utils import ExecutorFactory
from ..base import BaseConnector
from ..db_conn_info import DBConfig
from .connect_config_db import ConnectConfigDao
if TYPE_CHECKING:
# TODO: Don't depend on the rag module.
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
logger = logging.getLogger(__name__)
class ConnectorManager(BaseComponent):
"""Connector manager."""
name = ComponentType.CONNECTOR_MANAGER
def __init__(self, system_app: SystemApp):
"""Create a new ConnectorManager."""
self.storage = ConnectConfigDao()
self.system_app = system_app
self._db_summary_client: Optional["DBSummaryClient"] = None
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
"""Init component."""
self.system_app = system_app
def on_init(self):
"""Execute on init.
Load all connector classes.
"""
from dbgpt.datasource.conn_spark import SparkConnector # noqa: F401
from dbgpt.datasource.conn_tugraph import TuGraphConnector # noqa: F401
from dbgpt.datasource.rdbms.base import RDBMSConnector # noqa: F401
from dbgpt.datasource.rdbms.conn_clickhouse import ( # noqa: F401
ClickhouseConnector,
)
from dbgpt.datasource.rdbms.conn_doris import DorisConnector # noqa: F401
from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnector # noqa: F401
from dbgpt.datasource.rdbms.conn_hive import HiveConnector # noqa: F401
from dbgpt.datasource.rdbms.conn_mssql import MSSQLConnector # noqa: F401
from dbgpt.datasource.rdbms.conn_mysql import MySQLConnector # noqa: F401
from dbgpt.datasource.rdbms.conn_oceanbase import OceanBaseConnect # noqa: F401
from dbgpt.datasource.rdbms.conn_postgresql import ( # noqa: F401
PostgreSQLConnector,
)
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnector # noqa: F401
from dbgpt.datasource.rdbms.conn_starrocks import ( # noqa: F401
StarRocksConnector,
)
from dbgpt.datasource.rdbms.conn_vertica import VerticaConnector # noqa: F401
from dbgpt.datasource.rdbms.dialect.oceanbase.ob_dialect import ( # noqa: F401
OBDialect,
)
from .connect_config_db import ConnectConfigEntity # noqa: F401
def before_start(self):
"""Execute before start."""
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
self._db_summary_client = DBSummaryClient(self.system_app)
@property
def db_summary_client(self) -> "DBSummaryClient":
"""Get DBSummaryClient."""
if not self._db_summary_client:
raise ValueError("DBSummaryClient is not initialized")
return self._db_summary_client
def _get_all_subclasses(
self, cls: Type[BaseConnector]
) -> List[Type[BaseConnector]]:
"""Get all subclasses of cls."""
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += self._get_all_subclasses(subclass)
return subclasses
def get_all_completed_types(self) -> List[DBType]:
"""Get all completed types."""
chat_classes = self._get_all_subclasses(BaseConnector) # type: ignore
support_types = []
for cls in chat_classes:
if cls.db_type and cls.is_normal_type():
db_type = DBType.of_db_type(cls.db_type)
if db_type:
support_types.append(db_type)
return support_types
def get_cls_by_dbtype(self, db_type) -> Type[BaseConnector]:
"""Get class by db type."""
chat_classes = self._get_all_subclasses(BaseConnector) # type: ignore
result = None
for cls in chat_classes:
if cls.db_type == db_type and cls.is_normal_type():
result = cls
if not result:
raise ValueError("Unsupported Db Type" + db_type)
return result
def get_connector(self, db_name: str):
"""Create a new connection instance.
Args:
db_name (str): database name
"""
db_config = self.storage.get_db_config(db_name)
db_type = DBType.of_db_type(db_config.get("db_type"))
if not db_type:
raise ValueError("Unsupported Db Type" + db_config.get("db_type"))
connect_instance = self.get_cls_by_dbtype(db_type.value())
if db_type.is_file_db():
db_path = db_config.get("db_path")
return connect_instance.from_file_path(db_path) # type: ignore
else:
db_host = db_config.get("db_host")
db_port = db_config.get("db_port")
db_user = db_config.get("db_user")
db_pwd = db_config.get("db_pwd")
return connect_instance.from_uri_db( # type: ignore
host=db_host, port=db_port, user=db_user, pwd=db_pwd, db_name=db_name
)
def test_connect(self, db_info: DBConfig) -> BaseConnector:
"""Test connectivity.
Args:
db_info (DBConfig): db connect info.
Returns:
BaseConnector: connector instance.
Raises:
ValueError: Test connect Failure.
"""
try:
db_type = DBType.of_db_type(db_info.db_type)
if not db_type:
raise ValueError("Unsupported Db Type" + db_info.db_type)
connect_instance = self.get_cls_by_dbtype(db_type.value())
if db_type.is_file_db():
db_path = db_info.file_path
return connect_instance.from_file_path(db_path) # type: ignore
else:
db_name = db_info.db_name
db_host = db_info.db_host
db_port = db_info.db_port
db_user = db_info.db_user
db_pwd = db_info.db_pwd
return connect_instance.from_uri_db( # type: ignore
host=db_host,
port=db_port,
user=db_user,
pwd=db_pwd,
db_name=db_name,
)
except Exception as e:
logger.error(f"{db_info.db_name} Test connect Failure!{str(e)}")
raise ValueError(f"{db_info.db_name} Test connect Failure!{str(e)}")
def get_db_list(self, db_name: Optional[str] = None, user_id: Optional[str] = None):
"""Get db list."""
return self.storage.get_db_list(db_name, user_id)
def delete_db(self, db_name: str):
"""Delete db connect info."""
return self.storage.delete_db(db_name)
def edit_db(self, db_info: DBConfig):
"""Edit db connect info."""
return self.storage.update_db_info(
db_info.db_name,
db_info.db_type,
db_info.file_path,
db_info.db_host,
db_info.db_port,
db_info.db_user,
db_info.db_pwd,
db_info.comment,
)
async def async_db_summary_embedding(self, db_name, db_type):
"""Async db summary embedding."""
executor = self.system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create() # type: ignore
executor.submit(self.db_summary_client.db_summary_embedding, db_name, db_type)
return True
def add_db(self, db_info: DBConfig, user_id: Optional[str] = None):
"""Add db connect info.
Args:
db_info (DBConfig): db connect info.
"""
logger.info(f"add_db:{db_info.__dict__}")
try:
db_type = DBType.of_db_type(db_info.db_type)
if not db_type:
raise ValueError("Unsupported Db Type" + db_info.db_type)
if db_type.is_file_db():
self.storage.add_file_db(
db_info.db_name,
db_info.db_type,
db_info.file_path,
db_info.comment,
user_id,
)
else:
self.storage.add_url_db(
db_info.db_name,
db_info.db_type,
db_info.db_host,
db_info.db_port,
db_info.db_user,
db_info.db_pwd,
db_info.comment,
user_id,
)
# async embedding
executor = self.system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create() # type: ignore
executor.submit(
self.db_summary_client.db_summary_embedding,
db_info.db_name,
db_info.db_type,
)
except Exception as e:
raise ValueError("Add db connect info error!" + str(e))
return True