feat(core): Support simple DB query for sdk (#917)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-11 18:33:54 +08:00
committed by GitHub
parent 43190ca333
commit cbba50ab1b
18 changed files with 467 additions and 74 deletions

View File

@@ -4,9 +4,12 @@
import os
from typing import Optional, Any, Iterable
from sqlalchemy import create_engine, text
import tempfile
import logging
from dbgpt.datasource.rdbms.base import RDBMSDatabase
logger = logging.getLogger(__name__)
class SQLiteConnect(RDBMSDatabase):
"""Connect SQLite Database fetch MetaData
@@ -127,3 +130,116 @@ class SQLiteConnect(RDBMSDatabase):
results.append(f"{table_name}({','.join(table_colums)});")
return results
class SQLiteTempConnect(SQLiteConnect):
"""A temporary SQLite database connection. The database file will be deleted when the connection is closed."""
def __init__(self, engine, temp_file_path, *args, **kwargs):
super().__init__(engine, *args, **kwargs)
self.temp_file_path = temp_file_path
self._is_closed = False
@classmethod
def create_temporary_db(
cls, engine_args: Optional[dict] = None, **kwargs: Any
) -> "SQLiteTempConnect":
"""Create a temporary SQLite database with a temporary file.
Examples:
.. code-block:: python
with SQLiteTempConnect.create_temporary_db() as db:
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run(db.session, "insert into test(id) values (1)")
db.run(db.session, "insert into test(id) values (2)")
field_names, result = db.query_ex(db.session, "select * from test")
assert field_names == ["id"]
assert result == [(1,), (2,)]
Args:
engine_args (Optional[dict]): SQLAlchemy engine arguments.
Returns:
SQLiteTempConnect: A SQLiteTempConnect instance.
"""
_engine_args = engine_args or {}
_engine_args["connect_args"] = {"check_same_thread": False}
temp_file = tempfile.NamedTemporaryFile(delete=False)
temp_file_path = temp_file.name
temp_file.close()
engine = create_engine(f"sqlite:///{temp_file_path}", **_engine_args)
return cls(engine, temp_file_path, **kwargs)
def close(self):
"""Close the connection."""
if not self._is_closed:
if self._engine:
self._engine.dispose()
try:
if os.path.exists(self.temp_file_path):
os.remove(self.temp_file_path)
except Exception as e:
logger.error(f"Error removing temporary database file: {e}")
self._is_closed = True
def create_temp_tables(self, tables_info):
"""Create temporary tables with data.
Examples:
.. code-block:: python
tables_info = {
"test": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 20),
(2, "Jack", 21),
(3, "Alice", 22),
],
},
}
with SQLiteTempConnect.create_temporary_db() as db:
db.create_temp_tables(tables_info)
field_names, result = db.query_ex(db.session, "select * from test")
assert field_names == ["id", "name", "age"]
assert result == [(1, "Tom", 20), (2, "Jack", 21), (3, "Alice", 22)]
Args:
tables_info (dict): A dictionary of table information.
"""
for table_name, table_data in tables_info.items():
columns = ", ".join(
[f"{col} {dtype}" for col, dtype in table_data["columns"].items()]
)
create_sql = f"CREATE TABLE {table_name} ({columns});"
self.session.execute(text(create_sql))
for row in table_data.get("data", []):
placeholders = ", ".join(
[":param" + str(index) for index, _ in enumerate(row)]
)
insert_sql = f"INSERT INTO {table_name} VALUES ({placeholders});"
param_dict = {
"param" + str(index): value for index, value in enumerate(row)
}
self.session.execute(text(insert_sql), param_dict)
self.session.commit()
self._sync_tables_from_db()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __del__(self):
self.close()
@classmethod
def is_normal_type(cls) -> bool:
return False