add plugin mode

This commit is contained in:
yhjun1026
2023-05-29 19:32:20 +08:00
parent 52da74c54a
commit 20edf6daaa
45 changed files with 1202 additions and 804 deletions

View File

@@ -30,16 +30,16 @@ class Database:
"""SQLAlchemy wrapper around a database."""
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
@@ -119,7 +119,7 @@ class Database:
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> Database:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
@@ -148,7 +148,7 @@ class Database:
self._metadata = MetaData()
# sql = f"use {db_name}"
sql = text(f'use `{db_name}`')
sql = text(f"use `{db_name}`")
session.execute(sql)
# 处理表信息数据
@@ -159,13 +159,17 @@ class Database:
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=db_name)
+ (self._inspector.get_view_names(schema=db_name) if self.view_support else [])
+ (
self._inspector.get_view_names(schema=db_name)
if self.view_support
else []
)
)
return session
def get_current_db_name(self, session) -> str:
return session.execute(text('SELECT DATABASE()')).scalar()
return session.execute(text("SELECT DATABASE()")).scalar()
def table_simple_info(self, session):
_sql = f"""
@@ -201,7 +205,7 @@ class Database:
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
@@ -214,7 +218,7 @@ class Database:
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
@@ -303,6 +307,10 @@ class Database:
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text(' show databases;'))
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [d[0] for d in results if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]]
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]