From 4970c9f813600ff79b99c3f18fd3135d9ad1c6f0 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 18 Mar 2024 18:06:40 +0800 Subject: [PATCH] refactor: Refactor datasource module (#1309) --- .mypy.ini | 19 +- Makefile | 10 +- dbgpt/_private/config.py | 11 +- dbgpt/app/base.py | 7 - dbgpt/app/component_configs.py | 2 + dbgpt/app/knowledge/service.py | 2 +- dbgpt/app/openapi/api_v1/api_v1.py | 16 +- .../openapi/api_v1/editor/api_editor_v1.py | 10 +- dbgpt/app/openapi/api_v1/editor/service.py | 6 +- dbgpt/app/scene/chat_dashboard/chat.py | 2 +- dbgpt/app/scene/chat_dashboard/data_loader.py | 2 +- dbgpt/app/scene/chat_db/auto_execute/chat.py | 2 +- .../app/scene/chat_db/professional_qa/chat.py | 2 +- dbgpt/component.py | 1 + dbgpt/core/__init__.py | 3 + .../chunk.py => core/interface/knowledge.py} | 0 dbgpt/datasource/__init__.py | 7 +- dbgpt/datasource/base.py | 193 ++++++++++----- dbgpt/datasource/conn_spark.py | 85 ++++--- dbgpt/datasource/db_conn_info.py | 27 ++- dbgpt/datasource/manages/__init__.py | 5 + dbgpt/datasource/manages/connect_config_db.py | 89 +++---- .../manages/connect_storage_duckdb.py | 151 ------------ .../datasource/manages/connection_manager.py | 152 ------------ dbgpt/datasource/manages/connector_manager.py | 228 ++++++++++++++++++ dbgpt/datasource/nosql/__init__.py | 1 + dbgpt/datasource/operators/__init__.py | 1 + .../operators/datasource_operator.py | 17 +- dbgpt/datasource/rdbms/__init__.py | 1 + dbgpt/datasource/rdbms/_base_dao.py | 86 ------- dbgpt/datasource/rdbms/base.py | 216 ++++++++++------- dbgpt/datasource/rdbms/conn_clickhouse.py | 101 +++++--- dbgpt/datasource/rdbms/conn_doris.py | 56 +++-- dbgpt/datasource/rdbms/conn_duckdb.py | 26 +- dbgpt/datasource/rdbms/conn_hive.py | 34 ++- dbgpt/datasource/rdbms/conn_mssql.py | 25 +- dbgpt/datasource/rdbms/conn_mysql.py | 13 +- dbgpt/datasource/rdbms/conn_postgresql.py | 85 ++++--- dbgpt/datasource/rdbms/conn_sqlite.py | 64 ++--- dbgpt/datasource/rdbms/conn_starrocks.py | 55 +++-- dbgpt/datasource/rdbms/dialect/__init__.py | 1 + .../rdbms/dialect/starrocks/__init__.py | 2 +- .../dialect/starrocks/sqlalchemy/__init__.py | 2 +- .../dialect/starrocks/sqlalchemy/datatype.py | 30 ++- .../dialect/starrocks/sqlalchemy/dialect.py | 109 ++++++--- .../rdbms/tests/test_conn_duckdb.py | 4 +- .../rdbms/tests/test_conn_sqlite.py | 14 +- dbgpt/datasource/redis.py | 8 +- dbgpt/rag/chunk_manager.py | 2 +- dbgpt/rag/extractor/base.py | 3 +- dbgpt/rag/extractor/summary.py | 3 +- .../extractor/tests/test_summary_extractor.py | 2 +- dbgpt/rag/knowledge/base.py | 2 +- dbgpt/rag/knowledge/csv.py | 2 +- dbgpt/rag/knowledge/docx.py | 2 +- dbgpt/rag/knowledge/html.py | 2 +- dbgpt/rag/knowledge/markdown.py | 2 +- dbgpt/rag/knowledge/pdf.py | 2 +- dbgpt/rag/knowledge/pptx.py | 2 +- dbgpt/rag/knowledge/string.py | 2 +- dbgpt/rag/knowledge/txt.py | 2 +- dbgpt/rag/knowledge/url.py | 2 +- dbgpt/rag/operators/datasource.py | 4 +- dbgpt/rag/operators/db_schema.py | 6 +- dbgpt/rag/operators/embedding.py | 2 +- dbgpt/rag/operators/evaluation.py | 3 +- dbgpt/rag/operators/rerank.py | 2 +- dbgpt/rag/operators/schema_linking.py | 6 +- dbgpt/rag/retriever/base.py | 2 +- dbgpt/rag/retriever/db_schema.py | 16 +- dbgpt/rag/retriever/embedding.py | 2 +- dbgpt/rag/retriever/rerank.py | 2 +- dbgpt/rag/retriever/tests/test_db_struct.py | 4 +- dbgpt/rag/retriever/tests/test_embedding.py | 2 +- dbgpt/rag/schemalinker/schema_linking.py | 15 +- dbgpt/rag/summary/db_summary_client.py | 2 +- dbgpt/rag/summary/rdbms_db_summary.py | 26 +- dbgpt/rag/summary/tests/test_rdbms_summary.py | 18 +- dbgpt/rag/text_splitter/pre_text_splitter.py | 2 +- .../rag/text_splitter/tests/test_splitters.py | 2 +- dbgpt/rag/text_splitter/text_splitter.py | 2 +- dbgpt/serve/agent/app/controller.py | 2 +- .../resource_loader/datasource_load_client.py | 10 +- .../knowledge_space_load_client.py | 12 +- dbgpt/serve/rag/assembler/base.py | 4 +- dbgpt/serve/rag/assembler/db_schema.py | 17 +- dbgpt/serve/rag/assembler/embedding.py | 2 +- dbgpt/serve/rag/assembler/summary.py | 3 +- .../tests/test_db_struct_assembler.py | 4 +- .../tests/test_embedding_assembler.py | 4 +- dbgpt/serve/rag/operators/db_schema.py | 6 +- dbgpt/serve/rag/retriever/knowledge_space.py | 2 +- dbgpt/storage/vector_store/base.py | 3 +- dbgpt/storage/vector_store/chroma_store.py | 2 +- dbgpt/storage/vector_store/connector.py | 2 +- dbgpt/storage/vector_store/milvus_store.py | 3 +- dbgpt/storage/vector_store/pgvector_store.py | 2 +- dbgpt/storage/vector_store/weaviate_store.py | 2 +- .../simple_nl_schema_sql_chart_example.py | 12 +- examples/rag/db_schema_rag_example.py | 4 +- .../rag/simple_dbschema_retriever_example.py | 6 +- examples/rag/simple_rag_retriever_example.py | 2 +- examples/sdk/simple_sdk_llm_sql_example.py | 4 +- requirements/lint-requirements.txt | 3 +- .../datasource/test_conn_clickhouse.py | 4 +- .../datasource/test_conn_doris.py | 4 +- .../datasource/test_conn_mysql.py | 6 +- .../datasource/test_conn_starrocks.py | 4 +- 108 files changed, 1194 insertions(+), 1066 deletions(-) rename dbgpt/{rag/chunk.py => core/interface/knowledge.py} (100%) delete mode 100644 dbgpt/datasource/manages/connect_storage_duckdb.py delete mode 100644 dbgpt/datasource/manages/connection_manager.py create mode 100644 dbgpt/datasource/manages/connector_manager.py delete mode 100644 dbgpt/datasource/rdbms/_base_dao.py diff --git a/.mypy.ini b/.mypy.ini index 7b927c52c..bba90802a 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -5,12 +5,6 @@ exclude = /tests/ [mypy-dbgpt.app.*] follow_imports = skip -[mypy-dbgpt.datasource.*] -follow_imports = skip - -# [mypy-dbgpt.storage.*] -# follow_imports = skip - [mypy-dbgpt.serve.*] follow_imports = skip @@ -74,3 +68,16 @@ ignore_missing_imports = True [mypy-cryptography.*] ignore_missing_imports = True + +# Datasource +[mypy-pyspark.*] +ignore_missing_imports = True + +[mypy-regex.*] +ignore_missing_imports = True + +[mypy-sqlparse.*] +ignore_missing_imports = True + +[mypy-clickhouse_connect.*] +ignore_missing_imports = True \ No newline at end of file diff --git a/Makefile b/Makefile index 298133819..72cc22c01 100644 --- a/Makefile +++ b/Makefile @@ -48,9 +48,7 @@ fmt: setup ## Format Python code $(VENV_BIN)/blackdoc examples # TODO: Use flake8 to enforce Python style guide. # https://flake8.pycqa.org/en/latest/ - $(VENV_BIN)/flake8 dbgpt/core/ - $(VENV_BIN)/flake8 dbgpt/rag/ - $(VENV_BIN)/flake8 dbgpt/storage/ + $(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ # TODO: More package checks with flake8. .PHONY: fmt-check @@ -59,9 +57,7 @@ fmt-check: setup ## Check Python code formatting and style without making change $(VENV_BIN)/isort --check-only --extend-skip="examples/notebook" examples $(VENV_BIN)/black --check --extend-exclude="examples/notebook" . $(VENV_BIN)/blackdoc --check dbgpt examples - $(VENV_BIN)/flake8 dbgpt/core/ - $(VENV_BIN)/flake8 dbgpt/rag/ - $(VENV_BIN)/flake8 dbgpt/storage/ + $(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ .PHONY: pre-commit pre-commit: fmt-check test test-doc mypy ## Run formatting and unit tests before committing @@ -77,7 +73,7 @@ test-doc: $(VENV)/.testenv ## Run doctests .PHONY: mypy mypy: $(VENV)/.testenv ## Run mypy checks # https://github.com/python/mypy - $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ + $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ dbgpt/datasource/ # rag depends on core and storage, so we not need to check it again. # $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/storage/ # $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/ diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index da09c7d6e..e573da7ad 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from auto_gpt_plugin_template import AutoGPTPluginTemplate from dbgpt.component import SystemApp + from dbgpt.datasource.manages import ConnectorManager class Config(metaclass=Singleton): @@ -185,8 +186,6 @@ class Config(metaclass=Singleton): os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true" ) - self.LOCAL_DB_MANAGE = None - ###dbgpt meta info database connection configuration self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST") self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db") @@ -287,3 +286,11 @@ class Config(metaclass=Singleton): self.MODEL_CACHE_STORAGE_DISK_DIR: Optional[str] = os.getenv( "MODEL_CACHE_STORAGE_DISK_DIR" ) + + @property + def local_db_manager(self) -> "ConnectorManager": + from dbgpt.datasource.manages import ConnectorManager + + if not self.SYSTEM_APP: + raise ValueError("SYSTEM_APP is not set") + return ConnectorManager.get_instance(self.SYSTEM_APP) diff --git a/dbgpt/app/base.py b/dbgpt/app/base.py index d92dda5ad..6aaa491d2 100644 --- a/dbgpt/app/base.py +++ b/dbgpt/app/base.py @@ -70,15 +70,8 @@ def server_init(param: "WebServerParameters", system_app: SystemApp): def _create_model_start_listener(system_app: SystemApp): - from dbgpt.datasource.manages.connection_manager import ConnectManager - - cfg = Config() - def startup_event(wh): - # init connect manage print("begin run _add_app_startup_event") - conn_manage = ConnectManager(system_app) - cfg.LOCAL_DB_MANAGE = conn_manage async_db_summary(system_app) return startup_event diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index 7ee28be71..6078a2a12 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -23,6 +23,7 @@ def initialize_components( from dbgpt.app.initialization.embedding_component import _initialize_embedding_model from dbgpt.app.initialization.scheduler import DefaultScheduler from dbgpt.app.initialization.serve_initialization import register_serve_apps + from dbgpt.datasource.manages.connector_manager import ConnectorManager from dbgpt.model.cluster.controller.controller import controller # Register global default executor factory first @@ -31,6 +32,7 @@ def initialize_components( ) system_app.register(DefaultScheduler) system_app.register_instance(controller) + system_app.register(ConnectorManager) from dbgpt.serve.agent.hub.controller import module_plugin diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py index 7f42f6c05..c6bf61fc8 100644 --- a/dbgpt/app/knowledge/service.py +++ b/dbgpt/app/knowledge/service.py @@ -28,8 +28,8 @@ from dbgpt.app.knowledge.request.response import ( from dbgpt.app.knowledge.space_db import KnowledgeSpaceDao, KnowledgeSpaceEntity from dbgpt.component import ComponentType from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG +from dbgpt.core import Chunk from dbgpt.model import DefaultLLMClient -from dbgpt.rag.chunk import Chunk from dbgpt.rag.chunk_manager import ChunkParameters from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType diff --git a/dbgpt/app/openapi/api_v1/api_v1.py b/dbgpt/app/openapi/api_v1/api_v1.py index 0762feb17..3441ab8e0 100644 --- a/dbgpt/app/openapi/api_v1/api_v1.py +++ b/dbgpt/app/openapi/api_v1/api_v1.py @@ -67,7 +67,7 @@ def __new_conversation(chat_mode, user_name: str, sys_code: str) -> Conversation def get_db_list(): - dbs = CFG.LOCAL_DB_MANAGE.get_db_list() + dbs = CFG.local_db_manager.get_db_list() db_params = [] for item in dbs: params: dict = {} @@ -85,7 +85,7 @@ def plugins_select_info(): def get_db_list_info(): - dbs = CFG.LOCAL_DB_MANAGE.get_db_list() + dbs = CFG.local_db_manager.get_db_list() params: dict = {} for item in dbs: comment = item["comment"] @@ -147,22 +147,22 @@ def get_executor() -> Executor: @router.get("/v1/chat/db/list", response_model=Result[DBConfig]) async def db_connect_list(): - return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) + return Result.succ(CFG.local_db_manager.get_db_list()) @router.post("/v1/chat/db/add", response_model=Result[bool]) async def db_connect_add(db_config: DBConfig = Body()): - return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config)) + return Result.succ(CFG.local_db_manager.add_db(db_config)) @router.post("/v1/chat/db/edit", response_model=Result[bool]) async def db_connect_edit(db_config: DBConfig = Body()): - return Result.succ(CFG.LOCAL_DB_MANAGE.edit_db(db_config)) + return Result.succ(CFG.local_db_manager.edit_db(db_config)) @router.post("/v1/chat/db/delete", response_model=Result[bool]) async def db_connect_delete(db_name: str = None): - return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name)) + return Result.succ(CFG.local_db_manager.delete_db(db_name)) async def async_db_summary_embedding(db_name, db_type): @@ -174,7 +174,7 @@ async def async_db_summary_embedding(db_name, db_type): async def test_connect(db_config: DBConfig = Body()): try: # TODO Change the synchronous call to the asynchronous call - CFG.LOCAL_DB_MANAGE.test_connect(db_config) + CFG.local_db_manager.test_connect(db_config) return Result.succ(True) except Exception as e: return Result.failed(code="E1001", msg=str(e)) @@ -189,7 +189,7 @@ async def db_summary(db_name: str, db_type: str): @router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo]) async def db_support_types(): - support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types() + support_types = CFG.local_db_manager.get_all_completed_types() db_type_infos = [] for type in support_types: db_type_infos.append( diff --git a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py index 5421dd4e0..ad95a59bf 100644 --- a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py +++ b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py @@ -47,7 +47,7 @@ async def get_editor_tables( db_name: str, page_index: int, page_size: int, search_str: str = "" ): logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}") - db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) + db_conn = CFG.local_db_manager.get_connector(db_name) tables = db_conn.get_table_names() db_node: DataNode = DataNode(title=db_name, key=db_name, type="db") for table in tables: @@ -95,7 +95,7 @@ async def editor_sql_run(run_param: dict = Body()): sql = run_param["sql"] if not db_name and not sql: return Result.failed(msg="SQL run param error!") - conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) + conn = CFG.local_db_manager.get_connector(db_name) try: start_time = time.time() * 1000 @@ -125,7 +125,7 @@ async def sql_editor_submit( ): logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}") - conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name) + conn = CFG.local_db_manager.get_connector(sql_edit_context.db_name) try: editor_service.sql_editor_submit_and_save(sql_edit_context, conn) return Result.succ(None) @@ -168,7 +168,7 @@ async def editor_chart_run(run_param: dict = Body()): return Result.failed("SQL run param error!") try: dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() - db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) + db_conn = CFG.local_db_manager.get_connector(db_name) colunms, sql_result = db_conn.query_ex(sql) field_names, chart_values = dashboard_data_loader.get_chart_values_by_data( colunms, sql_result, sql @@ -204,7 +204,7 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()) history_messages: List[Dict] = history_mem.get_messages() if history_messages: dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() - db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name) + db_conn = CFG.local_db_manager.get_connector(chart_edit_context.db_name) edit_round = max(history_messages, key=lambda x: x["chat_order"]) if edit_round: diff --git a/dbgpt/app/openapi/api_v1/editor/service.py b/dbgpt/app/openapi/api_v1/editor/service.py index ef22adebd..bb27e0a92 100644 --- a/dbgpt/app/openapi/api_v1/editor/service.py +++ b/dbgpt/app/openapi/api_v1/editor/service.py @@ -22,7 +22,7 @@ from dbgpt.core.interface.message import ( from dbgpt.serve.conversation.serve import Serve as ConversationServe if TYPE_CHECKING: - from dbgpt.datasource.base import BaseConnect + from dbgpt.datasource.base import BaseConnector logger = logging.getLogger(__name__) @@ -86,7 +86,7 @@ class EditorService(BaseComponent): return None def sql_editor_submit_and_save( - self, sql_edit_context: ChatSqlEditContext, connection: BaseConnect + self, sql_edit_context: ChatSqlEditContext, connection: BaseConnector ): storage_conv: StorageConversation = self.get_storage_conv( sql_edit_context.conv_uid @@ -169,7 +169,7 @@ class EditorService(BaseComponent): filter(lambda x: x["chart_name"] == chart_title, charts) )[0] - conn = cfg.LOCAL_DB_MANAGE.get_connect(db_name) + conn = cfg.local_db_manager.get_connector(db_name) detail: ChartDetail = ChartDetail( chart_uid=find_chart["chart_uid"], chart_type=find_chart["chart_type"], diff --git a/dbgpt/app/scene/chat_dashboard/chat.py b/dbgpt/app/scene/chat_dashboard/chat.py index 3748ce32e..bb35de7cc 100644 --- a/dbgpt/app/scene/chat_dashboard/chat.py +++ b/dbgpt/app/scene/chat_dashboard/chat.py @@ -38,7 +38,7 @@ class ChatDashboard(BaseChat): self.db_name = self.db_name self.report_name = chat_param.get("report_name", "report") - self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) + self.database = CFG.local_db_manager.get_connector(self.db_name) self.top_k: int = 5 self.dashboard_template = self.__load_dashboard_template(self.report_name) diff --git a/dbgpt/app/scene/chat_dashboard/data_loader.py b/dbgpt/app/scene/chat_dashboard/data_loader.py index 05cc55970..954e7da44 100644 --- a/dbgpt/app/scene/chat_dashboard/data_loader.py +++ b/dbgpt/app/scene/chat_dashboard/data_loader.py @@ -57,5 +57,5 @@ class DashboardDataLoader: def get_chart_values_by_db(self, db_name: str, chart_sql: str): logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}") - db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) + db_conn = CFG.local_db_manager.get_connector(db_name) return self.get_chart_values_by_conn(db_conn, chart_sql) diff --git a/dbgpt/app/scene/chat_db/auto_execute/chat.py b/dbgpt/app/scene/chat_db/auto_execute/chat.py index 478ccc482..6a2bc3da2 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/chat.py +++ b/dbgpt/app/scene/chat_db/auto_execute/chat.py @@ -37,7 +37,7 @@ class ChatWithDbAutoExecute(BaseChat): with root_tracer.start_span( "ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name} ): - self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) + self.database = CFG.local_db_manager.get_connector(self.db_name) self.top_k: int = 50 self.api_call = ApiCall(display_registry=CFG.command_display) diff --git a/dbgpt/app/scene/chat_db/professional_qa/chat.py b/dbgpt/app/scene/chat_db/professional_qa/chat.py index fb616cf50..06411319c 100644 --- a/dbgpt/app/scene/chat_db/professional_qa/chat.py +++ b/dbgpt/app/scene/chat_db/professional_qa/chat.py @@ -29,7 +29,7 @@ class ChatWithDbQA(BaseChat): super().__init__(chat_param=chat_param) if self.db_name: - self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) + self.database = CFG.local_db_manager.get_connector(self.db_name) self.tables = self.database.get_table_names() self.top_k = ( diff --git a/dbgpt/component.py b/dbgpt/component.py index 7f96fbdc3..8ca3dd9d4 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -83,6 +83,7 @@ class ComponentType(str, Enum): AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager" AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager" UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory" + CONNECTOR_MANAGER = "dbgpt_connector_manager" _EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT" diff --git a/dbgpt/core/__init__.py b/dbgpt/core/__init__.py index 04c069d94..718a23158 100644 --- a/dbgpt/core/__init__.py +++ b/dbgpt/core/__init__.py @@ -8,6 +8,7 @@ from dbgpt.core.interface.cache import ( # noqa: F401 CacheValue, ) from dbgpt.core.interface.embeddings import Embeddings # noqa: F401 +from dbgpt.core.interface.knowledge import Chunk, Document # noqa: F401 from dbgpt.core.interface.llm import ( # noqa: F401 DefaultMessageConverter, LLMClient, @@ -105,4 +106,6 @@ __ALL__ = [ "QuerySpec", "StorageError", "Embeddings", + "Chunk", + "Document", ] diff --git a/dbgpt/rag/chunk.py b/dbgpt/core/interface/knowledge.py similarity index 100% rename from dbgpt/rag/chunk.py rename to dbgpt/core/interface/knowledge.py diff --git a/dbgpt/datasource/__init__.py b/dbgpt/datasource/__init__.py index 5c619dc41..264fd164c 100644 --- a/dbgpt/datasource/__init__.py +++ b/dbgpt/datasource/__init__.py @@ -1 +1,6 @@ -from .manages.connect_config_db import ConnectConfigDao, ConnectConfigEntity +"""Module to define the data source connectors.""" + +from .base import BaseConnector # noqa: F401 +from .rdbms.base import RDBMSConnector # noqa: F401 + +__ALL__ = ["BaseConnector", "RDBMSConnector"] diff --git a/dbgpt/datasource/base.py b/dbgpt/datasource/base.py index d95e173a0..ce125e140 100644 --- a/dbgpt/datasource/base.py +++ b/dbgpt/datasource/base.py @@ -1,23 +1,25 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- - -"""We need to design a base class. That other connector can Write with this""" -from abc import ABC -from typing import Any, Dict, Iterable, List, Optional +"""Base class for all connectors.""" +from abc import ABC, abstractmethod +from typing import Dict, Iterable, List, Optional, Tuple -class BaseConnect(ABC): +class BaseConnector(ABC): + """Base class for all connectors.""" + + db_type: str = "__abstract__db_type__" + driver: str = "" + def get_table_names(self) -> Iterable[str]: - """Get all table names""" - pass + """Get all table names.""" + raise NotImplementedError("Current connector does not support get_table_names") def get_table_info(self, table_names: Optional[List[str]] = None) -> str: - """Get table info about specified table. + r"""Get table info about specified table. Returns: - str: Table information joined by '\n\n' + str: Table information joined by "\n\n" """ - pass + raise NotImplementedError("Current connector does not support get_table_info") def get_index_info(self, table_names: Optional[List[str]] = None) -> str: """Get index info about specified table. @@ -25,7 +27,7 @@ class BaseConnect(ABC): Args: table_names (Optional[List[str]]): table names """ - pass + raise NotImplementedError("Current connector does not support get_index_info") def get_example_data(self, table: str, count: int = 3): """Get example data about specified table. @@ -36,104 +38,179 @@ class BaseConnect(ABC): table (str): table name count (int): example data count """ - pass + raise NotImplementedError("Current connector does not support get_example_data") - def get_database_list(self) -> List[str]: - """Get database list. + def get_database_names(self) -> List[str]: + """Return database names. + + Examples: + .. code-block:: python + + print(conn.get_database_names()) + # ['db1', 'db2'] Returns: List[str]: database list """ - pass + raise NotImplementedError( + "Current connector does not support get_database_names" + ) - def get_database_names(self): - """Get database names.""" - pass - - def get_table_comments(self, db_name: str): + def get_table_comments(self, db_name: str) -> List[Tuple[str, str]]: """Get table comments. Args: db_name (str): database name + + Returns: + List[Tuple[str, str]]: Table comments, first element is table name, second + element is comment """ - pass + raise NotImplementedError( + "Current connector does not support get_table_comments" + ) def get_table_comment(self, table_name: str) -> Dict: - """Get table comment. + """Return table comment with table name. + Args: table_name (str): table name + Returns: comment: Dict, which contains text: Optional[str], eg:["text": "comment"] """ - pass + raise NotImplementedError( + "Current connector does not support get_table_comment" + ) def get_columns(self, table_name: str) -> List[Dict]: - """Get columns. + """Return columns with table name. + Args: table_name (str): table name - Returns: - columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str - eg:[{'name': 'id', 'type': 'int', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...] - """ - pass - def get_column_comments(self, db_name, table_name): - """Get column comments. + Returns: + List[Dict]: columns of table, which contains name: str, type: str, + default_expression: str, is_in_primary_key: bool, comment: str + eg: [{'name': 'id', 'type': 'int', 'default_expression': '', + 'is_in_primary_key': True, 'comment': 'id'}, ...] + """ + raise NotImplementedError("Current connector does not support get_columns") + + def get_column_comments(self, db_name: str, table_name: str): + """Return column comments with db name and table name. Args: + db_name (str): database name table_name (_type_): _description_ """ - pass + raise NotImplementedError( + "Current connector does not support get_column_comments" + ) + @abstractmethod def run(self, command: str, fetch: str = "all") -> List: """Execute sql command. Args: command (str): sql command fetch (str): fetch type + + Returns: + List: result list """ - pass def run_to_df(self, command: str, fetch: str = "all"): - """Execute sql command and return dataframe.""" - pass + """Execute sql command and return result as dataframe. - def get_users(self): - """Get user info.""" + Args: + command (str): sql command + fetch (str): fetch type + + Returns: + DataFrame: result dataframe + """ + raise NotImplementedError("Current connector does not support run_to_df") + + def get_users(self) -> List[Tuple[str, str]]: + """Return user information. + + Returns: + List[Tuple[str, str]]: user list, which contains username and host + """ return [] - def get_grants(self): - """Get grant info.""" + def get_grants(self) -> List[Tuple]: + """Return grant information. + + Examples: + .. code-block:: python + + print(conn.get_grants()) + # [(('GRANT SELECT, INSERT, UPDATE, DROP ROLE ON *.* TO `root`@`%` + # WITH GRANT OPTION',)] + + Returns: + List[Tuple]: grant list, which contains grant information + """ return [] - def get_collation(self): - """Get collation.""" + def get_collation(self) -> Optional[str]: + """Return collation. + + Returns: + Optional[str]: collation + """ return None def get_charset(self) -> str: """Get character_set of current database.""" return "utf-8" - def get_fields(self, table_name): - """Get column fields about specified table.""" - pass + def get_fields(self, table_name: str) -> List[Tuple]: + """Get column fields about specified table. - def get_simple_fields(self, table_name): - """Get column fields about specified table.""" - return self.get_fields(table_name) - - def get_show_create_table(self, table_name): - """Get the creation table sql about specified table.""" - pass - - def get_indexes(self, table_name: str) -> List[Dict]: - """Get table indexes about specified table. Args: table_name (str): table name + Returns: - indexes: List[Dict], eg:[{'name': 'idx_key', 'column_names': ['id']}] + List[Tuple]: column fields, which contains column name, column type, + column default, is nullable, column comment """ - pass + raise NotImplementedError("Current connector does not support get_fields") + + def get_simple_fields(self, table_name: str) -> List[Tuple]: + """Return simple fields about specified table. + + Args: + table_name (str): table name + + Returns: + List[Tuple]: simple fields, which contains column name, column type, + is nullable, column key, default value, extra. + """ + return self.get_fields(table_name) + + def get_show_create_table(self, table_name: str) -> str: + """Return show create table about specified table. + + Returns: + str: show create table + """ + raise NotImplementedError( + "Current connector does not support get_show_create_table" + ) + + def get_indexes(self, table_name: str) -> List[Dict]: + """Return indexes about specified table. + + Args: + table_name (str): table name + + Returns: + List[Dict], eg:[{'name': 'idx_key', 'column_names': ['id']}] + """ + raise NotImplementedError("Current connector does not support get_indexes") @classmethod def is_normal_type(cls) -> bool: diff --git a/dbgpt/datasource/conn_spark.py b/dbgpt/datasource/conn_spark.py index cc8107f4b..a44bea6c7 100644 --- a/dbgpt/datasource/conn_spark.py +++ b/dbgpt/datasource/conn_spark.py @@ -1,13 +1,24 @@ -from typing import Any, Optional +"""Spark Connector.""" +import logging +from typing import TYPE_CHECKING, Any, Optional -from dbgpt.datasource.base import BaseConnect +from .base import BaseConnector + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + +logger = logging.getLogger(__name__) -class SparkConnect(BaseConnect): - """ - Spark Connect supports operating on a variety of data sources through the DataFrame interface. - A DataFrame can be operated on using relational transformations and can also be used to create a temporary view. - Registering a DataFrame as a temporary view allows you to run SQL queries over its data. +class SparkConnector(BaseConnector): + """Spark Connector. + + Spark Connect supports operating on a variety of data sources through the DataFrame + interface. + A DataFrame can be operated on using relational transformations and can also be + used to create a temporary view.Registering a DataFrame as a temporary view allows + you to run SQL queries over its data. + Datasource now support parquet, jdbc, orc, libsvm, csv, text, json. """ @@ -21,12 +32,15 @@ class SparkConnect(BaseConnect): def __init__( self, file_path: str, - spark_session: Optional = None, - engine_args: Optional[dict] = None, + spark_session: Optional["SparkSession"] = None, **kwargs: Any, ) -> None: - """Initialize the Spark DataFrame from Datasource path - return: Spark DataFrame + """Create a Spark Connector. + + Args: + file_path: file path + spark_session: spark session + kwargs: other args """ from pyspark.sql import SparkSession @@ -40,15 +54,21 @@ class SparkConnect(BaseConnect): @classmethod def from_file_path( cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any - ): + ) -> "SparkConnector": + """Create a new SparkConnector from file path.""" try: - return cls(file_path=file_path, engine_args=engine_args) + return cls(file_path=file_path, engine_args=engine_args, **kwargs) except Exception as e: - print("load spark datasource error" + str(e)) + logger.error("load spark datasource error" + str(e)) + raise e def create_df(self, path): - """Create a Spark DataFrame from Datasource path(now support parquet, jdbc, orc, libsvm, csv, text, json.). + """Create a Spark DataFrame. + + Create a Spark DataFrame from Datasource path(now support parquet, jdbc, + orc, libsvm, csv, text, json.). + return: Spark DataFrame reference:https://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html """ @@ -59,8 +79,9 @@ class SparkConnect(BaseConnect): path, format=extension, inferSchema="true", header="true" ) - def run(self, sql): - print(f"spark sql to run is {sql}") + def run(self, sql: str, fetch: str = "all"): + """Execute sql command.""" + logger.info(f"spark sql to run is {sql}") self.df.createOrReplaceTempView(self.table_name) df = self.spark_session.sql(sql) first_row = df.first() @@ -69,7 +90,8 @@ class SparkConnect(BaseConnect): rows.append(row) return rows - def query_ex(self, sql): + def query_ex(self, sql: str): + """Execute sql command.""" rows = self.run(sql) field_names = rows[0] return field_names, rows @@ -80,40 +102,31 @@ class SparkConnect(BaseConnect): def get_show_create_table(self, table_name): """Get table show create table about specified table.""" - return "ans" - def get_fields(self): - """Get column meta about dataframe.""" + def get_fields(self, table_name: str): + """Get column meta about dataframe. + + TODO: Support table_name. + """ return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes]) - def get_users(self): - return [] - - def get_grants(self): - return [] - def get_collation(self): """Get collation.""" return "UTF-8" - def get_charset(self): - return "UTF-8" - - def get_db_list(self): - return ["default"] - def get_db_names(self): + """Get database names.""" return ["default"] - def get_database_list(self): - return [] - def get_database_names(self): + """Get database names.""" return [] def table_simple_info(self): + """Get table simple info.""" return f"{self.table_name}{self.get_fields()}" def get_table_comments(self, db_name): + """Get table comments.""" return "" diff --git a/dbgpt/datasource/db_conn_info.py b/dbgpt/datasource/db_conn_info.py index 56daaae7c..46e48068f 100644 --- a/dbgpt/datasource/db_conn_info.py +++ b/dbgpt/datasource/db_conn_info.py @@ -1,17 +1,22 @@ -from dbgpt._private.pydantic import BaseModel +"""Configuration for database connection.""" +from dbgpt._private.pydantic import BaseModel, Field class DBConfig(BaseModel): - db_type: str - db_name: str - file_path: str = "" - db_host: str = "" - db_port: int = 0 - db_user: str = "" - db_pwd: str = "" - comment: str = "" + """Database connection configuration.""" + + db_type: str = Field(..., description="Database type, e.g. sqlite, mysql, etc.") + db_name: str = Field(..., description="Database name.") + file_path: str = Field("", description="File path for file-based database.") + db_host: str = Field("", description="Database host.") + db_port: int = Field(0, description="Database port.") + db_user: str = Field("", description="Database user.") + db_pwd: str = Field("", description="Database password.") + comment: str = Field("", description="Comment for the database.") class DbTypeInfo(BaseModel): - db_type: str - is_file_db: bool = False + """Database type information.""" + + db_type: str = Field(..., description="Database type.") + is_file_db: bool = Field(False, description="Whether the database is file-based.") diff --git a/dbgpt/datasource/manages/__init__.py b/dbgpt/datasource/manages/__init__.py index e69de29bb..733124ada 100644 --- a/dbgpt/datasource/manages/__init__.py +++ b/dbgpt/datasource/manages/__init__.py @@ -0,0 +1,5 @@ +"""This module is responsible for managing the connectors.""" + +from .connector_manager import ConnectorManager # noqa: F401 + +__ALL__ = ["ConnectorManager"] diff --git a/dbgpt/datasource/manages/connect_config_db.py b/dbgpt/datasource/manages/connect_config_db.py index 7c2cc2ff6..866bc1065 100644 --- a/dbgpt/datasource/manages/connect_config_db.py +++ b/dbgpt/datasource/manages/connect_config_db.py @@ -1,10 +1,17 @@ +"""DB Model for connect_config.""" + +import logging +from typing import Optional + from sqlalchemy import Column, Index, Integer, String, Text, UniqueConstraint, text from dbgpt.storage.metadata import BaseDao, Model +logger = logging.getLogger(__name__) + class ConnectConfigEntity(Model): - """db connect config entity""" + """DB connector config entity.""" __tablename__ = "connect_config" id = Column( @@ -28,32 +35,10 @@ class ConnectConfigEntity(Model): class ConnectConfigDao(BaseDao): - """db connect config dao""" + """DB connector config dao.""" - def update(self, entity: ConnectConfigEntity): - """update db connect info""" - session = self.get_raw_session() - try: - updated = session.merge(entity) - session.commit() - return updated.id - finally: - session.close() - - def delete(self, db_name: int): - """ "delete db connect info""" - session = self.get_raw_session() - if db_name is None: - raise Exception("db_name is None") - - db_connect = session.query(ConnectConfigEntity) - db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name) - db_connect.delete() - session.commit() - session.close() - - def get_by_names(self, db_name: str) -> ConnectConfigEntity: - """get db connect info by name""" + def get_by_names(self, db_name: str) -> Optional[ConnectConfigEntity]: + """Get db connect info by name.""" session = self.get_raw_session() db_connect = session.query(ConnectConfigEntity) db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name) @@ -71,8 +56,8 @@ class ConnectConfigDao(BaseDao): db_pwd: str, comment: str = "", ): - """ - add db connect info + """Add db connect info. + Args: db_name: db name db_type: db type @@ -90,9 +75,9 @@ class ConnectConfigDao(BaseDao): insert_statement = text( """ INSERT INTO connect_config ( - db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment - ) VALUES ( - :db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment + db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, + comment) VALUES (:db_name, :db_type, :db_path, :db_host, :db_port + , :db_user, :db_pwd, :comment ) """ ) @@ -111,7 +96,7 @@ class ConnectConfigDao(BaseDao): session.commit() session.close() except Exception as e: - print("add db connect info error!" + str(e)) + logger.warning("add db connect info error!" + str(e)) def update_db_info( self, @@ -124,37 +109,43 @@ class ConnectConfigDao(BaseDao): db_pwd: str = "", comment: str = "", ): - """update db connect info""" + """Update db connect info.""" old_db_conf = self.get_db_config(db_name) if old_db_conf: try: session = self.get_raw_session() if not db_path: update_statement = text( - f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'" + f"UPDATE connect_config set db_type='{db_type}', " + f"db_host='{db_host}', db_port={db_port}, db_user='{db_user}', " + f"db_pwd='{db_pwd}', comment='{comment}' where " + f"db_name='{db_name}'" ) else: update_statement = text( - f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'" + f"UPDATE connect_config set db_type='{db_type}', " + f"db_path='{db_path}', comment='{comment}' where " + f"db_name='{db_name}'" ) session.execute(update_statement) session.commit() session.close() except Exception as e: - print("edit db connect info error!" + str(e)) + logger.warning("edit db connect info error!" + str(e)) return True raise ValueError(f"{db_name} not have config info!") def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""): - """add file db connect info""" + """Add file db connect info.""" try: session = self.get_raw_session() insert_statement = text( """ INSERT INTO connect_config( - db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment - ) VALUES ( - :db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment + db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, + comment) VALUES ( + :db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd + , :comment ) """ ) @@ -174,19 +165,19 @@ class ConnectConfigDao(BaseDao): session.commit() session.close() except Exception as e: - print("add db connect info error!" + str(e)) + logger.warning("add db connect info error!" + str(e)) def get_db_config(self, db_name): - """get db config by name""" + """Return db connect info by name.""" session = self.get_raw_session() if db_name: select_statement = text( """ - SELECT + SELECT * - FROM - connect_config - WHERE + FROM + connect_config + WHERE db_name = :db_name """ ) @@ -196,7 +187,7 @@ class ConnectConfigDao(BaseDao): else: raise ValueError("Cannot get database by name" + db_name) - print(result) + logger.info(f"Result: {result}") fields = [field[0] for field in result.cursor.description] row_dict = {} row_1 = list(result.cursor.fetchall()[0]) @@ -205,7 +196,7 @@ class ConnectConfigDao(BaseDao): return row_dict def get_db_list(self): - """get db list""" + """Get db list.""" session = self.get_raw_session() result = session.execute(text("SELECT * FROM connect_config")) @@ -219,7 +210,7 @@ class ConnectConfigDao(BaseDao): return data def delete_db(self, db_name): - """delete db connect info""" + """Delete db connect info.""" session = self.get_raw_session() delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""") params = {"db_name": db_name} diff --git a/dbgpt/datasource/manages/connect_storage_duckdb.py b/dbgpt/datasource/manages/connect_storage_duckdb.py deleted file mode 100644 index a08d33094..000000000 --- a/dbgpt/datasource/manages/connect_storage_duckdb.py +++ /dev/null @@ -1,151 +0,0 @@ -import os - -import duckdb - -from dbgpt.configs.model_config import PILOT_PATH - -default_db_path = os.path.join(PILOT_PATH, "message") - -duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db") -table_name = "connect_config" - - -class DuckdbConnectConfig: - def __init__(self): - os.makedirs(default_db_path, exist_ok=True) - self.connect = duckdb.connect(duckdb_path) - self.__init_config_tables() - - def __init_config_tables(self): - # check config table - result = self.connect.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name] - ).fetchall() - - if not result: - # create config table - self.connect.execute( - "CREATE TABLE connect_config (id integer primary key, db_name VARCHAR(100) UNIQUE, db_type VARCHAR(50), db_path VARCHAR(255) NULL, db_host VARCHAR(255) NULL, db_port INTEGER NULL, db_user VARCHAR(255) NULL, db_pwd VARCHAR(255) NULL, comment TEXT NULL)" - ) - self.connect.execute("CREATE SEQUENCE seq_id START 1;") - - def add_url_db( - self, - db_name, - db_type, - db_host: str, - db_port: int, - db_user: str, - db_pwd: str, - comment: str = "", - ): - try: - cursor = self.connect.cursor() - cursor.execute( - "INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)", - [db_name, db_type, "", db_host, db_port, db_user, db_pwd, comment], - ) - cursor.commit() - self.connect.commit() - except Exception as e: - print("add db connect info error1!" + str(e)) - - def update_db_info( - self, - db_name, - db_type, - db_path: str = "", - db_host: str = "", - db_port: int = 0, - db_user: str = "", - db_pwd: str = "", - comment: str = "", - ): - old_db_conf = self.get_db_config(db_name) - if old_db_conf: - try: - cursor = self.connect.cursor() - if not db_path: - cursor.execute( - f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'" - ) - else: - cursor.execute( - f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'" - ) - cursor.commit() - self.connect.commit() - except Exception as e: - print("edit db connect info error2!" + str(e)) - return True - raise ValueError(f"{db_name} not have config info!") - - def get_file_db_name(self, path): - try: - conn = duckdb.connect(path) - result = conn.execute("SELECT current_database()").fetchone()[0] - return result - except Exception as e: - raise "Unusable duckdb database path:" + path - - def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""): - try: - cursor = self.connect.cursor() - cursor.execute( - "INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)", - [db_name, db_type, db_path, "", 0, "", "", comment], - ) - cursor.commit() - self.connect.commit() - except Exception as e: - print("add db connect info error2!" + str(e)) - - def delete_db(self, db_name): - cursor = self.connect.cursor() - cursor.execute("DELETE FROM connect_config where db_name=?", [db_name]) - cursor.commit() - return True - - def get_db_config(self, db_name): - if os.path.isfile(duckdb_path): - cursor = duckdb.connect(duckdb_path).cursor() - if db_name: - cursor.execute( - "SELECT * FROM connect_config where db_name=? ", [db_name] - ) - else: - raise ValueError("Cannot get database by name" + db_name) - - fields = [field[0] for field in cursor.description] - row_dict = {} - row_1 = list(cursor.fetchall()[0]) - for i, field in enumerate(fields): - row_dict[field] = row_1[i] - return row_dict - return None - - def get_db_list(self): - if os.path.isfile(duckdb_path): - cursor = duckdb.connect(duckdb_path).cursor() - cursor.execute("SELECT * FROM connect_config ") - - fields = [field[0] for field in cursor.description] - data = [] - for row in cursor.fetchall(): - row_dict = {} - for i, field in enumerate(fields): - row_dict[field] = row[i] - data.append(row_dict) - return data - - return [] - - def get_db_names(self): - if os.path.isfile(duckdb_path): - cursor = duckdb.connect(duckdb_path).cursor() - cursor.execute("SELECT db_name FROM connect_config ") - data = [] - for row in cursor.fetchall(): - data.append(row[0]) - return data - return [] diff --git a/dbgpt/datasource/manages/connection_manager.py b/dbgpt/datasource/manages/connection_manager.py deleted file mode 100644 index 74d2a1c55..000000000 --- a/dbgpt/datasource/manages/connection_manager.py +++ /dev/null @@ -1,152 +0,0 @@ -from typing import List, Type - -from dbgpt.component import ComponentType, SystemApp -from dbgpt.datasource import ConnectConfigDao -from dbgpt.datasource.base import BaseConnect -from dbgpt.datasource.conn_spark import SparkConnect -from dbgpt.datasource.db_conn_info import DBConfig -from dbgpt.datasource.rdbms.base import RDBMSDatabase -from dbgpt.datasource.rdbms.conn_clickhouse import ClickhouseConnect -from dbgpt.datasource.rdbms.conn_doris import DorisConnect -from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect -from dbgpt.datasource.rdbms.conn_hive import HiveConnect -from dbgpt.datasource.rdbms.conn_mssql import MSSQLConnect -from dbgpt.datasource.rdbms.conn_mysql import MySQLConnect -from dbgpt.datasource.rdbms.conn_postgresql import PostgreSQLDatabase -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect -from dbgpt.datasource.rdbms.conn_starrocks import StarRocksConnect -from dbgpt.rag.summary.db_summary_client import DBSummaryClient -from dbgpt.storage.schema import DBType -from dbgpt.util.executor_utils import ExecutorFactory - - -class ConnectManager: - """db connect manager""" - - def get_all_subclasses(self, cls: Type[BaseConnect]) -> List[Type[BaseConnect]]: - subclasses = cls.__subclasses__() - for subclass in subclasses: - subclasses += self.get_all_subclasses(subclass) - return subclasses - - def get_all_completed_types(self): - chat_classes = self.get_all_subclasses(BaseConnect) - support_types = [] - for cls in chat_classes: - if cls.db_type and cls.is_normal_type(): - support_types.append(DBType.of_db_type(cls.db_type)) - return support_types - - def get_cls_by_dbtype(self, db_type): - chat_classes = self.get_all_subclasses(BaseConnect) - result = None - for cls in chat_classes: - if cls.db_type == db_type and cls.is_normal_type(): - result = cls - if not result: - raise ValueError("Unsupported Db Type!" + db_type) - return result - - def __init__(self, system_app: SystemApp): - """metadata database management initialization""" - # self.storage = DuckdbConnectConfig() - self.storage = ConnectConfigDao() - self.system_app = system_app - self.db_summary_client = DBSummaryClient(system_app) - - def get_connect(self, db_name): - db_config = self.storage.get_db_config(db_name) - db_type = DBType.of_db_type(db_config.get("db_type")) - connect_instance = self.get_cls_by_dbtype(db_type.value()) - if db_type.is_file_db(): - db_path = db_config.get("db_path") - return connect_instance.from_file_path(db_path) - else: - db_host = db_config.get("db_host") - db_port = db_config.get("db_port") - db_user = db_config.get("db_user") - db_pwd = db_config.get("db_pwd") - return connect_instance.from_uri_db( - host=db_host, port=db_port, user=db_user, pwd=db_pwd, db_name=db_name - ) - - def test_connect(self, db_info: DBConfig): - try: - db_type = DBType.of_db_type(db_info.db_type) - connect_instance = self.get_cls_by_dbtype(db_type.value()) - if db_type.is_file_db(): - db_path = db_info.file_path - return connect_instance.from_file_path(db_path) - else: - db_name = db_info.db_name - db_host = db_info.db_host - db_port = db_info.db_port - db_user = db_info.db_user - db_pwd = db_info.db_pwd - return connect_instance.from_uri_db( - host=db_host, - port=db_port, - user=db_user, - pwd=db_pwd, - db_name=db_name, - ) - except Exception as e: - print(f"{db_info.db_name} Test connect Failure!{str(e)}") - raise ValueError(f"{db_info.db_name} Test connect Failure!{str(e)}") - - def get_db_list(self): - return self.storage.get_db_list() - - def get_db_names(self): - return self.storage.get_by_name() - - def delete_db(self, db_name: str): - return self.storage.delete_db(db_name) - - def edit_db(self, db_info: DBConfig): - return self.storage.update_db_info( - db_info.db_name, - db_info.db_type, - db_info.file_path, - db_info.db_host, - db_info.db_port, - db_info.db_user, - db_info.db_pwd, - db_info.comment, - ) - - async def async_db_summary_embedding(self, db_name, db_type): - # 在这里执行需要异步运行的代码 - self.db_summary_client.db_summary_embedding(db_name, db_type) - - def add_db(self, db_info: DBConfig): - print(f"add_db:{db_info.__dict__}") - try: - db_type = DBType.of_db_type(db_info.db_type) - if db_type.is_file_db(): - self.storage.add_file_db( - db_info.db_name, db_info.db_type, db_info.file_path - ) - else: - self.storage.add_url_db( - db_info.db_name, - db_info.db_type, - db_info.db_host, - db_info.db_port, - db_info.db_user, - db_info.db_pwd, - db_info.comment, - ) - # async embedding - executor = self.system_app.get_component( - ComponentType.EXECUTOR_DEFAULT, ExecutorFactory - ).create() - executor.submit( - self.db_summary_client.db_summary_embedding, - db_info.db_name, - db_info.db_type, - ) - except Exception as e: - raise ValueError("Add db connect info error!" + str(e)) - - return True diff --git a/dbgpt/datasource/manages/connector_manager.py b/dbgpt/datasource/manages/connector_manager.py new file mode 100644 index 000000000..f3a75e3f3 --- /dev/null +++ b/dbgpt/datasource/manages/connector_manager.py @@ -0,0 +1,228 @@ +"""Connection manager.""" +import logging +from typing import TYPE_CHECKING, List, Optional, Type + +from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt.storage.schema import DBType +from dbgpt.util.executor_utils import ExecutorFactory + +from ..base import BaseConnector +from ..db_conn_info import DBConfig +from .connect_config_db import ConnectConfigDao + +if TYPE_CHECKING: + # TODO: Don't depend on the rag module. + from dbgpt.rag.summary.db_summary_client import DBSummaryClient + +logger = logging.getLogger(__name__) + + +class ConnectorManager(BaseComponent): + """Connector manager.""" + + name = ComponentType.CONNECTOR_MANAGER + + def __init__(self, system_app: SystemApp): + """Create a new ConnectorManager.""" + self.storage = ConnectConfigDao() + self.system_app = system_app + self._db_summary_client: Optional["DBSummaryClient"] = None + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Init component.""" + self.system_app = system_app + + def on_init(self): + """Execute on init. + + Load all connector classes. + """ + from dbgpt.datasource.conn_spark import SparkConnector # noqa: F401 + from dbgpt.datasource.rdbms.base import RDBMSConnector # noqa: F401 + from dbgpt.datasource.rdbms.conn_clickhouse import ( # noqa: F401 + ClickhouseConnector, + ) + from dbgpt.datasource.rdbms.conn_doris import DorisConnector # noqa: F401 + from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnector # noqa: F401 + from dbgpt.datasource.rdbms.conn_hive import HiveConnector # noqa: F401 + from dbgpt.datasource.rdbms.conn_mssql import MSSQLConnector # noqa: F401 + from dbgpt.datasource.rdbms.conn_mysql import MySQLConnector # noqa: F401 + from dbgpt.datasource.rdbms.conn_postgresql import ( # noqa: F401 + PostgreSQLConnector, + ) + from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnector # noqa: F401 + from dbgpt.datasource.rdbms.conn_starrocks import ( # noqa: F401 + StarRocksConnector, + ) + + from .connect_config_db import ConnectConfigEntity # noqa: F401 + + def before_start(self): + """Execute before start.""" + from dbgpt.rag.summary.db_summary_client import DBSummaryClient + + self._db_summary_client = DBSummaryClient(self.system_app) + + @property + def db_summary_client(self) -> "DBSummaryClient": + """Get DBSummaryClient.""" + if not self._db_summary_client: + raise ValueError("DBSummaryClient is not initialized") + return self._db_summary_client + + def _get_all_subclasses( + self, cls: Type[BaseConnector] + ) -> List[Type[BaseConnector]]: + """Get all subclasses of cls.""" + subclasses = cls.__subclasses__() + for subclass in subclasses: + subclasses += self._get_all_subclasses(subclass) + return subclasses + + def get_all_completed_types(self) -> List[DBType]: + """Get all completed types.""" + chat_classes = self._get_all_subclasses(BaseConnector) # type: ignore + support_types = [] + for cls in chat_classes: + if cls.db_type and cls.is_normal_type(): + db_type = DBType.of_db_type(cls.db_type) + if db_type: + support_types.append(db_type) + return support_types + + def get_cls_by_dbtype(self, db_type) -> Type[BaseConnector]: + """Get class by db type.""" + chat_classes = self._get_all_subclasses(BaseConnector) # type: ignore + result = None + for cls in chat_classes: + if cls.db_type == db_type and cls.is_normal_type(): + result = cls + if not result: + raise ValueError("Unsupported Db Type!" + db_type) + return result + + def get_connector(self, db_name: str): + """Create a new connection instance. + + Args: + db_name (str): database name + """ + db_config = self.storage.get_db_config(db_name) + db_type = DBType.of_db_type(db_config.get("db_type")) + if not db_type: + raise ValueError("Unsupported Db Type!" + db_config.get("db_type")) + connect_instance = self.get_cls_by_dbtype(db_type.value()) + if db_type.is_file_db(): + db_path = db_config.get("db_path") + return connect_instance.from_file_path(db_path) # type: ignore + else: + db_host = db_config.get("db_host") + db_port = db_config.get("db_port") + db_user = db_config.get("db_user") + db_pwd = db_config.get("db_pwd") + return connect_instance.from_uri_db( # type: ignore + host=db_host, port=db_port, user=db_user, pwd=db_pwd, db_name=db_name + ) + + def test_connect(self, db_info: DBConfig) -> BaseConnector: + """Test connectivity. + + Args: + db_info (DBConfig): db connect info. + + Returns: + BaseConnector: connector instance. + + Raises: + ValueError: Test connect Failure. + """ + try: + db_type = DBType.of_db_type(db_info.db_type) + if not db_type: + raise ValueError("Unsupported Db Type!" + db_info.db_type) + connect_instance = self.get_cls_by_dbtype(db_type.value()) + if db_type.is_file_db(): + db_path = db_info.file_path + return connect_instance.from_file_path(db_path) # type: ignore + else: + db_name = db_info.db_name + db_host = db_info.db_host + db_port = db_info.db_port + db_user = db_info.db_user + db_pwd = db_info.db_pwd + return connect_instance.from_uri_db( # type: ignore + host=db_host, + port=db_port, + user=db_user, + pwd=db_pwd, + db_name=db_name, + ) + except Exception as e: + logger.error(f"{db_info.db_name} Test connect Failure!{str(e)}") + raise ValueError(f"{db_info.db_name} Test connect Failure!{str(e)}") + + def get_db_list(self): + """Get db list.""" + return self.storage.get_db_list() + + def delete_db(self, db_name: str): + """Delete db connect info.""" + return self.storage.delete_db(db_name) + + def edit_db(self, db_info: DBConfig): + """Edit db connect info.""" + return self.storage.update_db_info( + db_info.db_name, + db_info.db_type, + db_info.file_path, + db_info.db_host, + db_info.db_port, + db_info.db_user, + db_info.db_pwd, + db_info.comment, + ) + + async def async_db_summary_embedding(self, db_name, db_type): + """Async db summary embedding.""" + # TODO: async embedding + self.db_summary_client.db_summary_embedding(db_name, db_type) + + def add_db(self, db_info: DBConfig): + """Add db connect info. + + Args: + db_info (DBConfig): db connect info. + """ + logger.info(f"add_db:{db_info.__dict__}") + try: + db_type = DBType.of_db_type(db_info.db_type) + if not db_type: + raise ValueError("Unsupported Db Type!" + db_info.db_type) + if db_type.is_file_db(): + self.storage.add_file_db( + db_info.db_name, db_info.db_type, db_info.file_path + ) + else: + self.storage.add_url_db( + db_info.db_name, + db_info.db_type, + db_info.db_host, + db_info.db_port, + db_info.db_user, + db_info.db_pwd, + db_info.comment, + ) + # async embedding + executor = self.system_app.get_component( + ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ).create() # type: ignore + executor.submit( + self.db_summary_client.db_summary_embedding, + db_info.db_name, + db_info.db_type, + ) + except Exception as e: + raise ValueError("Add db connect info error!" + str(e)) + + return True diff --git a/dbgpt/datasource/nosql/__init__.py b/dbgpt/datasource/nosql/__init__.py index e69de29bb..09e99698d 100644 --- a/dbgpt/datasource/nosql/__init__.py +++ b/dbgpt/datasource/nosql/__init__.py @@ -0,0 +1 @@ +"""NoSQL data source package.""" diff --git a/dbgpt/datasource/operators/__init__.py b/dbgpt/datasource/operators/__init__.py index e69de29bb..744784b16 100644 --- a/dbgpt/datasource/operators/__init__.py +++ b/dbgpt/datasource/operators/__init__.py @@ -0,0 +1 @@ +"""Datasource operators.""" diff --git a/dbgpt/datasource/operators/datasource_operator.py b/dbgpt/datasource/operators/datasource_operator.py index ef0a03e65..d6a3d3b54 100644 --- a/dbgpt/datasource/operators/datasource_operator.py +++ b/dbgpt/datasource/operators/datasource_operator.py @@ -1,17 +1,26 @@ +"""DatasourceOperator class. + +Warning: This operator is in development and is not yet ready for production use. +""" from typing import Any from dbgpt.core.awel import MapOperator -from dbgpt.core.awel.task.base import IN, OUT -from dbgpt.datasource.base import BaseConnect + +from ..base import BaseConnector class DatasourceOperator(MapOperator[str, Any]): - def __init__(self, connection: BaseConnect, **kwargs): + """The Datasource Operator.""" + + def __init__(self, connection: BaseConnector, **kwargs): + """Create the datasource operator.""" super().__init__(**kwargs) self._connection = connection - async def map(self, input_value: IN) -> OUT: + async def map(self, input_value: str) -> Any: + """Execute the query.""" return await self.blocking_func_to_async(self.query, input_value) def query(self, input_value: str) -> Any: + """Execute the query.""" return self._connection.run_to_df(input_value) diff --git a/dbgpt/datasource/rdbms/__init__.py b/dbgpt/datasource/rdbms/__init__.py index e69de29bb..71a8c7a2b 100644 --- a/dbgpt/datasource/rdbms/__init__.py +++ b/dbgpt/datasource/rdbms/__init__.py @@ -0,0 +1 @@ +"""RDBMS Connector Module.""" diff --git a/dbgpt/datasource/rdbms/_base_dao.py b/dbgpt/datasource/rdbms/_base_dao.py deleted file mode 100644 index 9a4221bec..000000000 --- a/dbgpt/datasource/rdbms/_base_dao.py +++ /dev/null @@ -1,86 +0,0 @@ -import logging - -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from dbgpt._private.config import Config -from dbgpt.datasource.rdbms.base import RDBMSDatabase -from dbgpt.storage.schema import DBType - -logger = logging.getLogger(__name__) -CFG = Config() - - -class BaseDao: - def __init__( - self, orm_base=None, database: str = None, create_not_exist_table: bool = False - ) -> None: - """BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist""" - self._orm_base = orm_base - self._database = database - self._create_not_exist_table = create_not_exist_table - - self._db_engine = None - self._session = None - self._connection = None - - @property - def db_engine(self): - if not self._db_engine: - # lazy loading - db_engine, connection = _get_db_engine( - self._orm_base, self._database, self._create_not_exist_table - ) - self._db_engine = db_engine - self._connection = connection - return self._db_engine - - @property - def Session(self): - if not self._session: - self._session = sessionmaker(bind=self.db_engine) - return self._session - - -def _get_db_engine( - orm_base=None, database: str = None, create_not_exist_table: bool = False -): - db_engine = None - connection: RDBMSDatabase = None - - db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE) - if db_type is None or db_type == DBType.Mysql: - # default database - db_engine = create_engine( - f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}", - echo=True, - ) - else: - db_namager = CFG.LOCAL_DB_MANAGE - if not db_namager: - raise Exception( - "LOCAL_DB_MANAGE is not initialized, please check the system configuration" - ) - if db_type.is_file_db(): - db_path = CFG.LOCAL_DB_PATH - if db_path is None or db_path == "": - raise ValueError( - "You LOCAL_DB_TYPE is file db, but LOCAL_DB_PATH is not configured, please configure LOCAL_DB_PATH in you .env file" - ) - _, database = db_namager._parse_file_db_info(db_type.value(), db_path) - logger.info( - f"Current DAO database is file database, db_type: {db_type.value()}, db_path: {db_path}, db_name: {database}" - ) - logger.info(f"Get DAO database connection with database name {database}") - connection: RDBMSDatabase = db_namager.get_connect(database) - if not isinstance(connection, RDBMSDatabase): - raise ValueError( - "Currently only supports `RDBMSDatabase` database as the underlying database of BaseDao, please check your database configuration" - ) - db_engine = connection._engine - - if db_type.is_file_db() and orm_base is not None and create_not_exist_table: - logger.info("Current database is file database, create not exist table") - orm_base.metadata.create_all(db_engine) - - return db_engine, connection diff --git a/dbgpt/datasource/rdbms/base.py b/dbgpt/datasource/rdbms/base.py index f890cba6c..3197e211d 100644 --- a/dbgpt/datasource/rdbms/base.py +++ b/dbgpt/datasource/rdbms/base.py @@ -1,6 +1,9 @@ +"""Base class for RDBMS connectors.""" + from __future__ import annotations -from typing import Any, Dict, Iterable, List, Optional, Tuple +import logging +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, cast from urllib.parse import quote from urllib.parse import quote_plus as urlquote @@ -13,11 +16,10 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.schema import CreateTable -from dbgpt._private.config import Config -from dbgpt.datasource.base import BaseConnect +from dbgpt.datasource.base import BaseConnector from dbgpt.storage.schema import DBType -CFG = Config() +logger = logging.getLogger(__name__) def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: @@ -27,11 +29,9 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: ) -class RDBMSDatabase(BaseConnect): +class RDBMSConnector(BaseConnector): """SQLAlchemy wrapper around a database.""" - db_type: str = None - def __init__( self, engine, @@ -41,10 +41,11 @@ class RDBMSDatabase(BaseConnect): include_tables: Optional[List[str]] = None, sample_rows_in_table_info: int = 3, indexes_in_table_info: bool = False, - custom_table_info: Optional[dict] = None, + custom_table_info: Optional[Dict[str, str]] = None, view_support: bool = False, ): """Create engine from database URI. + Args: - engine: Engine sqlalchemy.engine - schema: Optional[str]. @@ -61,28 +62,27 @@ class RDBMSDatabase(BaseConnect): if include_tables and ignore_tables: raise ValueError("Cannot specify both include_tables and ignore_tables") + if not custom_table_info: + custom_table_info = {} + self._inspector = inspect(engine) session_factory = sessionmaker(bind=engine) Session_Manages = scoped_session(session_factory) self._db_sessions = Session_Manages self.session = self.get_session() - self._all_tables = set() - self.view_support = False - self._usable_tables = set() - self._include_tables = set() - self._ignore_tables = set() - self._custom_table_info = set() - self._indexes_in_table_info = set() - self._usable_tables = set() - self._usable_tables = set() - self._sample_rows_in_table_info = set() + self.view_support = view_support + self._usable_tables: Set[str] = set() + self._include_tables: Set[str] = set() + self._ignore_tables: Set[str] = set() + self._custom_table_info = custom_table_info + self._sample_rows_in_table_info = sample_rows_in_table_info self._indexes_in_table_info = indexes_in_table_info - self._metadata = MetaData() + self._metadata = metadata or MetaData() self._metadata.reflect(bind=self._engine) - self._all_tables = self._sync_tables_from_db() + self._all_tables: Set[str] = cast(Set[str], self._sync_tables_from_db()) @classmethod def from_uri_db( @@ -94,8 +94,9 @@ class RDBMSDatabase(BaseConnect): db_name: str, engine_args: Optional[dict] = None, **kwargs: Any, - ) -> RDBMSDatabase: + ) -> RDBMSConnector: """Construct a SQLAlchemy engine from uri database. + Args: host (str): database host. port (int): database port. @@ -112,7 +113,7 @@ class RDBMSDatabase(BaseConnect): @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any - ) -> RDBMSDatabase: + ) -> RDBMSConnector: """Construct a SQLAlchemy engine from URI.""" _engine_args = engine_args or {} return cls(create_engine(database_uri, **_engine_args), **kwargs) @@ -123,7 +124,7 @@ class RDBMSDatabase(BaseConnect): return self._engine.dialect.name def _sync_tables_from_db(self) -> Iterable[str]: - """Read table information from database""" + """Read table information from database.""" # TODO Use a background thread to refresh periodically # SQL will raise error with schema @@ -153,16 +154,25 @@ class RDBMSDatabase(BaseConnect): return self.get_usable_table_names() def get_session(self): + """Get session.""" session = self._db_sessions() return session def get_current_db_name(self) -> str: + """Get current database name. + + Returns: + str: database name + """ return self.session.execute(text("SELECT DATABASE()")).scalar() def table_simple_info(self): + """Return table simple info.""" _sql = f""" - select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name()}" group by TABLE_NAME; + select concat(table_name, "(" , group_concat(column_name), ")") + as schema_info from information_schema.COLUMNS where + table_schema="{self.get_current_db_name()}" group by TABLE_NAME; """ cursor = self.session.execute(text(_sql)) results = cursor.fetchall() @@ -222,12 +232,16 @@ class RDBMSDatabase(BaseConnect): return final_str def get_columns(self, table_name: str) -> List[Dict]: - """Get columns. + """Get columns about specified table. + Args: table_name (str): table name + Returns: - columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str - eg:[{'name': 'id', 'type': 'int', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...] + columns: List[Dict], which contains name: str, type: str, + default_expression: str, is_in_primary_key: bool, comment: str + eg:[{'name': 'id', 'type': 'int', 'default_expression': '', + 'is_in_primary_key': True, 'comment': 'id'}, ...] """ return self._inspector.get_columns(table_name) @@ -280,13 +294,14 @@ class RDBMSDatabase(BaseConnect): Args: write_sql (str): SQL write command to run """ - print(f"Write[{write_sql}]") + logger.info(f"Write[{write_sql}]") db_cache = self._engine.url.database result = self.session.execute(text(write_sql)) self.session.commit() - # TODO Subsequent optimization of dynamically specified database submission loss target problem + # TODO Subsequent optimization of dynamically specified database submission + # loss target problem self.session.execute(text(f"use `{db_cache}`")) - print(f"SQL[{write_sql}], result:{result.rowcount}") + logger.info(f"SQL[{write_sql}], result:{result.rowcount}") return result.rowcount def _query(self, query: str, fetch: str = "all"): @@ -296,9 +311,9 @@ class RDBMSDatabase(BaseConnect): query (str): SQL query to run fetch (str): fetch type """ - result = [] + result: List[Any] = [] - print(f"Query[{query}]") + logger.info(f"Query[{query}]") if not query: return result cursor = self.session.execute(text(query)) @@ -314,20 +329,28 @@ class RDBMSDatabase(BaseConnect): result.insert(0, field_names) return result - def query_table_schema(self, table_name): + def query_table_schema(self, table_name: str): + """Query table schema. + + Args: + table_name (str): table name + """ sql = f"select * from {table_name} limit 1" return self._query(sql) - def query_ex(self, query, fetch: str = "all"): - """ - only for query + def query_ex(self, query: str, fetch: str = "all"): + """Execute a SQL command and return the results. + + Only for query command. + Args: - session: - query: - fetch: + query (str): SQL query to run + fetch (str): fetch type + Returns: + List: result list """ - print(f"Query[{query}]") + logger.info(f"Query[{query}]") if not query: return [], None cursor = self.session.execute(text(query)) @@ -338,7 +361,7 @@ class RDBMSDatabase(BaseConnect): result = cursor.fetchone() # type: ignore else: raise ValueError("Fetch parameter must be either 'one' or 'all'") - field_names = list(i[0:] for i in cursor.keys()) + field_names = list(cursor.keys()) result = list(result) return field_names, result @@ -346,7 +369,7 @@ class RDBMSDatabase(BaseConnect): def run(self, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results.""" - print("SQL:" + command) + logger.info("SQL:" + command) if not command or len(command) < 0: return [] parsed, ttype, sql_type, table_name = self.__sql_parse(command) @@ -356,11 +379,13 @@ class RDBMSDatabase(BaseConnect): else: self._write(command) select_sql = self.convert_sql_write_to_select(command) - print(f"write result query:{select_sql}") + logger.info(f"write result query:{select_sql}") return self._query(select_sql) else: - print(f"DDL execution determines whether to enable through configuration ") + logger.info( + "DDL execution determines whether to enable through configuration " + ) cursor = self.session.execute(text(command)) self.session.commit() if cursor.returns_rows: @@ -368,7 +393,7 @@ class RDBMSDatabase(BaseConnect): field_names = tuple(i[0:] for i in cursor.keys()) result = list(result) result.insert(0, field_names) - print("DDL Result:" + str(result)) + logger.info("DDL Result:" + str(result)) if not result: # return self._query(f"SHOW COLUMNS FROM {table_name}") return self.get_simple_fields(table_name) @@ -377,6 +402,7 @@ class RDBMSDatabase(BaseConnect): return self.get_simple_fields(table_name) def run_to_df(self, command: str, fetch: str = "all"): + """Execute sql command and return result as dataframe.""" import pandas as pd # Pandas has too much dependence and the import time is too long @@ -398,44 +424,45 @@ class RDBMSDatabase(BaseConnect): return self.run(command, fetch) except SQLAlchemyError as e: """Format the error message""" - return f"Error: {e}" + logger.warning(f"Run SQL command failed: {e}") + return [] - def get_database_list(self): - session = self._db_sessions() - cursor = session.execute(text(" show databases;")) - results = cursor.fetchall() - return [ - d[0] - for d in results - if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"] - ] + def convert_sql_write_to_select(self, write_sql: str) -> str: + """Convert SQL write command to a SELECT command. - def convert_sql_write_to_select(self, write_sql): - """ SQL classification processing author:xiangh8 + + Examples: + .. code-block:: python + + write_sql = "insert into test(id) values (1)" + select_sql = convert_sql_write_to_select(write_sql) + print(select_sql) + # SELECT * FROM test WHERE id=1 Args: - sql: + write_sql (str): SQL write command Returns: - + str: SELECT command corresponding to the write command """ - # 将SQL命令转换为小写,并按空格拆分 + # Convert the SQL command to lowercase and split by space parts = write_sql.lower().split() - # 获取命令类型(insert, delete, update) + # Get the command type (insert, delete, update) cmd_type = parts[0] - # 根据命令类型进行处理 + # Handle according to command type if cmd_type == "insert": match = re.match( - r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower() + r"insert\s+into\s+(\w+)\s*\(([^)]+)\)\s*values\s*\(([^)]+)\)", + write_sql.lower(), ) if match: + # Get the table name, columns, and values table_name, columns, values = match.groups() - # 将字段列表和值列表分割为单独的字段和值 columns = columns.split(",") values = values.split(",") - # 构造 WHERE 子句 + # Build the WHERE clause where_clause = " AND ".join( [ f"{col.strip()}={val.strip()}" @@ -443,21 +470,23 @@ class RDBMSDatabase(BaseConnect): ] ) return f"SELECT * FROM {table_name} WHERE {where_clause}" + else: + raise ValueError(f"Unsupported SQL command: {write_sql}") elif cmd_type == "delete": table_name = parts[2] # delete from ... - # 返回一个select语句,它选择该表的所有数据 + # Return a SELECT statement that selects all data from the table return f"SELECT * FROM {table_name} " elif cmd_type == "update": table_name = parts[1] set_idx = parts.index("set") where_idx = parts.index("where") - # 截取 `set` 子句中的字段名 + # Get the field name in the `set` clause set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip() - # 截取 `where` 之后的条件语句 + # Get the condition statement after the `where` where_clause = " ".join(parts[where_idx + 1 :]) - # 返回一个select语句,它选择更新的数据 + # Return a SELECT statement that selects the updated data return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}" else: raise ValueError(f"Unsupported SQL command type: {cmd_type}") @@ -473,7 +502,9 @@ class RDBMSDatabase(BaseConnect): first_token = parsed.token_first(skip_ws=True, skip_cm=False) ttype = first_token.ttype - print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}") + logger.info( + f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}" + ) return parsed, ttype, sql_type, table_name def _extract_table_name_from_ddl(self, parsed): @@ -485,8 +516,10 @@ class RDBMSDatabase(BaseConnect): def get_indexes(self, table_name: str) -> List[Dict]: """Get table indexes about specified table. + Args: table_name:(str) table name + Returns: List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}] """ @@ -504,9 +537,9 @@ class RDBMSDatabase(BaseConnect): session = self._db_sessions() cursor = session.execute( text( - f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format( - table_name - ) + "SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, " + "COLUMN_COMMENT from information_schema.COLUMNS where " + f"table_name='{table_name}'".format(table_name) ) ) fields = cursor.fetchall() @@ -516,40 +549,41 @@ class RDBMSDatabase(BaseConnect): """Get column fields about specified table.""" return self._query(f"SHOW COLUMNS FROM {table_name}") - def get_charset(self): + def get_charset(self) -> str: """Get character_set.""" session = self._db_sessions() - cursor = session.execute(text(f"SELECT @@character_set_database")) - character_set = cursor.fetchone()[0] + cursor = session.execute(text("SELECT @@character_set_database")) + character_set = cursor.fetchone()[0] # type: ignore return character_set def get_collation(self): """Get collation.""" session = self._db_sessions() - cursor = session.execute(text(f"SELECT @@collation_database")) + cursor = session.execute(text("SELECT @@collation_database")) collation = cursor.fetchone()[0] return collation def get_grants(self): """Get grant info.""" session = self._db_sessions() - cursor = session.execute(text(f"SHOW GRANTS")) + cursor = session.execute(text("SHOW GRANTS")) grants = cursor.fetchall() return grants def get_users(self): """Get user info.""" try: - cursor = self.session.execute(text(f"SELECT user, host FROM mysql.user")) + cursor = self.session.execute(text("SELECT user, host FROM mysql.user")) users = cursor.fetchall() return [(user[0], user[1]) for user in users] - except Exception as e: + except Exception: return [] def get_table_comments(self, db_name: str): + """Return table comments.""" cursor = self.session.execute( text( - f"""SELECT table_name, table_comment FROM information_schema.tables + f"""SELECT table_name, table_comment FROM information_schema.tables WHERE table_schema = '{db_name}'""".format( db_name ) @@ -570,10 +604,11 @@ class RDBMSDatabase(BaseConnect): """ return self._inspector.get_table_comment(table_name) - def get_column_comments(self, db_name, table_name): + def get_column_comments(self, db_name: str, table_name: str): + """Return column comments.""" cursor = self.session.execute( text( - f"""SELECT column_name, column_comment FROM information_schema.columns + f"""SELECT column_name, column_comment FROM information_schema.columns WHERE table_schema = '{db_name}' and table_name = '{table_name}' """.format( db_name, table_name @@ -585,17 +620,12 @@ class RDBMSDatabase(BaseConnect): (column_comment[0], column_comment[1]) for column_comment in column_comments ] - def get_database_list(self): - session = self._db_sessions() - cursor = session.execute(text(" show databases;")) - results = cursor.fetchall() - return [ - d[0] - for d in results - if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"] - ] + def get_database_names(self) -> List[str]: + """Return a list of database names available in the database. - def get_database_names(self): + Returns: + List[str]: database list + """ session = self._db_sessions() cursor = session.execute(text(" show databases;")) results = cursor.fetchall() diff --git a/dbgpt/datasource/rdbms/conn_clickhouse.py b/dbgpt/datasource/rdbms/conn_clickhouse.py index ed1c1dbb1..55be461b8 100644 --- a/dbgpt/datasource/rdbms/conn_clickhouse.py +++ b/dbgpt/datasource/rdbms/conn_clickhouse.py @@ -1,18 +1,20 @@ +"""Clickhouse connector.""" +import logging import re from typing import Any, Dict, Iterable, List, Optional, Tuple import sqlparse from sqlalchemy import MetaData, text -from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.storage.schema import DBType +from .base import RDBMSConnector -class ClickhouseConnect(RDBMSDatabase): - """Connect Clickhouse Database fetch MetaData - Args: - Usage: - """ +logger = logging.getLogger(__name__) + + +class ClickhouseConnector(RDBMSConnector): + """Clickhouse connector.""" """db type""" db_type: str = "clickhouse" @@ -24,6 +26,7 @@ class ClickhouseConnect(RDBMSDatabase): client: Any = None def __init__(self, client, **kwargs): + """Create a new ClickhouseConnector from client.""" self.client = client self._all_tables = set() @@ -49,7 +52,8 @@ class ClickhouseConnect(RDBMSDatabase): db_name: str, engine_args: Optional[dict] = None, **kwargs: Any, - ) -> RDBMSDatabase: + ) -> "ClickhouseConnector": + """Create a new ClickhouseConnector from host, port, user, pwd, db_name.""" import clickhouse_connect from clickhouse_connect.driver import httputil @@ -70,11 +74,6 @@ class ClickhouseConnect(RDBMSDatabase): cls.client = client return cls(client, **kwargs) - @property - def dialect(self) -> str: - """Return string representation of dialect to use.""" - pass - def get_table_names(self): """Get all table names.""" session = self.client @@ -85,6 +84,7 @@ class ClickhouseConnect(RDBMSDatabase): def get_indexes(self, table_name: str) -> List[Dict]: """Get table indexes about specified table. + Args: table_name (str): table name Returns: @@ -93,10 +93,11 @@ class ClickhouseConnect(RDBMSDatabase): session = self.client _query_sql = f""" - SELECT name AS table, primary_key, from system.tables where database ='{self.client.database}' and table = '{table_name}' + SELECT name AS table, primary_key, from system.tables where + database ='{self.client.database}' and table = '{table_name}' """ with session.query_row_block_stream(_query_sql) as stream: - indexes = [block for block in stream] + indexes = [block for block in stream] # noqa return [ {"name": "primary_key", "column_names": column_names.split(",")} for table, column_names in indexes[0] @@ -104,6 +105,7 @@ class ClickhouseConnect(RDBMSDatabase): @property def table_info(self) -> str: + """Get table info.""" return self.get_table_info() def get_table_info(self, table_names: Optional[List[str]] = None) -> str: @@ -117,7 +119,7 @@ class ClickhouseConnect(RDBMSDatabase): demonstrated in the paper. """ # TODO: - pass + return "" def get_show_create_table(self, table_name): """Get table show create table about specified table.""" @@ -133,11 +135,14 @@ class ClickhouseConnect(RDBMSDatabase): def get_columns(self, table_name: str) -> List[Dict]: """Get columns. + Args: table_name (str): str Returns: - columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str - eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...] + List[Dict], which contains name: str, type: str, + default_expression: str, is_in_primary_key: bool, comment: str + eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', + 'is_in_primary_key': True, 'comment': 'id'}, ...] """ fields = self.get_fields(table_name) return [ @@ -150,18 +155,21 @@ class ClickhouseConnect(RDBMSDatabase): session = self.client _query_sql = f""" - SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}' + SELECT name, type, default_expression, is_in_primary_key, comment + from system.columns where table='{table_name}' """.format( table_name ) with session.query_row_block_stream(_query_sql) as stream: - fields = [block for block in stream] + fields = [block for block in stream] # noqa return fields def get_users(self): + """Get user info.""" return [] def get_grants(self): + """Get grants.""" return [] def get_collation(self): @@ -169,9 +177,11 @@ class ClickhouseConnect(RDBMSDatabase): return "UTF-8" def get_charset(self): + """Get character_set.""" return "UTF-8" - def get_database_list(self): + def get_database_names(self): + """Get database names.""" session = self.client with session.command("SHOW DATABASES") as stream: @@ -184,12 +194,10 @@ class ClickhouseConnect(RDBMSDatabase): ] return databases - def get_database_names(self): - return self.get_database_list() - def run(self, command: str, fetch: str = "all") -> List: + """Execute sql command.""" # TODO need to be implemented - print("SQL:" + command) + logger.info("SQL:" + command) if not command or len(command) < 0: return [] _, ttype, sql_type, table_name = self.__sql_parse(command) @@ -199,10 +207,12 @@ class ClickhouseConnect(RDBMSDatabase): else: self._write(command) select_sql = self.convert_sql_write_to_select(command) - print(f"write result query:{select_sql}") + logger.info(f"write result query:{select_sql}") return self._query(select_sql) else: - print(f"DDL execution determines whether to enable through configuration ") + logger.info( + "DDL execution determines whether to enable through configuration " + ) cursor = self.client.command(command) @@ -212,7 +222,7 @@ class ClickhouseConnect(RDBMSDatabase): result = list(result) result.insert(0, field_names) - print("DDL Result:" + str(result)) + logger.info("DDL Result:" + str(result)) if not result: # return self._query(f"SHOW COLUMNS FROM {table_name}") return self.get_simple_fields(table_name) @@ -225,13 +235,16 @@ class ClickhouseConnect(RDBMSDatabase): return self._query(f"SHOW COLUMNS FROM {table_name}") def get_current_db_name(self): + """Get current database name.""" return self.client.database def get_table_comments(self, db_name: str): + """Get table comments.""" session = self.client _query_sql = f""" - SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format( + SELECT table, comment FROM system.tables WHERE database = '{db_name}' + """.format( db_name ) @@ -241,6 +254,7 @@ class ClickhouseConnect(RDBMSDatabase): def get_table_comment(self, table_name: str) -> Dict: """Get table comment. + Args: table_name (str): table name Returns: @@ -249,7 +263,9 @@ class ClickhouseConnect(RDBMSDatabase): session = self.client _query_sql = f""" - SELECT table, comment FROM system.tables WHERE database = '{self.client.database}'and table = '{table_name}'""".format( + SELECT table, comment FROM system.tables WHERE + database = '{self.client.database}'and table = '{table_name}' + """.format( self.client.database ) @@ -258,9 +274,11 @@ class ClickhouseConnect(RDBMSDatabase): return [{"text": comment} for table_name, comment in table_comments][0] def get_column_comments(self, db_name, table_name): + """Get column comments.""" session = self.client _query_sql = f""" - select name column, comment from system.columns where database='{db_name}' and table='{table_name}' + select name column, comment from system.columns where database='{db_name}' + and table='{table_name}' """.format( db_name, table_name ) @@ -270,10 +288,13 @@ class ClickhouseConnect(RDBMSDatabase): return column_comments def table_simple_info(self): - # group_concat() not supported in clickhouse, use arrayStringConcat+groupArray instead; and quotes need to be escaped + """Get table simple info.""" + # group_concat() not supported in clickhouse, use arrayStringConcat+groupArray + # instead; and quotes need to be escaped _sql = f""" - SELECT concat(TABLE_NAME, '(', arrayStringConcat(groupArray(column_name), '-'), ')') AS schema_info + SELECT concat(TABLE_NAME, '(', arrayStringConcat( + groupArray(column_name), '-'), ')') AS schema_info FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = '{self.get_current_db_name()}' GROUP BY TABLE_NAME @@ -282,18 +303,18 @@ class ClickhouseConnect(RDBMSDatabase): return [row[0] for block in stream for row in block] def _write(self, write_sql: str): - """write data + """Execute write sql. Args: write_sql (str): sql string """ # TODO need to be implemented - print(f"Write[{write_sql}]") + logger.info(f"Write[{write_sql}]") result = self.client.command(write_sql) - print(f"SQL[{write_sql}], result:{result.written_rows}") + logger.info(f"SQL[{write_sql}], result:{result.written_rows}") def _query(self, query: str, fetch: str = "all"): - """Query data from clickhouse + """Query data from clickhouse. Args: query (str): sql string @@ -306,7 +327,7 @@ class ClickhouseConnect(RDBMSDatabase): _type_: List """ # TODO need to be implemented - print(f"Query[{query}]") + logger.info(f"Query[{query}]") if not query: return [] @@ -334,11 +355,13 @@ class ClickhouseConnect(RDBMSDatabase): first_token = parsed.token_first(skip_ws=True, skip_cm=False) ttype = first_token.ttype - print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}") + logger.info( + f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}" + ) return parsed, ttype, sql_type, table_name def _sync_tables_from_db(self) -> Iterable[str]: - """Read table information from database""" + """Read table information from database.""" # TODO Use a background thread to refresh periodically # SQL will raise error with schema diff --git a/dbgpt/datasource/rdbms/conn_doris.py b/dbgpt/datasource/rdbms/conn_doris.py index 47aaea2a5..9f1aeece2 100644 --- a/dbgpt/datasource/rdbms/conn_doris.py +++ b/dbgpt/datasource/rdbms/conn_doris.py @@ -1,13 +1,16 @@ -from typing import Any, Dict, Iterable, List, Optional, Tuple +"""Doris connector.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, cast from urllib.parse import quote from urllib.parse import quote_plus as urlquote from sqlalchemy import text -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from .base import RDBMSConnector -class DorisConnect(RDBMSDatabase): +class DorisConnector(RDBMSConnector): + """Doris connector.""" + driver = "doris" db_type = "doris" db_dialect = "doris" @@ -22,24 +25,27 @@ class DorisConnect(RDBMSDatabase): db_name: str, engine_args: Optional[dict] = None, **kwargs: Any, - ) -> RDBMSDatabase: + ) -> "DorisConnector": + """Create a new DorisConnector from host, port, user, pwd, db_name.""" db_url: str = ( f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}" ) - return cls.from_uri(db_url, engine_args, **kwargs) + return cast(DorisConnector, cls.from_uri(db_url, engine_args, **kwargs)) def _sync_tables_from_db(self) -> Iterable[str]: table_results = self.get_session().execute( text( - f"SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=database()" + "SELECT TABLE_NAME FROM information_schema.tables where " + "TABLE_SCHEMA=database()" ) ) - table_results = set(row[0] for row in table_results) + table_results = set(row[0] for row in table_results) # noqa: C401 self._all_tables = table_results self._metadata.reflect(bind=self._engine) return self._all_tables def get_grants(self): + """Get grants.""" cursor = self.get_session().execute(text("SHOW GRANTS")) grants = cursor.fetchall() if len(grants) == 0: @@ -51,14 +57,17 @@ class DorisConnect(RDBMSDatabase): return grants_list def _get_current_version(self): - """Get database current version""" + """Get database current version.""" return int( self.get_session().execute(text("select current_version()")).scalar() ) def get_collation(self): """Get collation. - ref: https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-reference/Show-Statements/SHOW-COLLATION/ + + ref `SHOW COLLATION `_ + """ cursor = self.get_session().execute(text("SHOW COLLATION")) results = cursor.fetchall() @@ -70,11 +79,14 @@ class DorisConnect(RDBMSDatabase): def get_columns(self, table_name: str) -> List[Dict]: """Get columns. + Args: table_name (str): str Returns: - columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str - eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...] + columns: List[Dict], which contains name: str, type: str, + default_expression: str, is_in_primary_key: bool, comment: str + eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', + 'is_in_primary_key': True, 'comment': 'id'}, ...] """ fields = self.get_fields(table_name) return [ @@ -92,8 +104,8 @@ class DorisConnect(RDBMSDatabase): """Get column fields about specified table.""" cursor = self.get_session().execute( text( - f"select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT " - f"from information_schema.columns " + "select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, " + "COLUMN_COMMENT from information_schema.columns " f'where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()' ) ) @@ -104,7 +116,8 @@ class DorisConnect(RDBMSDatabase): """Get character_set.""" return "utf-8" - def get_show_create_table(self, table_name): + def get_show_create_table(self, table_name) -> str: + """Get show create table.""" # cur = self.get_session().execute( # text( # f"""show create table {table_name}""" @@ -128,6 +141,7 @@ class DorisConnect(RDBMSDatabase): return "" def get_table_comments(self, db_name=None): + """Get table comments.""" db_name = "database()" if not db_name else f"'{db_name}'" cursor = self.get_session().execute( text( @@ -139,10 +153,8 @@ class DorisConnect(RDBMSDatabase): tables = cursor.fetchall() return [(table[0], table[1]) for table in tables] - def get_database_list(self): - return self.get_database_names() - def get_database_names(self): + """Get database names.""" cursor = self.get_session().execute(text("SHOW DATABASES")) results = cursor.fetchall() return [ @@ -160,15 +172,17 @@ class DorisConnect(RDBMSDatabase): ] def get_current_db_name(self) -> str: + """Get current database name.""" return self.get_session().execute(text("select database()")).scalar() def table_simple_info(self): + """Get table simple info.""" cursor = self.get_session().execute( text( - f"SELECT concat(TABLE_NAME,'(',group_concat(COLUMN_NAME,','),');') " - f"FROM information_schema.columns " - f"where TABLE_SCHEMA=database() " - f"GROUP BY TABLE_NAME" + "SELECT concat(TABLE_NAME,'(',group_concat(COLUMN_NAME,','),');') " + "FROM information_schema.columns " + "where TABLE_SCHEMA=database() " + "GROUP BY TABLE_NAME" ) ) results = cursor.fetchall() diff --git a/dbgpt/datasource/rdbms/conn_duckdb.py b/dbgpt/datasource/rdbms/conn_duckdb.py index f2f53a2a8..2e0c40ac0 100644 --- a/dbgpt/datasource/rdbms/conn_duckdb.py +++ b/dbgpt/datasource/rdbms/conn_duckdb.py @@ -1,15 +1,13 @@ +"""DuckDB connector.""" from typing import Any, Iterable, Optional from sqlalchemy import create_engine, text -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from .base import RDBMSConnector -class DuckDbConnect(RDBMSDatabase): - """Connect Duckdb Database fetch MetaData - Args: - Usage: - """ +class DuckDbConnector(RDBMSConnector): + """DuckDB connector.""" db_type: str = "duckdb" db_dialect: str = "duckdb" @@ -17,21 +15,24 @@ class DuckDbConnect(RDBMSDatabase): @classmethod def from_file_path( cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any - ) -> RDBMSDatabase: + ) -> RDBMSConnector: """Construct a SQLAlchemy engine from URI.""" _engine_args = engine_args or {} return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs) def get_users(self): + """Get users.""" cursor = self.session.execute( text( - f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';" + "SELECT * FROM sqlite_master WHERE type = 'table' AND " + "name = 'duckdb_sys_users';" ) ) users = cursor.fetchall() return [(user[0], user[1]) for user in users] def get_grants(self): + """Get grants.""" return [] def get_collation(self): @@ -39,12 +40,14 @@ class DuckDbConnect(RDBMSDatabase): return "UTF-8" def get_charset(self): + """Get character_set of current database.""" return "UTF-8" - def get_table_comments(self, db_name): + def get_table_comments(self, db_name: str): + """Get table comments.""" cursor = self.session.execute( text( - f""" + """ SELECT name, sql FROM sqlite_master WHERE type='table' """ ) @@ -55,7 +58,8 @@ class DuckDbConnect(RDBMSDatabase): ] def table_simple_info(self) -> Iterable[str]: - _tables_sql = f""" + """Get table simple info.""" + _tables_sql = """ SELECT name FROM sqlite_master WHERE type='table' """ cursor = self.session.execute(text(_tables_sql)) diff --git a/dbgpt/datasource/rdbms/conn_hive.py b/dbgpt/datasource/rdbms/conn_hive.py index 423f50bf8..335f315f3 100644 --- a/dbgpt/datasource/rdbms/conn_hive.py +++ b/dbgpt/datasource/rdbms/conn_hive.py @@ -1,14 +1,13 @@ -from typing import Any, Optional +"""Hive Connector.""" +from typing import Any, Optional, cast from urllib.parse import quote from urllib.parse import quote_plus as urlquote -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from .base import RDBMSConnector -class HiveConnect(RDBMSDatabase): - """db type""" +class HiveConnector(RDBMSConnector): + """Hive connector.""" db_type: str = "hive" """db driver""" @@ -26,28 +25,26 @@ class HiveConnect(RDBMSDatabase): db_name: str, engine_args: Optional[dict] = None, **kwargs: Any, - ) -> RDBMSDatabase: - """Construct a SQLAlchemy engine from uri database. - Args: - host (str): database host. - port (int): database port. - user (str): database user. - pwd (str): database password. - db_name (str): database name. - engine_args (Optional[dict]):other engine_args. - """ + ) -> "HiveConnector": + """Create a new HiveConnector from host, port, user, pwd, db_name.""" db_url: str = f"{cls.driver}://{host}:{str(port)}/{db_name}" if user and pwd: - db_url: str = f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}" - return cls.from_uri(db_url, engine_args, **kwargs) + db_url = ( + f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/" + f"{db_name}" + ) + return cast(HiveConnector, cls.from_uri(db_url, engine_args, **kwargs)) def table_simple_info(self): + """Get table simple info.""" return [] def get_users(self): + """Get users.""" return [] def get_grants(self): + """Get grants.""" return [] def get_collation(self): @@ -55,4 +52,5 @@ class HiveConnect(RDBMSDatabase): return "UTF-8" def get_charset(self): + """Get character_set of current database.""" return "UTF-8" diff --git a/dbgpt/datasource/rdbms/conn_mssql.py b/dbgpt/datasource/rdbms/conn_mssql.py index 961ff20d1..e9a25422d 100644 --- a/dbgpt/datasource/rdbms/conn_mssql.py +++ b/dbgpt/datasource/rdbms/conn_mssql.py @@ -1,17 +1,13 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -from typing import Any, Iterable, Optional +"""MSSQL connector.""" +from typing import Iterable -from sqlalchemy import MetaData, Table, create_engine, inspect, select, text +from sqlalchemy import text -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from .base import RDBMSConnector -class MSSQLConnect(RDBMSDatabase): - """Connect MSSQL Database fetch MetaData - Args: - Usage: - """ +class MSSQLConnector(RDBMSConnector): + """MSSQL connector.""" db_type: str = "mssql" db_dialect: str = "mssql" @@ -20,8 +16,10 @@ class MSSQLConnect(RDBMSDatabase): default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource", "sys"] def table_simple_info(self) -> Iterable[str]: - _tables_sql = f""" - SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' + """Get table simple info.""" + _tables_sql = """ + SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE + TABLE_TYPE='BASE TABLE' """ cursor = self.session.execute(text(_tables_sql)) tables_results = cursor.fetchall() @@ -29,7 +27,8 @@ class MSSQLConnect(RDBMSDatabase): for row in tables_results: table_name = row[0] _sql = f""" - SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='{table_name}' + SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE + TABLE_NAME='{table_name}' """ cursor_colums = self.session.execute(text(_sql)) colum_results = cursor_colums.fetchall() diff --git a/dbgpt/datasource/rdbms/conn_mysql.py b/dbgpt/datasource/rdbms/conn_mysql.py index 24db25c87..8bd42140b 100644 --- a/dbgpt/datasource/rdbms/conn_mysql.py +++ b/dbgpt/datasource/rdbms/conn_mysql.py @@ -1,13 +1,10 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -from dbgpt.datasource.rdbms.base import RDBMSDatabase +"""MySQL connector.""" + +from .base import RDBMSConnector -class MySQLConnect(RDBMSDatabase): - """Connect MySQL Database fetch MetaData - Args: - Usage: - """ +class MySQLConnector(RDBMSConnector): + """MySQL connector.""" db_type: str = "mysql" db_dialect: str = "mysql" diff --git a/dbgpt/datasource/rdbms/conn_postgresql.py b/dbgpt/datasource/rdbms/conn_postgresql.py index 71b63da07..4829c725c 100644 --- a/dbgpt/datasource/rdbms/conn_postgresql.py +++ b/dbgpt/datasource/rdbms/conn_postgresql.py @@ -1,13 +1,19 @@ -from typing import Any, Iterable, List, Optional, Tuple +"""PostgreSQL connector.""" +import logging +from typing import Any, Iterable, List, Optional, Tuple, cast from urllib.parse import quote from urllib.parse import quote_plus as urlquote from sqlalchemy import text -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from .base import RDBMSConnector + +logger = logging.getLogger(__name__) -class PostgreSQLDatabase(RDBMSDatabase): +class PostgreSQLConnector(RDBMSConnector): + """PostgreSQL connector.""" + driver = "postgresql+psycopg2" db_type = "postgresql" db_dialect = "postgresql" @@ -22,34 +28,38 @@ class PostgreSQLDatabase(RDBMSDatabase): db_name: str, engine_args: Optional[dict] = None, **kwargs: Any, - ) -> RDBMSDatabase: + ) -> "PostgreSQLConnector": + """Create a new PostgreSQLConnector from host, port, user, pwd, db_name.""" db_url: str = ( f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}" ) - return cls.from_uri(db_url, engine_args, **kwargs) + return cast(PostgreSQLConnector, cls.from_uri(db_url, engine_args, **kwargs)) def _sync_tables_from_db(self) -> Iterable[str]: table_results = self.session.execute( text( - "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'" + "SELECT tablename FROM pg_catalog.pg_tables WHERE " + "schemaname != 'pg_catalog' AND schemaname != 'information_schema'" ) ) view_results = self.session.execute( text( - "SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'" + "SELECT viewname FROM pg_catalog.pg_views WHERE " + "schemaname != 'pg_catalog' AND schemaname != 'information_schema'" ) ) - table_results = set(row[0] for row in table_results) - view_results = set(row[0] for row in view_results) + table_results = set(row[0] for row in table_results) # noqa: C401 + view_results = set(row[0] for row in view_results) # noqa: C401 self._all_tables = table_results.union(view_results) self._metadata.reflect(bind=self._engine) return self._all_tables def get_grants(self): + """Get grants.""" session = self._db_sessions() cursor = session.execute( text( - f""" + """ SELECT DISTINCT grantee, privilege_type FROM information_schema.role_table_grants WHERE grantee = CURRENT_USER;""" @@ -64,13 +74,14 @@ class PostgreSQLDatabase(RDBMSDatabase): session = self._db_sessions() cursor = session.execute( text( - "SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();" + "SELECT datcollate AS collation FROM pg_database WHERE " + "datname = current_database();" ) ) collation = cursor.fetchone()[0] return collation except Exception as e: - print("postgresql get collation error: ", e) + logger.warning(f"postgresql get collation error: {str(e)}") return None def get_users(self): @@ -82,7 +93,7 @@ class PostgreSQLDatabase(RDBMSDatabase): users = cursor.fetchall() return [user[0] for user in users] except Exception as e: - print("postgresql get users error: ", e) + logger.warning(f"postgresql get users error: {str(e)}") return [] def get_fields(self, table_name) -> List[Tuple]: @@ -90,7 +101,8 @@ class PostgreSQLDatabase(RDBMSDatabase): session = self._db_sessions() cursor = session.execute( text( - f"SELECT column_name, data_type, column_default, is_nullable, column_name as column_comment \ + "SELECT column_name, data_type, column_default, is_nullable, " + "column_name as column_comment \ FROM information_schema.columns WHERE table_name = :table_name", ), {"table_name": table_name}, @@ -103,23 +115,28 @@ class PostgreSQLDatabase(RDBMSDatabase): session = self._db_sessions() cursor = session.execute( text( - "SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();" + "SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE " + "datname = current_database();" ) ) character_set = cursor.fetchone()[0] return character_set - def get_show_create_table(self, table_name): + def get_show_create_table(self, table_name: str): + """Return show create table.""" cur = self.session.execute( text( f""" - SELECT a.attname as column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type + SELECT a.attname as column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type FROM pg_catalog.pg_attribute a WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attnum <= ( SELECT max(a.attnum) FROM pg_catalog.pg_attribute a - WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}') - ) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}') + WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class + WHERE relname='{table_name}') + ) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class + WHERE relname='{table_name}') """ ) ) @@ -133,6 +150,7 @@ class PostgreSQLDatabase(RDBMSDatabase): return create_table_query def get_table_comments(self, db_name=None): + """Get table comments.""" tablses = self.table_simple_info() comments = [] for table in tablses: @@ -141,15 +159,8 @@ class PostgreSQLDatabase(RDBMSDatabase): comments.append((table_name, table_comment)) return comments - def get_database_list(self): - session = self._db_sessions() - cursor = session.execute(text("SELECT datname FROM pg_database;")) - results = cursor.fetchall() - return [ - d[0] for d in results if d[0] not in ["template0", "template1", "postgres"] - ] - def get_database_names(self): + """Get database names.""" session = self._db_sessions() cursor = session.execute(text("SELECT datname FROM pg_database;")) results = cursor.fetchall() @@ -158,10 +169,12 @@ class PostgreSQLDatabase(RDBMSDatabase): ] def get_current_db_name(self) -> str: + """Get current database name.""" return self.session.execute(text("SELECT current_database()")).scalar() def table_simple_info(self): - _sql = f""" + """Get table simple info.""" + _sql = """ SELECT table_name, string_agg(column_name, ', ') AS schema_info FROM ( SELECT c.relname AS table_name, a.attname AS column_name @@ -181,17 +194,18 @@ class PostgreSQLDatabase(RDBMSDatabase): results = cursor.fetchall() return results - def get_fields(self, table_name, schema_name="public"): + def get_fields_wit_schema(self, table_name, schema_name="public"): """Get column fields about specified table.""" session = self._db_sessions() cursor = session.execute( text( f""" - SELECT c.column_name, c.data_type, c.column_default, c.is_nullable, d.description - FROM information_schema.columns c - LEFT JOIN pg_catalog.pg_description d - ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid AND c.ordinal_position = d.objsubid - WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}' + SELECT c.column_name, c.data_type, c.column_default, c.is_nullable, + d.description FROM information_schema.columns c + LEFT JOIN pg_catalog.pg_description d + ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid + AND c.ordinal_position = d.objsubid + WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}' """ ) ) @@ -203,7 +217,8 @@ class PostgreSQLDatabase(RDBMSDatabase): session = self._db_sessions() cursor = session.execute( text( - f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'" + f"SELECT indexname, indexdef FROM pg_indexes WHERE " + f"tablename = '{table_name}'" ) ) indexes = cursor.fetchall() diff --git a/dbgpt/datasource/rdbms/conn_sqlite.py b/dbgpt/datasource/rdbms/conn_sqlite.py index c9f442d89..e950e0843 100644 --- a/dbgpt/datasource/rdbms/conn_sqlite.py +++ b/dbgpt/datasource/rdbms/conn_sqlite.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - +"""SQLite connector.""" import logging import os import tempfile @@ -8,16 +6,13 @@ from typing import Any, Iterable, List, Optional, Tuple from sqlalchemy import create_engine, text -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from .base import RDBMSConnector logger = logging.getLogger(__name__) -class SQLiteConnect(RDBMSDatabase): - """Connect SQLite Database fetch MetaData - Args: - Usage: - """ +class SQLiteConnector(RDBMSConnector): + """SQLite connector.""" db_type: str = "sqlite" db_dialect: str = "sqlite" @@ -25,8 +20,8 @@ class SQLiteConnect(RDBMSDatabase): @classmethod def from_file_path( cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any - ) -> RDBMSDatabase: - """Construct a SQLAlchemy engine from URI.""" + ) -> "SQLiteConnector": + """Create a new SQLiteConnector from file path.""" _engine_args = engine_args or {} _engine_args["connect_args"] = {"check_same_thread": False} # _engine_args["echo"] = True @@ -52,7 +47,8 @@ class SQLiteConnect(RDBMSDatabase): """Get table show create table about specified table.""" cursor = self.session.execute( text( - f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'" + "SELECT sql FROM sqlite_master WHERE type='table' " + f"AND name='{table_name}'" ) ) ans = cursor.fetchall() @@ -62,7 +58,7 @@ class SQLiteConnect(RDBMSDatabase): """Get column fields about specified table.""" cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')")) fields = cursor.fetchall() - print(fields) + logger.info(fields) return [(field[1], field[2], field[3], field[4], field[5]) for field in fields] def get_simple_fields(self, table_name): @@ -70,9 +66,11 @@ class SQLiteConnect(RDBMSDatabase): return self.get_fields(table_name) def get_users(self): + """Get user info.""" return [] def get_grants(self): + """Get grants.""" return [] def get_collation(self): @@ -80,12 +78,11 @@ class SQLiteConnect(RDBMSDatabase): return "UTF-8" def get_charset(self): + """Get character_set of current database.""" return "UTF-8" - def get_database_list(self): - return [] - def get_database_names(self): + """Get database names.""" return [] def _sync_tables_from_db(self) -> Iterable[str]: @@ -95,25 +92,27 @@ class SQLiteConnect(RDBMSDatabase): view_results = self.session.execute( text("SELECT name FROM sqlite_master WHERE type='view'") ) - table_results = set(row[0] for row in table_results) - view_results = set(row[0] for row in view_results) + table_results = set(row[0] for row in table_results) # noqa + view_results = set(row[0] for row in view_results) # noqa self._all_tables = table_results.union(view_results) self._metadata.reflect(bind=self._engine) return self._all_tables def _write(self, write_sql): - print(f"Write[{write_sql}]") + logger.info(f"Write[{write_sql}]") session = self.session result = session.execute(text(write_sql)) session.commit() - # TODO Subsequent optimization of dynamically specified database submission loss target problem - print(f"SQL[{write_sql}], result:{result.rowcount}") + # TODO Subsequent optimization of dynamically specified database submission + # loss target problem + logger.info(f"SQL[{write_sql}], result:{result.rowcount}") return result.rowcount def get_table_comments(self, db_name=None): + """Get table comments.""" cursor = self.session.execute( text( - f""" + """ SELECT name, sql FROM sqlite_master WHERE type='table' """ ) @@ -124,7 +123,8 @@ class SQLiteConnect(RDBMSDatabase): ] def table_simple_info(self) -> Iterable[str]: - _tables_sql = f""" + """Get table simple info.""" + _tables_sql = """ SELECT name FROM sqlite_master WHERE type='table' """ cursor = self.session.execute(text(_tables_sql)) @@ -146,10 +146,14 @@ class SQLiteConnect(RDBMSDatabase): return results -class SQLiteTempConnect(SQLiteConnect): - """A temporary SQLite database connection. The database file will be deleted when the connection is closed.""" +class SQLiteTempConnector(SQLiteConnector): + """A temporary SQLite database connection. + + The database file will be deleted when the connection is closed. + """ def __init__(self, engine, temp_file_path, *args, **kwargs): + """Construct a temporary SQLite database connection.""" super().__init__(engine, *args, **kwargs) self.temp_file_path = temp_file_path self._is_closed = False @@ -157,7 +161,7 @@ class SQLiteTempConnect(SQLiteConnect): @classmethod def create_temporary_db( cls, engine_args: Optional[dict] = None, **kwargs: Any - ) -> "SQLiteTempConnect": + ) -> "SQLiteTempConnector": """Create a temporary SQLite database with a temporary file. Examples: @@ -175,7 +179,7 @@ class SQLiteTempConnect(SQLiteConnect): engine_args (Optional[dict]): SQLAlchemy engine arguments. Returns: - SQLiteTempConnect: A SQLiteTempConnect instance. + SQLiteTempConnector: A SQLiteTempConnect instance. """ _engine_args = engine_args or {} _engine_args["connect_args"] = {"check_same_thread": False} @@ -219,7 +223,7 @@ class SQLiteTempConnect(SQLiteConnect): ], }, } - with SQLiteTempConnect.create_temporary_db() as db: + with SQLiteTempConnector.create_temporary_db() as db: db.create_temp_tables(tables_info) field_names, result = db.query_ex(db.session, "select * from test") assert field_names == ["id", "name", "age"] @@ -248,14 +252,18 @@ class SQLiteTempConnect(SQLiteConnect): self._sync_tables_from_db() def __enter__(self): + """Return the connection when entering the context manager.""" return self def __exit__(self, exc_type, exc_val, exc_tb): + """Close the connection when exiting the context manager.""" self.close() def __del__(self): + """Close the connection when the object is deleted.""" self.close() @classmethod def is_normal_type(cls) -> bool: + """Return whether the connector is a normal type.""" return False diff --git a/dbgpt/datasource/rdbms/conn_starrocks.py b/dbgpt/datasource/rdbms/conn_starrocks.py index 26079b1fb..e3a96f5b0 100644 --- a/dbgpt/datasource/rdbms/conn_starrocks.py +++ b/dbgpt/datasource/rdbms/conn_starrocks.py @@ -1,21 +1,24 @@ -from typing import Any, Iterable, List, Optional, Tuple +"""StarRocks connector.""" +from typing import Any, Iterable, List, Optional, Tuple, Type, cast from urllib.parse import quote from urllib.parse import quote_plus as urlquote from sqlalchemy import text -from dbgpt.datasource.rdbms.base import RDBMSDatabase -from dbgpt.datasource.rdbms.dialect.starrocks.sqlalchemy import * +from .base import RDBMSConnector +from .dialect.starrocks.sqlalchemy import * # noqa -class StarRocksConnect(RDBMSDatabase): +class StarRocksConnector(RDBMSConnector): + """StarRocks connector.""" + driver = "starrocks" db_type = "starrocks" db_dialect = "starrocks" @classmethod def from_uri_db( - cls, + cls: Type["StarRocksConnector"], host: str, port: int, user: str, @@ -23,27 +26,31 @@ class StarRocksConnect(RDBMSDatabase): db_name: str, engine_args: Optional[dict] = None, **kwargs: Any, - ) -> RDBMSDatabase: + ) -> "StarRocksConnector": + """Create a new StarRocksConnector from host, port, user, pwd, db_name.""" db_url: str = ( f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}" ) - return cls.from_uri(db_url, engine_args, **kwargs) + return cast(StarRocksConnector, cls.from_uri(db_url, engine_args, **kwargs)) def _sync_tables_from_db(self) -> Iterable[str]: db_name = self.get_current_db_name() table_results = self.session.execute( text( - f'SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA="{db_name}"' + "SELECT TABLE_NAME FROM information_schema.tables where " + f'TABLE_SCHEMA="{db_name}"' ) ) - # view_results = self.session.execute(text(f'SELECT TABLE_NAME from information_schema.materialized_views where TABLE_SCHEMA="{db_name}"')) - table_results = set(row[0] for row in table_results) + # view_results = self.session.execute(text(f'SELECT TABLE_NAME from + # information_schema.materialized_views where TABLE_SCHEMA="{db_name}"')) + table_results = set(row[0] for row in table_results) # noqa: C401 # view_results = set(row[0] for row in view_results) self._all_tables = table_results self._metadata.reflect(bind=self._engine) return self._all_tables def get_grants(self): + """Get grants.""" session = self._db_sessions() cursor = session.execute(text("SHOW GRANTS")) grants = cursor.fetchall() @@ -56,7 +63,7 @@ class StarRocksConnect(RDBMSDatabase): return grants_list def _get_current_version(self): - """Get database current version""" + """Get database current version.""" return int(self.session.execute(text("select current_version()")).scalar()) def get_collation(self): @@ -75,7 +82,9 @@ class StarRocksConnect(RDBMSDatabase): db_name = f'"{db_name}"' cursor = session.execute( text( - f'select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.columns where TABLE_NAME="{table_name}" and TABLE_SCHEMA = {db_name}' + "select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, " + "COLUMN_COMMENT from information_schema.columns where " + f'TABLE_NAME="{table_name}" and TABLE_SCHEMA = {db_name}' ) ) fields = cursor.fetchall() @@ -83,10 +92,10 @@ class StarRocksConnect(RDBMSDatabase): def get_charset(self): """Get character_set.""" - return "utf-8" - def get_show_create_table(self, table_name): + def get_show_create_table(self, table_name: str): + """Get show create table.""" # cur = self.session.execute( # text( # f"""show create table {table_name}""" @@ -99,7 +108,8 @@ class StarRocksConnect(RDBMSDatabase): # 这里是要表描述, 返回建表语句会导致token过长而失败 cur = self.session.execute( text( - f'SELECT TABLE_COMMENT FROM information_schema.tables where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()' + "SELECT TABLE_COMMENT FROM information_schema.tables where " + f'TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()' ) ) table = cur.fetchone() @@ -109,20 +119,20 @@ class StarRocksConnect(RDBMSDatabase): return "" def get_table_comments(self, db_name=None): + """Get table comments.""" if not db_name: db_name = self.get_current_db_name() cur = self.session.execute( text( - f'SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables where TABLE_SCHEMA="{db_name}"' + "SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables " + f'where TABLE_SCHEMA="{db_name}"' ) ) tables = cur.fetchall() return [(table[0], table[1]) for table in tables] - def get_database_list(self): - return self.get_database_names() - def get_database_names(self): + """Get database names.""" session = self._db_sessions() cursor = session.execute(text("SHOW DATABASES;")) results = cursor.fetchall() @@ -133,11 +143,14 @@ class StarRocksConnect(RDBMSDatabase): ] def get_current_db_name(self) -> str: + """Get current database name.""" return self.session.execute(text("select database()")).scalar() def table_simple_info(self): - _sql = f""" - SELECT concat(TABLE_NAME,"(",group_concat(COLUMN_NAME,","),");") FROM information_schema.columns where TABLE_SCHEMA=database() + """Get table simple info.""" + _sql = """ + SELECT concat(TABLE_NAME,"(",group_concat(COLUMN_NAME,","),");") + FROM information_schema.columns where TABLE_SCHEMA=database() GROUP BY TABLE_NAME """ cursor = self.session.execute(text(_sql)) diff --git a/dbgpt/datasource/rdbms/dialect/__init__.py b/dbgpt/datasource/rdbms/dialect/__init__.py index e69de29bb..6669e8507 100644 --- a/dbgpt/datasource/rdbms/dialect/__init__.py +++ b/dbgpt/datasource/rdbms/dialect/__init__.py @@ -0,0 +1 @@ +"""Module for RDBMS dialects.""" diff --git a/dbgpt/datasource/rdbms/dialect/starrocks/__init__.py b/dbgpt/datasource/rdbms/dialect/starrocks/__init__.py index 20fa42f53..4518b71b0 100644 --- a/dbgpt/datasource/rdbms/dialect/starrocks/__init__.py +++ b/dbgpt/datasource/rdbms/dialect/starrocks/__init__.py @@ -1,4 +1,4 @@ -#! /usr/bin/python3 +"""StarRocks dialect for SQLAlchemy.""" # Copyright 2021-present StarRocks, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/__init__.py b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/__init__.py index 884fb87e2..553f8c3bb 100644 --- a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/__init__.py +++ b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/__init__.py @@ -1,4 +1,4 @@ -#! /usr/bin/python3 +"""SQLAlchemy dialect for StarRocks.""" # Copyright 2021-present StarRocks, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py index 4f479e79b..0388fef80 100644 --- a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py +++ b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py @@ -1,3 +1,5 @@ +"""SQLAlchemy data types for StarRocks.""" + import logging import re from typing import Any, Dict, List, Optional, Type @@ -10,50 +12,71 @@ logger = logging.getLogger(__name__) class TINYINT(Integer): # pylint: disable=no-init + """StarRocks TINYINT type.""" + __visit_name__ = "TINYINT" class LARGEINT(Integer): # pylint: disable=no-init + """StarRocks LARGEINT type.""" + __visit_name__ = "LARGEINT" class DOUBLE(Float): # pylint: disable=no-init + """StarRocks DOUBLE type.""" + __visit_name__ = "DOUBLE" class HLL(Numeric): # pylint: disable=no-init + """StarRocks HLL type.""" + __visit_name__ = "HLL" class BITMAP(Numeric): # pylint: disable=no-init + """StarRocks BITMAP type.""" + __visit_name__ = "BITMAP" class PERCENTILE(Numeric): # pylint: disable=no-init + """StarRocks PERCENTILE type.""" + __visit_name__ = "PERCENTILE" class ARRAY(TypeEngine): # pylint: disable=no-init + """StarRocks ARRAY type.""" + __visit_name__ = "ARRAY" @property - def python_type(self) -> Optional[Type[List[Any]]]: + def python_type(self) -> Optional[Type[List[Any]]]: # type: ignore + """Return the Python type for this SQL type.""" return list class MAP(TypeEngine): # pylint: disable=no-init + """StarRocks MAP type.""" + __visit_name__ = "MAP" @property - def python_type(self) -> Optional[Type[Dict[Any, Any]]]: + def python_type(self) -> Optional[Type[Dict[Any, Any]]]: # type: ignore + """Return the Python type for this SQL type.""" return dict class STRUCT(TypeEngine): # pylint: disable=no-init + """StarRocks STRUCT type.""" + __visit_name__ = "STRUCT" @property - def python_type(self) -> Optional[Type[Any]]: + def python_type(self) -> Optional[Type[Any]]: # type: ignore + """Return the Python type for this SQL type.""" return None @@ -90,6 +113,7 @@ _type_map = { def parse_sqltype(type_str: str) -> TypeEngine: + """Parse a SQL type string into a SQLAlchemy type object.""" type_str = type_str.strip().lower() match = re.match(r"^(?P\w+)\s*(?:\((?P.*)\))?", type_str) if not match: diff --git a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py index d563b9603..392918611 100644 --- a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py +++ b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py @@ -1,4 +1,4 @@ -#! /usr/bin/python3 +"""StarRocks dialect for SQLAlchemy.""" # Copyright 2021-present StarRocks, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, cast from sqlalchemy import exc, log, text from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql @@ -25,7 +25,9 @@ logger = logging.getLogger(__name__) @log.class_logger -class StarRocksDialect(MySQLDialect_pymysql): +class StarRocksDialect(MySQLDialect_pymysql): # type: ignore + """StarRocks dialect for SQLAlchemy.""" + # Caching # Warnings are generated by SQLAlchmey if this flag is not explicitly set # and tests are needed before being enabled @@ -34,9 +36,11 @@ class StarRocksDialect(MySQLDialect_pymysql): name = "starrocks" def __init__(self, *args, **kw): + """Create a new StarRocks dialect.""" super(StarRocksDialect, self).__init__(*args, **kw) - def has_table(self, connection, table_name, schema=None, **kw): + def has_table(self, connection, table_name, schema: Optional[str] = None, **kw): + """Return True if the given table is present in the database.""" self._ensure_has_table_connection(connection) if schema is None: @@ -53,15 +57,13 @@ class StarRocksDialect(MySQLDialect_pymysql): return res.first() is not None def get_schema_names(self, connection, **kw): + """Return a list of schema names available in the database.""" rp = connection.exec_driver_sql("SHOW schemas") return [r[0] for r in rp] - def get_table_names(self, connection, schema=None, **kw): + def get_table_names(self, connection, schema: Optional[str] = None, **kw): """Return a Unicode SHOW TABLES from a given schema.""" - if schema is not None: - current_schema = schema - else: - current_schema = self.default_schema_name + current_schema: str = cast(str, schema or self.default_schema_name) charset = self._connection_charset @@ -76,13 +78,15 @@ class StarRocksDialect(MySQLDialect_pymysql): if row[1] == "BASE TABLE" ] - def get_view_names(self, connection, schema=None, **kw): + def get_view_names(self, connection, schema: Optional[str] = None, **kw): + """Return a Unicode SHOW TABLES from a given schema.""" if schema is None: schema = self.default_schema_name + current_schema = cast(str, schema) charset = self._connection_charset rp = connection.exec_driver_sql( "SHOW FULL TABLES FROM %s" - % self.identifier_preparer.quote_identifier(schema) + % self.identifier_preparer.quote_identifier(current_schema) ) return [ row[0] @@ -90,9 +94,14 @@ class StarRocksDialect(MySQLDialect_pymysql): if row[1] in ("VIEW", "SYSTEM VIEW") ] - def get_columns( - self, connection: Connection, table_name: str, schema: str = None, **kw - ) -> List[Dict[str, Any]]: + def get_columns( # type: ignore + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw, + ) -> List[Dict[str, Any]]: # type: ignore + """Return information about columns in `table_name`.""" if not self.has_table(connection, table_name, schema): raise exc.NoSuchTableError(f"schema={schema}, table={table_name}") schema = schema or self._get_default_schema_name(connection) @@ -114,60 +123,100 @@ class StarRocksDialect(MySQLDialect_pymysql): columns.append(column) return columns - def get_pk_constraint(self, connection, table_name, schema=None, **kw): + def get_pk_constraint( + self, connection, table_name, schema: Optional[str] = None, **kw + ): + """Return information about the primary key constraint.""" return { # type: ignore # pep-655 not supported "name": None, "constrained_columns": [], } - def get_unique_constraints( - self, connection: Connection, table_name: str, schema: str = None, **kw + def get_unique_constraints( # type: ignore + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw, ) -> List[Dict[str, Any]]: + """Return information about unique constraints.""" return [] - def get_check_constraints( - self, connection: Connection, table_name: str, schema: str = None, **kw + def get_check_constraints( # type: ignore + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw, ) -> List[Dict[str, Any]]: + """Return information about check constraints.""" return [] - def get_foreign_keys( - self, connection: Connection, table_name: str, schema: str = None, **kw + def get_foreign_keys( # type: ignore + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw, ) -> List[Dict[str, Any]]: + """Return information about foreign keys.""" return [] def get_primary_keys( - self, connection: Connection, table_name: str, schema: str = None, **kw + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw, ) -> List[str]: + """Return the primary key columns of the given table.""" pk = self.get_pk_constraint(connection, table_name, schema) return pk.get("constrained_columns") # type: ignore - def get_indexes(self, connection, table_name, schema=None, **kw): + def get_indexes(self, connection, table_name, schema: Optional[str] = None, **kw): + """Get table indexes about specified table.""" return [] def has_sequence( - self, connection: Connection, sequence_name: str, schema: str = None, **kw + self, + connection: Connection, + sequence_name: str, + schema: Optional[str] = None, + **kw, ) -> bool: + """Return True if the given sequence is present in the database.""" return False def get_sequence_names( - self, connection: Connection, schema: str = None, **kw + self, connection: Connection, schema: Optional[str] = None, **kw ) -> List[str]: + """Return a list of sequence names.""" return [] def get_temp_view_names( - self, connection: Connection, schema: str = None, **kw + self, connection: Connection, schema: Optional[str] = None, **kw ) -> List[str]: + """Return a list of temporary view names.""" return [] def get_temp_table_names( - self, connection: Connection, schema: str = None, **kw + self, connection: Connection, schema: Optional[str] = None, **kw ) -> List[str]: + """Return a list of temporary table names.""" return [] - def get_table_options(self, connection, table_name, schema=None, **kw): + def get_table_options( + self, connection, table_name, schema: Optional[str] = None, **kw + ): + """Return a dictionary of options specified when the table was created.""" return {} - def get_table_comment( - self, connection: Connection, table_name: str, schema: str = None, **kw + def get_table_comment( # type: ignore + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw, ) -> Dict[str, Any]: + """Return the comment for a table.""" return dict(text=None) diff --git a/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py b/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py index 1b08f5506..6cfb919f8 100644 --- a/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py +++ b/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py @@ -6,14 +6,14 @@ import tempfile import pytest -from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect +from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnector @pytest.fixture def db(): temp_db_file = tempfile.NamedTemporaryFile(delete=False) temp_db_file.close() - conn = DuckDbConnect.from_file_path(temp_db_file.name + "duckdb.db") + conn = DuckDbConnector.from_file_path(temp_db_file.name + "duckdb.db") yield conn diff --git a/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py b/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py index f741158e4..3a940c158 100644 --- a/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py +++ b/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py @@ -6,14 +6,14 @@ import tempfile import pytest -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnector @pytest.fixture def db(): temp_db_file = tempfile.NamedTemporaryFile(delete=False) temp_db_file.close() - conn = SQLiteConnect.from_file_path(temp_db_file.name) + conn = SQLiteConnector.from_file_path(temp_db_file.name) yield conn try: # TODO: Failed on windows @@ -43,7 +43,7 @@ def test_run_sql(db): def test_run_no_throw(db): - assert db.run_no_throw("this is a error sql").startswith("Error:") + assert db.run_no_throw("this is a error sql") == [] def test_get_indexes(db): @@ -122,10 +122,6 @@ def test_get_table_comments(db): ] -def test_get_database_list(db): - db.get_database_list() == [] - - def test_get_database_names(db): db.get_database_names() == [] @@ -134,11 +130,11 @@ def test_db_dir_exist_dir(): with tempfile.TemporaryDirectory() as temp_dir: new_dir = os.path.join(temp_dir, "new_dir") file_path = os.path.join(new_dir, "sqlite.db") - db = SQLiteConnect.from_file_path(file_path) + db = SQLiteConnector.from_file_path(file_path) assert os.path.exists(new_dir) == True assert list(db.get_table_names()) == [] with tempfile.TemporaryDirectory() as existing_dir: file_path = os.path.join(existing_dir, "sqlite.db") - db = SQLiteConnect.from_file_path(file_path) + db = SQLiteConnector.from_file_path(file_path) assert os.path.exists(existing_dir) == True assert list(db.get_table_names()) == [] diff --git a/dbgpt/datasource/redis.py b/dbgpt/datasource/redis.py index 2502562a7..e411a8913 100644 --- a/dbgpt/datasource/redis.py +++ b/dbgpt/datasource/redis.py @@ -1,8 +1,10 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- +"""RedisConnector. + +TODO: Implement RedisConnector. +""" class RedisConnector: - """RedisConnector""" + """RedisConnector.""" pass diff --git a/dbgpt/rag/chunk_manager.py b/dbgpt/rag/chunk_manager.py index be43ce96c..80f3316f0 100644 --- a/dbgpt/rag/chunk_manager.py +++ b/dbgpt/rag/chunk_manager.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any, List, Optional from dbgpt._private.pydantic import BaseModel, Field -from dbgpt.rag.chunk import Chunk, Document +from dbgpt.core import Chunk, Document from dbgpt.rag.extractor.base import Extractor from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge from dbgpt.rag.text_splitter import TextSplitter diff --git a/dbgpt/rag/extractor/base.py b/dbgpt/rag/extractor/base.py index 6ba57f231..3763caab2 100644 --- a/dbgpt/rag/extractor/base.py +++ b/dbgpt/rag/extractor/base.py @@ -2,8 +2,7 @@ from abc import ABC, abstractmethod from typing import List -from dbgpt.core import LLMClient -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk, LLMClient class Extractor(ABC): diff --git a/dbgpt/rag/extractor/summary.py b/dbgpt/rag/extractor/summary.py index 28669a0f7..fef204679 100644 --- a/dbgpt/rag/extractor/summary.py +++ b/dbgpt/rag/extractor/summary.py @@ -3,8 +3,7 @@ from typing import List, Optional from dbgpt._private.llm_metadata import LLMMetadata -from dbgpt.core import LLMClient, ModelMessageRoleType, ModelRequest -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk, LLMClient, ModelMessageRoleType, ModelRequest from dbgpt.rag.extractor.base import Extractor from dbgpt.util import utils from dbgpt.util.chat_util import run_async_tasks diff --git a/dbgpt/rag/extractor/tests/test_summary_extractor.py b/dbgpt/rag/extractor/tests/test_summary_extractor.py index 0beba44c1..11d28af1b 100644 --- a/dbgpt/rag/extractor/tests/test_summary_extractor.py +++ b/dbgpt/rag/extractor/tests/test_summary_extractor.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import AsyncMock, MagicMock from dbgpt._private.llm_metadata import LLMMetadata -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from dbgpt.rag.extractor.summary import SummaryExtractor diff --git a/dbgpt/rag/knowledge/base.py b/dbgpt/rag/knowledge/base.py index e59138fff..46b4977ce 100644 --- a/dbgpt/rag/knowledge/base.py +++ b/dbgpt/rag/knowledge/base.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Any, List, Optional, Tuple, Type -from dbgpt.rag.chunk import Document +from dbgpt.core import Document from dbgpt.rag.text_splitter.text_splitter import ( MarkdownHeaderTextSplitter, PageTextSplitter, diff --git a/dbgpt/rag/knowledge/csv.py b/dbgpt/rag/knowledge/csv.py index 2276dfb94..6e6fb94fb 100644 --- a/dbgpt/rag/knowledge/csv.py +++ b/dbgpt/rag/knowledge/csv.py @@ -2,7 +2,7 @@ import csv from typing import Any, List, Optional -from dbgpt.rag.chunk import Document +from dbgpt.core import Document from dbgpt.rag.knowledge.base import ( ChunkStrategy, DocumentType, diff --git a/dbgpt/rag/knowledge/docx.py b/dbgpt/rag/knowledge/docx.py index b1c7e48b7..d75b5fcd3 100644 --- a/dbgpt/rag/knowledge/docx.py +++ b/dbgpt/rag/knowledge/docx.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional import docx -from dbgpt.rag.chunk import Document +from dbgpt.core import Document from dbgpt.rag.knowledge.base import ( ChunkStrategy, DocumentType, diff --git a/dbgpt/rag/knowledge/html.py b/dbgpt/rag/knowledge/html.py index 3af6fccd2..0d75592c7 100644 --- a/dbgpt/rag/knowledge/html.py +++ b/dbgpt/rag/knowledge/html.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional import chardet -from dbgpt.rag.chunk import Document +from dbgpt.core import Document from dbgpt.rag.knowledge.base import ( ChunkStrategy, DocumentType, diff --git a/dbgpt/rag/knowledge/markdown.py b/dbgpt/rag/knowledge/markdown.py index a2e92706f..c1ea6c6a3 100644 --- a/dbgpt/rag/knowledge/markdown.py +++ b/dbgpt/rag/knowledge/markdown.py @@ -1,7 +1,7 @@ """Markdown Knowledge.""" from typing import Any, List, Optional -from dbgpt.rag.chunk import Document +from dbgpt.core import Document from dbgpt.rag.knowledge.base import ( ChunkStrategy, DocumentType, diff --git a/dbgpt/rag/knowledge/pdf.py b/dbgpt/rag/knowledge/pdf.py index fab8e26e2..c50750369 100644 --- a/dbgpt/rag/knowledge/pdf.py +++ b/dbgpt/rag/knowledge/pdf.py @@ -1,7 +1,7 @@ """PDF Knowledge.""" from typing import Any, List, Optional -from dbgpt.rag.chunk import Document +from dbgpt.core import Document from dbgpt.rag.knowledge.base import ( ChunkStrategy, DocumentType, diff --git a/dbgpt/rag/knowledge/pptx.py b/dbgpt/rag/knowledge/pptx.py index 4f4a35e08..2232206f3 100644 --- a/dbgpt/rag/knowledge/pptx.py +++ b/dbgpt/rag/knowledge/pptx.py @@ -1,7 +1,7 @@ """PPTX Knowledge.""" from typing import Any, List, Optional -from dbgpt.rag.chunk import Document +from dbgpt.core import Document from dbgpt.rag.knowledge.base import ( ChunkStrategy, DocumentType, diff --git a/dbgpt/rag/knowledge/string.py b/dbgpt/rag/knowledge/string.py index fa007c6eb..3ec11266b 100644 --- a/dbgpt/rag/knowledge/string.py +++ b/dbgpt/rag/knowledge/string.py @@ -1,7 +1,7 @@ """String Knowledge.""" from typing import Any, List, Optional -from dbgpt.rag.chunk import Document +from dbgpt.core import Document from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType diff --git a/dbgpt/rag/knowledge/txt.py b/dbgpt/rag/knowledge/txt.py index 6c33b136b..391440011 100644 --- a/dbgpt/rag/knowledge/txt.py +++ b/dbgpt/rag/knowledge/txt.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional import chardet -from dbgpt.rag.chunk import Document +from dbgpt.core import Document from dbgpt.rag.knowledge.base import ( ChunkStrategy, DocumentType, diff --git a/dbgpt/rag/knowledge/url.py b/dbgpt/rag/knowledge/url.py index 38670463b..98a403a94 100644 --- a/dbgpt/rag/knowledge/url.py +++ b/dbgpt/rag/knowledge/url.py @@ -1,7 +1,7 @@ """URL Knowledge.""" from typing import Any, List, Optional -from dbgpt.rag.chunk import Document +from dbgpt.core import Document from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType diff --git a/dbgpt/rag/operators/datasource.py b/dbgpt/rag/operators/datasource.py index b236c2cff..ac2ccc9b6 100644 --- a/dbgpt/rag/operators/datasource.py +++ b/dbgpt/rag/operators/datasource.py @@ -3,14 +3,14 @@ from typing import Any from dbgpt.core.interface.operators.retriever import RetrieverOperator -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from dbgpt.datasource.rdbms.base import RDBMSConnector from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]): """The Datasource Retriever Operator.""" - def __init__(self, connection: RDBMSDatabase, **kwargs): + def __init__(self, connection: RDBMSConnector, **kwargs): """Create a new DatasourceRetrieverOperator.""" super().__init__(**kwargs) self._connection = connection diff --git a/dbgpt/rag/operators/db_schema.py b/dbgpt/rag/operators/db_schema.py index 49fffa3e4..b7ffe1abe 100644 --- a/dbgpt/rag/operators/db_schema.py +++ b/dbgpt/rag/operators/db_schema.py @@ -3,7 +3,7 @@ from typing import Any, Optional from dbgpt.core.interface.operators.retriever import RetrieverOperator -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from dbgpt.datasource.rdbms.base import RDBMSConnector from dbgpt.rag.retriever.db_schema import DBSchemaRetriever from dbgpt.storage.vector_store.connector import VectorStoreConnector @@ -12,7 +12,7 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]): """The DBSchema Retriever Operator. Args: - connection (RDBMSDatabase): The connection. + connection (RDBMSConnector): The connection. top_k (int, optional): The top k. Defaults to 4. vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None. @@ -22,7 +22,7 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]): self, vector_store_connector: VectorStoreConnector, top_k: int = 4, - connection: Optional[RDBMSDatabase] = None, + connection: Optional[RDBMSConnector] = None, **kwargs ): """Create a new DBSchemaRetrieverOperator.""" diff --git a/dbgpt/rag/operators/embedding.py b/dbgpt/rag/operators/embedding.py index d21401be2..7e926768a 100644 --- a/dbgpt/rag/operators/embedding.py +++ b/dbgpt/rag/operators/embedding.py @@ -3,8 +3,8 @@ from functools import reduce from typing import List, Optional, Union +from dbgpt.core import Chunk from dbgpt.core.interface.operators.retriever import RetrieverOperator -from dbgpt.rag.chunk import Chunk from dbgpt.rag.retriever.embedding import EmbeddingRetriever from dbgpt.rag.retriever.rerank import Ranker from dbgpt.rag.retriever.rewrite import QueryRewrite diff --git a/dbgpt/rag/operators/evaluation.py b/dbgpt/rag/operators/evaluation.py index 81c5ac52c..71c0aff1a 100644 --- a/dbgpt/rag/operators/evaluation.py +++ b/dbgpt/rag/operators/evaluation.py @@ -2,12 +2,11 @@ import asyncio from typing import Any, List, Optional +from dbgpt.core import Chunk from dbgpt.core.awel import JoinOperator from dbgpt.core.interface.evaluation import EvaluationMetric, EvaluationResult from dbgpt.core.interface.llm import LLMClient -from ..chunk import Chunk - class RetrieverEvaluatorOperator(JoinOperator[List[EvaluationResult]]): """Evaluator for retriever.""" diff --git a/dbgpt/rag/operators/rerank.py b/dbgpt/rag/operators/rerank.py index cde2c8b08..b9d5eb859 100644 --- a/dbgpt/rag/operators/rerank.py +++ b/dbgpt/rag/operators/rerank.py @@ -1,8 +1,8 @@ """The Rerank Operator.""" from typing import Any, List, Optional +from dbgpt.core import Chunk from dbgpt.core.awel import MapOperator -from dbgpt.rag.chunk import Chunk from dbgpt.rag.retriever.rerank import RANK_FUNC, DefaultRanker diff --git a/dbgpt/rag/operators/schema_linking.py b/dbgpt/rag/operators/schema_linking.py index 59f9c7185..c4355ec5c 100644 --- a/dbgpt/rag/operators/schema_linking.py +++ b/dbgpt/rag/operators/schema_linking.py @@ -7,7 +7,7 @@ from typing import Any, Optional from dbgpt.core import LLMClient from dbgpt.core.awel import MapOperator -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from dbgpt.datasource.rdbms.base import RDBMSConnector from dbgpt.rag.schemalinker.schema_linking import SchemaLinking from dbgpt.storage.vector_store.connector import VectorStoreConnector @@ -17,7 +17,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]): def __init__( self, - connection: RDBMSDatabase, + connection: RDBMSConnector, model_name: str, llm: LLMClient, top_k: int = 5, @@ -27,7 +27,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]): """Create the schema linking operator. Args: - connection (RDBMSDatabase): The connection. + connection (RDBMSConnector): The connection. llm (Optional[LLMClient]): base llm """ super().__init__(**kwargs) diff --git a/dbgpt/rag/retriever/base.py b/dbgpt/rag/retriever/base.py index 0ecf28e85..737f6c514 100644 --- a/dbgpt/rag/retriever/base.py +++ b/dbgpt/rag/retriever/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import List -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk class RetrieverStrategy(str, Enum): diff --git a/dbgpt/rag/retriever/db_schema.py b/dbgpt/rag/retriever/db_schema.py index bfd74dd1f..3d9e9bd88 100644 --- a/dbgpt/rag/retriever/db_schema.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -2,8 +2,8 @@ from functools import reduce from typing import List, Optional, cast -from dbgpt.datasource.rdbms.base import RDBMSDatabase -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk +from dbgpt.datasource.rdbms.base import RDBMSConnector from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary @@ -18,7 +18,7 @@ class DBSchemaRetriever(BaseRetriever): self, vector_store_connector: VectorStoreConnector, top_k: int = 4, - connection: Optional[RDBMSDatabase] = None, + connection: Optional[RDBMSConnector] = None, query_rewrite: bool = False, rerank: Optional[Ranker] = None, **kwargs @@ -28,14 +28,14 @@ class DBSchemaRetriever(BaseRetriever): Args: vector_store_connector (VectorStoreConnector): vector store connector top_k (int): top k - connection (Optional[RDBMSDatabase]): RDBMSDatabase connection. + connection (Optional[RDBMSConnector]): RDBMSConnector connection. query_rewrite (bool): query rewrite rerank (Ranker): rerank Examples: .. code-block:: python - from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect + from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig @@ -43,7 +43,7 @@ class DBSchemaRetriever(BaseRetriever): def _create_temporary_connection(): - connect = SQLiteTempConnect.create_temporary_db() + connect = SQLiteTempConnector.create_temporary_db() connect.create_temp_tables( { "user": { @@ -109,7 +109,7 @@ class DBSchemaRetriever(BaseRetriever): return cast(List[Chunk], reduce(lambda x, y: x + y, candidates)) else: if not self._connection: - raise RuntimeError("RDBMSDatabase connection is required.") + raise RuntimeError("RDBMSConnector connection is required.") table_summaries = _parse_db_summary(self._connection) return [Chunk(content=table_summary) for table_summary in table_summaries] @@ -174,5 +174,5 @@ class DBSchemaRetriever(BaseRetriever): from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary if not self._connection: - raise RuntimeError("RDBMSDatabase connection is required.") + raise RuntimeError("RDBMSConnector connection is required.") return _parse_db_summary(self._connection) diff --git a/dbgpt/rag/retriever/embedding.py b/dbgpt/rag/retriever/embedding.py index e55c38465..a97764bf6 100644 --- a/dbgpt/rag/retriever/embedding.py +++ b/dbgpt/rag/retriever/embedding.py @@ -2,7 +2,7 @@ from functools import reduce from typing import List, Optional, cast -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker from dbgpt.rag.retriever.rewrite import QueryRewrite diff --git a/dbgpt/rag/retriever/rerank.py b/dbgpt/rag/retriever/rerank.py index cd29db551..59cadd081 100644 --- a/dbgpt/rag/retriever/rerank.py +++ b/dbgpt/rag/retriever/rerank.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Callable, List, Optional -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk RANK_FUNC = Callable[[List[Chunk]], List[Chunk]] diff --git a/dbgpt/rag/retriever/tests/test_db_struct.py b/dbgpt/rag/retriever/tests/test_db_struct.py index 349111309..109229a41 100644 --- a/dbgpt/rag/retriever/tests/test_db_struct.py +++ b/dbgpt/rag/retriever/tests/test_db_struct.py @@ -1,10 +1,10 @@ from typing import List -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest import dbgpt -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from dbgpt.rag.retriever.db_schema import DBSchemaRetriever from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary diff --git a/dbgpt/rag/retriever/tests/test_embedding.py b/dbgpt/rag/retriever/tests/test_embedding.py index 7c9f79dee..0e8b965ea 100644 --- a/dbgpt/rag/retriever/tests/test_embedding.py +++ b/dbgpt/rag/retriever/tests/test_embedding.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock import pytest -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from dbgpt.rag.retriever.embedding import EmbeddingRetriever diff --git a/dbgpt/rag/schemalinker/schema_linking.py b/dbgpt/rag/schemalinker/schema_linking.py index 748b3a928..4897aed49 100644 --- a/dbgpt/rag/schemalinker/schema_linking.py +++ b/dbgpt/rag/schemalinker/schema_linking.py @@ -3,9 +3,14 @@ from functools import reduce from typing import List, Optional, cast -from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest -from dbgpt.datasource.rdbms.base import RDBMSDatabase -from dbgpt.rag.chunk import Chunk +from dbgpt.core import ( + Chunk, + LLMClient, + ModelMessage, + ModelMessageRoleType, + ModelRequest, +) +from dbgpt.datasource.rdbms.base import RDBMSConnector from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary from dbgpt.storage.vector_store.connector import VectorStoreConnector @@ -37,7 +42,7 @@ class SchemaLinking(BaseSchemaLinker): def __init__( self, - connection: RDBMSDatabase, + connection: RDBMSConnector, model_name: str, llm: LLMClient, top_k: int = 5, @@ -47,7 +52,7 @@ class SchemaLinking(BaseSchemaLinker): """Create the schema linking instance. Args: - connection (Optional[RDBMSDatabase]): RDBMSDatabase connection. + connection (Optional[RDBMSConnector]): RDBMSConnector connection. llm (Optional[LLMClient]): base llm """ super().__init__(**kwargs) diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 6cb725b0b..ab88a3940 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -67,7 +67,7 @@ class DBSummaryClient: def init_db_summary(self): """Initialize db summary profile.""" - db_mange = CFG.LOCAL_DB_MANAGE + db_mange = CFG.local_db_manager dbs = db_mange.get_db_list() for item in dbs: try: diff --git a/dbgpt/rag/summary/rdbms_db_summary.py b/dbgpt/rag/summary/rdbms_db_summary.py index ceae66527..54e30d20b 100644 --- a/dbgpt/rag/summary/rdbms_db_summary.py +++ b/dbgpt/rag/summary/rdbms_db_summary.py @@ -1,11 +1,14 @@ """Summary for rdbms database.""" -from typing import List +from typing import TYPE_CHECKING, List, Optional from dbgpt._private.config import Config -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from dbgpt.datasource.rdbms.base import RDBMSConnector from dbgpt.rag.summary.db_summary import DBSummary +if TYPE_CHECKING: + from dbgpt.datasource.manages import ConnectorManager + CFG = Config() @@ -17,7 +20,9 @@ class RdbmsSummary(DBSummary): column3(column3 comment) and index keys, and table comment is {table_comment}) """ - def __init__(self, name: str, type: str): + def __init__( + self, name: str, type: str, manager: Optional["ConnectorManager"] = None + ): """Create a new RdbmsSummary.""" self.name = name self.type = type @@ -26,10 +31,11 @@ class RdbmsSummary(DBSummary): # self.tables_info = [] # self.vector_tables_info = [] - if not CFG.LOCAL_DB_MANAGE: - raise ValueError("Local db manage is not initialized.") # TODO: Don't use the global variable. - self.db = CFG.LOCAL_DB_MANAGE.get_connect(name) + db_manager = manager or CFG.local_db_manager + if not db_manager: + raise ValueError("Local db manage is not initialized.") + self.db = db_manager.get_connector(name) self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format( @@ -58,12 +64,12 @@ class RdbmsSummary(DBSummary): def _parse_db_summary( - conn: RDBMSDatabase, summary_template: str = "{table_name}({columns})" + conn: RDBMSConnector, summary_template: str = "{table_name}({columns})" ) -> List[str]: """Get db summary for database. Args: - conn (RDBMSDatabase): database connection + conn (RDBMSConnector): database connection summary_template (str): summary template """ tables = conn.get_table_names() @@ -75,12 +81,12 @@ def _parse_db_summary( def _parse_table_summary( - conn: RDBMSDatabase, summary_template: str, table_name: str + conn: RDBMSConnector, summary_template: str, table_name: str ) -> str: """Get table summary for table. Args: - conn (RDBMSDatabase): database connection + conn (RDBMSConnector): database connection summary_template (str): summary template table_name (str): table name diff --git a/dbgpt/rag/summary/tests/test_rdbms_summary.py b/dbgpt/rag/summary/tests/test_rdbms_summary.py index 1ccc8aac2..c9028a714 100644 --- a/dbgpt/rag/summary/tests/test_rdbms_summary.py +++ b/dbgpt/rag/summary/tests/test_rdbms_summary.py @@ -1,10 +1,11 @@ import unittest +from typing import List from unittest.mock import Mock, patch from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary -class MockRDBMSDatabase(object): +class MockRDBMSConnector(object): def get_users(self): return "user1, user2" @@ -35,15 +36,12 @@ class MockRDBMSDatabase(object): class TestRdbmsSummary(unittest.TestCase): def setUp(self): self.mock_local_db_manage = Mock() - self.mock_local_db_manage.get_connect.return_value = MockRDBMSDatabase() - self.patcher = patch( - "dbgpt.rag.summary.rdbms_db_summary.CFG.LOCAL_DB_MANAGE", - new=self.mock_local_db_manage, - ) - self.patcher.start() + self.mock_local_db_manage.get_connector.return_value = MockRDBMSConnector() def test_rdbms_summary_initialization(self): - rdbms_summary = RdbmsSummary(name="test_db", type="test_type") + rdbms_summary = RdbmsSummary( + name="test_db", type="test_type", manager=self.mock_local_db_manage + ) self.assertEqual(rdbms_summary.name, "test_db") self.assertEqual(rdbms_summary.type, "test_type") self.assertTrue("user info :user1, user2" in rdbms_summary.metadata) @@ -52,7 +50,9 @@ class TestRdbmsSummary(unittest.TestCase): self.assertTrue("collation:utf8_general_ci" in rdbms_summary.metadata) def test_table_summaries(self): - rdbms_summary = RdbmsSummary(name="test_db", type="test_type") + rdbms_summary = RdbmsSummary( + name="test_db", type="test_type", manager=self.mock_local_db_manage + ) summaries = rdbms_summary.table_summaries() self.assertTrue( "table1(column1 (first column), column2), and index keys: index1(`column1`) , and table comment: table1 comment" diff --git a/dbgpt/rag/text_splitter/pre_text_splitter.py b/dbgpt/rag/text_splitter/pre_text_splitter.py index 3c43cf8d2..fff8185fb 100644 --- a/dbgpt/rag/text_splitter/pre_text_splitter.py +++ b/dbgpt/rag/text_splitter/pre_text_splitter.py @@ -1,7 +1,7 @@ """Pre text splitter.""" from typing import Iterable, List -from dbgpt.rag.chunk import Chunk, Document +from dbgpt.core import Chunk, Document from dbgpt.rag.text_splitter.text_splitter import TextSplitter diff --git a/dbgpt/rag/text_splitter/tests/test_splitters.py b/dbgpt/rag/text_splitter/tests/test_splitters.py index 3d3ebb764..fd32956ce 100644 --- a/dbgpt/rag/text_splitter/tests/test_splitters.py +++ b/dbgpt/rag/text_splitter/tests/test_splitters.py @@ -1,4 +1,4 @@ -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from dbgpt.rag.text_splitter.text_splitter import ( CharacterTextSplitter, MarkdownHeaderTextSplitter, diff --git a/dbgpt/rag/text_splitter/text_splitter.py b/dbgpt/rag/text_splitter/text_splitter.py index 8516dc279..fd8c78e2c 100644 --- a/dbgpt/rag/text_splitter/text_splitter.py +++ b/dbgpt/rag/text_splitter/text_splitter.py @@ -5,7 +5,7 @@ import logging from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Iterable, List, Optional, TypedDict, Union, cast -from dbgpt.rag.chunk import Chunk, Document +from dbgpt.core import Chunk, Document logger = logging.getLogger(__name__) diff --git a/dbgpt/serve/agent/app/controller.py b/dbgpt/serve/agent/app/controller.py index 6dc829802..eae45f973 100644 --- a/dbgpt/serve/agent/app/controller.py +++ b/dbgpt/serve/agent/app/controller.py @@ -148,7 +148,7 @@ async def app_resources( results = [] match type: case ResourceType.DB.value: - dbs = CFG.LOCAL_DB_MANAGE.get_db_list() + dbs = CFG.local_db_manager.get_db_list() results = [db["db_name"] for db in dbs] if name: results = [r for r in results if name in r] diff --git a/dbgpt/serve/agent/resource_loader/datasource_load_client.py b/dbgpt/serve/agent/resource_loader/datasource_load_client.py index edc34e626..a32210ab0 100644 --- a/dbgpt/serve/agent/resource_loader/datasource_load_client.py +++ b/dbgpt/serve/agent/resource_loader/datasource_load_client.py @@ -22,7 +22,7 @@ class DatasourceLoadClient(ResourceDbClient): ).create() def get_data_type(self, resource: AgentResource) -> str: - conn = CFG.LOCAL_DB_MANAGE.get_connect(resource.value) + conn = CFG.local_db_manager.get_connector(resource.value) return conn.db_type async def a_get_schema_link(self, db: str, question: Optional[str] = None) -> str: @@ -44,7 +44,7 @@ class DatasourceLoadClient(ResourceDbClient): except Exception as e: print("db summary find error!" + str(e)) if not table_infos: - conn = CFG.LOCAL_DB_MANAGE.get_connect(db) + conn = CFG.local_db_manager.get_connector(db) table_infos = await blocking_func_to_async( self._executor, conn.table_simple_info ) @@ -52,13 +52,13 @@ class DatasourceLoadClient(ResourceDbClient): return table_infos async def a_query_to_df(self, db: str, sql: str): - conn = CFG.LOCAL_DB_MANAGE.get_connect(db) + conn = CFG.local_db_manager.get_connector(db) return conn.run_to_df(sql) async def a_query(self, db: str, sql: str): - conn = CFG.LOCAL_DB_MANAGE.get_connect(db) + conn = CFG.local_db_manager.get_connector(db) return conn.query_ex(sql) async def a_run_sql(self, db: str, sql: str): - conn = CFG.LOCAL_DB_MANAGE.get_connect(db) + conn = CFG.local_db_manager.get_connector(db) return conn.run(sql) diff --git a/dbgpt/serve/agent/resource_loader/knowledge_space_load_client.py b/dbgpt/serve/agent/resource_loader/knowledge_space_load_client.py index d90d21b4f..b0f216147 100644 --- a/dbgpt/serve/agent/resource_loader/knowledge_space_load_client.py +++ b/dbgpt/serve/agent/resource_loader/knowledge_space_load_client.py @@ -1,17 +1,11 @@ -import json import logging -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional from dbgpt._private.config import Config -from dbgpt.agent.plugin.generator import PluginPromptGenerator -from dbgpt.agent.resource.resource_api import AgentResource, ResourceType +from dbgpt.agent.resource.resource_api import AgentResource from dbgpt.agent.resource.resource_knowledge_api import ResourceKnowledgeClient -from dbgpt.component import ComponentType -from dbgpt.rag.chunk import Chunk -from dbgpt.serve.agent.hub.controller import ModulePlugin +from dbgpt.core import Chunk from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever -from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async -from dbgpt.util.tracer import root_tracer, trace CFG = Config() diff --git a/dbgpt/serve/rag/assembler/base.py b/dbgpt/serve/rag/assembler/base.py index 72501f5b7..20bb439bc 100644 --- a/dbgpt/serve/rag/assembler/base.py +++ b/dbgpt/serve/rag/assembler/base.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod from typing import Any, List, Optional -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters from dbgpt.rag.extractor.base import Extractor from dbgpt.rag.knowledge.base import Knowledge from dbgpt.rag.retriever.base import BaseRetriever -from dbgpt.util.tracer import root_tracer, trace +from dbgpt.util.tracer import root_tracer class BaseAssembler(ABC): diff --git a/dbgpt/serve/rag/assembler/db_schema.py b/dbgpt/serve/rag/assembler/db_schema.py index eed45ba34..b5fe877bf 100644 --- a/dbgpt/serve/rag/assembler/db_schema.py +++ b/dbgpt/serve/rag/assembler/db_schema.py @@ -1,8 +1,7 @@ -import os from typing import Any, List, Optional -from dbgpt.datasource.rdbms.base import RDBMSDatabase -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk +from dbgpt.datasource.rdbms.base import RDBMSConnector from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge @@ -18,12 +17,12 @@ class DBSchemaAssembler(BaseAssembler): Example: .. code-block:: python - from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect + from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig - connection = SQLiteTempConnect.create_temporary_db() + connection = SQLiteTempConnector.create_temporary_db() assembler = DBSchemaAssembler.load_from_connection( connection=connection, embedding_model=embedding_model_path, @@ -35,7 +34,7 @@ class DBSchemaAssembler(BaseAssembler): def __init__( self, - connection: RDBMSDatabase = None, + connection: RDBMSConnector = None, chunk_parameters: Optional[ChunkParameters] = None, embedding_model: Optional[str] = None, embedding_factory: Optional[EmbeddingFactory] = None, @@ -44,7 +43,7 @@ class DBSchemaAssembler(BaseAssembler): ) -> None: """Initialize with Embedding Assembler arguments. Args: - connection: (RDBMSDatabase) RDBMSDatabase connection. + connection: (RDBMSConnector) RDBMSConnector connection. knowledge: (Knowledge) Knowledge datasource. chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking. embedding_model: (Optional[str]) Embedding model to use. @@ -76,7 +75,7 @@ class DBSchemaAssembler(BaseAssembler): @classmethod def load_from_connection( cls, - connection: RDBMSDatabase = None, + connection: RDBMSConnector = None, knowledge: Optional[Knowledge] = None, chunk_parameters: Optional[ChunkParameters] = None, embedding_model: Optional[str] = None, @@ -85,7 +84,7 @@ class DBSchemaAssembler(BaseAssembler): ) -> "DBSchemaAssembler": """Load document embedding into vector store from path. Args: - connection: (RDBMSDatabase) RDBMSDatabase connection. + connection: (RDBMSConnector) RDBMSDatabase connection. knowledge: (Knowledge) Knowledge datasource. chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking. embedding_model: (Optional[str]) Embedding model to use. diff --git a/dbgpt/serve/rag/assembler/embedding.py b/dbgpt/serve/rag/assembler/embedding.py index b43803a42..a21d7bd85 100644 --- a/dbgpt/serve/rag/assembler/embedding.py +++ b/dbgpt/serve/rag/assembler/embedding.py @@ -1,7 +1,7 @@ import os from typing import Any, List, Optional -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from dbgpt.rag.chunk_manager import ChunkParameters from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.knowledge.base import Knowledge diff --git a/dbgpt/serve/rag/assembler/summary.py b/dbgpt/serve/rag/assembler/summary.py index 5ff3fa8a3..188a1e6de 100644 --- a/dbgpt/serve/rag/assembler/summary.py +++ b/dbgpt/serve/rag/assembler/summary.py @@ -1,8 +1,7 @@ import os from typing import Any, List, Optional -from dbgpt.core import LLMClient -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk, LLMClient from dbgpt.rag.chunk_manager import ChunkParameters from dbgpt.rag.extractor.base import Extractor from dbgpt.rag.knowledge.base import Knowledge diff --git a/dbgpt/serve/rag/assembler/tests/test_db_struct_assembler.py b/dbgpt/serve/rag/assembler/tests/test_db_struct_assembler.py index 34f15b458..c6fbe0539 100644 --- a/dbgpt/serve/rag/assembler/tests/test_db_struct_assembler.py +++ b/dbgpt/serve/rag/assembler/tests/test_db_struct_assembler.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock import pytest -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.knowledge.base import Knowledge @@ -14,7 +14,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector @pytest.fixture def mock_db_connection(): """Create a temporary database connection for testing.""" - connect = SQLiteTempConnect.create_temporary_db() + connect = SQLiteTempConnector.create_temporary_db() connect.create_temp_tables( { "user": { diff --git a/dbgpt/serve/rag/assembler/tests/test_embedding_assembler.py b/dbgpt/serve/rag/assembler/tests/test_embedding_assembler.py index fc77c18c9..1570774fd 100644 --- a/dbgpt/serve/rag/assembler/tests/test_embedding_assembler.py +++ b/dbgpt/serve/rag/assembler/tests/test_embedding_assembler.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, patch import pytest -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.knowledge.base import Knowledge @@ -14,7 +14,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector @pytest.fixture def mock_db_connection(): """Create a temporary database connection for testing.""" - connect = SQLiteTempConnect.create_temporary_db() + connect = SQLiteTempConnector.create_temporary_db() connect.create_temp_tables( { "user": { diff --git a/dbgpt/serve/rag/operators/db_schema.py b/dbgpt/serve/rag/operators/db_schema.py index 995a05f7e..b0affd0a3 100644 --- a/dbgpt/serve/rag/operators/db_schema.py +++ b/dbgpt/serve/rag/operators/db_schema.py @@ -1,7 +1,7 @@ from typing import Any, Optional from dbgpt.core.awel.task.base import IN -from dbgpt.datasource.rdbms.base import RDBMSDatabase +from dbgpt.datasource.rdbms.base import RDBMSConnector from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.serve.rag.operators.base import AssemblerOperator from dbgpt.storage.vector_store.connector import VectorStoreConnector @@ -10,14 +10,14 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector class DBSchemaAssemblerOperator(AssemblerOperator[Any, Any]): """The DBSchema Assembler Operator. Args: - connection (RDBMSDatabase): The connection. + connection (RDBMSConnector): The connection. chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None. vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None. """ def __init__( self, - connection: RDBMSDatabase = None, + connection: RDBMSConnector = None, vector_store_connector: Optional[VectorStoreConnector] = None, **kwargs ): diff --git a/dbgpt/serve/rag/retriever/knowledge_space.py b/dbgpt/serve/rag/retriever/knowledge_space.py index 34c3915f5..72ac30580 100644 --- a/dbgpt/serve/rag/retriever/knowledge_space.py +++ b/dbgpt/serve/rag/retriever/knowledge_space.py @@ -3,7 +3,7 @@ from typing import List, Optional from dbgpt._private.config import Config from dbgpt.component import ComponentType from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.storage.vector_store.connector import VectorStoreConnector diff --git a/dbgpt/storage/vector_store/base.py b/dbgpt/storage/vector_store/base.py index 2de3c3223..28412967c 100644 --- a/dbgpt/storage/vector_store/base.py +++ b/dbgpt/storage/vector_store/base.py @@ -7,8 +7,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import List, Optional from dbgpt._private.pydantic import BaseModel, Field -from dbgpt.core import Embeddings -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk, Embeddings logger = logging.getLogger(__name__) diff --git a/dbgpt/storage/vector_store/chroma_store.py b/dbgpt/storage/vector_store/chroma_store.py index 93c4c239d..453e45691 100644 --- a/dbgpt/storage/vector_store/chroma_store.py +++ b/dbgpt/storage/vector_store/chroma_store.py @@ -10,7 +10,7 @@ from dbgpt._private.pydantic import Field from dbgpt.configs.model_config import PILOT_PATH # TODO: Recycle dependency on rag and storage -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from .base import VectorStoreBase, VectorStoreConfig diff --git a/dbgpt/storage/vector_store/connector.py b/dbgpt/storage/vector_store/connector.py index 851f07337..ae18a7c90 100644 --- a/dbgpt/storage/vector_store/connector.py +++ b/dbgpt/storage/vector_store/connector.py @@ -3,7 +3,7 @@ import os from typing import Any, Dict, List, Optional, Type, cast -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from dbgpt.storage import vector_store from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig diff --git a/dbgpt/storage/vector_store/milvus_store.py b/dbgpt/storage/vector_store/milvus_store.py index 58281d2c2..65e056329 100644 --- a/dbgpt/storage/vector_store/milvus_store.py +++ b/dbgpt/storage/vector_store/milvus_store.py @@ -7,8 +7,7 @@ import os from typing import Any, Iterable, List, Optional from dbgpt._private.pydantic import Field -from dbgpt.core import Embeddings -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk, Embeddings from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig from dbgpt.util import string_utils diff --git a/dbgpt/storage/vector_store/pgvector_store.py b/dbgpt/storage/vector_store/pgvector_store.py index 7038e4c7e..cf71c5801 100644 --- a/dbgpt/storage/vector_store/pgvector_store.py +++ b/dbgpt/storage/vector_store/pgvector_store.py @@ -4,7 +4,7 @@ from typing import Any, List from dbgpt._private.config import Config from dbgpt._private.pydantic import Field -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig logger = logging.getLogger(__name__) diff --git a/dbgpt/storage/vector_store/weaviate_store.py b/dbgpt/storage/vector_store/weaviate_store.py index 776d40e08..226f1f672 100644 --- a/dbgpt/storage/vector_store/weaviate_store.py +++ b/dbgpt/storage/vector_store/weaviate_store.py @@ -5,7 +5,7 @@ from typing import List from dbgpt._private.config import Config from dbgpt._private.pydantic import Field -from dbgpt.rag.chunk import Chunk +from dbgpt.core import Chunk from .base import VectorStoreBase, VectorStoreConfig diff --git a/examples/awel/simple_nl_schema_sql_chart_example.py b/examples/awel/simple_nl_schema_sql_chart_example.py index 93332ec51..b430c1307 100644 --- a/examples/awel/simple_nl_schema_sql_chart_example.py +++ b/examples/awel/simple_nl_schema_sql_chart_example.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator -from dbgpt.datasource.rdbms.base import RDBMSDatabase -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect +from dbgpt.datasource.rdbms.base import RDBMSConnector +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.model.proxy import OpenAILLMClient from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.operators.schema_linking import SchemaLinkingOperator @@ -66,7 +66,7 @@ def _create_vector_connector(): def _create_temporary_connection(): """Create a temporary database connection for testing.""" - connect = SQLiteTempConnect.create_temporary_db() + connect = SQLiteTempConnector.create_temporary_db() connect.create_temp_tables( { "user": { @@ -181,10 +181,10 @@ class SqlGenOperator(MapOperator[Any, Any]): class SqlExecOperator(MapOperator[Any, Any]): """The Sql Execution Operator.""" - def __init__(self, connection: Optional[RDBMSDatabase] = None, **kwargs): + def __init__(self, connection: Optional[RDBMSConnector] = None, **kwargs): """ Args: - connection (Optional[RDBMSDatabase]): RDBMSDatabase connection + connection (Optional[RDBMSConnector]): RDBMSConnector connection """ super().__init__(**kwargs) self._connection = connection @@ -207,7 +207,7 @@ class ChartDrawOperator(MapOperator[Any, Any]): def __init__(self, **kwargs): """ Args: - connection (RDBMSDatabase): The connection. + connection (RDBMSConnector): The connection. """ super().__init__(**kwargs) diff --git a/examples/rag/db_schema_rag_example.py b/examples/rag/db_schema_rag_example.py index 71efdfac5..3e7835fcf 100644 --- a/examples/rag/db_schema_rag_example.py +++ b/examples/rag/db_schema_rag_example.py @@ -1,7 +1,7 @@ import os from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.rag.embedding import DefaultEmbeddingFactory from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig @@ -22,7 +22,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector def _create_temporary_connection(): """Create a temporary database connection for testing.""" - connect = SQLiteTempConnect.create_temporary_db() + connect = SQLiteTempConnector.create_temporary_db() connect.create_temp_tables( { "user": { diff --git a/examples/rag/simple_dbschema_retriever_example.py b/examples/rag/simple_dbschema_retriever_example.py index 354c72d78..90cea6a4c 100644 --- a/examples/rag/simple_dbschema_retriever_example.py +++ b/examples/rag/simple_dbschema_retriever_example.py @@ -29,9 +29,9 @@ from pydantic import BaseModel, Field from dbgpt._private.config import Config from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH +from dbgpt.core import Chunk from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect -from dbgpt.rag.chunk import Chunk +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.rag.embedding import DefaultEmbeddingFactory from dbgpt.rag.operators import DBSchemaRetrieverOperator from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator @@ -57,7 +57,7 @@ def _create_vector_connector(): def _create_temporary_connection(): """Create a temporary database connection for testing.""" - connect = SQLiteTempConnect.create_temporary_db() + connect = SQLiteTempConnector.create_temporary_db() connect.create_temp_tables( { "user": { diff --git a/examples/rag/simple_rag_retriever_example.py b/examples/rag/simple_rag_retriever_example.py index 19ab78666..798f53468 100644 --- a/examples/rag/simple_rag_retriever_example.py +++ b/examples/rag/simple_rag_retriever_example.py @@ -33,9 +33,9 @@ from pydantic import BaseModel, Field from dbgpt._private.config import Config from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH +from dbgpt.core import Chunk from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.model.proxy import OpenAILLMClient -from dbgpt.rag.chunk import Chunk from dbgpt.rag.embedding import DefaultEmbeddingFactory from dbgpt.rag.operators import ( EmbeddingRetrieverOperator, diff --git a/examples/sdk/simple_sdk_llm_sql_example.py b/examples/sdk/simple_sdk_llm_sql_example.py index d518395ca..43bc24e97 100644 --- a/examples/sdk/simple_sdk_llm_sql_example.py +++ b/examples/sdk/simple_sdk_llm_sql_example.py @@ -16,14 +16,14 @@ from dbgpt.core.operators import ( RequestBuilderOperator, ) from dbgpt.datasource.operators.datasource_operator import DatasourceOperator -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.model.proxy import OpenAILLMClient from dbgpt.rag.operators.datasource import DatasourceRetrieverOperator def _create_temporary_connection(): """Create a temporary database connection for testing.""" - connect = SQLiteTempConnect.create_temporary_db() + connect = SQLiteTempConnector.create_temporary_db() connect.create_temp_tables( { "user": { diff --git a/requirements/lint-requirements.txt b/requirements/lint-requirements.txt index 4e9e454e8..733bf89f4 100644 --- a/requirements/lint-requirements.txt +++ b/requirements/lint-requirements.txt @@ -12,4 +12,5 @@ pyupgrade==3.1.0 types-requests types-beautifulsoup4 types-Markdown -types-tqdm \ No newline at end of file +types-tqdm +pandas-stubs \ No newline at end of file diff --git a/tests/intetration_tests/datasource/test_conn_clickhouse.py b/tests/intetration_tests/datasource/test_conn_clickhouse.py index a90de0039..a25777b5a 100644 --- a/tests/intetration_tests/datasource/test_conn_clickhouse.py +++ b/tests/intetration_tests/datasource/test_conn_clickhouse.py @@ -20,12 +20,12 @@ from typing import Dict, List import pytest -from dbgpt.datasource.rdbms.conn_clickhouse import ClickhouseConnect +from dbgpt.datasource.rdbms.conn_clickhouse import ClickhouseConnector @pytest.fixture def db(): - conn = ClickhouseConnect.from_uri_db("localhost", 8123, "default", "", "default") + conn = ClickhouseConnector.from_uri_db("localhost", 8123, "default", "", "default") yield conn diff --git a/tests/intetration_tests/datasource/test_conn_doris.py b/tests/intetration_tests/datasource/test_conn_doris.py index 2a9fcfd2f..33364f413 100644 --- a/tests/intetration_tests/datasource/test_conn_doris.py +++ b/tests/intetration_tests/datasource/test_conn_doris.py @@ -4,10 +4,10 @@ import pytest -from dbgpt.datasource.rdbms.conn_doris import DorisConnect +from dbgpt.datasource.rdbms.conn_doris import DorisConnector @pytest.fixture def db(): - conn = DorisConnect.from_uri_db("localhost", 9030, "root", "", "test") + conn = DorisConnector.from_uri_db("localhost", 9030, "root", "", "test") yield conn diff --git a/tests/intetration_tests/datasource/test_conn_mysql.py b/tests/intetration_tests/datasource/test_conn_mysql.py index 738d1470a..4ccb15ff1 100644 --- a/tests/intetration_tests/datasource/test_conn_mysql.py +++ b/tests/intetration_tests/datasource/test_conn_mysql.py @@ -20,7 +20,7 @@ import pytest -from dbgpt.datasource.rdbms.conn_mysql import MySQLConnect +from dbgpt.datasource.rdbms.conn_mysql import MySQLConnector _create_table_sql = """ CREATE TABLE IF NOT EXISTS `test` ( @@ -31,7 +31,7 @@ _create_table_sql = """ @pytest.fixture def db(): - conn = MySQLConnect.from_uri_db( + conn = MySQLConnector.from_uri_db( "localhost", 3307, "root", @@ -89,4 +89,4 @@ def test_get_users(db): def test_get_database_lists(db): - assert db.get_database_list() == ["test"] + assert db.get_database_names() == ["test"] diff --git a/tests/intetration_tests/datasource/test_conn_starrocks.py b/tests/intetration_tests/datasource/test_conn_starrocks.py index ec4d9b0ec..9c4ee3dc4 100644 --- a/tests/intetration_tests/datasource/test_conn_starrocks.py +++ b/tests/intetration_tests/datasource/test_conn_starrocks.py @@ -21,12 +21,12 @@ import pytest -from dbgpt.datasource.rdbms.conn_starrocks import StarRocksConnect +from dbgpt.datasource.rdbms.conn_starrocks import StarRocksConnector @pytest.fixture def db(): - conn = StarRocksConnect.from_uri_db("localhost", 9030, "root", "", "test") + conn = StarRocksConnector.from_uri_db("localhost", 9030, "root", "", "test") yield conn