mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-03 09:34:04 +00:00
283 lines
9.8 KiB
Python
283 lines
9.8 KiB
Python
"""SQLite connector."""
|
|
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
from typing import Any, Iterable, List, Optional, Tuple
|
|
|
|
from sqlalchemy import create_engine, text
|
|
|
|
from .base import RDBMSConnector
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SQLiteConnector(RDBMSConnector):
|
|
"""SQLite connector."""
|
|
|
|
db_type: str = "sqlite"
|
|
db_dialect: str = "sqlite"
|
|
|
|
@classmethod
|
|
def from_file_path(
|
|
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
|
) -> "SQLiteConnector":
|
|
"""Create a new SQLiteConnector from file path."""
|
|
_engine_args = engine_args or {}
|
|
_engine_args["connect_args"] = {"check_same_thread": False}
|
|
# _engine_args["echo"] = True
|
|
directory = os.path.dirname(file_path)
|
|
if not os.path.exists(directory):
|
|
os.makedirs(directory)
|
|
return cls(create_engine("sqlite:///" + file_path, **_engine_args), **kwargs)
|
|
|
|
def get_indexes(self, table_name):
|
|
"""Get table indexes about specified table."""
|
|
cursor = self.session.execute(text(f"PRAGMA index_list({table_name})"))
|
|
indexes = cursor.fetchall()
|
|
result = []
|
|
for idx in indexes:
|
|
index_name = idx[1]
|
|
cursor = self.session.execute(text(f"PRAGMA index_info({index_name})"))
|
|
index_infos = cursor.fetchall()
|
|
column_names = [index_info[2] for index_info in index_infos]
|
|
result.append({"name": index_name, "column_names": column_names})
|
|
return result
|
|
|
|
def get_show_create_table(self, table_name):
|
|
"""Get table show create table about specified table."""
|
|
cursor = self.session.execute(
|
|
text(
|
|
"SELECT sql FROM sqlite_master WHERE type='table' "
|
|
f"AND name='{table_name}'"
|
|
)
|
|
)
|
|
ans = cursor.fetchall()
|
|
return ans[0][0]
|
|
|
|
def get_fields(self, table_name, db_name=None) -> List[Tuple]:
|
|
"""Get column fields about specified table."""
|
|
cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')"))
|
|
fields = cursor.fetchall()
|
|
logger.info(fields)
|
|
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
|
|
|
|
def get_simple_fields(self, table_name):
|
|
"""Get column fields about specified table."""
|
|
return self.get_fields(table_name)
|
|
|
|
def get_users(self):
|
|
"""Get user info."""
|
|
return []
|
|
|
|
def get_grants(self):
|
|
"""Get grants."""
|
|
return []
|
|
|
|
def get_collation(self):
|
|
"""Get collation."""
|
|
return "UTF-8"
|
|
|
|
def get_charset(self):
|
|
"""Get character_set of current database."""
|
|
return "UTF-8"
|
|
|
|
def get_database_names(self):
|
|
"""Get database names."""
|
|
return []
|
|
|
|
def _sync_tables_from_db(self) -> Iterable[str]:
|
|
table_results = self.session.execute(
|
|
text("SELECT name FROM sqlite_master WHERE type='table'")
|
|
)
|
|
view_results = self.session.execute(
|
|
text("SELECT name FROM sqlite_master WHERE type='view'")
|
|
)
|
|
table_results = set(row[0] for row in table_results) # noqa
|
|
view_results = set(row[0] for row in view_results) # noqa
|
|
self._all_tables = table_results.union(view_results)
|
|
self._metadata.reflect(bind=self._engine)
|
|
return self._all_tables
|
|
|
|
def _write(self, write_sql):
|
|
logger.info(f"Write[{write_sql}]")
|
|
session = self.session
|
|
result = session.execute(text(write_sql))
|
|
session.commit()
|
|
# TODO Subsequent optimization of dynamically specified database submission
|
|
# loss target problem
|
|
logger.info(f"SQL[{write_sql}], result:{result.rowcount}")
|
|
return result.rowcount
|
|
|
|
def get_table_comments(self, db_name=None):
|
|
"""Get table comments."""
|
|
cursor = self.session.execute(
|
|
text(
|
|
"""
|
|
SELECT name, sql FROM sqlite_master WHERE type='table'
|
|
"""
|
|
)
|
|
)
|
|
table_comments = cursor.fetchall()
|
|
return [
|
|
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
|
]
|
|
|
|
def get_current_db_name(self) -> str:
|
|
"""Get current database name.
|
|
|
|
Returns:
|
|
str: database name
|
|
"""
|
|
full_path = self._engine.url.database
|
|
db_name = os.path.basename(full_path)
|
|
if db_name.endswith(".db"):
|
|
db_name = db_name[:-3]
|
|
return db_name
|
|
|
|
def table_simple_info(self) -> Iterable[str]:
|
|
"""Get table simple info."""
|
|
_tables_sql = """
|
|
SELECT name FROM sqlite_master WHERE type='table'
|
|
"""
|
|
cursor = self.session.execute(text(_tables_sql))
|
|
tables_results = cursor.fetchall()
|
|
results = []
|
|
for row in tables_results:
|
|
table_name = row[0]
|
|
_sql = f"""
|
|
PRAGMA table_info({table_name})
|
|
"""
|
|
cursor_colums = self.session.execute(text(_sql))
|
|
colum_results = cursor_colums.fetchall()
|
|
table_colums = []
|
|
for row_col in colum_results:
|
|
field_info = list(row_col)
|
|
table_colums.append(field_info[1])
|
|
|
|
results.append(f"{table_name}({','.join(table_colums)});")
|
|
return results
|
|
|
|
|
|
class SQLiteTempConnector(SQLiteConnector):
|
|
"""A temporary SQLite database connection.
|
|
|
|
The database file will be deleted when the connection is closed.
|
|
"""
|
|
|
|
def __init__(self, engine, temp_file_path, *args, **kwargs):
|
|
"""Construct a temporary SQLite database connection."""
|
|
super().__init__(engine, *args, **kwargs)
|
|
self.temp_file_path = temp_file_path
|
|
self._is_closed = False
|
|
|
|
@classmethod
|
|
def create_temporary_db(
|
|
cls, engine_args: Optional[dict] = None, **kwargs: Any
|
|
) -> "SQLiteTempConnector":
|
|
"""Create a temporary SQLite database with a temporary file.
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
with SQLiteTempConnect.create_temporary_db() as db:
|
|
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
|
db.run(db.session, "insert into test(id) values (1)")
|
|
db.run(db.session, "insert into test(id) values (2)")
|
|
field_names, result = db.query_ex(db.session, "select * from test")
|
|
assert field_names == ["id"]
|
|
assert result == [(1,), (2,)]
|
|
|
|
Args:
|
|
engine_args (Optional[dict]): SQLAlchemy engine arguments.
|
|
|
|
Returns:
|
|
SQLiteTempConnector: A SQLiteTempConnect instance.
|
|
"""
|
|
_engine_args = engine_args or {}
|
|
_engine_args["connect_args"] = {"check_same_thread": False}
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False)
|
|
temp_file_path = temp_file.name
|
|
temp_file.close()
|
|
|
|
engine = create_engine(f"sqlite:///{temp_file_path}", **_engine_args)
|
|
return cls(engine, temp_file_path, **kwargs)
|
|
|
|
def close(self):
|
|
"""Close the connection."""
|
|
if not self._is_closed:
|
|
if self._engine:
|
|
self._engine.dispose()
|
|
try:
|
|
if os.path.exists(self.temp_file_path):
|
|
os.remove(self.temp_file_path)
|
|
except Exception as e:
|
|
logger.error(f"Error removing temporary database file: {e}")
|
|
self._is_closed = True
|
|
|
|
def create_temp_tables(self, tables_info):
|
|
"""Create temporary tables with data.
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
tables_info = {
|
|
"test": {
|
|
"columns": {
|
|
"id": "INTEGER PRIMARY KEY",
|
|
"name": "TEXT",
|
|
"age": "INTEGER",
|
|
},
|
|
"data": [
|
|
(1, "Tom", 20),
|
|
(2, "Jack", 21),
|
|
(3, "Alice", 22),
|
|
],
|
|
},
|
|
}
|
|
with SQLiteTempConnector.create_temporary_db() as db:
|
|
db.create_temp_tables(tables_info)
|
|
field_names, result = db.query_ex(db.session, "select * from test")
|
|
assert field_names == ["id", "name", "age"]
|
|
assert result == [(1, "Tom", 20), (2, "Jack", 21), (3, "Alice", 22)]
|
|
|
|
Args:
|
|
tables_info (dict): A dictionary of table information.
|
|
"""
|
|
for table_name, table_data in tables_info.items():
|
|
columns = ", ".join(
|
|
[f"{col} {dtype}" for col, dtype in table_data["columns"].items()]
|
|
)
|
|
create_sql = f"CREATE TABLE {table_name} ({columns});"
|
|
self.session.execute(text(create_sql))
|
|
for row in table_data.get("data", []):
|
|
placeholders = ", ".join(
|
|
[":param" + str(index) for index, _ in enumerate(row)]
|
|
)
|
|
insert_sql = f"INSERT INTO {table_name} VALUES ({placeholders});"
|
|
|
|
param_dict = {
|
|
"param" + str(index): value for index, value in enumerate(row)
|
|
}
|
|
self.session.execute(text(insert_sql), param_dict)
|
|
self.session.commit()
|
|
self._sync_tables_from_db()
|
|
|
|
def __enter__(self):
|
|
"""Return the connection when entering the context manager."""
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""Close the connection when exiting the context manager."""
|
|
self.close()
|
|
|
|
def __del__(self):
|
|
"""Close the connection when the object is deleted."""
|
|
self.close()
|
|
|
|
@classmethod
|
|
def is_normal_type(cls) -> bool:
|
|
"""Return whether the connector is a normal type."""
|
|
return False
|