refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

@@ -270,7 +270,12 @@ class RDBMSDatabase(BaseConnect):
"""Format the error message"""
return f"Error: {e}"
def __write(self, write_sql):
def _write(self, write_sql: str):
"""Run a SQL write command and return the results as a list of tuples.
Args:
write_sql (str): SQL write command to run
"""
print(f"Write[{write_sql}]")
db_cache = self._engine.url.database
result = self.session.execute(text(write_sql))
@@ -280,16 +285,12 @@ class RDBMSDatabase(BaseConnect):
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def __query(self, query, fetch: str = "all"):
"""
only for query
def _query(self, query: str, fetch: str = "all"):
"""Run a SQL query and return the results as a list of tuples.
Args:
session:
query:
fetch:
Returns:
query (str): SQL query to run
fetch (str): fetch type
"""
print(f"Query[{query}]")
if not query:
@@ -308,6 +309,10 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names)
return result
def query_table_schema(self, table_name):
sql = f"select * from {table_name} limit 1"
return self._query(sql)
def query_ex(self, query, fetch: str = "all"):
"""
only for query
@@ -325,7 +330,7 @@ class RDBMSDatabase(BaseConnect):
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
result = cursor.fetchone() # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(i[0:] for i in cursor.keys())
@@ -342,12 +347,12 @@ class RDBMSDatabase(BaseConnect):
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self.__query(command, fetch)
return self._query(command, fetch)
else:
self.__write(command)
self._write(command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self.__query(select_sql)
return self._query(select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
@@ -360,10 +365,11 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names)
print("DDL Result:" + str(result))
if not result:
return self.__query(f"SHOW COLUMNS FROM {table_name}")
# return self._query(f"SHOW COLUMNS FROM {table_name}")
return self.get_simple_fields(table_name)
return result
else:
return self.__query(f"SHOW COLUMNS FROM {table_name}")
return self.get_simple_fields(table_name)
def run_to_df(self, command: str, fetch: str = "all"):
result_lst = self.run(command, fetch)
@@ -451,13 +457,23 @@ class RDBMSDatabase(BaseConnect):
sql = sql.strip()
parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type()
table_name = parsed.get_name()
if sql_type == "CREATE":
table_name = self._extract_table_name_from_ddl(parsed)
else:
table_name = parsed.get_name()
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
return parsed, ttype, sql_type, table_name
def _extract_table_name_from_ddl(self, parsed):
"""Extract table name from CREATE TABLE statement.""" ""
for token in parsed.tokens:
if token.ttype is None and isinstance(token, sqlparse.sql.Identifier):
return token.get_real_name()
return None
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()
@@ -485,6 +501,10 @@ class RDBMSDatabase(BaseConnect):
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_simple_fields(self, table_name):
"""Get column fields about specified table."""
return self._query(f"SHOW COLUMNS FROM {table_name}")
def get_charset(self):
"""Get character_set."""
session = self._db_sessions()

View File

@@ -56,6 +56,10 @@ class SQLiteConnect(RDBMSDatabase):
print(fields)
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
def get_simple_fields(self, table_name):
"""Get column fields about specified table."""
return self.get_fields(table_name)
def get_users(self):
return []
@@ -88,8 +92,9 @@ class SQLiteConnect(RDBMSDatabase):
self._metadata.reflect(bind=self._engine)
return self._all_tables
def _write(self, session, write_sql):
def _write(self, write_sql):
print(f"Write[{write_sql}]")
session = self.session
result = session.execute(text(write_sql))
session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem

View File

@@ -25,41 +25,41 @@ def test_get_table_info(db):
def test_get_table_info_with_table(db):
db.run(db.session, "CREATE TABLE test (id INTEGER);")
db.run("CREATE TABLE test (id INTEGER);")
print(db._sync_tables_from_db())
table_info = db.get_table_info()
assert "CREATE TABLE test" in table_info
def test_run_sql(db):
result = db.run(db.session, "CREATE TABLE test (id INTEGER);")
assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk")
result = db.run("CREATE TABLE test(id INTEGER);")
assert result[0] == ("id", "INTEGER", 0, None, 0)
def test_run_no_throw(db):
assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:")
assert db.run_no_throw("this is a error sql").startswith("Error:")
def test_get_indexes(db):
db.run(db.session, "CREATE TABLE test (name TEXT);")
db.run(db.session, "CREATE INDEX idx_name ON test(name);")
db.run("CREATE TABLE test (name TEXT);")
db.run("CREATE INDEX idx_name ON test(name);")
assert db.get_indexes("test") == [("idx_name", "c")]
def test_get_indexes_empty(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_indexes("test") == []
def test_get_show_create_table(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert (
db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
)
def test_get_fields(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)]
@@ -72,26 +72,26 @@ def test_get_collation(db):
def test_table_simple_info(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.table_simple_info() == ["test(id);"]
def test_get_table_info_no_throw(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_table_info_no_throw("xxxx_table").startswith("Error:")
def test_query_ex(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run(db.session, "insert into test(id) values (1)")
db.run(db.session, "insert into test(id) values (2)")
field_names, result = db.query_ex(db.session, "select * from test")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("insert into test(id) values (1)")
db.run("insert into test(id) values (2)")
field_names, result = db.query_ex("select * from test")
assert field_names == ["id"]
assert result == [(1,), (2,)]
field_names, result = db.query_ex(db.session, "select * from test", fetch="one")
field_names, result = db.query_ex("select * from test", fetch="one")
assert field_names == ["id"]
assert result == [(1,)]
assert result == [1]
def test_convert_sql_write_to_select(db):
@@ -109,7 +109,7 @@ def test_get_users(db):
def test_get_table_comments(db):
assert db.get_table_comments() == []
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_table_comments() == [
("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)")
]