mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 12:37:14 +00:00
refactor: Refactor storage system (#937)
This commit is contained in:
@@ -91,6 +91,10 @@ class BaseConnect(ABC):
|
||||
"""Get column fields about specified table."""
|
||||
pass
|
||||
|
||||
def get_simple_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
return self.get_fields(table_name)
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get the creation table sql about specified table."""
|
||||
pass
|
||||
|
@@ -1,16 +1,10 @@
|
||||
from sqlalchemy import Column, Integer, String, Index, Text, text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
class ConnectConfigEntity(Base):
|
||||
class ConnectConfigEntity(Model):
|
||||
"""db connect config entity"""
|
||||
|
||||
__tablename__ = "connect_config"
|
||||
@@ -38,17 +32,9 @@ class ConnectConfigEntity(Base):
|
||||
class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
"""db connect config dao"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def update(self, entity: ConnectConfigEntity):
|
||||
"""update db connect info"""
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
@@ -58,7 +44,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
|
||||
def delete(self, db_name: int):
|
||||
""" "delete db connect info"""
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
if db_name is None:
|
||||
raise Exception("db_name is None")
|
||||
|
||||
@@ -70,7 +56,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
|
||||
def get_by_names(self, db_name: str) -> ConnectConfigEntity:
|
||||
"""get db connect info by name"""
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
||||
result = db_connect.first()
|
||||
@@ -99,7 +85,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
comment: comment
|
||||
"""
|
||||
try:
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
@@ -144,7 +130,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
old_db_conf = self.get_db_config(db_name)
|
||||
if old_db_conf:
|
||||
try:
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
if not db_path:
|
||||
update_statement = text(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
|
||||
@@ -164,7 +150,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
||||
"""add file db connect info"""
|
||||
try:
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
insert_statement = text(
|
||||
"""
|
||||
INSERT INTO connect_config(
|
||||
@@ -194,7 +180,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
|
||||
def get_db_config(self, db_name):
|
||||
"""get db config by name"""
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
if db_name:
|
||||
select_statement = text(
|
||||
"""
|
||||
@@ -221,7 +207,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
|
||||
def get_db_list(self):
|
||||
"""get db list"""
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
result = session.execute(text("SELECT * FROM connect_config"))
|
||||
|
||||
fields = [field[0] for field in result.cursor.description]
|
||||
@@ -235,7 +221,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
|
||||
def delete_db(self, db_name):
|
||||
"""delete db connect info"""
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
|
||||
params = {"db_name": db_name}
|
||||
session.execute(delete_statement, params)
|
||||
|
@@ -270,7 +270,12 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def __write(self, write_sql):
|
||||
def _write(self, write_sql: str):
|
||||
"""Run a SQL write command and return the results as a list of tuples.
|
||||
|
||||
Args:
|
||||
write_sql (str): SQL write command to run
|
||||
"""
|
||||
print(f"Write[{write_sql}]")
|
||||
db_cache = self._engine.url.database
|
||||
result = self.session.execute(text(write_sql))
|
||||
@@ -280,16 +285,12 @@ class RDBMSDatabase(BaseConnect):
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def __query(self, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
def _query(self, query: str, fetch: str = "all"):
|
||||
"""Run a SQL query and return the results as a list of tuples.
|
||||
|
||||
Args:
|
||||
session:
|
||||
query:
|
||||
fetch:
|
||||
|
||||
Returns:
|
||||
|
||||
query (str): SQL query to run
|
||||
fetch (str): fetch type
|
||||
"""
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
@@ -308,6 +309,10 @@ class RDBMSDatabase(BaseConnect):
|
||||
result.insert(0, field_names)
|
||||
return result
|
||||
|
||||
def query_table_schema(self, table_name):
|
||||
sql = f"select * from {table_name} limit 1"
|
||||
return self._query(sql)
|
||||
|
||||
def query_ex(self, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
@@ -325,7 +330,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()[0] # type: ignore
|
||||
result = cursor.fetchone() # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = list(i[0:] for i in cursor.keys())
|
||||
@@ -342,12 +347,12 @@ class RDBMSDatabase(BaseConnect):
|
||||
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
|
||||
if ttype == sqlparse.tokens.DML:
|
||||
if sql_type == "SELECT":
|
||||
return self.__query(command, fetch)
|
||||
return self._query(command, fetch)
|
||||
else:
|
||||
self.__write(command)
|
||||
self._write(command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
return self.__query(select_sql)
|
||||
return self._query(select_sql)
|
||||
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
@@ -360,10 +365,11 @@ class RDBMSDatabase(BaseConnect):
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
if not result:
|
||||
return self.__query(f"SHOW COLUMNS FROM {table_name}")
|
||||
# return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.get_simple_fields(table_name)
|
||||
return result
|
||||
else:
|
||||
return self.__query(f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.get_simple_fields(table_name)
|
||||
|
||||
def run_to_df(self, command: str, fetch: str = "all"):
|
||||
result_lst = self.run(command, fetch)
|
||||
@@ -451,13 +457,23 @@ class RDBMSDatabase(BaseConnect):
|
||||
sql = sql.strip()
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
sql_type = parsed.get_type()
|
||||
table_name = parsed.get_name()
|
||||
if sql_type == "CREATE":
|
||||
table_name = self._extract_table_name_from_ddl(parsed)
|
||||
else:
|
||||
table_name = parsed.get_name()
|
||||
|
||||
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||
ttype = first_token.ttype
|
||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
|
||||
return parsed, ttype, sql_type, table_name
|
||||
|
||||
def _extract_table_name_from_ddl(self, parsed):
|
||||
"""Extract table name from CREATE TABLE statement.""" ""
|
||||
for token in parsed.tokens:
|
||||
if token.ttype is None and isinstance(token, sqlparse.sql.Identifier):
|
||||
return token.get_real_name()
|
||||
return None
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
session = self._db_sessions()
|
||||
@@ -485,6 +501,10 @@ class RDBMSDatabase(BaseConnect):
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_simple_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
session = self._db_sessions()
|
||||
|
@@ -56,6 +56,10 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
print(fields)
|
||||
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
|
||||
|
||||
def get_simple_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
return self.get_fields(table_name)
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
@@ -88,8 +92,9 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def _write(self, session, write_sql):
|
||||
def _write(self, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
session = self.session
|
||||
result = session.execute(text(write_sql))
|
||||
session.commit()
|
||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
|
@@ -25,41 +25,41 @@ def test_get_table_info(db):
|
||||
|
||||
|
||||
def test_get_table_info_with_table(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER);")
|
||||
db.run("CREATE TABLE test (id INTEGER);")
|
||||
print(db._sync_tables_from_db())
|
||||
table_info = db.get_table_info()
|
||||
assert "CREATE TABLE test" in table_info
|
||||
|
||||
|
||||
def test_run_sql(db):
|
||||
result = db.run(db.session, "CREATE TABLE test (id INTEGER);")
|
||||
assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk")
|
||||
result = db.run("CREATE TABLE test(id INTEGER);")
|
||||
assert result[0] == ("id", "INTEGER", 0, None, 0)
|
||||
|
||||
|
||||
def test_run_no_throw(db):
|
||||
assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:")
|
||||
assert db.run_no_throw("this is a error sql").startswith("Error:")
|
||||
|
||||
|
||||
def test_get_indexes(db):
|
||||
db.run(db.session, "CREATE TABLE test (name TEXT);")
|
||||
db.run(db.session, "CREATE INDEX idx_name ON test(name);")
|
||||
db.run("CREATE TABLE test (name TEXT);")
|
||||
db.run("CREATE INDEX idx_name ON test(name);")
|
||||
assert db.get_indexes("test") == [("idx_name", "c")]
|
||||
|
||||
|
||||
def test_get_indexes_empty(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_indexes("test") == []
|
||||
|
||||
|
||||
def test_get_show_create_table(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert (
|
||||
db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
|
||||
)
|
||||
|
||||
|
||||
def test_get_fields(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)]
|
||||
|
||||
|
||||
@@ -72,26 +72,26 @@ def test_get_collation(db):
|
||||
|
||||
|
||||
def test_table_simple_info(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.table_simple_info() == ["test(id);"]
|
||||
|
||||
|
||||
def test_get_table_info_no_throw(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_table_info_no_throw("xxxx_table").startswith("Error:")
|
||||
|
||||
|
||||
def test_query_ex(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run(db.session, "insert into test(id) values (1)")
|
||||
db.run(db.session, "insert into test(id) values (2)")
|
||||
field_names, result = db.query_ex(db.session, "select * from test")
|
||||
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run("insert into test(id) values (1)")
|
||||
db.run("insert into test(id) values (2)")
|
||||
field_names, result = db.query_ex("select * from test")
|
||||
assert field_names == ["id"]
|
||||
assert result == [(1,), (2,)]
|
||||
|
||||
field_names, result = db.query_ex(db.session, "select * from test", fetch="one")
|
||||
field_names, result = db.query_ex("select * from test", fetch="one")
|
||||
assert field_names == ["id"]
|
||||
assert result == [(1,)]
|
||||
assert result == [1]
|
||||
|
||||
|
||||
def test_convert_sql_write_to_select(db):
|
||||
@@ -109,7 +109,7 @@ def test_get_users(db):
|
||||
|
||||
def test_get_table_comments(db):
|
||||
assert db.get_table_comments() == []
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_table_comments() == [
|
||||
("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)")
|
||||
]
|
||||
|
Reference in New Issue
Block a user