feat(core): Support simple DB query for sdk (#917)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-11 18:33:54 +08:00
committed by GitHub
parent 43190ca333
commit cbba50ab1b
18 changed files with 467 additions and 74 deletions

View File

@@ -2,58 +2,104 @@
# -*- coding:utf-8 -*-
"""We need to design a base class. That other connector can Write with this"""
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional
from abc import ABC
from typing import Iterable, List, Optional
class BaseConnect(ABC):
def get_connect(self, db_name: str):
pass
def get_table_names(self) -> Iterable[str]:
"""Get all table names"""
pass
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get table info about specified table.
Returns:
str: Table information joined by '\n\n'
"""
pass
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get index info about specified table.
Args:
table_names (Optional[List[str]]): table names
"""
pass
def get_example_data(self, table: str, count: int = 3):
"""Get example data about specified table.
Not used now.
Args:
table (str): table name
count (int): example data count
"""
pass
def get_database_list(self):
def get_database_list(self) -> List[str]:
"""Get database list.
Returns:
List[str]: database list
"""
pass
def get_database_names(self):
"""Get database names."""
pass
def get_table_comments(self, db_name):
"""Get table comments.
Args:
db_name (str): database name
"""
pass
def run(self, session, command: str, fetch: str = "all") -> List:
def run(self, command: str, fetch: str = "all") -> List:
"""Execute sql command.
Args:
command (str): sql command
fetch (str): fetch type
"""
pass
def run_to_df(self, command: str, fetch: str = "all"):
"""Execute sql command and return dataframe."""
pass
def get_users(self):
pass
"""Get user info."""
return []
def get_grants(self):
pass
"""Get grant info."""
return []
def get_collation(self):
pass
"""Get collation."""
return None
def get_charset(self):
pass
def get_charset(self) -> str:
"""Get character_set of current database."""
return "utf-8"
def get_fields(self, table_name):
"""Get column fields about specified table."""
pass
def get_show_create_table(self, table_name):
"""Get the creation table sql about specified table."""
pass
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
pass
@classmethod
def is_normal_type(cls) -> bool:
"""Return whether the connector is a normal type."""
return True

View File

@@ -1,4 +1,4 @@
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel
class DBConfig(BaseModel):

View File

@@ -1,3 +1,4 @@
from typing import List, Type
from dbgpt.datasource import ConnectConfigDao
from dbgpt.storage.schema import DBType
from dbgpt.component import SystemApp, ComponentType
@@ -21,7 +22,7 @@ from dbgpt.datasource.rdbms.conn_doris import DorisConnect
class ConnectManager:
"""db connect manager"""
def get_all_subclasses(self, cls):
def get_all_subclasses(self, cls: Type[BaseConnect]) -> List[Type[BaseConnect]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += self.get_all_subclasses(subclass)
@@ -31,7 +32,7 @@ class ConnectManager:
chat_classes = self.get_all_subclasses(BaseConnect)
support_types = []
for cls in chat_classes:
if cls.db_type:
if cls.db_type and cls.is_normal_type():
support_types.append(DBType.of_db_type(cls.db_type))
return support_types
@@ -39,7 +40,7 @@ class ConnectManager:
chat_classes = self.get_all_subclasses(BaseConnect)
result = None
for cls in chat_classes:
if cls.db_type == db_type:
if cls.db_type == db_type and cls.is_normal_type():
result = cls
if not result:
raise ValueError("Unsupported Db Type" + db_type)

View File

View File

@@ -0,0 +1,16 @@
from typing import Any
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT
from dbgpt.datasource.base import BaseConnect
class DatasourceOperator(MapOperator[str, Any]):
def __init__(self, connection: BaseConnect, **kwargs):
super().__init__(**kwargs)
self._connection = connection
async def map(self, input_value: IN) -> OUT:
return await self.blocking_func_to_async(self.query, input_value)
def query(self, input_value: str) -> Any:
return self._connection.run_to_df(input_value)

View File

@@ -4,9 +4,12 @@
import os
from typing import Optional, Any, Iterable
from sqlalchemy import create_engine, text
import tempfile
import logging
from dbgpt.datasource.rdbms.base import RDBMSDatabase
logger = logging.getLogger(__name__)
class SQLiteConnect(RDBMSDatabase):
"""Connect SQLite Database fetch MetaData
@@ -127,3 +130,116 @@ class SQLiteConnect(RDBMSDatabase):
results.append(f"{table_name}({','.join(table_colums)});")
return results
class SQLiteTempConnect(SQLiteConnect):
"""A temporary SQLite database connection. The database file will be deleted when the connection is closed."""
def __init__(self, engine, temp_file_path, *args, **kwargs):
super().__init__(engine, *args, **kwargs)
self.temp_file_path = temp_file_path
self._is_closed = False
@classmethod
def create_temporary_db(
cls, engine_args: Optional[dict] = None, **kwargs: Any
) -> "SQLiteTempConnect":
"""Create a temporary SQLite database with a temporary file.
Examples:
.. code-block:: python
with SQLiteTempConnect.create_temporary_db() as db:
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run(db.session, "insert into test(id) values (1)")
db.run(db.session, "insert into test(id) values (2)")
field_names, result = db.query_ex(db.session, "select * from test")
assert field_names == ["id"]
assert result == [(1,), (2,)]
Args:
engine_args (Optional[dict]): SQLAlchemy engine arguments.
Returns:
SQLiteTempConnect: A SQLiteTempConnect instance.
"""
_engine_args = engine_args or {}
_engine_args["connect_args"] = {"check_same_thread": False}
temp_file = tempfile.NamedTemporaryFile(delete=False)
temp_file_path = temp_file.name
temp_file.close()
engine = create_engine(f"sqlite:///{temp_file_path}", **_engine_args)
return cls(engine, temp_file_path, **kwargs)
def close(self):
"""Close the connection."""
if not self._is_closed:
if self._engine:
self._engine.dispose()
try:
if os.path.exists(self.temp_file_path):
os.remove(self.temp_file_path)
except Exception as e:
logger.error(f"Error removing temporary database file: {e}")
self._is_closed = True
def create_temp_tables(self, tables_info):
"""Create temporary tables with data.
Examples:
.. code-block:: python
tables_info = {
"test": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 20),
(2, "Jack", 21),
(3, "Alice", 22),
],
},
}
with SQLiteTempConnect.create_temporary_db() as db:
db.create_temp_tables(tables_info)
field_names, result = db.query_ex(db.session, "select * from test")
assert field_names == ["id", "name", "age"]
assert result == [(1, "Tom", 20), (2, "Jack", 21), (3, "Alice", 22)]
Args:
tables_info (dict): A dictionary of table information.
"""
for table_name, table_data in tables_info.items():
columns = ", ".join(
[f"{col} {dtype}" for col, dtype in table_data["columns"].items()]
)
create_sql = f"CREATE TABLE {table_name} ({columns});"
self.session.execute(text(create_sql))
for row in table_data.get("data", []):
placeholders = ", ".join(
[":param" + str(index) for index, _ in enumerate(row)]
)
insert_sql = f"INSERT INTO {table_name} VALUES ({placeholders});"
param_dict = {
"param" + str(index): value for index, value in enumerate(row)
}
self.session.execute(text(insert_sql), param_dict)
self.session.commit()
self._sync_tables_from_db()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __del__(self):
self.close()
@classmethod
def is_normal_type(cls) -> bool:
return False