refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

@@ -91,6 +91,10 @@ class BaseConnect(ABC):
"""Get column fields about specified table."""
pass
def get_simple_fields(self, table_name):
"""Get column fields about specified table."""
return self.get_fields(table_name)
def get_show_create_table(self, table_name):
"""Get the creation table sql about specified table."""
pass

View File

@@ -1,16 +1,10 @@
from sqlalchemy import Column, Integer, String, Index, Text, text
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
class ConnectConfigEntity(Base):
class ConnectConfigEntity(Model):
"""db connect config entity"""
__tablename__ = "connect_config"
@@ -38,17 +32,9 @@ class ConnectConfigEntity(Base):
class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
"""db connect config dao"""
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def update(self, entity: ConnectConfigEntity):
"""update db connect info"""
session = self.get_session()
session = self.get_raw_session()
try:
updated = session.merge(entity)
session.commit()
@@ -58,7 +44,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def delete(self, db_name: int):
""" "delete db connect info"""
session = self.get_session()
session = self.get_raw_session()
if db_name is None:
raise Exception("db_name is None")
@@ -70,7 +56,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_by_names(self, db_name: str) -> ConnectConfigEntity:
"""get db connect info by name"""
session = self.get_session()
session = self.get_raw_session()
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
result = db_connect.first()
@@ -99,7 +85,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
comment: comment
"""
try:
session = self.get_session()
session = self.get_raw_session()
from sqlalchemy import text
@@ -144,7 +130,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
old_db_conf = self.get_db_config(db_name)
if old_db_conf:
try:
session = self.get_session()
session = self.get_raw_session()
if not db_path:
update_statement = text(
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
@@ -164,7 +150,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
"""add file db connect info"""
try:
session = self.get_session()
session = self.get_raw_session()
insert_statement = text(
"""
INSERT INTO connect_config(
@@ -194,7 +180,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_db_config(self, db_name):
"""get db config by name"""
session = self.get_session()
session = self.get_raw_session()
if db_name:
select_statement = text(
"""
@@ -221,7 +207,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_db_list(self):
"""get db list"""
session = self.get_session()
session = self.get_raw_session()
result = session.execute(text("SELECT * FROM connect_config"))
fields = [field[0] for field in result.cursor.description]
@@ -235,7 +221,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def delete_db(self, db_name):
"""delete db connect info"""
session = self.get_session()
session = self.get_raw_session()
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
params = {"db_name": db_name}
session.execute(delete_statement, params)

View File

@@ -270,7 +270,12 @@ class RDBMSDatabase(BaseConnect):
"""Format the error message"""
return f"Error: {e}"
def __write(self, write_sql):
def _write(self, write_sql: str):
"""Run a SQL write command and return the results as a list of tuples.
Args:
write_sql (str): SQL write command to run
"""
print(f"Write[{write_sql}]")
db_cache = self._engine.url.database
result = self.session.execute(text(write_sql))
@@ -280,16 +285,12 @@ class RDBMSDatabase(BaseConnect):
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def __query(self, query, fetch: str = "all"):
"""
only for query
def _query(self, query: str, fetch: str = "all"):
"""Run a SQL query and return the results as a list of tuples.
Args:
session:
query:
fetch:
Returns:
query (str): SQL query to run
fetch (str): fetch type
"""
print(f"Query[{query}]")
if not query:
@@ -308,6 +309,10 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names)
return result
def query_table_schema(self, table_name):
sql = f"select * from {table_name} limit 1"
return self._query(sql)
def query_ex(self, query, fetch: str = "all"):
"""
only for query
@@ -325,7 +330,7 @@ class RDBMSDatabase(BaseConnect):
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
result = cursor.fetchone() # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(i[0:] for i in cursor.keys())
@@ -342,12 +347,12 @@ class RDBMSDatabase(BaseConnect):
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self.__query(command, fetch)
return self._query(command, fetch)
else:
self.__write(command)
self._write(command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self.__query(select_sql)
return self._query(select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
@@ -360,10 +365,11 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names)
print("DDL Result:" + str(result))
if not result:
return self.__query(f"SHOW COLUMNS FROM {table_name}")
# return self._query(f"SHOW COLUMNS FROM {table_name}")
return self.get_simple_fields(table_name)
return result
else:
return self.__query(f"SHOW COLUMNS FROM {table_name}")
return self.get_simple_fields(table_name)
def run_to_df(self, command: str, fetch: str = "all"):
result_lst = self.run(command, fetch)
@@ -451,13 +457,23 @@ class RDBMSDatabase(BaseConnect):
sql = sql.strip()
parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type()
table_name = parsed.get_name()
if sql_type == "CREATE":
table_name = self._extract_table_name_from_ddl(parsed)
else:
table_name = parsed.get_name()
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
return parsed, ttype, sql_type, table_name
def _extract_table_name_from_ddl(self, parsed):
"""Extract table name from CREATE TABLE statement.""" ""
for token in parsed.tokens:
if token.ttype is None and isinstance(token, sqlparse.sql.Identifier):
return token.get_real_name()
return None
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()
@@ -485,6 +501,10 @@ class RDBMSDatabase(BaseConnect):
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_simple_fields(self, table_name):
"""Get column fields about specified table."""
return self._query(f"SHOW COLUMNS FROM {table_name}")
def get_charset(self):
"""Get character_set."""
session = self._db_sessions()

View File

@@ -56,6 +56,10 @@ class SQLiteConnect(RDBMSDatabase):
print(fields)
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
def get_simple_fields(self, table_name):
"""Get column fields about specified table."""
return self.get_fields(table_name)
def get_users(self):
return []
@@ -88,8 +92,9 @@ class SQLiteConnect(RDBMSDatabase):
self._metadata.reflect(bind=self._engine)
return self._all_tables
def _write(self, session, write_sql):
def _write(self, write_sql):
print(f"Write[{write_sql}]")
session = self.session
result = session.execute(text(write_sql))
session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem

View File

@@ -25,41 +25,41 @@ def test_get_table_info(db):
def test_get_table_info_with_table(db):
db.run(db.session, "CREATE TABLE test (id INTEGER);")
db.run("CREATE TABLE test (id INTEGER);")
print(db._sync_tables_from_db())
table_info = db.get_table_info()
assert "CREATE TABLE test" in table_info
def test_run_sql(db):
result = db.run(db.session, "CREATE TABLE test (id INTEGER);")
assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk")
result = db.run("CREATE TABLE test(id INTEGER);")
assert result[0] == ("id", "INTEGER", 0, None, 0)
def test_run_no_throw(db):
assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:")
assert db.run_no_throw("this is a error sql").startswith("Error:")
def test_get_indexes(db):
db.run(db.session, "CREATE TABLE test (name TEXT);")
db.run(db.session, "CREATE INDEX idx_name ON test(name);")
db.run("CREATE TABLE test (name TEXT);")
db.run("CREATE INDEX idx_name ON test(name);")
assert db.get_indexes("test") == [("idx_name", "c")]
def test_get_indexes_empty(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_indexes("test") == []
def test_get_show_create_table(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert (
db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
)
def test_get_fields(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)]
@@ -72,26 +72,26 @@ def test_get_collation(db):
def test_table_simple_info(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.table_simple_info() == ["test(id);"]
def test_get_table_info_no_throw(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_table_info_no_throw("xxxx_table").startswith("Error:")
def test_query_ex(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run(db.session, "insert into test(id) values (1)")
db.run(db.session, "insert into test(id) values (2)")
field_names, result = db.query_ex(db.session, "select * from test")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("insert into test(id) values (1)")
db.run("insert into test(id) values (2)")
field_names, result = db.query_ex("select * from test")
assert field_names == ["id"]
assert result == [(1,), (2,)]
field_names, result = db.query_ex(db.session, "select * from test", fetch="one")
field_names, result = db.query_ex("select * from test", fetch="one")
assert field_names == ["id"]
assert result == [(1,)]
assert result == [1]
def test_convert_sql_write_to_select(db):
@@ -109,7 +109,7 @@ def test_get_users(db):
def test_get_table_comments(db):
assert db.get_table_comments() == []
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_table_comments() == [
("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)")
]